From d339e348f9551d30860396a78baf2123c7f988cc Mon Sep 17 00:00:00 2001
From: David Schaefer <david.schaefer@ufz.de>
Date: Thu, 1 Feb 2024 00:13:17 +0100
Subject: [PATCH] first draft - squeeze not working

---
 saqc/funcs/flagtools.py | 65 +++++++++++++++++++++--------------------
 saqc/lib/tools.py       | 14 +++++++++
 2 files changed, 48 insertions(+), 31 deletions(-)

diff --git a/saqc/funcs/flagtools.py b/saqc/funcs/flagtools.py
index 2db1839ff..3373855dd 100644
--- a/saqc/funcs/flagtools.py
+++ b/saqc/funcs/flagtools.py
@@ -19,7 +19,7 @@ from saqc import BAD, FILTER_ALL, UNFLAGGED
 from saqc.core import DictOfSeries, flagging, register
 from saqc.core.history import History
 from saqc.lib.checking import validateChoice, validateWindow
-from saqc.lib.tools import initializeTargets, isflagged, isunflagged, toSequence
+from saqc.lib.tools import initializeTargets, isflagged, isunflagged, multivariateParameters, toSequence
 
 if TYPE_CHECKING:
     from saqc import SaQC
@@ -290,6 +290,7 @@ class FlagtoolsMixin:
         demask=[],
         squeeze=[],
         handles_target=True,  # function defines a target parameter, so it needs to handle it
+        multivariate=True,
     )
     def transferFlags(
         self: "SaQC",
@@ -349,37 +350,39 @@ class FlagtoolsMixin:
            0   -inf   -inf   -inf
            1  255.0  255.0  255.0
         """
-        history = self._flags.history[field]
-
-        if target is None:
-            target = field
-
-        if overwrite is False:
-            mask = isflagged(self._flags[target], thresh=kwargs["dfilter"])
-            history._hist[mask] = np.nan
-
-        # append a dummy column
-        meta = {
-            "func": f"transferFlags",
-            "args": (),
-            "kwargs": {
-                "field": field,
-                "target": target,
-                "squeeze": squeeze,
-                "overwrite": overwrite,
-                **kwargs,
-            },
-        }
-
-        if squeeze:
-            flags = history.squeeze(raw=True)
-            # init an empty history to which we later append the squeezed flags
-            history = History(index=history.index)
-        else:
-            flags = pd.Series(np.nan, index=history.index, dtype=float)
 
-        history.append(flags, meta)
-        self._flags.history[target].append(history)
+        fields, targets = multivariateParameters(field, target)
+
+        for field, target in zip(fields, targets):
+
+            history = self._flags.history[field]
+            # append a dummy column
+            meta = {
+                "func": f"transferFlags",
+                "args": (),
+                "kwargs": {
+                    "field": field,
+                    "target": target,
+                    "squeeze": squeeze,
+                    "overwrite": overwrite,
+                    **kwargs,
+                },
+            }
+
+            if overwrite is False:
+                mask = isflagged(self._flags[target], thresh=kwargs["dfilter"])
+                history._hist[mask] = np.nan
+
+            if squeeze:
+                flags = history.squeeze(raw=True)
+                # init an empty history to which we later append the squeezed flags
+                history = History(index=history.index)
+            else:
+                flags = pd.Series(np.nan, index=history.index, dtype=float)
+
+                history.append(flags, meta)
+                self._flags.history[target].append(history)
+
 
         return self
 
diff --git a/saqc/lib/tools.py b/saqc/lib/tools.py
index 003423570..9132012a2 100644
--- a/saqc/lib/tools.py
+++ b/saqc/lib/tools.py
@@ -652,3 +652,17 @@ def joinExt(sep: str, iterable: Iterable[str], last_sep: str | None = None) -> s
     if len(iterable) < 2:
         return sep.join(iterable)
     return f"{sep.join(iterable[:-1])}{last_sep}{iterable[-1]}"
+
+
+def multivariateParameters(field: str | list[str], target: str | list[str] | None = None) -> tuple[list[str], list[str]]:
+    fields = toSequence(field)
+    targets = fields if target is None else toSequence(target)
+
+    if len(targets) == 1:
+        targets = targets * len(fields)
+    if len(targets) != len(fields):
+        raise ValueError(
+            "expected a single 'target' or the same number of 'field' and 'target' values"
+        )
+
+    return fields, targets
-- 
GitLab