Skip to content
Snippets Groups Projects
Commit 0a14d0b5 authored by Bert Palm's avatar Bert Palm 🎇
Browse files

added copy, method doku, force method, restructured class, added tests

parent 3737e639
No related branches found
No related tags found
1 merge request!218Flags
......@@ -7,7 +7,7 @@ import dios
from saqc.flagger.history import History
import numpy as np
import pandas as pd
from typing import Union, Dict, DefaultDict, Iterable, Tuple, Optional
from typing import Union, Dict, DefaultDict, Iterable, Tuple, Optional, Type
UNTOUCHED = np.nan
UNFLAGGED = 0
......@@ -101,6 +101,11 @@ class Flags:
if k in result:
raise ValueError('raw_data must not have duplicate keys')
# No, means no ! (copy)
if isinstance(item, History) and not copy:
result[k] = item
continue
if isinstance(item, pd.Series):
item = item.to_frame(name=0)
elif isinstance(item, History):
......@@ -112,30 +117,12 @@ class Flags:
return result
def __getitem__(self, key: str) -> pd.Series:
if key not in self._cache:
self._cache[key] = self._data[key].max()
return self._cache[key].copy()
def __setitem__(self, key: str, value: pd.Series):
if key not in self._data:
hist = History()
else:
hist = self._data[key]
hist.append(value)
self._cache.pop(key, None)
def __delitem__(self, key):
del self._data[key]
self._cache.pop(key, None)
@property
def _constructor(self) -> Type['Flags']:
return Flags
def drop(self, key):
self.__delitem__(key)
# ----------------------------------------------------------------------
# mata data
@property
def columns(self) -> pd.Index:
......@@ -166,10 +153,104 @@ class Flags:
self._data = _data
self._cache = _cache
@property
def empty(self) -> bool:
return len(self._data) == 0
def __len__(self) -> int:
return len(self._data)
# ----------------------------------------------------------------------
# item access
def __getitem__(self, key: str) -> pd.Series:
if key not in self._cache:
self._cache[key] = self._data[key].max()
return self._cache[key].copy()
def __setitem__(self, key: str, value: pd.Series, force=False):
# force is internal available only
if key not in self._data:
hist = History()
else:
hist = self._data[key]
hist.append(value, force=force)
self._cache.pop(key, None)
def force(self, key: str, value: pd.Series) -> Flags:
"""
Overwrite existing flags, regardless if they are better
or worse than the existing flags.
Parameters
----------
key : str
column name
value : pandas.Series
A series of float flags to force
Returns
-------
Flags
the same flags object with altered flags, no copy
"""
self.__setitem__(key, value, force=True)
return self
def __delitem__(self, key):
del self._data[key]
self._cache.pop(key, None)
def drop(self, key: str):
"""
Delete a flags column.
Parameters
----------
key : str
column name
Returns
-------
Flags
the same flags object with dropeed column, no copy
"""
self.__delitem__(key)
# ----------------------------------------------------------------------
# accessor
@property
def history(self) -> _HistAccess:
return _HistAccess(self)
# ----------------------------------------------------------------------
# copy
def copy(self, deep=True):
return self._constructor(self, copy=deep)
def __copy__(self, deep=True):
return self.copy(deep=deep)
def __deepcopy__(self, memo=None):
"""
Parameters
----------
memo, default None
Standard signature. Unused
"""
return self.copy(deep=True)
# ----------------------------------------------------------------------
# transformation and representation
def to_dios(self) -> dios.DictOfSeries:
di = dios.DictOfSeries(columns=self.columns)
......@@ -181,13 +262,6 @@ class Flags:
def to_frame(self) -> pd.DataFrame:
return self.to_dios().to_df()
@property
def empty(self) -> bool:
return len(self._data) == 0
def __len__(self) -> int:
return len(self._data)
def __repr__(self) -> str:
return str(self.to_dios()).replace('DictOfSeries', type(self).__name__)
......
......@@ -5,6 +5,10 @@ import numpy as np
import pandas as pd
from pandas.api.types import is_bool_dtype
from test.common import TESTFLAGGER, initData
from test.flagger.test_history import (
History,
is_equal as hist_equal,
)
from saqc.flagger.flags import Flags
_data = [
......@@ -47,6 +51,115 @@ def test_init(data: np.array):
assert len(data.keys()) == len(flags)
def is_equal(f1, f2):
assert f1.columns.equals(f2.columns)
for c in f1.columns:
assert hist_equal(f1.history[c], f2.history[c])
@pytest.mark.parametrize('data', data)
def test_copy(data: np.array):
flags = Flags(data)
shallow = flags.copy(deep=False)
deep = flags.copy(deep=True)
# checks
for copy in [deep, shallow]:
assert isinstance(copy, Flags)
assert copy is not flags
assert copy._data is not flags._data
is_equal(copy, flags)
assert deep is not shallow
is_equal(deep, shallow)
for c in shallow.columns:
assert shallow._data[c] is flags._data[c]
for c in deep.columns:
assert deep._data[c] is not flags._data[c]
@pytest.mark.parametrize('data', data)
def test_flags_history(data: np.array):
flags = Flags(data)
# get
for c in flags.columns:
hist = flags.history[c]
assert isinstance(hist, History)
assert len(hist) > 0
# set
for c in flags.columns:
hist = flags.history[c]
hlen = len(hist)
hist.append(pd.Series(888., index=hist.index, dtype=float))
flags.history[c] = hist
assert isinstance(hist, History)
assert len(hist) == hlen + 1
@pytest.mark.parametrize('data', data)
def test_get_flags(data: np.array):
flags = Flags(data)
for c in flags.columns:
# check obvious
var = flags[c]
assert isinstance(var, pd.Series)
assert not var.empty
assert var.equals(flags._data[c].max())
# always a copy
assert var is not flags[c]
# in particular, a deep copy
var[:] = 9999.
assert all(flags[c] != var)
@pytest.mark.parametrize('data', data)
def test_set_flags_and_force(data: np.array):
flags = Flags(data)
for c in flags.columns:
var = flags[c]
hlen = len(flags.history[c])
new = pd.Series(9999., index=var.index, dtype=float)
flags[c] = new
assert len(flags.history[c]) == hlen + 1
assert all(flags.history[c].max() == 9999.)
assert all(flags.history[c].max() == flags[c])
# check if deep-copied correctly
new[:] = 8888.
assert all(flags.history[c].max() == 9999.)
# no overwrite if flag-values are not worse
flags[c] = new
assert len(flags.history[c]) == hlen + 2
assert all(flags.history[c].max() == 9999.)
assert all(flags.history[c].max() == flags[c])
# but overwrite with force
flags.force(c, new)
assert len(flags.history[c]) == hlen + 3
assert all(flags.history[c].max() == 8888.)
assert all(flags.history[c].max() == flags[c])
# check if deep-copied correctly
new[:] = 7777.
assert all(flags.history[c].max() == 8888.)
@pytest.mark.parametrize('data', data)
def test_force_flags(data: np.array):
pass
def test_cache():
arr = np.array([
[0, 0, 0, 0],
......@@ -77,3 +190,28 @@ def test_cache():
for c in flags.columns:
assert c in flags._cache
def _validate_flags_equals_frame(flags, df):
assert df.columns.equals(flags.columns)
for c in flags.columns:
assert df[c].index.equals(flags[c].index)
assert df[c].equals(flags[c]) # respects nan's
@pytest.mark.parametrize('data', data)
def test_to_dios(data: np.array):
flags = Flags(data)
df = flags.to_dios()
assert isinstance(df, dios.DictOfSeries)
_validate_flags_equals_frame(flags, df)
@pytest.mark.parametrize('data', data)
def test_to_frame(data: np.array):
flags = Flags(data)
df = flags.to_frame()
assert isinstance(df, pd.DataFrame)
_validate_flags_equals_frame(flags, df)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment