From 469c7734fff7a26a5f86aebec5ac176d28d40656 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?David=20Sch=C3=A4fer?= <david.schaefer@ufz.de>
Date: Thu, 21 Jan 2021 19:52:11 +0100
Subject: [PATCH] changepoints: type hints

---
 saqc/funcs/changepoints.py | 88 ++++++++++++++++++++++++--------------
 1 file changed, 57 insertions(+), 31 deletions(-)

diff --git a/saqc/funcs/changepoints.py b/saqc/funcs/changepoints.py
index a048f8142..7a92e470e 100644
--- a/saqc/funcs/changepoints.py
+++ b/saqc/funcs/changepoints.py
@@ -1,22 +1,38 @@
 #! /usr/bin/env python
 # -*- coding: utf-8 -*-
 
+import logging
+
 import pandas as pd
 import numpy as np
 import numba
+from typing import Callable, Union, Tuple
+from typing_extensions import Literal
+
+from dios import DictOfSeries
 
 from saqc.core.register import register
 from saqc.lib.tools import customRoller
-import logging
+from saqc.flagger.baseflagger import BaseFlagger
 
 logger = logging.getLogger("SaQC")
 
 
 @register(masking='field')
-def flagChangePoints(data, field, flagger, stat_func, thresh_func, bwd_window, min_periods_bwd,
-                     fwd_window=None, min_periods_fwd=None, closed='both', try_to_jit=True,
-                     reduce_window=None, reduce_func=lambda x, y: x.argmax(), flag_changepoints=False,
-                     _model_by_resids=False, _assign_cluster=True, **kwargs):
+def flagChangePoints(data: DictOfSeries, field: str, flagger: BaseFlagger,
+                     stat_func: Callable[[np.array], np.array],
+                     thresh_func: Callable[[np.array], np.array],
+                     bwd_window: str,
+                     min_periods_bwd: Union[str, int],
+                     fwd_window: str=None,
+                     min_periods_fwd: Union[str, int]=None,
+                     closed: Literal["right", "left", "both", "neither"]="both",
+                     try_to_jit: bool=True,
+                     reduce_window: str=None,
+                     reduce_func: Callable[[np.array, np.array], np.array]=lambda x, y: x.argmax(),
+                     model_by_resids: bool=False,
+                     assign_cluster: bool=True,
+                     **kwargs) -> Tuple[DictOfSeries, BaseFlagger]:
     """
     Flag datapoints, where the parametrization of the process, the data is assumed to generate by, significantly
     changes.
@@ -42,21 +58,21 @@ def flagChangePoints(data, field, flagger, stat_func, thresh_func, bwd_window, m
     min_periods_bwd : {str, int}
         Minimum number of periods that have to be present in a backwards facing window, for a changepoint test to be
         performed.
-    fwd_window : {Non/home/luenensc/PyPojects/testSpace/flagBasicMystery.pye, str}, default None
-        The right (fo/home/luenensc/PyPojects/testSpace/flagBasicMystery.pyrward facing) windows temporal extension (freq-string).
+    fwd_window : {None, str}, default None
+        The right (forward facing) windows temporal extension (freq-string).
     min_periods_fwd : {None, str, int}, default None
-        Minimum numbe/home/luenensc/PyPojects/testSpace/flagBasicMystery.pyr of periods that have to be present in a forward facing window, for a changepoint test to be
+        Minimum number of periods that have to be present in a forward facing window, for a changepoint test to be
         performed.
     closed : {'right', 'left', 'both', 'neither'}, default 'both'
         Determines the closure of the sliding windows.
-    reduce_window : {None, False, str}, default None
+    reduce_window : {None, str}, default None
         The sliding window search method is not an exact CP search method and usually there wont be
         detected a single changepoint, but a "region" of change around a changepoint.
-        If `reduce_window` is not False, for every window of size `reduce_window`, there
+        If `reduce_window` is given, for every window of size `reduce_window`, there
         will be selected the value with index `reduce_func(x, y)` and the others will be dropped.
         If `reduce_window` is None, the reduction window size equals the
         twin window size, the changepoints have been detected with.
-    reduce_func : Callable[numpy.array, numpy.array], default lambda x, y: x.argmax()
+    reduce_func : Callable[[numpy.array, numpy.array], np.array], default lambda x, y: x.argmax()
         A function that must return an index value upon input of two arrays x and y.
         First input parameter will hold the result from the stat_func evaluation for every
         reduction window. Second input parameter holds the result from the thresh_func evaluation.
@@ -72,16 +88,28 @@ def flagChangePoints(data, field, flagger, stat_func, thresh_func, bwd_window, m
                                              bwd_window=bwd_window, min_periods_bwd=min_periods_bwd,
                                              fwd_window=fwd_window, min_periods_fwd=min_periods_fwd, closed=closed,
                                              try_to_jit=try_to_jit, reduce_window=reduce_window,
-                                             reduce_func=reduce_func, flag_changepoints=True, _model_by_resids=False,
-                                             _assign_cluster=False)
+                                             reduce_func=reduce_func, flag_changepoints=True, model_by_resids=False,
+                                             assign_cluster=False, **kwargs)
     return data, flagger
 
 
 @register(masking='field')
-def assignChangePointCluster(data, field, flagger, stat_func, thresh_func, bwd_window, min_periods_bwd,
-                             fwd_window=None, min_periods_fwd=None, closed='both', try_to_jit=True,
-                             reduce_window=None, reduce_func=lambda x, y: x.argmax(), flag_changepoints=False,
-                             _model_by_resids=False, _assign_cluster=True, **kwargs):
+def assignChangePointCluster(data: DictOfSeries, field: str, flagger: BaseFlagger,
+                             stat_func: Callable[[np.array], np.array],
+                             thresh_func: Callable[[np.array], np.array],
+                             bwd_window: str,
+                             min_periods_bwd: Union[str, int],
+                             fwd_window: str=None,
+                             min_periods_fwd: Union[str, int]=None,
+                             closed: Literal["right", "left", "both", "neither"]="both",
+                             try_to_jit: bool=True,
+                             reduce_window: str=None,
+                             reduce_func: Callable[[np.array, np.array], np.array]=lambda x, y: x.argmax(),
+                             model_by_resids: bool=False,
+                             flag_changepoints: bool=False,
+                             assign_cluster: bool=True,
+                             **kwargs) -> Tuple[DictOfSeries, BaseFlagger]:
+
     """
     Assigns label to the data, aiming to reflect continous regimes of the processes the data is assumed to be
     generated by.
