Skip to content
Snippets Groups Projects
Commit 16231fb0 authored by Sebastian Müller's avatar Sebastian Müller 🐈
Browse files

info.accepts: indicate info origin; better mask checking

parent 6130fa76
No related branches found
No related tags found
1 merge request!286Add mask to Info object
......@@ -634,20 +634,19 @@ def from_compressed(xdata, shape, order="C", mask=None, **kwargs):
-----
If both `mask` and `shape` are given, they need to match in size.
"""
if mask is None or not mask_specified(mask):
if mask is None or mask is np.ma.nomask or not mask_specified(mask):
if kwargs and mask is Mask.NONE:
msg = "from_compressed: Can't create masked array with mask=Mask.NONE"
raise FinamDataError(msg)
data = np.reshape(xdata, shape, order=order)
return to_masked(data, **kwargs) if kwargs else data
mask = mask if mask is np.ma.nomask else np.ravel(mask, order=order)
return to_masked(data, **kwargs) if kwargs or mask is np.ma.nomask else data
if is_quantified(xdata):
# pylint: disable-next=unexpected-keyword-arg
data = quantify(np.empty_like(xdata, shape=np.prod(shape)), xdata.units)
else:
# pylint: disable-next=unexpected-keyword-arg
data = np.empty_like(xdata, shape=np.prod(shape))
data[~mask] = xdata
data[np.logical_not(np.ravel(mask, order=order))] = xdata
return to_masked(np.reshape(data, shape, order=order), mask=mask, **kwargs)
......@@ -818,7 +817,7 @@ def assert_type(cls, slot, obj, types):
)
def masks_compatible(this, incoming):
def masks_compatible(this, incoming, incoming_donwstream):
"""
Check if an incoming mask is compatible with a given mask.
......@@ -828,25 +827,31 @@ def masks_compatible(this, incoming):
mask specification to check against
incoming : :any:`Mask` value or valid boolean mask for :any:`MaskedArray` or None
incoming mask to check for compatibility
incoming_donwstream : bool
Whether the incoming mask is from downstream data
Returns
-------
bool
mask compatibility
"""
if incoming_donwstream:
upstream, downstream = this, incoming
else:
upstream, downstream = incoming, this
# None is incompatible
if incoming is None:
if upstream is None:
return False
# Mask.FLEX accepts anything, Mask.NONE only Mask.NONE
if this in list(Mask):
if incoming in list(Mask):
return this == Mask.FLEX or incoming == Mask.NONE
return this == Mask.FLEX
# if mask is specified, incoming mask must also be specified
if incoming in list(Mask):
if not mask_specified(downstream):
if not mask_specified(upstream):
return downstream == Mask.FLEX or upstream == Mask.NONE
return downstream == Mask.FLEX
# if mask is specified, upstream mask must also be specified
if not mask_specified(upstream):
return False
# if both mask given, compare them
return masks_equal(this, incoming)
return masks_equal(downstream, upstream)
def masks_equal(this, other):
......@@ -868,7 +873,7 @@ def masks_equal(this, other):
"""
if this is None and other is None:
return True
if this in list(Mask) and other in list(Mask):
if not mask_specified(this) and not mask_specified(other):
return this == other
# need a valid mask at this point
if not np.ma.is_mask(this) or not np.ma.is_mask(other):
......@@ -906,7 +911,7 @@ def mask_specified(mask):
def _format_mask(mask):
if mask not in list(Mask) + [None, np.ma.nomask]:
if mask_specified(mask) and mask is not np.ma.nomask:
return "<ndarray>"
if mask is np.ma.nomask:
return "nomask"
......@@ -950,7 +955,7 @@ class Info:
self.time = time
self.grid = grid
if mask not in list(Mask) + [None]:
if mask_specified(mask) and mask is not None:
mask = np.ma.make_mask(mask, shrink=False)
self.mask = mask
self.meta = meta or {}
......@@ -1008,7 +1013,7 @@ class Info:
return other
def accepts(self, incoming, fail_info, ignore_none=False):
def accepts(self, incoming, fail_info, incoming_donwstream=False):
"""
Tests whether this info can accept/is compatible with an incoming info.
......@@ -1020,8 +1025,8 @@ class Info:
Incoming/source info to check. This is the info from upstream.
fail_info : dict
Dictionary that will be filled with failed properties; name: (source, target).
ignore_none : bool
Ignores ``None`` values in the incoming info.
incoming_donwstream : bool, optional
Whether the incoming info is from downstream data. Default: False
Returns
-------
......@@ -1034,19 +1039,21 @@ class Info:
success = True
if self.grid is not None and not self.grid.compatible_with(incoming.grid):
if not (ignore_none and incoming.grid is None):
if not (incoming_donwstream and incoming.grid is None):
fail_info["grid"] = (incoming.grid, self.grid)
success = False
if self.mask is not None and not masks_compatible(self.mask, incoming.mask):
if not (ignore_none and incoming.mask is None):
if self.mask is not None and not masks_compatible(
self.mask, incoming.mask, incoming_donwstream
):
if not (incoming_donwstream and incoming.mask is None):
fail_info["mask"] = (incoming.mask, self.mask)
success = False
u1_none = (u1 := self.units) is None
u2_none = (u2 := incoming.units) is None
if not u1_none and (u2_none or not compatible_units(u1, u2)):
if not (ignore_none and u2_none):
if not (incoming_donwstream and u2_none):
fail_info["units"] = (u2, u1)
success = False
......
"""
Implementations of IOutput
"""
import logging
import os
from datetime import datetime
......@@ -382,7 +383,7 @@ class Output(IOutput, Loggable):
raise FinamNoDataError("No data info available")
fail_info = {}
if not self._output_info.accepts(info, fail_info, ignore_none=True):
if not self._output_info.accepts(info, fail_info, incoming_donwstream=True):
fail_info = "\n".join(
[
f"{name} - got {got}, expected {exp}"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment