diff --git a/benchmarks/README.md b/benchmarks/README.md index 2c6828b9cc26f1089591569c9978bb1cd7836e0a..ae54c8d5423ba024ce6e7121de6479de73aa55df 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -18,9 +18,7 @@ Simple run over one year with two coupled components with daily time step. Groups left to right: * Using numpy arrays, no data copy, no units conversion * Using numpy arrays, with data copy, no units conversion -* Using xarray arrays, no data copy, no units conversion -* Using xarray arrays, with data copy, no units conversion -* Using xarray arrays, no data copy, with units conversion +* Using numpy arrays, no data copy, with units conversion  @@ -28,10 +26,13 @@ Groups left to right: ### Push & pull -Push & pull using numpy arrays (`np`) and xarray arrays (`xr`). -(xarray benchmarks include a call to `fm.tools.assign_time`) +Push & pull using numpy arrays, with and without units conversion. - + + +Push & pull using zero memory limit. I.e. everything written to and re-read from file. + + ## Data diff --git a/benchmarks/data/test_tools.py b/benchmarks/data/test_tools.py index b6cfbb48d66f4195572e898cd18011ef51898132..0fbaa5167df2f73704b950b217d3c92dc7a4c18c 100644 --- a/benchmarks/data/test_tools.py +++ b/benchmarks/data/test_tools.py @@ -69,6 +69,27 @@ class TestPrepare(unittest.TestCase): xdata = full(0.0, info) _result = self.benchmark(prepare, data=xdata, info=info) + @pytest.mark.benchmark(group="data-tools") + def test_prepare_np_units_01_2x1(self): + time = dt.datetime(2000, 1, 1) + info = fm.Info(time=time, grid=fm.UniformGrid((2, 1)), units="m") + xdata = full(0.0, info).magnitude + _result = self.benchmark(prepare, data=xdata, info=info) + + @pytest.mark.benchmark(group="data-tools") + def test_prepare_np_units_02_512x256(self): + time = dt.datetime(2000, 1, 1) + info = fm.Info(time=time, grid=fm.UniformGrid((512, 256)), units="m") + xdata = full(0.0, info).magnitude + _result = self.benchmark(prepare, data=xdata, info=info) + + @pytest.mark.benchmark(group="data-tools") + def test_prepare_np_units_03_2048x1024(self): + time = dt.datetime(2000, 1, 1) + info = fm.Info(time=time, grid=fm.UniformGrid((2048, 1024)), units="m") + xdata = full(0.0, info).magnitude + _result = self.benchmark(prepare, data=xdata, info=info) + @pytest.mark.benchmark(group="data-tools-slow") def test_cp_prepare_np_01_2x1(self): time = dt.datetime(2000, 1, 1) diff --git a/benchmarks/numpy/test_save_load.py b/benchmarks/numpy/test_save_load.py new file mode 100644 index 0000000000000000000000000000000000000000..3db7271f5d5b7123010b0b72e1aab12b652f4b34 --- /dev/null +++ b/benchmarks/numpy/test_save_load.py @@ -0,0 +1,74 @@ +import os.path +import tempfile +import unittest + +import numpy as np +import pytest + +import finam as fm + + +class TestCreateUniform(unittest.TestCase): + @pytest.fixture(autouse=True) + def setupBenchmark(self, benchmark): + self.benchmark = benchmark + + @pytest.mark.benchmark(group="np-save-load") + def test_save_01_64x32(self): + xdata = np.full((1, 64, 32), 1.0, dtype=np.dtype(np.float64)) + with tempfile.TemporaryDirectory() as d: + fp = os.path.join(d, "temp.npy") + _result = self.benchmark(np.save, file=fp, arr=xdata) + + @pytest.mark.benchmark(group="np-save-load") + def test_save_02_512x256(self): + xdata = np.full((1, 512, 256), 1.0, dtype=np.dtype(np.float64)) + with tempfile.TemporaryDirectory() as d: + fp = os.path.join(d, "temp.npy") + _result = self.benchmark(np.save, file=fp, arr=xdata) + + @pytest.mark.benchmark(group="np-save-load") + def test_save_03_1024x512(self): + xdata = np.full((1, 1024, 512), 1.0, dtype=np.dtype(np.float64)) + with tempfile.TemporaryDirectory() as d: + fp = os.path.join(d, "temp.npy") + _result = self.benchmark(np.save, file=fp, arr=xdata) + + @pytest.mark.benchmark(group="np-save-load") + def test_save_04_2048x1024(self): + xdata = np.full((1, 2048, 1024), 1.0, dtype=np.dtype(np.float64)) + with tempfile.TemporaryDirectory() as d: + fp = os.path.join(d, "temp.npy") + _result = self.benchmark(np.save, file=fp, arr=xdata) + + @pytest.mark.benchmark(group="np-save-load") + def test_load_01_64x32(self): + xdata = np.full((1, 64, 32), 1.0, dtype=np.dtype(np.float64)) + with tempfile.TemporaryDirectory() as d: + fp = os.path.join(d, "temp.npy") + np.save(fp, xdata) + _result = self.benchmark(np.load, file=fp) + + @pytest.mark.benchmark(group="np-save-load") + def test_load_02_512x256(self): + xdata = np.full((1, 512, 256), 1.0, dtype=np.dtype(np.float64)) + with tempfile.TemporaryDirectory() as d: + fp = os.path.join(d, "temp.npy") + np.save(fp, xdata) + _result = self.benchmark(np.load, file=fp) + + @pytest.mark.benchmark(group="np-save-load") + def test_load_03_1024x512(self): + xdata = np.full((1, 1024, 512), 1.0, dtype=np.dtype(np.float64)) + with tempfile.TemporaryDirectory() as d: + fp = os.path.join(d, "temp.npy") + np.save(fp, xdata) + _result = self.benchmark(np.load, file=fp) + + @pytest.mark.benchmark(group="np-save-load") + def test_load_04_2048x1024(self): + xdata = np.full((1, 2048, 1024), 1.0, dtype=np.dtype(np.float64)) + with tempfile.TemporaryDirectory() as d: + fp = os.path.join(d, "temp.npy") + np.save(fp, xdata) + _result = self.benchmark(np.load, file=fp) diff --git a/benchmarks/profiling/mem_huge_memory.py b/benchmarks/profiling/profile_huge_memory.py similarity index 65% rename from benchmarks/profiling/mem_huge_memory.py rename to benchmarks/profiling/profile_huge_memory.py index 5d7d18b55fb9736234f8ac5cb756a621d06cc890..3ec5831a60a617f0fb3910c077f684162ed73580 100644 --- a/benchmarks/profiling/mem_huge_memory.py +++ b/benchmarks/profiling/profile_huge_memory.py @@ -1,10 +1,18 @@ +import cProfile import datetime as dt +import io +import pstats +import sys +import time import numpy as np import finam as fm -if __name__ == "__main__": + +def run_model(): + t = time.time() + start_time = dt.datetime(2000, 1, 1) end_time = dt.datetime(2002, 1, 1) @@ -29,9 +37,23 @@ if __name__ == "__main__": step=dt.timedelta(days=365), ) - composition = fm.Composition([source, sink]) + composition = fm.Composition([source, sink], slot_memory_limit=500 * 2**20) composition.initialize() source["Out"] >> sink["In"] composition.run(end_time=end_time) + + print("Total time:", time.time() - t) + + +if __name__ == "__main__": + pr = cProfile.Profile() + pr.enable() + + run_model() + + pr.disable() + s = io.StringIO() + ps = pstats.Stats(pr, stream=s).sort_stats(pstats.SortKey.CUMULATIVE) + ps.dump_stats(sys.argv[1]) diff --git a/benchmarks/sdk/test_io.py b/benchmarks/sdk/test_io.py index 3b294445ef093adb15afc6ed4006eb4b395952bd..d9b20453e6784ea212c4dc2ae52df46116d81a94 100644 --- a/benchmarks/sdk/test_io.py +++ b/benchmarks/sdk/test_io.py @@ -1,4 +1,5 @@ import datetime as dt +import tempfile import unittest import pytest @@ -6,12 +7,7 @@ import pytest import finam as fm -class TestPushPull(unittest.TestCase): - @pytest.fixture(autouse=True) - def setupBenchmark(self, benchmark): - self.benchmark = benchmark - self.counter = 0 - +class TestPushPullBase(unittest.TestCase): def push_pull(self): # Trick the shared memory check in the output data = self.data[self.counter % 2] @@ -22,7 +18,7 @@ class TestPushPull(unittest.TestCase): self.counter += 1 return data - def setup_link(self, grid, target_units): + def setup_link(self, grid, target_units, memory_limit=None, tempdir=None): self.time = dt.datetime(2000, 1, 1) info1 = fm.Info(time=self.time, grid=grid, units="mm") info2 = fm.Info(time=self.time, grid=grid, units=target_units) @@ -35,11 +31,21 @@ class TestPushPull(unittest.TestCase): self.out = fm.Output(name="Output") self.inp = fm.Input(name="Input") + self.out.memory_limit = memory_limit + self.out.memory_location = tempdir + self.out >> self.inp self.inp.ping() self.out.push_info(info1) self.inp.exchange_info(info2) + +class TestPushPull(TestPushPullBase): + @pytest.fixture(autouse=True) + def setupBenchmark(self, benchmark): + self.benchmark = benchmark + self.counter = 0 + @pytest.mark.benchmark(group="sdk-io") def test_push_pull_np_01_2x1(self): grid = fm.UniformGrid((2, 1)) @@ -123,3 +129,35 @@ class TestPushPull(unittest.TestCase): self.setup_link(grid, target_units="L/m^2") data = self.benchmark(self.push_pull) self.assertEqual(fm.UNITS.Unit("L/m^2"), data.units) + + @pytest.mark.benchmark(group="sdk-io-mem") + def test_push_pull_file_01_2x1(self): + grid = fm.UniformGrid((2, 1)) + with tempfile.TemporaryDirectory() as td: + self.setup_link(grid, target_units="m", memory_limit=0, tempdir=td) + self.benchmark(self.push_pull) + self.out.finalize() + + @pytest.mark.benchmark(group="sdk-io-mem") + def test_push_pull_file_02_512x256(self): + grid = fm.UniformGrid((512, 256)) + with tempfile.TemporaryDirectory() as td: + self.setup_link(grid, target_units="m", memory_limit=0, tempdir=td) + self.benchmark(self.push_pull) + self.out.finalize() + + @pytest.mark.benchmark(group="sdk-io-mem") + def test_push_pull_file_03_1024x512(self): + grid = fm.UniformGrid((1024, 512)) + with tempfile.TemporaryDirectory() as td: + self.setup_link(grid, target_units="m", memory_limit=0, tempdir=td) + self.benchmark(self.push_pull) + self.out.finalize() + + @pytest.mark.benchmark(group="sdk-io-mem") + def test_push_pull_file_04_2048x1024(self): + grid = fm.UniformGrid((2048, 1024)) + with tempfile.TemporaryDirectory() as td: + self.setup_link(grid, target_units="m", memory_limit=0, tempdir=td) + self.benchmark(self.push_pull) + self.out.finalize() diff --git a/src/finam/adapters/time.py b/src/finam/adapters/time.py index 16387d224be1e1edd921e04c56f535573457b460..cd9f5e1775a6c4907ff6e552cf047fa6d6646230 100644 --- a/src/finam/adapters/time.py +++ b/src/finam/adapters/time.py @@ -1,6 +1,7 @@ """ Adapters that deal with time, like temporal interpolation and integration. """ +import os from abc import ABC, abstractmethod from datetime import datetime, timedelta @@ -241,7 +242,7 @@ class TimeCachingAdapter(Adapter, NoBranchAdapter, ABC): check_time(self.logger, time) data = dtools.strip_time(self.pull_data(time, self), self._input_info.grid) - self.data.append((time, data)) + self.data.append((time, self._pack(data))) def _get_data(self, time, _target): """Get the output's data-set for the given time. @@ -267,7 +268,11 @@ class TimeCachingAdapter(Adapter, NoBranchAdapter, ABC): def _clear_cached_data(self, time): while len(self.data) > 1 and self.data[1][0] <= time: - self.data.pop(0) + d = self.data.pop(0) + if isinstance(d[1], str): + os.remove(d[1]) + else: + self._total_mem -= d[1].nbytes @abstractmethod def _interpolate(self, time): @@ -289,13 +294,13 @@ class NextTime(TimeCachingAdapter): def _interpolate(self, time): if len(self.data) == 1: - return self.data[0][1] + return self._unpack(self.data[0][1]) for t, data in self.data: if time > t: continue - return data + return self._unpack(data) raise FinamTimeError( f"Time interpolation failed. This should not happen and is probably a bug. " @@ -318,16 +323,16 @@ class PreviousTime(TimeCachingAdapter): def _interpolate(self, time): if len(self.data) == 1: - return self.data[0][1] + return self._unpack(self.data[0][1]) for i, (t, data) in enumerate(self.data): if time > t: continue if time == t: - return data + return self._unpack(data) _, data_prev = self.data[i - 1] - return data_prev + return self._unpack(data_prev) raise FinamTimeError( f"Time interpolation failed. This should not happen and is probably a bug. " @@ -353,10 +358,10 @@ class StackTime(TimeCachingAdapter): for t, data in self.data: if time > t: - extract.append((t, data)) + extract.append((t, self._unpack(data))) continue - extract.append((t, data)) + extract.append((t, self._unpack(data))) break arr = np.stack([d[1] for d in extract]) @@ -396,13 +401,13 @@ class LinearTime(TimeCachingAdapter): if time > t: continue if time == t: - return data + return self._unpack(data) t_prev, data_prev = self.data[i - 1] dt = (time - t_prev) / (t - t_prev) - result = interpolate(data_prev, data, dt) + result = interpolate(self._unpack(data_prev), self._unpack(data), dt) return result @@ -459,7 +464,7 @@ class StepTime(TimeCachingAdapter): if time > t: continue if time == t: - return data + return self._unpack(data) t_prev, data_prev = self.data[i - 1] @@ -467,7 +472,7 @@ class StepTime(TimeCachingAdapter): result = interpolate_step(data_prev, data, dt, self.step) - return result + return self._unpack(result) raise FinamTimeError( f"Time interpolation failed. This should not happen and is probably a bug. " diff --git a/src/finam/adapters/time_integration.py b/src/finam/adapters/time_integration.py index 45e4111dba4407c19484c45304ba64f308fff897..5d18cd5957766f41b0d315a042b7349331df1a61 100644 --- a/src/finam/adapters/time_integration.py +++ b/src/finam/adapters/time_integration.py @@ -27,7 +27,7 @@ class TimeIntegrationAdapter(TimeCachingAdapter, ABC): check_time(self.logger, time) data = tools.strip_time(self.pull_data(time, self), self._input_info.grid) - self.data.append((time, data)) + self.data.append((time, self._pack(data))) if self._prev_time is None: self._prev_time = time @@ -111,18 +111,21 @@ class AvgOverTime(TimeIntegrationAdapter): def _interpolate(self, time): if len(self.data) == 1: - return self.data[0][1] + return self._unpack(self.data[0][1]) if time <= self.data[0][0]: - return self.data[0][1] + return self._unpack(self.data[0][1]) sum_value = None + t_old, v_old = self.data[0] + v_old = self._unpack(v_old) for i in range(len(self.data) - 1): - t_old, v_old = self.data[i] t_new, v_new = self.data[i + 1] + v_new = self._unpack(v_new) if self._prev_time >= t_new: + t_old, v_old = t_new, v_new continue if time <= t_old: break @@ -147,6 +150,8 @@ class AvgOverTime(TimeIntegrationAdapter): sum_value = value if sum_value is None else sum_value + value + t_old, v_old = t_new, v_new + dt = time - self._prev_time if dt.total_seconds() > 0: sum_value /= dt.total_seconds() * tools.UNITS.Unit("s") @@ -239,20 +244,22 @@ class SumOverTime(TimeIntegrationAdapter): if len(self.data) == 1 or time <= self.data[0][0]: if self._per_time: return ( - self.data[0][1] - * self._initial_interval.total_seconds() - * tools.UNITS.Unit("s") + self._unpack(self.data[0][1]) + * (self._initial_interval.total_seconds() * tools.UNITS.Unit("s")) ).to_reduced_units() - return self.data[0][1] + return self._unpack(self.data[0][1]) sum_value = None + t_old, v_old = self.data[0] + v_old = self._unpack(v_old) for i in range(len(self.data) - 1): - t_old, v_old = self.data[i] t_new, v_new = self.data[i + 1] + v_new = self._unpack(v_new) if self._prev_time >= t_new: + t_old, v_old = t_new, v_new continue if time <= t_old: break @@ -278,6 +285,8 @@ class SumOverTime(TimeIntegrationAdapter): sum_value = value if sum_value is None else sum_value + value + t_old, v_old = t_new, v_new + if self._per_time: return sum_value.to_reduced_units() diff --git a/src/finam/interfaces.py b/src/finam/interfaces.py index 5f371a013a2fe96f6a960338680322b12d8495a9..1fdeb3f90fe9565a5b0713cd3385715f439db29b 100644 --- a/src/finam/interfaces.py +++ b/src/finam/interfaces.py @@ -278,6 +278,26 @@ class IOutput(ABC): def needs_push(self): """bool: if the output needs push.""" + @property + @abstractmethod + def memory_limit(self): + """The memory limit for this slot""" + + @memory_limit.setter + @abstractmethod + def memory_limit(self, limit): + """The memory limit for this slot""" + + @property + @abstractmethod + def memory_location(self): + """The memory-mapping location for this slot""" + + @memory_location.setter + @abstractmethod + def memory_location(self, directory): + """The memory-mapping location for this slot""" + @abstractmethod def has_info(self): """Returns if the output has a data info. @@ -407,6 +427,10 @@ class IOutput(ABC): The last element of the chain. """ + @abstractmethod + def finalize(self): + """Finalize the output""" + def __rshift__(self, other): return self.chain(other) diff --git a/src/finam/schedule.py b/src/finam/schedule.py index d575ca4ef78a20a9b9644d1569b760e4585d5d73..1b98b6750a9a4c05d7b6bb16016a2573e6a135af 100644 --- a/src/finam/schedule.py +++ b/src/finam/schedule.py @@ -10,6 +10,7 @@ Composition :noindex: Composition """ import logging +import os import sys from datetime import datetime from pathlib import Path @@ -70,8 +71,13 @@ class Composition(Loggable): Whether to write a log file, by default None log_level : int or str, optional Logging level, by default logging.INFO - mpi_rank : int, default 0 - MPI rank of the composition. + slot_memory_limit : int, optional + Memory limit per output and adapter data, in bytes. + When the limit is exceeded, data is stored to disk under the path of `slot_memory_location`. + Default: no limit (``None``). + slot_memory_location : str, optional + Location for storing data when exceeding ``slot_memory_limit``. + Default: "temp". """ def __init__( @@ -81,7 +87,8 @@ class Composition(Loggable): print_log=True, log_file=None, log_level=logging.INFO, - mpi_rank=0, + slot_memory_limit=None, + slot_memory_location="temp", ): super().__init__() # setup logger @@ -116,7 +123,9 @@ class Composition(Loggable): self.output_owners = None self.is_initialized = False self.is_connected = False - self.mpi_rank = mpi_rank + + self.slot_memory_limit = slot_memory_limit + self.slot_memory_location = slot_memory_location def initialize(self): """Initialize all modules. @@ -130,6 +139,9 @@ class Composition(Loggable): for mod in self.modules: self._check_status(mod, [ComponentStatus.CREATED]) + if self.slot_memory_location is not None: + os.makedirs(self.slot_memory_location, exist_ok=True) + for mod in self.modules: if is_loggable(mod) and mod.uses_base_logger_name: mod.base_logger_name = self.logger_name @@ -139,6 +151,12 @@ class Composition(Loggable): mod.inputs.set_logger(mod) mod.outputs.set_logger(mod) + for _, out in mod.outputs.items(): + if out.memory_limit is None: + out.memory_limit = self.slot_memory_limit + if out.memory_location is None: + out.memory_location = self.slot_memory_location + self._check_status(mod, [ComponentStatus.INITIALIZED]) self.is_initialized = True @@ -168,17 +186,7 @@ class Composition(Loggable): ) else: if start_time is None: - t_min = None - for mod in time_modules: - if mod.time is not None: - if t_min is None or mod.time < t_min: - t_min = mod.time - if t_min is None: - raise ValueError( - "Unable to determine starting time of the composition." - "Please provide a starting time in ``run()`` or ``connect()``" - ) - start_time = t_min + start_time = _get_start_time(time_modules) if not isinstance(start_time, datetime): raise ValueError( "start must be of type datetime for a composition with time components" @@ -187,6 +195,12 @@ class Composition(Loggable): self._collect_adapters() self._validate_composition() + for ada in self.adapters: + if ada.memory_limit is None: + ada.memory_limit = self.slot_memory_limit + if ada.memory_location is None: + ada.memory_location = self.slot_memory_location + self._connect_components(start_time) self.logger.debug("validate components") @@ -432,6 +446,20 @@ def _collect_adapters_output(out: IOutput, out_adapters: set): _collect_adapters_output(trg, out_adapters) +def _get_start_time(time_modules): + t_min = None + for mod in time_modules: + if mod.time is not None: + if t_min is None or mod.time < t_min: + t_min = mod.time + if t_min is None: + raise ValueError( + "Unable to determine starting time of the composition." + "Please provide a starting time in ``run()`` or ``connect()``" + ) + return t_min + + def _check_missing_modules(modules): inputs, outputs = _collect_inputs_outputs(modules) diff --git a/src/finam/sdk/component.py b/src/finam/sdk/component.py index 3fd7a3beb4616aa065e79439b3e9f014b821c400..24c1fd7316c02a3aca50a9dd24b141bbd70b0e4a 100644 --- a/src/finam/sdk/component.py +++ b/src/finam/sdk/component.py @@ -173,6 +173,10 @@ class Component(IComponent, Loggable, ABC): """ self.logger.debug("finalize") self._finalize() + + for _n, out in self.outputs.items(): + out.finalize() + if self.status != ComponentStatus.FAILED: self.status = ComponentStatus.FINALIZED diff --git a/src/finam/sdk/output.py b/src/finam/sdk/output.py index f2bbaf7665465c1f1c016271cf214f1eaa7730e1..bc1250e9665cdfc4a9965ed75f7c5c1df498bf4d 100644 --- a/src/finam/sdk/output.py +++ b/src/finam/sdk/output.py @@ -2,6 +2,7 @@ Implementations of IOutput """ import logging +import os from datetime import datetime import numpy as np @@ -19,6 +20,7 @@ from ..interfaces import IAdapter, IInput, IOutput, Loggable from ..tools.log_helper import ErrorLogger +# pylint: disable=too-many-public-methods class Output(IOutput, Loggable): """Default output implementation.""" @@ -44,6 +46,10 @@ class Output(IOutput, Loggable): self._out_infos_exchanged = 0 self._time = None + self._mem_limit = None + self._mem_location = None + self._total_mem = 0 + self._mem_counter = 0 @property def time(self): @@ -75,6 +81,26 @@ class Output(IOutput, Loggable): """bool: if the output needs push.""" return True + @property + def memory_limit(self): + """The memory limit for this slot""" + return self._mem_limit + + @memory_limit.setter + def memory_limit(self, limit): + """The memory limit for this slot""" + self._mem_limit = limit + + @property + def memory_location(self): + """The memory-mapping location for this slot""" + return self._mem_location + + @memory_location.setter + def memory_location(self, directory): + """The memory-mapping location for this slot""" + self._mem_location = directory + def has_info(self): """Returns if the output has a data info. @@ -149,13 +175,13 @@ class Output(IOutput, Loggable): with ErrorLogger(self.logger): xdata = tools.prepare(data, self.info) - if len(self.data) > 0: + if len(self.data) > 0 and not isinstance(self.data[-1][1], str): d = self.data[-1][1] if np.may_share_memory(d.data, xdata.data): raise FinamDataError( "Received data that shares memory with previously received data." ) - + xdata = self._pack(xdata) self.data.append((time, xdata)) self._time = time @@ -227,7 +253,11 @@ class Output(IOutput, Loggable): raise FinamNoDataError(f"No data available in {self.name}") with ErrorLogger(self.logger): - data = self.data[0][1] if self.is_static else self._interpolate(time) + data = ( + self._unpack(self.data[0][1]) + if self.is_static + else self._interpolate(time) + ) if not self.is_static: data_count = len(self.data) @@ -240,6 +270,37 @@ class Output(IOutput, Loggable): return data + def _pack(self, data): + data_size = data.nbytes + if self.memory_limit is not None and 0 <= self.memory_limit < ( + self._total_mem + data_size + ): + fn = os.path.join( + self.memory_location or "", f"{id(self)}-{self._mem_counter}.npy" + ) + self.logger.debug( + "dumping data to file %s (total RAM %0.2f MB)", + fn, + self._total_mem / 1048576, + ) + self._mem_counter += 1 + np.save(fn, data.magnitude) + return fn + + self._total_mem += data_size + self.logger.debug( + "keeping data in RAM (total RAM %0.2f MB)", self._total_mem / 1048576 + ) + return data + + def _unpack(self, where): + if isinstance(where, str): + self.logger.debug("reading data from file %s", where) + data = np.load(where, allow_pickle=True) + return tools.UNITS.Quantity(data, self.info.units) + + return where + def _clear_data(self, time, target): self._connected_inputs[target] = time if any(t is None for t in self._connected_inputs.values()): @@ -247,7 +308,18 @@ class Output(IOutput, Loggable): t_min = min(self._connected_inputs.values()) while len(self.data) > 1 and self.data[1][0] <= t_min: - self.data.pop(0) + d = self.data.pop(0) + if isinstance(d[1], str): + os.remove(d[1]) + else: + self._total_mem -= d[1].nbytes + + def finalize(self): + """Finalize the output""" + for _t, d in self.data: + if isinstance(d, str): + os.remove(d) + self.data.clear() def _interpolate(self, time): if time < self.data[0][0] or time > self.data[-1][0]: @@ -258,16 +330,16 @@ class Output(IOutput, Loggable): if time > t: continue if time == t: - return data + return self._unpack(data) t_prev, data_prev = self.data[i - 1] diff = t - t_prev t_half = t_prev + diff / 2 if time < t_half: - return data_prev + return self._unpack(data_prev) - return data + return self._unpack(data) raise FinamTimeError( f"Time interpolation failed. This should not happen and is probably a bug. " @@ -453,6 +525,9 @@ class CallbackOutput(Output): self.last_data = xdata return xdata + def finalize(self): + """Finalize the output""" + def _check_time(time, is_static): if is_static: diff --git a/tests/core/test_sdk.py b/tests/core/test_sdk.py index 9c0928bd7a53eda4f06f7fc1d49327eee7277ec6..cdd1efba94de44c4d18359048fb7d95a736e0492 100644 --- a/tests/core/test_sdk.py +++ b/tests/core/test_sdk.py @@ -2,9 +2,13 @@ Unit tests for the sdk implementations. """ import logging +import os.path +import tempfile import unittest from datetime import datetime, timedelta +import numpy as np + import finam as fm from finam import ( Adapter, @@ -397,9 +401,50 @@ class TestOutput(unittest.TestCase): out_data = in1.pull_data(t, in1) self.assertEqual(out_data[0, 0, 0], 0.0 * fm.UNITS("km")) - in_data[0, 0, 0] = 1.0 * fm.UNITS("m") + in_data[0, 0] = 1.0 * fm.UNITS("m") self.assertEqual(out_data[0, 0, 0], 0.0 * fm.UNITS("km")) + def test_memory_limit(self): + t = datetime(2000, 1, 1) + info = Info(time=t, grid=fm.UniformGrid((100, 100))) + + with tempfile.TemporaryDirectory() as td: + + out = Output(name="Output") + out.memory_limit = 0 + out.memory_location = td + oid = id(out) + + in1 = Input(name="Input") + + out >> in1 + + in1.ping() + + out.push_info(info) + in1.exchange_info(info) + + in_data = fm.data.full(0.0, info) + out.push_data(np.copy(in_data), datetime(2000, 1, 1)) + out.push_data(np.copy(in_data), datetime(2000, 1, 2)) + + self.assertTrue(os.path.isfile(os.path.join(td, f"{oid}-{0}.npy"))) + self.assertTrue(os.path.isfile(os.path.join(td, f"{oid}-{1}.npy"))) + + data = in1.pull_data(datetime(2000, 1, 2), in1) + + np.testing.assert_allclose(data.magnitude, in_data.magnitude) + self.assertEqual(data.units, in_data.units) + self.assertEqual(data.units, info.units) + + self.assertFalse(os.path.isfile(os.path.join(td, f"{oid}-{0}.npy"))) + self.assertTrue(os.path.isfile(os.path.join(td, f"{oid}-{1}.npy"))) + + out.finalize() + + self.assertFalse(os.path.isfile(os.path.join(td, f"{oid}-{0}.npy"))) + self.assertFalse(os.path.isfile(os.path.join(td, f"{oid}-{1}.npy"))) + class TestInput(unittest.TestCase): def test_fail_set_source(self):