From ccf2f228ff2185f8efbc317b627252d9e59f45f1 Mon Sep 17 00:00:00 2001
From: Bert Palm <bert.palm@ufz.de>
Date: Wed, 1 Mar 2023 14:37:52 +0100
Subject: [PATCH] added slice support for Flags

---
 CHANGELOG.md             |  1 +
 saqc/core/flags.py       | 24 ++++++++++++++++++--
 tests/core/test_flags.py | 48 ++++++++++++++++++++++++++++++++++++++++
 3 files changed, 71 insertions(+), 2 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index ddd4980e1..7d348a4b5 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -10,6 +10,7 @@ SPDX-License-Identifier: GPL-3.0-or-later
 [List of commits](https://git.ufz.de/rdm-software/saqc/-/compare/v2.3.0...develop)
 ### Added
 - Methods `logicalAnd` and `logicalOr`
+- `Flags` supports slicing and column selection with `list` or a `pd.Index`.
 ### Changed
 ### Removed
 ### Fixed
diff --git a/saqc/core/flags.py b/saqc/core/flags.py
index 3e7e7dcc0..b1ec0b580 100644
--- a/saqc/core/flags.py
+++ b/saqc/core/flags.py
@@ -320,8 +320,28 @@ class Flags:
     # ----------------------------------------------------------------------
     # item access
 
-    def __getitem__(self, key: str) -> pd.Series:
-        return self._data[key].squeeze()
+    def __getitem__(self, key: str | list | pd.Index) -> pd.Series | Flags:
+        if isinstance(key, str):
+            return self._data[key].squeeze()
+
+        if isinstance(key, slice):
+            key = self.columns[key]
+
+        if isinstance(key, (list, pd.Index)):
+            # only copy necessary data
+            data = self._data
+            try:
+                self._data = {}
+                new = self.copy()
+            finally:
+                self._data = data
+            new._data = {k: self._data[k].copy() for k in key}
+            return new
+
+        raise TypeError(
+            "Key must be of type str, list or index of string or slice,"
+            f"not {type(key)}."
+        )
 
     def __setitem__(self, key: SelectT, value: ValueT):
         # force-KW is only internally available
diff --git a/tests/core/test_flags.py b/tests/core/test_flags.py
index 74e3f4d90..6b80d4ddb 100644
--- a/tests/core/test_flags.py
+++ b/tests/core/test_flags.py
@@ -49,6 +49,7 @@ for d in _arrays:
 
 
 def is_equal(f1, f2):
+    """assert Flags instance equals other"""
     assert f1.columns.equals(f2.columns)
     for c in f1.columns:
         assert test_hist.is_equal(f1.history[c], f2.history[c])
@@ -334,3 +335,50 @@ def test_columns_setter_raises(columns, err):
     )
     with pytest.raises(err):
         flags.columns = columns
+
+
+@pytest.mark.parametrize(
+    "data,key,expected",
+    [
+        (dict(a=[0, 1], b=[]), "a", pd.Series([0, 1], dtype=float)),
+        (dict(a=[0, 1], b=[]), "b", pd.Series([], dtype=float)),
+    ],
+)
+def test__getitem__scalar(data, key, expected):
+    flags = Flags({k: pd.Series(v, dtype=float) for k, v in data.items()})
+    result: pd.Series = flags[key]
+    assert isinstance(result, pd.Series)
+    assert result.equals(expected)
+    # assert copying
+    assert flags[key] is not flags[key]
+
+
+@pytest.mark.parametrize(
+    "data,key,expected",
+    [
+        (dict(a=[0, 1], b=[]), [], dict()),
+        (dict(a=[0, 1], b=[]), ["a"], dict(a=[0, 1])),
+        (dict(a=[0, 1], b=[]), ["a", "b"], dict(a=[0, 1], b=[])),
+        (dict(a=[0, 1], b=[]), pd.Index([]), dict()),
+        (dict(a=[0, 1], b=[]), pd.Index(["a"]), dict(a=[0, 1])),
+        (dict(a=[0, 1], b=[]), pd.Index(["a", "b"]), dict(a=[0, 1], b=[])),
+        (dict(a=[0, 1], b=[]), slice(None), dict(a=[0, 1], b=[])),
+        (dict(a=[0, 1], b=[]), slice(0, 1), dict(a=[0, 1])),
+        (dict(a=[0, 1], b=[]), slice(1, 99), dict(b=[])),
+        (dict(a=[0, 1], b=[]), slice(5, 99), dict()),
+    ],
+)
+def test__getitem__listlike_and_slice(data, key, expected):
+    flags = Flags({k: pd.Series(v, dtype=float) for k, v in data.items()})
+    result: Flags = flags[key]
+    assert isinstance(result, Flags)
+    # assert that a new Flags object was created
+    assert flags[key] is not flags[key]
+    # assert that internal data is copied
+    if len(result):
+        left = result._data[result.columns[0]]
+        right = flags._data[result.columns[0]]
+        assert left is not right
+
+    expected = Flags({k: pd.Series(v, dtype=float) for k, v in expected.items()})
+    is_equal(result, expected)
-- 
GitLab