diff --git a/CHANGELOG.md b/CHANGELOG.md index 6731aa25e3a53a09da14156df029fb9f6976d5a7..839310af7c89e234921ce0d54afdbadd97558a88 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ SPDX-License-Identifier: GPL-3.0-or-later ## Unreleased [List of commits](https://git.ufz.de/rdm-software/saqc/-/compare/v2.3.0...develop) ### Added +- Methods `logicalAnd` and `logicalOr` ### Changed ### Removed ### Fixed diff --git a/saqc/funcs/flagtools.py b/saqc/funcs/flagtools.py index 80420a1cdd18ac939a7981a0cda961219e263b3a..d378c70af9e58009eaedcde790da754b5a3c4a95 100644 --- a/saqc/funcs/flagtools.py +++ b/saqc/funcs/flagtools.py @@ -7,8 +7,9 @@ # -*- coding: utf-8 -*- from __future__ import annotations +import operator import warnings -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, Callable, Sequence, Union import numpy as np import pandas as pd @@ -17,6 +18,7 @@ from typing_extensions import Literal from dios import DictOfSeries from saqc.constants import BAD, FILTER_ALL, UNFLAGGED from saqc.core.register import _isflagged, flagging, register +from saqc.lib.tools import toSequence if TYPE_CHECKING: from saqc.core.core import SaQC @@ -531,3 +533,139 @@ class FlagtoolsMixin: self._flags[repeated, field] = flag return self + + @register( + mask=["field"], + demask=["field"], + squeeze=["field"], + multivariate=False, + handles_target=True, + ) + def andGroup( + self: "SaQC", + field: str, + group: Sequence["SaQC"] | dict["SaQC", str | Sequence[str]], + target: str | None = None, + flag: float = BAD, + **kwargs, + ) -> "SaQC": + """ + Flag all values, if a given variable is also flagged in all other given SaQC objects. + + Parameters + ---------- + field : str + Name of the field to check for flags. 'field' needs to present in all + objects in 'qcs'. + + qcs : list of SaQC + A list of SaQC objects to check for flags. + + target : str, default none + Name of the field the generated flags will be written to. If None, the result + will be written to 'field', + + flag: float, default ``BAD`` + The quality flag to set. + + Returns + ------- + saqc.SaQC + """ + + return _groupOperation( + base=self, + field=field, + target=target, + func=operator.and_, + group=group, + flag=flag, + **kwargs, + ) + + @register( + mask=["field"], + demask=["field"], + squeeze=["field"], + multivariate=False, + handles_target=True, + ) + def orGroup( + self: "SaQC", + field: str, + group: Sequence["SaQC"] | dict["SaQC", str | Sequence[str]], + target: str | None = None, + flag: float = BAD, + **kwargs, + ) -> "SaQC": + """ + Flag all values, if a given variable is also flagged in at least one other of the given SaQC objects. + + Parameters + ---------- + field : str + Name of the field to check for flags. 'field' needs to present in all + objects in 'qcs'. + + qcs : list of SaQC + A list of SaQC objects to check for flags. + + target : str, default none + Name of the field the generated flags will be written to. If None, the result + will be written to 'field', + + flag: float, default ``BAD`` + The quality flag to set. + + Returns + ------- + saqc.SaQC + """ + return _groupOperation( + base=self, + field=field, + target=target, + func=operator.or_, + group=group, + flag=flag, + **kwargs, + ) + + +def _groupOperation( + base: "SaQC", + field: str, + func: Callable[[pd.Series, pd.Series], pd.Series], + group: Sequence["SaQC"] | dict["SaQC", str | Sequence[str]], + target: str | None = None, + flag: float = BAD, + **kwargs, +) -> "SaQC": + # Should this be multivariate? And what would multivariate mean in this context + + dfilter = kwargs.get("dfilter", FILTER_ALL) + if target is None: + target = field + + # harmonise `group` to type dict[SaQC, list[str]] + if not isinstance(group, dict): + group = {qc: field for qc in group} + + for k, v in group.items(): + group[k] = toSequence(v) + + qcs_items: list[tuple["SaQC", list[str]]] = list(group.items()) + # generate initial mask from the first `qc` object on the popped first field + mask = _isflagged(qcs_items[0][0]._flags[qcs_items[0][1].pop(0)], thresh=dfilter) + + for qc, fields in qcs_items: + if field not in qc._flags: + raise KeyError(f"variable {field} is missing in given SaQC object") + for field in fields: + mask = func(mask, _isflagged(qc._flags[field], thresh=FILTER_ALL)) + + if target not in base._data: + base = base.copyField(field=field, target=target) + + base._flags[mask, target] = flag + return base diff --git a/tests/funcs/test_flagtools.py b/tests/funcs/test_flagtools.py index 457edab8ab666175cafc9bb5df6cf01f5bb70af1..f885c91e1a106d44ac69f2fe22784d4fcb6ff47d 100644 --- a/tests/funcs/test_flagtools.py +++ b/tests/funcs/test_flagtools.py @@ -4,15 +4,16 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import operator + import numpy as np import pandas as pd - -# -*- coding: utf-8 -*- import pytest from saqc import BAD as B from saqc import UNFLAGGED as U from saqc import SaQC +from saqc.funcs.flagtools import _groupOperation N = np.nan @@ -98,3 +99,70 @@ def test_propagateFlagsIrregularIndex(got, expected, kwargs): saqc = SaQC(data=data, flags=flags).propagateFlags(field="x", **kwargs) result = saqc._flags.history["x"].hist[1].astype(float) assert result.equals(expected) + + +@pytest.mark.parametrize( + "left,right,expected", + [ + ([B, U, U, B], [B, B, U, U], [B, U, U, U]), + ([B, B, B, B], [B, B, B, B], [B, B, B, B]), + ([U, U, U, U], [U, U, U, U], [U, U, U, U]), + ], +) +def test_andGroup(left, right, expected): + + data = pd.DataFrame({"data": [1, 2, 3, 4]}) + + base = SaQC(data=data) + this = SaQC(data=data, flags=pd.DataFrame({"data": pd.Series(left)})) + that = SaQC(data=data, flags=pd.DataFrame({"data": pd.Series(right)})) + result = base.andGroup(field="data", group=[this, that]) + + assert pd.Series(expected).equals(result.flags["data"]) + + +@pytest.mark.parametrize( + "left,right,expected", + [ + ([B, U, U, B], [B, B, U, U], [B, B, U, B]), + ([B, B, B, B], [B, B, B, B], [B, B, B, B]), + ([U, U, U, U], [U, U, U, U], [U, U, U, U]), + ], +) +def test_orGroup(left, right, expected): + + data = pd.DataFrame({"data": [1, 2, 3, 4]}) + + base = SaQC(data=data) + this = SaQC(data=data, flags=pd.DataFrame({"data": pd.Series(left)})) + that = SaQC(data=data, flags=pd.DataFrame({"data": pd.Series(right)})) + result = base.orGroup(field="data", group=[this, that]) + + assert pd.Series(expected).equals(result.flags["data"]) + + +@pytest.mark.parametrize( + "left,right,expected", + [ + ([B, U, U, B], [B, B, U, U], [B, B, U, B]), + ([B, B, B, B], [B, B, B, B], [B, B, B, B]), + ([U, U, U, U], [U, U, U, U], [U, U, U, U]), + ], +) +def test__groupOperation(left, right, expected): + + data = pd.DataFrame( + {"x": [0, 1, 2, 3], "y": [0, 11, 22, 33], "z": [0, 111, 222, 333]} + ) + base = SaQC(data=data) + this = SaQC( + data=data, flags=pd.DataFrame({k: pd.Series(left) for k in data.columns}) + ) + that = SaQC( + data=data, flags=pd.DataFrame({k: pd.Series(right) for k in data.columns}) + ) + result = _groupOperation( + base=base, field="x", func=operator.or_, group={this: "y", that: ["y", "z"]} + ) + + assert pd.Series(expected).equals(result.flags["x"])