diff --git a/src/finam/interfaces.py b/src/finam/interfaces.py index 30efe0ba89209cdcfec29f8b4489e3007098cb82..ceb8a774f29a631942b52c028c81e5e9feeb77a8 100644 --- a/src/finam/interfaces.py +++ b/src/finam/interfaces.py @@ -402,6 +402,10 @@ class IOutput(ABC): class IAdapter(IInput, IOutput, ABC): """Interface for adapters.""" + @abstractmethod + def finalize(self): + """Called at the end of each run. Can be used for cleanup.""" + class NoBranchAdapter: """Interface to mark adapters as allowing only a single end point.""" diff --git a/src/finam/schedule.py b/src/finam/schedule.py index 15cff66ec698ad0e7eb58ba3f0430c90fae645d5..9f34e3d5078491bfdbbf53ef28c300090a7878a1 100644 --- a/src/finam/schedule.py +++ b/src/finam/schedule.py @@ -23,6 +23,7 @@ from .errors import ( ) from .interfaces import ( ComponentStatus, + IAdapter, IComponent, IInput, IOutput, @@ -109,6 +110,7 @@ class Composition(Loggable): "Composition: modules need to be instances of 'IComponent'." ) self.modules = modules + self.adapters = set() self.dependencies = None self.output_owners = None self.is_initialized = False @@ -181,6 +183,7 @@ class Composition(Loggable): "start must be of type datetime for a composition with time components" ) + self._collect_adapters() self._validate_composition() self._connect_components(start_time) @@ -297,6 +300,13 @@ class Composition(Loggable): return None + def _collect_adapters(self): + for mod in self.modules: + for _, inp in mod.inputs.items(): + _collect_adapters_input(inp, self.adapters) + for _, out in mod.outputs.items(): + _collect_adapters_output(out, self.adapters) + def _validate_composition(self): """Validates the coupling setup by checking for dangling inputs and disallowed branching connections.""" self.logger.debug("validate composition") @@ -375,6 +385,9 @@ class Composition(Loggable): mod.finalize() self._check_status(mod, [ComponentStatus.FINALIZED]) + for ada in self.adapters: + ada.finalize() + def _finalize_composition(self): self.logger.debug("finalize composition") handlers = self.logger.handlers[:] @@ -401,6 +414,23 @@ class Composition(Loggable): ) +def _collect_adapters_input(inp: IInput, out_adapters: set): + src = inp.get_source() + if src is None: + return + + if isinstance(src, IAdapter): + out_adapters.add(src) + _collect_adapters_input(src, out_adapters) + + +def _collect_adapters_output(out: IOutput, out_adapters: set): + for trg in out.get_targets(): + if isinstance(trg, IAdapter): + out_adapters.add(trg) + _collect_adapters_output(trg, out_adapters) + + def _check_missing_modules(modules): inputs, outputs = _collect_inputs_outputs(modules) diff --git a/src/finam/sdk/adapter.py b/src/finam/sdk/adapter.py index 2005ce328a1792ea57022e031b8ba32589d52bda..2e78ccababc50f75e219fdfa9a05d385fdfbf2ed 100644 --- a/src/finam/sdk/adapter.py +++ b/src/finam/sdk/adapter.py @@ -283,6 +283,14 @@ class Adapter(IAdapter, Input, Output, ABC): base_logger = logging.getLogger(self.base_logger_name) return ".".join(([base_logger.name, " >> ", self.name])) + def finalize(self): + """Called at the end of each run. Calls :meth:`._finalize`.""" + self.logger.debug("finalize") + self._finalize() + + def _finalize(self): + """Called at the end of each run. Overwrite this for cleanup.""" + class TimeDelayAdapter(Adapter, ITimeDelayAdapter, ABC): """Base class for adapters that delay/offset time to resolve dependency cycles.""" diff --git a/tests/core/test_schedule.py b/tests/core/test_schedule.py index e1736171f51f70cac9fd142e5f20b315e5db6c78..a80ebb623e648ebd2dcd5e276e42168c49d16e6d 100644 --- a/tests/core/test_schedule.py +++ b/tests/core/test_schedule.py @@ -310,6 +310,22 @@ class TestComposition(unittest.TestCase): lines = f.readlines() self.assertNotEqual(len(lines), 0) + def test_collect_adapters(self): + module1 = MockupComponent( + callbacks={"Output": lambda t: t.day}, step=timedelta(1.0) + ) + module2 = MockupDependentComponent(step=timedelta(1.0)) + + composition = Composition([module2, module1]) + composition.initialize() + + ada = fm.adapters.Scale(1.0) + module1.outputs["Output"] >> ada >> module2.inputs["Input"] + + composition.connect() + + self.assertEqual({ada}, composition.adapters) + def test_fail_time(self): module1 = MockupComponent( callbacks={"Output": lambda t: t}, step=timedelta(1.0)