Skip to content
Snippets Groups Projects
Commit f8f46213 authored by Peter Lünenschloß's avatar Peter Lünenschloß
Browse files

fixed target handling

parent 72c882c0
No related branches found
No related tags found
1 merge request!850Horizontal axis rolling
Pipeline #207031 passed with stages
in 5 minutes and 11 seconds
This commit is part of merge request !850. Comments created here will be created in the context of that merge request.
......@@ -32,9 +32,9 @@ class RollingMixin:
)
def rolling(
self: "SaQC",
field: str,
field: str | list[str],
window: str | int,
target: str = None,
target: str | list[str] = None,
func: Callable[[pd.Series], np.ndarray] | str = "mean",
min_periods: int = 0,
center: bool = True,
......@@ -66,27 +66,14 @@ class RollingMixin:
If True, center the rolling window.
"""
# HINT: checking in _roll
if len(field) == 1:
if target:
self = self.copyField(field[0], target=target)
field = target
self._data, self._flags = _roll(
data=self._data,
field=field,
flags=self._flags,
window=window,
func=func,
min_periods=min_periods,
center=center,
**kwargs,
if target and (len(target) > 1) and (len(field) != len(target)):
raise ValueError(
f"""If multiple targets are given, per-field application of rolling is conducted and the number of
fields has to equal the number of targets.\n Got: \n Fields={field} \n Targets={target}"""
)
else:
if not target:
raise ValueError(
"Target has to be assigned for cross statistics calculations."
)
if target and (len(field) > 1) and (len(target) == 1):
target = target[0]
if target not in self._data.columns:
self[target] = saqc.SaQC(
pd.Series(
......@@ -105,6 +92,22 @@ class RollingMixin:
center=center,
)
else:
if target:
for ft in zip(field, target):
self = self.copyField(ft[0], target=ft[1], overwrite=True)
field = target
for f in field:
self._data, self._flags = _roll(
data=self._data,
field=f,
flags=self._flags,
window=window,
func=func,
min_periods=min_periods,
center=center,
**kwargs,
)
return self
@register(mask=["field"], demask=[], squeeze=[])
......@@ -221,9 +224,8 @@ def _hroll(
if center:
f_out = f_out.shift(-int(np.floor(window / 2)))
for f in target:
data[f] = d_out
flags[f] = f_out
data[target] = d_out
flags[target] = f_out
return data, flags
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment