From 4d8d3d86b2537bd64b1c1ea98e38ad57e7deb5a4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sebastian=20M=C3=BCller?= <mueller.seb@posteo.de>
Date: Mon, 17 Jul 2023 17:15:54 +0200
Subject: [PATCH] tests: test masking adapter

---
 tests/adapters/test_masking.py | 62 ++++++++++++++++++++++++++++++++++
 1 file changed, 62 insertions(+)
 create mode 100644 tests/adapters/test_masking.py

diff --git a/tests/adapters/test_masking.py b/tests/adapters/test_masking.py
new file mode 100644
index 00000000..4ebc1b9c
--- /dev/null
+++ b/tests/adapters/test_masking.py
@@ -0,0 +1,62 @@
+"""
+Unit tests for masking adapter.
+"""
+import unittest
+from datetime import datetime, timedelta
+
+import numpy as np
+
+from finam import (
+    Composition,
+    EsriGrid,
+    Info,
+)
+from finam.adapters.masking import Masking
+from finam.modules import debug, generators
+
+
+class TestMasking(unittest.TestCase):
+    def test_masking(self):
+        time = datetime(2000, 1, 1)
+
+        mask = [
+            [True, False, True],
+            [False, False, True],
+            [False, False, False],
+            [True, False, False],
+        ]
+
+        in_grid = EsriGrid(ncols=3, nrows=4, order="F")
+        out_grid = EsriGrid(ncols=3, nrows=4, mask=mask, order="F")
+
+        in_info = Info(time=time, grid=in_grid, units="m")
+
+        in_data = np.zeros(shape=in_info.grid.data_shape, order=in_info.grid.order)
+        in_data.data[0, 0] = 1.0
+        in_data.data[0, 1] = 2.0
+
+        source = generators.CallbackGenerator(
+            callbacks={"Output": (lambda t: in_data, in_info)},
+            start=datetime(2000, 1, 1),
+            step=timedelta(days=1),
+        )
+
+        sink = debug.DebugConsumer(
+            {"Input": Info(None, grid=out_grid, units=None)},
+            start=datetime(2000, 1, 1),
+            step=timedelta(days=1),
+        )
+
+        composition = Composition([source, sink], log_level="DEBUG")
+        composition.initialize()
+
+        source.outputs["Output"] >> Masking() >> sink.inputs["Input"]
+
+        composition.connect()
+        self.assertTrue(np.isnan(sink.data["Input"][0][0, 0]))
+        self.assertAlmostEqual(sink.data["Input"][0][0, 1].magnitude, 2.0)
+        print(sink.data["Input"][0])
+
+
+if __name__ == "__main__":
+    unittest.main()
-- 
GitLab