Skip to content
Snippets Groups Projects
Commit d38f166f authored by Martin Lange's avatar Martin Lange
Browse files

also cache the units themselves

parent 97b17e91
No related branches found
No related tags found
1 merge request!229Cache units, compatibility and equivalence
Pipeline #133704 passed with stages
in 5 minutes and 30 seconds
......@@ -25,7 +25,8 @@ from .grid_tools import Grid, GridBase
pint_xarray.unit_registry.default_format = "cf"
UNITS = pint_xarray.unit_registry
_UNITS_CACHE = {}
_UNIT_CACHE = {}
_UNIT_PAIRS_CACHE = {}
def _gen_dims(ndim, info):
......@@ -98,7 +99,7 @@ def to_xarray(data, name, info, time_entries=1, force_copy=False):
return to_units(data, info.units)
units = UNITS.Unit(info.units)
units = _get_pint_units(info.units)
if isinstance(data, pint.Quantity):
if not compatible_units(data.units, units):
raise FinamDataError(
......@@ -329,7 +330,7 @@ def to_units(xdata, units):
Converted data.
"""
check_quantified(xdata, "to_units")
units = UNITS.Unit(units)
units = _get_pint_units(units)
if units == xdata.pint.units:
return xdata
return xdata.pint.to(units)
......@@ -435,11 +436,10 @@ def check(
f"check: given data has wrong meta data.\nData: {xdata.attrs}\nMeta: {meta}"
)
# check units
units = UNITS.Unit(info.units)
if not compatible_units(units, xdata):
if not compatible_units(info.units, xdata):
raise FinamDataError(
f"check: given data has incompatible units. "
f"Got {get_units(xdata)}, expected {units}."
f"Got {get_units(xdata)}, expected {UNITS.Unit(info.units)}."
)
......@@ -518,7 +518,12 @@ def _get_pint_units(var):
if isinstance(var, xr.DataArray):
return var.pint.units or UNITS.dimensionless
return UNITS.Unit(var)
units = _UNIT_CACHE.get(var)
if units is None:
units = UNITS.Unit(var)
_UNIT_CACHE[var] = units
return units
def compatible_units(unit1, unit2):
......@@ -538,7 +543,7 @@ def compatible_units(unit1, unit2):
Unit compatibility.
"""
unit1, unit2 = _get_pint_units(unit1), _get_pint_units(unit2)
comp_equiv = _UNITS_CACHE.get((unit1, unit2))
comp_equiv = _UNIT_PAIRS_CACHE.get((unit1, unit2))
if comp_equiv is None:
comp_equiv = _cache_units(unit1, unit2)
......@@ -562,7 +567,7 @@ def equivalent_units(unit1, unit2):
Unit equivalence.
"""
unit1, unit2 = _get_pint_units(unit1), _get_pint_units(unit2)
comp_equiv = _UNITS_CACHE.get((unit1, unit2))
comp_equiv = _UNIT_PAIRS_CACHE.get((unit1, unit2))
if comp_equiv is None:
comp_equiv = _cache_units(unit1, unit2)
......@@ -578,13 +583,14 @@ def _cache_units(unit1, unit2):
except pint.errors.DimensionalityError:
pass
_UNITS_CACHE[(unit1, unit2)] = compat, equiv
_UNIT_PAIRS_CACHE[(unit1, unit2)] = compat, equiv
return compat, equiv
def clear_units_cache():
"""Clears the units cache"""
_UNITS_CACHE.clear()
_UNIT_CACHE.clear()
_UNIT_PAIRS_CACHE.clear()
def assert_type(cls, slot, obj, types):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment