From c498f0b4f125832606ca6a1cca71d90c484598a3 Mon Sep 17 00:00:00 2001
From: Bert Palm <bert.palm@ufz.de>
Date: Fri, 26 Feb 2021 14:34:45 +0100
Subject: [PATCH] init core adjusted

---
 saqc/core/core.py          | 62 ++++++++++++++++++--------------------
 saqc/flagger/__init__.py   |  2 ++
 saqc/flagger/flags.py      |  4 +++
 test/core/test_core_new.py | 20 ++++++++++++
 4 files changed, 55 insertions(+), 33 deletions(-)
 create mode 100644 test/core/test_core_new.py

diff --git a/saqc/core/core.py b/saqc/core/core.py
index 8a8b6283c..df90d80a0 100644
--- a/saqc/core/core.py
+++ b/saqc/core/core.py
@@ -10,7 +10,7 @@ TODOS:
 
 import logging
 import copy as stdcopy
-from typing import List, Tuple, Sequence
+from typing import List, Tuple, Sequence, Union
 from typing_extensions import Literal
 
 import pandas as pd
@@ -19,7 +19,8 @@ import numpy as np
 import timeit
 import inspect
 
-from saqc.flagger import BaseFlagger, CategoricalFlagger, SimpleFlagger, DmpFlagger
+from saqc.common import *
+from saqc.flagger.flags import init_flags_like, Flagger
 from saqc.core.lib import APIController, ColumnSelector
 from saqc.core.register import FUNC_MAP, SaQCFunction
 from saqc.core.modules import FuncModules
@@ -49,7 +50,8 @@ def _handleErrors(exc: Exception, field: str, control: APIController, func: SaQC
         raise exc
 
 
-def _prepInput(flagger, data, flags):
+# todo: shouldt this go to Saqc.__init__ ?
+def _prepInput(data, flags):
     dios_like = (dios.DictOfSeries, pd.DataFrame)
 
     if isinstance(data, pd.Series):
@@ -66,30 +68,23 @@ def _prepInput(flagger, data, flags):
     if not hasattr(data.columns, "str"):
         raise TypeError("expected dataframe columns of type string")
 
-    if not isinstance(flagger, BaseFlagger):
-        # NOTE: we should generate that list automatically,
-        #       it won't ever be complete otherwise
-        flaggerlist = [CategoricalFlagger, SimpleFlagger, DmpFlagger]
-        raise TypeError(f"'flagger' must be of type {flaggerlist} or a subclass of {BaseFlagger}")
-
     if flags is not None:
-        if not isinstance(flags, dios_like):
-            raise TypeError("'flags' must be of type dios.DictOfSeries or pd.DataFrame")
 
         if isinstance(flags, pd.DataFrame):
             if isinstance(flags.index, pd.MultiIndex) or isinstance(flags.columns, pd.MultiIndex):
                 raise TypeError("'flags' should not use MultiIndex")
-            flags = dios.to_dios(flags)
 
-        # NOTE: do not test all columns as they not necessarily need to be the same
-        cols = flags.columns & data.columns
-        if not (flags[cols].lengths == data[cols].lengths).all():
-            raise ValueError("the length of 'flags' and 'data' need to be equal")
+        if isinstance(flags, (dios.DictOfSeries, pd.DataFrame, Flagger)):
+            # NOTE: only test common columns, data as well as flags could
+            # have more columns than the respective other.
+            cols = flags.columns & data.columns
+            for c in cols:
+                if not flags[c].index.equals(data[c].index):
+                    raise ValueError(f"the index of 'flags' and 'data' missmatch in column {c}")
 
-    if flagger.initialized:
-        diff = data.columns.difference(flagger.getFlags().columns)
-        if not diff.empty:
-            raise ValueError("Missing columns in 'flagger': '{list(diff)}'")
+        # this also ensures float dtype
+        if not isinstance(flags, Flagger):
+            flags = Flagger(flags, copy=True)
 
     return data, flags
 
@@ -110,31 +105,32 @@ _setup()
 
 class SaQC(FuncModules):
 
-    def __init__(self, flagger, data, flags=None, nodata=np.nan, to_mask=None, error_policy="raise"):
+    def __init__(self, data, flags=None, nodata=np.nan, to_mask=None, error_policy="raise"):
         super().__init__(self)
-        data, flags = _prepInput(flagger, data, flags)
+        data, flagger = _prepInput(data, flags)
         self._data = data
         self._nodata = nodata
         self._to_mask = to_mask
-        self._flagger = self._initFlagger(data, flagger, flags)
+        self._flagger = self._initFlagger(data, flags)
         self._error_policy = error_policy
         # NOTE: will be filled by calls to `_wrap`
         self._to_call: List[Tuple[ColumnSelector, APIController, SaQCFunction]] = []
 
-    def _initFlagger(self, data, flagger, flags):
+    def _initFlagger(self, data, flagger: Union[Flagger, None]):
         """ Init the internal flagger object.
 
         Ensures that all data columns are present and user passed flags from
-        a flags frame and/or an already initialised flagger are used.
-        If columns overlap the passed flagger object is prioritised.
+        a flags frame or an already initialised flagger are used.
         """
-        # ensure all data columns
-        merged = flagger.initFlags(data)
-        if flags is not None:
-            merged = merged.merge(flagger.initFlags(flags=flags), inplace=True)
-        if flagger.initialized:
-            merged = merged.merge(flagger, inplace=True)
-        return merged
+        if flagger is None:
+            return init_flags_like(data)
+
+        for c in flagger.columns.union(data.columns):
+            if c in flagger:
+                continue
+            if c in data:
+                flagger[c] = pd.Series(UNFLAGGED, index=data[c].index, dtype=float)
+        return flagger
 
     def readConfig(self, fname):
         from saqc.core.reader import readConfig
diff --git a/saqc/flagger/__init__.py b/saqc/flagger/__init__.py
index d5124fb9d..774f2ec2b 100644
--- a/saqc/flagger/__init__.py
+++ b/saqc/flagger/__init__.py
@@ -1,6 +1,8 @@
 #! /usr/bin/env python
 # -*- coding: utf-8 -*-
 
+from .flags import Flagger
+from .history import History
 from saqc.flagger.baseflagger import BaseFlagger
 from saqc.flagger.categoricalflagger import CategoricalFlagger
 from saqc.flagger.simpleflagger import SimpleFlagger
diff --git a/saqc/flagger/flags.py b/saqc/flagger/flags.py
index bf64ec556..15b8a4efc 100644
--- a/saqc/flagger/flags.py
+++ b/saqc/flagger/flags.py
@@ -311,3 +311,7 @@ def init_flags_like(reference: Union[pd.Series, DictLike, Flags], initial_value:
 
     return Flags(result)
 
+
+# for now we keep this name
+Flagger = Flags
+
diff --git a/test/core/test_core_new.py b/test/core/test_core_new.py
new file mode 100644
index 000000000..b16714c8e
--- /dev/null
+++ b/test/core/test_core_new.py
@@ -0,0 +1,20 @@
+#!/usr/bin/env python
+
+import pandas as pd
+import numpy as np
+import dios
+
+
+def test_init():
+    from saqc import SaQC, Flagger
+
+    arr = np.array([
+        [0, 1, 2],
+        [0, 1, 3],
+    ])
+    data = pd.DataFrame(arr, columns=list('abc'))
+    qc = SaQC(data)
+
+    assert isinstance(qc, SaQC)
+    assert isinstance(qc._flagger, Flagger)
+    assert isinstance(qc._data, dios.DictOfSeries)
-- 
GitLab