diff --git a/saqc/funcs/functions.py b/saqc/funcs/functions.py index 9e22e3fd7ae8b0b66210a4cf8e790a70ddb9decf..91982c1af235bd90cfdfd22829654d695f435393 100644 --- a/saqc/funcs/functions.py +++ b/saqc/funcs/functions.py @@ -1052,7 +1052,7 @@ def flagDriftFromReference(data, field, flagger, fields, segment_freq, thresh, def _slidingWindowSearch(data_arr, bwd_start, fwd_end, stat_func, thresh_func, num_val): stat_arr = np.zeros(num_val) thresh_arr = np.zeros(num_val) - for win_i in numba.prange(1, len(data_arr)): + for win_i in numba.prange(1, num_val): x = data_arr[bwd_start[win_i - 1]:win_i] y = data_arr[win_i:fwd_end[win_i - 1]] stat_arr[win_i - 1] = stat_func(x, y) @@ -1060,10 +1060,22 @@ def _slidingWindowSearch(data_arr, bwd_start, fwd_end, stat_func, thresh_func, n return stat_arr, thresh_arr +@numba.jit(nopython=True, parallel=True) +def _reduceCPCluster(stat_arr, thresh_arr, start, end, obj_func, num_val, out_arr): + for win_i in numba.prange(0, num_val): + s, e = start[win_i], end[win_i] + x = stat_arr[s:e] + y = thresh_arr[s:e] + pos = s + obj_func(x, y) + out_arr[pos] = True + + return out_arr + + @register(masking='field') def flagChangePoints(data, field, flagger, stat_func, thresh_func, bwd_window, min_periods_bwd, fwd_window=None, min_periods_fwd=None, closed='both', try_to_jit=True, - agg_range=None): + agg_range=None, reduce_func=lambda x, y: x.argmax()): """ Function for change point detection based on sliding window search. @@ -1106,6 +1118,7 @@ def flagChangePoints(data, field, flagger, stat_func, thresh_func, bwd_window, m if try_to_jit: stat_func = numba.jit(stat_func) thresh_func = numba.jit(thresh_func) + reduce_func = numba.jit(reduce_func) indexer = FreqIndexer() indexer.index_array = data_ser.index.to_numpy(int) @@ -1124,8 +1137,13 @@ def flagChangePoints(data, field, flagger, stat_func, thresh_func, bwd_window, m stat_arr, thresh_arr = _slidingWindowSearch(data_arr, bwd_start, fwd_end, stat_func, thresh_func, var_len) result_arr = stat_arr > thresh_arr - detected = pd.Series(True, index=data_ser[result_arr].index) - cp_cluster = customRolling(detected, agg_range, count, closed='both', min_periods=1, center=True) - - flagger = flagger.setFlags(field, loc=result_arr) + det_index = data_ser[result_arr].index + detected = pd.Series(True, index=det_index) + start, end = customRolling(detected, agg_range, count, closed='both', min_periods=1, center=True, index_only=True) + out_arr = np.zeros(shape=detected.shape[0], dtype=bool) + detected = _reduceCPCluster(stat_arr[result_arr], thresh_arr[result_arr], start, end, reduce_func, + detected.shape[0], out_arr) + + det_index = det_index[detected] + flagger = flagger.setFlags(field, loc=det_index) return data, flagger \ No newline at end of file diff --git a/saqc/lib/tools.py b/saqc/lib/tools.py index b3396373a4c751156920895fcfe613f49b20906c..4987dae119b784f0d2a3dfabd036d4281bc5c82c 100644 --- a/saqc/lib/tools.py +++ b/saqc/lib/tools.py @@ -439,7 +439,7 @@ class PeriodsIndexer(BaseIndexer): def customRolling(to_roll, winsz, func, roll_mask=None, min_periods=1, center=False, closed=None, raw=True, engine=None, - forward=False): + forward=False, index_only=False): """ A wrapper around pandas.rolling.apply(), that allows for skipping func application on arbitrary selections of windows. @@ -476,8 +476,10 @@ def customRolling(to_roll, winsz, func, roll_mask=None, min_periods=1, center=Fa If true, roll with forward facing windows. (not yet implemented for integer defined windows.) center : bool, default False - If true, set the label to the center of the rolling window. Although available - for windows defined by sample rates! (yeah!) + If true, set the label to the center of the rolling window. Also available + for frequencie defined rolling windows! (yeah!) + index_only : bool, default False + Only return rolling window indices. Returns ------- @@ -502,6 +504,10 @@ def customRolling(to_roll, winsz, func, roll_mask=None, min_periods=1, center=Fa center=center, closed=closed) + if index_only: + num_values = to_roll.shape[0] + return indexer.get_window_bounds(num_values, min_periods, center, closed) + i_roller = i_roll.rolling(indexer, min_periods=min_periods, center=center,