diff --git a/CHANGELOG.md b/CHANGELOG.md index 3fc3ff4f95d47bcc7db6f9327949c005a36c8567..a6b95ae2178b2293444a95273aa3c821df3787d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,12 @@ * `get_cells_matrix`: convert `cells_connectivity` or `cells_definition` back to the default cells matrix used in the Grid class (can be used to convert VTK-grids into FINAM-grids) * `INV_VTK_TYPE_MAP`: inverse mapping to `VTK_TYPE_MAP` - FINAM cell type to VTK cell type * `VTK_CELL_DIM`: parametric dimension for each VTK cell type +* Grid class now reusable when having different data locations and better grid type casting (!278) + * added `copy` method to grids with optional argument `deep` (`False` by default) to create a copy of a grid + * added setter for `data_location` in order to set a new data location (e.g. after copying a grid) + * added class attribute `valid_locations` in order to check the set data location (esri-grid only supports cells, unstructured-points only support points) + * added missing casting methods to convert esri to uniform and uniform to rectilinear (when you want to use point data on an esri-grid, you can cast it to uniform first) + * added `axes_attributes` also to unstructured grids ### Documentation diff --git a/src/finam/data/grid_base.py b/src/finam/data/grid_base.py index 8d51159123f618c2ddae6ad34b7c6808e68c73e1..5e4f2abf80af8f35ba04135a2b9c3965f952b570 100644 --- a/src/finam/data/grid_base.py +++ b/src/finam/data/grid_base.py @@ -1,4 +1,6 @@ """Grid abstract base classes for FINAM.""" + +import copy as cp from abc import ABC, abstractmethod from pathlib import Path @@ -34,6 +36,22 @@ class GridBase(ABC): def dim(self): """int: Dimension of the grid or data.""" + def copy(self, deep=False): + """ + Copy of this grid. + + Parameters + ---------- + deep : bool, optional + If false, only a shallow copy is returned to save memory, by default False + + Returns + ------- + Grid + The grid copy. + """ + return cp.deepcopy(self) if deep else cp.copy(self) + def to_canonical(self, data): """Convert grid specific data to canonical form.""" return data @@ -51,6 +69,9 @@ class GridBase(ABC): class Grid(GridBase): """Abstract grid specification.""" + valid_locations = (Location.CELLS, Location.POINTS) + """tuple: Valid locations for the grid.""" + @property @abstractmethod def crs(self): @@ -124,6 +145,11 @@ class Grid(GridBase): def data_location(self): """Location of the associated data (either CELLS or POINTS).""" + @data_location.setter + @abstractmethod + def data_location(self, data_location): + """Set location of the associated data (either CELLS or POINTS).""" + @property def data_points(self): """Points of the associated data (either cell_centers or points).""" @@ -151,6 +177,11 @@ class Grid(GridBase): def axes_names(self): """list of str: Axes names (xyz order).""" + @property + @abstractmethod + def axes_attributes(self): + """list of dict: Axes attributes following the CF convention (xyz order).""" + @property def data_axes_names(self): """list of str: Axes names of the data.""" @@ -272,11 +303,6 @@ class StructuredGrid(Grid): """list of bool: False to indicate a bottom up axis (xyz order).""" # esri grids and some netcdf are given bottom up (northing/lat inverted) - @property - @abstractmethod - def axes_attributes(self): - """list of dict: Axes attributes following the CF convention (xyz order).""" - @property @abstractmethod def order(self): diff --git a/src/finam/data/grid_spec.py b/src/finam/data/grid_spec.py index cf2153a0fb1bfeb6aee69e9b88ca9ada514ba0c7..1e0389eb5e28ab1735b4d5a9220667b303ce7861 100644 --- a/src/finam/data/grid_spec.py +++ b/src/finam/data/grid_spec.py @@ -1,4 +1,5 @@ """Grid specifications to handle spatial data with FINAM.""" + from pathlib import Path import numpy as np @@ -17,6 +18,15 @@ from .grid_tools import ( ) +def _check_location(grid, data_location): + # need to define this here to prevent circular imports + location = get_enum_value(data_location, Location) + if location not in grid.valid_locations: + msg = f"{grid.name}: data location {location} not valid." + raise ValueError(msg) + return location + + class NoGrid(GridBase): """Indicator for data without a spatial grid.""" @@ -89,7 +99,8 @@ class RectilinearGrid(StructuredGrid): # all axes made increasing self._axes_increase = check_axes_monotonicity(self.axes) self._dim = len(self.dims) - self._data_location = get_enum_value(data_location, Location) + self._data_location = None + self.data_location = data_location self._order = order self._axes_reversed = bool(axes_reversed) self._axes_attributes = axes_attributes or (self.dim * [{}]) @@ -118,6 +129,7 @@ class RectilinearGrid(StructuredGrid): cell_types=self.cell_types, data_location=self.data_location, order=self.order, + axes_attributes=self.axes_attributes, axes_names=self.axes_names, crs=self.crs, ) @@ -187,6 +199,11 @@ class RectilinearGrid(StructuredGrid): """Location of the associated data (either CELLS or POINTS).""" return self._data_location + @data_location.setter + def data_location(self, data_location): + """Set location of the associated data (either CELLS or POINTS).""" + self._data_location = _check_location(self, data_location) + class UniformGrid(RectilinearGrid): """Regular grid with uniform spacing in up to three coordinate directions. @@ -296,6 +313,28 @@ class UniformGrid(RectilinearGrid): spacing = self.spacing + (0.0,) * (3 - self.dim) imageToVTK(path, origin, spacing, **kw) + def to_rectilinear(self): + """ + Cast grid to a rectilinear grid. + + Returns + ------- + UniformGrid + Grid as rectilinear grid. + """ + grid = RectilinearGrid( + axes=self.axes, + data_location=self.data_location, + order=self.order, + axes_reversed=self.axes_reversed, + axes_attributes=self.axes_attributes, + axes_names=self.axes_names, + crs=self.crs, + ) + # pylint: disable-next=protected-access + grid._axes_increase = self.axes_increase + return grid + class EsriGrid(UniformGrid): """ @@ -324,6 +363,9 @@ class EsriGrid(UniformGrid): The coordinate reference system, by default None """ + valid_locations = (Location.CELLS,) + """tuple: Valid locations for the grid.""" + def __init__( self, ncols, @@ -378,6 +420,28 @@ class EsriGrid(UniformGrid): header["axes_attributes"] = axes_attributes return cls(**header) + def to_uniform(self): + """ + Cast grid to an uniform grid. + + Returns + ------- + UniformGrid + Grid as uniform grid. + """ + return UniformGrid( + dims=self.dims, + spacing=self.spacing, + origin=self.origin, + data_location=self.data_location, + order=self.order, + axes_reversed=self.axes_reversed, + axes_increase=self.axes_increase, + axes_attributes=self.axes_attributes, + axes_names=self.axes_names, + crs=self.crs, + ) + class UnstructuredGrid(Grid): """ @@ -396,6 +460,8 @@ class UnstructuredGrid(Grid): order : str, optional Data ordering. Either Fortran-like ("F") or C-like ("C"), by default "C" + axes_attributes : list of dict or None, optional + Axes attributes following the CF convention (in xyz order), by default None axes_names : list of str or None, optional Axes names (in xyz order), by default ["x", "y", "z"] crs : str or None, optional @@ -409,6 +475,7 @@ class UnstructuredGrid(Grid): cell_types, data_location=Location.CELLS, order="C", + axes_attributes=None, axes_names=None, crs=None, ): @@ -416,8 +483,12 @@ class UnstructuredGrid(Grid): self._points = np.asarray(np.atleast_2d(points), dtype=float)[:, :3] self._cells = np.asarray(np.atleast_2d(cells), dtype=int) self._cell_types = np.asarray(np.atleast_1d(cell_types), dtype=int) - self._data_location = get_enum_value(data_location, Location) + self._data_location = None + self.data_location = data_location self._order = order + self._axes_attributes = axes_attributes or (self.dim * [{}]) + if len(self.axes_attributes) != self.dim: + raise ValueError("UnstructuredGrid: wrong length of 'axes_attributes'") self._axes_names = axes_names or ["x", "y", "z"][: self.dim] if len(self.axes_names) != self.dim: raise ValueError("UnstructuredGrid: wrong length of 'axes_names'") @@ -482,11 +553,21 @@ class UnstructuredGrid(Grid): """Location of the associated data (either CELLS or POINTS).""" return self._data_location + @data_location.setter + def data_location(self, data_location): + """Set location of the associated data (either CELLS or POINTS).""" + self._data_location = _check_location(self, data_location) + @property def order(self): """str: Point, cell and data order (C-like or F-like for flatten).""" return self._order + @property + def axes_attributes(self): + """list of dict: Axes attributes following the CF convention (xyz order).""" + return self._axes_attributes + @property def axes_names(self): """list of str: Axes names (xyz order).""" @@ -504,16 +585,22 @@ class UnstructuredPoints(UnstructuredGrid): order : str, optional Data ordering. Either Fortran-like ("F") or C-like ("C"), by default "C" + axes_attributes : list of dict or None, optional + Axes attributes following the CF convention (in xyz order), by default None axes_names : list of str or None, optional Axes names (in xyz order), by default ["x", "y", "z"] crs : str or None, optional The coordinate reference system, by default None """ + valid_locations = (Location.POINTS,) + """tuple: Valid locations for the grid.""" + def __init__( self, points, order="C", + axes_attributes=None, axes_names=None, crs=None, ): @@ -525,6 +612,7 @@ class UnstructuredPoints(UnstructuredGrid): cell_types=np.full(pnt_cnt, CellType.VERTEX, dtype=int), data_location=Location.POINTS, order=order, + axes_attributes=axes_attributes, axes_names=axes_names, crs=crs, ) diff --git a/tests/data/test_grid_spec.py b/tests/data/test_grid_spec.py index aed4b11bc5d87c0d53012524641cb5f834a14afe..814c5ed25e320acb1468f7713e7f0c6f6cb45278 100644 --- a/tests/data/test_grid_spec.py +++ b/tests/data/test_grid_spec.py @@ -196,6 +196,9 @@ class TestGridSpec(unittest.TestCase): axes_names=["to_few"], ) + with self.assertRaises(ValueError): + UnstructuredPoints(points=[[0.0, 0.0]], axes_attributes=[{"too": "short"}]) + self.assertEqual(grid.name, "UnstructuredGrid") self.assertEqual(grid2.name, "UnstructuredPoints") self.assertIsNone(grid.crs) @@ -219,6 +222,9 @@ class TestGridSpec(unittest.TestCase): self.assertEqual(grid2.mesh_dim, 0) assert_allclose(grid2.cell_node_counts, 1) + with self.assertRaises(ValueError): + grid2.data_location = Location.CELLS + def test_esri(self): header = { "ncols": 520, @@ -238,6 +244,20 @@ class TestGridSpec(unittest.TestCase): self.assertAlmostEqual(grid.xllcorner, 4375000) self.assertAlmostEqual(grid.yllcorner, 2700000) + # casting + grid2 = grid.to_uniform() + grid3 = grid.to_rectilinear() + grid4 = grid2.to_rectilinear() + self.assertTrue(grid == grid2) + self.assertTrue(grid == grid3) + self.assertTrue(grid == grid4) + self.assertTrue(grid2 == grid3) + self.assertTrue(grid2 == grid4) + self.assertTrue(grid3 == grid4) + + with self.assertRaises(ValueError): + grid.data_location = Location.POINTS + def test_data_location(self): grid1 = UniformGrid((1,), data_location=0) grid2 = UniformGrid((2, 2), data_location="CELLS") @@ -258,6 +278,20 @@ class TestGridSpec(unittest.TestCase): assert_allclose(grid.data_points, us_grid.data_points) self.assertIsInstance(us_grid, UnstructuredGrid) + def test_copy(self): + grid = EsriGrid(3, 2) + us_grid = grid.to_unstructured() + cp_grid1 = us_grid.copy() + cp_grid2 = us_grid.copy(deep=True) + + self.assertTrue(us_grid == cp_grid1) + self.assertTrue(us_grid == cp_grid2) + + # shallow copy shares info + cp_grid1.points[0, 0] = 0.1 + self.assertTrue(us_grid == cp_grid1) + self.assertFalse(us_grid == cp_grid2) + def test_equality(self): grid1 = UniformGrid((2, 2), data_location=0) grid2 = UnstructuredGrid(