From 847d0f24e7541ef01c0a71d1bda7a4dcdcd9a7ca Mon Sep 17 00:00:00 2001
From: David Schaefer <david.schaefer@ufz.de>
Date: Thu, 1 Feb 2024 21:58:18 +0100
Subject: [PATCH] make transferFlags a multivariate function

---
 CHANGELOG.md                  |  1 +
 saqc/core/history.py          |  3 ---
 saqc/funcs/flagtools.py       | 50 ++++++++++++++++++++++-------------
 saqc/lib/tools.py             |  8 +++---
 tests/funcs/test_flagtools.py | 43 +++++++++++++++++++++++-------
 5 files changed, 70 insertions(+), 35 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 849df4b84..bfc336e9f 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -14,6 +14,7 @@ SPDX-License-Identifier: GPL-3.0-or-later
 - `SaQC`: support for selection, slicing and setting of items by use of subscription on SaQC objects (e.g. `qc[key]` and `qc[key] = value`).
    Selection works with single keys, collections of keys and string slices (e.g. `qc["a":"f"]`).  Values can be SaQC objects, pd.Series, 
    Iterable of Series and dict-like with series values.
+- `transferFlags` becomes a multivariate function
 ### Changed
 ### Removed
 ### Fixed
diff --git a/saqc/core/history.py b/saqc/core/history.py
index 23e6fa59a..3391697c9 100644
--- a/saqc/core/history.py
+++ b/saqc/core/history.py
@@ -435,9 +435,6 @@ class History:
         new._meta = copyfunc(self._meta)
         return new
 
-    def equals(self, other: History) -> bool:
-        return self._hist.equals(other._hist) and self.meta == other.meta
-
     def __copy__(self):
         return self.copy(deep=False)
 
diff --git a/saqc/funcs/flagtools.py b/saqc/funcs/flagtools.py
index b8c1fbc10..5b9c3c646 100644
--- a/saqc/funcs/flagtools.py
+++ b/saqc/funcs/flagtools.py
@@ -352,46 +352,58 @@ class FlagtoolsMixin:
            1  255.0  255.0  255.0
         """
 
-        fields, targets = multivariateParameters(field, target)
+        fields, targets, broadcasting = multivariateParameters(field, target)
+        meta = {
+            "func": f"transferFlags",
+            "args": (),
+            "kwargs": {
+                "field": field,
+                "target": target,
+                "squeeze": squeeze,
+                "overwrite": overwrite,
+                **kwargs,
+            },
+        }
+
 
         for field, target in zip(fields, targets):
 
+            # initialize non existing targets
             if target not in self._data:
                 self._data[target] = pd.Series(np.nan, index=self._data[field].index)
                 self._flags._data[target] = History(self._data[target].index)
 
-            history = self._flags.history[field]
-            # append a dummy column
-            meta = {
-                "func": f"transferFlags",
-                "args": (),
-                "kwargs": {
-                    "field": field,
-                    "target": target,
-                    "squeeze": squeeze,
-                    "overwrite": overwrite,
-                    **kwargs,
-                },
-            }
+            history = self._flags.history[field].copy(deep=True)
 
             if overwrite is False:
                 mask = isflagged(self._flags[target], thresh=kwargs["dfilter"])
                 history._hist[mask] = np.nan
 
             if squeeze:
+                # add squeezed flags
                 flags = history.squeeze(raw=True)
-                # init an empty history to which we later append the squeezed flags
-                history = History(index=history.index)
-            else:
+                history = History(index=history.index).append(flags, meta)
+            elif broadcasting is False:
+                # add an empty flags
                 flags = pd.Series(np.nan, index=history.index, dtype=float)
+                history.append(flags, meta)
+            # else:
+            #    broadcasting -> multiple fields will be written to one target
+            #    only add the fields' histories and add an empty column later
 
-            history.append(flags, meta)
             self._flags.history[target].append(history)
-            import ipdb; ipdb.set_trace()
 
+        if broadcasting and not squeeze:
+            # add one final history column
+            # all targets are identical, if we broadcast fields -> target
+            target = targets[0]
+            history = self._flags.history[target]
+            flags = pd.Series(np.nan, index=history.index, dtype=float)
+            self._flags.history[target].append(flags, meta)
 
         return self
 
+
     @flagging()
     def propagateFlags(
         self: "SaQC",
diff --git a/saqc/lib/tools.py b/saqc/lib/tools.py
index ad55cfb52..4dbcf7108 100644
--- a/saqc/lib/tools.py
+++ b/saqc/lib/tools.py
@@ -554,7 +554,7 @@ def initializeTargets(
     index: pd.Index,
 ):
     """
