Skip to content
Snippets Groups Projects
Commit 4c658f9c authored by David Schäfer's avatar David Schäfer
Browse files

Support function call groups

parent 213986d7
No related branches found
No related tags found
3 merge requests!685Release 2.4,!684Release 2.4,!607Support function call groups
......@@ -9,6 +9,7 @@ SPDX-License-Identifier: GPL-3.0-or-later
## Unreleased
[List of commits](https://git.ufz.de/rdm-software/saqc/-/compare/v2.3.0...develop)
### Added
- Methods `logicalAnd` and `logicalOr`
### Changed
### Removed
### Fixed
......
......@@ -7,8 +7,9 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import operator
import warnings
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any, Callable, Sequence, Union
import numpy as np
import pandas as pd
......@@ -17,6 +18,7 @@ from typing_extensions import Literal
from dios import DictOfSeries
from saqc.constants import BAD, FILTER_ALL, UNFLAGGED
from saqc.core.register import _isflagged, flagging, register
from saqc.lib.tools import toSequence
if TYPE_CHECKING:
from saqc.core.core import SaQC
......@@ -531,3 +533,139 @@ class FlagtoolsMixin:
self._flags[repeated, field] = flag
return self
@register(
mask=["field"],
demask=["field"],
squeeze=["field"],
multivariate=False,
handles_target=True,
)
def andGroup(
self: "SaQC",
field: str,
group: Sequence["SaQC"] | dict["SaQC", str | Sequence[str]],
target: str | None = None,
flag: float = BAD,
**kwargs,
) -> "SaQC":
"""
Flag all values, if a given variable is also flagged in all other given SaQC objects.
Parameters
----------
field : str
Name of the field to check for flags. 'field' needs to present in all
objects in 'qcs'.
qcs : list of SaQC
A list of SaQC objects to check for flags.
target : str, default none
Name of the field the generated flags will be written to. If None, the result
will be written to 'field',
flag: float, default ``BAD``
The quality flag to set.
Returns
-------
saqc.SaQC
"""
return _groupOperation(
base=self,
field=field,
target=target,
func=operator.and_,
group=group,
flag=flag,
**kwargs,
)
@register(
mask=["field"],
demask=["field"],
squeeze=["field"],
multivariate=False,
handles_target=True,
)
def orGroup(
self: "SaQC",
field: str,
group: Sequence["SaQC"] | dict["SaQC", str | Sequence[str]],
target: str | None = None,
flag: float = BAD,
**kwargs,
) -> "SaQC":
"""
Flag all values, if a given variable is also flagged in at least one other of the given SaQC objects.
Parameters
----------
field : str
Name of the field to check for flags. 'field' needs to present in all
objects in 'qcs'.
qcs : list of SaQC
A list of SaQC objects to check for flags.
target : str, default none
Name of the field the generated flags will be written to. If None, the result
will be written to 'field',
flag: float, default ``BAD``
The quality flag to set.
Returns
-------
saqc.SaQC
"""
return _groupOperation(
base=self,
field=field,
target=target,
func=operator.or_,
group=group,
flag=flag,
**kwargs,
)
def _groupOperation(
base: "SaQC",
field: str,
func: Callable[[pd.Series, pd.Series], pd.Series],
group: Sequence["SaQC"] | dict["SaQC", str | Sequence[str]],
target: str | None = None,
flag: float = BAD,
**kwargs,
) -> "SaQC":
# Should this be multivariate? And what would multivariate mean in this context
dfilter = kwargs.get("dfilter", FILTER_ALL)
if target is None:
target = field
# harmonise `group` to type dict[SaQC, list[str]]
if not isinstance(group, dict):
group = {qc: field for qc in group}
for k, v in group.items():
group[k] = toSequence(v)
qcs_items: list[tuple["SaQC", list[str]]] = list(group.items())
# generate initial mask from the first `qc` object on the popped first field
mask = _isflagged(qcs_items[0][0]._flags[qcs_items[0][1].pop(0)], thresh=dfilter)
for qc, fields in qcs_items:
if field not in qc._flags:
raise KeyError(f"variable {field} is missing in given SaQC object")
for field in fields:
mask = func(mask, _isflagged(qc._flags[field], thresh=FILTER_ALL))
if target not in base._data:
base = base.copyField(field=field, target=target)
base._flags[mask, target] = flag
return base
......@@ -4,15 +4,16 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import operator
import numpy as np
import pandas as pd
# -*- coding: utf-8 -*-
import pytest
from saqc import BAD as B
from saqc import UNFLAGGED as U
from saqc import SaQC
from saqc.funcs.flagtools import _groupOperation
N = np.nan
......@@ -98,3 +99,70 @@ def test_propagateFlagsIrregularIndex(got, expected, kwargs):
saqc = SaQC(data=data, flags=flags).propagateFlags(field="x", **kwargs)
result = saqc._flags.history["x"].hist[1].astype(float)
assert result.equals(expected)
@pytest.mark.parametrize(
"left,right,expected",
[
([B, U, U, B], [B, B, U, U], [B, U, U, U]),
([B, B, B, B], [B, B, B, B], [B, B, B, B]),
([U, U, U, U], [U, U, U, U], [U, U, U, U]),
],
)
def test_andGroup(left, right, expected):
data = pd.DataFrame({"data": [1, 2, 3, 4]})
base = SaQC(data=data)
this = SaQC(data=data, flags=pd.DataFrame({"data": pd.Series(left)}))
that = SaQC(data=data, flags=pd.DataFrame({"data": pd.Series(right)}))
result = base.andGroup(field="data", group=[this, that])
assert pd.Series(expected).equals(result.flags["data"])
@pytest.mark.parametrize(
"left,right,expected",
[
([B, U, U, B], [B, B, U, U], [B, B, U, B]),
([B, B, B, B], [B, B, B, B], [B, B, B, B]),
([U, U, U, U], [U, U, U, U], [U, U, U, U]),
],
)
def test_orGroup(left, right, expected):
data = pd.DataFrame({"data": [1, 2, 3, 4]})
base = SaQC(data=data)
this = SaQC(data=data, flags=pd.DataFrame({"data": pd.Series(left)}))
that = SaQC(data=data, flags=pd.DataFrame({"data": pd.Series(right)}))
result = base.orGroup(field="data", group=[this, that])
assert pd.Series(expected).equals(result.flags["data"])
@pytest.mark.parametrize(
"left,right,expected",
[
([B, U, U, B], [B, B, U, U], [B, B, U, B]),
([B, B, B, B], [B, B, B, B], [B, B, B, B]),
([U, U, U, U], [U, U, U, U], [U, U, U, U]),
],
)
def test__groupOperation(left, right, expected):
data = pd.DataFrame(
{"x": [0, 1, 2, 3], "y": [0, 11, 22, 33], "z": [0, 111, 222, 333]}
)
base = SaQC(data=data)
this = SaQC(
data=data, flags=pd.DataFrame({k: pd.Series(left) for k in data.columns})
)
that = SaQC(
data=data, flags=pd.DataFrame({k: pd.Series(right) for k in data.columns})
)
result = _groupOperation(
base=base, field="x", func=operator.or_, group={this: "y", that: ["y", "z"]}
)
assert pd.Series(expected).equals(result.flags["x"])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment