From 8bc66a122573222a8c63114a88ca8feb63f16575 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?David=20Sch=C3=A4fer?= <david.schaefer@ufz.de>
Date: Wed, 6 Mar 2024 09:02:29 +0100
Subject: [PATCH] Make transferFlags a multivariate function

---
 CHANGELOG.md                  |  1 +
 saqc/funcs/flagtools.py       | 66 +++++++++++++++++++++++++----------
 saqc/lib/tools.py             | 20 ++++++++++-
 tests/funcs/test_flagtools.py | 43 +++++++++++++++++++++++
 4 files changed, 111 insertions(+), 19 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index ef5c8a0ec..0f09eccdf 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -15,6 +15,7 @@ SPDX-License-Identifier: GPL-3.0-or-later
 - `SaQC`: support for selection, slicing and setting of items by use of subscription on SaQC objects (e.g. `qc[key]` and `qc[key] = value`).
    Selection works with single keys, collections of keys and string slices (e.g. `qc["a":"f"]`).  Values can be SaQC objects, pd.Series, 
    Iterable of Series and dict-like with series values.
+- `transferFlags` is a multivariate function
 - `plot`: added `yscope` keyword
 - `setFlags`: function to replace `flagManual`
 - `flagUniLOF`: added defaultly applied correction to mitigate phenomenon of overflagging at relatively steep data value slopes. (parameter `slope_correct`). 
diff --git a/saqc/funcs/flagtools.py b/saqc/funcs/flagtools.py
index 9e01e74e2..029665a3f 100644
--- a/saqc/funcs/flagtools.py
+++ b/saqc/funcs/flagtools.py
@@ -17,9 +17,16 @@ from typing_extensions import Literal
 
 from saqc import BAD, FILTER_ALL, UNFLAGGED
 from saqc.core import DictOfSeries, flagging, register
+from saqc.core.flags import Flags
 from saqc.core.history import History
 from saqc.lib.checking import validateChoice, validateWindow
-from saqc.lib.tools import initializeTargets, isflagged, isunflagged, toSequence
+from saqc.lib.tools import (
+    initializeTargets,
+    isflagged,
+    isunflagged,
+    multivariateParameters,
+    toSequence,
+)
 
 if TYPE_CHECKING:
     from saqc import SaQC
@@ -356,6 +363,7 @@ class FlagtoolsMixin:
         demask=[],
         squeeze=[],
         handles_target=True,  # function defines a target parameter, so it needs to handle it
