Skip to content
Snippets Groups Projects
Commit f3137c26 authored by David Schäfer's avatar David Schäfer
Browse files

allow target broadcasting in flagGeneric

parent 3e0ab5d8
No related branches found
No related tags found
1 merge request!752allow target broadcasting in flagGeneric
......@@ -8,6 +8,7 @@ SPDX-License-Identifier: GPL-3.0-or-later
## Unreleased
[List of commits](https://git.ufz.de/rdm-software/saqc/-/compare/v2.5.0...develop)
### Added
- `flagGeneric`: target broadcasting
### Changed
### Removed
### Fixed
......
......@@ -11,10 +11,9 @@ from __future__ import annotations
from abc import abstractmethod
from typing import Any, Dict
import numpy as np
import pandas as pd
from saqc import BAD, FILTER_ALL, GOOD, UNFLAGGED
from saqc import BAD, FILTER_ALL, UNFLAGGED
from saqc.core import DictOfSeries, Flags
from saqc.lib.types import ExternalFlag
......
......@@ -233,7 +233,7 @@ class GenericMixin:
result = _execGeneric(fchunk, dchunk, func, dfilter=dfilter)
result = _castResult(result)
if len(targets) != len(result.columns):
if len(result.columns) > 1 and len(targets) != len(result.columns):
raise ValueError(
f"the generic function returned {len(result.columns)} field(s), "
f"but {len(targets)} target(s) were given"
......@@ -244,7 +244,8 @@ class GenericMixin:
# update flags & data
for i, col in enumerate(targets):
mask = result[result.columns[i]]
# broadcast one column results to all targets
mask = result[result.columns[i if len(result.columns) > 1 else 0]]
# make sure the column exists
if col not in self._flags:
......
......@@ -41,6 +41,7 @@ def test_emptyData():
["tmp1", "tmp2"],
lambda x, y: [pd.Series(True, index=x.index.union(y.index))] * 2,
),
(["tmp1", "tmp2"], lambda x, y: pd.Series(True, index=x.index.union(y.index))),
],
)
def test_writeTargetFlagGeneric(data, targets, func):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment