Skip to content
Snippets Groups Projects
Commit 5bff58fa authored by Sebastian Müller's avatar Sebastian Müller 🐈
Browse files

Merge branch 'add_grid_copy' into 'main'

Add grid.copy

See merge request !278
parents 978b8691 8ba23e42
No related branches found
No related tags found
1 merge request!278Add grid.copy
Pipeline #208064 passed with stages
in 5 minutes and 40 seconds
......@@ -20,6 +20,12 @@
* `get_cells_matrix`: convert `cells_connectivity` or `cells_definition` back to the default cells matrix used in the Grid class (can be used to convert VTK-grids into FINAM-grids)
* `INV_VTK_TYPE_MAP`: inverse mapping to `VTK_TYPE_MAP` - FINAM cell type to VTK cell type
* `VTK_CELL_DIM`: parametric dimension for each VTK cell type
* Grid class now reusable when having different data locations and better grid type casting (!278)
* added `copy` method to grids with optional argument `deep` (`False` by default) to create a copy of a grid
* added setter for `data_location` in order to set a new data location (e.g. after copying a grid)
* added class attribute `valid_locations` in order to check the set data location (esri-grid only supports cells, unstructured-points only support points)
* added missing casting methods to convert esri to uniform and uniform to rectilinear (when you want to use point data on an esri-grid, you can cast it to uniform first)
* added `axes_attributes` also to unstructured grids
### Documentation
......
"""Grid abstract base classes for FINAM."""
import copy as cp
from abc import ABC, abstractmethod
from pathlib import Path
......@@ -34,6 +36,22 @@ class GridBase(ABC):
def dim(self):
"""int: Dimension of the grid or data."""
def copy(self, deep=False):
"""
Copy of this grid.
Parameters
----------
deep : bool, optional
If false, only a shallow copy is returned to save memory, by default False
Returns
-------
Grid
The grid copy.
"""
return cp.deepcopy(self) if deep else cp.copy(self)
def to_canonical(self, data):
"""Convert grid specific data to canonical form."""
return data
......@@ -51,6 +69,9 @@ class GridBase(ABC):
class Grid(GridBase):
"""Abstract grid specification."""
valid_locations = (Location.CELLS, Location.POINTS)
"""tuple: Valid locations for the grid."""
@property
@abstractmethod
def crs(self):
......@@ -124,6 +145,11 @@ class Grid(GridBase):
def data_location(self):
"""Location of the associated data (either CELLS or POINTS)."""
@data_location.setter
@abstractmethod
def data_location(self, data_location):
"""Set location of the associated data (either CELLS or POINTS)."""
@property
def data_points(self):
"""Points of the associated data (either cell_centers or points)."""
......@@ -151,6 +177,11 @@ class Grid(GridBase):
def axes_names(self):
"""list of str: Axes names (xyz order)."""
@property
@abstractmethod
def axes_attributes(self):
"""list of dict: Axes attributes following the CF convention (xyz order)."""
@property
def data_axes_names(self):
"""list of str: Axes names of the data."""
......@@ -272,11 +303,6 @@ class StructuredGrid(Grid):
"""list of bool: False to indicate a bottom up axis (xyz order)."""
# esri grids and some netcdf are given bottom up (northing/lat inverted)
@property
@abstractmethod
def axes_attributes(self):
"""list of dict: Axes attributes following the CF convention (xyz order)."""
@property
@abstractmethod
def order(self):
......
"""Grid specifications to handle spatial data with FINAM."""
from pathlib import Path
import numpy as np
......@@ -17,6 +18,15 @@ from .grid_tools import (
)
def _check_location(grid, data_location):
# need to define this here to prevent circular imports
location = get_enum_value(data_location, Location)
if location not in grid.valid_locations:
msg = f"{grid.name}: data location {location} not valid."
raise ValueError(msg)
return location
class NoGrid(GridBase):
"""Indicator for data without a spatial grid."""
......@@ -89,7 +99,8 @@ class RectilinearGrid(StructuredGrid):
# all axes made increasing
self._axes_increase = check_axes_monotonicity(self.axes)
self._dim = len(self.dims)
self._data_location = get_enum_value(data_location, Location)
self._data_location = None
self.data_location = data_location
self._order = order
self._axes_reversed = bool(axes_reversed)
self._axes_attributes = axes_attributes or (self.dim * [{}])
......@@ -118,6 +129,7 @@ class RectilinearGrid(StructuredGrid):
cell_types=self.cell_types,
data_location=self.data_location,
order=self.order,
axes_attributes=self.axes_attributes,
axes_names=self.axes_names,
crs=self.crs,
)
......@@ -187,6 +199,11 @@ class RectilinearGrid(StructuredGrid):
"""Location of the associated data (either CELLS or POINTS)."""
return self._data_location
@data_location.setter
def data_location(self, data_location):
"""Set location of the associated data (either CELLS or POINTS)."""
self._data_location = _check_location(self, data_location)
class UniformGrid(RectilinearGrid):
"""Regular grid with uniform spacing in up to three coordinate directions.
......@@ -296,6 +313,28 @@ class UniformGrid(RectilinearGrid):
spacing = self.spacing + (0.0,) * (3 - self.dim)
imageToVTK(path, origin, spacing, **kw)
def to_rectilinear(self):
"""
Cast grid to a rectilinear grid.
Returns
-------
UniformGrid
Grid as rectilinear grid.
"""
grid = RectilinearGrid(
axes=self.axes,
data_location=self.data_location,
order=self.order,
axes_reversed=self.axes_reversed,
axes_attributes=self.axes_attributes,
axes_names=self.axes_names,
crs=self.crs,
)
# pylint: disable-next=protected-access
grid._axes_increase = self.axes_increase
return grid
class EsriGrid(UniformGrid):
"""
......@@ -324,6 +363,9 @@ class EsriGrid(UniformGrid):
The coordinate reference system, by default None
"""
valid_locations = (Location.CELLS,)
"""tuple: Valid locations for the grid."""
def __init__(
self,
ncols,
......@@ -378,6 +420,28 @@ class EsriGrid(UniformGrid):
header["axes_attributes"] = axes_attributes
return cls(**header)
def to_uniform(self):
"""
Cast grid to an uniform grid.
Returns
-------
UniformGrid
Grid as uniform grid.
"""
return UniformGrid(
dims=self.dims,
spacing=self.spacing,
origin=self.origin,
data_location=self.data_location,
order=self.order,
axes_reversed=self.axes_reversed,
axes_increase=self.axes_increase,
axes_attributes=self.axes_attributes,
axes_names=self.axes_names,
crs=self.crs,
)
class UnstructuredGrid(Grid):
"""
......@@ -396,6 +460,8 @@ class UnstructuredGrid(Grid):
order : str, optional
Data ordering.
Either Fortran-like ("F") or C-like ("C"), by default "C"
axes_attributes : list of dict or None, optional
Axes attributes following the CF convention (in xyz order), by default None
axes_names : list of str or None, optional
Axes names (in xyz order), by default ["x", "y", "z"]
crs : str or None, optional
......@@ -409,6 +475,7 @@ class UnstructuredGrid(Grid):
cell_types,
data_location=Location.CELLS,
order="C",
axes_attributes=None,
axes_names=None,
crs=None,
):
......@@ -416,8 +483,12 @@ class UnstructuredGrid(Grid):
self._points = np.asarray(np.atleast_2d(points), dtype=float)[:, :3]
self._cells = np.asarray(np.atleast_2d(cells), dtype=int)
self._cell_types = np.asarray(np.atleast_1d(cell_types), dtype=int)
self._data_location = get_enum_value(data_location, Location)
self._data_location = None
self.data_location = data_location
self._order = order
self._axes_attributes = axes_attributes or (self.dim * [{}])
if len(self.axes_attributes) != self.dim:
raise ValueError("UnstructuredGrid: wrong length of 'axes_attributes'")
self._axes_names = axes_names or ["x", "y", "z"][: self.dim]
if len(self.axes_names) != self.dim:
raise ValueError("UnstructuredGrid: wrong length of 'axes_names'")
......@@ -482,11 +553,21 @@ class UnstructuredGrid(Grid):
"""Location of the associated data (either CELLS or POINTS)."""
return self._data_location
@data_location.setter
def data_location(self, data_location):
"""Set location of the associated data (either CELLS or POINTS)."""
self._data_location = _check_location(self, data_location)
@property
def order(self):
"""str: Point, cell and data order (C-like or F-like for flatten)."""
return self._order
@property
def axes_attributes(self):
"""list of dict: Axes attributes following the CF convention (xyz order)."""
return self._axes_attributes
@property
def axes_names(self):
"""list of str: Axes names (xyz order)."""
......@@ -504,16 +585,22 @@ class UnstructuredPoints(UnstructuredGrid):
order : str, optional
Data ordering.
Either Fortran-like ("F") or C-like ("C"), by default "C"
axes_attributes : list of dict or None, optional
Axes attributes following the CF convention (in xyz order), by default None
axes_names : list of str or None, optional
Axes names (in xyz order), by default ["x", "y", "z"]
crs : str or None, optional
The coordinate reference system, by default None
"""
valid_locations = (Location.POINTS,)
"""tuple: Valid locations for the grid."""
def __init__(
self,
points,
order="C",
axes_attributes=None,
axes_names=None,
crs=None,
):
......@@ -525,6 +612,7 @@ class UnstructuredPoints(UnstructuredGrid):
cell_types=np.full(pnt_cnt, CellType.VERTEX, dtype=int),
data_location=Location.POINTS,
order=order,
axes_attributes=axes_attributes,
axes_names=axes_names,
crs=crs,
)
......
......@@ -196,6 +196,9 @@ class TestGridSpec(unittest.TestCase):
axes_names=["to_few"],
)
with self.assertRaises(ValueError):
UnstructuredPoints(points=[[0.0, 0.0]], axes_attributes=[{"too": "short"}])
self.assertEqual(grid.name, "UnstructuredGrid")
self.assertEqual(grid2.name, "UnstructuredPoints")
self.assertIsNone(grid.crs)
......@@ -219,6 +222,9 @@ class TestGridSpec(unittest.TestCase):
self.assertEqual(grid2.mesh_dim, 0)
assert_allclose(grid2.cell_node_counts, 1)
with self.assertRaises(ValueError):
grid2.data_location = Location.CELLS
def test_esri(self):
header = {
"ncols": 520,
......@@ -238,6 +244,20 @@ class TestGridSpec(unittest.TestCase):
self.assertAlmostEqual(grid.xllcorner, 4375000)
self.assertAlmostEqual(grid.yllcorner, 2700000)
# casting
grid2 = grid.to_uniform()
grid3 = grid.to_rectilinear()
grid4 = grid2.to_rectilinear()
self.assertTrue(grid == grid2)
self.assertTrue(grid == grid3)
self.assertTrue(grid == grid4)
self.assertTrue(grid2 == grid3)
self.assertTrue(grid2 == grid4)
self.assertTrue(grid3 == grid4)
with self.assertRaises(ValueError):
grid.data_location = Location.POINTS
def test_data_location(self):
grid1 = UniformGrid((1,), data_location=0)
grid2 = UniformGrid((2, 2), data_location="CELLS")
......@@ -258,6 +278,20 @@ class TestGridSpec(unittest.TestCase):
assert_allclose(grid.data_points, us_grid.data_points)
self.assertIsInstance(us_grid, UnstructuredGrid)
def test_copy(self):
grid = EsriGrid(3, 2)
us_grid = grid.to_unstructured()
cp_grid1 = us_grid.copy()
cp_grid2 = us_grid.copy(deep=True)
self.assertTrue(us_grid == cp_grid1)
self.assertTrue(us_grid == cp_grid2)
# shallow copy shares info
cp_grid1.points[0, 0] = 0.1
self.assertTrue(us_grid == cp_grid1)
self.assertFalse(us_grid == cp_grid2)
def test_equality(self):
grid1 = UniformGrid((2, 2), data_location=0)
grid2 = UnstructuredGrid(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment