From c9402f5e8396795bbf6a48e81e7f70fe70749c69 Mon Sep 17 00:00:00 2001 From: Peter Luenenschloss <peter.luenenschloss@ufz.de> Date: Sat, 26 Sep 2020 15:43:41 +0200 Subject: [PATCH] jitted changepoint detection - inner loop --- saqc/funcs/functions.py | 32 ++++++++++++++++++++++---------- saqc/lib/tools.py | 2 -- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/saqc/funcs/functions.py b/saqc/funcs/functions.py index 63c30a3be..8a07acb86 100644 --- a/saqc/funcs/functions.py +++ b/saqc/funcs/functions.py @@ -11,6 +11,7 @@ import dtw import pywt import itertools import collections +import numba from mlxtend.evaluate import permutation_test from scipy.cluster.hierarchy import linkage, fcluster @@ -1047,13 +1048,25 @@ def flagDriftFromReference(data, field, flagger, fields, segment_freq, thresh, return data, flagger +@numba.jit(nopython=True) +def _slidingWindowSearch(data_arr, bwd_start, fwd_end, stat_func, thresh_func): + result_arr = np.zeros(len(data_arr) - 1) + + for win_i in range(1, len(data_arr)): + x = data_arr[bwd_start[win_i - 1]:win_i] + y = data_arr[win_i:fwd_end[win_i - 1]] + result_arr[win_i - 1] = stat_func(x, y) > thresh_func(x, y) + + return result_arr + + @register(masking='field') def flagChangePoints(data, field, flagger, stat_func, thresh_func, bwd_window, min_periods_bwd, - fwd_window=None, min_periods_fwd=None, closed='both'): + fwd_window=None, min_periods_fwd=None, closed='both', try_to_jit=True): """ Function for change point detection based on sliding window search. - The function provides general basic architecture for applying two-sided t-test, + The function provides general basic architecture for applying two-sided t-tests, max-likelyhood modelling or piecewise regression modelling in order to detect changepoints via a sliding "twin window" search. @@ -1083,12 +1096,16 @@ def flagChangePoints(data, field, flagger, stat_func, thresh_func, bwd_window, m data_ser = data[field] center = False var_len = data_ser.shape[0] + if try_to_jit: + stat_func = numba.jit(stat_func) + thresh_func = numba.jit(thresh_func) + FI = FreqIndexer() - FI.index_array=data_ser.index.to_numpy(int) + FI.index_array = data_ser.index.to_numpy(int) FI.win_points = np.array([True]*var_len) FI.window_size = int(pd.Timedelta(bwd_window).total_seconds() * 10 ** 9) FI.forward = False - bwd_start, bwd_end = FI.get_window_bounds(var_len, min_periods_bwd, center, closed) + bwd_start, bwd_end = FI.get_window_bounds(var_len, min_periods_bwd, center, closed) FI.window_size = int(pd.Timedelta(fwd_window).total_seconds() * 10 ** 9) FI.forward = True @@ -1096,11 +1113,6 @@ def flagChangePoints(data, field, flagger, stat_func, thresh_func, bwd_window, m fwd_start, fwd_end = np.roll(fwd_start, -1), np.roll(fwd_end, -1) data_arr = data_ser.values - result_arr = np.zeros(len(data_arr) - 1) - for win_i in range(len(data_arr) - 1): - x = data_arr[bwd_start[win_i]:bwd_end[win_i]] - y = data_arr[fwd_start[win_i]:fwd_end[win_i]] - result_arr[win_i] = stat_func(x, y) > thresh_func(x, y) - + result_arr = _slidingWindowSearch(data_arr, bwd_start, fwd_end, stat_func, thresh_func) flagger = flagger.setFlags(field, loc=result_arr[result_arr]) return data, flagger \ No newline at end of file diff --git a/saqc/lib/tools.py b/saqc/lib/tools.py index d051ec22b..11c1adb45 100644 --- a/saqc/lib/tools.py +++ b/saqc/lib/tools.py @@ -469,5 +469,3 @@ def customRolling(to_roll, winsz, func, roll_mask, min_periods=1, center=False, return pd.Series(i_roll.values, index=to_roll.index) - - -- GitLab