From 3e97e4d3f36123493c6103434a71dc2b1625dd13 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sebastian=20M=C3=BCller?= <mueller.seb@posteo.de>
Date: Mon, 18 Nov 2024 10:25:46 +0100
Subject: [PATCH] Info.mask: compare masks on canonical grid

---
 src/finam/data/tools/info.py |  2 +-
 src/finam/data/tools/mask.py | 26 +++++++++++++++++++++-----
 2 files changed, 22 insertions(+), 6 deletions(-)

diff --git a/src/finam/data/tools/info.py b/src/finam/data/tools/info.py
index b652ab3f..40debdc1 100644
--- a/src/finam/data/tools/info.py
+++ b/src/finam/data/tools/info.py
@@ -145,7 +145,7 @@ class Info:
                 success = False
 
         if self.mask is not None and not masks_compatible(
-            self.mask, incoming.mask, incoming_donwstream
+            self.mask, incoming.mask, incoming_donwstream, self.grid, incoming.grid
         ):
             if not (incoming_donwstream and incoming.mask is None):
                 fail_info["mask"] = (incoming.mask, self.mask)
diff --git a/src/finam/data/tools/mask.py b/src/finam/data/tools/mask.py
index 444776b4..f7eb3663 100644
--- a/src/finam/data/tools/mask.py
+++ b/src/finam/data/tools/mask.py
@@ -240,7 +240,9 @@ def _is_single_mask_value(mask):
     return mask is None or mask is np.ma.nomask or mask is False or mask is True
 
 
-def masks_compatible(this, incoming, incoming_donwstream):
+def masks_compatible(
+    this, incoming, incoming_donwstream, this_grid=None, incoming_grid=None
+):
     """
     Check if an incoming mask is compatible with a given mask.
 
@@ -252,6 +254,10 @@ def masks_compatible(this, incoming, incoming_donwstream):
         incoming mask to check for compatibility
     incoming_donwstream : bool
         Whether the incoming mask is from downstream data
+    this_grid : Grid or NoGrid or None, optional
+        grid for first mask (to check shape and value equality)
+    incoming_grid : Grid or NoGrid or None, optional
+        grid for second mask (to check shape and value equality)
 
     Returns
     -------
@@ -260,8 +266,10 @@ def masks_compatible(this, incoming, incoming_donwstream):
     """
     if incoming_donwstream:
         upstream, downstream = this, incoming
+        up_grid, down_grid = this_grid, incoming_grid
     else:
         upstream, downstream = incoming, this
+        up_grid, down_grid = incoming_grid, this_grid
     # None is incompatible
     if upstream is None:
         return False
@@ -274,10 +282,10 @@ def masks_compatible(this, incoming, incoming_donwstream):
     if not mask_specified(upstream):
         return False
     # if both mask given, compare them
-    return masks_equal(downstream, upstream)
+    return masks_equal(downstream, upstream, down_grid, up_grid)
 
 
-def masks_equal(this, other):
+def masks_equal(this, other, this_grid=None, other_grid=None):
     """
     Check two masks for equality.
 
@@ -285,9 +293,12 @@ def masks_equal(this, other):
     ----------
     this : :any:`Mask` value or valid boolean mask for :any:`MaskedArray` or None
         first mask
-    incoming : :any:`Mask` value or valid boolean mask for :any:`MaskedArray` or None
+    other : :any:`Mask` value or valid boolean mask for :any:`MaskedArray` or None
         second mask
-
+    this_grid : Grid or NoGrid or None, optional
+        grid for first mask (to check shape and value equality)
+    other_grid : Grid or NoGrid or None, optional
+        grid for second mask (to check shape and value equality)
 
     Returns
     -------
@@ -311,6 +322,11 @@ def masks_equal(this, other):
     # compare masks
     if not np.ndim(this) == np.ndim(other):
         return False
+    # mask shape is grid specific (reversed axes, decreasing axis)
+    if this_grid is None or other_grid is None:
+        return True
+    this = this_grid.to_canonical(this)
+    other = other_grid.to_canonical(other)
     if not np.all(np.shape(this) == np.shape(other)):
         return False
     return np.all(this == other)
-- 
GitLab