From 847d0f24e7541ef01c0a71d1bda7a4dcdcd9a7ca Mon Sep 17 00:00:00 2001 From: David Schaefer <david.schaefer@ufz.de> Date: Thu, 1 Feb 2024 21:58:18 +0100 Subject: [PATCH] make transferFlags a multivariate function --- CHANGELOG.md | 1 + saqc/core/history.py | 3 --- saqc/funcs/flagtools.py | 50 ++++++++++++++++++++++------------- saqc/lib/tools.py | 8 +++--- tests/funcs/test_flagtools.py | 43 +++++++++++++++++++++++------- 5 files changed, 70 insertions(+), 35 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 849df4b84..bfc336e9f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ SPDX-License-Identifier: GPL-3.0-or-later - `SaQC`: support for selection, slicing and setting of items by use of subscription on SaQC objects (e.g. `qc[key]` and `qc[key] = value`). Selection works with single keys, collections of keys and string slices (e.g. `qc["a":"f"]`). Values can be SaQC objects, pd.Series, Iterable of Series and dict-like with series values. +- `transferFlags` becomes a multivariate function ### Changed ### Removed ### Fixed diff --git a/saqc/core/history.py b/saqc/core/history.py index 23e6fa59a..3391697c9 100644 --- a/saqc/core/history.py +++ b/saqc/core/history.py @@ -435,9 +435,6 @@ class History: new._meta = copyfunc(self._meta) return new - def equals(self, other: History) -> bool: - return self._hist.equals(other._hist) and self.meta == other.meta - def __copy__(self): return self.copy(deep=False) diff --git a/saqc/funcs/flagtools.py b/saqc/funcs/flagtools.py index b8c1fbc10..5b9c3c646 100644 --- a/saqc/funcs/flagtools.py +++ b/saqc/funcs/flagtools.py @@ -352,46 +352,58 @@ class FlagtoolsMixin: 1 255.0 255.0 255.0 """ - fields, targets = multivariateParameters(field, target) + fields, targets, broadcasting = multivariateParameters(field, target) + meta = { + "func": f"transferFlags", + "args": (), + "kwargs": { + "field": field, + "target": target, + "squeeze": squeeze, + "overwrite": overwrite, + **kwargs, + }, + } + for field, target in zip(fields, targets): + # initialize non existing targets if target not in self._data: self._data[target] = pd.Series(np.nan, index=self._data[field].index) self._flags._data[target] = History(self._data[target].index) - history = self._flags.history[field] - # append a dummy column - meta = { - "func": f"transferFlags", - "args": (), - "kwargs": { - "field": field, - "target": target, - "squeeze": squeeze, - "overwrite": overwrite, - **kwargs, - }, - } + history = self._flags.history[field].copy(deep=True) if overwrite is False: mask = isflagged(self._flags[target], thresh=kwargs["dfilter"]) history._hist[mask] = np.nan if squeeze: + # add squeezed flags flags = history.squeeze(raw=True) - # init an empty history to which we later append the squeezed flags - history = History(index=history.index) - else: + history = History(index=history.index).append(flags, meta) + elif broadcasting is False: + # add an empty flags flags = pd.Series(np.nan, index=history.index, dtype=float) + history.append(flags, meta) + # else: + # broadcasting -> multiple fields will be written to one target + # only add the fields' histories and add an empty column later - history.append(flags, meta) self._flags.history[target].append(history) - import ipdb; ipdb.set_trace() + if broadcasting and not squeeze: + # add one final history column + # all targets are identical, if we broadcast fields -> target + target = targets[0] + history = self._flags.history[target] + flags = pd.Series(np.nan, index=history.index, dtype=float) + self._flags.history[target].append(flags, meta) return self + @flagging() def propagateFlags( self: "SaQC", diff --git a/saqc/lib/tools.py b/saqc/lib/tools.py index ad55cfb52..4dbcf7108 100644 --- a/saqc/lib/tools.py +++ b/saqc/lib/tools.py @@ -554,7 +554,7 @@ def initializeTargets( index: pd.Index, ): """ - Initialize all targets based on field. + Initialize all targets based on fields. Note ---- @@ -656,15 +656,17 @@ def joinExt(sep: str, iterable: Iterable[str], last_sep: str | None = None) -> s def multivariateParameters( field: str | list[str], target: str | list[str] | None = None -) -> tuple[list[str], list[str]]: +) -> tuple[list[str], list[str], bool]: fields = toSequence(field) targets = fields if target is None else toSequence(target) + broadcasting = False if len(targets) == 1: targets = targets * len(fields) + broadcasting = True if len(targets) != len(fields): raise ValueError( "expected a single 'target' or the same number of 'field' and 'target' values" ) - return fields, targets + return fields, targets, broadcasting diff --git a/tests/funcs/test_flagtools.py b/tests/funcs/test_flagtools.py index 0103314dd..efadb0219 100644 --- a/tests/funcs/test_flagtools.py +++ b/tests/funcs/test_flagtools.py @@ -185,16 +185,39 @@ def test_transferFlags(): flags=pd.DataFrame({"x": [B, U, U, B], "y": [B, B, U, U], "z": [B, B, U, B]}), ) - # qc1 = qc.transferFlags("x", target="a") - # assert qc1._history["a"].equals(qc1._history["x"]) - - # qc2 = qc.transferFlags(["x", "y"], target=["a", "b"]) - # assert qc2._history["a"].equals(qc2._history["x"]) - # assert qc2._history["b"].equals(qc2._history["y"]) - - qc3 = qc.transferFlags(["x", "y", "z"], target="a") - import ipdb; ipdb.set_trace() - assert qc3._history["a"].equals(qc2._history["x"].append(qc2._history["y"]).append(qc2._history["z"])) + # no squueze + qc1 = qc.transferFlags("x", target="a") + assert qc1._history["a"].hist.iloc[:, :-1].equals(qc1._history["x"].hist) + assert qc1._history["a"].hist.iloc[:, -1].isna().all() + + qc2 = qc.transferFlags(["x", "y"], target=["a", "b"]) + assert qc2._history["a"].hist.iloc[:, :-1].equals(qc2._history["x"].hist) + assert qc2._history["a"].hist.iloc[:, -1].isna().all() + assert qc2._history["b"].hist.iloc[:, :-1].equals(qc2._history["y"].hist) + assert qc2._history["b"].hist.iloc[:, -1].isna().all() + + # we use the overwrite option here for easy checking against the origin + # flags, because otherwise we would need to respect the inserted nan + qc3 = qc.transferFlags(["x", "y", "z"], target="a", overwrite=True) + assert qc3._history["a"].hist.iloc[:, 0].equals(qc3._history["x"].hist.squeeze()) + assert qc3._history["a"].hist.iloc[:, 1].equals(qc3._history["y"].hist.squeeze()) + assert qc3._history["a"].hist.iloc[:, 2].equals(qc3._history["z"].hist.squeeze()) + assert qc3._history["a"].hist.iloc[:, -1].isna().all() + + # squueze + qc1 = qc.transferFlags("x", target="a", squeeze=True) + assert qc1._history["a"].hist.equals(qc1._history["x"].hist) + + qc2 = qc.transferFlags(["x", "y"], target=["a", "b"], squeeze=True) + assert qc2._history["a"].hist.equals(qc2._history["x"].hist) + assert qc2._history["b"].hist.equals(qc2._history["y"].hist) + + # we use the overwrite option here for easy checking against the origin + # flags, because otherwise we would need to respect the inserted nan + qc3 = qc.transferFlags(["x", "y", "z"], target="a", overwrite=True, squeeze=True) + assert qc3._history["a"].hist.iloc[:, 0].equals(qc3._history["x"].hist.squeeze()) + assert qc3._history["a"].hist.iloc[:, 1].equals(qc3._history["y"].hist.squeeze()) + assert qc3._history["a"].hist.iloc[:, 2].equals(qc3._history["z"].hist.squeeze()) -- GitLab