From 39102bb321ebeed2a601df6a2ebe459dc0bccc72 Mon Sep 17 00:00:00 2001
From: luenensc <peter.luenenschloss@ufz.de>
Date: Tue, 17 Jan 2023 18:23:49 +0100
Subject: [PATCH] further streamlining/found faster back prop trick/added
 support for offset defined gap limits

---
 saqc/funcs/interpolation.py    |  3 +-
 saqc/lib/ts_operators.py       | 60 +++++++++++++++-------------------
 tests/lib/test_ts_operators.py |  5 +--
 3 files changed, 29 insertions(+), 39 deletions(-)

diff --git a/saqc/funcs/interpolation.py b/saqc/funcs/interpolation.py
index 52f17bda2..60bdbf29d 100644
--- a/saqc/funcs/interpolation.py
+++ b/saqc/funcs/interpolation.py
@@ -144,7 +144,7 @@ class InterpolationMixin:
         method: _SUPPORTED_METHODS,
         order: int = 2,
         limit: int | None = None,
-        downgrade: bool = False,
+        extrapolate: Literal['forward', 'backward', 'both'] = None,
         flag: float = UNFLAGGED,
         **kwargs,
     ) -> "SaQC":
@@ -187,6 +187,7 @@ class InterpolationMixin:
             method,
             order=order,
             gap_limit=limit,
+            extrapolate=extrapolate
         )
 
         interpolated = self._data[field].isna() & inter_data.notna()
diff --git a/saqc/lib/ts_operators.py b/saqc/lib/ts_operators.py
index 3a51e750c..0347c757d 100644
--- a/saqc/lib/ts_operators.py
+++ b/saqc/lib/ts_operators.py
@@ -21,7 +21,6 @@ import pandas as pd
 from scipy.signal import butter, filtfilt
 from scipy.stats import iqr, median_abs_deviation
 from sklearn.neighbors import NearestNeighbors
-
 from saqc.lib.tools import getFreqDelta
 
 
@@ -317,30 +316,29 @@ def interpolateNANs(
 
     :return:
     """
+
+    # helper variable for checking numerical value of gap limit, if its a numeric value (to avoid comparison to str)
+    gap_check = np.nan if isinstance(gap_limit, str) else gap_limit
     data = pd.Series(data, copy=True)
     limit_area = "inside" if not extrapolate else "outside"
-    if gap_limit is None:
+    if gap_check is None:
         # if there is actually no limit set to the gaps to-be interpolated, generate a dummy mask for the gaps
         gap_mask = pd.Series(True, index=data.index, name=data.name)
-    elif gap_limit < 2:
-        return data
     else:
-        # if there is a limit to the gaps to be interpolated, generate a mask that evaluates to False at the right side
-        # of each too-large gap with a rolling.sum combo
-        gap_mask = data.isna().rolling(gap_limit, min_periods=0).sum() != gap_limit
-        if gap_limit == 2:
-            # for the common case of gap_limit=2 (default "harmonisation"), we efficiently back propagate the False
-            # value to fill the whole too-large gap by a shift and a conjunction.
-            gap_mask &= gap_mask & gap_mask.shift(-1, fill_value=True)
+        if gap_check < 2:
+            # breaks execution down the line and is thus catched here since it basically means "do nothing"
+            return data
         else:
-            # If the gap_size is bigger we use pandas backfill-interpolation to propagate the False values back.
-            # Therefor we replace the True values with np.nan so hat they are interpreted as missing periods.
-            gap_mask = (
-                gap_mask.replace(True, np.nan)
-                .fillna(method="bfill", limit=gap_limit - 1)
-                .replace(np.nan, True)
-                .astype(bool)
-            )
+            # if there is a limit to the gaps to be interpolated, generate a mask that evaluates to False at the right
+            # side of each too-large gap with a rolling.sum combo
+            gap_mask = data.rolling(gap_limit, min_periods=0).count() > 0
+            if gap_limit == 2:
+                # for the common case of gap_limit=2 (default "harmonisation"), we efficiently back propagate the False
+                # value to fill the whole too-large gap by a shift and a conjunction.
+                gap_mask = gap_mask & gap_mask.shift(-1, fill_value=True)
+            else:
+                # If the gap_size is bigger we make an flip-rolling combo to backpropagate the False values
+                gap_mask = ~((~gap_mask[::-1]).rolling(gap_limit, min_periods=0).sum() > 0)[::-1]
 
     # memorizing the index for later reindexing
     pre_index = data.index
@@ -361,21 +359,15 @@ def interpolateNANs(
         # with the .transform method of the grouper.
         gap_mask = (~gap_mask).cumsum()[data.index]
         chunk_groups = data.groupby(by=gap_mask)
-        if extrapolate:
-            if extrapolate in ['both', 'backward']:
-                lead_idx = gap_mask[gap_mask==gap_mask.min()].index
-                data[lead_idx] = _interpolWrapper(data[lead_idx], order=order, method=method, limit_area=limit_area, limit_direction='backward')
-            if extrapolate in ['both', 'forward']:
-                trail_idx = gap_mask[gap_mask==gap_mask.max()].index
-                data[trail_idx] = _interpolWrapper(data[lead_idx], order=order, method=method, limit_area=limit_area, limit_direction='forward')
-        else:
-            data = chunk_groups.groupby(by=gap_mask).transform(
-                _interpolWrapper,
-                **{
-                    "order": order,
-                    "method": method,
-                },
-            )
+        data = chunk_groups.transform(
+            _interpolWrapper,
+            **{
+                "order": order,
+                "method": method,
+                "limit_area": limit_area,
+                "limit_direction": extrapolate
+            },
+        )
     # finally reinsert the dropped data gaps
     data = data.reindex(pre_index)
     return data
diff --git a/tests/lib/test_ts_operators.py b/tests/lib/test_ts_operators.py
index dffb19aa6..0ef417dbb 100644
--- a/tests/lib/test_ts_operators.py
+++ b/tests/lib/test_ts_operators.py
@@ -229,7 +229,4 @@ def test_rateOfChange(data, expected):
 )
 def test_interpolatNANs(limit, data, expected):
     got = interpolateNANs(pd.Series(data), gap_limit=limit, method="linear")
-    try:
-        assert got.equals(pd.Series(expected, dtype=float))
-    except:
-        print("stop")
+    assert got.equals(pd.Series(expected, dtype=float))
-- 
GitLab