-    Initialize all targets based on field.
+    Initialize all targets based on fields.
 
     Note
     ----
@@ -656,15 +656,17 @@ def joinExt(sep: str, iterable: Iterable[str], last_sep: str | None = None) -> s
 
 def multivariateParameters(
     field: str | list[str], target: str | list[str] | None = None
-) -> tuple[list[str], list[str]]:
+) -> tuple[list[str], list[str], bool]:
     fields = toSequence(field)
     targets = fields if target is None else toSequence(target)
+    broadcasting = False
 
     if len(targets) == 1:
         targets = targets * len(fields)
+        broadcasting = True
     if len(targets) != len(fields):
         raise ValueError(
             "expected a single 'target' or the same number of 'field' and 'target' values"
         )
 
-    return fields, targets
+    return fields, targets, broadcasting
diff --git a/tests/funcs/test_flagtools.py b/tests/funcs/test_flagtools.py
index 0103314dd..efadb0219 100644
--- a/tests/funcs/test_flagtools.py
+++ b/tests/funcs/test_flagtools.py
@@ -185,16 +185,39 @@ def test_transferFlags():
         flags=pd.DataFrame({"x": [B, U, U, B], "y": [B, B, U, U], "z": [B, B, U, B]}),
     )
 
-    # qc1 = qc.transferFlags("x", target="a")
-    # assert qc1._history["a"].equals(qc1._history["x"])
-
-    # qc2 = qc.transferFlags(["x", "y"], target=["a", "b"])
-    # assert qc2._history["a"].equals(qc2._history["x"])
-    # assert qc2._history["b"].equals(qc2._history["y"])
-
-    qc3 = qc.transferFlags(["x", "y", "z"], target="a")
-    import ipdb; ipdb.set_trace()
-    assert qc3._history["a"].equals(qc2._history["x"].append(qc2._history["y"]).append(qc2._history["z"]))
+    # no squueze
+    qc1 = qc.transferFlags("x", target="a")
+    assert qc1._history["a"].hist.iloc[:, :-1].equals(qc1._history["x"].hist)
+    assert qc1._history["a"].hist.iloc[:, -1].isna().all()
+
+    qc2 = qc.transferFlags(["x", "y"], target=["a", "b"])
+    assert qc2._history["a"].hist.iloc[:, :-1].equals(qc2._history["x"].hist)
+    assert qc2._history["a"].hist.iloc[:, -1].isna().all()
+    assert qc2._history["b"].hist.iloc[:, :-1].equals(qc2._history["y"].hist)
+    assert qc2._history["b"].hist.iloc[:, -1].isna().all()
+
+    # we use the overwrite option here for easy checking against the origin
+    # flags, because otherwise we would need to respect the inserted nan
+    qc3 = qc.transferFlags(["x", "y", "z"], target="a", overwrite=True)
+    assert qc3._history["a"].hist.iloc[:, 0].equals(qc3._history["x"].hist.squeeze())
+    assert qc3._history["a"].hist.iloc[:, 1].equals(qc3._history["y"].hist.squeeze())
+    assert qc3._history["a"].hist.iloc[:, 2].equals(qc3._history["z"].hist.squeeze())
+    assert qc3._history["a"].hist.iloc[:, -1].isna().all()
+
+    # squueze
+    qc1 = qc.transferFlags("x", target="a", squeeze=True)
+    assert qc1._history["a"].hist.equals(qc1._history["x"].hist)
+
+    qc2 = qc.transferFlags(["x", "y"], target=["a", "b"], squeeze=True)
+    assert qc2._history["a"].hist.equals(qc2._history["x"].hist)
+    assert qc2._history["b"].hist.equals(qc2._history["y"].hist)
+
+    # we use the overwrite option here for easy checking against the origin
+    # flags, because otherwise we would need to respect the inserted nan
+    qc3 = qc.transferFlags(["x", "y", "z"], target="a", overwrite=True, squeeze=True)
+    assert qc3._history["a"].hist.iloc[:, 0].equals(qc3._history["x"].hist.squeeze())
+    assert qc3._history["a"].hist.iloc[:, 1].equals(qc3._history["y"].hist.squeeze())
+    assert qc3._history["a"].hist.iloc[:, 2].equals(qc3._history["z"].hist.squeeze())
 
 
 
-- 
GitLab