From 13614dbce4da09c6167d21c51d949808ce5d61e7 Mon Sep 17 00:00:00 2001 From: Bert Palm <bert.palm@ufz.de> Date: Sat, 11 Apr 2020 21:05:22 +0200 Subject: [PATCH] implemented axis=1 for squeeze, all and any --- dios/dios.py | 39 +++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/dios/dios.py b/dios/dios.py index 9de0890..0284846 100644 --- a/dios/dios.py +++ b/dios/dios.py @@ -423,11 +423,20 @@ class DictOfSeries: except Exception: return False + def _reduce_horizontal(self, function, initializer_value): + res = pd.Series(data=initializer_value, index=self.index_of('all')) + for d in self._data: + base = res.loc[d.index] + if len(base) > 0: + res.loc[d.index] = function(base, d) + return res + def all(self, axis=0): if axis in [0, 'index']: return self._data.apply(all) elif axis in [1, 'columns']: - raise NotImplementedError + func = lambda s1, s2: s1.astype(bool) & s2.astype(bool) + return self._reduce_horizontal(func, True) elif axis is None: return self._data.apply(all).all() raise ValueError(axis) @@ -436,23 +445,27 @@ class DictOfSeries: if axis in [0, 'index']: return self._data.apply(any) elif axis in [1, 'columns']: - raise NotImplementedError + func = lambda s1, s2: s1.astype(bool) | s2.astype(bool) + return self._reduce_horizontal(func, False) elif axis is None: return self._data.apply(any).any() raise ValueError(axis) def squeeze(self, axis=None): if axis in [0, 'index']: - raise NotImplementedError - - if len(self) > 1: + if (self.lengths == 1).all(): + return self._data.apply(pd.Series.squeeze) + return self + elif axis in [1, 'columns']: + if len(self) == 1: + return self._data.squeeze() return self - - if axis in [1, 'columns']: - return self._data.squeeze() elif axis is None: - return self._data.squeeze().squeeze() - + if len(self) == 1: + return self._data.squeeze().squeeze() + if (self.lengths == 1).all(): + return self._data.apply(pd.Series.squeeze).squeeze() + return self raise ValueError(axis) @property @@ -590,10 +603,8 @@ class DictOfSeries: for c in self.columns: s = func(self._data.at[c].values if raw else self._data.at[c], *args, **kwds) new.append(s) - try: - need_dios = True if not _is_scalar(s) else need_dios - except TypeError: - pass + if not _is_scalar(s): + need_dios = True if need_dios: data = pd.Series(dtype='O', index=self.columns) -- GitLab