From 44e7a32b22a919871829a60c8db5ca91fbe14b11 Mon Sep 17 00:00:00 2001
From: Martin Lange <martin.lange@ufz.de>
Date: Tue, 6 Dec 2022 16:24:34 +0100
Subject: [PATCH] fix components and tests

---
 src/finam/adapters/regrid.py           | 10 +++---
 src/finam/adapters/time.py             |  4 +--
 src/finam/adapters/time_integration.py |  2 +-
 src/finam/data/tools.py                | 42 ++++++++++++++++----------
 src/finam/modules/debug.py             | 16 ++++++++--
 src/finam/modules/mergers.py           |  4 +--
 src/finam/modules/noise.py             |  1 -
 src/finam/modules/writers.py           |  4 +--
 tests/adapters/test_stats.py           |  2 +-
 tests/adapters/test_time.py            | 12 ++++----
 tests/modules/test_debug.py            |  8 ++---
 tests/modules/test_noise.py            | 18 +++++------
 tests/modules/test_parametric.py       |  4 +--
 13 files changed, 73 insertions(+), 54 deletions(-)

diff --git a/src/finam/adapters/regrid.py b/src/finam/adapters/regrid.py
index 31fb15bd..472f39a1 100644
--- a/src/finam/adapters/regrid.py
+++ b/src/finam/adapters/regrid.py
@@ -132,10 +132,8 @@ class RegridNearest(ARegridding):
     def _get_data(self, time, target):
         in_data = self.pull_data(time, target)
 
