From 32ddf520a9331f88c2ac466566f125f918daa9c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20M=C3=BCller?= <mueller.seb@posteo.de> Date: Sat, 11 Jan 2025 11:46:23 +0100 Subject: [PATCH] NoGrid: allow explict setting of data_shape --- src/finam/data/grid_base.py | 6 +++--- src/finam/data/grid_spec.py | 40 +++++++++++++++++++++++++++++------- src/finam/data/tools/core.py | 22 ++++++++++++++++---- 3 files changed, 54 insertions(+), 14 deletions(-) diff --git a/src/finam/data/grid_base.py b/src/finam/data/grid_base.py index f6fc3132..e74f697a 100644 --- a/src/finam/data/grid_base.py +++ b/src/finam/data/grid_base.py @@ -70,6 +70,9 @@ class GridBase(ABC): """Transformation between compatible grids.""" return None + def __repr__(self): + return f"{self.name} ({self.dim}D) {self.data_shape}" + class Grid(GridBase): """Abstract grid specification.""" @@ -187,9 +190,6 @@ class Grid(GridBase): """list of str: Axes names of the data.""" return ["id"] - def __repr__(self): - return f"{self.__class__.__name__} ({self.dim}D) {self.data_shape}" - def compatible_with(self, other, check_location=True): """ Check for compatibility with other Grid. diff --git a/src/finam/data/grid_spec.py b/src/finam/data/grid_spec.py index b5364570..9a4f43e3 100644 --- a/src/finam/data/grid_spec.py +++ b/src/finam/data/grid_spec.py @@ -28,10 +28,39 @@ def _check_location(grid, data_location): class NoGrid(GridBase): - """Indicator for data without a spatial grid.""" + """ + Indicator for data without a spatial grid. + + Either dim or data_shape needed. + + Parameters + ---------- + dim : int or None, optional + Data dimensionality. Should match the length of data_shape. + data_shape : tuple of int or None, optional + Data shape. Can contain -1 to indicate flexible axis. + + Raises + ------ + ValueError + If none of dim or data_shape are given. + ValueError + If dim does not match the length of data_shape. + """ - def __init__(self, dim=0): + def __init__(self, dim=None, data_shape=None): + if dim is None and data_shape is None: + msg = "NoGrid: either dim or data_shape needed." + raise ValueError(msg) + if data_shape is None: + data_shape = (-1,) * dim + if dim is None: + dim = len(data_shape) + if dim != len(data_shape): + msg = "NoGrid: dim needs to match the length of data_shape." + raise ValueError(msg) self._dim = dim + self._data_shape = data_shape @property def dim(self): @@ -41,10 +70,7 @@ class NoGrid(GridBase): @property def data_shape(self): """tuple: Shape of the associated data.""" - return tuple() - - def __repr__(self): - return f"{self.__class__.__name__} ({self.dim}D)" + return self._data_shape # pylint: disable-next=unused-argument def compatible_with(self, other, check_location=True): @@ -63,7 +89,7 @@ class NoGrid(GridBase): bool compatibility """ - return isinstance(other, NoGrid) and self.dim == other.dim + return isinstance(other, NoGrid) and self.data_shape == other.data_shape def __eq__(self, other): return self.compatible_with(other) diff --git a/src/finam/data/tools/core.py b/src/finam/data/tools/core.py index df3958a0..525d5663 100644 --- a/src/finam/data/tools/core.py +++ b/src/finam/data/tools/core.py @@ -154,22 +154,36 @@ def _check_input_shape(data, info, time_entries): def _check_input_shape_no_grid(data, info, time_entries): if len(data.shape) != info.grid.dim + 1: - if len(data.shape) == info.grid.dim: + if _no_grid_shape_valid(data.shape, info.grid): data = np.expand_dims(data, 0) else: raise FinamDataError( - f"quantify: number of dimensions in data doesn't match expected number. " - f"Got {len(data.shape)}, expected {info.grid.dim}" + f"Data shape not valid. " + f"Got {data.shape}, expected {info.grid.data_shape}" ) else: + if not _no_grid_shape_valid(data.shape[1:], info.grid): + raise FinamDataError( + f"Data shape not valid. " + f"Got {data.shape[1:]}, expected {info.grid.data_shape}" + ) if data.shape[0] != time_entries: raise FinamDataError( - f"quantify: number of time entries in data doesn't match expected number. " + f"Number of time entries in data doesn't match expected number. " f"Got {data.shape[0]}, expected {time_entries}" ) return data +def _no_grid_shape_valid(data_shape, grid): + if len(data_shape) != grid.dim: + return False + dshp = np.array(data_shape) + gshp = np.array(grid.data_shape) + check = gshp != -1 + return np.all(dshp[check] == gshp[check]) + + def has_time_axis(xdata, grid): """ Check if the data array has a time axis. -- GitLab