Skip to content
Snippets Groups Projects

Auto-transfer infos

Merged Martin Lange requested to merge auto-transfer-info into main
1 file
+ 2
5
Compare changes
  • Side-by-side
  • Inline
"""Iterative connection helpers."""
import copy
import logging
from abc import ABC
from finam.interfaces import ComponentStatus, Loggable
from ..data.tools import Info
from ..errors import FinamNoDataError
from ..tools.log_helper import ErrorLogger
class MissingInfoError(Exception):
"""Internal error type for handling missing infos for transfer rules"""
class InfoSource(ABC):
"""Base class for info transfer rules from inputs or outputs"""
def __init__(self, name, *fields):
self.name = name
self.fields = list(*fields) or []
class FromInput(InfoSource):
"""Info transfer rule from an input.
See :meth:`.Component.create_connector` for usage details.
Parameters
----------
name : str
Name of the input to take info from
*fields : str, optional
Info fields to take from the input.
Takes all fields if this is empty.
"""
def __init__(self, name, *fields):
super().__init__(name, fields)
class FromOutput(InfoSource):
"""Info transfer rule from an output.
See :meth:`.Component.create_connector` for usage details.
Parameters
----------
name : str
Name of the output to take info from
*fields : str, optional
Info fields to take from the output.
Takes all fields if this is empty.
"""
def __init__(self, name, *fields):
super().__init__(name, fields)
class FromValue:
"""
Info transfer rule from a given value
Parameters
----------
field : str
Field to set.
value : any
Value to set.
"""
def __init__(self, field, value):
self.field = field
self.value = value
class ConnectHelper(Loggable):
"""Helper for iterative connect.
Warning:
This class is not intended for direct use!
Use :meth:`.Components.create_connector` and :meth:`.Components.try_connect` instead.
Use :meth:`.Component.create_connector` and :meth:`.Component.try_connect` instead.
Parameters
----------
@@ -22,6 +90,10 @@ class ConnectHelper(Loggable):
All inputs of the component.
outputs : dict
All outputs of the component.
in_info_rules : dict
Info transfer rules for inputs.
out_info_rules : dict
Info transfer rules for outputs.
pull_data : arraylike
Names of the inputs that are to be pulled.
cache : bool
@@ -35,6 +107,8 @@ class ConnectHelper(Loggable):
inputs,
outputs,
pull_data=None,
in_info_rules=None,
out_info_rules=None,
cache=True,
):
@@ -46,7 +120,7 @@ class ConnectHelper(Loggable):
with ErrorLogger(self.logger):
for name in pull_data or []:
if name not in self._inputs:
raise ValueError(
raise KeyError(
f"No input named '{name}' available to get info for."
)
@@ -62,10 +136,105 @@ class ConnectHelper(Loggable):
name: False for name, out in self.outputs.items() if out.needs_push
}
self._in_info_rules = in_info_rules or {}
self._out_info_rules = out_info_rules or {}
with ErrorLogger(self.logger):
self._check_info_rules()
self._in_info_cache = {}
self._out_info_cache = {}
self._out_data_cache = {}
def add_in_info_rule(self, in_name, rule):
"""
Add an input info rule.
Parameters
----------
in_name : str
Name of the input to add an info rule to.
rule : FromOutput or FromInput or FromValue
Rule to add.
"""
if in_name in self._in_info_rules:
self._in_info_rules[in_name].append(rule)
else:
self._in_info_rules[in_name] = [rule]
with ErrorLogger(self.logger):
self._check_info_rules()
def add_out_info_rule(self, out_name, rule):
"""
Add an output info rule.
Parameters
----------
out_name : str
Name of the output to add an info rule to.
rule : FromInput or FromOutput or FromValue
Rule to add.
"""
if out_name in self._out_info_rules:
self._out_info_rules[out_name].append(rule)
else:
self._out_info_rules[out_name] = [rule]
with ErrorLogger(self.logger):
self._check_info_rules()
def _apply_rules(self, rules):
info = Info(time=None, grid=None)
for rule in rules:
if isinstance(rule, FromInput):
in_info = self.in_infos[rule.name]
if in_info is None:
raise MissingInfoError()
_transfer_fields(in_info, info, rule.fields)
elif isinstance(rule, FromOutput):
out_info = self.out_infos[rule.name]
if out_info is None:
raise MissingInfoError()
_transfer_fields(out_info, info, rule.fields)
elif isinstance(rule, FromValue):
if rule.field == "time":
info.time = rule.value
elif rule.field == "grid":
info.grid = rule.value
else:
info.meta[rule.field] = rule.value
return info
def _check_info_rules(self):
for name, rules in self._in_info_rules.items():
if name not in self._inputs:
raise KeyError(f"No input named '{name}' to apply info transfer rule.")
for rule in rules:
self._check_rule(rule)
for name, rules in self._out_info_rules.items():
if name not in self._outputs:
raise KeyError(f"No output named '{name}' to apply info transfer rule.")
for rule in rules:
self._check_rule(rule)
def _check_rule(self, rule):
if isinstance(rule, FromInput):
if rule.name not in self._inputs:
raise KeyError(
f"No input named '{rule.name}' to use in info transfer rule."
)
elif isinstance(rule, FromOutput):
if rule.name not in self._outputs:
raise KeyError(
f"No output named '{rule.name}' to use in info transfer rule."
)
elif not isinstance(rule, FromValue):
raise TypeError(
f"Rules must be one of the types FromInput, FromOutput or FromValue. "
f"Got '{rule.__class__.__name__}'."
)
@property
def logger(self):
"""Logger for this component."""
@@ -108,6 +277,11 @@ class ConnectHelper(Loggable):
"""dict: The pulled input data so far. May contain None values."""
return self._pulled_data
@property
def all_data_pulled(self):
"""bool: True if all expected data is pulled."""
return all(data is not None for data in self.in_data.values())
@property
def infos_pushed(self):
"""dict: If an info was pushed for outputs so far."""
@@ -149,6 +323,7 @@ class ConnectHelper(Loggable):
with ErrorLogger(self.logger):
self._check_names(exchange_infos, push_infos, push_data)
self._check_in_rules(exchange_infos, push_infos)
exchange_infos = {
k: v for k, v in exchange_infos.items() if self.in_infos[k] is None
@@ -156,6 +331,11 @@ class ConnectHelper(Loggable):
push_infos = {k: v for k, v in push_infos.items() if self.out_infos[k] is None}
push_data = {k: v for k, v in push_data.items() if not self.data_pushed[k]}
# Try to generate infos from transfer rules
with ErrorLogger(self.logger):
exchange_infos.update(self._apply_in_info_rules())
push_infos.update(self._apply_out_info_rules())
if self._cache:
self._in_info_cache.update(exchange_infos)
self._out_info_cache.update(push_infos)
@@ -205,18 +385,54 @@ class ConnectHelper(Loggable):
return ComponentStatus.CONNECTING_IDLE
def _apply_in_info_rules(self):
exchange_infos = {}
for name, rules in self._in_info_rules.items():
if self.in_infos[name] is None and name not in self._in_info_cache:
try:
info = self._apply_rules(rules)
exchange_infos[name] = info
except MissingInfoError:
pass
return exchange_infos
def _apply_out_info_rules(self):
push_infos = {}
for name, rules in self._out_info_rules.items():
if not self.infos_pushed[name] and name not in self._out_info_cache:
try:
info = self._apply_rules(rules)
push_infos[name] = info
except MissingInfoError:
pass
return push_infos
def _check_names(self, exchange_infos, push_infos, push_data):
for name in exchange_infos:
if name not in self._inputs:
raise ValueError(
raise KeyError(
f"No input named '{name}' available to exchange info for."
)
for name in push_infos:
if name not in self._outputs:
raise ValueError(f"No output named '{name}' available to push info.")
raise KeyError(f"No output named '{name}' available to push info.")
for name in push_data:
if name not in self._outputs:
raise ValueError(f"No output named '{name}' available to push data.")
raise KeyError(f"No output named '{name}' available to push data.")
def _check_in_rules(self, exchange_infos, push_infos):
for name in exchange_infos:
if name in self._in_info_rules:
raise ValueError(
f"There are info transfer rules given for input `{name}`. "
f"Can't provide the info directly."
)
for name in push_infos:
if name in self._out_info_rules:
raise ValueError(
f"There are info transfer rules given for output `{name}`. "
f"Can't provide the info directly."
)
def _exchange_in_infos(self):
any_done = False
@@ -271,3 +487,18 @@ class ConnectHelper(Loggable):
self.logger.debug("Failed to push output data for %s", name)
return any_done
def _transfer_fields(source_info, target_info, fields):
if len(fields) == 0:
target_info.time = source_info.time
target_info.grid = source_info.grid
target_info.meta = copy.copy(source_info.meta)
else:
for field in fields:
if field == "time":
target_info.time = source_info.time
elif field == "grid":
target_info.grid = source_info.grid
else:
target_info.meta[field] = source_info.meta[field]
Loading