From 16447abf0da80d35dbcd663c00850bff50b3e6a0 Mon Sep 17 00:00:00 2001
From: Martin Lange <martin.lange@ufz.de>
Date: Tue, 6 Dec 2022 14:46:30 +0100
Subject: [PATCH] fix tests for tools and sdk

---
 benchmarks/data/test_tools.py  |   1 -
 src/finam/modules/control.py   |   3 +-
 src/finam/modules/debug.py     |   4 +-
 src/finam/modules/mergers.py   |   6 +-
 tests/core/test_schedule.py    |  15 ++--
 tests/core/test_sdk.py         |  12 +--
 tests/data/test_tools.py       | 139 ++++++++++++++-------------------
 tests/modules/test_callback.py |   2 +-
 8 files changed, 76 insertions(+), 106 deletions(-)

diff --git a/benchmarks/data/test_tools.py b/benchmarks/data/test_tools.py
index 42f28f03..4bf066aa 100644
--- a/benchmarks/data/test_tools.py
+++ b/benchmarks/data/test_tools.py
@@ -14,7 +14,6 @@ from finam.data.tools import (
     get_magnitude,
     get_units,
     is_quantified,
-    strip_data,
     strip_time,
     to_units,
     to_xarray,
diff --git a/src/finam/modules/control.py b/src/finam/modules/control.py
index e3ba593c..280b57d4 100644
--- a/src/finam/modules/control.py
+++ b/src/finam/modules/control.py
@@ -2,7 +2,6 @@
 import datetime as dt
 
 from ..data.grid_spec import NoGrid
-from ..data.tools import strip_data
 from ..errors import FinamMetaDataError
 from ..sdk import TimeComponent
 from ..tools.connect_helper import FromInput, FromOutput
@@ -156,7 +155,7 @@ class TimeTrigger(TimeComponent):
     def _update(self):
         self.time += self._step
 
-        data = strip_data(self.inputs["In"].pull_data(self.time))
+        data = self.inputs["In"].pull_data(self.time)
         self.outputs["Out"].push_data(data, self.time)
 
     def _finalize(self):
diff --git a/src/finam/modules/debug.py b/src/finam/modules/debug.py
index d2ad2e54..cfd94a89 100644
--- a/src/finam/modules/debug.py
+++ b/src/finam/modules/debug.py
@@ -116,7 +116,7 @@ class DebugConsumer(TimeComponent):
                 self.logger.debug("Pulled input data for %s", name)
 
                 if self._log_data is not None:
-                    pdata = tools.strip_data(data) if self._strip_data else data
+                    pdata = data[0, ...] if self._strip_data else data
                     self.logger.log(
                         self._log_data,
                         'Received "%s" - %s: %s',
@@ -140,7 +140,7 @@ class DebugConsumer(TimeComponent):
         }
         for name, data in self._data.items():
             if self._log_data is not None:
-                pdata = tools.strip_data(data) if self._strip_data else data
+                pdata = data[0, ...] if self._strip_data else data
                 self.logger.log(
                     self._log_data,
                     'Received "%s" - %s: %s',
diff --git a/src/finam/modules/mergers.py b/src/finam/modules/mergers.py
index c46cf94f..4220a959 100644
--- a/src/finam/modules/mergers.py
+++ b/src/finam/modules/mergers.py
@@ -1,7 +1,7 @@
 """Pull-based components for merging multiple inputs into a single output"""
 from finam.interfaces import ComponentStatus
 
-from ..data.tools import compatible_units, strip_data
+from ..data.tools import compatible_units, strip_time
 from ..errors import FinamMetaDataError
 from ..sdk import CallbackOutput, Component
 from ..tools.log_helper import ErrorLogger
@@ -140,8 +140,8 @@ class WeightedSum(Component):
             result = None
 
             for name in self._input_names:
-                value = strip_data(self._in_data[name])
-                weight = strip_data(self._in_data[name + "_weight"])
+                value = strip_time(self._in_data[name])
+                weight = strip_time(self._in_data[name + "_weight"])
 
                 if result is None:
                     result = value * weight
diff --git a/tests/core/test_schedule.py b/tests/core/test_schedule.py
index a80ebb62..1c984949 100644
--- a/tests/core/test_schedule.py
+++ b/tests/core/test_schedule.py
@@ -27,7 +27,6 @@ from finam import (
     Output,
     TimeComponent,
 )
-from finam import data as tools
 from finam.adapters.base import Scale
 from finam.adapters.time import DelayFixed, NextTime
 from finam.modules import CallbackComponent, CallbackGenerator, DebugPushConsumer, debug
@@ -148,7 +147,7 @@ class MockupCircularComponent(TimeComponent):
         push_data = {}
         pulled_data = self.connector.in_data["Input"]
         if pulled_data is not None:
-            push_data["Output"] = tools.get_data(tools.strip_time(pulled_data))
+            push_data["Output"] = pulled_data[0, ...]
 
         self.try_connect(
             start_time,
@@ -162,9 +161,7 @@ class MockupCircularComponent(TimeComponent):
     def _update(self):
         self._time += self._step
         pulled = self.inputs["Input"].pull_data(self.time)
-        self.outputs["Output"].push_data(
-            tools.get_data(tools.strip_time(pulled)), self.time
-        )
+        self.outputs["Output"].push_data(pulled, self.time)
 
     def _finalize(self):
         pass
@@ -291,7 +288,7 @@ class TestComposition(unittest.TestCase):
             log_file = os.path.join(tmp, "test.log")
 
             module1 = MockupComponent(
-                callbacks={"Output": lambda t: t}, step=timedelta(1.0)
+                callbacks={"Output": lambda t: t.day}, step=timedelta(1.0)
             )
             module2 = MockupDependentComponent(step=timedelta(1.0))
 
@@ -363,7 +360,7 @@ class TestComposition(unittest.TestCase):
 
     def test_iterative_connect(self):
         module1 = MockupComponent(
-            callbacks={"Output": lambda t: t}, step=timedelta(1.0)
+            callbacks={"Output": lambda t: t.day}, step=timedelta(1.0)
         )
         module2 = MockupDependentComponent(step=timedelta(1.0))
 
@@ -376,7 +373,7 @@ class TestComposition(unittest.TestCase):
 
     def test_iterative_connect_multi(self):
         module1 = MockupComponent(
-            callbacks={"Output": lambda t: t}, step=timedelta(1.0)
+            callbacks={"Output": lambda t: t.day}, step=timedelta(1.0)
         )
         module2 = MockupCircularComponent(step=timedelta(1.0))
         module3 = MockupDependentComponent(step=timedelta(1.0))
@@ -866,7 +863,7 @@ class TestComposition(unittest.TestCase):
             return t.day
 
         def lambda_component(inp, t):
-            return {"Out": fm.data.strip_data(inp["In"])}
+            return {"Out": inp["In"][0, ...]}
 
         def lambda_debugger(name, data, t):
             updates[name].append(t.day)
diff --git a/tests/core/test_sdk.py b/tests/core/test_sdk.py
index 4b67ca39..eec532fb 100644
--- a/tests/core/test_sdk.py
+++ b/tests/core/test_sdk.py
@@ -372,7 +372,7 @@ class TestOutput(unittest.TestCase):
         out.push_info(info)
         in1.exchange_info(info)
 
-        in_data = fm.data.full(0.0, "test", info)
+        in_data = fm.data.full(0.0, info)
         out.push_data(in_data, t)
         with self.assertRaises(FinamDataError):
             out.push_data(in_data, t)
@@ -392,7 +392,7 @@ class TestOutput(unittest.TestCase):
         out.push_info(info1)
         in1.exchange_info(info2)
 
-        in_data = fm.data.full(0.0, "test", info1)
+        in_data = fm.data.full(0.0, info1)
         out.push_data(in_data, t)
         out_data = in1.pull_data(t, in1)
 
@@ -414,7 +414,7 @@ class TestOutput(unittest.TestCase):
         out.push_info(info)
         in1.exchange_info(info)
 
-        in_data = fm.data.strip_data(fm.data.full(0.0, "test", info))
+        in_data = fm.data.full(0.0, info)
         out.push_data(in_data, t)
         with self.assertRaises(FinamDataError):
             out.push_data(in_data, t)
@@ -434,7 +434,7 @@ class TestOutput(unittest.TestCase):
         out.push_info(info1)
         in1.exchange_info(info2)
 
-        in_data = fm.data.strip_data(fm.data.full(0.0, "test", info1))
+        in_data = fm.data.full(0.0, info1)
         out.push_data(in_data, t)
         out_data = in1.pull_data(t, in1)
 
@@ -496,7 +496,7 @@ class TestInput(unittest.TestCase):
         out.push_data(0, None)
         data = in1.pull_data(None)
 
-        self.assertTrue(fm.data.has_time_axis(data))
+        self.assertTrue(fm.data.has_time_axis(data, info.grid))
 
         data_2 = in1.pull_data(None)
 
@@ -520,7 +520,7 @@ class TestInput(unittest.TestCase):
         out.push_data(0, None)
         data = in1.pull_data(t)
 
-        self.assertTrue(fm.data.has_time_axis(data))
+        self.assertTrue(fm.data.has_time_axis(data, info.grid))
 
 
 class TestCallbackInput(unittest.TestCase):
diff --git a/tests/data/test_tools.py b/tests/data/test_tools.py
index b0269278..7c221818 100644
--- a/tests/data/test_tools.py
+++ b/tests/data/test_tools.py
@@ -4,7 +4,6 @@ from datetime import datetime as dt
 
 import numpy as np
 import pint
-import xarray as xr
 
 import finam
 import finam.errors
@@ -29,68 +28,56 @@ class TestDataTools(unittest.TestCase):
         tim0 = dt(year=2021, month=10, day=12)
 
         data = np.arange(6).reshape(3, 2)
-        dar0 = finam.data.to_xarray(data, "data", info)
-        dar1 = finam.data.to_xarray(data, "data", info)
+        dar0 = finam.data.to_xarray(data, info)
+        dar1 = finam.data.to_xarray(data, info)
 
         # assert stuff
         self.assertIsInstance(finam.data.get_magnitude(dar0), np.ndarray)
-        self.assertIsInstance(finam.data.get_data(dar0), pint.Quantity)
+        self.assertIsInstance(finam.data.strip_time(dar0, info.grid), pint.Quantity)
         self.assertIsInstance(
             finam.data.get_dimensionality(dar0), pint.util.UnitsContainer
         )
 
         # should work
-        finam.data.to_xarray(dar0, "data", info)
-        finam.data.to_xarray(dar1, "data", info)
-        finam.data.check(dar0, "data", info)
-        finam.data.check(dar1, "data", info)
+        finam.data.to_xarray(dar0, info)
+        finam.data.to_xarray(dar1, info)
+        finam.data.check(dar0, info)
+        finam.data.check(dar1, info)
         finam.data.to_units(dar0, "km")
 
         # wrong shape
         with self.assertRaises(finam.errors.FinamDataError):
-            finam.data.to_xarray(1, "data", info)
+            finam.data.to_xarray(1, info)
 
         # no DataArray
         with self.assertRaises(finam.errors.FinamDataError):
-            finam.data.check(None, "data", info)
+            finam.data.check(None, info)
 
         # not qunatified
         with self.assertRaises(finam.errors.FinamDataError):
-            finam.data.check(dar0.pint.dequantify(), "data", info)
+            finam.data.check(dar0.magnitude, info)
 
-        # wrong name
-        with self.assertRaises(finam.errors.FinamDataError):
-            finam.data.check(dar0, "wrong", info)
-
-        finam.data.check(dar1, "data", inf0)
+        finam.data.check(dar1, inf0)
 
         # other units format should work
-        finam.data.check(dar0, "data", inf0)
-
-        # wrong meta
-        with self.assertRaises(finam.errors.FinamDataError):
-            finam.data.check(dar0, "data", inf1)
+        finam.data.check(dar0, inf0)
 
         # wrong units
         with self.assertRaises(finam.errors.FinamDataError):
-            finam.data.check(dar0, "data", inf2)
-
-        # wrong dims
-        with self.assertRaises(finam.errors.FinamDataError):
-            finam.data.check(dar0, "data", inf3)
+            finam.data.check(dar0, inf2)
 
         # wrong shape
         with self.assertRaises(finam.errors.FinamDataError):
-            finam.data.check(dar0, "data", inf4)
+            finam.data.check(dar0, inf4)
         with self.assertRaises(finam.errors.FinamDataError):
-            finam.data.check(dar1, "data", inf4)
+            finam.data.check(dar1, inf4)
 
         # check full_like
         dar2 = finam.data.full_like(dar0, 0)
-        finam.data.check(dar2, "data", info)
+        finam.data.check(dar2, info)
 
-        dar3 = finam.data.full(0, "data", info)
-        finam.data.check(dar3, "data", info)
+        dar3 = finam.data.full(0, info)
+        finam.data.check(dar3, info)
 
     def test_other_grids(self):
         time = dt(2000, 1, 1)
@@ -99,82 +86,67 @@ class TestDataTools(unittest.TestCase):
         gri1 = finam.UnstructuredPoints(points=[[0, 0], [0, 2], [2, 2]])
         info = finam.Info(time, gri0, units="s")
         data = np.arange(3)
-        dar0 = finam.data.to_xarray(data, "data", info)
-        dar1 = finam.data.to_xarray(data, "data", info.copy_with(grid=gri1))
+        dar0 = finam.data.to_xarray(data, info)
+        dar1 = finam.data.to_xarray(data, info.copy_with(grid=gri1))
 
-        self.assertTrue("dim_0" in dar0.dims)
-        self.assertTrue("id" in dar1.dims)
+        self.assertEqual((1, 3), dar0.shape)
+        self.assertEqual((1, 3), dar1.shape)
 
     def test_strip_time(self):
         time = dt(2000, 1, 1)
+        grid = finam.NoGrid()
 
-        xdata = finam.data.to_xarray(1.0, "data", finam.Info(time, grid=finam.NoGrid()))
-        self.assertEqual(xdata.shape, (1,))
-        stripped = finam.data.strip_time(xdata)
-        self.assertEqual(stripped.shape, ())
-
-        xdata = finam.data.to_xarray(
-            1.0,
-            "data",
-            finam.Info(time, grid=finam.NoGrid()),
-        )
+        xdata = finam.data.to_xarray(1.0, finam.Info(time, grid=grid))
         self.assertEqual(xdata.shape, (1,))
-        stripped = finam.data.strip_time(xdata)
+        stripped = finam.data.strip_time(xdata, grid)
         self.assertEqual(stripped.shape, ())
 
         xdata = finam.data.to_xarray(
             [1.0, 2.0, 3.0],
-            "data",
             finam.Info(time, grid=finam.NoGrid(dim=1)),
         )
         self.assertEqual(xdata.shape, (1, 3))
-        stripped = finam.data.strip_time(xdata)
+        stripped = finam.data.strip_time(xdata, finam.NoGrid(dim=1))
         self.assertEqual(stripped.shape, (3,))
-        stripped2 = finam.data.strip_time(xdata)
+        stripped2 = finam.data.strip_time(xdata, finam.NoGrid(dim=1))
         self.assertEqual(stripped2.shape, stripped.shape)
 
-        with self.assertRaises(finam.errors.FinamDataError):
-            stripped_ = finam.data.strip_time(np.asarray([1.0, 2.0]))
-
         arr1 = finam.data.to_xarray(
             1.0,
-            "A",
             finam.Info(time, grid=finam.NoGrid()),
         )
         arr2 = finam.data.to_xarray(
             1.0,
-            "A",
             finam.Info(time, grid=finam.NoGrid()),
         )
-        data = xr.concat([arr1, arr2], dim="time")
+        data = np.concatenate([arr1, arr2], axis=0)
         with self.assertRaises(finam.errors.FinamDataError):
-            stripped_ = finam.data.strip_time(data)
-
-    def test_strip_data(self):
-        time = dt(2000, 1, 1)
-        xdata = finam.data.to_xarray(1.0, "data", finam.Info(time, grid=finam.NoGrid()))
-        self.assertEqual(xdata.shape, (1,))
-        stripped = finam.data.strip_data(xdata)
-        self.assertEqual(stripped.shape, ())
-        self.assertTrue(isinstance(stripped, pint.Quantity))
-        self.assertFalse(isinstance(stripped, xr.DataArray))
+            stripped_ = finam.data.strip_time(data, finam.NoGrid())
 
     def test_to_xarray(self):
         time = dt(2000, 1, 1)
+
+        data = finam.data.to_xarray(1.0, finam.Info(time, grid=finam.NoGrid()))
+        self.assertEqual(np.asarray([1.0]) * finam.UNITS(""), data)
+
+        data = finam.data.to_xarray(
+            [[1.0, 1.0], [1.0, 1.0]], finam.Info(time, grid=finam.UniformGrid((3, 3)))
+        )
+        self.assertEqual((1, 2, 2), data.shape)
+
         with self.assertRaises(finam.errors.FinamDataError):
             finam.data.to_xarray(
-                np.asarray([1, 2]), "A", finam.Info(time, grid=finam.NoGrid())
+                np.asarray([1, 2]), finam.Info(time, grid=finam.NoGrid())
             )
 
         with self.assertRaises(finam.errors.FinamDataError):
             finam.data.to_xarray(
-                1.0 * finam.UNITS.meter, "A", finam.Info(time, grid=finam.NoGrid())
+                1.0 * finam.UNITS.meter, finam.Info(time, grid=finam.NoGrid())
             )
 
         with self.assertRaises(finam.errors.FinamDataError):
             finam.data.to_xarray(
                 1.0 * finam.UNITS.meter,
-                "A",
                 finam.Info(time, grid=finam.NoGrid(), units="m^3"),
             )
 
@@ -185,39 +157,39 @@ class TestDataTools(unittest.TestCase):
 
         # using numpy arrays without units
         data = np.asarray([1, 2])
-        xdata = finam.data.to_xarray(data, "test", info_1, force_copy=True)
+        xdata = finam.data.to_xarray(data, info_1, force_copy=True)
         data[0] = 0
         self.assertNotEqual(xdata[0, 0], data[0])
 
         # using numpy arrays with units
         data = np.asarray([1, 2]) * finam.UNITS("m")
-        xdata = finam.data.to_xarray(data, "test", info_1)
+        xdata = finam.data.to_xarray(data, info_1)
         data[0] = 0 * finam.UNITS("m")
         self.assertEqual(xdata[0, 0], data[0])
 
         data = np.asarray([1, 2]) * finam.UNITS("m")
-        xdata = finam.data.to_xarray(data, "test", info_1, force_copy=True)
+        xdata = finam.data.to_xarray(data, info_1, force_copy=True)
         data[0] = 0 * finam.UNITS("m")
         self.assertNotEqual(xdata[0, 0], data[0])
 
         data = np.asarray([1, 2]) * finam.UNITS("m")
-        xdata = finam.data.to_xarray(data, "test", info_2)
+        xdata = finam.data.to_xarray(data, info_2)
         data[0] = 0 * finam.UNITS("m")
         self.assertNotEqual(finam.data.get_magnitude(xdata[0, 0]), 0.0)
 
         # using xarray arrays
-        xdata = finam.data.to_xarray(np.asarray([1, 2]), "test", info_1)
-        xdata2 = finam.data.to_xarray(xdata, "test", info_1)
+        xdata = finam.data.to_xarray(np.asarray([1, 2]), info_1)
+        xdata2 = finam.data.to_xarray(xdata, info_1)
         xdata[0, 0] = 0 * finam.UNITS("m")
         self.assertEqual(xdata2[0, 0], xdata[0, 0])
 
-        xdata = finam.data.to_xarray(np.asarray([1, 2]), "test", info_1)
-        xdata2 = finam.data.to_xarray(xdata, "test", info_1, force_copy=True)
+        xdata = finam.data.to_xarray(np.asarray([1, 2]), info_1)
+        xdata2 = finam.data.to_xarray(xdata, info_1, force_copy=True)
         xdata[0, 0] = 0 * finam.UNITS("m")
         self.assertNotEqual(xdata2[0, 0], xdata[0, 0])
 
-        xdata = finam.data.to_xarray(np.asarray([1, 2]), "test", info_1)
-        xdata2 = finam.data.to_xarray(xdata, "test", info_2)
+        xdata = finam.data.to_xarray(np.asarray([1, 2]), info_1)
+        xdata2 = finam.data.to_xarray(xdata, info_2)
         xdata[0, 0] = 0 * finam.UNITS("m")
         self.assertNotEqual(finam.data.get_magnitude(xdata2[0, 0]), 0.0)
 
@@ -272,7 +244,6 @@ class TestDataTools(unittest.TestCase):
         time = dt(2000, 1, 1)
         xdata = finam.data.to_xarray(
             1.0,
-            "data",
             finam.Info(time, grid=finam.NoGrid()),
         )
 
@@ -282,14 +253,18 @@ class TestDataTools(unittest.TestCase):
             finam.data.tools._check_shape(xdata, finam.NoGrid(dim=1))
 
     def test_quantify(self):
-        xdata = xr.DataArray(1.0, attrs={"units": "m"})
-        xdata = finam.data.quantify(xdata)
+        xdata = np.asarray([1.0])
+        xdata = finam.data.quantify(xdata, "m")
         self.assertEqual(finam.data.get_units(xdata), finam.UNITS.meter)
 
-        xdata = xr.DataArray(1.0)
+        xdata = np.asarray([1.0])
         xdata = finam.data.quantify(xdata)
         self.assertEqual(finam.data.get_units(xdata), finam.UNITS.dimensionless)
 
+        xdata = np.asarray([1.0]) * finam.UNITS("")
+        with self.assertRaises(finam.FinamDataError):
+            xdata = finam.data.quantify(xdata)
+
     def test_to_datetime(self):
         t = np.datetime64("1900-01-01")
         self.assertEqual(datetime.datetime(1900, 1, 1), finam.data.to_datetime(t))
diff --git a/tests/modules/test_callback.py b/tests/modules/test_callback.py
index 6d322dab..bb865bf7 100644
--- a/tests/modules/test_callback.py
+++ b/tests/modules/test_callback.py
@@ -9,7 +9,7 @@ from finam.modules import CallbackComponent, CallbackGenerator, DebugConsumer
 
 
 def transform(inputs, _time):
-    return {"Out1": fm.data.strip_data(inputs["In1"]) * 2.0}
+    return {"Out1": inputs["In1"][0, ...] * 2.0}
 
 
 def consume(_inputs, _time):
-- 
GitLab