Skip to content
Snippets Groups Projects
Commit f8f46213 authored by Peter Lünenschloß's avatar Peter Lünenschloß
Browse files

fixed target handling

parent 72c882c0
No related branches found
No related tags found
1 merge request!850Horizontal axis rolling
Pipeline #207031 passed with stages
in 5 minutes and 11 seconds
This commit is part of merge request !850. Comments created here will be created in the context of that merge request.
...@@ -32,9 +32,9 @@ class RollingMixin: ...@@ -32,9 +32,9 @@ class RollingMixin:
) )
def rolling( def rolling(
self: "SaQC", self: "SaQC",
field: str, field: str | list[str],
window: str | int, window: str | int,
target: str = None, target: str | list[str] = None,
func: Callable[[pd.Series], np.ndarray] | str = "mean", func: Callable[[pd.Series], np.ndarray] | str = "mean",
min_periods: int = 0, min_periods: int = 0,
center: bool = True, center: bool = True,
...@@ -66,27 +66,14 @@ class RollingMixin: ...@@ -66,27 +66,14 @@ class RollingMixin:
If True, center the rolling window. If True, center the rolling window.
""" """
# HINT: checking in _roll # HINT: checking in _roll
if len(field) == 1: if target and (len(target) > 1) and (len(field) != len(target)):
if target: raise ValueError(
self = self.copyField(field[0], target=target) f"""If multiple targets are given, per-field application of rolling is conducted and the number of
field = target fields has to equal the number of targets.\n Got: \n Fields={field} \n Targets={target}"""
self._data, self._flags = _roll(
data=self._data,
field=field,
flags=self._flags,
window=window,
func=func,
min_periods=min_periods,
center=center,
**kwargs,
) )
else:
if not target:
raise ValueError(
"Target has to be assigned for cross statistics calculations."
)
if target and (len(field) > 1) and (len(target) == 1):
target = target[0]
if target not in self._data.columns: if target not in self._data.columns:
self[target] = saqc.SaQC( self[target] = saqc.SaQC(
pd.Series( pd.Series(
...@@ -105,6 +92,22 @@ class RollingMixin: ...@@ -105,6 +92,22 @@ class RollingMixin:
center=center, center=center,
) )
else:
if target:
for ft in zip(field, target):
self = self.copyField(ft[0], target=ft[1], overwrite=True)
field = target
for f in field:
self._data, self._flags = _roll(
data=self._data,
field=f,
flags=self._flags,
window=window,
func=func,
min_periods=min_periods,
center=center,
**kwargs,
)
return self return self
@register(mask=["field"], demask=[], squeeze=[]) @register(mask=["field"], demask=[], squeeze=[])
...@@ -221,9 +224,8 @@ def _hroll( ...@@ -221,9 +224,8 @@ def _hroll(
if center: if center:
f_out = f_out.shift(-int(np.floor(window / 2))) f_out = f_out.shift(-int(np.floor(window / 2)))
for f in target: data[target] = d_out
data[f] = d_out flags[target] = f_out
flags[f] = f_out
return data, flags return data, flags
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment