diff --git a/CHANGELOG.md b/CHANGELOG.md
index 3fc3ff4f95d47bcc7db6f9327949c005a36c8567..a6b95ae2178b2293444a95273aa3c821df3787d2 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -20,6 +20,12 @@
   * `get_cells_matrix`: convert `cells_connectivity` or `cells_definition` back to the default cells matrix used in the Grid class (can be used to convert VTK-grids into FINAM-grids)
   * `INV_VTK_TYPE_MAP`: inverse mapping to `VTK_TYPE_MAP` - FINAM cell type to VTK cell type
   * `VTK_CELL_DIM`: parametric dimension for each VTK cell type
+* Grid class now reusable when having different data locations and better grid type casting (!278)
+  * added `copy` method to grids with optional argument `deep` (`False` by default) to create a copy of a grid
+  * added setter for `data_location` in order to set a new data location (e.g. after copying a grid)
+  * added class attribute `valid_locations` in order to check the set data location (esri-grid only supports cells, unstructured-points only support points)
+  * added missing casting methods to convert esri to uniform and uniform to rectilinear (when you want to use point data on an esri-grid, you can cast it to uniform first)
+  * added `axes_attributes` also to unstructured grids
 
 ### Documentation
 
diff --git a/src/finam/data/grid_base.py b/src/finam/data/grid_base.py
index 8d51159123f618c2ddae6ad34b7c6808e68c73e1..5e4f2abf80af8f35ba04135a2b9c3965f952b570 100644
--- a/src/finam/data/grid_base.py
+++ b/src/finam/data/grid_base.py
@@ -1,4 +1,6 @@
 """Grid abstract base classes for FINAM."""
+
+import copy as cp
 from abc import ABC, abstractmethod
 from pathlib import Path
 
@@ -34,6 +36,22 @@ class GridBase(ABC):
     def dim(self):
         """int: Dimension of the grid or data."""
 
+    def copy(self, deep=False):
+        """
+        Copy of this grid.
+
+        Parameters
+        ----------
+        deep : bool, optional
+            If false, only a shallow copy is returned to save memory, by default False
+
+        Returns
+        -------
+        Grid
+            The grid copy.
+        """
+        return cp.deepcopy(self) if deep else cp.copy(self)
+
     def to_canonical(self, data):
         """Convert grid specific data to canonical form."""
         return data
@@ -51,6 +69,9 @@ class GridBase(ABC):
 class Grid(GridBase):
     """Abstract grid specification."""
 
+    valid_locations = (Location.CELLS, Location.POINTS)
+    """tuple: Valid locations for the grid."""
+
     @property
     @abstractmethod
     def crs(self):
@@ -124,6 +145,11 @@ class Grid(GridBase):
     def data_location(self):
         """Location of the associated data (either CELLS or POINTS)."""
 
+    @data_location.setter
+    @abstractmethod
+    def data_location(self, data_location):
+        """Set location of the associated data (either CELLS or POINTS)."""
+
     @property
     def data_points(self):
         """Points of the associated data (either cell_centers or points)."""
@@ -151,6 +177,11 @@ class Grid(GridBase):
     def axes_names(self):
         """list of str: Axes names (xyz order)."""
 
+    @property
+    @abstractmethod
+    def axes_attributes(self):
+        """list of dict: Axes attributes following the CF convention (xyz order)."""
+
     @property
     def data_axes_names(self):
         """list of str: Axes names of the data."""
@@ -272,11 +303,6 @@ class StructuredGrid(Grid):
         """list of bool: False to indicate a bottom up axis (xyz order)."""
         # esri grids and some netcdf are given bottom up (northing/lat inverted)
 
-    @property
-    @abstractmethod
-    def axes_attributes(self):
-        """list of dict: Axes attributes following the CF convention (xyz order)."""
-
     @property
     @abstractmethod
     def order(self):
diff --git a/src/finam/data/grid_spec.py b/src/finam/data/grid_spec.py
index cf2153a0fb1bfeb6aee69e9b88ca9ada514ba0c7..1e0389eb5e28ab1735b4d5a9220667b303ce7861 100644
--- a/src/finam/data/grid_spec.py
+++ b/src/finam/data/grid_spec.py
@@ -1,4 +1,5 @@
 """Grid specifications to handle spatial data with FINAM."""
+
 from pathlib import Path
 
 import numpy as np
@@ -17,6 +18,15 @@ from .grid_tools import (
 )
 
 
+def _check_location(grid, data_location):
+    # need to define this here to prevent circular imports
+    location = get_enum_value(data_location, Location)
+    if location not in grid.valid_locations:
+        msg = f"{grid.name}: data location {location} not valid."
+        raise ValueError(msg)
+    return location
+
+
 class NoGrid(GridBase):
     """Indicator for data without a spatial grid."""
 
