From e6f927ed3b7e9e5cefb2f6cf5b37b78881a51b10 Mon Sep 17 00:00:00 2001 From: Bert Palm <bert.palm@ufz.de> Date: Wed, 2 Dec 2020 15:09:32 +0100 Subject: [PATCH] everything works, but idiotic pandas count(), because they handle us differently then self --- saqc/lib/rolling.py | 70 ++++++++++++++++------------------------ test/lib/test_rolling.py | 2 +- 2 files changed, 29 insertions(+), 43 deletions(-) diff --git a/saqc/lib/rolling.py b/saqc/lib/rolling.py index 62502533f..4a5516ff8 100644 --- a/saqc/lib/rolling.py +++ b/saqc/lib/rolling.py @@ -36,15 +36,12 @@ def is_slice(k): return isinstance(k, slice) class _CustomBaseIndexer(BaseIndexer): is_datetimelike = None - def __init__(self, index_array, window_size, min_periods=None, center=False, closed=None, forward=False, + def __init__(self, index_array, window_size, center=False, forward=False, expand=False, step=None, mask=None): super().__init__() self.index_array = index_array - self.num_values = len(index_array) self.window_size = window_size - self.min_periods = min_periods - self.center = center - self.closed = closed + self._center = center self.forward = forward self.expand = expand self.step = step @@ -52,7 +49,9 @@ class _CustomBaseIndexer(BaseIndexer): self.validate() def validate(self) -> None: - if self.center is not None and not is_bool(self.center): + if self._center is None: + self._center = False + if not is_bool(self._center): raise ValueError("center must be a boolean") if not is_bool(self.forward): raise ValueError("forward must be a boolean") @@ -74,11 +73,17 @@ class _CustomBaseIndexer(BaseIndexer): raise TypeError('mask must have boolean values only.') self.skip = ~self.skip - def get_window_bounds(self, num_values=0, min_periods=None, center=False, closed=None): - num_values = self.num_values - min_periods = self.min_periods - center = self.center - closed = self.closed + def get_window_bounds(self, num_values=0, min_periods=None, center=None, closed=None): + if min_periods is None: + assert self.is_datetimelike is False + min_periods = 1 + + # if one call us directly, one may pass a center value we should consider. + # pandas instead (via customRoller) will always pass None and the correct + # center value is set in __init__. This is because pandas cannot center on + # dt_like windows and would fail before even call us. + if center is None: + center = self._center start, end = self._get_bounds(num_values, min_periods, center, closed) start, end = self._apply_skipmask(start, end) @@ -95,9 +100,9 @@ class _CustomBaseIndexer(BaseIndexer): # end[end - start < self.min_periods] = 0 return start, end - def _get_center_window_sizes(self, winsz): + def _get_center_window_sizes(self, center, winsz): ws1 = ws2 = winsz - if self.center: + if center: # centering of dtlike windows is just looking left and right # with half amount of window-size ws1 = (winsz + 1) // 2 @@ -130,11 +135,6 @@ class _FixedWindowDirectionIndexer(_CustomBaseIndexer): # set here is_datetimelike = False - def validate(self) -> None: - super().validate() - if self.min_periods is None: - self.min_periods = self.window_size - def _get_bounds(self, num_values=0, min_periods=None, center=False, closed=None): # closed is always ignored and handled as 'both' other cases not implemented offset = calculate_center_offset(self.window_size) if center else 0 @@ -168,7 +168,7 @@ class _FixedWindowDirectionIndexer(_CustomBaseIndexer): def _remove_ramps(self, start, end, center): fw, bw = self.forward, not self.forward - ramp_l, ramp_r = self._get_center_window_sizes(self.window_size - 1) + ramp_l, ramp_r = self._get_center_window_sizes(center, self.window_size - 1) if center: fw = bw = True @@ -199,15 +199,8 @@ class _VariableWindowDirectionIndexer(_CustomBaseIndexer): # set here is_datetimelike = True - def validate(self) -> None: - super().validate() - if self.min_periods is None: - self.min_periods = 1 - if self.window_size == 0: - self.min_periods = 0 - def _get_bounds(self, num_values=0, min_periods=None, center=False, closed=None): - ws_bw, ws_fw = self._get_center_window_sizes(self.window_size) + ws_bw, ws_fw = self._get_center_window_sizes(center, self.window_size) if center: c1 = c2 = closed if closed == 'neither': @@ -227,7 +220,7 @@ class _VariableWindowDirectionIndexer(_CustomBaseIndexer): return start, end def _remove_ramps(self, start, end, center): - ws_bw, ws_fw = self._get_center_window_sizes(self.window_size) + ws_bw, ws_fw = self._get_center_window_sizes(center, self.window_size) if center or not self.forward: # remove (up) ramp @@ -344,30 +337,22 @@ def customRoller(obj, window, min_periods=None, # aka minimum non-nan values if not isinstance(obj, (ABCSeries, ABCDataFrame)): raise TypeError(f"invalid type: {type(obj)}") - theirs = dict(min_periods=min_periods, center=center, win_type=win_type, on=on, axis=axis, closed=closed) - ours = dict(forward=forward, expand=expand, step=step, mask=mask) - assert len(theirs) + len(ours) == num_params, "not all params covert (!)" - # center is the only param from the pandas rolling implementation # that we advance, namely we allow center=True on dt-indexed data - ours.update(center=theirs.pop('center')) + # that's why we take it as ours + theirs = dict(min_periods=min_periods, win_type=win_type, on=on, axis=axis, closed=closed) + ours = dict(center=center, forward=forward, expand=expand, step=step, mask=mask) + assert len(theirs) + len(ours) == num_params, "not all params covert (!)" # use .rolling to do all the checks like if closed is one of [left, right, neither, both], # closed not allowed for integer windows, index is monotonic (in- or decreasing), if freq-based # windows can be transformed to nanoseconds (eg. fails for `1y` - it could have 364 or 365 days), etc. # Also it converts window and the index to numpy-arrays (so we don't have to do it :D). try: - x = obj.rolling(window, center=False, **theirs) + x = obj.rolling(window, **theirs) except Exception: raise - if theirs.pop('win_type') is not None: - raise NotImplementedError("customRoller() does not implemented win_type.") - num_params -= 1 - - ours.update(min_periods=theirs.pop('min_periods'), closed=theirs.pop('closed')) - assert len(theirs) + len(ours) == num_params, "not all params covert (!)" - indexer = _VariableWindowDirectionIndexer if x.is_freq_type else _FixedWindowDirectionIndexer indexer = indexer(index_array=x._on.asi8, window_size=x.window, **ours) @@ -385,4 +370,5 @@ def customRoller(obj, window, min_periods=None, # aka minimum non-nan values # Lastly, it is necessary to pass min_periods at all (!) and do not set it to a fix value (1, 0, None,...). This # is, because we cannot throw out values by ourself in the indexer, because min_periods also evaluates NA values # in its count and we have no control over the actual values, just their indexes. - return obj.rolling(indexer, min_periods=x.min_periods, **theirs) + theirs.update(min_periods=x.min_periods) + return obj.rolling(indexer, center=None, **theirs) diff --git a/test/lib/test_rolling.py b/test/lib/test_rolling.py index dbb2f7ca4..d8365f4eb 100644 --- a/test/lib/test_rolling.py +++ b/test/lib/test_rolling.py @@ -158,7 +158,7 @@ def dt_center_kws(): return l -@pytest.mark.parametrize("kws", make_num_kws(), ids=lambda x: str(x)) +@pytest.mark.parametrize("kws", dt_center_kws(), ids=lambda x: str(x)) def test_centering_w_dtindex(kws): print(kws) s = pd.Series(0., index=pd.date_range("2000", periods=10, freq='1H')) -- GitLab