From 519572e88f4f0bd72f095208567e4101aa9681b0 Mon Sep 17 00:00:00 2001
From: Peter Luenenschloss <peter.luenenschloss@ufz.de>
Date: Mon, 28 Sep 2020 13:58:59 +0200
Subject: [PATCH] CP cluster reduction added to CPD algorithm

---
 saqc/funcs/functions.py | 30 ++++++++++++++++++++++++------
 saqc/lib/tools.py       | 12 +++++++++---
 2 files changed, 33 insertions(+), 9 deletions(-)

diff --git a/saqc/funcs/functions.py b/saqc/funcs/functions.py
index 9e22e3fd7..91982c1af 100644
--- a/saqc/funcs/functions.py
+++ b/saqc/funcs/functions.py
@@ -1052,7 +1052,7 @@ def flagDriftFromReference(data, field, flagger, fields, segment_freq, thresh,
 def _slidingWindowSearch(data_arr, bwd_start, fwd_end, stat_func, thresh_func, num_val):
     stat_arr = np.zeros(num_val)
     thresh_arr = np.zeros(num_val)
-    for win_i in numba.prange(1, len(data_arr)):
+    for win_i in numba.prange(1, num_val):
         x = data_arr[bwd_start[win_i - 1]:win_i]
         y = data_arr[win_i:fwd_end[win_i - 1]]
         stat_arr[win_i - 1] = stat_func(x, y)
@@ -1060,10 +1060,22 @@ def _slidingWindowSearch(data_arr, bwd_start, fwd_end, stat_func, thresh_func, n
     return stat_arr, thresh_arr
 
 
+@numba.jit(nopython=True, parallel=True)
+def _reduceCPCluster(stat_arr, thresh_arr, start, end, obj_func, num_val, out_arr):
+    for win_i in numba.prange(0, num_val):
+        s, e = start[win_i], end[win_i]
+        x = stat_arr[s:e]
+        y = thresh_arr[s:e]
+        pos = s + obj_func(x, y)
+        out_arr[pos] = True
+
+    return out_arr
+
+
 @register(masking='field')
 def flagChangePoints(data, field, flagger, stat_func, thresh_func, bwd_window, min_periods_bwd,
                      fwd_window=None, min_periods_fwd=None, closed='both', try_to_jit=True,
-                     agg_range=None):
+                     agg_range=None, reduce_func=lambda x, y: x.argmax()):
     """
     Function for change point detection based on sliding window search.
 
@@ -1106,6 +1118,7 @@ def flagChangePoints(data, field, flagger, stat_func, thresh_func, bwd_window, m
     if try_to_jit:
         stat_func = numba.jit(stat_func)
         thresh_func = numba.jit(thresh_func)
+        reduce_func = numba.jit(reduce_func)
 
     indexer = FreqIndexer()
     indexer.index_array = data_ser.index.to_numpy(int)
@@ -1124,8 +1137,13 @@ def flagChangePoints(data, field, flagger, stat_func, thresh_func, bwd_window, m
     stat_arr, thresh_arr = _slidingWindowSearch(data_arr, bwd_start, fwd_end, stat_func, thresh_func, var_len)
 
     result_arr = stat_arr > thresh_arr
-    detected = pd.Series(True, index=data_ser[result_arr].index)
-    cp_cluster = customRolling(detected, agg_range, count, closed='both', min_periods=1, center=True)
-
-    flagger = flagger.setFlags(field, loc=result_arr)
+    det_index = data_ser[result_arr].index
+    detected = pd.Series(True, index=det_index)
+    start, end = customRolling(detected, agg_range, count, closed='both', min_periods=1, center=True, index_only=True)
+    out_arr = np.zeros(shape=detected.shape[0], dtype=bool)
+    detected = _reduceCPCluster(stat_arr[result_arr], thresh_arr[result_arr], start, end, reduce_func,
+                                detected.shape[0], out_arr)
+
+    det_index = det_index[detected]
+    flagger = flagger.setFlags(field, loc=det_index)
     return data, flagger
\ No newline at end of file
diff --git a/saqc/lib/tools.py b/saqc/lib/tools.py
index b3396373a..4987dae11 100644
--- a/saqc/lib/tools.py
+++ b/saqc/lib/tools.py
@@ -439,7 +439,7 @@ class PeriodsIndexer(BaseIndexer):
 
 
 def customRolling(to_roll, winsz, func, roll_mask=None, min_periods=1, center=False, closed=None, raw=True, engine=None,
-                  forward=False):
+                  forward=False, index_only=False):
     """
     A wrapper around pandas.rolling.apply(), that allows for skipping func application on
     arbitrary selections of windows.
@@ -476,8 +476,10 @@ def customRolling(to_roll, winsz, func, roll_mask=None, min_periods=1, center=Fa
         If true, roll with forward facing windows. (not yet implemented for
         integer defined windows.)
     center : bool, default False
-        If true, set the label to the center of the rolling window. Although available
-        for windows defined by sample rates! (yeah!)
+        If true, set the label to the center of the rolling window. Also available
+        for frequencie defined rolling windows! (yeah!)
+    index_only : bool, default False
+        Only return rolling window indices.
 
     Returns
     -------
@@ -502,6 +504,10 @@ def customRolling(to_roll, winsz, func, roll_mask=None, min_periods=1, center=Fa
                                  center=center,
                                  closed=closed)
 
+    if index_only:
+        num_values = to_roll.shape[0]
+        return indexer.get_window_bounds(num_values, min_periods, center, closed)
+
     i_roller = i_roll.rolling(indexer,
                             min_periods=min_periods,
                             center=center,
-- 
GitLab