From 886861186c5aa0190750b5b3619e8f14e50ccb1a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sebastian=20M=C3=BCller?= <mueller.seb@posteo.de>
Date: Tue, 4 Jul 2023 11:47:28 +0200
Subject: [PATCH] Grid: respect pint when in from_compressed; better mask
 setting

---
 src/finam/data/grid_base.py  | 11 ++++++++---
 src/finam/data/grid_spec.py  | 18 +++---------------
 src/finam/data/grid_tools.py | 36 +++++++++++++++++++++++++++++++++++-
 3 files changed, 46 insertions(+), 19 deletions(-)

diff --git a/src/finam/data/grid_base.py b/src/finam/data/grid_base.py
index 756b5d96..eddb89d2 100644
--- a/src/finam/data/grid_base.py
+++ b/src/finam/data/grid_base.py
@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
 from pathlib import Path
 
 import numpy as np
+import pint
 from pyevtk.hl import gridToVTK, unstructuredGridToVTK
 
 from .grid_tools import (
@@ -287,12 +288,16 @@ class Grid(GridBase):
             Grid specific Data.
         """
         if self.mask is None:
+            # reshape works with quantities
             return np.reshape(data, self.data_shape, order=self.order)
-        data = np.asanyarray(data)
-        out = np.empty(self.data_size, dtype=data.dtype)
+        if isinstance(data, pint.Quantity):
+            out = np.empty(self.data_size, dtype=data.dtype) * data.units
+        else:
+            data = np.asarray(data)
+            out = np.empty(self.data_size, dtype=data.dtype)
         mask = np.reshape(self.mask, -1, order=self.order)
-        out[mask] = nodata
         out[~mask] = data
+        out[mask] = nodata
         return np.reshape(out, self.data_shape, order=self.order)
 
     @property
diff --git a/src/finam/data/grid_spec.py b/src/finam/data/grid_spec.py
index a65d9e5e..ed684ff2 100644
--- a/src/finam/data/grid_spec.py
+++ b/src/finam/data/grid_spec.py
@@ -10,10 +10,10 @@ from .grid_tools import (
     CellType,
     Location,
     check_axes_monotonicity,
-    check_mask_shape,
     gen_axes,
     prepare_vtk_data,
     prepare_vtk_kwargs,
+    set_mask,
 )
 
 
@@ -119,13 +119,7 @@ class RectilinearGrid(StructuredGrid):
 
     @mask.setter
     def mask(self, mask):
-        if mask is not None:
-            mask = np.asarray(mask, dtype=bool)
-            if not check_mask_shape(mask, self.data_shape):
-                msg = "Grid.mask: given mask has wrong shape."
-                msg += f" Expected: {self.data_shape}, Got: {np.shape(mask)}"
-                raise ValueError(msg)
-        self._mask = mask
+        self._mask = set_mask(self, mask)
 
     def to_unstructured(self):
         """
@@ -478,13 +472,7 @@ class UnstructuredGrid(Grid):
 
     @mask.setter
     def mask(self, mask):
-        if mask is not None:
-            mask = np.asarray(mask, dtype=bool)
-            if not check_mask_shape(mask, self.data_shape):
-                msg = "Grid.mask: given mask has wrong shape."
-                msg += f" Expected: {self.data_shape}, Got: {np.shape(mask)}"
-                raise ValueError(msg)
-        self._mask = mask
+        self._mask = set_mask(self, mask)
 
     @property
     def dim(self):
diff --git a/src/finam/data/grid_tools.py b/src/finam/data/grid_tools.py
index dc76adc4..bedf94fc 100644
--- a/src/finam/data/grid_tools.py
+++ b/src/finam/data/grid_tools.py
@@ -49,7 +49,41 @@ def check_mask_shape(mask, shape):
     # None is always ok
     if mask is None:
         return True
-    return mask.shape == shape
+    return np.shape(mask) == shape
+
+
+def set_mask(grid, mask):
+    """
+    Set and check mask for given grid.
+
+    Parameters
+    ----------
+    grid : Grid
+        Grid to be masked.
+    mask : np.ndarray or None
+        Given mask.
+
+    Returns
+    -------
+    np.ndarray or None
+        Mask to set
+
+    Raises
+    ------
+    ValueError
+        If mask shape is not matching grid.
+    """
+    if mask is not None:
+        mask = np.asarray(mask, dtype=bool)
+        if not check_mask_shape(mask, grid.data_shape):
+            msg = "Grid.mask: given mask has wrong shape."
+            msg += f" Expected: {grid.data_shape}, Got: {np.shape(mask)}"
+            raise ValueError(msg)
+        if grid.order == "C":
+            return np.ascontiguousarray(mask)
+        if grid.order == "F":
+            return np.asfortranarray(mask)
+    return mask
 
 
 def point_order(order, axes_reversed=False):
-- 
GitLab