diff --git a/src/finam_mhm_module/lai_adapter.py b/src/finam_mhm_module/lai_adapter.py index d70e69b74a881645d9634e9c2039cff50b4a7d65..11da3bbe91e5a3dcd46fd3692aba3eb07f83b967 100644 --- a/src/finam_mhm_module/lai_adapter.py +++ b/src/finam_mhm_module/lai_adapter.py @@ -9,7 +9,6 @@ class YearlyToMonthly(AAdapter): def __init__(self, lai_curve): super().__init__() self.data = None - self.month = 0 if len(lai_curve) != 12: raise ValueError("LAI curve must be an array of 12 values") self.lai_curve = lai_curve @@ -19,7 +18,6 @@ class YearlyToMonthly(AAdapter): raise ValueError("Time must be of type datetime") self.data = self.pull_data(time) - self.month = time.month self.notify_targets(time) @@ -27,9 +25,6 @@ class YearlyToMonthly(AAdapter): if not isinstance(time, datetime): raise ValueError("Time must be of type datetime") - new_month = time.month - lai = [] - for m in range(self.month - 1, new_month): - lai.append(self.data * self.lai_curve[m]) + lai = [self.data * self.lai_curve[m] for m in range(0, len(self.lai_curve))] return lai diff --git a/tests/test_lai_adapter.py b/tests/test_lai_adapter.py index d5defe70bab56bc6d7c83d7a51b4bc96e50f9f58..319681d45c7f8266739c8cbb1457c2a712d4f9e1 100644 --- a/tests/test_lai_adapter.py +++ b/tests/test_lai_adapter.py @@ -1,9 +1,45 @@ import unittest from datetime import datetime, timedelta +from finam.modules.generators import CallbackGenerator +from finam.data.grid import Grid, GridSpec from finam_mhm_module import YearlyToMonthly class TestLaiAdapter(unittest.TestCase): def test_adapter(self): - pass + grid = Grid(GridSpec(3, 2)) + grid.fill(1.0) + for c in range(3): + grid.set(c, 0, 2.0) + + source = CallbackGenerator( + callbacks={"LAI": lambda t: grid}, + start=datetime(2000, 1, 1), + step=timedelta(365.0), + ) + lai_curve = [0.0, 0.0, 0.2, 0.5, 0.8, 1.0, 1.0, 1.0, 0.8, 0.5, 0.2, 0.0] + + adapter = YearlyToMonthly(lai_curve) + source.initialize() + source.outputs["LAI"] >> adapter + + source.connect() + source.validate() + + lai_months = adapter.get_data(datetime(2000, 1, 1, 0)) + self.assertEqual(len(lai_months), 12) + + for m in range(len(lai_months)): + self.assertTrue(isinstance(lai_months[m], Grid)) + self.assertEqual(lai_months[m], grid * lai_curve[m]) + + grid.fill(2.0) + source.update() + + lai_months_2 = adapter.get_data(datetime(2001, 1, 1, 0)) + self.assertEqual(len(lai_months_2), 12) + + for m in range(len(lai_months_2)): + self.assertTrue(isinstance(lai_months_2[m], Grid)) + self.assertEqual(lai_months_2[m], grid * lai_curve[m])