From f385057f1ed55112917ec3717a981d0d1ef24084 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?David=20Sch=C3=A4fer?= <david.schaefer@ufz.de>
Date: Thu, 21 Jan 2021 09:47:10 +0100
Subject: [PATCH] drift: type hints

---
 saqc/funcs/drift.py | 115 ++++++++++++++++++++++++++++----------------
 1 file changed, 74 insertions(+), 41 deletions(-)

diff --git a/saqc/funcs/drift.py b/saqc/funcs/drift.py
index 71061240b..a84ddf657 100644
--- a/saqc/funcs/drift.py
+++ b/saqc/funcs/drift.py
@@ -1,16 +1,19 @@
 #! /usr/bin/env python
 # -*- coding: utf-8 -*-
 import functools
+from typing import Optional, Tuple, Sequence, Callable, Any, Optional
+from typing_extensions import Literal
 
-import dios
 import numpy as np
 import pandas as pd
-import scipy
 from scipy import stats
 from scipy.optimize import curve_fit
+from scipy.spatial.distance import pdist
 
+from dios import DictOfSeries
 
 from saqc.core.register import register
+from saqc.flagger.baseflagger import BaseFlagger
 from saqc.funcs.resampling import shift
 from saqc.funcs.changepoints import assignChangePointCluster
 from saqc.funcs.tools import drop, copy
@@ -19,10 +22,14 @@ from saqc.lib.ts_operators import expModelFunc
 
 
 @register(masking='all')
-def flagDriftFromNorm(data, field, flagger, fields, segment_freq, norm_spread, norm_frac=0.5,
-                      metric=lambda x, y: scipy.spatial.distance.pdist(np.array([x, y]),
-                                                                       metric='cityblock') / len(x),
-                      linkage_method='single', **kwargs):
+def flagDriftFromNorm(data: DictOfSeries, field: str, flagger: BaseFlagger,
+                      fields: Sequence[str],
+                      segment_freq: str,
+                      norm_spread: float,
+                      norm_frac: float=0.5,
+                      metric: Callable[[np.array, np.array], float]=lambda x, y: pdist(np.array([x, y]), metric='cityblock') / len(x),
+                      linkage_method: Literal["single", "complete", "average", "weighted", "centroid", "median", "ward"]="single",
+                      **kwargs) -> Tuple[DictOfSeries, BaseFlagger]:
     """
     The function flags value courses that significantly deviate from a group of normal value courses.
 
@@ -52,7 +59,7 @@ def flagDriftFromNorm(data, field, flagger, fields, segment_freq, norm_spread, n
         Has to be in [0,1]. Determines the minimum percentage of variables, the "normal" group has to comprise to be the
         normal group actually. The higher that value, the more stable the algorithm will be with respect to false
         positives. Also, nobody knows what happens, if this value is below 0.5.
-    metric : Callable[(numpyp.array, numpy-array), float]
+    metric : Callable[[numpy.array, numpy.array], float]
         A distance function. It should be a function of 2 1-dimensional arrays and return a float scalar value.
         This value is interpreted as the distance of the two input arrays. The default is the averaged manhatten metric.
         See the Notes section to get an idea of why this could be a good choice.
@@ -124,10 +131,12 @@ def flagDriftFromNorm(data, field, flagger, fields, segment_freq, norm_spread, n
 
 
 @register(masking='all')
-def flagDriftFromReference(data, field, flagger, fields, segment_freq, thresh,
-                      metric=lambda x, y: scipy.spatial.distance.pdist(np.array([x, y]),
-                                                                    metric='cityblock')/len(x),
-                       **kwargs):
+def flagDriftFromReference(data: DictOfSeries, field: str, flagger: BaseFlagger,
+                           fields: Sequence[str],
+                           segment_freq: str,
+                           thresh: float,
+                           metric: Callable[[np.array, np.array], float]=lambda x, y: pdist(np.array([x, y]), metric='cityblock') / len(x),
+                           **kwargs) -> Tuple[DictOfSeries, BaseFlagger]:
     """
     The function flags value courses that deviate from a reference course by a margin exceeding a certain threshold.
 
