From 0a14d0b565fb825cc5e65823dd8a208130f5be96 Mon Sep 17 00:00:00 2001
From: Bert Palm <bert.palm@ufz.de>
Date: Thu, 11 Feb 2021 21:02:22 +0100
Subject: [PATCH] added copy, method doku, force method, restructured class,
 added tests

---
 saqc/flagger/flags.py      | 136 +++++++++++++++++++++++++++---------
 test/flagger/test_flags.py | 138 +++++++++++++++++++++++++++++++++++++
 2 files changed, 243 insertions(+), 31 deletions(-)

diff --git a/saqc/flagger/flags.py b/saqc/flagger/flags.py
index 1b75d08e6..21ef9d698 100644
--- a/saqc/flagger/flags.py
+++ b/saqc/flagger/flags.py
@@ -7,7 +7,7 @@ import dios
 from saqc.flagger.history import History
 import numpy as np
 import pandas as pd
-from typing import Union, Dict, DefaultDict, Iterable, Tuple, Optional
+from typing import Union, Dict, DefaultDict, Iterable, Tuple, Optional, Type
 
 UNTOUCHED = np.nan
 UNFLAGGED = 0
@@ -101,6 +101,11 @@ class Flags:
             if k in result:
                 raise ValueError('raw_data must not have duplicate keys')
 
+            # No, means no ! (copy)
+            if isinstance(item, History) and not copy:
+                result[k] = item
+                continue
+
             if isinstance(item, pd.Series):
                 item = item.to_frame(name=0)
             elif isinstance(item, History):
@@ -112,30 +117,12 @@ class Flags:
 
         return result
 
-    def __getitem__(self, key: str) -> pd.Series:
-
-        if key not in self._cache:
-            self._cache[key] = self._data[key].max()
-
-        return self._cache[key].copy()
-
-    def __setitem__(self, key: str, value: pd.Series):
-
-        if key not in self._data:
-            hist = History()
-
-        else:
-            hist = self._data[key]
-
-        hist.append(value)
-        self._cache.pop(key, None)
-
-    def __delitem__(self, key):
-        del self._data[key]
-        self._cache.pop(key, None)
+    @property
+    def _constructor(self) -> Type['Flags']:
+        return Flags
 
-    def drop(self, key):
-        self.__delitem__(key)
+    # ----------------------------------------------------------------------
+    # mata data
 
     @property
     def columns(self) -> pd.Index:
@@ -166,10 +153,104 @@ class Flags:
         self._data = _data
         self._cache = _cache
 
+    @property
+    def empty(self) -> bool:
+        return len(self._data) == 0
+
+    def __len__(self) -> int:
+        return len(self._data)
+
+    # ----------------------------------------------------------------------
+    # item access
+
+    def __getitem__(self, key: str) -> pd.Series:
+
+        if key not in self._cache:
+            self._cache[key] = self._data[key].max()
+
+        return self._cache[key].copy()
+
+    def __setitem__(self, key: str, value: pd.Series, force=False):
+        # force is internal available only
+
+        if key not in self._data:
+            hist = History()
+
+        else:
+            hist = self._data[key]
+
+        hist.append(value, force=force)
+        self._cache.pop(key, None)
+
+    def force(self, key: str, value: pd.Series) -> Flags:
+        """
+        Overwrite existing flags, regardless if they are better
+        or worse than the existing flags.
+
+        Parameters
+        ----------
+        key : str
+            column name
+
+        value : pandas.Series
+            A series of float flags to force
+
+        Returns
+        -------
+        Flags
+            the same flags object with altered flags, no copy
+        """
+        self.__setitem__(key, value, force=True)
+        return self
+
+    def __delitem__(self, key):
+        del self._data[key]
+        self._cache.pop(key, None)
+
+    def drop(self, key: str):
+        """
+        Delete a flags column.
+
+        Parameters
+        ----------
+        key : str
+            column name
+
+        Returns
+        -------
+        Flags
+            the same flags object with dropeed column, no copy
+        """
+        self.__delitem__(key)
+
+    # ----------------------------------------------------------------------
+    # accessor
+
     @property
     def history(self) -> _HistAccess:
         return _HistAccess(self)
 
+    # ----------------------------------------------------------------------
+    # copy
+
+    def copy(self, deep=True):
+        return self._constructor(self, copy=deep)
+
+    def __copy__(self, deep=True):
+        return self.copy(deep=deep)
+
+    def __deepcopy__(self, memo=None):
+        """
+        Parameters
+        ----------
+        memo, default None
+            Standard signature. Unused
+        """
+        return self.copy(deep=True)
+
+    # ----------------------------------------------------------------------
+    # transformation and representation
+
     def to_dios(self) -> dios.DictOfSeries:
         di = dios.DictOfSeries(columns=self.columns)
 
@@ -181,13 +262,6 @@ class Flags:
     def to_frame(self) -> pd.DataFrame:
         return self.to_dios().to_df()
 
-    @property
-    def empty(self) -> bool:
-        return len(self._data) == 0
-
-    def __len__(self) -> int:
-        return len(self._data)
-
     def __repr__(self) -> str:
         return str(self.to_dios()).replace('DictOfSeries', type(self).__name__)
 
