From 4eb2ca53f59562823759e4169372c1a3a57ac039 Mon Sep 17 00:00:00 2001
From: luenensc <peter.luenenschloss@ufz.de>
Date: Sat, 9 Mar 2024 09:58:34 +0100
Subject: [PATCH] added trg broadcasting and numpy array support for generics

---
 CHANGELOG.md          |  1 +
 saqc/funcs/generic.py | 20 +++++++++++++++++++-
 2 files changed, 20 insertions(+), 1 deletion(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 0f09eccdf..ec5745c7a 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -8,6 +8,7 @@ SPDX-License-Identifier: GPL-3.0-or-later
 ## Unreleased
 [List of commits](https://git.ufz.de/rdm-software/saqc/-/compare/v2.5.0...develop)
 ### Added
+- `generics`: target broadcasting and numpy array support
 - `flagGeneric`: target broadcasting
 - `SaQC`: automatic translation of incoming flags
 - Option to change the flagging scheme after initialization
diff --git a/saqc/funcs/generic.py b/saqc/funcs/generic.py
index 94fa02eea..165190ce8 100644
--- a/saqc/funcs/generic.py
+++ b/saqc/funcs/generic.py
@@ -77,10 +77,25 @@ def _execGeneric(
     return func(*cols)
 
 
+def _inferBroadcast(obj, trg_shape) -> pd.DataFrame:
+    # simple single value broadcasting
+    if pd.api.types.is_scalar(obj):
+        return np.full(trg_shape, obj)
+    return obj
+
+def _inferDF(obj, cols, index):
+    # infer dataframe if result is numpy array of fitting shape
+    if isinstance(obj, np.ndarray):
+        lc = len(cols)
+        li = len(index)
+        if (obj.shape == (li,lc)) or (obj.shape == (li,)):
+            return pd.DataFrame(obj, columns=cols, index=index)
+    return obj
+
 def _castResult(obj) -> DictOfSeries:
     # Note: the actual keys aka. column names
     # we use here to create a DictOfSeries
-    # are never used, and only exists temporary.
+    # are never used and only exist temporarily.
 
     if isinstance(obj, pd.Series):
         return DictOfSeries({"0": obj})
@@ -151,7 +166,10 @@ class GenericMixin:
         targets = fields if target is None else toSequence(target)
 
         dchunk, fchunk = self._data[fields].copy(), self._flags[fields].copy()
+        trg_idx = dchunk[dchunk.columns[0]].index
         result = _execGeneric(fchunk, dchunk, func, dfilter=dfilter)
+        result = _inferBroadcast(result, (len(trg_idx), len(targets)))
+        result = _inferDF(result, cols=targets, index=trg_idx)
         result = _castResult(result)
 
         # update data & flags
-- 
GitLab