From 5576f0c98c9522f6c9d892364cff86d3374613ca Mon Sep 17 00:00:00 2001
From: Peter Luenenschloss <peter.luenenschloss@ufz.de>
Date: Mon, 6 Jul 2020 12:03:08 +0200
Subject: [PATCH] linear correction implemented

---
 saqc/funcs/proc_functions.py | 54 ++++++++++++++++++++++--------------
 1 file changed, 33 insertions(+), 21 deletions(-)

diff --git a/saqc/funcs/proc_functions.py b/saqc/funcs/proc_functions.py
index 2c002eb79..6d6ee9c2b 100644
--- a/saqc/funcs/proc_functions.py
+++ b/saqc/funcs/proc_functions.py
@@ -10,7 +10,8 @@ import dios
 import functools
 import matplotlib.pyplot as plt
 from scipy.optimize import curve_fit
-import pickle
+from sklearn.linear_model import LinearRegression
+
 ORIGINAL_SUFFIX = '_original'
 
 METHOD2ARGS = {'inverse_fshift': ('backward', pd.Timedelta),
@@ -158,8 +159,8 @@ def proc_interpolateGrid(data, field, flagger, freq, method, inter_order=2, drop
 
     Note, it is possible to interpolate unregular "grids" (with no frequencies). In fact, any date index
     can be target of the interpolation. Just pass the field name of the variable, holding the index
-    you want to interpolate, to "grid_field". The feature is currently regarded experimental. Interpolation
-    range can not be controlled.
+    you want to interpolate, to "grid_field". 'freq' is then use to determine the maximum gap size for
+    a grid point to be interpolated.
 
     Parameters
     ---------.copy()
@@ -742,13 +743,6 @@ def proc_seefoExpDriftCorrecture(data, field, flagger, maint_data_field, cal_mea
     # define target values for correction
     shift_targets = drift_grouper.aggregate(lambda x: x[:cal_mean].mean()).shift(-1)
 
-    ########################### plotting stuff for testing phase #############################################
-    fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True)
-    axes[0].plot(to_correct[drift_frame.index[0]:drift_frame.index[-1]])
-    axes[0].set(ylabel='sak')
-    axes[1].set(ylabel='shifted - sak')
-    ##########################################################################################################
-
     for k, group in drift_grouper:
         dataSeries = group[to_correct.name]
         dataFit, dataShiftTarget = _drift_fit(dataSeries, shift_targets.loc[k, :][0], cal_mean)
@@ -757,21 +751,39 @@ def proc_seefoExpDriftCorrecture(data, field, flagger, maint_data_field, cal_mea
         dataShiftVektor = dataShiftTarget - dataFit
         shiftedData = dataSeries + dataShiftVektor
         to_correct[shiftedData.index] = shiftedData
-    ########################### plotting stuff for testing phase ##################################################
-        axes[0].plot(dataFit, color='red')
-        axes[0].plot(dataShiftTarget, color='yellow')
-        axes[1].plot(shiftedData, color='green')
-
-    axes[0].vlines(maint_data[drift_frame.index[0]:drift_frame.index[-1]].index, to_correct.min(), to_correct.max(), color='black')
-    axes[0].vlines(maint_data[drift_frame.index[0]:drift_frame.index[-1]].values, to_correct.min(), to_correct.max(), color='black')
-    fig.autofmt_xdate()
-    with open('/home/luenensc/PyPojects/testSpace/SEEFOPics/DriftCorrecture2.pkl', 'wb') as file:
-        pickle.dump(fig, file)
-    ################################################################################################################
 
     if flag_maint_period:
         to_flag = drift_frame['drift_group']
         to_flag = to_flag.drop(to_flag[:maint_data.index[0]].index)
         to_flag = to_flag[to_flag.isna()]
         flagger = flagger.setFlags(field, loc=to_flag, **kwargs)
+    return data, flagger
+
+
+@register
+def proc_seefoLinearDriftCorrecture(data, field, flagger, x_field, y_field, **kwargs):
+    """
+    Train a linear model that predicts data[y_field] by x_1*(data[x_field]) + x_0. (Least squares fit)
+
+    Then correct the data[field] via:
+
+    data[field] = data[field]*x_1 + x_0
+
+    Note, that data[x_field] and data[y_field] must be of equal length.
+    (Also, you might want them to be sampled at same timestamps.)
+
+    Parameters
+    ----------
+    x_field : String
+        Field name of x - data.
+    y_field : String
+        Field name of y - data.
+
+    """
+    data = data.copy()
+    datcol = data[field]
+    reg = LinearRegression()
+    reg.fit(data[x_field].values.reshape(-1,1), data[y_field].values)
+    datcol = (datcol * reg.coef_[0]) + reg.intercept_
+    data[field] = datcol
     return data, flagger
\ No newline at end of file
-- 
GitLab