diff --git a/test/flagger/test_flags.py b/test/flagger/test_flags.py
index 7dcfc97dc..999f3f29d 100644
--- a/test/flagger/test_flags.py
+++ b/test/flagger/test_flags.py
@@ -5,6 +5,10 @@ import numpy as np
 import pandas as pd
 from pandas.api.types import is_bool_dtype
 from test.common import TESTFLAGGER, initData
+from test.flagger.test_history import (
+    History,
+    is_equal as hist_equal,
+)
 from saqc.flagger.flags import Flags
 
 _data = [
@@ -47,6 +51,115 @@ def test_init(data: np.array):
     assert len(data.keys()) == len(flags)
 
 
+def is_equal(f1, f2):
+    assert f1.columns.equals(f2.columns)
+    for c in f1.columns:
+        assert hist_equal(f1.history[c], f2.history[c])
+
+
+@pytest.mark.parametrize('data', data)
+def test_copy(data: np.array):
+    flags = Flags(data)
+    shallow = flags.copy(deep=False)
+    deep = flags.copy(deep=True)
+
+    # checks
+
+    for copy in [deep, shallow]:
+        assert isinstance(copy, Flags)
+        assert copy is not flags
+        assert copy._data is not flags._data
+        is_equal(copy, flags)
+
+    assert deep is not shallow
+    is_equal(deep, shallow)
+
+    for c in shallow.columns:
+        assert shallow._data[c] is flags._data[c]
+
+    for c in deep.columns:
+        assert deep._data[c] is not flags._data[c]
+
+
+@pytest.mark.parametrize('data', data)
+def test_flags_history(data: np.array):
+    flags = Flags(data)
+
+    # get
+    for c in flags.columns:
+        hist = flags.history[c]
+        assert isinstance(hist, History)
+        assert len(hist) > 0
+
+    # set
+    for c in flags.columns:
+        hist = flags.history[c]
+        hlen = len(hist)
+        hist.append(pd.Series(888., index=hist.index, dtype=float))
+        flags.history[c] = hist
+        assert isinstance(hist, History)
+        assert len(hist) == hlen + 1
+
+
+@pytest.mark.parametrize('data', data)
+def test_get_flags(data: np.array):
+    flags = Flags(data)
+
+    for c in flags.columns:
+        # check obvious
+        var = flags[c]
+        assert isinstance(var, pd.Series)
+        assert not var.empty
+        assert var.equals(flags._data[c].max())
+
+        # always a copy
+        assert var is not flags[c]
+
+        # in particular, a deep copy
+        var[:] = 9999.
+        assert all(flags[c] != var)
+
+
+@pytest.mark.parametrize('data', data)
+def test_set_flags_and_force(data: np.array):
+    flags = Flags(data)
+
+    for c in flags.columns:
+        var = flags[c]
+        hlen = len(flags.history[c])
+        new = pd.Series(9999., index=var.index, dtype=float)
+
+        flags[c] = new
+        assert len(flags.history[c]) == hlen + 1
+        assert all(flags.history[c].max() == 9999.)
+        assert all(flags.history[c].max() == flags[c])
+
+        # check if deep-copied correctly
+        new[:] = 8888.
+        assert all(flags.history[c].max() == 9999.)
+
+        # no overwrite if flag-values are not worse
+        flags[c] = new
+        assert len(flags.history[c]) == hlen + 2
+        assert all(flags.history[c].max() == 9999.)
+        assert all(flags.history[c].max() == flags[c])
+
+        # but overwrite with force
+        flags.force(c, new)
+        assert len(flags.history[c]) == hlen + 3
+        assert all(flags.history[c].max() == 8888.)
+        assert all(flags.history[c].max() == flags[c])
+
+        # check if deep-copied correctly
+        new[:] = 7777.
+        assert all(flags.history[c].max() == 8888.)
+
+
+@pytest.mark.parametrize('data', data)
+def test_force_flags(data: np.array):
+    pass
+
+
 def test_cache():
     arr = np.array([
         [0, 0, 0, 0],
@@ -77,3 +190,28 @@ def test_cache():
     for c in flags.columns:
         assert c in flags._cache
 
+
+def _validate_flags_equals_frame(flags, df):
+    assert df.columns.equals(flags.columns)
+
+    for c in flags.columns:
+        assert df[c].index.equals(flags[c].index)
+        assert df[c].equals(flags[c])  # respects nan's
+
+
+@pytest.mark.parametrize('data', data)
+def test_to_dios(data: np.array):
+    flags = Flags(data)
+    df = flags.to_dios()
+
+    assert isinstance(df, dios.DictOfSeries)
+    _validate_flags_equals_frame(flags, df)
+
+
+@pytest.mark.parametrize('data', data)
+def test_to_frame(data: np.array):
+    flags = Flags(data)
+    df = flags.to_frame()
+
+    assert isinstance(df, pd.DataFrame)
+    _validate_flags_equals_frame(flags, df)
-- 
GitLab