diff --git a/saqc/core/register.py b/saqc/core/register.py index 83c845e0abb27edfa93d9ef4ee16e3d32aa85a2b..89bf09bb9ba0d9c7b072db979dfcb69a7ca8a170 100644 --- a/saqc/core/register.py +++ b/saqc/core/register.py @@ -162,7 +162,8 @@ class SaQCFunc(Func): if field not in flagger.getFlags(): flagger = flagger.merge(flagger.initFlags(data=pd.Series(name=field))) - data_in = self._maskData(data, flagger) + columns_in = data.columns.intersection([field]) + data_in = self._maskData(data.loc[:, columns_in], flagger) data_result, flagger_result = self.func(data_in, field, flagger, *self.args, **self.kwargs) @@ -189,20 +190,20 @@ class SaQCFunc(Func): mask_old = flagger_old.isFlagged(flag=to_mask, comparator="==") mask_new = flagger_new.isFlagged(flag=to_mask, comparator="==") - for col, left in data_new.indexes.iteritems(): + for col, right in data_new.indexes.iteritems(): if col not in mask_old: continue - right = mask_old[col].index + left = mask_old[col].index + col_data = data_new[col].values # NOTE: ignore columns with changed indices (assumption: harmonization) if left.equals(right): # NOTE: Don't overwrite data, that was masked, but is not considered # flagged anymore and also respect newly set data on masked locations. mask = mask_old[col].values & mask_new[col].values & data_new[col].isna().values if np.any(mask): - col_data = data_new[col].values col_data[mask] = data_old[col].values[mask] - data_new[col] = col_data - return data_new + data_old[col] = col_data + return data_old # NOTE: diff --git a/test/core/test_core.py b/test/core/test_core.py index 21a9624a3d84e2c435a988e00d0fe1deed2be976..4b1a166c67db02aa30c3c43b4cb55489cf684316 100644 --- a/test/core/test_core.py +++ b/test/core/test_core.py @@ -79,7 +79,7 @@ def test_assignVariable(flagger): pdata, pflagger = SaQC(flagger, data).flagAll(var1).flagAll(var2).getResult() pflags = pflagger.getFlags() - assert (pflags.columns == [var1, var2]).all() + assert (set(pflags.columns) == {var1, var2}) assert pflagger.isFlagged(var2).empty @@ -105,7 +105,6 @@ def test_masking(data, flagger): test if flagged values are exluded during the preceding tests """ flagger = flagger.initFlags(data) - flags = flagger.getFlags() var1 = 'var1' mn = min(data[var1]) mx = max(data[var1]) / 2