diff --git a/src/finam/adapters/regrid.py b/src/finam/adapters/regrid.py index 31fb15bd5be49c651970cad5ab14031f5148e8b4..472f39a179334fa257e46703832992f94f7261e8 100644 --- a/src/finam/adapters/regrid.py +++ b/src/finam/adapters/regrid.py @@ -132,10 +132,8 @@ class RegridNearest(ARegridding): def _get_data(self, time, target): in_data = self.pull_data(time, target) - res = ( - dtools.get_data(in_data) - .reshape(-1, order=self.input_grid.order)[self.ids] - .reshape(self.output_grid.data_shape, order=self.output_grid.order) + res = in_data.reshape(-1, order=self.input_grid.order)[self.ids].reshape( + self.output_grid.data_shape, order=self.output_grid.order ) return res @@ -219,7 +217,9 @@ class RegridLinear(ARegridding): in_data = self.pull_data(time, target) if isinstance(self.input_grid, StructuredGrid): - self.inter.values = dtools.get_magnitude(dtools.strip_time(in_data)) + self.inter.values = dtools.get_magnitude( + dtools.strip_time(in_data, self.input_grid) + ) res = self.inter(self.out_coords) if self.fill_with_nearest: res[self.out_ids] = self.inter.values.flatten( diff --git a/src/finam/adapters/time.py b/src/finam/adapters/time.py index 7b3a816dfe8e3aea23fed35ef848c22b0d2ebb9b..73101267c42d4532c41a375ed64aa16dc23232b8 100644 --- a/src/finam/adapters/time.py +++ b/src/finam/adapters/time.py @@ -240,7 +240,7 @@ class TimeCachingAdapter(Adapter, NoBranchAdapter, ABC): """ check_time(self.logger, time) - data = dtools.strip_data(self.pull_data(time, self)) + data = dtools.strip_time(self.pull_data(time, self), self._input_info.grid) self.data.append((time, data)) def _get_data(self, time, _target): @@ -360,7 +360,7 @@ class StackTime(TimeCachingAdapter): break arr = np.stack([d[1] for d in extract]) - return dtools.to_xarray(arr, self.name, self.info, time_entries=len(extract)) + return dtools.to_xarray(arr, self.info, time_entries=len(extract)) class LinearTime(TimeCachingAdapter): diff --git a/src/finam/adapters/time_integration.py b/src/finam/adapters/time_integration.py index a7e474ce923e129fd4df38b1d0dd16695985c026..5d905c102322f21d1421f3e612d66900bda5000a 100644 --- a/src/finam/adapters/time_integration.py +++ b/src/finam/adapters/time_integration.py @@ -26,7 +26,7 @@ class TimeIntegrationAdapter(TimeCachingAdapter, ABC): """ check_time(self.logger, time) - data = tools.strip_data(self.pull_data(time, self)) + data = tools.strip_time(self.pull_data(time, self), self._input_info.grid) self.data.append((time, data)) if self._prev_time is None: diff --git a/src/finam/data/tools.py b/src/finam/data/tools.py index d78192c845c8824d2c03f705f11fbdf6188f3175..bfe1ed4d40861555d8b4efbbbb064cd283bc10e9 100644 --- a/src/finam/data/tools.py +++ b/src/finam/data/tools.py @@ -108,6 +108,11 @@ def to_xarray(data, info, time_entries=1, force_copy=False): def _check_input_shape(data, info, time_entries): # check correct data size if isinstance(info.grid, Grid): + time_entries = ( + data.shape[0] + if len(data.shape) == len(info.grid.data_shape) + 1 + else time_entries + ) data_size = data.size / time_entries if data_size != info.grid.data_size: raise FinamDataError( @@ -135,21 +140,25 @@ def _check_input_shape(data, info, time_entries): [time_entries] + list(info.grid.data_shape), order=info.grid.order ) elif isinstance(info.grid, grid_spec.NoGrid): - if len(data.shape) != info.grid.dim + 1: - if len(data.shape) == info.grid.dim: - data = np.expand_dims(data, 0) - else: - raise FinamDataError( - f"to_xarray: number of dimensions in data doesn't match expected number. " - f"Got {len(data.shape)}, expected {info.grid.dim}" - ) - else: - if data.shape[0] != time_entries: - raise FinamDataError( - f"to_xarray: number of time entries in data doesn't match expected number. " - f"Got {data.shape[0]}, expected {time_entries}" - ) + data = _check_input_shape_no_grid(data, info, time_entries) + return data + +def _check_input_shape_no_grid(data, info, time_entries): + if len(data.shape) != info.grid.dim + 1: + if len(data.shape) == info.grid.dim: + data = np.expand_dims(data, 0) + else: + raise FinamDataError( + f"to_xarray: number of dimensions in data doesn't match expected number. " + f"Got {len(data.shape)}, expected {info.grid.dim}" + ) + else: + if data.shape[0] != time_entries: + raise FinamDataError( + f"to_xarray: number of time entries in data doesn't match expected number. " + f"Got {data.shape[0]}, expected {time_entries}" + ) return data @@ -168,10 +177,11 @@ def has_time_axis(xdata, grid): bool Whether the data has a time axis. """ - if xdata.ndim == grid.dim: + grid_dim = len(grid.data_shape) if isinstance(grid, Grid) else grid.dim + if xdata.ndim == grid_dim: return False - if xdata.ndim == grid.dim + 1: + if xdata.ndim == grid_dim + 1: return True raise FinamDataError("Data dimension must be grid dimension or grid dimension + 1.") diff --git a/src/finam/modules/debug.py b/src/finam/modules/debug.py index cfd94a89a85419b87f789323e6632eedfac34640..b1207727b1fce35bd6c33f01b15730575f17b3a3 100644 --- a/src/finam/modules/debug.py +++ b/src/finam/modules/debug.py @@ -116,7 +116,11 @@ class DebugConsumer(TimeComponent): self.logger.debug("Pulled input data for %s", name) if self._log_data is not None: - pdata = data[0, ...] if self._strip_data else data + pdata = ( + tools.strip_time(data, self.inputs[name].info.grid) + if self._strip_data + else data + ) self.logger.log( self._log_data, 'Received "%s" - %s: %s', @@ -140,7 +144,11 @@ class DebugConsumer(TimeComponent): } for name, data in self._data.items(): if self._log_data is not None: - pdata = data[0, ...] if self._strip_data else data + pdata = ( + tools.strip_time(data, self.inputs[name].info.grid) + if self._strip_data + else data + ) self.logger.log( self._log_data, 'Received "%s" - %s: %s', @@ -253,7 +261,9 @@ class DebugPushConsumer(Component): data = caller.pull_data(time) self._data[caller.name] = data if self._log_data is not None: - pdata = tools.strip_data(data) if self._strip_data else data + pdata = ( + tools.strip_time(data, caller.info.grid) if self._strip_data else data + ) self.logger.log( self._log_data, 'Received "%s" - %s: %s', diff --git a/src/finam/modules/mergers.py b/src/finam/modules/mergers.py index 4220a9593e76c69bf5aa1e92d7fcd3fce62d7e6c..eee2614a05760c9bb19dc9c59c6f14ec0e9eb44d 100644 --- a/src/finam/modules/mergers.py +++ b/src/finam/modules/mergers.py @@ -140,8 +140,8 @@ class WeightedSum(Component): result = None for name in self._input_names: - value = strip_time(self._in_data[name]) - weight = strip_time(self._in_data[name + "_weight"]) + value = strip_time(self._in_data[name], self._grid) + weight = strip_time(self._in_data[name + "_weight"], self._grid) if result is None: result = value * weight diff --git a/src/finam/modules/noise.py b/src/finam/modules/noise.py index e95d1fca762833c9250e59250dbba062cea5adb1..2d4b82411f611113e22b9c14d3f82f093f71d3eb 100644 --- a/src/finam/modules/noise.py +++ b/src/finam/modules/noise.py @@ -307,7 +307,6 @@ def _generate_noise( data /= max_amp data = data * (high - low) / 2 + (high + low) / 2 - return data diff --git a/src/finam/modules/writers.py b/src/finam/modules/writers.py index e9547d4bfa18a91af3ad2359119135ee80f13053..1dcbe8757fb98b1899ca53265f1e9950d995e452 100644 --- a/src/finam/modules/writers.py +++ b/src/finam/modules/writers.py @@ -106,7 +106,7 @@ class CsvWriter(TimeComponent): if self.status == ComponentStatus.CONNECTED: values = [ - dtools.get_magnitude(dtools.strip_time(data)) + dtools.get_magnitude(dtools.strip_time(data, NoGrid())) for _, data in self.connector.in_data.items() ] @@ -128,7 +128,7 @@ class CsvWriter(TimeComponent): values = [ dtools.get_magnitude( - dtools.strip_time(self.inputs[inp].pull_data(self.time)) + dtools.strip_time(self.inputs[inp].pull_data(self.time), NoGrid()) ) for inp in self._input_names ] diff --git a/tests/adapters/test_stats.py b/tests/adapters/test_stats.py index c90d44ba6d5d7102fa7a089aae7c4510520bfbc0..939057928cbb42db753b653e3f3d8c110684db90 100644 --- a/tests/adapters/test_stats.py +++ b/tests/adapters/test_stats.py @@ -36,7 +36,7 @@ class TestHistogram(unittest.TestCase): data = sink.data["Input"] self.assertEqual(data.shape, (1, 20)) - self.assertEqual(data.data.sum(), 11 * 14) + self.assertEqual(data.sum(), 11 * 14) self.assertEqual(fm.data.get_units(data), fm.UNITS.dimensionless) composition.run(end_time=datetime(2000, 1, 10)) diff --git a/tests/adapters/test_time.py b/tests/adapters/test_time.py index 3f212fd16c39ad19ef8a989829e710907abac570..82bab928fbdd4e2c4c14e764d88c657938335aad 100644 --- a/tests/adapters/test_time.py +++ b/tests/adapters/test_time.py @@ -121,28 +121,28 @@ class TestDelayFixed(unittest.TestCase): def test_fixed_delay(self): data = self.adapter.get_data(datetime(2000, 1, 1), None) - self.assertEqual(tools.get_data(data), 0) + self.assertEqual(data, 0) self.source.update() self.source.update() data = self.adapter.get_data(datetime(2000, 1, 5), None) - self.assertEqual(tools.get_data(data), 0) + self.assertEqual(data, 0) for _ in range(20): self.source.update() data = self.adapter.get_data(datetime(2000, 1, 10), None) - self.assertEqual(tools.get_data(data), 0) + self.assertEqual(data, 0) data = self.adapter.get_data(datetime(2000, 1, 11), None) - self.assertEqual(tools.get_data(data), 0) + self.assertEqual(data, 0) data = self.adapter.get_data(datetime(2000, 1, 12), None) - self.assertEqual(tools.get_data(data), 1) + self.assertEqual(data, 1) data = self.adapter.get_data(datetime(2000, 1, 20), None) - self.assertEqual(tools.get_data(data), 9) + self.assertEqual(data, 9) class TestNextValue(unittest.TestCase): diff --git a/tests/modules/test_debug.py b/tests/modules/test_debug.py index aecef42edc125d468d6d03fa4308be6da194a0d1..09bd12cc7ef828e7d7998914677aeb8f26c3c4eb 100644 --- a/tests/modules/test_debug.py +++ b/tests/modules/test_debug.py @@ -21,7 +21,7 @@ class TestScheduleLogger(unittest.TestCase): outputs={ "Out": fm.Info(time=None, grid=fm.NoGrid()), }, - callback=lambda inp, _t: {"Out": fm.data.strip_data(inp["In"])}, + callback=lambda inp, _t: {"Out": inp["In"][0, ...]}, start=start, step=timedelta(days=5), ) @@ -32,7 +32,7 @@ class TestScheduleLogger(unittest.TestCase): outputs={ "Out": fm.Info(time=None, grid=fm.NoGrid()), }, - callback=lambda inp, _t: {"Out": fm.data.strip_data(inp["In"])}, + callback=lambda inp, _t: {"Out": inp["In"][0, ...]}, start=start, step=timedelta(days=8), ) @@ -87,8 +87,8 @@ class TestPushDebugConsumer(unittest.TestCase): module1.outputs["Out"] >> consumer.inputs["In"] composition.connect(start) - self.assertEqual(fm.data.strip_data(consumer.data["In"]), 1) + self.assertEqual(consumer.data["In"][0, ...], 1) composition.run(start_time=start, end_time=datetime(2000, 1, 10)) - self.assertEqual(fm.data.strip_data(consumer.data["In"]), 11) + self.assertEqual(consumer.data["In"][0, ...], 11) diff --git a/tests/modules/test_noise.py b/tests/modules/test_noise.py index 2932c2b08e93fa95d71e7e842f1ba3cd9557400a..bc1d7f1a5d7c90a5face9822f4e85ff8edd2917a 100644 --- a/tests/modules/test_noise.py +++ b/tests/modules/test_noise.py @@ -36,7 +36,7 @@ class TestNoise(unittest.TestCase): composition.connect() composition.run(end_time=datetime(2000, 1, 2)) - data = fm.data.strip_data(sink.data["Input"]) + data = sink.data["Input"][0, ...] self.assertEqual(data.shape, ()) def test_noise_uniform_1d(self): @@ -68,7 +68,7 @@ class TestNoise(unittest.TestCase): composition.connect(time) composition.run(end_time=datetime(2000, 1, 2)) - data = fm.data.strip_data(sink.data["Input"]) + data = sink.data["Input"][0, ...] self.assertEqual(data.shape, (19,)) def test_noise_uniform_2d(self): @@ -100,7 +100,7 @@ class TestNoise(unittest.TestCase): composition.connect(time) composition.run(end_time=datetime(2000, 1, 2)) - data = fm.data.strip_data(sink.data["Input"]) + data = sink.data["Input"][0, ...] self.assertEqual(data.shape, (19, 14)) def test_noise_uniform_3d(self): @@ -132,7 +132,7 @@ class TestNoise(unittest.TestCase): composition.connect(time) composition.run(end_time=datetime(2000, 1, 2)) - data = fm.data.strip_data(sink.data["Input"]) + data = sink.data["Input"][0, ...] self.assertEqual(data.shape, (19, 14, 9)) def test_noise_points_1d(self): @@ -164,7 +164,7 @@ class TestNoise(unittest.TestCase): composition.connect(time) composition.run(end_time=datetime(2000, 1, 2)) - data = fm.data.strip_data(sink.data["Input"]) + data = sink.data["Input"][0, ...] self.assertEqual(data.shape, (100,)) def test_noise_points_2d(self): @@ -196,7 +196,7 @@ class TestNoise(unittest.TestCase): composition.connect(time) composition.run(end_time=datetime(2000, 1, 2)) - data = fm.data.strip_data(sink.data["Input"]) + data = sink.data["Input"][0, ...] self.assertEqual(data.shape, (100,)) def test_noise_points_3d(self): @@ -228,7 +228,7 @@ class TestNoise(unittest.TestCase): composition.connect(time) composition.run(end_time=datetime(2000, 1, 2)) - data = fm.data.strip_data(sink.data["Input"]) + data = sink.data["Input"][0, ...] self.assertEqual(data.shape, (100,)) def test_noise_fail_nogrid(self): @@ -298,12 +298,12 @@ class TestStaticNoise(unittest.TestCase): source.outputs["Noise"] >> sink.inputs["Input"] composition.connect(None) - data_1 = fm.data.strip_data(sink.data["Input"]) + data_1 = sink.data["Input"][0, ...] self.assertEqual(data_1.shape, ()) composition.run(end_time=None) - data_2 = fm.data.strip_data(sink.data["Input"]) + data_2 = sink.data["Input"][0, ...] self.assertEqual(data_1, data_2) self.assertEqual(data_2.shape, ()) diff --git a/tests/modules/test_parametric.py b/tests/modules/test_parametric.py index 858cecde023a3f35decaeb6ca5af1c7a643103e1..bc17d71a13387ee89f08c56abe7d9bc02b4f7638 100644 --- a/tests/modules/test_parametric.py +++ b/tests/modules/test_parametric.py @@ -283,12 +283,12 @@ class TestStaticParametricGrid(unittest.TestCase): source.outputs["Grid"] >> sink.inputs["Input"] composition.connect(None) - data_1 = fm.data.strip_data(sink.data["Input"]) + data_1 = sink.data["Input"][0, ...] self.assertEqual(data_1.shape, (19, 14)) composition.run(end_time=None) - data_2 = fm.data.strip_data(sink.data["Input"]) + data_2 = sink.data["Input"][0, ...] assert_allclose(data_1, data_2) self.assertEqual(data_2.shape, (19, 14))