diff --git a/CHANGELOG.md b/CHANGELOG.md index a6b95ae2178b2293444a95273aa3c821df3787d2..4456fd18c8c25f45a7f130ceed5e008909455625 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ * added class attribute `valid_locations` in order to check the set data location (esri-grid only supports cells, unstructured-points only support points) * added missing casting methods to convert esri to uniform and uniform to rectilinear (when you want to use point data on an esri-grid, you can cast it to uniform first) * added `axes_attributes` also to unstructured grids +* Grid method `compatible_with` now has a `check_location` argument to optionally check data location (!280) ### Documentation diff --git a/src/finam/data/grid_base.py b/src/finam/data/grid_base.py index 5e4f2abf80af8f35ba04135a2b9c3965f952b570..0ca332a4041bc4f877d970881c9853170e64d886 100644 --- a/src/finam/data/grid_base.py +++ b/src/finam/data/grid_base.py @@ -190,7 +190,7 @@ class Grid(GridBase): def __repr__(self): return f"{self.__class__.__name__} ({self.dim}D) {self.data_shape}" - def compatible_with(self, other): + def compatible_with(self, other, check_location=True): """ Check for compatibility with other Grid. @@ -198,6 +198,8 @@ class Grid(GridBase): ---------- other : instance of Grid Other grid to compatibility with. + check_location : bool, optional + Whether to check location for equality, by default True Returns ------- @@ -214,14 +216,18 @@ class Grid(GridBase): self.dim == other.dim and self.crs == other.crs and self.order == other.order - and self.data_location == other.data_location + and (not check_location or self.data_location == other.data_location) ): return False - if self.data_shape != other.data_shape: + if check_location and self.data_shape != other.data_shape: return False - return np.allclose(self.data_points, other.data_points) + return ( + np.allclose(self.points, other.points) + and np.all(self.cells == other.cells) + and np.all(self.cell_types == other.cell_types) + ) def __eq__(self, other): return self.compatible_with(other) @@ -393,7 +399,7 @@ class StructuredGrid(Grid): np.maximum(dims - 1, 1) if self.data_location == Location.CELLS else dims ) - def compatible_with(self, other): + def compatible_with(self, other, check_location=True): """ Check for compatibility with other Grid. @@ -401,6 +407,8 @@ class StructuredGrid(Grid): ---------- other : instance of Grid Other grid to compatibility with. + check_location : bool, optional + Whether to check location for equality, by default True Returns ------- @@ -416,11 +424,11 @@ class StructuredGrid(Grid): if not ( self.dim == other.dim and self.crs == other.crs - and self.data_location == other.data_location + and (not check_location or self.data_location == other.data_location) ): return False - if self.data_shape != ( + if check_location and self.data_shape != ( other.data_shape[::-1] if self.axes_reversed != other.axes_reversed else other.data_shape diff --git a/src/finam/data/grid_spec.py b/src/finam/data/grid_spec.py index 1e0389eb5e28ab1735b4d5a9220667b303ce7861..6ef840cdccb6069bd3d613a67bd83d532e53d42f 100644 --- a/src/finam/data/grid_spec.py +++ b/src/finam/data/grid_spec.py @@ -42,7 +42,7 @@ class NoGrid(GridBase): return f"{self.__class__.__name__} ({self.dim}D)" # pylint: disable-next=unused-argument - def compatible_with(self, other): + def compatible_with(self, other, check_location=True): """ Check for compatibility with other Grid. @@ -50,6 +50,8 @@ class NoGrid(GridBase): ---------- other : instance of Grid Other grid to compatibility with. + check_location : bool, optional + Whether to check location for equality, by default True Returns ------- diff --git a/tests/data/test_grid_spec.py b/tests/data/test_grid_spec.py index 814c5ed25e320acb1468f7713e7f0c6f6cb45278..fe97d0b513c06832d03ee3890ec7887f939da3bd 100644 --- a/tests/data/test_grid_spec.py +++ b/tests/data/test_grid_spec.py @@ -292,6 +292,16 @@ class TestGridSpec(unittest.TestCase): self.assertTrue(us_grid == cp_grid1) self.assertFalse(us_grid == cp_grid2) + def test_location_check(self): + grid_s1 = UniformGrid((2, 2), data_location="CELLS") + grid_s2 = UniformGrid((2, 2), data_location="POINTS") + grid_u1 = grid_s1.to_unstructured() + grid_u2 = grid_s2.to_unstructured() + self.assertTrue(grid_s1.compatible_with(grid_s2, check_location=False)) + self.assertFalse(grid_s1.compatible_with(grid_s2, check_location=True)) + self.assertTrue(grid_u1.compatible_with(grid_u2, check_location=False)) + self.assertFalse(grid_u1.compatible_with(grid_u2, check_location=True)) + def test_equality(self): grid1 = UniformGrid((2, 2), data_location=0) grid2 = UnstructuredGrid(