+        multivariate=True,
     )
     def transferFlags(
         self: "SaQC",
@@ -415,16 +423,8 @@ class FlagtoolsMixin:
            0   -inf   -inf   -inf
            1  255.0  255.0  255.0
         """
-        history = self._flags.history[field]
-
-        if target is None:
-            target = field
 
-        if overwrite is False:
-            mask = isflagged(self._flags[target], thresh=kwargs["dfilter"])
-            history._hist[mask] = np.nan
-
-        # append a dummy column
+        fields, targets, broadcasting = multivariateParameters(field, target)
         meta = {
             "func": f"transferFlags",
             "args": (),
@@ -437,15 +437,45 @@ class FlagtoolsMixin:
             },
         }
 
-        if squeeze:
-            flags = history.squeeze(raw=True)
-            # init an empty history to which we later append the squeezed flags
-            history = History(index=history.index)
-        else:
+        for field, target in zip(fields, targets):
+            # initialize non existing targets
+            if target not in self._data:
+                self._data[target] = pd.Series(np.nan, index=self._data[field].index)
+                self._flags._data[target] = History(self._data[target].index)
+            if not self._data[field].index.equals(self._data[target].index):
+                raise ValueError(
+                    f"All Field and Target indices must match!\n"
+                    f"Indices of {field} and {target} seem to be not congruent within the context of the given\n"
+                    f"- fields: {fields}\n "
+                    f"- and targets: {targets}"
+                )
+            history = self._flags.history[field].copy(deep=True)
+
+            if overwrite is False:
+                mask = isflagged(self._flags[target], thresh=kwargs["dfilter"])
+                history._hist[mask] = np.nan
+
+            if squeeze:
+                # add squeezed flags
+                flags = history.squeeze(raw=True)
+                history = History(index=history.index).append(flags, meta)
+            elif broadcasting is False:
+                # add an empty flags
+                flags = pd.Series(np.nan, index=history.index, dtype=float)
+                history.append(flags, meta)
+            # else:
+            #    broadcasting -> multiple fields will be written to one target
+            #    only add the fields' histories and add an empty column later
+
+            self._flags.history[target].append(history)
+
+        if broadcasting and not squeeze:
+            # add one final history column
+            # all targets are identical, if we broadcast fields -> target
+            target = targets[0]
+            history = self._flags.history[target]
             flags = pd.Series(np.nan, index=history.index, dtype=float)
-
-        history.append(flags, meta)
-        self._flags.history[target].append(history)
+            self._flags.history[target].append(flags, meta)
 
         return self
 
diff --git a/saqc/lib/tools.py b/saqc/lib/tools.py
index 003423570..4dbcf7108 100644
--- a/saqc/lib/tools.py
+++ b/saqc/lib/tools.py
@@ -554,7 +554,7 @@ def initializeTargets(
     index: pd.Index,
 ):
     """
-    Initialize all targets based on field.
+    Initialize all targets based on fields.
 
     Note
     ----
@@ -652,3 +652,21 @@ def joinExt(sep: str, iterable: Iterable[str], last_sep: str | None = None) -> s
     if len(iterable) < 2:
         return sep.join(iterable)
     return f"{sep.join(iterable[:-1])}{last_sep}{iterable[-1]}"
+
+
+def multivariateParameters(
+    field: str | list[str], target: str | list[str] | None = None
+) -> tuple[list[str], list[str], bool]:
+    fields = toSequence(field)
+    targets = fields if target is None else toSequence(target)
+    broadcasting = False
+
+    if len(targets) == 1:
+        targets = targets * len(fields)
+        broadcasting = True
+    if len(targets) != len(fields):
+        raise ValueError(
+            "expected a single 'target' or the same number of 'field' and 'target' values"
+        )
+
+    return fields, targets, broadcasting
diff --git a/tests/funcs/test_flagtools.py b/tests/funcs/test_flagtools.py
index 9c650ab65..6bda00301 100644
--- a/tests/funcs/test_flagtools.py
+++ b/tests/funcs/test_flagtools.py
@@ -178,6 +178,49 @@ def test__groupOperation(field, target, expected, copy):
             assert (result._data[f] == result._data[t]).all(axis=None)
 
 
+def test_transferFlags():
+    qc = SaQC(
+        data=pd.DataFrame(
+            {"x": [0, 1, 2, 3], "y": [0, 11, 22, 33], "z": [0, 111, 222, 333]}
+        ),
+        flags=pd.DataFrame({"x": [B, U, U, B], "y": [B, B, U, U], "z": [B, B, U, B]}),
+    )
+
+    # no squueze
+    qc1 = qc.transferFlags("x", target="a")
+    assert qc1._history["a"].hist.iloc[:, :-1].equals(qc1._history["x"].hist)
+    assert qc1._history["a"].hist.iloc[:, -1].isna().all()
+
+    qc2 = qc.transferFlags(["x", "y"], target=["a", "b"])
+    assert qc2._history["a"].hist.iloc[:, :-1].equals(qc2._history["x"].hist)
+    assert qc2._history["a"].hist.iloc[:, -1].isna().all()
+    assert qc2._history["b"].hist.iloc[:, :-1].equals(qc2._history["y"].hist)
+    assert qc2._history["b"].hist.iloc[:, -1].isna().all()
+
+    # we use the overwrite option here for easy checking against the origin
+    # flags, because otherwise we would need to respect the inserted nan
+    qc3 = qc.transferFlags(["x", "y", "z"], target="a", overwrite=True)
+    assert qc3._history["a"].hist.iloc[:, 0].equals(qc3._history["x"].hist.squeeze())
+    assert qc3._history["a"].hist.iloc[:, 1].equals(qc3._history["y"].hist.squeeze())
+    assert qc3._history["a"].hist.iloc[:, 2].equals(qc3._history["z"].hist.squeeze())
+    assert qc3._history["a"].hist.iloc[:, -1].isna().all()
+
+    # squueze
+    qc1 = qc.transferFlags("x", target="a", squeeze=True)
+    assert qc1._history["a"].hist.equals(qc1._history["x"].hist)
+
+    qc2 = qc.transferFlags(["x", "y"], target=["a", "b"], squeeze=True)
+    assert qc2._history["a"].hist.equals(qc2._history["x"].hist)
+    assert qc2._history["b"].hist.equals(qc2._history["y"].hist)
+
+    # we use the overwrite option here for easy checking against the origin
+    # flags, because otherwise we would need to respect the inserted nan
+    qc3 = qc.transferFlags(["x", "y", "z"], target="a", overwrite=True, squeeze=True)
+    assert qc3._history["a"].hist.iloc[:, 0].equals(qc3._history["x"].hist.squeeze())
+    assert qc3._history["a"].hist.iloc[:, 1].equals(qc3._history["y"].hist.squeeze())
+    assert qc3._history["a"].hist.iloc[:, 2].equals(qc3._history["z"].hist.squeeze())
+
+
 @pytest.mark.parametrize(
     "f_data",
     [
-- 
GitLab