From 4eb2ca53f59562823759e4169372c1a3a57ac039 Mon Sep 17 00:00:00 2001 From: luenensc <peter.luenenschloss@ufz.de> Date: Sat, 9 Mar 2024 09:58:34 +0100 Subject: [PATCH] added trg broadcasting and numpy array support for generics --- CHANGELOG.md | 1 + saqc/funcs/generic.py | 20 +++++++++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f09eccdf..ec5745c7a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ SPDX-License-Identifier: GPL-3.0-or-later ## Unreleased [List of commits](https://git.ufz.de/rdm-software/saqc/-/compare/v2.5.0...develop) ### Added +- `generics`: target broadcasting and numpy array support - `flagGeneric`: target broadcasting - `SaQC`: automatic translation of incoming flags - Option to change the flagging scheme after initialization diff --git a/saqc/funcs/generic.py b/saqc/funcs/generic.py index 94fa02eea..165190ce8 100644 --- a/saqc/funcs/generic.py +++ b/saqc/funcs/generic.py @@ -77,10 +77,25 @@ def _execGeneric( return func(*cols) +def _inferBroadcast(obj, trg_shape) -> pd.DataFrame: + # simple single value broadcasting + if pd.api.types.is_scalar(obj): + return np.full(trg_shape, obj) + return obj + +def _inferDF(obj, cols, index): + # infer dataframe if result is numpy array of fitting shape + if isinstance(obj, np.ndarray): + lc = len(cols) + li = len(index) + if (obj.shape == (li,lc)) or (obj.shape == (li,)): + return pd.DataFrame(obj, columns=cols, index=index) + return obj + def _castResult(obj) -> DictOfSeries: # Note: the actual keys aka. column names # we use here to create a DictOfSeries - # are never used, and only exists temporary. + # are never used and only exist temporarily. if isinstance(obj, pd.Series): return DictOfSeries({"0": obj}) @@ -151,7 +166,10 @@ class GenericMixin: targets = fields if target is None else toSequence(target) dchunk, fchunk = self._data[fields].copy(), self._flags[fields].copy() + trg_idx = dchunk[dchunk.columns[0]].index result = _execGeneric(fchunk, dchunk, func, dfilter=dfilter) + result = _inferBroadcast(result, (len(trg_idx), len(targets))) + result = _inferDF(result, cols=targets, index=trg_idx) result = _castResult(result) # update data & flags -- GitLab