diff --git a/saqc/lib/rolling.py b/saqc/lib/rolling.py index 714545eabe51fe43f9bcf5a6c1c25722a91fbc18..d4afba65cc76ab097caf0baab12c085dc6e9fc9a 100644 --- a/saqc/lib/rolling.py +++ b/saqc/lib/rolling.py @@ -145,9 +145,9 @@ class _FixedWindowDirectionIndexer(_CustomBaseIndexer): num_values += offset if self.forward: - start, end = self._fw(num_values, min_periods, center, closed) + start, end = self._fw(num_values, min_periods, center, closed, offset) else: - start, end = self._bw(num_values, min_periods, center, closed) + start, end = self._bw(num_values, min_periods, center, closed, offset) if center: start, end = self._center_result(start, end, offset) @@ -160,8 +160,12 @@ class _FixedWindowDirectionIndexer(_CustomBaseIndexer): def _center_result(self, start, end, offset): if offset > 0: - start = start[offset:] - end = end[offset:] + if self.forward: + start = start[:-offset] + end = end[:-offset] + else: + start = start[offset:] + end = end[offset:] return start, end def _remove_ramps(self, start, end, center): @@ -177,7 +181,7 @@ class _FixedWindowDirectionIndexer(_CustomBaseIndexer): return start, end - def _bw(self, num_values=0, min_periods=None, center=False, closed=None): + def _bw(self, num_values=0, min_periods=None, center=False, closed=None, offset=0): # code taken from pd.core.windows.indexer.FixedWindowIndexer start_s = np.zeros(self.window_size, dtype="int64") start_e = (np.arange(self.window_size, num_values, dtype="int64") - self.window_size + 1) @@ -189,10 +193,10 @@ class _FixedWindowDirectionIndexer(_CustomBaseIndexer): # end stolen code return start, end - def _fw(self, num_values=0, min_periods=None, center=False, closed=None): - s, _ = self._bw(num_values, min_periods, center, closed) - start = np.arange(num_values) - end = num_values - s[::-1] + def _fw(self, num_values=0, min_periods=None, center=False, closed=None, offset=0): + start = np.arange(-offset, num_values, dtype="int64")[:num_values] + end = start + self.window_size + start[:offset] = 0 return start, end diff --git a/test/lib/test_rolling.py b/test/lib/test_rolling.py index 443aba123bfaf73a3ae4a3bfa8ce8f1492cd1670..d2980ff386814fa4e15df7b888ef2549cc91cef3 100644 --- a/test/lib/test_rolling.py +++ b/test/lib/test_rolling.py @@ -86,30 +86,30 @@ def runtest_for_kw_combi(s, kws): print_diff(s, result, expected) assert False -# -# @pytest.mark.parametrize("kws", make_num_kws()) -# def test_pandas_conform_num(data, kws): -# runtest_for_kw_combi(data, kws) -# -# -# @pytest.mark.parametrize("kws", make_dt_kws()) -# def test_pandas_conform_dt(data, kws): -# if kws.get('center', False) is True: -# pass -# else: -# runtest_for_kw_combi(data, kws) -# -# -# @pytest.mark.parametrize("kws", make_num_kws()) -# def test_forward_num(data, kws): -# kws.update(forward=True, center=False) -# runtest_for_kw_combi(data, kws) + +@pytest.mark.parametrize("kws", make_num_kws()) +def test_pandas_conform_num(data, kws): + runtest_for_kw_combi(data, kws) @pytest.mark.parametrize("kws", make_dt_kws()) -def test_forward_dt(data, kws): - kws.update(forward=True) - if kws['center'] is True: - pytest.skip('pandas has no center on dt-index') +def test_pandas_conform_dt(data, kws): + if kws.get('center', False) is True: + pass else: runtest_for_kw_combi(data, kws) + + +@pytest.mark.parametrize("kws", make_num_kws()) +def test_forward_num(data, kws): + kws.update(forward=True) + runtest_for_kw_combi(data, kws) + + +# @pytest.mark.parametrize("kws", make_dt_kws()) +# def test_forward_dt(data, kws): +# kws.update(forward=True) +# if kws['center'] is True: +# pytest.skip('pandas has no center on dt-index') +# else: +# runtest_for_kw_combi(data, kws)