diff --git a/dios/dios.py b/dios/dios.py index 6341be12ece17a590f31ca2f5220fbebf81d1015..4b17832db2f9e552d92cb8fb91532b13da090e9e 100644 --- a/dios/dios.py +++ b/dios/dios.py @@ -38,6 +38,10 @@ def is_dios_like(obj): return isinstance(obj, DictOfSeries) or isinstance(obj, pd.DataFrame) +def is_bool_series(obj): + return isinstance(obj, pd.Series) and obj.dtype == bool + + def is_iterator(obj): """ This is only a dummy wrapper, to warn that the docu of this isnt't right. Unlike the example says, @@ -180,7 +184,6 @@ class DictOfSeries: keys, ixs, ixalign = self._unpack_key(key) new = self.copy_empty() for i, k in enumerate(keys): - # fixme: optimasation: new._data.loc[k] = self._get_item(k, ixs[i], ixalign) new[k] = self._get_item(k, ixs[i], ixalign) return new @@ -230,8 +233,7 @@ class DictOfSeries: ix = ser.index.intersection(ix.index) if isinstance(right, pd.Series): left = ser[ix] - right = align_index_by_policy(left, right) - ix = right.index + right, ix = align_index_by_policy(left, right) ser.loc[ix] = right def _insert(self, col, val): @@ -249,13 +251,13 @@ class DictOfSeries: def _unpack_value(self, keys, ixs, val): """Return a generator that yield (key, indexer, value) for all keys""" - val = list(val) if is_iterator(val) else val - - if is_dios_like(val): - val = val.squeeze() + # prepare value + val = list(val) if is_iterator(val) else val + val = val.squeeze() if is_dios_like(val) else val dioslike, nlistlike = is_dios_like(val), is_nested_list_like(val) + # check value if (dioslike or nlistlike) and len(val) != len(keys): raise ValueError(f"could not broadcast input array with length {len(val)}" f" into dios of length {len(keys)}") @@ -267,7 +269,8 @@ class DictOfSeries: key, ix = keys[i], ixs[i] if dioslike: # we explicitly do not align keys here. usr can use .loc for this - # purpose, (but we do align on rows, later in the setting chain) + # purpose, a modified version for of this function for .loc, can + # be found in locator.py yield key, ix, val[val.columns[i]] elif nlistlike: yield key, ix, val[i] diff --git a/dios/locator.py b/dios/locator.py index 78e1643a759d0fcd6581196a17b7f8c5e914ca6e..2dd0be33e6df0b5f1f7f1250d68d0ec2f6a7877d 100644 --- a/dios/locator.py +++ b/dios/locator.py @@ -1,58 +1,66 @@ from dios.dios import * + class _Indexer: def __init__(self, _dios): self._dios = _dios - self.columns = _dios.columns self._data = _dios._data - # self._unpack_value = _dios._unpack_value + self._columns = _dios.columns class _LocIndexer(_Indexer): def __init__(self, _dios): super().__init__(_dios) + # we can use set item here, as this + # also uses .loc for setting values self._set_item = _dios._set_item + def __setitem__(self, key, val): + keys, rkey, lowdim = self._unpack_key(key) + ix, ixalign = self._unpack_rowkey(rkey) + gen = _unpack_value(keys, ix, val) - def _series(self, rkey, cols, lowdim): - if lowdim: - return self._scalar(rkey[0], cols[0]) - new = pd.Series() - for c in cols: - try: - new[c] = self._data[c].loc[rkey] - except KeyError: - new[c] = np.nan - - def _scalar(self, r, c): - return self._data[c].loc[r] - - def __setitem__(self, key, value): - data, rkey = self._getdata(key) - if data.empty: - return - if isinstance(data, pd.Series): - pass + for tup in gen: + self._set_item(*tup, ixalign=ixalign) def __getitem__(self, key): - data, rkey, lowdim = self._getdata(key) + keys, rkey, lowdim = self._unpack_key(key) + ix, ixalign = self._unpack_rowkey(rkey) - if is_hashable(rkey): + if is_hashable(ix): new = pd.Series() - data.name = rkey + new.name = ix else: new = self._dios.copy_empty() - if not data.empty: - if lowdim: - new = data.loc[rkey] - else: - for s in data.index: - new[s] = data[s].loc[rkey] + # set series in new dios OR set values in + # new series if ix is hashable (see above) + for k in keys: + new[k] = self._get_item(self._data.loc[k], ix, ixalign=ixalign) + maby_set_series_name(new[k], k) + + # squeeze to series if a single label was given + # OR squeeze to val if additional ix is hashable + if lowdim: + new = new.squeeze() + return new - def _getdata(self, key): + def _get_item(self, ser, ix, ixalign=False): + if ixalign: + ix = ser.index.intersection(ix.index) + return ser.loc[ix] + + def _unpack_rowkey(self, rkey): + align = False + if is_dios_like(rkey) or is_nested_list_like(rkey): + raise ValueError("Cannot index with multidimensional key") + if is_bool_series(rkey): + rkey, align = rkey[rkey], True # kill `False` + return rkey, align + + def _unpack_key(self, key): lowdim = False if isinstance(key, tuple): key, ckey, *fail = key @@ -61,28 +69,13 @@ class _LocIndexer(_Indexer): if is_dios_like(ckey): raise ValueError("Cannot index with multidimensional key") if is_hashable(ckey): + keys = [ckey] lowdim = True - try: - data = self._data.loc[ckey] - except Exception as e: - raise e + else: + keys = self._data.loc[ckey].index.to_list() else: - data = self._data - return data, key, lowdim - - def _col_slice_to_col_list(self, cslice): - """ see here: - https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#indexing-slicing-with-labels - """ - keys = list(self._data.index) - try: - start = keys.index(cslice.start) if cslice.start is not None else None - stop = keys.index(cslice.stop) if cslice.stop is not None else None - except ValueError: - raise KeyError("The slice start label, or the slice stop label, is not present in columns.") - if not is_integer(cslice.step) or cslice.step <= 0: - return [] - return keys[slice(start, stop + 1, cslice.step)] + keys = self._columns.to_list() + return keys, key, lowdim class _iLocIndexer(_Indexer): @@ -189,3 +182,34 @@ class _iLocIndexer(_Indexer): if not is_integer(s): raise TypeError(f"positional indexing with slice must be integers, passed type was {type(s)}") return list(self._data.index)[sl] + + +def _unpack_value(keys, ix, val): + """Return a generator that yield (column key, corresponding value, value-align(bool) ) + for all columns. + This is analogous to DictOfSeries._unpack_value, but with some modifications.""" + + # prepare value + val = list(val) if is_iterator(val) else val + val = val.squeeze() if is_dios_like(val) else val + dioslike, nlistlike = is_dios_like(val), is_nested_list_like(val) + + # check value + if nlistlike and len(val) != len(keys): + raise ValueError(f"could not broadcast input array with length {len(val)}" + f" into dios of length {len(keys)}") + if dioslike: + keys = val.columns.intersection(keys) + + for i, k in enumerate(keys): + if dioslike: + yield k, ix, val[k] + elif nlistlike: + yield k, ix, val[i] + else: + yield k, ix, val + + +def maby_set_series_name(maybe_ser, name): + if isinstance(maybe_ser, pd.Series): + maybe_ser.name = name diff --git a/dios/options.py b/dios/options.py index 3e1e57c0c0897cc674c69a419c7ee18c6920591b..f06e9cdd9cd4b2edbf6100a9d211d9b4bcec96f4 100644 --- a/dios/options.py +++ b/dios/options.py @@ -78,14 +78,20 @@ dios_options = { } -def align_index_by_policy(left, right): - policy = dios_options[OptsFields.setitem_nan_policy] +def align_index_by_policy(left, right, policy=None): + if policy is None: + policy = dios_options[OptsFields.setitem_nan_policy] + if policy in [Opts.keep_nans, Opts.pdlike_nans]: - # return right.align(left, join='right')[0] - return right.reindex_like(left) + # r = right.align(left, join='right')[0] + r = right.reindex_like(left) elif policy in [Opts.drop_nans]: - # return right.align(left, join='inner')[0] - return right.loc[left.index.intersection(right.index)] + # r = right.align(left, join='inner')[0] + r = right.loc[left.index.intersection(right.index)] + else: + raise ValueError(policy) + + return r, r.index def get_keys_by_policy(tocheck, keys, policy): diff --git a/test/run_dios.py b/test/run_dios.py index 18251c43b615329542465a1cd6246ff6d2334e68..2c2849083e4fa8dd2d8331d409909262012ee45e 100644 --- a/test/run_dios.py +++ b/test/run_dios.py @@ -21,7 +21,7 @@ if __name__ == '__main__': print(d, type(d)) a = d.loc[:,'a'] print(a, type(a)) - x = d.loc[1,['a', 'ss']] + x = d.loc[1,['a', 'ss', 'z']] print(x, type(x)) diff --git a/test/test__getsetitem__.py b/test/test__getitem__.py similarity index 98% rename from test/test__getsetitem__.py rename to test/test__getitem__.py index 2941de41603dd00201adbea506566bba748ca01b..52dd8130ac69981de3ac5e4422c40f3645985cf4 100644 --- a/test/test__getsetitem__.py +++ b/test/test__getitem__.py @@ -50,7 +50,7 @@ def test__getitem_single_fail(idxer): @pytest.mark.parametrize('idxer', ['x', '2', 1, None, ]) def test__getitem_single_loc_fail(idxer): - with pytest.raises(KeyError): + with pytest.raises((KeyError, TypeError)): a = d1.loc[:, idxer]