diff --git a/saqc/flagger/baseflagger.py b/saqc/flagger/baseflagger.py index b46515d68543bcfe1f4caf987440c621d4821005..af0b1b7a5a4936362451488f7629800e82b3f908 100644 --- a/saqc/flagger/baseflagger.py +++ b/saqc/flagger/baseflagger.py @@ -1,18 +1,21 @@ #! /usr/bin/env python # -*- coding: utf-8 -*- +from __future__ import annotations import operator as op from copy import deepcopy from abc import ABC, abstractmethod -from typing import TypeVar, Union, Any, List, Optional +from typing import TypeVar, Union, Any, List, Optional, Tuple, Dict, Sequence import pandas as pd import numpy as np -import dios + +from dios import DictOfSeries from saqc.lib.tools import assertScalar, mergeDios, toSequence, customRoller + COMPARATOR_MAP = { "!=": op.ne, "==": op.eq, @@ -22,13 +25,12 @@ COMPARATOR_MAP = { "<": op.lt, } -# TODO: get some real types here (could be tricky...) -LocT = Any + +LocT = Union[pd.Series, pd.Index, slice] FlagT = Any -diosT = dios.DictOfSeries -BaseFlaggerT = TypeVar("BaseFlaggerT") -PandasT = Union[pd.Series, diosT] -FieldsT = Union[str, List[str]] +BaseFlaggerT = TypeVar("BaseFlaggerT", bound="BaseFlagger") +PandasT = Union[pd.Series, DictOfSeries] +FieldsT = Union[str, List[str], slice] class BaseFlagger(ABC): @@ -40,7 +42,7 @@ class BaseFlagger(ABC): # NOTE: the arggumens of setFlags supported from # the configuration functions self.signature = ("flag",) - self._flags: Optional[diosT] = None + self._flags: Optional[DictOfSeries] = None @property def initialized(self): @@ -50,7 +52,7 @@ class BaseFlagger(ABC): def flags(self): return self._flags.copy() - def initFlags(self, data: diosT = None, flags: diosT = None) -> BaseFlaggerT: + def initFlags(self, data: DictOfSeries = None, flags: DictOfSeries = None) -> BaseFlagger: """ initialize a flagger based on the given 'data' or 'flags' if 'data' is not None: return a flagger with flagger.UNFLAGGED values @@ -64,21 +66,21 @@ class BaseFlagger(ABC): raise TypeError("either 'data' or 'flags' can be given") if data is not None: - if not isinstance(data, diosT): - data = dios.DictOfSeries(data) + if not isinstance(data, DictOfSeries): + data = DictOfSeries(data) - flags = dios.DictOfSeries(columns=data.columns) - for c in flags.columns: + flags = DictOfSeries(columns=data.columns) + for c in flags.columns.tolist(): flags[c] = pd.Series(self.UNFLAGGED, index=data[c].index) else: - if not isinstance(flags, diosT): - flags = dios.DictOfSeries(flags) + if not isinstance(flags, DictOfSeries): + flags = DictOfSeries(flags) flags = flags.astype(self.dtype) newflagger = self.copy(flags=flags) return newflagger - def merge(self, other: BaseFlaggerT, subset: Optional[List] = None, join: str = "merge", inplace=False): + def merge(self, other: BaseFlagger, subset: Optional[List] = None, join: str = "merge", inplace=False) -> BaseFlagger: """ Merge the given flagger 'other' into self """ @@ -92,13 +94,13 @@ class BaseFlagger(ABC): else: return self.copy(flags=mergeDios(self._flags, other._flags, subset=subset, join=join)) - def slice(self, field: FieldsT = None, loc: LocT = None, drop: FieldsT = None, inplace=False) -> BaseFlaggerT: + def slice(self, field: FieldsT = None, loc: LocT = None, drop: FieldsT = None, inplace=False) -> BaseFlagger: """ Return a potentially trimmed down copy of self. """ if drop is not None: if field is not None: raise TypeError("either 'field' or 'drop' can be given, but not both") field = self._flags.columns.drop(drop, errors="ignore") - flags = self.getFlags(field=field, loc=loc).to_dios() + flags = self.getFlags(field=field, loc=loc).to_dios() # type: ignore if inplace: self._flags = flags @@ -120,7 +122,7 @@ class BaseFlagger(ABC): """ return self._flags.to_df() - def getFlags(self, field: FieldsT = None, loc: LocT = None, full=False): + def getFlags(self, field: FieldsT = None, loc: LocT = None, full=False) -> Union[PandasT, Tuple[DictOfSeries, Dict[str, PandasT]]]: """ Return a potentially, to `loc`, trimmed down version of flags. Parameters @@ -155,13 +157,12 @@ class BaseFlagger(ABC): # loc should be a valid 2D-indexer and # then field must be None. Otherwise aloc # will fail and throw the correct Error. - if isinstance(loc, diosT) and field is None: + if isinstance(loc, DictOfSeries) and field is None: indexer = loc - else: - loc = slice(None) if loc is None else loc - field = slice(None) if field is None else self._check_field(field) - indexer = (loc, field) + rows = slice(None) if loc is None else loc + cols = slice(None) if field is None else self._check_field(field) + indexer = (rows, cols) # this is a bug in `dios.aloc`, which may return a shallow copied dios, if `slice(None)` is passed # as row indexer. Thus is because pandas `.loc` return a shallow copy if a null-slice is passed to a series. @@ -183,7 +184,7 @@ class BaseFlagger(ABC): flag_before: Union[str, int] = None, win_flag: FlagT = None, **kwargs - ) -> BaseFlaggerT: + ) -> BaseFlagger: """Overwrite existing flags at loc. If `force=False` (default) only flags with a lower priority are overwritten, @@ -212,7 +213,7 @@ class BaseFlagger(ABC): trimmed = self.getFlags(field=field, loc=loc) if force: - mask = pd.Series(True, index=trimmed.index, dtype=bool) + mask = pd.Series(True, index=trimmed.index, dtype=bool) # type: ignore / `trimmed` is asserted to be a Series else: mask = trimmed < flag @@ -298,7 +299,7 @@ class BaseFlagger(ABC): return mask, win_flag - def clearFlags(self, field: str, loc: LocT = None, inplace=False, **kwargs) -> BaseFlaggerT: + def clearFlags(self, field: str, loc: LocT = None, inplace: bool=False, **kwargs) -> BaseFlagger: assertScalar("field", field, optional=False) if "force" in kwargs: raise ValueError("Keyword 'force' is not allowed here.") @@ -306,7 +307,7 @@ class BaseFlagger(ABC): raise ValueError("Keyword 'flag' is not allowed here.") return self.setFlags(field=field, loc=loc, flag=self.UNFLAGGED, force=True, inplace=inplace, **kwargs) - def isFlagged(self, field=None, loc: LocT = None, flag: FlagT = None, comparator: str = ">") -> PandasT: + def isFlagged(self, field=None, loc: LocT = None, flag: FlagT = None, comparator: str=">") -> PandasT: """ Returns boolean data that indicate where data has been flagged. @@ -339,12 +340,12 @@ class BaseFlagger(ABC): raise TypeError("flag: pd.Series is not allowed") flags_to_compare = set(toSequence(flag, self.GOOD)) - flags = self.getFlags(field, loc) + flags = self.getFlags(field, loc, full=False) cp = COMPARATOR_MAP[comparator] # notna() to prevent nans to become True, # eg.: `np.nan != 0 -> True` - flagged = flags.notna() + flagged = flags.notna() # type: ignore / we asserted, that flags is of `PandasT` # passing an empty list must result # in a everywhere-False data @@ -358,7 +359,7 @@ class BaseFlagger(ABC): return flagged - def copy(self, flags=None) -> BaseFlaggerT: + def copy(self, flags: Optional[PandasT]=None) -> BaseFlagger: if flags is None: out = deepcopy(self) else: @@ -391,7 +392,7 @@ class BaseFlagger(ABC): # version of it. return flag == self.BAD or flag == self.GOOD or flag == self.UNFLAGGED or self.isSUSPICIOUS(flag) - def replaceField(self, field, flags, inplace=False, **kwargs): + def replaceField(self, field: str, flags: Optional[pd.Series], inplace: bool=False, **kwargs) -> BaseFlagger: """ Replace or delete all data for a given field. Parameters @@ -436,7 +437,7 @@ class BaseFlagger(ABC): out._flags[field] = flags.astype(self.dtype) return out - def _check_field(self, field): + def _check_field(self, field: Union[str, Sequence[str]]) -> Union[str, Sequence[str]]: """ Check if (all) field(s) in self._flags. """ # wait for outcome of diff --git a/saqc/flagger/categoricalflagger.py b/saqc/flagger/categoricalflagger.py index 20d2680343bfc659a5e95809732852b26e913a23..598cdb8681ad1c0c5696217abaaabf76936be630 100644 --- a/saqc/flagger/categoricalflagger.py +++ b/saqc/flagger/categoricalflagger.py @@ -6,7 +6,6 @@ from collections import OrderedDict import pandas as pd from saqc.flagger.baseflagger import BaseFlagger -from saqc.lib.tools import assertDictOfSeries class Flags(pd.CategoricalDtype): diff --git a/saqc/flagger/dmpflagger.py b/saqc/flagger/dmpflagger.py index d4ff7cd5b4d54171df97a66d9d5d13b3812172e8..aabbcf215e5bd275f0e3212c7439673e0e3fd33f 100644 --- a/saqc/flagger/dmpflagger.py +++ b/saqc/flagger/dmpflagger.py @@ -1,5 +1,6 @@ #! /usr/bin/env python # -*- coding: utf-8 -*- +from __future__ import annotations import json from copy import deepcopy @@ -7,11 +8,10 @@ from typing import TypeVar, Optional, List import pandas as pd -import dios +from dios import DictOfSeries -from saqc.flagger.baseflagger import diosT from saqc.flagger.categoricalflagger import CategoricalFlagger -from saqc.lib.tools import assertScalar, mergeDios, mutateIndex +from saqc.lib.tools import assertScalar, mergeDios DmpFlaggerT = TypeVar("DmpFlaggerT") @@ -63,7 +63,7 @@ class DmpFlagger(CategoricalFlagger): out = out.reorder_levels(order=[1, 0], axis=1).sort_index(axis=1, level=0, sort_remaining=False) return out - def initFlags(self, data: dios.DictOfSeries = None, flags: dios.DictOfSeries = None): + def initFlags(self, data: DictOfSeries = None, flags: DictOfSeries = None): """ initialize a flagger based on the given 'data' or 'flags' if 'data' is not None: return a flagger with flagger.UNFALGGED values @@ -84,7 +84,7 @@ class DmpFlagger(CategoricalFlagger): newflagger._comments = self._comments.aloc[flags, ...] return newflagger - def merge(self, other: DmpFlaggerT, subset: Optional[List] = None, join: str = "merge", inplace=False): + def merge(self, other: DmpFlagger, subset: Optional[List] = None, join: str = "merge", inplace=False): assert isinstance(other, DmpFlagger) flags = mergeDios(self._flags, other._flags, subset=subset, join=join) causes = mergeDios(self._causes, other._causes, subset=subset, join=join) @@ -101,7 +101,7 @@ class DmpFlagger(CategoricalFlagger): # loc should be a valid 2D-indexer and # then field must be None. Otherwise aloc # will fail and throw the correct Error. - if isinstance(loc, diosT) and field is None: + if isinstance(loc, DictOfSeries) and field is None: indexer = loc else: loc = slice(None) if loc is None else loc @@ -235,7 +235,7 @@ class DmpFlagger(CategoricalFlagger): out._comments[field] = comments.astype(str) return out - def _construct_new(self, flags, causes, comments) -> DmpFlaggerT: + def _construct_new(self, flags, causes, comments) -> DmpFlagger: new = DmpFlagger() new._global_comments = self._global_comments new._flags = flags