From 13f833a9f4700e5a9ebe4797b00c61daa20e92ad Mon Sep 17 00:00:00 2001
From: Bert Palm <bert.palm@ufz.de>
Date: Sun, 16 Feb 2020 00:01:26 +0100
Subject: [PATCH] rework set again

---
 dios/dios.py     | 81 +++++++++++++++++++++++++++++++++++++++++-------
 test/run_dios.py |  3 ++
 2 files changed, 73 insertions(+), 11 deletions(-)

diff --git a/dios/dios.py b/dios/dios.py
index dca1354..e2bf03f 100644
--- a/dios/dios.py
+++ b/dios/dios.py
@@ -8,7 +8,8 @@ import operator as op
 
 from functools import wraps
 from collections import OrderedDict
-from pandas._libs.lib import to_object_array
+from pandas._libs.lib import to_object_array, is_bool_array
+from pandas.core.common import is_bool_indexer
 from pandas.core.dtypes.common import (
     is_list_like,
     is_nested_list_like,
@@ -20,6 +21,19 @@ from pandas.core.dtypes.common import (
 from pandas.core.dtypes.common import is_iterator as _is_iterator
 
 
+def is_dios_like(obj):
+    return isinstance(obj, DictOfSeries)
+
+def is_pandas_like(obj):
+    """We consider ourselfs (dios) as pandas-like"""
+    return is_series_like(obj) or is_dataframe_like(obj) or is_dios_like(obj)
+
+def is_series_like(obj):
+    return isinstance(obj, pd.Series)
+
+def is_dataframe_like(obj):
+    return isinstance(obj, pd.DataFrame)
+
 def is_iterator(obj):
     """ This is only a dummy wrapper, to warn that the docu of this isnt't right.
     Unlike the example says,
@@ -226,26 +240,70 @@ class DictOfSeries:
         if is_iterator(key):
             key = list(key)
 
+        keys = None
+        indexers = None
+
+        ki = dict()
         # determine action by keys
+
         if isinstance(key, str):
+            # special case: insert a fresh new key
             if key not in self.columns:
                 self._setitem_new(key, value)
+                return
             else:
-                self._setitem(key, value)
-            return
+                ki[key] = slice(None)
+                keys = [key]
 
         elif isinstance(key, slice):
             keys = self.columns
-            kslicer = key
+            indexers = [key]
+
+        # list, np.arrays, ... of list, np.arrays..
+        elif is_nested_list_like(key):
+            keys = self.columns
+            indexers = []
+            # we only allow nested lists with bool entries
+            for i in range(len(key)):
+                arr = np.array(i)
+                if not is_bool_array(arr):
+                    raise ValueError("Must pass nested-list-like with boolean values only")
+                indexers.append(arr)
+
+        # ser, df, dios
+        elif is_pandas_like(key):
+            if is_series_like(key):
+                keys = key.to_list()
+            elif is_dataframe_like(key):
+                keys = key.columns.to_list()
+                indexers = key.values
+                testbool = True
+            elif is_dios_like(key):
+                keys = key.columns
+                indexers = list(key.values)
+
+        # list, np.array, np.ndarray, ...
+        elif is_list_like(key):
+            arr = np.array(key)
+            if is_bool_array(arr):
+                keys = self.columns
+                indexers = [arr]
+            else:
+                keys = key
 
-        elif is_list_like(key) and not is_nested_list_like(key):
-            self._check_keys(key)
-            keys = key
-            kslicer = None
 
         else:
             raise KeyError(f"{key}")
 
+        if length != len(keys):
+            raise ValueError(f"Length mismatch for nested list: expected {len(keys)}, got {length}")
+        if not indexers:
+            indexers = [slice(None)]
+        if len(indexers) == 1:
+            indexers = indexers * len(keys)
+
+        assert len(indexers) == len(keys)
+        # now we have a indexer for every series
         # determine action by value
 
         if isinstance(value, DictOfSeries):
@@ -282,9 +340,10 @@ class DictOfSeries:
         if isinstance(val, pd.Series):
             val = cast_to_itype(val, self._itype, policy=self._policy)
             left = self._data[key][sl]
-            l, r = left.align(val, join='inner')
-            if not r.empty:
-                left.loc[r.index] = r.copy(deep=True)
+            idx = left.index.intersection(val.index)
+            # l, r = left.align(val, join='inner')
+            if not idx.empty:
+                left.loc[idx] = val.loc[idx].copy()
             return
 
         item = self._data[key]
diff --git a/test/run_dios.py b/test/run_dios.py
index 3ed5da7..8aba9fd 100644
--- a/test/run_dios.py
+++ b/test/run_dios.py
@@ -5,6 +5,9 @@ import numpy as np
 if __name__ == '__main__':
     # dios_options[Options.mixed_itype_policy] = 'error'
 
+    df = pd.DataFrame([1,24,5,456,45], index=pd.date_range(periods=5, freq='1d', start='2000-01-01'))
+    df[[True, False]]
+
     dios = DictOfSeries(data=[234.54, 5, 5, 4, np.nan, 5, 4, 5])
 
     dios = abs(-dios)
-- 
GitLab