From d339e348f9551d30860396a78baf2123c7f988cc Mon Sep 17 00:00:00 2001 From: David Schaefer <david.schaefer@ufz.de> Date: Thu, 1 Feb 2024 00:13:17 +0100 Subject: [PATCH] first draft - squeeze not working --- saqc/funcs/flagtools.py | 65 +++++++++++++++++++++-------------------- saqc/lib/tools.py | 14 +++++++++ 2 files changed, 48 insertions(+), 31 deletions(-) diff --git a/saqc/funcs/flagtools.py b/saqc/funcs/flagtools.py index 2db1839ff..3373855dd 100644 --- a/saqc/funcs/flagtools.py +++ b/saqc/funcs/flagtools.py @@ -19,7 +19,7 @@ from saqc import BAD, FILTER_ALL, UNFLAGGED from saqc.core import DictOfSeries, flagging, register from saqc.core.history import History from saqc.lib.checking import validateChoice, validateWindow -from saqc.lib.tools import initializeTargets, isflagged, isunflagged, toSequence +from saqc.lib.tools import initializeTargets, isflagged, isunflagged, multivariateParameters, toSequence if TYPE_CHECKING: from saqc import SaQC @@ -290,6 +290,7 @@ class FlagtoolsMixin: demask=[], squeeze=[], handles_target=True, # function defines a target parameter, so it needs to handle it + multivariate=True, ) def transferFlags( self: "SaQC", @@ -349,37 +350,39 @@ class FlagtoolsMixin: 0 -inf -inf -inf 1 255.0 255.0 255.0 """ - history = self._flags.history[field] - - if target is None: - target = field - - if overwrite is False: - mask = isflagged(self._flags[target], thresh=kwargs["dfilter"]) - history._hist[mask] = np.nan - - # append a dummy column - meta = { - "func": f"transferFlags", - "args": (), - "kwargs": { - "field": field, - "target": target, - "squeeze": squeeze, - "overwrite": overwrite, - **kwargs, - }, - } - - if squeeze: - flags = history.squeeze(raw=True) - # init an empty history to which we later append the squeezed flags - history = History(index=history.index) - else: - flags = pd.Series(np.nan, index=history.index, dtype=float) - history.append(flags, meta) - self._flags.history[target].append(history) + fields, targets = multivariateParameters(field, target) + + for field, target in zip(fields, targets): + + history = self._flags.history[field] + # append a dummy column + meta = { + "func": f"transferFlags", + "args": (), + "kwargs": { + "field": field, + "target": target, + "squeeze": squeeze, + "overwrite": overwrite, + **kwargs, + }, + } + + if overwrite is False: + mask = isflagged(self._flags[target], thresh=kwargs["dfilter"]) + history._hist[mask] = np.nan + + if squeeze: + flags = history.squeeze(raw=True) + # init an empty history to which we later append the squeezed flags + history = History(index=history.index) + else: + flags = pd.Series(np.nan, index=history.index, dtype=float) + + history.append(flags, meta) + self._flags.history[target].append(history) + return self diff --git a/saqc/lib/tools.py b/saqc/lib/tools.py index 003423570..9132012a2 100644 --- a/saqc/lib/tools.py +++ b/saqc/lib/tools.py @@ -652,3 +652,17 @@ def joinExt(sep: str, iterable: Iterable[str], last_sep: str | None = None) -> s if len(iterable) < 2: return sep.join(iterable) return f"{sep.join(iterable[:-1])}{last_sep}{iterable[-1]}" + + +def multivariateParameters(field: str | list[str], target: str | list[str] | None = None) -> tuple[list[str], list[str]]: + fields = toSequence(field) + targets = fields if target is None else toSequence(target) + + if len(targets) == 1: + targets = targets * len(fields) + if len(targets) != len(fields): + raise ValueError( + "expected a single 'target' or the same number of 'field' and 'target' values" + ) + + return fields, targets -- GitLab