diff --git a/saqc/flagger/flags.py b/saqc/flagger/flags.py index c06d8da4e6f884bcee9876524839278bac96d2bd..ee365703c493024fcfd09510d73cdb0621ee05eb 100644 --- a/saqc/flagger/flags.py +++ b/saqc/flagger/flags.py @@ -178,8 +178,15 @@ class Flags: raise KeyError("a single 'column' or a tuple of 'mask, column' must be passt") mask, key = key - # raises (correct) KeyError tmp = pd.Series(UNTOUCHED, index=self._data[key].index, dtype=float) + + # make a mask from an index, because it seems + # that passing an index is a very common workflow + if isinstance(mask, pd.Index): + mask = pd.Series(True, index=mask, dtype=bool) + mask = mask.reindex(tmp.index, fill_value=False) + + # raises (correct) KeyError try: tmp[mask] = value except Exception: diff --git a/test/flagger/test_flags.py b/test/flagger/test_flags.py index 1f68c115bfb9079d838eb09646b182ec2cfa4791..83c156011099db48ced9d0e4d520fb64320fad31 100644 --- a/test/flagger/test_flags.py +++ b/test/flagger/test_flags.py @@ -188,6 +188,34 @@ def test_set_flags_with_mask(data: np.array): flags[mask, c] = vector +@pytest.mark.parametrize('data', data) +def test_set_flags_with_index(data: np.array): + flags = Flags(data) + + for c in flags.columns: + var = flags[c] + mask = var == UNFLAGGED + index = mask[mask].index + + scalar = 222. + flags[index, c] = scalar + assert all(flags[c].loc[mask] == 222.) + assert all(flags[c].loc[~mask] != 222.) + + vector = var.copy() + vector[:] = 333. + flags[index, c] = vector + assert all(flags[c].loc[mask] == 333.) + assert all(flags[c].loc[~mask] != 333.) + + # works with any that pandas eat, eg with numpy + vector[:] = 444. + vector = vector.to_numpy() + flags[index, c] = vector + assert all(flags[c].loc[mask] == 444.) + assert all(flags[c].loc[~mask] != 444.) + + def test_cache(): arr = np.array([ [0, 0, 0, 0],