@@ -89,7 +99,8 @@ class RectilinearGrid(StructuredGrid):
         # all axes made increasing
         self._axes_increase = check_axes_monotonicity(self.axes)
         self._dim = len(self.dims)
-        self._data_location = get_enum_value(data_location, Location)
+        self._data_location = None
+        self.data_location = data_location
         self._order = order
         self._axes_reversed = bool(axes_reversed)
         self._axes_attributes = axes_attributes or (self.dim * [{}])
@@ -118,6 +129,7 @@ class RectilinearGrid(StructuredGrid):
             cell_types=self.cell_types,
             data_location=self.data_location,
             order=self.order,
+            axes_attributes=self.axes_attributes,
             axes_names=self.axes_names,
             crs=self.crs,
         )
@@ -187,6 +199,11 @@ class RectilinearGrid(StructuredGrid):
         """Location of the associated data (either CELLS or POINTS)."""
         return self._data_location
 
+    @data_location.setter
+    def data_location(self, data_location):
+        """Set location of the associated data (either CELLS or POINTS)."""
+        self._data_location = _check_location(self, data_location)
+
 
 class UniformGrid(RectilinearGrid):
     """Regular grid with uniform spacing in up to three coordinate directions.
@@ -296,6 +313,28 @@ class UniformGrid(RectilinearGrid):
             spacing = self.spacing + (0.0,) * (3 - self.dim)
             imageToVTK(path, origin, spacing, **kw)
 
+    def to_rectilinear(self):
+        """
+        Cast grid to a rectilinear grid.
+
+        Returns
+        -------
+        UniformGrid
+            Grid as rectilinear grid.
+        """
+        grid = RectilinearGrid(
+            axes=self.axes,
+            data_location=self.data_location,
+            order=self.order,
+            axes_reversed=self.axes_reversed,
+            axes_attributes=self.axes_attributes,
+            axes_names=self.axes_names,
+            crs=self.crs,
+        )
+        # pylint: disable-next=protected-access
+        grid._axes_increase = self.axes_increase
+        return grid
+
 
 class EsriGrid(UniformGrid):
     """
@@ -324,6 +363,9 @@ class EsriGrid(UniformGrid):
         The coordinate reference system, by default None
     """
 
+    valid_locations = (Location.CELLS,)
+    """tuple: Valid locations for the grid."""
+
     def __init__(
         self,
         ncols,
@@ -378,6 +420,28 @@ class EsriGrid(UniformGrid):
         header["axes_attributes"] = axes_attributes
         return cls(**header)
 
+    def to_uniform(self):
+        """
+        Cast grid to an uniform grid.
+
+        Returns
+        -------
+        UniformGrid
+            Grid as uniform grid.
+        """
+        return UniformGrid(
+            dims=self.dims,
+            spacing=self.spacing,
+            origin=self.origin,
+            data_location=self.data_location,
+            order=self.order,
+            axes_reversed=self.axes_reversed,
+            axes_increase=self.axes_increase,
+            axes_attributes=self.axes_attributes,
+            axes_names=self.axes_names,
+            crs=self.crs,
+        )
+
 
 class UnstructuredGrid(Grid):
     """
@@ -396,6 +460,8 @@ class UnstructuredGrid(Grid):
     order : str, optional
         Data ordering.
         Either Fortran-like ("F") or C-like ("C"), by default "C"
+    axes_attributes : list of dict or None, optional
+        Axes attributes following the CF convention (in xyz order), by default None
     axes_names : list of str or None, optional
         Axes names (in xyz order), by default ["x", "y", "z"]
     crs : str or None, optional
@@ -409,6 +475,7 @@ class UnstructuredGrid(Grid):
         cell_types,
         data_location=Location.CELLS,
         order="C",
+        axes_attributes=None,
         axes_names=None,
         crs=None,
     ):
@@ -416,8 +483,12 @@ class UnstructuredGrid(Grid):
         self._points = np.asarray(np.atleast_2d(points), dtype=float)[:, :3]
         self._cells = np.asarray(np.atleast_2d(cells), dtype=int)
         self._cell_types = np.asarray(np.atleast_1d(cell_types), dtype=int)
-        self._data_location = get_enum_value(data_location, Location)
+        self._data_location = None
+        self.data_location = data_location
         self._order = order
+        self._axes_attributes = axes_attributes or (self.dim * [{}])
+        if len(self.axes_attributes) != self.dim:
+            raise ValueError("UnstructuredGrid: wrong length of 'axes_attributes'")
         self._axes_names = axes_names or ["x", "y", "z"][: self.dim]
         if len(self.axes_names) != self.dim:
             raise ValueError("UnstructuredGrid: wrong length of 'axes_names'")
