Skip to content
Snippets Groups Projects
Commit 7fac418f authored by Martin Lange's avatar Martin Lange
Browse files

Merge branch 'cache-units' into 'main'

Cache units, compatibility and equivalence

Closes #91

See merge request !229
parents 4bd14d02 d38f166f
No related branches found
No related tags found
1 merge request!229Cache units, compatibility and equivalence
Pipeline #133931 passed with stages
in 6 minutes and 47 seconds
......@@ -25,6 +25,9 @@ from .grid_tools import Grid, GridBase
pint_xarray.unit_registry.default_format = "cf"
UNITS = pint_xarray.unit_registry
_UNIT_CACHE = {}
_UNIT_PAIRS_CACHE = {}
def _gen_dims(ndim, info):
"""
......@@ -96,7 +99,7 @@ def to_xarray(data, name, info, time_entries=1, force_copy=False):
return to_units(data, info.units)
units = UNITS.Unit(info.units)
units = _get_pint_units(info.units)
if isinstance(data, pint.Quantity):
if not compatible_units(data.units, units):
raise FinamDataError(
......@@ -327,7 +330,7 @@ def to_units(xdata, units):
Converted data.
"""
check_quantified(xdata, "to_units")
units = UNITS.Unit(units)
units = _get_pint_units(units)
if units == xdata.pint.units:
return xdata
return xdata.pint.to(units)
......@@ -433,11 +436,10 @@ def check(
f"check: given data has wrong meta data.\nData: {xdata.attrs}\nMeta: {meta}"
)
# check units
units = UNITS.Unit(info.units)
if not compatible_units(units, xdata):
if not compatible_units(info.units, xdata):
raise FinamDataError(
f"check: given data has incompatible units. "
f"Got {get_units(xdata)}, expected {units}."
f"Got {get_units(xdata)}, expected {UNITS.Unit(info.units)}."
)
......@@ -516,7 +518,12 @@ def _get_pint_units(var):
if isinstance(var, xr.DataArray):
return var.pint.units or UNITS.dimensionless
return UNITS.Unit(var)
units = _UNIT_CACHE.get(var)
if units is None:
units = UNITS.Unit(var)
_UNIT_CACHE[var] = units
return units
def compatible_units(unit1, unit2):
......@@ -533,10 +540,14 @@ def compatible_units(unit1, unit2):
Returns
-------
bool
Uint compatibility.
Unit compatibility.
"""
unit1, unit2 = _get_pint_units(unit1), _get_pint_units(unit2)
return unit1.is_compatible_with(unit2)
comp_equiv = _UNIT_PAIRS_CACHE.get((unit1, unit2))
if comp_equiv is None:
comp_equiv = _cache_units(unit1, unit2)
return comp_equiv[0]
def equivalent_units(unit1, unit2):
......@@ -556,10 +567,30 @@ def equivalent_units(unit1, unit2):
Unit equivalence.
"""
unit1, unit2 = _get_pint_units(unit1), _get_pint_units(unit2)
comp_equiv = _UNIT_PAIRS_CACHE.get((unit1, unit2))
if comp_equiv is None:
comp_equiv = _cache_units(unit1, unit2)
return comp_equiv[1]
def _cache_units(unit1, unit2):
equiv = False
compat = False
try:
return np.isclose((1.0 * unit1).to(unit2).magnitude, 1.0)
equiv = np.isclose((1.0 * unit1).to(unit2).magnitude, 1.0)
compat = True
except pint.errors.DimensionalityError:
return False
pass
_UNIT_PAIRS_CACHE[(unit1, unit2)] = compat, equiv
return compat, equiv
def clear_units_cache():
"""Clears the units cache"""
_UNIT_CACHE.clear()
_UNIT_PAIRS_CACHE.clear()
def assert_type(cls, slot, obj, types):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment