Skip to content
Snippets Groups Projects
Commit c4bac556 authored by David Schäfer's avatar David Schäfer
Browse files

Auto translation of input flags

parent 252d0c10
No related branches found
No related tags found
1 merge request!785Auto translation of input flags
......@@ -9,6 +9,8 @@ SPDX-License-Identifier: GPL-3.0-or-later
[List of commits](https://git.ufz.de/rdm-software/saqc/-/compare/v2.5.0...develop)
### Added
- `flagGeneric`: target broadcasting
- `SaQC`: automatic tranlation of incoming flags
- Option to change the flagging scheme after initialization
### Changed
### Removed
### Fixed
......
......@@ -10,6 +10,7 @@ from __future__ import annotations
import warnings
from copy import copy as shallowcopy
from copy import deepcopy
from functools import partial
from typing import Any, Hashable, MutableMapping
import numpy as np
......@@ -52,13 +53,21 @@ class SaQC(FunctionsMixin):
def __init__(
self,
data=None,
flags=None,
data: pd.Series
| pd.DataFrame
| DictOfSeries
| list[pd.Series | pd.DataFrame | DictOfSeries]
| None = None,
flags: pd.DataFrame
| DictOfSeries
| Flags
| list[pd.DataFrame | DictOfSeries | Flags]
| None = None,
scheme: str | TranslationScheme = "float",
):
self.scheme: TranslationScheme = scheme
self._data: DictOfSeries = self._initData(data)
self._flags: Flags = self._initFlags(flags)
self._scheme: TranslationScheme = self._initTranslationScheme(scheme)
self._attrs: dict = {}
self._validate(reason="init")
......@@ -86,7 +95,7 @@ class SaQC(FunctionsMixin):
def _validate(self, reason=None):
if not self._data.columns.equals(self._flags.columns):
msg = "Consistency broken. data and flags have not the same columns."
msg = "Data and flags don't contain the same columns."
if reason:
msg += f" This was most likely caused by: {reason}"
raise RuntimeError(msg)
......@@ -114,6 +123,21 @@ class SaQC(FunctionsMixin):
flags.attrs = self._attrs.copy()
return flags
@property
def scheme(self) -> TranslationScheme:
return self._scheme
@scheme.setter
def scheme(self, scheme: str | TranslationScheme) -> None:
if isinstance(scheme, str) and scheme in TRANSLATION_SCHEMES:
scheme = TRANSLATION_SCHEMES[scheme]()
if not isinstance(scheme, TranslationScheme):
raise TypeError(
f"expected one of the following translation schemes '{TRANSLATION_SCHEMES.keys()} "
f"or an initialized Translator object, got '{scheme}'"
)
self._scheme = scheme
@property
def _history(self) -> _HistAccess:
return self._flags.history
......@@ -124,7 +148,6 @@ class SaQC(FunctionsMixin):
We use this mechanism to make the registered functions appear
as `SaQC`-methods without actually implementing them.
"""
from functools import partial
if key not in FUNC_MAP:
raise AttributeError(f"SaQC has no attribute {repr(key)}")
......@@ -248,5 +271,5 @@ class SaQC(FunctionsMixin):
if isinstance(idx, pd.MultiIndex):
raise TypeError("'flags' should not have MultiIndex")
if not isinstance(flags, Flags):
flags = Flags(flags)
flags = Flags(self._scheme.toInternal(flags))
return flags
......@@ -71,6 +71,23 @@ class DictOfSeries(DictOfPandas):
return self.shared_index()
raise ValueError("method must be one of 'shared' or 'union'.")
def astype(self, dtype: str | type) -> DictOfSeries:
"""
Cast a DictOfSeries object to the specified ``dtype``
Parameters
----------
dtype: data type to cast the entire object to.
Returns
-------
DictOfSeries
"""
out = DictOfSeries()
for k, v in self.data.items():
out[k] = v.astype(dtype)
return out
DictOfSeries.empty.__doc__ = """
Indicator whether DictOfSeries is empty.
......
......@@ -65,6 +65,22 @@ def test_dtypes(data, flags):
assert pflags[c].dtype == flags[c].dtype
def test_autoTranslation():
data = pd.Series(
[1, 2], index=pd.date_range("2000", periods=2, freq="D"), name="data"
)
flags = pd.DataFrame(["BAD", "UNFLAGGED"], index=data.index, columns=["data"])
qc = SaQC(data=data, flags=flags, scheme="simple")
assert (qc.flags["data"] == ["BAD", "UNFLAGGED"]).all() # external flags
assert (qc._flags["data"] == [BAD, UNFLAGGED]).all() # internal flags
qc.scheme = "float"
assert (qc.flags["data"] == [BAD, UNFLAGGED]).all() # external flags
assert (qc._flags["data"] == [BAD, UNFLAGGED]).all() # internal flags
def test_new_call(data):
qc = SaQC(data)
qc = qc.flagRange("var1", max=5)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment