From f8f462139299a6b2bf63665eb7aa74ccc1eeabfa Mon Sep 17 00:00:00 2001
From: luenensc <peter.luenenschloss@ufz.de>
Date: Tue, 16 Apr 2024 21:13:08 +0200
Subject: [PATCH] fixed target handling

---
 saqc/funcs/rolling.py | 50 ++++++++++++++++++++++---------------------
 1 file changed, 26 insertions(+), 24 deletions(-)

diff --git a/saqc/funcs/rolling.py b/saqc/funcs/rolling.py
index 0b0666714..327c32f13 100644
--- a/saqc/funcs/rolling.py
+++ b/saqc/funcs/rolling.py
@@ -32,9 +32,9 @@ class RollingMixin:
     )
     def rolling(
         self: "SaQC",
-        field: str,
+        field: str | list[str],
         window: str | int,
-        target: str = None,
+        target: str | list[str] = None,
         func: Callable[[pd.Series], np.ndarray] | str = "mean",
         min_periods: int = 0,
         center: bool = True,
@@ -66,27 +66,14 @@ class RollingMixin:
             If True, center the rolling window.
         """
         # HINT: checking in  _roll
-        if len(field) == 1:
-            if target:
-                self = self.copyField(field[0], target=target)
-                field = target
-
-            self._data, self._flags = _roll(
-                data=self._data,
-                field=field,
-                flags=self._flags,
-                window=window,
-                func=func,
-                min_periods=min_periods,
-                center=center,
-                **kwargs,
+        if target and (len(target) > 1) and (len(field) != len(target)):
+            raise ValueError(
+                f"""If multiple targets are given, per-field application of rolling is conducted and the number of 
+                fields has to equal the number of targets.\n Got: \n Fields={field} \n Targets={target}"""
             )
-        else:
-            if not target:
-                raise ValueError(
-                    "Target has to be assigned for cross statistics calculations."
-                )
 
+        if target and (len(field) > 1) and (len(target) == 1):
+            target = target[0]
             if target not in self._data.columns:
                 self[target] = saqc.SaQC(
                     pd.Series(
@@ -105,6 +92,22 @@ class RollingMixin:
                 center=center,
             )
 
+        else:
+            if target:
+                for ft in zip(field, target):
+                    self = self.copyField(ft[0], target=ft[1], overwrite=True)
+                field = target
+            for f in field:
+                self._data, self._flags = _roll(
+                    data=self._data,
+                    field=f,
+                    flags=self._flags,
+                    window=window,
+                    func=func,
+                    min_periods=min_periods,
+                    center=center,
+                    **kwargs,
+                )
         return self
 
     @register(mask=["field"], demask=[], squeeze=[])
@@ -221,9 +224,8 @@ def _hroll(
     if center:
         f_out = f_out.shift(-int(np.floor(window / 2)))
 
-    for f in target:
-        data[f] = d_out
-        flags[f] = f_out
+    data[target] = d_out
+    flags[target] = f_out
 
     return data, flags
 
-- 
GitLab