@@ -109,30 +137,30 @@ def assignChangePointCluster(data, field, flagger, stat_func, thresh_func, bwd_w
     min_periods_bwd : {str, int}
         Minimum number of periods that have to be present in a backwards facing window, for a changepoint test to be
         performed.
-    fwd_window : {Non/home/luenensc/PyPojects/testSpace/flagBasicMystery.pye, str}, default None
-        The right (fo/home/luenensc/PyPojects/testSpace/flagBasicMystery.pyrward facing) windows temporal extension (freq-string).
+    fwd_window : {None, str}, default None
+        The right (forward facing) windows temporal extension (freq-string).
     min_periods_fwd : {None, str, int}, default None
-        Minimum numbe/home/luenensc/PyPojects/testSpace/flagBasicMystery.pyr of periods that have to be present in a forward facing window, for a changepoint test to be
+        Minimum number of periods that have to be present in a forward facing window, for a changepoint test to be
         performed.
     closed : {'right', 'left', 'both', 'neither'}, default 'both'
         Determines the closure of the sliding windows.
-    reduce_window : {None, False, str}, default None
+    reduce_window : {None, str}, default None
         The sliding window search method is not an exact CP search method and usually there wont be
         detected a single changepoint, but a "region" of change around a changepoint.
-        If `reduce_window` is not False, for every window of size `reduce_window`, there
+        If `reduce_window` is given, for every window of size `reduce_window`, there
         will be selected the value with index `reduce_func(x, y)` and the others will be dropped.
         If `reduce_window` is None, the reduction window size equals the
         twin window size, the changepoints have been detected with.
-    reduce_func : Callable[numpy.array, numpy.array], default lambda x, y: x.argmax()
+    reduce_func : Callable[[numpy.array, numpy.array], numpy.array], default lambda x, y: x.argmax()
         A function that must return an index value upon input of two arrays x and y.
         First input parameter will hold the result from the stat_func evaluation for every
         reduction window. Second input parameter holds the result from the thresh_func evaluation.
         The default reduction function just selects the value that maximizes the stat_func.
     flag_changepoints : bool, default False
         If true, the points, where there is a change in data modelling regime detected get flagged bad.
-    _model_by_resids : bool, default False
+    model_by_resids : bool, default False
         If True, the data is replaced by the stat_funcs results instead of regime labels.
-    _assign_cluster : bool, default True
+    assign_cluster : bool, default True
         Is set to False, if called by function that oly wants to calculate flags.
 
     Returns
@@ -141,8 +169,6 @@ def assignChangePointCluster(data, field, flagger, stat_func, thresh_func, bwd_w
     """
     data = data.copy()
     data_ser = data[field].dropna()
-    center = False
-    var_len = data_ser.shape[0]
     if fwd_window is None:
         fwd_window = bwd_window
     if min_periods_fwd is None:
@@ -185,7 +211,7 @@ def assignChangePointCluster(data, field, flagger, stat_func, thresh_func, bwd_w
                                                     check_len)
     result_arr = stat_arr > thresh_arr
 
-    if _model_by_resids:
+    if model_by_resids:
         residues = pd.Series(np.nan, index=data[field].index)
         residues[masked_index] = stat_arr
         data[field] = residues
@@ -194,7 +220,7 @@ def assignChangePointCluster(data, field, flagger, stat_func, thresh_func, bwd_w
 
     det_index = masked_index[result_arr]
     detected = pd.Series(True, index=det_index)
-    if reduce_window is not False:
+    if reduce_window:
         l = detected.shape[0]
         roller = customRoller(detected, window=reduce_window)
         start, end = roller.window.get_window_bounds(num_values=l, min_periods=1, closed='both', center=True)
@@ -202,7 +228,7 @@ def assignChangePointCluster(data, field, flagger, stat_func, thresh_func, bwd_w
         detected = _reduceCPCluster(stat_arr[result_arr], thresh_arr[result_arr], start, end, reduce_func, l)
         det_index = det_index[detected]
 
-    if _assign_cluster:
+    if assign_cluster:
         cluster = pd.Series(False, index=data[field].index)
         cluster[det_index] = True
         cluster = cluster.cumsum()
@@ -248,4 +274,4 @@ def _reduceCPCluster(stat_arr, thresh_arr, start, end, obj_func, num_val):
         pos = s + obj_func(x, y) + 1
         out_arr[s:e] = False
         out_arr[pos] = True
-    return out_arr
\ No newline at end of file
+    return out_arr
-- 
GitLab