From 3b2264432377b195661e474839faafc4cd1ee0dc Mon Sep 17 00:00:00 2001
From: Bert Palm <bert.palm@ufz.de>
Date: Sun, 16 Feb 2020 02:57:48 +0100
Subject: [PATCH] finest setitem

---
 dios/dios.py    | 155 ++++++++++++++++++++++++++----------------------
 dios/options.py |  31 ++++------
 2 files changed, 95 insertions(+), 91 deletions(-)

diff --git a/dios/dios.py b/dios/dios.py
index e2bf03f..2d2d2dc 100644
--- a/dios/dios.py
+++ b/dios/dios.py
@@ -24,16 +24,20 @@ from pandas.core.dtypes.common import is_iterator as _is_iterator
 def is_dios_like(obj):
     return isinstance(obj, DictOfSeries)
 
+
 def is_pandas_like(obj):
     """We consider ourselfs (dios) as pandas-like"""
     return is_series_like(obj) or is_dataframe_like(obj) or is_dios_like(obj)
 
+
 def is_series_like(obj):
     return isinstance(obj, pd.Series)
 
+
 def is_dataframe_like(obj):
     return isinstance(obj, pd.DataFrame)
 
+
 def is_iterator(obj):
     """ This is only a dummy wrapper, to warn that the docu of this isnt't right.
     Unlike the example says,
@@ -235,29 +239,39 @@ class DictOfSeries:
             in the ``options`` dictionary.
           - [3] If ``iterable`` contains any(!) label that does not exist, a KeyError is raised.
         """
+        # special case: insert a fresh new key
+        if isinstance(key, str) and key not in self.columns:
+            self._insert(key, value)
+            return
 
-        # prepare
-        if is_iterator(key):
-            key = list(key)
+        k, i = self._get_keys_and_indexer(key)
+        gen = self._setitem_stage2(k,i, value)
+        for tup in gen:
+            self._set_item(*tup)
 
+    def _get_keys_and_indexer(self, key):
+        """ Determine keys and indexer
+        Notes:
+            Which keys we get, depends on the policy in dios_options
+        """
+
+        err_bool = "only boolen values are allowed"
         keys = None
         indexers = None
+        blowup = False
 
-        ki = dict()
-        # determine action by keys
+        # prevent consuming of a generator
+        if is_iterator(key):
+            key = list(key)
 
         if isinstance(key, str):
-            # special case: insert a fresh new key
             if key not in self.columns:
-                self._setitem_new(key, value)
-                return
-            else:
-                ki[key] = slice(None)
-                keys = [key]
+                raise KeyError(key)
+            keys = [key]
 
         elif isinstance(key, slice):
             keys = self.columns
-            indexers = [key]
+            indexers, blowup = [key], True
 
         # list, np.arrays, ... of list, np.arrays..
         elif is_nested_list_like(key):
@@ -267,59 +281,81 @@ class DictOfSeries:
             for i in range(len(key)):
                 arr = np.array(i)
                 if not is_bool_array(arr):
-                    raise ValueError("Must pass nested-list-like with boolean values only")
+                    raise ValueError("nested list: " + err_bool)
                 indexers.append(arr)
 
         # ser, df, dios
         elif is_pandas_like(key):
             if is_series_like(key):
                 keys = key.to_list()
+
             elif is_dataframe_like(key):
                 keys = key.columns.to_list()
                 indexers = key.values
-                testbool = True
+                if not is_bool_array(indexers):
+                    raise ValueError("df: " + err_bool)
+
             elif is_dios_like(key):
                 keys = key.columns
                 indexers = list(key.values)
+                if not is_bool_array(indexers):
+                    raise ValueError("dios: " + err_bool)
 
         # list, np.array, np.ndarray, ...
+        # Note: series considered list-like,
+        # so we handle lists last
         elif is_list_like(key):
             arr = np.array(key)
             if is_bool_array(arr):
                 keys = self.columns
-                indexers = [arr]
+                if len(arr) != len(keys):
+                    keys = np.array(keys)[arr]
             else:
                 keys = key
 
-
         else:
             raise KeyError(f"{key}")
 
-        if length != len(keys):
-            raise ValueError(f"Length mismatch for nested list: expected {len(keys)}, got {length}")
-        if not indexers:
-            indexers = [slice(None)]
-        if len(indexers) == 1:
-            indexers = indexers * len(keys)
-
-        assert len(indexers) == len(keys)
-        # now we have a indexer for every series
-        # determine action by value
-
-        if isinstance(value, DictOfSeries):
-            method = dios_options[Options.dios_to_dios_method]
-            keys = get_dios_to_dios_keys(keys, value, method)
-            for k in keys:
-                self._setitem(k, value[k], sl=kslicer)
-        else:
-
-            if is_iterator(value):
-                value = list(value)
+        # check keys
+        method = dios_options[Options.dios_to_dios_method]
+        keys = check_keys_by_policy(keys, self.columns, method)
 