-        res = (
-            dtools.get_data(in_data)
-            .reshape(-1, order=self.input_grid.order)[self.ids]
-            .reshape(self.output_grid.data_shape, order=self.output_grid.order)
+        res = in_data.reshape(-1, order=self.input_grid.order)[self.ids].reshape(
+            self.output_grid.data_shape, order=self.output_grid.order
         )
         return res
 
@@ -219,7 +217,9 @@ class RegridLinear(ARegridding):
         in_data = self.pull_data(time, target)
 
         if isinstance(self.input_grid, StructuredGrid):
-            self.inter.values = dtools.get_magnitude(dtools.strip_time(in_data))
+            self.inter.values = dtools.get_magnitude(
+                dtools.strip_time(in_data, self.input_grid)
+            )
             res = self.inter(self.out_coords)
             if self.fill_with_nearest:
                 res[self.out_ids] = self.inter.values.flatten(
diff --git a/src/finam/adapters/time.py b/src/finam/adapters/time.py
index 7b3a816d..73101267 100644
--- a/src/finam/adapters/time.py
+++ b/src/finam/adapters/time.py
@@ -240,7 +240,7 @@ class TimeCachingAdapter(Adapter, NoBranchAdapter, ABC):
         """
         check_time(self.logger, time)
 
-        data = dtools.strip_data(self.pull_data(time, self))
+        data = dtools.strip_time(self.pull_data(time, self), self._input_info.grid)
         self.data.append((time, data))
 
     def _get_data(self, time, _target):
@@ -360,7 +360,7 @@ class StackTime(TimeCachingAdapter):
             break
 
         arr = np.stack([d[1] for d in extract])
-        return dtools.to_xarray(arr, self.name, self.info, time_entries=len(extract))
+        return dtools.to_xarray(arr, self.info, time_entries=len(extract))
 
 
 class LinearTime(TimeCachingAdapter):
diff --git a/src/finam/adapters/time_integration.py b/src/finam/adapters/time_integration.py
index a7e474ce..5d905c10 100644
--- a/src/finam/adapters/time_integration.py
+++ b/src/finam/adapters/time_integration.py
@@ -26,7 +26,7 @@ class TimeIntegrationAdapter(TimeCachingAdapter, ABC):
         """
         check_time(self.logger, time)
 
-        data = tools.strip_data(self.pull_data(time, self))
+        data = tools.strip_time(self.pull_data(time, self), self._input_info.grid)
         self.data.append((time, data))
 
         if self._prev_time is None:
diff --git a/src/finam/data/tools.py b/src/finam/data/tools.py
index d78192c8..bfe1ed4d 100644
--- a/src/finam/data/tools.py
+++ b/src/finam/data/tools.py
@@ -108,6 +108,11 @@ def to_xarray(data, info, time_entries=1, force_copy=False):
 def _check_input_shape(data, info, time_entries):
     # check correct data size
     if isinstance(info.grid, Grid):
+        time_entries = (
+            data.shape[0]
+            if len(data.shape) == len(info.grid.data_shape) + 1
+            else time_entries
+        )
         data_size = data.size / time_entries
         if data_size != info.grid.data_size:
             raise FinamDataError(
@@ -135,21 +140,25 @@ def _check_input_shape(data, info, time_entries):
                     [time_entries] + list(info.grid.data_shape), order=info.grid.order
                 )
     elif isinstance(info.grid, grid_spec.NoGrid):
-        if len(data.shape) != info.grid.dim + 1:
-            if len(data.shape) == info.grid.dim:
-                data = np.expand_dims(data, 0)
-            else:
-                raise FinamDataError(
-                    f"to_xarray: number of dimensions in data doesn't match expected number. "
-                    f"Got {len(data.shape)}, expected {info.grid.dim}"
-                )
-        else:
-            if data.shape[0] != time_entries:
-                raise FinamDataError(
-                    f"to_xarray: number of time entries in data doesn't match expected number. "
-                    f"Got {data.shape[0]}, expected {time_entries}"
-                )
+        data = _check_input_shape_no_grid(data, info, time_entries)
+    return data
+
 
+def _check_input_shape_no_grid(data, info, time_entries):
+    if len(data.shape) != info.grid.dim + 1:
+        if len(data.shape) == info.grid.dim:
+            data = np.expand_dims(data, 0)
+        else:
+            raise FinamDataError(
+                f"to_xarray: number of dimensions in data doesn't match expected number. "
+                f"Got {len(data.shape)}, expected {info.grid.dim}"
+            )
+    else:
+        if data.shape[0] != time_entries:
+            raise FinamDataError(
+                f"to_xarray: number of time entries in data doesn't match expected number. "
+                f"Got {data.shape[0]}, expected {time_entries}"
+            )
     return data
 
 
@@ -168,10 +177,11 @@ def has_time_axis(xdata, grid):
     bool
         Whether the data has a time axis.
     """
-    if xdata.ndim == grid.dim:
+    grid_dim = len(grid.data_shape) if isinstance(grid, Grid) else grid.dim
+    if xdata.ndim == grid_dim:
         return False
 
-    if xdata.ndim == grid.dim + 1:
+    if xdata.ndim == grid_dim + 1:
         return True
 
     raise FinamDataError("Data dimension must be grid dimension or grid dimension + 1.")
diff --git a/src/finam/modules/debug.py b/src/finam/modules/debug.py
index cfd94a89..b1207727 100644
--- a/src/finam/modules/debug.py
+++ b/src/finam/modules/debug.py
@@ -116,7 +116,11 @@ class DebugConsumer(TimeComponent):
                 self.logger.debug("Pulled input data for %s", name)
 
                 if self._log_data is not None:
-                    pdata = data[0, ...] if self._strip_data else data
+                    pdata = (
+                        tools.strip_time(data, self.inputs[name].info.grid)
+                        if self._strip_data
+                        else data
+                    )
                     self.logger.log(
                         self._log_data,
                         'Received "%s" - %s: %s',
@@ -140,7 +144,11 @@ class DebugConsumer(TimeComponent):
         }
         for name, data in self._data.items():
             if self._log_data is not None:
-                pdata = data[0, ...] if self._strip_data else data
+                pdata = (
+                    tools.strip_time(data, self.inputs[name].info.grid)
+                    if self._strip_data
+                    else data
+                )
                 self.logger.log(
                     self._log_data,
                     'Received "%s" - %s: %s',
@@ -253,7 +261,9 @@ class DebugPushConsumer(Component):
         data = caller.pull_data(time)
         self._data[caller.name] = data
         if self._log_data is not None:
-            pdata = tools.strip_data(data) if self._strip_data else data
+            pdata = (
+                tools.strip_time(data, caller.info.grid) if self._strip_data else data
+            )
             self.logger.log(
                 self._log_data,
                 'Received "%s" - %s: %s',
diff --git a/src/finam/modules/mergers.py b/src/finam/modules/mergers.py
index 4220a959..eee2614a 100644
--- a/src/finam/modules/mergers.py
+++ b/src/finam/modules/mergers.py
@@ -140,8 +140,8 @@ class WeightedSum(Component):
             result = None
 
             for name in self._input_names:
-                value = strip_time(self._in_data[name])
-                weight = strip_time(self._in_data[name + "_weight"])
+                value = strip_time(self._in_data[name], self._grid)
+                weight = strip_time(self._in_data[name + "_weight"], self._grid)
 
                 if result is None:
                     result = value * weight
diff --git a/src/finam/modules/noise.py b/src/finam/modules/noise.py
index e95d1fca..2d4b8241 100644
--- a/src/finam/modules/noise.py
+++ b/src/finam/modules/noise.py
@@ -307,7 +307,6 @@ def _generate_noise(
 
     data /= max_amp
     data = data * (high - low) / 2 + (high + low) / 2
-
     return data
 
 
diff --git a/src/finam/modules/writers.py b/src/finam/modules/writers.py
index e9547d4b..1dcbe875 100644
--- a/src/finam/modules/writers.py
+++ b/src/finam/modules/writers.py
@@ -106,7 +106,7 @@ class CsvWriter(TimeComponent):
 
         if self.status == ComponentStatus.CONNECTED:
             values = [
-                dtools.get_magnitude(dtools.strip_time(data))
+                dtools.get_magnitude(dtools.strip_time(data, NoGrid()))
                 for _, data in self.connector.in_data.items()
             ]
 
@@ -128,7 +128,7 @@ class CsvWriter(TimeComponent):
 
         values = [
             dtools.get_magnitude(
-                dtools.strip_time(self.inputs[inp].pull_data(self.time))
+                dtools.strip_time(self.inputs[inp].pull_data(self.time), NoGrid())
             )
             for inp in self._input_names
         ]
diff --git a/tests/adapters/test_stats.py b/tests/adapters/test_stats.py
index c90d44ba..93905792 100644
--- a/tests/adapters/test_stats.py
+++ b/tests/adapters/test_stats.py
@@ -36,7 +36,7 @@ class TestHistogram(unittest.TestCase):
 
         data = sink.data["Input"]
         self.assertEqual(data.shape, (1, 20))
-        self.assertEqual(data.data.sum(), 11 * 14)
+        self.assertEqual(data.sum(), 11 * 14)
         self.assertEqual(fm.data.get_units(data), fm.UNITS.dimensionless)
 
         composition.run(end_time=datetime(2000, 1, 10))
diff --git a/tests/adapters/test_time.py b/tests/adapters/test_time.py
index 3f212fd1..82bab928 100644
--- a/tests/adapters/test_time.py
+++ b/tests/adapters/test_time.py
@@ -121,28 +121,28 @@ class TestDelayFixed(unittest.TestCase):
 
     def test_fixed_delay(self):
         data = self.adapter.get_data(datetime(2000, 1, 1), None)
-        self.assertEqual(tools.get_data(data), 0)
+        self.assertEqual(data, 0)
 
         self.source.update()
         self.source.update()
 
         data = self.adapter.get_data(datetime(2000, 1, 5), None)
-        self.assertEqual(tools.get_data(data), 0)
+        self.assertEqual(data, 0)
 
         for _ in range(20):
             self.source.update()
 
         data = self.adapter.get_data(datetime(2000, 1, 10), None)
-        self.assertEqual(tools.get_data(data), 0)
+        self.assertEqual(data, 0)
 
         data = self.adapter.get_data(datetime(2000, 1, 11), None)
-        self.assertEqual(tools.get_data(data), 0)
+        self.assertEqual(data, 0)
 
         data = self.adapter.get_data(datetime(2000, 1, 12), None)
-        self.assertEqual(tools.get_data(data), 1)
+        self.assertEqual(data, 1)
 
         data = self.adapter.get_data(datetime(2000, 1, 20), None)
-        self.assertEqual(tools.get_data(data), 9)
+        self.assertEqual(data, 9)
 
 
 class TestNextValue(unittest.TestCase):
diff --git a/tests/modules/test_debug.py b/tests/modules/test_debug.py
index aecef42e..09bd12cc 100644
--- a/tests/modules/test_debug.py
+++ b/tests/modules/test_debug.py
@@ -21,7 +21,7 @@ class TestScheduleLogger(unittest.TestCase):
             outputs={
                 "Out": fm.Info(time=None, grid=fm.NoGrid()),
             },
-            callback=lambda inp, _t: {"Out": fm.data.strip_data(inp["In"])},
+            callback=lambda inp, _t: {"Out": inp["In"][0, ...]},
             start=start,
             step=timedelta(days=5),
         )
@@ -32,7 +32,7 @@ class TestScheduleLogger(unittest.TestCase):
             outputs={
                 "Out": fm.Info(time=None, grid=fm.NoGrid()),
             },
-            callback=lambda inp, _t: {"Out": fm.data.strip_data(inp["In"])},
+            callback=lambda inp, _t: {"Out": inp["In"][0, ...]},
             start=start,
             step=timedelta(days=8),
         )
@@ -87,8 +87,8 @@ class TestPushDebugConsumer(unittest.TestCase):
         module1.outputs["Out"] >> consumer.inputs["In"]
 
         composition.connect(start)
-        self.assertEqual(fm.data.strip_data(consumer.data["In"]), 1)
+        self.assertEqual(consumer.data["In"][0, ...], 1)
 
         composition.run(start_time=start, end_time=datetime(2000, 1, 10))
 
-        self.assertEqual(fm.data.strip_data(consumer.data["In"]), 11)
+        self.assertEqual(consumer.data["In"][0, ...], 11)
diff --git a/tests/modules/test_noise.py b/tests/modules/test_noise.py
index 2932c2b0..bc1d7f1a 100644
--- a/tests/modules/test_noise.py
+++ b/tests/modules/test_noise.py
@@ -36,7 +36,7 @@ class TestNoise(unittest.TestCase):
         composition.connect()
         composition.run(end_time=datetime(2000, 1, 2))
 
-        data = fm.data.strip_data(sink.data["Input"])
+        data = sink.data["Input"][0, ...]
         self.assertEqual(data.shape, ())
 
     def test_noise_uniform_1d(self):
@@ -68,7 +68,7 @@ class TestNoise(unittest.TestCase):
         composition.connect(time)
         composition.run(end_time=datetime(2000, 1, 2))
 
-        data = fm.data.strip_data(sink.data["Input"])
+        data = sink.data["Input"][0, ...]
         self.assertEqual(data.shape, (19,))
 
     def test_noise_uniform_2d(self):
@@ -100,7 +100,7 @@ class TestNoise(unittest.TestCase):
         composition.connect(time)
         composition.run(end_time=datetime(2000, 1, 2))
 
-        data = fm.data.strip_data(sink.data["Input"])
+        data = sink.data["Input"][0, ...]
         self.assertEqual(data.shape, (19, 14))
 
     def test_noise_uniform_3d(self):
@@ -132,7 +132,7 @@ class TestNoise(unittest.TestCase):
         composition.connect(time)
         composition.run(end_time=datetime(2000, 1, 2))
 
-        data = fm.data.strip_data(sink.data["Input"])
+        data = sink.data["Input"][0, ...]
         self.assertEqual(data.shape, (19, 14, 9))
 
     def test_noise_points_1d(self):
@@ -164,7 +164,7 @@ class TestNoise(unittest.TestCase):
         composition.connect(time)
         composition.run(end_time=datetime(2000, 1, 2))
 
-        data = fm.data.strip_data(sink.data["Input"])
+        data = sink.data["Input"][0, ...]
         self.assertEqual(data.shape, (100,))
 
     def test_noise_points_2d(self):
@@ -196,7 +196,7 @@ class TestNoise(unittest.TestCase):
         composition.connect(time)
         composition.run(end_time=datetime(2000, 1, 2))
 
-        data = fm.data.strip_data(sink.data["Input"])
+        data = sink.data["Input"][0, ...]
         self.assertEqual(data.shape, (100,))
 
     def test_noise_points_3d(self):
@@ -228,7 +228,7 @@ class TestNoise(unittest.TestCase):
         composition.connect(time)
         composition.run(end_time=datetime(2000, 1, 2))
 
-        data = fm.data.strip_data(sink.data["Input"])
+        data = sink.data["Input"][0, ...]
         self.assertEqual(data.shape, (100,))
 
     def test_noise_fail_nogrid(self):
@@ -298,12 +298,12 @@ class TestStaticNoise(unittest.TestCase):
         source.outputs["Noise"] >> sink.inputs["Input"]
 
         composition.connect(None)
-        data_1 = fm.data.strip_data(sink.data["Input"])
+        data_1 = sink.data["Input"][0, ...]
 
         self.assertEqual(data_1.shape, ())
 
         composition.run(end_time=None)
 
-        data_2 = fm.data.strip_data(sink.data["Input"])
+        data_2 = sink.data["Input"][0, ...]
         self.assertEqual(data_1, data_2)
         self.assertEqual(data_2.shape, ())
diff --git a/tests/modules/test_parametric.py b/tests/modules/test_parametric.py
index 858cecde..bc17d71a 100644
--- a/tests/modules/test_parametric.py
+++ b/tests/modules/test_parametric.py
@@ -283,12 +283,12 @@ class TestStaticParametricGrid(unittest.TestCase):
         source.outputs["Grid"] >> sink.inputs["Input"]
 
         composition.connect(None)
-        data_1 = fm.data.strip_data(sink.data["Input"])
+        data_1 = sink.data["Input"][0, ...]
 
         self.assertEqual(data_1.shape, (19, 14))
 
         composition.run(end_time=None)
 
-        data_2 = fm.data.strip_data(sink.data["Input"])
+        data_2 = sink.data["Input"][0, ...]
         assert_allclose(data_1, data_2)
         self.assertEqual(data_2.shape, (19, 14))
-- 
GitLab