@@ -172,6 +181,7 @@ def flagDriftFromReference(data, field, flagger, fields, segment_freq, thresh,
 
     data_to_flag = data[fields].to_df()
     data_to_flag.dropna(inplace=True)
+    fields = list(fields)
     if field not in fields:
         fields.append(field)
     var_num = len(fields)
@@ -190,10 +200,15 @@ def flagDriftFromReference(data, field, flagger, fields, segment_freq, thresh,
 
 
 @register(masking='all')
-def flagDriftFromScaledNorm(data, field, flagger, fields_scale1, fields_scale2, segment_freq, norm_spread, norm_frac=0.5,
-                            metric=lambda x, y: scipy.spatial.distance.pdist(np.array([x, y]),
-                                                                                    metric='cityblock')/len(x),
-                            linkage_method='single', **kwargs):
+def flagDriftFromScaledNorm(data: DictOfSeries, field: str, flagger: BaseFlagger,
+                            fields_scale1: Sequence[str],
+                            fields_scale2: Sequence[str],
+                            segment_freq: str,
+                            norm_spread: float,
+                            norm_frac: float=0.5,
+                            metric: Callable[[np.array, np.array], float]=lambda x, y: pdist(np.array([x, y]), metric='cityblock') / len(x),
+                            linkage_method: Literal["single", "complete", "average", "weighted", "centroid", "median", "ward"]="single",
+                            **kwargs) -> Tuple[DictOfSeries, BaseFlagger]:
 
 
     """
@@ -261,7 +276,7 @@ def flagDriftFromScaledNorm(data, field, flagger, fields_scale1, fields_scale2,
         [2] https://en.wikipedia.org/wiki/Hierarchical_clustering
     """
 
-    fields = fields_scale1 + fields_scale2
+    fields = list(fields_scale1) + list(fields_scale2)
     data_to_flag = data[fields].to_df()
     data_to_flag.dropna(inplace=True)
 
@@ -270,14 +285,14 @@ def flagDriftFromScaledNorm(data, field, flagger, fields_scale1, fields_scale2,
 
     for field1 in fields_scale1:
         for field2 in fields_scale2:
-            slope, intercept, r_value, p_value, std_err = stats.linregress(data_to_flag[field1], data_to_flag[field2])
+            slope, intercept, _, _, _ = stats.linregress(data_to_flag[field1], data_to_flag[field2])
             convert_slope.append(slope)
             convert_intercept.append(intercept)
 
     factor_slope = np.median(convert_slope)
     factor_intercept = np.median(convert_intercept)
 
-    dat = dios.DictOfSeries()
+    dat = DictOfSeries()
     for field1 in fields_scale1:
         dat[field1] = factor_intercept + factor_slope * data_to_flag[field1]
     for field2 in fields_scale2:
@@ -297,8 +312,9 @@ def flagDriftFromScaledNorm(data, field, flagger, fields_scale1, fields_scale2,
 
 
 @register(masking='all')
-def correctExponentialDrift(data, field, flagger, maint_data_field, cal_mean=5, flag_maint_period=False,
-                            check_maint='1h', **kwargs):
+def correctExponentialDrift(data: DictOfSeries, field: str, flagger: BaseFlagger,
+                            maint_data_field: str, cal_mean: int=5, flag_maint_period: bool=False,
+                            **kwargs) -> Tuple[DictOfSeries, BaseFlagger]:
     """
     The function fits an exponential model to chunks of data[field].
     It is assumed, that between maintenance events, there is a drift effect shifting the meassurements in a way, that
@@ -347,8 +363,6 @@ def correctExponentialDrift(data, field, flagger, maint_data_field, cal_mean=5,
         directly before maintenance event. This values are needed for shift calibration. (see above description)
     flag_maint_period : bool, default False
         Wheather or not to flag BAD the values directly obtained while maintenance.
-    check_maint : bool, default True
-        Wheather or not to check, if the reported maintenance intervals match are plausible
 
     Returns
     -------
@@ -360,7 +374,6 @@ def correctExponentialDrift(data, field, flagger, maint_data_field, cal_mean=5,
         Flags values may have changed relatively to the flagger input.
     """
 
-
     # 1: extract fit intervals:
     if data[maint_data_field].empty:
         return data, flagger
@@ -399,7 +412,11 @@ def correctExponentialDrift(data, field, flagger, maint_data_field, cal_mean=5,
 
 
 @register(masking='all')
-def correctRegimeAnomaly(data, field, flagger, cluster_field, model, regime_transmission=None, x_date=False):
+def correctRegimeAnomaly(data: DictOfSeries, field: str, flagger: BaseFlagger,
+                         cluster_field: str,
+                         model: Callable[[np.array, Any], np.array],
+                         regime_transmission: Optional[str]=None,
+                         x_date: bool=False) -> Tuple[DictOfSeries, BaseFlagger]:
     """
     Function fits the passed model to the different regimes in data[field] and tries to correct
     those values, that have assigned a negative label by data[cluster_field].
@@ -420,7 +437,7 @@ def correctRegimeAnomaly(data, field, flagger, cluster_field, model, regime_tran
         The fieldname of the data column, you want to correct.
     flagger : saqc.flagger
         A flagger object, holding flags and additional Informations related to `data`.
-    clusterfield : str
+    cluster_field : str
         A string denoting the field in data, holding the cluster label for the data you want to correct.
     model : Callable
         The model function to be fitted to the regimes.
@@ -468,7 +485,7 @@ def correctRegimeAnomaly(data, field, flagger, cluster_field, model, regime_tran
             valid_mask &= (xdata > xdata[0] + regime_transmission)
             valid_mask &= (xdata < xdata[-1] - regime_transmission)
         try:
-            p, pcov = curve_fit(model, xdata[valid_mask], ydata[valid_mask])
+            p, *_ = curve_fit(model, xdata[valid_mask], ydata[valid_mask])
         except (RuntimeError, ValueError):
             p = np.array([np.nan])
         para_dict[label] = p
@@ -504,8 +521,14 @@ def correctRegimeAnomaly(data, field, flagger, cluster_field, model, regime_tran
 
 
 @register(masking='all')
-def correctOffset(data, field, flagger, max_mean_jump, normal_spread, search_winsz, min_periods,
-                  regime_transmission=None):
+def correctOffset(data: DictOfSeries, field: str, flagger: BaseFlagger,
+                  max_mean_jump: float,
+                  normal_spread: float,
+                  search_winsz: str,
+                  min_periods: int,
+                  regime_transmission: Optional[str]=None,
+                  **kwargs
+                  ) -> Tuple[DictOfSeries, BaseFlagger]:
     """
 
     Parameters
@@ -574,7 +597,7 @@ def _drift_fit(x, shift_target, cal_mean):
     dataFitFunc = functools.partial(modelWrapper, a=origin_mean, target_mean=target_mean)
 
     try:
-        fitParas, _ = curve_fit(dataFitFunc, x_data, y_data, bounds=([0], [np.inf]))
+        fitParas, *_ = curve_fit(dataFitFunc, x_data, y_data, bounds=([0], [np.inf]))
         dataFit = dataFitFunc(x_data, fitParas[0])
         b_val = (shift_target - origin_mean) / (np.exp(fitParas[0]) - 1)
         dataShiftFunc = functools.partial(expModelFunc, a=origin_mean, b=b_val, c=fitParas[0])
@@ -587,9 +610,13 @@ def _drift_fit(x, shift_target, cal_mean):
 
 
 @register(masking='all')
-def flagRegimeAnomaly(data, field, flagger, cluster_field, norm_spread, linkage_method='single',
-                       metric=lambda x, y: np.abs(np.nanmean(x) - np.nanmean(y)),
-                       norm_frac=0.5, **kwargs):
+def flagRegimeAnomaly(data: DictOfSeries, field: str, flagger: BaseFlagger,
+                      cluster_field: str,
+                      norm_spread: float,
+                      linkage_method: Literal["single", "complete", "average", "weighted", "centroid", "median", "ward"]="single",
+                      metric: Callable[[np.array, np.array], float]=lambda x, y: np.abs(np.nanmean(x) - np.nanmean(y)),
+                      norm_frac: float=0.5,
+                      **kwargs) -> Tuple[DictOfSeries, BaseFlagger]:
     """
     A function to flag values belonging to an anomalous regime regarding modelling regimes of field.
 
@@ -638,15 +665,21 @@ def flagRegimeAnomaly(data, field, flagger, cluster_field, norm_spread, linkage_
 
     data, flagger = assignRegimeAnomaly(data, field, flagger, cluster_field, norm_spread,
                                         linkage_method=linkage_method, metric=metric, norm_frac=norm_frac,
-                                        _set_cluster=False, _set_flags=True, **kwargs)
+                                        set_cluster=False, set_flags=True, **kwargs)
 
     return data, flagger
 
 
 @register(masking='all')
-def assignRegimeAnomaly(data, field, flagger, cluster_field, norm_spread, linkage_method='single',
-                        metric=lambda x, y: np.abs(np.nanmean(x) - np.nanmean(y)),
-                        norm_frac=0.5, _set_cluster=True, _set_flags=False, **kwargs):
+def assignRegimeAnomaly(data: DictOfSeries, field: str, flagger: BaseFlagger,
+                        cluster_field: str,
+                        norm_spread: float,
+                        linkage_method: Literal["single", "complete", "average", "weighted", "centroid", "median", "ward"]="single",
+                        metric: Callable[[np.array, np.array], float]=lambda x, y: np.abs(np.nanmean(x) - np.nanmean(y)),
+                        norm_frac: float=0.5,
+                        set_cluster: bool=True,
+                        set_flags: bool=False,
+                        **kwargs) -> Tuple[DictOfSeries, BaseFlagger]:
     """
     A function to detect values belonging to an anomalous regime regarding modelling regimes of field.
 
@@ -684,10 +717,10 @@ def assignRegimeAnomaly(data, field, flagger, cluster_field, norm_spread, linkag
     norm_frac : float
         Has to be in [0,1]. Determines the minimum percentage of samples,
         the "normal" group has to comprise to be the normal group actually.
-    _set_cluster : bool, default False
+    set_cluster : bool, default False
         If True, all data, considered "anormal", gets assigned a negative clusterlabel. This option
         is present for further use (correction) of the anomaly information.
-    _set_flags : bool, default True
+    set_flags : bool, default True
         Wheather or not to flag abnormal values (do not flag them, if you want to correct them
         afterwards, becasue flagged values usually are not visible in further tests.).
 
@@ -704,14 +737,14 @@ def assignRegimeAnomaly(data, field, flagger, cluster_field, norm_spread, linkag
 
     clusterser = data[cluster_field]
     cluster = np.unique(clusterser)
-    cluster_dios = dios.DictOfSeries({i: data[field][clusterser == i] for i in cluster})
+    cluster_dios = DictOfSeries({i: data[field][clusterser == i] for i in cluster})
     plateaus = detectDeviants(cluster_dios, metric, norm_spread, norm_frac, linkage_method, 'samples')
 
-    if _set_flags:
+    if set_flags:
         for p in plateaus:
             flagger = flagger.setFlags(field, loc=cluster_dios.iloc[:, p].index, **kwargs)
 
-    if _set_cluster:
+    if set_cluster:
         for p in plateaus:
             if cluster[p] > 0:
                 clusterser[clusterser == cluster[p]] = -cluster[p]
-- 
GitLab