diff --git a/saqc/funcs/spikes_detection.py b/saqc/funcs/spikes_detection.py index 9b1ae4274ca32faab7c4d188c184cf691c691434..666f97da7f4c395e55f0b2d0129d461b4cde3484 100644 --- a/saqc/funcs/spikes_detection.py +++ b/saqc/funcs/spikes_detection.py @@ -141,7 +141,12 @@ def spikes_flagMultivarScores(data, field, flagger, fields, trafo='normScale', a expfit_binning='auto', stray_partition=None, stray_partition_min=0, **kwargs): - trafo = composeFunction(trafo.split(',')) + trafo_list = trafo.split(',') + if len(trafo_list) == 1: + trafo_list = trafo_list * len(fields) + trafo_dict = {var_name: composeFunction(traffo.split('-')) for (var_name, traffo) + in dict(zip(fields, trafo_list)).items()} + # data fransformation/extraction val_frame = data[fields[0]] @@ -153,7 +158,7 @@ def spikes_flagMultivarScores(data, field, flagger, fields, trafo='normScale', a ) val_frame.dropna(inplace=True) - val_frame = val_frame.transform(trafo) + val_frame = val_frame.transform(trafo_dict) if threshing == 'stray': to_flag_index = _stray(val_frame,