diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f09eccdf09c944bd13dc1bd003de1e65ecd6933..9296a9dd5c6de68d61665315ce9b758b8bbaf068 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ SPDX-License-Identifier: GPL-3.0-or-later ## Unreleased [List of commits](https://git.ufz.de/rdm-software/saqc/-/compare/v2.5.0...develop) ### Added -- `flagGeneric`: target broadcasting +- `flagGeneric`, `processGeneric`: target broadcasting and numpy array support - `SaQC`: automatic translation of incoming flags - Option to change the flagging scheme after initialization - `flagByClick`: manually assign flags using a graphical user interface diff --git a/saqc/funcs/generic.py b/saqc/funcs/generic.py index 94fa02eea9796a6e3d3891821d282d2fbc8bf751..7f85016a0251d43a87191e3a2ef339c965439e20 100644 --- a/saqc/funcs/generic.py +++ b/saqc/funcs/generic.py @@ -77,10 +77,27 @@ def _execGeneric( return func(*cols) +def _inferBroadcast(obj, trg_shape) -> pd.DataFrame: + # simple single value broadcasting + if pd.api.types.is_scalar(obj): + return np.full(trg_shape, obj) + return obj + + +def _inferDF(obj, cols, index): + # infer dataframe if result is numpy array of fitting shape + if isinstance(obj, np.ndarray): + lc = len(cols) + li = len(index) + if (obj.shape == (li, lc)) or (obj.shape == (li,)): + return pd.DataFrame(obj, columns=cols, index=index) + return obj + + def _castResult(obj) -> DictOfSeries: # Note: the actual keys aka. column names # we use here to create a DictOfSeries - # are never used, and only exists temporary. + # are never used and only exist temporarily. if isinstance(obj, pd.Series): return DictOfSeries({"0": obj}) @@ -151,7 +168,10 @@ class GenericMixin: targets = fields if target is None else toSequence(target) dchunk, fchunk = self._data[fields].copy(), self._flags[fields].copy() + trg_idx = dchunk[dchunk.columns[0]].index result = _execGeneric(fchunk, dchunk, func, dfilter=dfilter) + result = _inferBroadcast(result, (len(trg_idx), len(targets))) + result = _inferDF(result, cols=targets, index=trg_idx) result = _castResult(result) # update data & flags @@ -231,7 +251,10 @@ class GenericMixin: dfilter = kwargs.get("dfilter", BAD) dchunk, fchunk = self._data[fields].copy(), self._flags[fields].copy() + trg_idx = dchunk[dchunk.columns[0]].index result = _execGeneric(fchunk, dchunk, func, dfilter=dfilter) + result = _inferBroadcast(result, (len(trg_idx), len(targets))) + result = _inferDF(result, cols=targets, index=trg_idx) result = _castResult(result) if len(result.columns) > 1 and len(targets) != len(result.columns):