Skip to content
Snippets Groups Projects

Cache units, compatibility and equivalence

Merged Martin Lange requested to merge cache-units into main
+ 41
10
@@ -25,6 +25,9 @@ from .grid_tools import Grid, GridBase
pint_xarray.unit_registry.default_format = "cf"
UNITS = pint_xarray.unit_registry
_UNIT_CACHE = {}
_UNIT_PAIRS_CACHE = {}
def _gen_dims(ndim, info):
"""
@@ -96,7 +99,7 @@ def to_xarray(data, name, info, time_entries=1, force_copy=False):
return to_units(data, info.units)
units = UNITS.Unit(info.units)
units = _get_pint_units(info.units)
if isinstance(data, pint.Quantity):
if not compatible_units(data.units, units):
raise FinamDataError(
@@ -327,7 +330,7 @@ def to_units(xdata, units):
Converted data.
"""
check_quantified(xdata, "to_units")
units = UNITS.Unit(units)
units = _get_pint_units(units)
if units == xdata.pint.units:
return xdata
return xdata.pint.to(units)
@@ -433,11 +436,10 @@ def check(
f"check: given data has wrong meta data.\nData: {xdata.attrs}\nMeta: {meta}"
)
# check units
units = UNITS.Unit(info.units)
if not compatible_units(units, xdata):
if not compatible_units(info.units, xdata):
raise FinamDataError(
f"check: given data has incompatible units. "
f"Got {get_units(xdata)}, expected {units}."
f"Got {get_units(xdata)}, expected {UNITS.Unit(info.units)}."
)
@@ -516,7 +518,12 @@ def _get_pint_units(var):
if isinstance(var, xr.DataArray):
return var.pint.units or UNITS.dimensionless
return UNITS.Unit(var)
units = _UNIT_CACHE.get(var)
if units is None:
units = UNITS.Unit(var)
_UNIT_CACHE[var] = units
return units
def compatible_units(unit1, unit2):
@@ -533,10 +540,14 @@ def compatible_units(unit1, unit2):
Returns
-------
bool
Uint compatibility.
Unit compatibility.
"""
unit1, unit2 = _get_pint_units(unit1), _get_pint_units(unit2)
return unit1.is_compatible_with(unit2)
comp_equiv = _UNIT_PAIRS_CACHE.get((unit1, unit2))
if comp_equiv is None:
comp_equiv = _cache_units(unit1, unit2)
return comp_equiv[0]
def equivalent_units(unit1, unit2):
@@ -556,10 +567,30 @@ def equivalent_units(unit1, unit2):
Unit equivalence.
"""
unit1, unit2 = _get_pint_units(unit1), _get_pint_units(unit2)
comp_equiv = _UNIT_PAIRS_CACHE.get((unit1, unit2))
if comp_equiv is None:
comp_equiv = _cache_units(unit1, unit2)
return comp_equiv[1]
def _cache_units(unit1, unit2):
equiv = False
compat = False
try:
return np.isclose((1.0 * unit1).to(unit2).magnitude, 1.0)
equiv = np.isclose((1.0 * unit1).to(unit2).magnitude, 1.0)
compat = True
except pint.errors.DimensionalityError:
return False
pass
_UNIT_PAIRS_CACHE[(unit1, unit2)] = compat, equiv
return compat, equiv
def clear_units_cache():
"""Clears the units cache"""
_UNIT_CACHE.clear()
_UNIT_PAIRS_CACHE.clear()
def assert_type(cls, slot, obj, types):
Loading