diff --git a/saqc/funcs/rolling.py b/saqc/funcs/rolling.py index 0b0666714f2c06a64d8a357d14cbb535dd6044b7..327c32f132e1dee4916ca031fc52f1eab51ec683 100644 --- a/saqc/funcs/rolling.py +++ b/saqc/funcs/rolling.py @@ -32,9 +32,9 @@ class RollingMixin: ) def rolling( self: "SaQC", - field: str, + field: str | list[str], window: str | int, - target: str = None, + target: str | list[str] = None, func: Callable[[pd.Series], np.ndarray] | str = "mean", min_periods: int = 0, center: bool = True, @@ -66,27 +66,14 @@ class RollingMixin: If True, center the rolling window. """ # HINT: checking in _roll - if len(field) == 1: - if target: - self = self.copyField(field[0], target=target) - field = target - - self._data, self._flags = _roll( - data=self._data, - field=field, - flags=self._flags, - window=window, - func=func, - min_periods=min_periods, - center=center, - **kwargs, + if target and (len(target) > 1) and (len(field) != len(target)): + raise ValueError( + f"""If multiple targets are given, per-field application of rolling is conducted and the number of + fields has to equal the number of targets.\n Got: \n Fields={field} \n Targets={target}""" ) - else: - if not target: - raise ValueError( - "Target has to be assigned for cross statistics calculations." - ) + if target and (len(field) > 1) and (len(target) == 1): + target = target[0] if target not in self._data.columns: self[target] = saqc.SaQC( pd.Series( @@ -105,6 +92,22 @@ class RollingMixin: center=center, ) + else: + if target: + for ft in zip(field, target): + self = self.copyField(ft[0], target=ft[1], overwrite=True) + field = target + for f in field: + self._data, self._flags = _roll( + data=self._data, + field=f, + flags=self._flags, + window=window, + func=func, + min_periods=min_periods, + center=center, + **kwargs, + ) return self @register(mask=["field"], demask=[], squeeze=[]) @@ -221,9 +224,8 @@ def _hroll( if center: f_out = f_out.shift(-int(np.floor(window / 2))) - for f in target: - data[f] = d_out - flags[f] = f_out + data[target] = d_out + flags[target] = f_out return data, flags