From 886861186c5aa0190750b5b3619e8f14e50ccb1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20M=C3=BCller?= <mueller.seb@posteo.de> Date: Tue, 4 Jul 2023 11:47:28 +0200 Subject: [PATCH] Grid: respect pint when in from_compressed; better mask setting --- src/finam/data/grid_base.py | 11 ++++++++--- src/finam/data/grid_spec.py | 18 +++--------------- src/finam/data/grid_tools.py | 36 +++++++++++++++++++++++++++++++++++- 3 files changed, 46 insertions(+), 19 deletions(-) diff --git a/src/finam/data/grid_base.py b/src/finam/data/grid_base.py index 756b5d96..eddb89d2 100644 --- a/src/finam/data/grid_base.py +++ b/src/finam/data/grid_base.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from pathlib import Path import numpy as np +import pint from pyevtk.hl import gridToVTK, unstructuredGridToVTK from .grid_tools import ( @@ -287,12 +288,16 @@ class Grid(GridBase): Grid specific Data. """ if self.mask is None: + # reshape works with quantities return np.reshape(data, self.data_shape, order=self.order) - data = np.asanyarray(data) - out = np.empty(self.data_size, dtype=data.dtype) + if isinstance(data, pint.Quantity): + out = np.empty(self.data_size, dtype=data.dtype) * data.units + else: + data = np.asarray(data) + out = np.empty(self.data_size, dtype=data.dtype) mask = np.reshape(self.mask, -1, order=self.order) - out[mask] = nodata out[~mask] = data + out[mask] = nodata return np.reshape(out, self.data_shape, order=self.order) @property diff --git a/src/finam/data/grid_spec.py b/src/finam/data/grid_spec.py index a65d9e5e..ed684ff2 100644 --- a/src/finam/data/grid_spec.py +++ b/src/finam/data/grid_spec.py @@ -10,10 +10,10 @@ from .grid_tools import ( CellType, Location, check_axes_monotonicity, - check_mask_shape, gen_axes, prepare_vtk_data, prepare_vtk_kwargs, + set_mask, ) @@ -119,13 +119,7 @@ class RectilinearGrid(StructuredGrid): @mask.setter def mask(self, mask): - if mask is not None: - mask = np.asarray(mask, dtype=bool) - if not check_mask_shape(mask, self.data_shape): - msg = "Grid.mask: given mask has wrong shape." - msg += f" Expected: {self.data_shape}, Got: {np.shape(mask)}" - raise ValueError(msg) - self._mask = mask + self._mask = set_mask(self, mask) def to_unstructured(self): """ @@ -478,13 +472,7 @@ class UnstructuredGrid(Grid): @mask.setter def mask(self, mask): - if mask is not None: - mask = np.asarray(mask, dtype=bool) - if not check_mask_shape(mask, self.data_shape): - msg = "Grid.mask: given mask has wrong shape." - msg += f" Expected: {self.data_shape}, Got: {np.shape(mask)}" - raise ValueError(msg) - self._mask = mask + self._mask = set_mask(self, mask) @property def dim(self): diff --git a/src/finam/data/grid_tools.py b/src/finam/data/grid_tools.py index dc76adc4..bedf94fc 100644 --- a/src/finam/data/grid_tools.py +++ b/src/finam/data/grid_tools.py @@ -49,7 +49,41 @@ def check_mask_shape(mask, shape): # None is always ok if mask is None: return True - return mask.shape == shape + return np.shape(mask) == shape + + +def set_mask(grid, mask): + """ + Set and check mask for given grid. + + Parameters + ---------- + grid : Grid + Grid to be masked. + mask : np.ndarray or None + Given mask. + + Returns + ------- + np.ndarray or None + Mask to set + + Raises + ------ + ValueError + If mask shape is not matching grid. + """ + if mask is not None: + mask = np.asarray(mask, dtype=bool) + if not check_mask_shape(mask, grid.data_shape): + msg = "Grid.mask: given mask has wrong shape." + msg += f" Expected: {grid.data_shape}, Got: {np.shape(mask)}" + raise ValueError(msg) + if grid.order == "C": + return np.ascontiguousarray(mask) + if grid.order == "F": + return np.asfortranarray(mask) + return mask def point_order(order, axes_reversed=False): -- GitLab