Skip to content
Snippets Groups Projects
Commit ad614ccf authored by David Schäfer's avatar David Schäfer
Browse files

Merge branch 'generic-rework' into 'develop'

Simplify the generics and make them policy compliant

See merge request !650
parents a4d4021a 9085905a
No related branches found
No related tags found
3 merge requests!685Release 2.4,!684Release 2.4,!650Simplify the generics and make them policy compliant
Pipeline #161503 passed with stages
in 10 minutes and 9 seconds
......@@ -8,7 +8,7 @@ from __future__ import annotations
import typing
import warnings
from typing import DefaultDict, Dict, Iterable, Mapping, Tuple, Type, Union
from typing import DefaultDict, Dict, Iterable, Mapping, Tuple, Type, Union, overload
import numpy as np
import pandas as pd
......@@ -320,6 +320,14 @@ class Flags:
# ----------------------------------------------------------------------
# item access
@overload
def __getitem__(self, key: str) -> pd.Series:
...
@overload
def __getitem__(self, key: list | pd.Index) -> Flags:
...
def __getitem__(self, key: str | list | pd.Index) -> pd.Series | Flags:
if isinstance(key, str):
return self._data[key].squeeze()
......
......@@ -7,23 +7,29 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import TYPE_CHECKING, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Protocol, Sequence
import numpy as np
import pandas as pd
from saqc import BAD, FILTER_ALL
from saqc.core import DictOfSeries, Flags, History, register
from saqc.core.register import _maskData
from saqc.core import DictOfSeries, Flags, register
from saqc.lib.tools import isAllBoolean, isflagged, toSequence
from saqc.lib.types import GenericFunction
from saqc.parsing.environ import ENVIRONMENT
if TYPE_CHECKING:
from saqc import SaQC
def _flagSelect(field, flags, label=None):
class GenericFunction(Protocol):
__name__: str
__globals__: dict[str, Any]
def __call__(self, *args: pd.Series) -> pd.Series | pd.DataFrame | DictOfSeries:
... # pragma: no cover
def _flagSelect(field: str, flags: Flags, label: str | None = None) -> pd.Series:
if label is None:
return flags[field]
......@@ -44,24 +50,12 @@ def _flagSelect(field, flags, label=None):
return out.fillna(-np.inf)
def _prepare(
data: DictOfSeries, flags: Flags, columns: Sequence[str], dfilter: float
) -> Tuple[DictOfSeries, Flags]:
fchunk = Flags({f: flags[f] for f in columns})
for f in fchunk.columns:
fchunk.history[f] = flags.history[f]
dchunk, _ = _maskData(
data=data[columns].copy(), flags=fchunk, columns=columns, thresh=dfilter
)
return dchunk, fchunk.copy()
def _execGeneric(
flags: Flags,
data: pd.DataFrame | pd.Series | DictOfSeries,
func: GenericFunction,
dfilter: float = FILTER_ALL,
) -> DictOfSeries:
) -> DictOfSeries | pd.DataFrame | pd.Series:
globs = {
"isflagged": lambda data, label=None: isflagged(
_flagSelect(data.name, flags, label), thresh=dfilter
......@@ -107,9 +101,9 @@ def _castResult(obj) -> DictOfSeries:
class GenericMixin:
@register(
mask=[],
demask=[],
squeeze=[],
mask=["field"],
demask=["field"],
squeeze=["field", "target"],
multivariate=True,
handles_target=True,
)
......@@ -175,27 +169,17 @@ class GenericMixin:
fields = toSequence(field)
targets = fields if target is None else toSequence(target)
dchunk, fchunk = _prepare(self._data, self._flags, fields, dfilter)
dchunk, fchunk = self._data[fields].copy(), self._flags[fields].copy()
result = _execGeneric(fchunk, dchunk, func, dfilter=dfilter)
result = _castResult(result)
meta = {
"func": "procGeneric",
"args": (field, target),
"kwargs": {
"func": func.__name__,
"dfilter": dfilter,
**kwargs,
},
}
# update data & flags
for i, col in enumerate(targets):
datacol = result[result.columns[i]]
self._data[col] = datacol
if col not in self._flags:
self._flags.history[col] = History(datacol.index)
self._flags[col] = pd.Series(np.nan, index=datacol.index)
if not self._flags[col].index.equals(datacol.index):
raise ValueError(
......@@ -203,16 +187,14 @@ class GenericMixin:
"because of incompatible indices, please choose another 'target'"
)
self._flags.history[col].append(
pd.Series(np.nan, index=datacol.index), meta
)
self._flags[:, col] = np.nan
return self
@register(
mask=[],
demask=[],
squeeze=[],
mask=["field"],
demask=["field"],
squeeze=["field", "target"],
multivariate=True,
handles_target=True,
)
......@@ -222,7 +204,6 @@ class GenericMixin:
func: GenericFunction,
target: str | Sequence[str] | None = None,
flag: float = BAD,
dfilter: float = FILTER_ALL,
**kwargs,
) -> "SaQC":
"""
......@@ -249,10 +230,6 @@ class GenericMixin:
flag: float, default ``BAD``
Quality flag to set.
dfilter: float, default ``FILTER_ALL``
Threshold flag. Flag values greater than ``dfilter`` indicate that the associated
data value is inappropiate for further usage.
Returns
-------
saqc.SaQC
......@@ -285,8 +262,9 @@ class GenericMixin:
fields = toSequence(field)
targets = fields if target is None else toSequence(target)
dfilter = kwargs.get("dfilter", BAD)
dchunk, fchunk = _prepare(self._data, self._flags, fields, dfilter)
dchunk, fchunk = self._data[fields].copy(), self._flags[fields].copy()
result = _execGeneric(fchunk, dchunk, func, dfilter=dfilter)
result = _castResult(result)
......@@ -299,43 +277,28 @@ class GenericMixin:
if not result.empty and not isAllBoolean(result):
raise TypeError(f"generic expression does not return a boolean array")
meta = {
"func": "flagGeneric",
"args": (field, target),
"kwargs": {
"func": func.__name__,
"flag": flag,
"dfilter": dfilter,
**kwargs,
},
}
# update flags & data
for i, col in enumerate(targets):
maskcol = result[result.columns[i]]
mask = result[result.columns[i]]
# make sure the column exists
if col not in self._flags:
self._flags.history[col] = History(maskcol.index)
self._flags[col] = pd.Series(np.nan, index=mask.index)
# respect existing flags
mask = ~isflagged(self._flags[col], thresh=dfilter) & mask
# dummy column to ensure consistency between flags and data
if col not in self._data:
self._data[col] = pd.Series(np.nan, index=maskcol.index, dtype=float)
# Note: big speedup for series, because replace works
# with a loop and setting with mask is vectorized.
# old code:
# >>> flagcol = maskcol.replace({False: np.nan, True: flag}).astype(float)
flagcol = pd.Series(np.nan, index=maskcol.index, dtype=float)
flagcol[maskcol] = flag
self._data[col] = pd.Series(np.nan, index=mask.index, dtype=float)
# we need equal indices to work on
if not self._flags[col].index.equals(maskcol.index):
if not self._flags[col].index.equals(mask.index):
raise ValueError(
f"cannot assign function result to the existing variable {repr(col)} "
"because of incompatible indices, please choose another 'target'"
)
self._flags.history[col].append(flagcol, meta)
self._flags[mask, col] = flag
return self
......@@ -14,8 +14,6 @@ import numpy as np
import pandas as pd
from typing_extensions import Protocol
from saqc.core import DictOfSeries
__all__ = [
"T",
"ArrayLike",
......@@ -36,14 +34,6 @@ class CurveFitter(Protocol):
... # pragma: no cover
class GenericFunction(Protocol):
__name__: str
__globals__: Dict[str, Any]
def __call__(self, *args: pd.Series) -> pd.Series | pd.DataFrame | DictOfSeries:
... # pragma: no cover
class Comparable(Protocol):
@abc.abstractmethod
def __gt__(self: CompT, other: CompT) -> bool:
......
......@@ -9,9 +9,9 @@
import pandas as pd
import pytest
from saqc import BAD, FILTER_ALL, UNFLAGGED, SaQC
from saqc import BAD, UNFLAGGED, SaQC
from saqc.constants import FILTER_NONE
from saqc.core import DictOfSeries, Flags
from saqc.lib.tools import toSequence
from tests.common import initData
......@@ -44,21 +44,10 @@ def test_emptyData():
],
)
def test_writeTargetFlagGeneric(data, targets, func):
expected_meta = {
"func": "flagGeneric",
"args": (data.columns.tolist(), targets),
"kwargs": {
"func": func.__name__,
"flag": BAD,
"dfilter": FILTER_ALL,
},
}
saqc = SaQC(data=data)
saqc = saqc.flagGeneric(field=data.columns, target=targets, func=func, flag=BAD)
for target in targets:
assert saqc._flags.history[target].hist.iloc[0].tolist() == [BAD]
assert saqc._flags.history[target].meta[0] == expected_meta
@pytest.mark.parametrize(
......@@ -74,16 +63,6 @@ def test_writeTargetFlagGeneric(data, targets, func):
def test_overwriteFieldFlagGeneric(data, fields, func):
flag = 12
expected_meta = {
"func": "flagGeneric",
"args": (fields, fields),
"kwargs": {
"func": func.__name__,
"flag": flag,
"dfilter": FILTER_ALL,
},
}
saqc = SaQC(
data=data.copy(),
flags=Flags(
......@@ -96,13 +75,12 @@ def test_overwriteFieldFlagGeneric(data, fields, func):
),
)
res = saqc.flagGeneric(field=fields, func=func, flag=flag)
res = saqc.flagGeneric(field=fields, func=func, flag=flag, dfilter=FILTER_NONE)
for field in fields:
histcol1 = res._flags.history[field].hist[1]
assert (histcol1 == flag).all()
assert (data[field] == res.data[field]).all(axis=None)
assert res._flags.history[field].meta[0] == {}
assert res._flags.history[field].meta[1] == expected_meta
@pytest.mark.parametrize(
......@@ -139,15 +117,6 @@ def test_writeTargetProcGeneric(data, targets, func, expected_data):
fields = data.columns.tolist()
dfilter = 128
expected_meta = {
"func": "procGeneric",
"args": (fields, targets),
"kwargs": {
"func": func.__name__,
"dfilter": dfilter,
"label": "generic",
},
}
saqc = SaQC(
data=data,
flags=Flags({k: pd.Series(127.0, index=data[k].index) for k in data.columns}),
......@@ -163,7 +132,6 @@ def test_writeTargetProcGeneric(data, targets, func, expected_data):
# check that new histories where created
for target in targets:
assert res._flags.history[target].hist.iloc[0].isna().all()
assert res._flags.history[target].meta[0] == expected_meta
@pytest.mark.parametrize(
......@@ -186,16 +154,6 @@ def test_writeTargetProcGeneric(data, targets, func, expected_data):
def test_overwriteFieldProcGeneric(data, fields, func, expected_data):
dfilter = 128
expected_meta = {
"func": "procGeneric",
"args": (fields, fields),
"kwargs": {
"func": func.__name__,
"dfilter": dfilter,
"label": "generic",
},
}
saqc = SaQC(
data=data,
flags=Flags({k: pd.Series(127.0, index=data[k].index) for k in data.columns}),
......@@ -208,7 +166,6 @@ def test_overwriteFieldProcGeneric(data, fields, func, expected_data):
assert (res._flags.history[field].hist[0] == 127.0).all()
assert res._flags.history[field].hist[1].isna().all()
assert res._flags.history[field].meta[0] == {}
assert res._flags.history[field].meta[1] == expected_meta
def test_label():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment