From bd1dc8606c6b94f30654d326e8d89a85e547f934 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sebastian=20M=C3=BCller?= <mueller.seb@posteo.de>
Date: Tue, 29 Aug 2023 17:43:07 +0200
Subject: [PATCH] masking adapter: make mask adapter work with masked arrays

---
 src/finam/adapters/masking.py  | 27 ++++++++++++++-----
 src/finam/data/tools.py        | 24 +++++++++++++++++
 tests/adapters/test_masking.py | 49 +++++++++++++++++++++++++++++++++-
 3 files changed, 92 insertions(+), 8 deletions(-)

diff --git a/src/finam/adapters/masking.py b/src/finam/adapters/masking.py
index ba5ffe50..52102015 100644
--- a/src/finam/adapters/masking.py
+++ b/src/finam/adapters/masking.py
@@ -42,6 +42,7 @@ class Masking(Adapter):
         self._canonical_mask = None
         self._sup_grid = None
         self._sub_grid = None
+        self.masked = True
 
     def _get_data(self, time, target):
         """Get the output's data-set for the given time.
@@ -86,12 +87,13 @@ class Masking(Adapter):
         self._sup_grid = in_info.grid
         self._sub_grid = info.grid
 
+        # check no-data value
+        if self.nodata is None:
+            self.nodata = out_nodata if out_nodata is not None else in_nodata
+
         # create_selection
         if self._sub_grid.mask is not None:
             self._canonical_mask = self._sub_grid.to_canonical(self._sub_grid.mask)
-            # check no-data value
-            if self.nodata is None:
-                self.nodata = out_nodata if out_nodata is not None else in_nodata
             if self.nodata is None:
                 with ErrorLogger(self.logger):
                     raise FinamMetaDataError("Couldn't determine no-data value.")
@@ -100,8 +102,10 @@ class Masking(Adapter):
 
         # return output info
         self._canonical_mask = None
-        if out_nodata is None:
+        if self.nodata is None:
+            self.masked = False  # no masked array created
             return in_info.copy_with(grid=info.grid)
+
         # if missing value was present, add it again
         return in_info.copy_with(grid=info.grid, missing_value=self.nodata)
 
@@ -118,6 +122,15 @@ class Masking(Adapter):
     def _transform(self, data):
         if self._canonical_mask is not None:
             data = np.copy(self._sup_grid.to_canonical(data))
-            data[self._canonical_mask] = tools.UNITS.Quantity(self.nodata, data.units)
-            return self._sub_grid.from_canonical(data)
-        return self._sub_grid.from_canonical(self._sup_grid.to_canonical(data))
+            return self._sub_grid.from_canonical(
+                tools.to_masked(data, mask=self._canonical_mask, fill_value=self.nodata)
+            )
+
+        out = self._sub_grid.from_canonical(self._sup_grid.to_canonical(data))
+        # if missing_value in info we should create a masked array
+        # return unmasked array if info indicates unmasked data
+        return (
+            tools.to_masked(out, fill_value=self.nodata)
+            if self.masked
+            else tools.filled(out)
+        )
diff --git a/src/finam/data/tools.py b/src/finam/data/tools.py
index ac13348b..e9339c1a 100644
--- a/src/finam/data/tools.py
+++ b/src/finam/data/tools.py
@@ -500,6 +500,30 @@ def filled(xdata, fill_value=None):
     return xdata.filled(fill_value)
 
 
+def to_masked(xdata, **kwargs):
+    """
+    Return a masked version of the data.
+
+    Parameters
+    ----------
+    xdata : :class:`pint.Quantity` or :class:`numpy.ndarray` or :class:`numpy.ma.MaskedArray`
+        The reference object input.
+    **kwargs
+        keyword arguments forwarded to :any:`numpy.ma.array`
+
+    Returns
+    -------
+    pint.Quantity or numpy.ma.MaskedArray
+        New object with the same shape and type but as a masked array.
+        Units will be taken from the input if present.
+    """
+    if is_masked_array(xdata) and not kwargs:
+        return xdata
+    if is_quantified(xdata):
+        return UNITS.Quantity(np.ma.array(xdata.magnitude, **kwargs), xdata.units)
+    return np.ma.array(xdata, **kwargs)
+
+
 def quantify(xdata, units=None):
     """
     Quantifies data.
diff --git a/tests/adapters/test_masking.py b/tests/adapters/test_masking.py
index b073c3d7..ab38da30 100644
--- a/tests/adapters/test_masking.py
+++ b/tests/adapters/test_masking.py
@@ -52,7 +52,54 @@ class TestMasking(unittest.TestCase):
 
         composition.connect()
 
-        self.assertAlmostEqual(sink.data["Input"][0][0, 0].magnitude, -9999)
+        self.assertTrue(sink.data["Input"][0].magnitude.mask[0, 0])
+        self.assertAlmostEqual(sink.data["Input"][0].fill_value, -9999)
+        self.assertAlmostEqual(sink.data["Input"][0][0, 1].magnitude, 2.0)
+
+    def test_masked_arrays(self):
+        time = datetime(2000, 1, 1)
+
+        mask = [
+            [True, False, True],
+            [False, False, True],
+            [False, False, False],
+            [True, False, False],
+        ]
+
+        in_grid = EsriGrid(ncols=3, nrows=4, order="F")
+        out_grid = EsriGrid(ncols=3, nrows=4, mask=mask, order="F")
+
+        # missing_value to indicate no-data value in masking adapter
+        in_info = Info(time=time, grid=in_grid, units="m", missing_value=-9999)
+
+        in_data = np.ma.zeros(shape=in_info.grid.data_shape, order=in_info.grid.order)
+        in_data.mask = mask
+        in_data.fill_value = -9999
+        in_data[0, 0] = 1.0
+        in_data[0, 1] = 2.0
+
+        source = generators.CallbackGenerator(
+            callbacks={"Output": (lambda t: in_data, in_info)},
+            start=datetime(2000, 1, 1),
+            step=timedelta(days=1),
+        )
+
+        sink = debug.DebugConsumer(
+            {"Input": Info(None, grid=out_grid, units=None)},
+            start=datetime(2000, 1, 1),
+            step=timedelta(days=1),
+        )
+
+        composition = Composition([source, sink], log_level="DEBUG")
+        composition.initialize()
+
+        # no-data value from missing-value (from source)
+        source.outputs["Output"] >> Masking(nodata=None) >> sink.inputs["Input"]
+
+        composition.connect()
+
+        self.assertTrue(sink.data["Input"][0].magnitude.mask[0, 0])
+        self.assertAlmostEqual(sink.data["Input"][0].fill_value, -9999)
         self.assertAlmostEqual(sink.data["Input"][0][0, 1].magnitude, 2.0)
 
 
-- 
GitLab