diff --git a/saqc/core/history.py b/saqc/core/history.py index 3391697c9b85326204623b82200702367647e01a..23e6fa59a6592628dde4b630dba094e4c4b74497 100644 --- a/saqc/core/history.py +++ b/saqc/core/history.py @@ -435,6 +435,9 @@ class History: new._meta = copyfunc(self._meta) return new + def equals(self, other: History) -> bool: + return self._hist.equals(other._hist) and self.meta == other.meta + def __copy__(self): return self.copy(deep=False) diff --git a/saqc/funcs/flagtools.py b/saqc/funcs/flagtools.py index 3373855dd3c93e3adcf380142ae06a8b6c19d915..b8c1fbc101b08d98dcadc2bdb7d43b1476988587 100644 --- a/saqc/funcs/flagtools.py +++ b/saqc/funcs/flagtools.py @@ -17,6 +17,7 @@ from typing_extensions import Literal from saqc import BAD, FILTER_ALL, UNFLAGGED from saqc.core import DictOfSeries, flagging, register +from saqc.core.flags import Flags from saqc.core.history import History from saqc.lib.checking import validateChoice, validateWindow from saqc.lib.tools import initializeTargets, isflagged, isunflagged, multivariateParameters, toSequence @@ -355,6 +356,10 @@ class FlagtoolsMixin: for field, target in zip(fields, targets): + if target not in self._data: + self._data[target] = pd.Series(np.nan, index=self._data[field].index) + self._flags._data[target] = History(self._data[target].index) + history = self._flags.history[field] # append a dummy column meta = { @@ -380,8 +385,9 @@ class FlagtoolsMixin: else: flags = pd.Series(np.nan, index=history.index, dtype=float) - history.append(flags, meta) - self._flags.history[target].append(history) + history.append(flags, meta) + self._flags.history[target].append(history) + import ipdb; ipdb.set_trace() return self diff --git a/saqc/lib/tools.py b/saqc/lib/tools.py index 9132012a2d564dff11c2c623b6e1c11a3c8e8f7a..ad55cfb526994e0a1938e1b65e43f23bc8195dd9 100644 --- a/saqc/lib/tools.py +++ b/saqc/lib/tools.py @@ -654,7 +654,9 @@ def joinExt(sep: str, iterable: Iterable[str], last_sep: str | None = None) -> s return f"{sep.join(iterable[:-1])}{last_sep}{iterable[-1]}" -def multivariateParameters(field: str | list[str], target: str | list[str] | None = None) -> tuple[list[str], list[str]]: +def multivariateParameters( + field: str | list[str], target: str | list[str] | None = None +) -> tuple[list[str], list[str]]: fields = toSequence(field) targets = fields if target is None else toSequence(target) diff --git a/tests/funcs/test_flagtools.py b/tests/funcs/test_flagtools.py index b5de1df5b627ee360502fc34810ded4085b5a748..0103314dda85a6f827fbf6db25e1e36e7a655e48 100644 --- a/tests/funcs/test_flagtools.py +++ b/tests/funcs/test_flagtools.py @@ -175,3 +175,26 @@ def test__groupOperation(field, target, expected, copy): fields = toSequence(itertools.chain.from_iterable(field)) for f, t in zip(fields, targets): assert (result._data[f] == result._data[t]).all(axis=None) + + +def test_transferFlags(): + qc = SaQC( + data=pd.DataFrame( + {"x": [0, 1, 2, 3], "y": [0, 11, 22, 33], "z": [0, 111, 222, 333]} + ), + flags=pd.DataFrame({"x": [B, U, U, B], "y": [B, B, U, U], "z": [B, B, U, B]}), + ) + + # qc1 = qc.transferFlags("x", target="a") + # assert qc1._history["a"].equals(qc1._history["x"]) + + # qc2 = qc.transferFlags(["x", "y"], target=["a", "b"]) + # assert qc2._history["a"].equals(qc2._history["x"]) + # assert qc2._history["b"].equals(qc2._history["y"]) + + qc3 = qc.transferFlags(["x", "y", "z"], target="a") + import ipdb; ipdb.set_trace() + assert qc3._history["a"].equals(qc2._history["x"].append(qc2._history["y"]).append(qc2._history["z"])) + + +