-            for k in keys:
-                self._setitem(k, value, sl=kslicer)
+        # check indexer
+        if indexers is None:
+            indexers, blowup = [slice(None)], True
+        if blowup:
+            indexers = indexers * len(keys)
+        if len(indexers) != len(keys):
+            raise ValueError
+
+        # now we have a valid indexer (a slice or a bool array) for every series
+        return keys, indexers
+
+    def _setitem_stage2(self, keys, ixs, val):
+        "determine looping and .."
+
+        if is_iterator(val):
+            val = list(val)
+
+        diosl, dfl, nlistl = is_dios_like(val), is_dataframe_like(val), is_nested_list_like(val)
+        if diosl or dfl or nlistl and len(val) != len(keys):
+            raise ValueError(f"could not broadcast input array with length {len(val)}"
+                             f" into dios of length {len(keys)}")
+
+        # now we have everything we need: key, indexer, value
+        # so we just pack it nice and cosy and let setitem
+        # do the dirty work.
+        for i, _ in enumerate(keys):
+            key, ix = keys[i], ixs[i]
+            if dfl or diosl:
+                yield key, ix, val[val.columns[i]]
+            elif nlistl:
+                yield key, ix, val[i]
+            else:
+                yield key, ix, val
 
-    def _setitem_new(self, key, val):
+    def _insert(self, key, val):
+        """"""
         if isinstance(val, DictOfSeries):
             val = val.squeeze()
         elif is_list_like(val) and not is_nested_list_like(val):
@@ -331,39 +367,16 @@ class DictOfSeries:
         val = cast_to_itype(val, self._itype, policy=self._policy)
         self._data[key] = val.copy(deep=True)
 
-    def _setitem(self, key, val, sl=None):
-        """ Set a value or a set of values to a single(!) key in self k"""
-        sl = sl or slice(None)
-
-        # series, dios['a'] = series, 'a' exist !
-        # diosA[slice] = diosB --> dios[slice][k] = diosB[k] for all k
-        if isinstance(val, pd.Series):
-            val = cast_to_itype(val, self._itype, policy=self._policy)
-            left = self._data[key][sl]
-            idx = left.index.intersection(val.index)
-            # l, r = left.align(val, join='inner')
-            if not idx.empty:
-                left.loc[idx] = val.loc[idx].copy()
-            return
-
-        item = self._data[key]
-
-        # label <- scalar: dios['a'] = 3.9 or
-        # slice <- scalar: dios[0:3] = 4.0
-        if is_scalar(val):
-            item[sl] = val
-
-        # label  <- list: dios['a'] = [0.0, 0.3, 0.0]
-        # sclice <- list: dios[0:3] = [0.0, 0.3, 0.0]
-        elif is_list_like(val) and not is_nested_list_like(val):
-            # ensure same size # fixme: is this neccessary, wouldnt pd.Series raise a Valuerror ?
-            if len(item[sl]) == len(val):
-                item[sl] = val
-            else:
-                raise ValueError(f'Length of values does not match length of sliced for the key {key}')
+    def _set_item(self, key, ix, val):
+        "Set a value (scalar or list or series)"
+        ser = self._data[key]
+        if is_series_like(val):
+            left = ser[ix]
+            index = left.index.intersection(val.index)
+            if not index.empty:
+                left.loc[index] = val.loc[index].copy()
         else:
-            raise ValueError(f"assignments with a values of type {type(val)} are not supported")
-        return
+            ser[ix] = val
 
     @property
     def loc(self):
diff --git a/dios/options.py b/dios/options.py
index 9350c8e..a75be73 100644
--- a/dios/options.py
+++ b/dios/options.py
@@ -34,30 +34,21 @@ dios_options = {
 }
 
 
-def get_dios_to_dios_keys(keys, other, method):
+def check_keys_by_policy(check, keys, policy):
 
-    err_append = "consider changing dios.option['dios_to_dios_method']"
+    if policy == OptionsDiosToDios.any_matching:
+        check = [k for k in check if k in keys]
 
-    if method == OptionsDiosToDios.any_matching:
-        keys = [k for k in keys if k in other.columns]
+    elif policy == OptionsDiosToDios.at_least_one:
+        check = [k for k in check if k in keys]
+        if not check:
+            raise KeyError("policy says: at least one key must be shared.")
 
-    elif method == OptionsDiosToDios.at_least_one:
-        keys = [k for k in keys if k in other.columns]
-        if not keys:
-            raise KeyError("src-DioS and dest-DioS need to share at least one key, " + err_append)
-
-    # elif method == 2:
-    #     fail = [k for k in keys if k not in other.columns]
-    #     if fail:
-    #         raise KeyError(f"{fail} are missing in the destiny-dios, " + err_append)
-
-    # keys in both dios's must be equal
     elif OptionsDiosToDios.all_must_match:
-        fail = set(keys).symmetric_difference(set(other.columns))
+        fail = set(check).symmetric_difference(set(keys))
         if fail:
-            raise KeyError(f"{fail} is not in both of src- and dest-dios, " + err_append)
-
+            raise KeyError(f"{fail}. policy says: all keys must be present.")
     else:
-        raise ValueError(method)
+        raise ValueError(policy)
 
-    return keys
+    return check
-- 
GitLab