Skip to content
Snippets Groups Projects
Commit 6a328972 authored by David Schäfer's avatar David Schäfer
Browse files

flagger: typehints

parent cb12c41e
No related branches found
No related tags found
No related merge requests found
#! /usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import annotations
import operator as op
from copy import deepcopy
from abc import ABC, abstractmethod
from typing import TypeVar, Union, Any, List, Optional
from typing import TypeVar, Union, Any, List, Optional, Tuple, Dict, Sequence
import pandas as pd
import numpy as np
import dios
from dios import DictOfSeries
from saqc.lib.tools import assertScalar, mergeDios, toSequence, customRoller
COMPARATOR_MAP = {
"!=": op.ne,
"==": op.eq,
......@@ -22,13 +25,12 @@ COMPARATOR_MAP = {
"<": op.lt,
}
# TODO: get some real types here (could be tricky...)
LocT = Any
LocT = Union[pd.Series, pd.Index, slice]
FlagT = Any
diosT = dios.DictOfSeries
BaseFlaggerT = TypeVar("BaseFlaggerT")
PandasT = Union[pd.Series, diosT]
FieldsT = Union[str, List[str]]
BaseFlaggerT = TypeVar("BaseFlaggerT", bound="BaseFlagger")
PandasT = Union[pd.Series, DictOfSeries]
FieldsT = Union[str, List[str], slice]
class BaseFlagger(ABC):
......@@ -40,7 +42,7 @@ class BaseFlagger(ABC):
# NOTE: the arggumens of setFlags supported from
# the configuration functions
self.signature = ("flag",)
self._flags: Optional[diosT] = None
self._flags: Optional[DictOfSeries] = None
@property
def initialized(self):
......@@ -50,7 +52,7 @@ class BaseFlagger(ABC):
def flags(self):
return self._flags.copy()
def initFlags(self, data: diosT = None, flags: diosT = None) -> BaseFlaggerT:
def initFlags(self, data: DictOfSeries = None, flags: DictOfSeries = None) -> BaseFlagger:
"""
initialize a flagger based on the given 'data' or 'flags'
if 'data' is not None: return a flagger with flagger.UNFLAGGED values
......@@ -64,21 +66,21 @@ class BaseFlagger(ABC):
raise TypeError("either 'data' or 'flags' can be given")
if data is not None:
if not isinstance(data, diosT):
data = dios.DictOfSeries(data)
if not isinstance(data, DictOfSeries):
data = DictOfSeries(data)
flags = dios.DictOfSeries(columns=data.columns)
for c in flags.columns:
flags = DictOfSeries(columns=data.columns)
for c in flags.columns.tolist():
flags[c] = pd.Series(self.UNFLAGGED, index=data[c].index)
else:
if not isinstance(flags, diosT):
flags = dios.DictOfSeries(flags)
if not isinstance(flags, DictOfSeries):
flags = DictOfSeries(flags)
flags = flags.astype(self.dtype)
newflagger = self.copy(flags=flags)
return newflagger
def merge(self, other: BaseFlaggerT, subset: Optional[List] = None, join: str = "merge", inplace=False):
def merge(self, other: BaseFlagger, subset: Optional[List] = None, join: str = "merge", inplace=False) -> BaseFlagger:
"""
Merge the given flagger 'other' into self
"""
......@@ -92,13 +94,13 @@ class BaseFlagger(ABC):
else:
return self.copy(flags=mergeDios(self._flags, other._flags, subset=subset, join=join))
def slice(self, field: FieldsT = None, loc: LocT = None, drop: FieldsT = None, inplace=False) -> BaseFlaggerT:
def slice(self, field: FieldsT = None, loc: LocT = None, drop: FieldsT = None, inplace=False) -> BaseFlagger:
""" Return a potentially trimmed down copy of self. """
if drop is not None:
if field is not None:
raise TypeError("either 'field' or 'drop' can be given, but not both")
field = self._flags.columns.drop(drop, errors="ignore")
flags = self.getFlags(field=field, loc=loc).to_dios()
flags = self.getFlags(field=field, loc=loc).to_dios() # type: ignore
if inplace:
self._flags = flags
......@@ -120,7 +122,7 @@ class BaseFlagger(ABC):
"""
return self._flags.to_df()
def getFlags(self, field: FieldsT = None, loc: LocT = None, full=False):
def getFlags(self, field: FieldsT = None, loc: LocT = None, full=False) -> Union[PandasT, Tuple[DictOfSeries, Dict[str, PandasT]]]:
""" Return a potentially, to `loc`, trimmed down version of flags.
Parameters
......@@ -155,13 +157,12 @@ class BaseFlagger(ABC):
# loc should be a valid 2D-indexer and
# then field must be None. Otherwise aloc
# will fail and throw the correct Error.
if isinstance(loc, diosT) and field is None:
if isinstance(loc, DictOfSeries) and field is None:
indexer = loc
else:
loc = slice(None) if loc is None else loc
field = slice(None) if field is None else self._check_field(field)
indexer = (loc, field)
rows = slice(None) if loc is None else loc
cols = slice(None) if field is None else self._check_field(field)
indexer = (rows, cols)
# this is a bug in `dios.aloc`, which may return a shallow copied dios, if `slice(None)` is passed
# as row indexer. Thus is because pandas `.loc` return a shallow copy if a null-slice is passed to a series.
......@@ -183,7 +184,7 @@ class BaseFlagger(ABC):
flag_before: Union[str, int] = None,
win_flag: FlagT = None,
**kwargs
) -> BaseFlaggerT:
) -> BaseFlagger:
"""Overwrite existing flags at loc.
If `force=False` (default) only flags with a lower priority are overwritten,
......@@ -212,7 +213,7 @@ class BaseFlagger(ABC):
trimmed = self.getFlags(field=field, loc=loc)
if force:
mask = pd.Series(True, index=trimmed.index, dtype=bool)
mask = pd.Series(True, index=trimmed.index, dtype=bool) # type: ignore / `trimmed` is asserted to be a Series
else:
mask = trimmed < flag
......@@ -298,7 +299,7 @@ class BaseFlagger(ABC):
return mask, win_flag
def clearFlags(self, field: str, loc: LocT = None, inplace=False, **kwargs) -> BaseFlaggerT:
def clearFlags(self, field: str, loc: LocT = None, inplace: bool=False, **kwargs) -> BaseFlagger:
assertScalar("field", field, optional=False)
if "force" in kwargs:
raise ValueError("Keyword 'force' is not allowed here.")
......@@ -306,7 +307,7 @@ class BaseFlagger(ABC):
raise ValueError("Keyword 'flag' is not allowed here.")
return self.setFlags(field=field, loc=loc, flag=self.UNFLAGGED, force=True, inplace=inplace, **kwargs)
def isFlagged(self, field=None, loc: LocT = None, flag: FlagT = None, comparator: str = ">") -> PandasT:
def isFlagged(self, field=None, loc: LocT = None, flag: FlagT = None, comparator: str=">") -> PandasT:
"""
Returns boolean data that indicate where data has been flagged.
......@@ -339,12 +340,12 @@ class BaseFlagger(ABC):
raise TypeError("flag: pd.Series is not allowed")
flags_to_compare = set(toSequence(flag, self.GOOD))
flags = self.getFlags(field, loc)
flags = self.getFlags(field, loc, full=False)
cp = COMPARATOR_MAP[comparator]
# notna() to prevent nans to become True,
# eg.: `np.nan != 0 -> True`
flagged = flags.notna()
flagged = flags.notna() # type: ignore / we asserted, that flags is of `PandasT`
# passing an empty list must result
# in a everywhere-False data
......@@ -358,7 +359,7 @@ class BaseFlagger(ABC):
return flagged
def copy(self, flags=None) -> BaseFlaggerT:
def copy(self, flags: Optional[PandasT]=None) -> BaseFlagger:
if flags is None:
out = deepcopy(self)
else:
......@@ -391,7 +392,7 @@ class BaseFlagger(ABC):
# version of it.
return flag == self.BAD or flag == self.GOOD or flag == self.UNFLAGGED or self.isSUSPICIOUS(flag)
def replaceField(self, field, flags, inplace=False, **kwargs):
def replaceField(self, field: str, flags: Optional[pd.Series], inplace: bool=False, **kwargs) -> BaseFlagger:
""" Replace or delete all data for a given field.
Parameters
......@@ -436,7 +437,7 @@ class BaseFlagger(ABC):
out._flags[field] = flags.astype(self.dtype)
return out
def _check_field(self, field):
def _check_field(self, field: Union[str, Sequence[str]]) -> Union[str, Sequence[str]]:
""" Check if (all) field(s) in self._flags. """
# wait for outcome of
......
......@@ -6,7 +6,6 @@ from collections import OrderedDict
import pandas as pd
from saqc.flagger.baseflagger import BaseFlagger
from saqc.lib.tools import assertDictOfSeries
class Flags(pd.CategoricalDtype):
......
#! /usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import annotations
import json
from copy import deepcopy
......@@ -7,11 +8,10 @@ from typing import TypeVar, Optional, List
import pandas as pd
import dios
from dios import DictOfSeries
from saqc.flagger.baseflagger import diosT
from saqc.flagger.categoricalflagger import CategoricalFlagger
from saqc.lib.tools import assertScalar, mergeDios, mutateIndex
from saqc.lib.tools import assertScalar, mergeDios
DmpFlaggerT = TypeVar("DmpFlaggerT")
......@@ -63,7 +63,7 @@ class DmpFlagger(CategoricalFlagger):
out = out.reorder_levels(order=[1, 0], axis=1).sort_index(axis=1, level=0, sort_remaining=False)
return out
def initFlags(self, data: dios.DictOfSeries = None, flags: dios.DictOfSeries = None):
def initFlags(self, data: DictOfSeries = None, flags: DictOfSeries = None):
"""
initialize a flagger based on the given 'data' or 'flags'
if 'data' is not None: return a flagger with flagger.UNFALGGED values
......@@ -84,7 +84,7 @@ class DmpFlagger(CategoricalFlagger):
newflagger._comments = self._comments.aloc[flags, ...]
return newflagger
def merge(self, other: DmpFlaggerT, subset: Optional[List] = None, join: str = "merge", inplace=False):
def merge(self, other: DmpFlagger, subset: Optional[List] = None, join: str = "merge", inplace=False):
assert isinstance(other, DmpFlagger)
flags = mergeDios(self._flags, other._flags, subset=subset, join=join)
causes = mergeDios(self._causes, other._causes, subset=subset, join=join)
......@@ -101,7 +101,7 @@ class DmpFlagger(CategoricalFlagger):
# loc should be a valid 2D-indexer and
# then field must be None. Otherwise aloc
# will fail and throw the correct Error.
if isinstance(loc, diosT) and field is None:
if isinstance(loc, DictOfSeries) and field is None:
indexer = loc
else:
loc = slice(None) if loc is None else loc
......@@ -235,7 +235,7 @@ class DmpFlagger(CategoricalFlagger):
out._comments[field] = comments.astype(str)
return out
def _construct_new(self, flags, causes, comments) -> DmpFlaggerT:
def _construct_new(self, flags, causes, comments) -> DmpFlagger:
new = DmpFlagger()
new._global_comments = self._global_comments
new._flags = flags
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment