From c9402f5e8396795bbf6a48e81e7f70fe70749c69 Mon Sep 17 00:00:00 2001
From: Peter Luenenschloss <peter.luenenschloss@ufz.de>
Date: Sat, 26 Sep 2020 15:43:41 +0200
Subject: [PATCH] jitted changepoint detection - inner loop

---
 saqc/funcs/functions.py | 32 ++++++++++++++++++++++----------
 saqc/lib/tools.py       |  2 --
 2 files changed, 22 insertions(+), 12 deletions(-)

diff --git a/saqc/funcs/functions.py b/saqc/funcs/functions.py
index 63c30a3be..8a07acb86 100644
--- a/saqc/funcs/functions.py
+++ b/saqc/funcs/functions.py
@@ -11,6 +11,7 @@ import dtw
 import pywt
 import itertools
 import collections
+import numba
 from mlxtend.evaluate import permutation_test
 from scipy.cluster.hierarchy import linkage, fcluster
 
@@ -1047,13 +1048,25 @@ def flagDriftFromReference(data, field, flagger, fields, segment_freq, thresh,
 
     return data, flagger
 
+@numba.jit(nopython=True)
+def _slidingWindowSearch(data_arr, bwd_start, fwd_end, stat_func, thresh_func):
+    result_arr = np.zeros(len(data_arr) - 1)
+
+    for win_i in range(1, len(data_arr)):
+        x = data_arr[bwd_start[win_i - 1]:win_i]
+        y = data_arr[win_i:fwd_end[win_i - 1]]
+        result_arr[win_i - 1] = stat_func(x, y) > thresh_func(x, y)
+
+    return result_arr
+
+
 @register(masking='field')
 def flagChangePoints(data, field, flagger, stat_func, thresh_func, bwd_window, min_periods_bwd,
-                     fwd_window=None, min_periods_fwd=None, closed='both'):
+                     fwd_window=None, min_periods_fwd=None, closed='both', try_to_jit=True):
     """
     Function for change point detection based on sliding window search.
 
-    The function provides general basic architecture for applying two-sided t-test,
+    The function provides general basic architecture for applying two-sided t-tests,
     max-likelyhood modelling or piecewise regression modelling in order to detect changepoints
     via a sliding "twin window" search.
 
@@ -1083,12 +1096,16 @@ def flagChangePoints(data, field, flagger, stat_func, thresh_func, bwd_window, m
     data_ser = data[field]
     center = False
     var_len = data_ser.shape[0]
+    if try_to_jit:
+        stat_func = numba.jit(stat_func)
+        thresh_func = numba.jit(thresh_func)
+
     FI = FreqIndexer()
-    FI.index_array=data_ser.index.to_numpy(int)
+    FI.index_array = data_ser.index.to_numpy(int)
     FI.win_points = np.array([True]*var_len)
     FI.window_size = int(pd.Timedelta(bwd_window).total_seconds() * 10 ** 9)
     FI.forward = False
-    bwd_start,  bwd_end = FI.get_window_bounds(var_len, min_periods_bwd, center, closed)
+    bwd_start, bwd_end = FI.get_window_bounds(var_len, min_periods_bwd, center, closed)
 
     FI.window_size = int(pd.Timedelta(fwd_window).total_seconds() * 10 ** 9)
     FI.forward = True
@@ -1096,11 +1113,6 @@ def flagChangePoints(data, field, flagger, stat_func, thresh_func, bwd_window, m
     fwd_start, fwd_end = np.roll(fwd_start, -1), np.roll(fwd_end, -1)
 
     data_arr = data_ser.values
-    result_arr = np.zeros(len(data_arr) - 1)
-    for win_i in range(len(data_arr) - 1):
-        x = data_arr[bwd_start[win_i]:bwd_end[win_i]]
-        y = data_arr[fwd_start[win_i]:fwd_end[win_i]]
-        result_arr[win_i] = stat_func(x, y) > thresh_func(x, y)
-
+    result_arr = _slidingWindowSearch(data_arr, bwd_start, fwd_end, stat_func, thresh_func)
     flagger = flagger.setFlags(field, loc=result_arr[result_arr])
     return data, flagger
\ No newline at end of file
diff --git a/saqc/lib/tools.py b/saqc/lib/tools.py
index d051ec22b..11c1adb45 100644
--- a/saqc/lib/tools.py
+++ b/saqc/lib/tools.py
@@ -469,5 +469,3 @@ def customRolling(to_roll, winsz, func, roll_mask, min_periods=1, center=False,
 
     return pd.Series(i_roll.values, index=to_roll.index)
 
-
-
-- 
GitLab