From ccf2f228ff2185f8efbc317b627252d9e59f45f1 Mon Sep 17 00:00:00 2001 From: Bert Palm <bert.palm@ufz.de> Date: Wed, 1 Mar 2023 14:37:52 +0100 Subject: [PATCH] added slice support for Flags --- CHANGELOG.md | 1 + saqc/core/flags.py | 24 ++++++++++++++++++-- tests/core/test_flags.py | 48 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ddd4980e1..7d348a4b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ SPDX-License-Identifier: GPL-3.0-or-later [List of commits](https://git.ufz.de/rdm-software/saqc/-/compare/v2.3.0...develop) ### Added - Methods `logicalAnd` and `logicalOr` +- `Flags` supports slicing and column selection with `list` or a `pd.Index`. ### Changed ### Removed ### Fixed diff --git a/saqc/core/flags.py b/saqc/core/flags.py index 3e7e7dcc0..b1ec0b580 100644 --- a/saqc/core/flags.py +++ b/saqc/core/flags.py @@ -320,8 +320,28 @@ class Flags: # ---------------------------------------------------------------------- # item access - def __getitem__(self, key: str) -> pd.Series: - return self._data[key].squeeze() + def __getitem__(self, key: str | list | pd.Index) -> pd.Series | Flags: + if isinstance(key, str): + return self._data[key].squeeze() + + if isinstance(key, slice): + key = self.columns[key] + + if isinstance(key, (list, pd.Index)): + # only copy necessary data + data = self._data + try: + self._data = {} + new = self.copy() + finally: + self._data = data + new._data = {k: self._data[k].copy() for k in key} + return new + + raise TypeError( + "Key must be of type str, list or index of string or slice," + f"not {type(key)}." + ) def __setitem__(self, key: SelectT, value: ValueT): # force-KW is only internally available diff --git a/tests/core/test_flags.py b/tests/core/test_flags.py index 74e3f4d90..6b80d4ddb 100644 --- a/tests/core/test_flags.py +++ b/tests/core/test_flags.py @@ -49,6 +49,7 @@ for d in _arrays: def is_equal(f1, f2): + """assert Flags instance equals other""" assert f1.columns.equals(f2.columns) for c in f1.columns: assert test_hist.is_equal(f1.history[c], f2.history[c]) @@ -334,3 +335,50 @@ def test_columns_setter_raises(columns, err): ) with pytest.raises(err): flags.columns = columns + + +@pytest.mark.parametrize( + "data,key,expected", + [ + (dict(a=[0, 1], b=[]), "a", pd.Series([0, 1], dtype=float)), + (dict(a=[0, 1], b=[]), "b", pd.Series([], dtype=float)), + ], +) +def test__getitem__scalar(data, key, expected): + flags = Flags({k: pd.Series(v, dtype=float) for k, v in data.items()}) + result: pd.Series = flags[key] + assert isinstance(result, pd.Series) + assert result.equals(expected) + # assert copying + assert flags[key] is not flags[key] + + +@pytest.mark.parametrize( + "data,key,expected", + [ + (dict(a=[0, 1], b=[]), [], dict()), + (dict(a=[0, 1], b=[]), ["a"], dict(a=[0, 1])), + (dict(a=[0, 1], b=[]), ["a", "b"], dict(a=[0, 1], b=[])), + (dict(a=[0, 1], b=[]), pd.Index([]), dict()), + (dict(a=[0, 1], b=[]), pd.Index(["a"]), dict(a=[0, 1])), + (dict(a=[0, 1], b=[]), pd.Index(["a", "b"]), dict(a=[0, 1], b=[])), + (dict(a=[0, 1], b=[]), slice(None), dict(a=[0, 1], b=[])), + (dict(a=[0, 1], b=[]), slice(0, 1), dict(a=[0, 1])), + (dict(a=[0, 1], b=[]), slice(1, 99), dict(b=[])), + (dict(a=[0, 1], b=[]), slice(5, 99), dict()), + ], +) +def test__getitem__listlike_and_slice(data, key, expected): + flags = Flags({k: pd.Series(v, dtype=float) for k, v in data.items()}) + result: Flags = flags[key] + assert isinstance(result, Flags) + # assert that a new Flags object was created + assert flags[key] is not flags[key] + # assert that internal data is copied + if len(result): + left = result._data[result.columns[0]] + right = flags._data[result.columns[0]] + assert left is not right + + expected = Flags({k: pd.Series(v, dtype=float) for k, v in expected.items()}) + is_equal(result, expected) -- GitLab