@@ -482,11 +553,21 @@ class UnstructuredGrid(Grid):
         """Location of the associated data (either CELLS or POINTS)."""
         return self._data_location
 
+    @data_location.setter
+    def data_location(self, data_location):
+        """Set location of the associated data (either CELLS or POINTS)."""
+        self._data_location = _check_location(self, data_location)
+
     @property
     def order(self):
         """str: Point, cell and data order (C-like or F-like for flatten)."""
         return self._order
 
+    @property
+    def axes_attributes(self):
+        """list of dict: Axes attributes following the CF convention (xyz order)."""
+        return self._axes_attributes
+
     @property
     def axes_names(self):
         """list of str: Axes names (xyz order)."""
@@ -504,16 +585,22 @@ class UnstructuredPoints(UnstructuredGrid):
     order : str, optional
         Data ordering.
         Either Fortran-like ("F") or C-like ("C"), by default "C"
+    axes_attributes : list of dict or None, optional
+        Axes attributes following the CF convention (in xyz order), by default None
     axes_names : list of str or None, optional
         Axes names (in xyz order), by default ["x", "y", "z"]
     crs : str or None, optional
         The coordinate reference system, by default None
     """
 
+    valid_locations = (Location.POINTS,)
+    """tuple: Valid locations for the grid."""
+
     def __init__(
         self,
         points,
         order="C",
+        axes_attributes=None,
         axes_names=None,
         crs=None,
     ):
@@ -525,6 +612,7 @@ class UnstructuredPoints(UnstructuredGrid):
             cell_types=np.full(pnt_cnt, CellType.VERTEX, dtype=int),
             data_location=Location.POINTS,
             order=order,
+            axes_attributes=axes_attributes,
             axes_names=axes_names,
             crs=crs,
         )
diff --git a/tests/data/test_grid_spec.py b/tests/data/test_grid_spec.py
index aed4b11bc5d87c0d53012524641cb5f834a14afe..814c5ed25e320acb1468f7713e7f0c6f6cb45278 100644
--- a/tests/data/test_grid_spec.py
+++ b/tests/data/test_grid_spec.py
@@ -196,6 +196,9 @@ class TestGridSpec(unittest.TestCase):
                 axes_names=["to_few"],
             )
 
+        with self.assertRaises(ValueError):
+            UnstructuredPoints(points=[[0.0, 0.0]], axes_attributes=[{"too": "short"}])
+
         self.assertEqual(grid.name, "UnstructuredGrid")
         self.assertEqual(grid2.name, "UnstructuredPoints")
         self.assertIsNone(grid.crs)
@@ -219,6 +222,9 @@ class TestGridSpec(unittest.TestCase):
         self.assertEqual(grid2.mesh_dim, 0)
         assert_allclose(grid2.cell_node_counts, 1)
 
+        with self.assertRaises(ValueError):
+            grid2.data_location = Location.CELLS
+
     def test_esri(self):
         header = {
             "ncols": 520,
@@ -238,6 +244,20 @@ class TestGridSpec(unittest.TestCase):
         self.assertAlmostEqual(grid.xllcorner, 4375000)
         self.assertAlmostEqual(grid.yllcorner, 2700000)
 
+        # casting
+        grid2 = grid.to_uniform()
+        grid3 = grid.to_rectilinear()
+        grid4 = grid2.to_rectilinear()
+        self.assertTrue(grid == grid2)
+        self.assertTrue(grid == grid3)
+        self.assertTrue(grid == grid4)
+        self.assertTrue(grid2 == grid3)
+        self.assertTrue(grid2 == grid4)
+        self.assertTrue(grid3 == grid4)
+
+        with self.assertRaises(ValueError):
+            grid.data_location = Location.POINTS
+
     def test_data_location(self):
         grid1 = UniformGrid((1,), data_location=0)
         grid2 = UniformGrid((2, 2), data_location="CELLS")
@@ -258,6 +278,20 @@ class TestGridSpec(unittest.TestCase):
         assert_allclose(grid.data_points, us_grid.data_points)
         self.assertIsInstance(us_grid, UnstructuredGrid)
 
+    def test_copy(self):
+        grid = EsriGrid(3, 2)
+        us_grid = grid.to_unstructured()
+        cp_grid1 = us_grid.copy()
+        cp_grid2 = us_grid.copy(deep=True)
+
+        self.assertTrue(us_grid == cp_grid1)
+        self.assertTrue(us_grid == cp_grid2)
+
+        # shallow copy shares info
+        cp_grid1.points[0, 0] = 0.1
+        self.assertTrue(us_grid == cp_grid1)
+        self.assertFalse(us_grid == cp_grid2)
+
     def test_equality(self):
         grid1 = UniformGrid((2, 2), data_location=0)
         grid2 = UnstructuredGrid(