From 03586180c13f95cd9e14ffbc7d385c03f4581843 Mon Sep 17 00:00:00 2001
From: Martin Lange <martin.lange@ufz.de>
Date: Thu, 8 Dec 2022 01:07:25 +0100
Subject: [PATCH] simply re-assign equivalen units, add benchmarks

---
 benchmarks/sdk/test_io.py | 65 ++++++++++++++++++++++++++++++---------
 src/finam/data/tools.py   |  9 ++++--
 src/finam/sdk/input.py    |  2 +-
 3 files changed, 59 insertions(+), 17 deletions(-)

diff --git a/benchmarks/sdk/test_io.py b/benchmarks/sdk/test_io.py
index 47d83248..3b294445 100644
--- a/benchmarks/sdk/test_io.py
+++ b/benchmarks/sdk/test_io.py
@@ -17,13 +17,14 @@ class TestPushPull(unittest.TestCase):
         data = self.data[self.counter % 2]
 
         self.out.push_data(data, self.time)
-        _ = self.inp.pull_data(self.time)
+        data = self.inp.pull_data(self.time)
         self.time += dt.timedelta(days=1)
         self.counter += 1
+        return data
 
     def setup_link(self, grid, target_units):
         self.time = dt.datetime(2000, 1, 1)
-        info1 = fm.Info(time=self.time, grid=grid, units="m")
+        info1 = fm.Info(time=self.time, grid=grid, units="mm")
         info2 = fm.Info(time=self.time, grid=grid, units=target_units)
 
         self.data = [
@@ -42,47 +43,83 @@ class TestPushPull(unittest.TestCase):
     @pytest.mark.benchmark(group="sdk-io")
     def test_push_pull_np_01_2x1(self):
         grid = fm.UniformGrid((2, 1))
-        self.setup_link(grid, target_units="m")
-        self.benchmark(self.push_pull)
+        self.setup_link(grid, target_units="mm")
+        data = self.benchmark(self.push_pull)
+        self.assertEqual(fm.UNITS.millimeter, data.units)
 
     @pytest.mark.benchmark(group="sdk-io")
     def test_push_pull_np_02_512x256(self):
         grid = fm.UniformGrid((512, 256))
-        self.setup_link(grid, target_units="m")
-        self.benchmark(self.push_pull)
+        self.setup_link(grid, target_units="mm")
+        data = self.benchmark(self.push_pull)
+        self.assertEqual(fm.UNITS.millimeter, data.units)
 
     @pytest.mark.benchmark(group="sdk-io")
     def test_push_pull_np_03_1024x512(self):
         grid = fm.UniformGrid((1024, 512))
-        self.setup_link(grid, target_units="m")
-        self.benchmark(self.push_pull)
+        self.setup_link(grid, target_units="mm")
+        data = self.benchmark(self.push_pull)
+        self.assertEqual(fm.UNITS.millimeter, data.units)
 
     @pytest.mark.benchmark(group="sdk-io")
     def test_push_pull_np_04_2048x1024(self):
         grid = fm.UniformGrid((2048, 1024))
-        self.setup_link(grid, target_units="m")
-        self.benchmark(self.push_pull)
+        self.setup_link(grid, target_units="mm")
+        data = self.benchmark(self.push_pull)
+        self.assertEqual(fm.UNITS.millimeter, data.units)
 
     @pytest.mark.benchmark(group="sdk-io")
     def test_push_pull_np_units_01_2x1(self):
         grid = fm.UniformGrid((2, 1))
         self.setup_link(grid, target_units="km")
-        self.benchmark(self.push_pull)
+        data = self.benchmark(self.push_pull)
+        self.assertEqual(fm.UNITS.kilometer, data.units)
 
     @pytest.mark.benchmark(group="sdk-io")
     def test_push_pull_np_units_02_512x256(self):
         grid = fm.UniformGrid((512, 256))
         self.setup_link(grid, target_units="km")
-        self.benchmark(self.push_pull)
+        data = self.benchmark(self.push_pull)
+        self.assertEqual(fm.UNITS.kilometer, data.units)
 
     @pytest.mark.benchmark(group="sdk-io")
     def test_push_pull_np_units_03_1024x512(self):
         grid = fm.UniformGrid((1024, 512))
         self.setup_link(grid, target_units="km")
-        self.benchmark(self.push_pull)
+        data = self.benchmark(self.push_pull)
+        self.assertEqual(fm.UNITS.kilometer, data.units)
 
     @pytest.mark.benchmark(group="sdk-io")
     def test_push_pull_np_units_04_2048x1024(self):
         grid = fm.UniformGrid((2048, 1024))
         self.setup_link(grid, target_units="km")
-        self.benchmark(self.push_pull)
+        data = self.benchmark(self.push_pull)
+        self.assertEqual(fm.UNITS.kilometer, data.units)
+
+    @pytest.mark.benchmark(group="sdk-io")
+    def test_push_pull_np_equiv_01_2x1(self):
+        grid = fm.UniformGrid((2, 1))
+        self.setup_link(grid, target_units="L/m^2")
+        data = self.benchmark(self.push_pull)
+        self.assertEqual(fm.UNITS.Unit("L/m^2"), data.units)
+
+    @pytest.mark.benchmark(group="sdk-io")
+    def test_push_pull_np_equiv_02_512x256(self):
+        grid = fm.UniformGrid((512, 256))
+        self.setup_link(grid, target_units="L/m^2")
+        data = self.benchmark(self.push_pull)
+        self.assertEqual(fm.UNITS.Unit("L/m^2"), data.units)
+
+    @pytest.mark.benchmark(group="sdk-io")
+    def test_push_pull_np_equiv_03_1024x512(self):
+        grid = fm.UniformGrid((1024, 512))
+        self.setup_link(grid, target_units="L/m^2")
+        data = self.benchmark(self.push_pull)
+        self.assertEqual(fm.UNITS.Unit("L/m^2"), data.units)
+
+    @pytest.mark.benchmark(group="sdk-io")
+    def test_push_pull_np_equiv_04_2048x1024(self):
+        grid = fm.UniformGrid((2048, 1024))
+        self.setup_link(grid, target_units="L/m^2")
+        data = self.benchmark(self.push_pull)
+        self.assertEqual(fm.UNITS.Unit("L/m^2"), data.units)
diff --git a/src/finam/data/tools.py b/src/finam/data/tools.py
index 7617c4ca..7989a63a 100644
--- a/src/finam/data/tools.py
+++ b/src/finam/data/tools.py
@@ -272,7 +272,7 @@ def get_dimensionality(xdata):
     return xdata.dimensionality
 
 
-def to_units(xdata, units):
+def to_units(xdata, units, check_equivalent=False):
     """
     Convert data to given units.
 
@@ -282,6 +282,8 @@ def to_units(xdata, units):
         The given data array.
     units : str or pint.Unit
         Desired units.
+    check_equivalent : bool, optional
+        Checks for equivalent units and simply re-assigns if possible.
 
     Returns
     -------
@@ -290,8 +292,11 @@ def to_units(xdata, units):
     """
     check_quantified(xdata, "to_units")
     units = _get_pint_units(units)
-    if units == xdata.units:
+    units2 = xdata.units
+    if units == units2:
         return xdata
+    if check_equivalent and equivalent_units(units, units2):
+        return UNITS.Quantity(xdata.magnitude, units)
     return xdata.to(units)
 
 
diff --git a/src/finam/sdk/input.py b/src/finam/sdk/input.py
index 7841f5a7..2a8a951a 100644
--- a/src/finam/sdk/input.py
+++ b/src/finam/sdk/input.py
@@ -122,7 +122,7 @@ class Input(IInput, Loggable):
             data = self.source.get_data(time, target or self)
 
         with ErrorLogger(self.logger):
-            data = tools.to_units(data, self._input_info.units)
+            data = tools.to_units(data, self._input_info.units, check_equivalent=True)
             tools.check(data, self._input_info)
 
         return data
-- 
GitLab