From 44e7a32b22a919871829a60c8db5ca91fbe14b11 Mon Sep 17 00:00:00 2001 From: Martin Lange <martin.lange@ufz.de> Date: Tue, 6 Dec 2022 16:24:34 +0100 Subject: [PATCH] fix components and tests --- src/finam/adapters/regrid.py | 10 +++--- src/finam/adapters/time.py | 4 +-- src/finam/adapters/time_integration.py | 2 +- src/finam/data/tools.py | 42 ++++++++++++++++---------- src/finam/modules/debug.py | 16 ++++++++-- src/finam/modules/mergers.py | 4 +-- src/finam/modules/noise.py | 1 - src/finam/modules/writers.py | 4 +-- tests/adapters/test_stats.py | 2 +- tests/adapters/test_time.py | 12 ++++---- tests/modules/test_debug.py | 8 ++--- tests/modules/test_noise.py | 18 +++++------ tests/modules/test_parametric.py | 4 +-- 13 files changed, 73 insertions(+), 54 deletions(-) diff --git a/src/finam/adapters/regrid.py b/src/finam/adapters/regrid.py index 31fb15bd..472f39a1 100644 --- a/src/finam/adapters/regrid.py +++ b/src/finam/adapters/regrid.py @@ -132,10 +132,8 @@ class RegridNearest(ARegridding): def _get_data(self, time, target): in_data = self.pull_data(time, target) - res = ( - dtools.get_data(in_data) - .reshape(-1, order=self.input_grid.order)[self.ids] - .reshape(self.output_grid.data_shape, order=self.output_grid.order) + res = in_data.reshape(-1, order=self.input_grid.order)[self.ids].reshape( + self.output_grid.data_shape, order=self.output_grid.order ) return res @@ -219,7 +217,9 @@ class RegridLinear(ARegridding): in_data = self.pull_data(time, target) if isinstance(self.input_grid, StructuredGrid): - self.inter.values = dtools.get_magnitude(dtools.strip_time(in_data)) + self.inter.values = dtools.get_magnitude( + dtools.strip_time(in_data, self.input_grid) + ) res = self.inter(self.out_coords) if self.fill_with_nearest: res[self.out_ids] = self.inter.values.flatten( diff --git a/src/finam/adapters/time.py b/src/finam/adapters/time.py index 7b3a816d..73101267 100644 --- a/src/finam/adapters/time.py +++ b/src/finam/adapters/time.py @@ -240,7 +240,7 @@ class TimeCachingAdapter(Adapter, NoBranchAdapter, ABC): """ check_time(self.logger, time) - data = dtools.strip_data(self.pull_data(time, self)) + data = dtools.strip_time(self.pull_data(time, self), self._input_info.grid) self.data.append((time, data)) def _get_data(self, time, _target): @@ -360,7 +360,7 @@ class StackTime(TimeCachingAdapter): break arr = np.stack([d[1] for d in extract]) - return dtools.to_xarray(arr, self.name, self.info, time_entries=len(extract)) + return dtools.to_xarray(arr, self.info, time_entries=len(extract)) class LinearTime(TimeCachingAdapter): diff --git a/src/finam/adapters/time_integration.py b/src/finam/adapters/time_integration.py index a7e474ce..5d905c10 100644 --- a/src/finam/adapters/time_integration.py +++ b/src/finam/adapters/time_integration.py @@ -26,7 +26,7 @@ class TimeIntegrationAdapter(TimeCachingAdapter, ABC): """ check_time(self.logger, time) - data = tools.strip_data(self.pull_data(time, self)) + data = tools.strip_time(self.pull_data(time, self), self._input_info.grid) self.data.append((time, data)) if self._prev_time is None: diff --git a/src/finam/data/tools.py b/src/finam/data/tools.py index d78192c8..bfe1ed4d 100644 --- a/src/finam/data/tools.py +++ b/src/finam/data/tools.py @@ -108,6 +108,11 @@ def to_xarray(data, info, time_entries=1, force_copy=False): def _check_input_shape(data, info, time_entries): # check correct data size if isinstance(info.grid, Grid): + time_entries = ( + data.shape[0] + if len(data.shape) == len(info.grid.data_shape) + 1 + else time_entries + ) data_size = data.size / time_entries if data_size != info.grid.data_size: raise FinamDataError( @@ -135,21 +140,25 @@ def _check_input_shape(data, info, time_entries): [time_entries] + list(info.grid.data_shape), order=info.grid.order ) elif isinstance(info.grid, grid_spec.NoGrid): - if len(data.shape) != info.grid.dim + 1: - if len(data.shape) == info.grid.dim: - data = np.expand_dims(data, 0) - else: - raise FinamDataError( - f"to_xarray: number of dimensions in data doesn't match expected number. " - f"Got {len(data.shape)}, expected {info.grid.dim}" - ) - else: - if data.shape[0] != time_entries: - raise FinamDataError( - f"to_xarray: number of time entries in data doesn't match expected number. " - f"Got {data.shape[0]}, expected {time_entries}" - ) + data = _check_input_shape_no_grid(data, info, time_entries) + return data + +def _check_input_shape_no_grid(data, info, time_entries): + if len(data.shape) != info.grid.dim + 1: + if len(data.shape) == info.grid.dim: + data = np.expand_dims(data, 0) + else: + raise FinamDataError( + f"to_xarray: number of dimensions in data doesn't match expected number. " + f"Got {len(data.shape)}, expected {info.grid.dim}" + ) + else: + if data.shape[0] != time_entries: + raise FinamDataError( + f"to_xarray: number of time entries in data doesn't match expected number. " + f"Got {data.shape[0]}, expected {time_entries}" + ) return data @@ -168,10 +177,11 @@ def has_time_axis(xdata, grid): bool Whether the data has a time axis. """ - if xdata.ndim == grid.dim: + grid_dim = len(grid.data_shape) if isinstance(grid, Grid) else grid.dim + if xdata.ndim == grid_dim: return False - if xdata.ndim == grid.dim + 1: + if xdata.ndim == grid_dim + 1: return True raise FinamDataError("Data dimension must be grid dimension or grid dimension + 1.") diff --git a/src/finam/modules/debug.py b/src/finam/modules/debug.py index cfd94a89..b1207727 100644 --- a/src/finam/modules/debug.py +++ b/src/finam/modules/debug.py @@ -116,7 +116,11 @@ class DebugConsumer(TimeComponent): self.logger.debug("Pulled input data for %s", name) if self._log_data is not None: - pdata = data[0, ...] if self._strip_data else data + pdata = ( + tools.strip_time(data, self.inputs[name].info.grid) + if self._strip_data + else data + ) self.logger.log( self._log_data, 'Received "%s" - %s: %s', @@ -140,7 +144,11 @@ class DebugConsumer(TimeComponent): } for name, data in self._data.items(): if self._log_data is not None: - pdata = data[0, ...] if self._strip_data else data + pdata = ( + tools.strip_time(data, self.inputs[name].info.grid) + if self._strip_data + else data + ) self.logger.log( self._log_data, 'Received "%s" - %s: %s', @@ -253,7 +261,9 @@ class DebugPushConsumer(Component): data = caller.pull_data(time) self._data[caller.name] = data if self._log_data is not None: - pdata = tools.strip_data(data) if self._strip_data else data + pdata = ( + tools.strip_time(data, caller.info.grid) if self._strip_data else data + ) self.logger.log( self._log_data, 'Received "%s" - %s: %s', diff --git a/src/finam/modules/mergers.py b/src/finam/modules/mergers.py index 4220a959..eee2614a 100644 --- a/src/finam/modules/mergers.py +++ b/src/finam/modules/mergers.py @@ -140,8 +140,8 @@ class WeightedSum(Component): result = None for name in self._input_names: - value = strip_time(self._in_data[name]) - weight = strip_time(self._in_data[name + "_weight"]) + value = strip_time(self._in_data[name], self._grid) + weight = strip_time(self._in_data[name + "_weight"], self._grid) if result is None: result = value * weight diff --git a/src/finam/modules/noise.py b/src/finam/modules/noise.py index e95d1fca..2d4b8241 100644 --- a/src/finam/modules/noise.py +++ b/src/finam/modules/noise.py @@ -307,7 +307,6 @@ def _generate_noise( data /= max_amp data = data * (high - low) / 2 + (high + low) / 2 - return data diff --git a/src/finam/modules/writers.py b/src/finam/modules/writers.py index e9547d4b..1dcbe875 100644 --- a/src/finam/modules/writers.py +++ b/src/finam/modules/writers.py @@ -106,7 +106,7 @@ class CsvWriter(TimeComponent): if self.status == ComponentStatus.CONNECTED: values = [ - dtools.get_magnitude(dtools.strip_time(data)) + dtools.get_magnitude(dtools.strip_time(data, NoGrid())) for _, data in self.connector.in_data.items() ] @@ -128,7 +128,7 @@ class CsvWriter(TimeComponent): values = [ dtools.get_magnitude( - dtools.strip_time(self.inputs[inp].pull_data(self.time)) + dtools.strip_time(self.inputs[inp].pull_data(self.time), NoGrid()) ) for inp in self._input_names ] diff --git a/tests/adapters/test_stats.py b/tests/adapters/test_stats.py index c90d44ba..93905792 100644 --- a/tests/adapters/test_stats.py +++ b/tests/adapters/test_stats.py @@ -36,7 +36,7 @@ class TestHistogram(unittest.TestCase): data = sink.data["Input"] self.assertEqual(data.shape, (1, 20)) - self.assertEqual(data.data.sum(), 11 * 14) + self.assertEqual(data.sum(), 11 * 14) self.assertEqual(fm.data.get_units(data), fm.UNITS.dimensionless) composition.run(end_time=datetime(2000, 1, 10)) diff --git a/tests/adapters/test_time.py b/tests/adapters/test_time.py index 3f212fd1..82bab928 100644 --- a/tests/adapters/test_time.py +++ b/tests/adapters/test_time.py @@ -121,28 +121,28 @@ class TestDelayFixed(unittest.TestCase): def test_fixed_delay(self): data = self.adapter.get_data(datetime(2000, 1, 1), None) - self.assertEqual(tools.get_data(data), 0) + self.assertEqual(data, 0) self.source.update() self.source.update() data = self.adapter.get_data(datetime(2000, 1, 5), None) - self.assertEqual(tools.get_data(data), 0) + self.assertEqual(data, 0) for _ in range(20): self.source.update() data = self.adapter.get_data(datetime(2000, 1, 10), None) - self.assertEqual(tools.get_data(data), 0) + self.assertEqual(data, 0) data = self.adapter.get_data(datetime(2000, 1, 11), None) - self.assertEqual(tools.get_data(data), 0) + self.assertEqual(data, 0) data = self.adapter.get_data(datetime(2000, 1, 12), None) - self.assertEqual(tools.get_data(data), 1) + self.assertEqual(data, 1) data = self.adapter.get_data(datetime(2000, 1, 20), None) - self.assertEqual(tools.get_data(data), 9) + self.assertEqual(data, 9) class TestNextValue(unittest.TestCase): diff --git a/tests/modules/test_debug.py b/tests/modules/test_debug.py index aecef42e..09bd12cc 100644 --- a/tests/modules/test_debug.py +++ b/tests/modules/test_debug.py @@ -21,7 +21,7 @@ class TestScheduleLogger(unittest.TestCase): outputs={ "Out": fm.Info(time=None, grid=fm.NoGrid()), }, - callback=lambda inp, _t: {"Out": fm.data.strip_data(inp["In"])}, + callback=lambda inp, _t: {"Out": inp["In"][0, ...]}, start=start, step=timedelta(days=5), ) @@ -32,7 +32,7 @@ class TestScheduleLogger(unittest.TestCase): outputs={ "Out": fm.Info(time=None, grid=fm.NoGrid()), }, - callback=lambda inp, _t: {"Out": fm.data.strip_data(inp["In"])}, + callback=lambda inp, _t: {"Out": inp["In"][0, ...]}, start=start, step=timedelta(days=8), ) @@ -87,8 +87,8 @@ class TestPushDebugConsumer(unittest.TestCase): module1.outputs["Out"] >> consumer.inputs["In"] composition.connect(start) - self.assertEqual(fm.data.strip_data(consumer.data["In"]), 1) + self.assertEqual(consumer.data["In"][0, ...], 1) composition.run(start_time=start, end_time=datetime(2000, 1, 10)) - self.assertEqual(fm.data.strip_data(consumer.data["In"]), 11) + self.assertEqual(consumer.data["In"][0, ...], 11) diff --git a/tests/modules/test_noise.py b/tests/modules/test_noise.py index 2932c2b0..bc1d7f1a 100644 --- a/tests/modules/test_noise.py +++ b/tests/modules/test_noise.py @@ -36,7 +36,7 @@ class TestNoise(unittest.TestCase): composition.connect() composition.run(end_time=datetime(2000, 1, 2)) - data = fm.data.strip_data(sink.data["Input"]) + data = sink.data["Input"][0, ...] self.assertEqual(data.shape, ()) def test_noise_uniform_1d(self): @@ -68,7 +68,7 @@ class TestNoise(unittest.TestCase): composition.connect(time) composition.run(end_time=datetime(2000, 1, 2)) - data = fm.data.strip_data(sink.data["Input"]) + data = sink.data["Input"][0, ...] self.assertEqual(data.shape, (19,)) def test_noise_uniform_2d(self): @@ -100,7 +100,7 @@ class TestNoise(unittest.TestCase): composition.connect(time) composition.run(end_time=datetime(2000, 1, 2)) - data = fm.data.strip_data(sink.data["Input"]) + data = sink.data["Input"][0, ...] self.assertEqual(data.shape, (19, 14)) def test_noise_uniform_3d(self): @@ -132,7 +132,7 @@ class TestNoise(unittest.TestCase): composition.connect(time) composition.run(end_time=datetime(2000, 1, 2)) - data = fm.data.strip_data(sink.data["Input"]) + data = sink.data["Input"][0, ...] self.assertEqual(data.shape, (19, 14, 9)) def test_noise_points_1d(self): @@ -164,7 +164,7 @@ class TestNoise(unittest.TestCase): composition.connect(time) composition.run(end_time=datetime(2000, 1, 2)) - data = fm.data.strip_data(sink.data["Input"]) + data = sink.data["Input"][0, ...] self.assertEqual(data.shape, (100,)) def test_noise_points_2d(self): @@ -196,7 +196,7 @@ class TestNoise(unittest.TestCase): composition.connect(time) composition.run(end_time=datetime(2000, 1, 2)) - data = fm.data.strip_data(sink.data["Input"]) + data = sink.data["Input"][0, ...] self.assertEqual(data.shape, (100,)) def test_noise_points_3d(self): @@ -228,7 +228,7 @@ class TestNoise(unittest.TestCase): composition.connect(time) composition.run(end_time=datetime(2000, 1, 2)) - data = fm.data.strip_data(sink.data["Input"]) + data = sink.data["Input"][0, ...] self.assertEqual(data.shape, (100,)) def test_noise_fail_nogrid(self): @@ -298,12 +298,12 @@ class TestStaticNoise(unittest.TestCase): source.outputs["Noise"] >> sink.inputs["Input"] composition.connect(None) - data_1 = fm.data.strip_data(sink.data["Input"]) + data_1 = sink.data["Input"][0, ...] self.assertEqual(data_1.shape, ()) composition.run(end_time=None) - data_2 = fm.data.strip_data(sink.data["Input"]) + data_2 = sink.data["Input"][0, ...] self.assertEqual(data_1, data_2) self.assertEqual(data_2.shape, ()) diff --git a/tests/modules/test_parametric.py b/tests/modules/test_parametric.py index 858cecde..bc17d71a 100644 --- a/tests/modules/test_parametric.py +++ b/tests/modules/test_parametric.py @@ -283,12 +283,12 @@ class TestStaticParametricGrid(unittest.TestCase): source.outputs["Grid"] >> sink.inputs["Input"] composition.connect(None) - data_1 = fm.data.strip_data(sink.data["Input"]) + data_1 = sink.data["Input"][0, ...] self.assertEqual(data_1.shape, (19, 14)) composition.run(end_time=None) - data_2 = fm.data.strip_data(sink.data["Input"]) + data_2 = sink.data["Input"][0, ...] assert_allclose(data_1, data_2) self.assertEqual(data_2.shape, (19, 14)) -- GitLab