From 4c658f9cedaca6aea8be8b122cf7eb6537b6f819 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?David=20Sch=C3=A4fer?= <david.schaefer@ufz.de>
Date: Tue, 31 Jan 2023 12:30:31 +0100
Subject: [PATCH] Support function call groups

---
 CHANGELOG.md                  |   1 +
 saqc/funcs/flagtools.py       | 140 +++++++++++++++++++++++++++++++++-
 tests/funcs/test_flagtools.py |  72 ++++++++++++++++-
 3 files changed, 210 insertions(+), 3 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 6731aa25e..839310af7 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -9,6 +9,7 @@ SPDX-License-Identifier: GPL-3.0-or-later
 ## Unreleased
 [List of commits](https://git.ufz.de/rdm-software/saqc/-/compare/v2.3.0...develop)
 ### Added
+- Methods `logicalAnd` and `logicalOr`
 ### Changed
 ### Removed
 ### Fixed
diff --git a/saqc/funcs/flagtools.py b/saqc/funcs/flagtools.py
index 80420a1cd..d378c70af 100644
--- a/saqc/funcs/flagtools.py
+++ b/saqc/funcs/flagtools.py
@@ -7,8 +7,9 @@
 # -*- coding: utf-8 -*-
 from __future__ import annotations
 
+import operator
 import warnings
-from typing import TYPE_CHECKING, Any, Union
+from typing import TYPE_CHECKING, Any, Callable, Sequence, Union
 
 import numpy as np
 import pandas as pd
@@ -17,6 +18,7 @@ from typing_extensions import Literal
 from dios import DictOfSeries
 from saqc.constants import BAD, FILTER_ALL, UNFLAGGED
 from saqc.core.register import _isflagged, flagging, register
+from saqc.lib.tools import toSequence
 
 if TYPE_CHECKING:
     from saqc.core.core import SaQC
@@ -531,3 +533,139 @@ class FlagtoolsMixin:
         self._flags[repeated, field] = flag
 
         return self
+
+    @register(
+        mask=["field"],
+        demask=["field"],
+        squeeze=["field"],
+        multivariate=False,
+        handles_target=True,
+    )
+    def andGroup(
+        self: "SaQC",
+        field: str,
+        group: Sequence["SaQC"] | dict["SaQC", str | Sequence[str]],
+        target: str | None = None,
+        flag: float = BAD,
+        **kwargs,
+    ) -> "SaQC":
+        """
+        Flag all values, if a given variable is also flagged in all other given SaQC objects.
+
+        Parameters
+        ----------
+        field : str
+            Name of the field to check for flags. 'field' needs to present in all
+            objects in 'qcs'.
+
+        qcs : list of SaQC
+            A list of SaQC objects to check for flags.
+
+        target : str, default none
+            Name of the field the generated flags will be written to. If None, the result
+            will be written to 'field',
+
+        flag: float, default ``BAD``
+            The quality flag to set.
+
+        Returns
+        -------
+        saqc.SaQC
+        """
+
+        return _groupOperation(
+            base=self,
+            field=field,
+            target=target,
+            func=operator.and_,
+            group=group,
+            flag=flag,
+            **kwargs,
+        )
+
+    @register(
+        mask=["field"],
+        demask=["field"],
+        squeeze=["field"],
+        multivariate=False,
+        handles_target=True,
+    )
+    def orGroup(
+        self: "SaQC",
+        field: str,
+        group: Sequence["SaQC"] | dict["SaQC", str | Sequence[str]],
+        target: str | None = None,
+        flag: float = BAD,
+        **kwargs,
+    ) -> "SaQC":
+        """
+        Flag all values, if a given variable is also flagged in at least one other of the given SaQC objects.
+
+        Parameters
+        ----------
+        field : str
+            Name of the field to check for flags. 'field' needs to present in all
+            objects in 'qcs'.
+
+        qcs : list of SaQC
+            A list of SaQC objects to check for flags.
+
+        target : str, default none
+            Name of the field the generated flags will be written to. If None, the result
+            will be written to 'field',
+
+        flag: float, default ``BAD``
+            The quality flag to set.
+
+        Returns
+        -------
+        saqc.SaQC
+        """
+        return _groupOperation(
+            base=self,
+            field=field,
+            target=target,
+            func=operator.or_,
+            group=group,
+            flag=flag,
+            **kwargs,
+        )
+
+
+def _groupOperation(
+    base: "SaQC",
+    field: str,
+    func: Callable[[pd.Series, pd.Series], pd.Series],
+    group: Sequence["SaQC"] | dict["SaQC", str | Sequence[str]],
+    target: str | None = None,
+    flag: float = BAD,
+    **kwargs,
+) -> "SaQC":
+    # Should this be multivariate? And what would multivariate mean in this context
+
+    dfilter = kwargs.get("dfilter", FILTER_ALL)
+    if target is None:
+        target = field
+
+    # harmonise `group` to type dict[SaQC, list[str]]
+    if not isinstance(group, dict):
+        group = {qc: field for qc in group}
+
+    for k, v in group.items():
+        group[k] = toSequence(v)
+
+    qcs_items: list[tuple["SaQC", list[str]]] = list(group.items())
+    # generate initial mask from the first `qc` object on the popped first field
+    mask = _isflagged(qcs_items[0][0]._flags[qcs_items[0][1].pop(0)], thresh=dfilter)
+
+    for qc, fields in qcs_items:
+        if field not in qc._flags:
+            raise KeyError(f"variable {field} is missing in given SaQC object")
+        for field in fields:
+            mask = func(mask, _isflagged(qc._flags[field], thresh=FILTER_ALL))
+
+    if target not in base._data:
+        base = base.copyField(field=field, target=target)
+
+    base._flags[mask, target] = flag
+    return base
diff --git a/tests/funcs/test_flagtools.py b/tests/funcs/test_flagtools.py
index 457edab8a..f885c91e1 100644
--- a/tests/funcs/test_flagtools.py
+++ b/tests/funcs/test_flagtools.py
@@ -4,15 +4,16 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+import operator
+
 import numpy as np
 import pandas as pd
-
-# -*- coding: utf-8 -*-
 import pytest
 
 from saqc import BAD as B
 from saqc import UNFLAGGED as U
 from saqc import SaQC
+from saqc.funcs.flagtools import _groupOperation
 
 N = np.nan
 
@@ -98,3 +99,70 @@ def test_propagateFlagsIrregularIndex(got, expected, kwargs):
     saqc = SaQC(data=data, flags=flags).propagateFlags(field="x", **kwargs)
     result = saqc._flags.history["x"].hist[1].astype(float)
     assert result.equals(expected)
+
+
+@pytest.mark.parametrize(
+    "left,right,expected",
+    [
+        ([B, U, U, B], [B, B, U, U], [B, U, U, U]),
+        ([B, B, B, B], [B, B, B, B], [B, B, B, B]),
+        ([U, U, U, U], [U, U, U, U], [U, U, U, U]),
+    ],
+)
+def test_andGroup(left, right, expected):
+
+    data = pd.DataFrame({"data": [1, 2, 3, 4]})
+
+    base = SaQC(data=data)
+    this = SaQC(data=data, flags=pd.DataFrame({"data": pd.Series(left)}))
+    that = SaQC(data=data, flags=pd.DataFrame({"data": pd.Series(right)}))
+    result = base.andGroup(field="data", group=[this, that])
+
+    assert pd.Series(expected).equals(result.flags["data"])
+
+
+@pytest.mark.parametrize(
+    "left,right,expected",
+    [
+        ([B, U, U, B], [B, B, U, U], [B, B, U, B]),
+        ([B, B, B, B], [B, B, B, B], [B, B, B, B]),
+        ([U, U, U, U], [U, U, U, U], [U, U, U, U]),
+    ],
+)
+def test_orGroup(left, right, expected):
+
+    data = pd.DataFrame({"data": [1, 2, 3, 4]})
+
+    base = SaQC(data=data)
+    this = SaQC(data=data, flags=pd.DataFrame({"data": pd.Series(left)}))
+    that = SaQC(data=data, flags=pd.DataFrame({"data": pd.Series(right)}))
+    result = base.orGroup(field="data", group=[this, that])
+
+    assert pd.Series(expected).equals(result.flags["data"])
+
+
+@pytest.mark.parametrize(
+    "left,right,expected",
+    [
+        ([B, U, U, B], [B, B, U, U], [B, B, U, B]),
+        ([B, B, B, B], [B, B, B, B], [B, B, B, B]),
+        ([U, U, U, U], [U, U, U, U], [U, U, U, U]),
+    ],
+)
+def test__groupOperation(left, right, expected):
+
+    data = pd.DataFrame(
+        {"x": [0, 1, 2, 3], "y": [0, 11, 22, 33], "z": [0, 111, 222, 333]}
+    )
+    base = SaQC(data=data)
+    this = SaQC(
+        data=data, flags=pd.DataFrame({k: pd.Series(left) for k in data.columns})
+    )
+    that = SaQC(
+        data=data, flags=pd.DataFrame({k: pd.Series(right) for k in data.columns})
+    )
+    result = _groupOperation(
+        base=base, field="x", func=operator.or_, group={this: "y", that: ["y", "z"]}
+    )
+
+    assert pd.Series(expected).equals(result.flags["x"])
-- 
GitLab