diff --git a/saqc/funcs/harm_functions.py b/saqc/funcs/harm_functions.py index 9c50bfaba7f5c52253f1ccd467b00135c58d6782..afa30e616483150869a31af7402c610be9bf83d5 100644 --- a/saqc/funcs/harm_functions.py +++ b/saqc/funcs/harm_functions.py @@ -7,7 +7,7 @@ import logging from saqc.funcs.functions import flagMissing from saqc.funcs.register import register -from saqc.lib.tools import toSequence +from saqc.lib.tools import toSequence, funcInput_2_func # todo: frequencie estimation function @@ -44,10 +44,10 @@ def harmWrapper(heap={}): freq, inter_method, reshape_method, - inter_agg=np.mean, + inter_agg="mean", inter_order=1, inter_downcast=False, - reshape_agg=max, + reshape_agg="max", reshape_missing_flag=None, reshape_shift_comment=False, drop_flags=None, @@ -55,6 +55,10 @@ def harmWrapper(heap={}): **kwargs ): + # get funcs from strings: + inter_agg = funcInput_2_func(inter_agg) + reshape_agg = funcInput_2_func(reshape_agg) + # for some tingle tangle reasons, resolving the harmonization will not be sound, if not all missing/np.nan # values get flagged initially: data, flagger = flagMissing( @@ -830,7 +834,7 @@ def linear2Grid(data, field, flagger, freq, flag_assignment_method='nearest_agg' @register('harmonize_interpolate2Grid') def interpolate2Grid(data, field, flagger, freq, interpolation_method, interpolation_order=1, - flag_assignment_method='nearest_agg', flag_agg_func=max, drop_flags=None, **kwargs): + flag_assignment_method='nearest_agg', flag_agg_func="max", drop_flags=None, **kwargs): return harmonize( data, field, @@ -845,9 +849,14 @@ def interpolate2Grid(data, field, flagger, freq, interpolation_method, interpola @register('harmonize_downsample') -def downsample(data, field, flagger, sample_freq, agg_freq, sample_func=np.mean, agg_func=np.mean, +def downsample(data, field, flagger, sample_freq, agg_freq, sample_func="mean", agg_func="mean", invalid_flags=None, max_invalid=np.inf, **kwargs): + agg_func = funcInput_2_func(agg_func) + + if sample_func is not None: + sample_func = funcInput_2_func(sample_func) + # define the "fastest possible" aggregator if sample_func is None: if max_invalid < np.inf: diff --git a/saqc/lib/tools.py b/saqc/lib/tools.py index a3950cc91c6137145c5c4d8e28ea283c3069942d..321b59c2aaf2718d3a0cb7a73919d93f71097376 100644 --- a/saqc/lib/tools.py +++ b/saqc/lib/tools.py @@ -10,6 +10,15 @@ import numba as nb from saqc.lib.types import T +STRING_2_FUNC = { + 'sum': np.sum, + 'mean': np.mean, + 'median': np.median, + 'min': np.min, + 'max': np.max, + 'first': pd.Series(np.nan, index=pd.DatetimeIndex([])).resample('0min').first, + 'last': pd.Series(np.nan, index=pd.DatetimeIndex([])).resample('0min').last +} def assertScalar(name, value, optional=False): if (not np.isscalar(value)) and (value is not None) and (optional is True): @@ -286,3 +295,19 @@ def assertSingleColumns(df, argname=""): raise TypeError( f"given pd.DataFrame {argname} is not allowed to have a muliindex on columns" ) + +def funcInput_2_func(func): + """ + Aggregation functions passed by the user, are selected by looking them up in the STRING_2_DICT dictionary - + But since there are wrappers, that dynamically generate aggregation functions and pass those on ,the parameter + interfaces must as well be capable of processing real functions passed. This function does that. + + :param func: A key to the STRING_2_FUNC dict, or an actual function + """ + # if input is a callable - than just pass it: + if hasattr(func, "__call__"): + return func + elif func in STRING_2_FUNC.keys(): + return STRING_2_FUNC[func] + else: + raise ValueError("Function input not a callable nor a known key to internal the func dictionary.") diff --git a/test/funcs/test_harm_funcs.py b/test/funcs/test_harm_funcs.py index 1170981b10589db94c89ef4937c019fccea22bc1..d6e34a3e5dc64a37b1d0a5d25d3a1a6ead007ed1 100644 --- a/test/funcs/test_harm_funcs.py +++ b/test/funcs/test_harm_funcs.py @@ -161,7 +161,7 @@ def test_harmSingleVarInterpolations(data, flagger, interpolation, freq): interpolation, "fshift", reshape_shift_comment=False, - inter_agg=np.sum, + inter_agg="sum", ) if interpolation is "fshift": @@ -261,8 +261,8 @@ def test_multivariatHarmonization(multi_data, flagger, shift_comment): freq, "bagg", "bshift", - inter_agg=sum, - reshape_agg=max, + inter_agg="sum", + reshape_agg="max", reshape_shift_comment=shift_comment, ) @@ -295,14 +295,14 @@ def test_gridInterpolation(data, method): data = data.squeeze() # we are just testing if the interpolation gets passed to the series without causing an error: _interpolateGrid( - data, freq, method, order=1, agg_method=sum, downcast_interpolation=True + data, freq, method, order=1, agg_method="sum", downcast_interpolation=True ) if method == "polynomial": _interpolateGrid( - data, freq, method, order=2, agg_method=sum, downcast_interpolation=True + data, freq, method, order=2, agg_method="sum", downcast_interpolation=True ) _interpolateGrid( - data, freq, method, order=10, agg_method=sum, downcast_interpolation=True + data, freq, method, order=10, agg_method="sum", downcast_interpolation=True ) data = _insertGrid(data, freq) _interpolate(data, method, inter_limit=3) @@ -335,10 +335,10 @@ def test_wrapper(data, flagger): field = data.columns[0] freq = '15min' flagger = flagger.initFlags(data) - downsample(data, field, flagger, '15min', '30min', agg_func=np.sum, sample_func=np.mean) + downsample(data, field, flagger, '15min', '30min', agg_func="sum", sample_func="mean") - linear2Grid(data, field, flagger, freq, flag_assignment_method='nearest_agg', flag_agg_func=max, + linear2Grid(data, field, flagger, freq, flag_assignment_method='nearest_agg', flag_agg_func="max", drop_flags=None) - aggregate2Grid(data, field, flagger, freq, agg_func=sum, agg_method='nearest_agg', - flag_agg_func=max, drop_flags=None) + aggregate2Grid(data, field, flagger, freq, agg_func="sum", agg_method='nearest_agg', + flag_agg_func="max", drop_flags=None) shift2Grid(data, field, flagger, freq, shift_method='nearest_shift', drop_flags=None)