From 7fa14a1325081392d73dcfd51275be16c9693246 Mon Sep 17 00:00:00 2001
From: Bert Palm <bert.palm@ufz.de>
Date: Tue, 2 Mar 2021 21:54:09 +0100
Subject: [PATCH] fixed drift.py

---
 saqc/funcs/drift.py | 83 ++++++++++++++++++++++++++++-----------------
 1 file changed, 51 insertions(+), 32 deletions(-)

diff --git a/saqc/funcs/drift.py b/saqc/funcs/drift.py
index fd7d0dcb9..173f3ce51 100644
--- a/saqc/funcs/drift.py
+++ b/saqc/funcs/drift.py
@@ -126,14 +126,17 @@ def flagDriftFromNorm(
 
     data_to_flag = data[fields].to_df()
     data_to_flag.dropna(inplace=True)
+
     segments = data_to_flag.groupby(pd.Grouper(freq=segment_freq))
     for segment in segments:
+
         if segment[1].shape[0] <= 1:
             continue
+
         drifters = detectDeviants(segment[1], metric, norm_spread, norm_frac, linkage_method, 'variables')
 
         for var in drifters:
-            flagger = flagger.setFlags(fields[var], loc=segment[1].index, **kwargs)
+            flagger[segment[1].index, fields[var]] = kwargs['flag']
 
     return data, flagger
 
@@ -193,20 +196,24 @@ def flagDriftFromReference(
 
     data_to_flag = data[fields].to_df()
     data_to_flag.dropna(inplace=True)
+
     fields = list(fields)
     if field not in fields:
         fields.append(field)
+
     var_num = len(fields)
-    segments = data_to_flag.groupby(pd.Grouper(freq=segment_freq))
 
+    segments = data_to_flag.groupby(pd.Grouper(freq=segment_freq))
     for segment in segments:
 
         if segment[1].shape[0] <= 1:
             continue
+
         for i in range(var_num):
             dist = metric(segment[1].iloc[:, i].values, segment[1].loc[:, field].values)
+
             if dist > thresh:
-                flagger = flagger.setFlags(fields[i], loc=segment[1].index, **kwargs)
+                flagger[segment[1].index, fields[i]] = kwargs['flag']
 
     return data, flagger
 
@@ -246,7 +253,7 @@ def flagDriftFromScaledNorm(
         A dictionary of pandas.Series, holding all the data.
     field : str
         A dummy parameter.
-    flagger : saqc.flagger
+    flagger : saqc.flagger.Flagger
         A flagger object, holding flags and additional informations related to `data`.
     fields_scale1 : str
         List of fieldnames in data to be included into the flagging process which are scaled according to scaling
@@ -280,7 +287,7 @@ def flagDriftFromScaledNorm(
     -------
     data : dios.DictOfSeries
         A dictionary of pandas.Series, holding all the data.
-    flagger : saqc.flagger
+    flagger : saqc.flagger.Flagger
         The flagger object, holding flags and additional Informations related to `data`.
         Flags values may have changed relatively to the input flagger.
 
@@ -318,11 +325,14 @@ def flagDriftFromScaledNorm(
 
     segments = dat_to_flag.groupby(pd.Grouper(freq=segment_freq))
     for segment in segments:
+
         if segment[1].shape[0] <= 1:
             continue
+
         drifters = detectDeviants(segment[1], metric, norm_spread, norm_frac, linkage_method, 'variables')
+
         for var in drifters:
-            flagger = flagger.setFlags(fields[var], loc=segment[1].index, **kwargs)
+            flagger[segment[1].index, fields[var]] = kwargs['flag']
 
     return data, flagger
 
@@ -395,22 +405,25 @@ def correctExponentialDrift(
         The flagger object, holding flags and additional Informations related to `data`.
         Flags values may have changed relatively to the flagger input.
     """
-
     # 1: extract fit intervals:
     if data[maint_data_field].empty:
         return data, flagger
+
     data = data.copy()
     to_correct = data[field]
     maint_data = data[maint_data_field]
-    drift_frame = pd.DataFrame({"drift_group": np.nan, to_correct.name: to_correct.values}, index=to_correct.index)
+
+    d = {"drift_group": np.nan, to_correct.name: to_correct.values}
+    drift_frame = pd.DataFrame(d, index=to_correct.index)
 
     # group the drift frame
     for k in range(0, maint_data.shape[0] - 1):
         # assign group numbers for the timespans in between one maintenance ending and the beginning of the next
         # maintenance time itself remains np.nan assigned
         drift_frame.loc[maint_data.values[k] : pd.Timestamp(maint_data.index[k + 1]), "drift_group"] = k
-    drift_grouper = drift_frame.groupby("drift_group")
+
     # define target values for correction
+    drift_grouper = drift_frame.groupby("drift_group")
     shift_targets = drift_grouper.aggregate(lambda x: x[:cal_mean].mean()).shift(-1)
 
     for k, group in drift_grouper:
@@ -422,13 +435,13 @@ def correctExponentialDrift(
         shiftedData = dataSeries + dataShiftVektor
         to_correct[shiftedData.index] = shiftedData
 
+    data[field] = to_correct
+
     if flag_maint_period:
         to_flag = drift_frame["drift_group"]
         to_flag = to_flag.drop(to_flag[: maint_data.index[0]].index)
-        to_flag = to_flag[to_flag.isna()]
-        flagger = flagger.setFlags(field, loc=to_flag, **kwargs)
-
-    data[field] = to_correct
+        to_flag = to_flag.dropna()
+        flagger[to_flag, field] = kwargs['flag']
 
     return data, flagger
 
@@ -461,7 +474,7 @@ def correctRegimeAnomaly(
         A dictionary of pandas.Series, holding all the data.
     field : str
         The fieldname of the data column, you want to correct.
-    flagger : saqc.flagger
+    flagger : saqc.flagger.Flagger
         A flagger object, holding flags and additional Informations related to `data`.
     cluster_field : str
         A string denoting the field in data, holding the cluster label for the data you want to correct.
@@ -484,7 +497,7 @@ def correctRegimeAnomaly(
     data : dios.DictOfSeries
         A dictionary of pandas.Series, holding all the data.
         Data values may have changed relatively to the data input.
-    flagger : saqc.flagger
+    flagger : saqc.flagger.Flagger
         The flagger object, holding flags and additional Informations related to `data`.
     """
 
@@ -566,7 +579,7 @@ def correctOffset(
         A dictionary of pandas.Series, holding all the data.
     field : str
         The fieldname of the data column, you want to correct.
-    flagger : saqc.flagger
+    flagger : saqc.flagger.Flagger
         A flagger object, holding flags and additional Informations related to `data`.
     max_mean_jump : float
         when searching for changepoints in mean - this is the threshold a mean difference in the
@@ -590,7 +603,7 @@ def correctOffset(
     data : dios.DictOfSeries
         A dictionary of pandas.Series, holding all the data.
         Data values may have changed relatively to the data input.
-    flagger : saqc.flagger
+    flagger : saqc.flagger.Flagger
         The flagger object, holding flags and additional Informations related to `data`.
 
     """
@@ -674,7 +687,7 @@ def flagRegimeAnomaly(
         A dictionary of pandas.Series, holding all the data.
     field : str
         The fieldname of the column, holding the data-to-be-flagged.
-    flagger : saqc.flagger
+    flagger : saqc.flagger.Flagger
         A flagger object, holding flags and additional Informations related to `data`.
     cluster_field : str
         The name of the column in data, holding the cluster labels for the samples in field. (has to be indexed
@@ -694,17 +707,23 @@ def flagRegimeAnomaly(
 
     data : dios.DictOfSeries
         A dictionary of pandas.Series, holding all the data.
-    flagger : saqc.flagger
+    flagger : saqc.flagger.Flagger
         The flagger object, holding flags and additional informations related to `data`.
         Flags values may have changed, relatively to the flagger input.
 
     """
 
-    data, flagger = assignRegimeAnomaly(data, field, flagger, cluster_field, norm_spread,
-                                        linkage_method=linkage_method, metric=metric, norm_frac=norm_frac,
-                                        set_cluster=False, set_flags=True, **kwargs)
-
-    return data, flagger
+    return assignRegimeAnomaly(
+        data, field, flagger,
+        cluster_field,
+        norm_spread,
+        linkage_method=linkage_method,
+        metric=metric,
+        norm_frac=norm_frac,
+        set_cluster=False,
+        set_flags=True,
+        **kwargs
+    )
 
 
 @register(masking='all', module="drift")
@@ -744,7 +763,7 @@ def assignRegimeAnomaly(
         A dictionary of pandas.Series, holding all the data.
     field : str
         The fieldname of the column, holding the data-to-be-flagged.
-    flagger : saqc.flagger
+    flagger : saqc.flagger.Flagger
         A flagger object, holding flags and additional Informations related to `data`.
     cluster_field : str
         The name of the column in data, holding the cluster labels for the samples in field. (has to be indexed
@@ -770,25 +789,25 @@ def assignRegimeAnomaly(
 
     data : dios.DictOfSeries
         A dictionary of pandas.Series, holding all the data.
-    flagger : saqc.flagger
+    flagger : saqc.flagger.Flagger
         The flagger object, holding flags and additional informations related to `data`.
         Flags values may have changed, relatively to the flagger input.
 
     """
 
-    clusterser = data[cluster_field]
-    cluster = np.unique(clusterser)
-    cluster_dios = DictOfSeries({i: data[field][clusterser == i] for i in cluster})
+    series = data[cluster_field]
+    cluster = np.unique(series)
+    cluster_dios = DictOfSeries({i: data[field][series == i] for i in cluster})
     plateaus = detectDeviants(cluster_dios, metric, norm_spread, norm_frac, linkage_method, 'samples')
 
     if set_flags:
         for p in plateaus:
-            flagger = flagger.setFlags(field, loc=cluster_dios.iloc[:, p].index, **kwargs)
+            flagger[cluster_dios.iloc[:, p].index, field] = kwargs['flags']
 
     if set_cluster:
         for p in plateaus:
             if cluster[p] > 0:
-                clusterser[clusterser == cluster[p]] = -cluster[p]
+                series[series == cluster[p]] = -cluster[p]
 
-    data[cluster_field] = clusterser
+    data[cluster_field] = series
     return data, flagger
-- 
GitLab