From c498f0b4f125832606ca6a1cca71d90c484598a3 Mon Sep 17 00:00:00 2001 From: Bert Palm <bert.palm@ufz.de> Date: Fri, 26 Feb 2021 14:34:45 +0100 Subject: [PATCH] init core adjusted --- saqc/core/core.py | 62 ++++++++++++++++++-------------------- saqc/flagger/__init__.py | 2 ++ saqc/flagger/flags.py | 4 +++ test/core/test_core_new.py | 20 ++++++++++++ 4 files changed, 55 insertions(+), 33 deletions(-) create mode 100644 test/core/test_core_new.py diff --git a/saqc/core/core.py b/saqc/core/core.py index 8a8b6283c..df90d80a0 100644 --- a/saqc/core/core.py +++ b/saqc/core/core.py @@ -10,7 +10,7 @@ TODOS: import logging import copy as stdcopy -from typing import List, Tuple, Sequence +from typing import List, Tuple, Sequence, Union from typing_extensions import Literal import pandas as pd @@ -19,7 +19,8 @@ import numpy as np import timeit import inspect -from saqc.flagger import BaseFlagger, CategoricalFlagger, SimpleFlagger, DmpFlagger +from saqc.common import * +from saqc.flagger.flags import init_flags_like, Flagger from saqc.core.lib import APIController, ColumnSelector from saqc.core.register import FUNC_MAP, SaQCFunction from saqc.core.modules import FuncModules @@ -49,7 +50,8 @@ def _handleErrors(exc: Exception, field: str, control: APIController, func: SaQC raise exc -def _prepInput(flagger, data, flags): +# todo: shouldt this go to Saqc.__init__ ? +def _prepInput(data, flags): dios_like = (dios.DictOfSeries, pd.DataFrame) if isinstance(data, pd.Series): @@ -66,30 +68,23 @@ def _prepInput(flagger, data, flags): if not hasattr(data.columns, "str"): raise TypeError("expected dataframe columns of type string") - if not isinstance(flagger, BaseFlagger): - # NOTE: we should generate that list automatically, - # it won't ever be complete otherwise - flaggerlist = [CategoricalFlagger, SimpleFlagger, DmpFlagger] - raise TypeError(f"'flagger' must be of type {flaggerlist} or a subclass of {BaseFlagger}") - if flags is not None: - if not isinstance(flags, dios_like): - raise TypeError("'flags' must be of type dios.DictOfSeries or pd.DataFrame") if isinstance(flags, pd.DataFrame): if isinstance(flags.index, pd.MultiIndex) or isinstance(flags.columns, pd.MultiIndex): raise TypeError("'flags' should not use MultiIndex") - flags = dios.to_dios(flags) - # NOTE: do not test all columns as they not necessarily need to be the same - cols = flags.columns & data.columns - if not (flags[cols].lengths == data[cols].lengths).all(): - raise ValueError("the length of 'flags' and 'data' need to be equal") + if isinstance(flags, (dios.DictOfSeries, pd.DataFrame, Flagger)): + # NOTE: only test common columns, data as well as flags could + # have more columns than the respective other. + cols = flags.columns & data.columns + for c in cols: + if not flags[c].index.equals(data[c].index): + raise ValueError(f"the index of 'flags' and 'data' missmatch in column {c}") - if flagger.initialized: - diff = data.columns.difference(flagger.getFlags().columns) - if not diff.empty: - raise ValueError("Missing columns in 'flagger': '{list(diff)}'") + # this also ensures float dtype + if not isinstance(flags, Flagger): + flags = Flagger(flags, copy=True) return data, flags @@ -110,31 +105,32 @@ _setup() class SaQC(FuncModules): - def __init__(self, flagger, data, flags=None, nodata=np.nan, to_mask=None, error_policy="raise"): + def __init__(self, data, flags=None, nodata=np.nan, to_mask=None, error_policy="raise"): super().__init__(self) - data, flags = _prepInput(flagger, data, flags) + data, flagger = _prepInput(data, flags) self._data = data self._nodata = nodata self._to_mask = to_mask - self._flagger = self._initFlagger(data, flagger, flags) + self._flagger = self._initFlagger(data, flags) self._error_policy = error_policy # NOTE: will be filled by calls to `_wrap` self._to_call: List[Tuple[ColumnSelector, APIController, SaQCFunction]] = [] - def _initFlagger(self, data, flagger, flags): + def _initFlagger(self, data, flagger: Union[Flagger, None]): """ Init the internal flagger object. Ensures that all data columns are present and user passed flags from - a flags frame and/or an already initialised flagger are used. - If columns overlap the passed flagger object is prioritised. + a flags frame or an already initialised flagger are used. """ - # ensure all data columns - merged = flagger.initFlags(data) - if flags is not None: - merged = merged.merge(flagger.initFlags(flags=flags), inplace=True) - if flagger.initialized: - merged = merged.merge(flagger, inplace=True) - return merged + if flagger is None: + return init_flags_like(data) + + for c in flagger.columns.union(data.columns): + if c in flagger: + continue + if c in data: + flagger[c] = pd.Series(UNFLAGGED, index=data[c].index, dtype=float) + return flagger def readConfig(self, fname): from saqc.core.reader import readConfig diff --git a/saqc/flagger/__init__.py b/saqc/flagger/__init__.py index d5124fb9d..774f2ec2b 100644 --- a/saqc/flagger/__init__.py +++ b/saqc/flagger/__init__.py @@ -1,6 +1,8 @@ #! /usr/bin/env python # -*- coding: utf-8 -*- +from .flags import Flagger +from .history import History from saqc.flagger.baseflagger import BaseFlagger from saqc.flagger.categoricalflagger import CategoricalFlagger from saqc.flagger.simpleflagger import SimpleFlagger diff --git a/saqc/flagger/flags.py b/saqc/flagger/flags.py index bf64ec556..15b8a4efc 100644 --- a/saqc/flagger/flags.py +++ b/saqc/flagger/flags.py @@ -311,3 +311,7 @@ def init_flags_like(reference: Union[pd.Series, DictLike, Flags], initial_value: return Flags(result) + +# for now we keep this name +Flagger = Flags + diff --git a/test/core/test_core_new.py b/test/core/test_core_new.py new file mode 100644 index 000000000..b16714c8e --- /dev/null +++ b/test/core/test_core_new.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +import pandas as pd +import numpy as np +import dios + + +def test_init(): + from saqc import SaQC, Flagger + + arr = np.array([ + [0, 1, 2], + [0, 1, 3], + ]) + data = pd.DataFrame(arr, columns=list('abc')) + qc = SaQC(data) + + assert isinstance(qc, SaQC) + assert isinstance(qc._flagger, Flagger) + assert isinstance(qc._data, dios.DictOfSeries) -- GitLab