Skip to content
Snippets Groups Projects

Cache units, compatibility and equivalence

Merged Martin Lange requested to merge cache-units into main
+ 17
11
@@ -25,7 +25,8 @@ from .grid_tools import Grid, GridBase
pint_xarray.unit_registry.default_format = "cf"
UNITS = pint_xarray.unit_registry
_UNITS_CACHE = {}
_UNIT_CACHE = {}
_UNIT_PAIRS_CACHE = {}
def _gen_dims(ndim, info):
@@ -98,7 +99,7 @@ def to_xarray(data, name, info, time_entries=1, force_copy=False):
return to_units(data, info.units)
units = UNITS.Unit(info.units)
units = _get_pint_units(info.units)
if isinstance(data, pint.Quantity):
if not compatible_units(data.units, units):
raise FinamDataError(
@@ -329,7 +330,7 @@ def to_units(xdata, units):
Converted data.
"""
check_quantified(xdata, "to_units")
units = UNITS.Unit(units)
units = _get_pint_units(units)
if units == xdata.pint.units:
return xdata
return xdata.pint.to(units)
@@ -435,11 +436,10 @@ def check(
f"check: given data has wrong meta data.\nData: {xdata.attrs}\nMeta: {meta}"
)
# check units
units = UNITS.Unit(info.units)
if not compatible_units(units, xdata):
if not compatible_units(info.units, xdata):
raise FinamDataError(
f"check: given data has incompatible units. "
f"Got {get_units(xdata)}, expected {units}."
f"Got {get_units(xdata)}, expected {UNITS.Unit(info.units)}."
)
@@ -518,7 +518,12 @@ def _get_pint_units(var):
if isinstance(var, xr.DataArray):
return var.pint.units or UNITS.dimensionless
return UNITS.Unit(var)
units = _UNIT_CACHE.get(var)
if units is None:
units = UNITS.Unit(var)
_UNIT_CACHE[var] = units
return units
def compatible_units(unit1, unit2):
@@ -538,7 +543,7 @@ def compatible_units(unit1, unit2):
Unit compatibility.
"""
unit1, unit2 = _get_pint_units(unit1), _get_pint_units(unit2)
comp_equiv = _UNITS_CACHE.get((unit1, unit2))
comp_equiv = _UNIT_PAIRS_CACHE.get((unit1, unit2))
if comp_equiv is None:
comp_equiv = _cache_units(unit1, unit2)
@@ -562,7 +567,7 @@ def equivalent_units(unit1, unit2):
Unit equivalence.
"""
unit1, unit2 = _get_pint_units(unit1), _get_pint_units(unit2)
comp_equiv = _UNITS_CACHE.get((unit1, unit2))
comp_equiv = _UNIT_PAIRS_CACHE.get((unit1, unit2))
if comp_equiv is None:
comp_equiv = _cache_units(unit1, unit2)
@@ -578,13 +583,14 @@ def _cache_units(unit1, unit2):
except pint.errors.DimensionalityError:
pass
_UNITS_CACHE[(unit1, unit2)] = compat, equiv
_UNIT_PAIRS_CACHE[(unit1, unit2)] = compat, equiv
return compat, equiv
def clear_units_cache():
"""Clears the units cache"""
_UNITS_CACHE.clear()
_UNIT_CACHE.clear()
_UNIT_PAIRS_CACHE.clear()
def assert_type(cls, slot, obj, types):
Loading