From c2049c158b1a1586b6f433e81e7fce39ede42618 Mon Sep 17 00:00:00 2001
From: David Schaefer <david.schaefer@ufz.de>
Date: Tue, 29 Nov 2022 16:16:18 +0100
Subject: [PATCH] multivariate plotting

---
 CHANGELOG.md         |   1 +
 saqc/funcs/tools.py  |   8 +-
 saqc/lib/plotting.py | 290 +++++++++++++++++++++++--------------------
 3 files changed, 159 insertions(+), 140 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index b28a550b0..eb9d21990 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -10,6 +10,7 @@ SPDX-License-Identifier: GPL-3.0-or-later
 [List of commits](https://git.ufz.de/rdm-software/saqc/-/compare/v2.2.1...develop)
 ### Added
 - add option to not overwrite existing flags to `concatFlags`
+- multivariate plotting
 ### Changed
 ### Removed
 ### Fixed
diff --git a/saqc/funcs/tools.py b/saqc/funcs/tools.py
index 0967a823a..1d3f7c380 100644
--- a/saqc/funcs/tools.py
+++ b/saqc/funcs/tools.py
@@ -18,7 +18,7 @@ from typing_extensions import Literal
 from saqc.constants import FILTER_NONE, UNFLAGGED
 from saqc.core.register import processing, register
 from saqc.lib.plotting import makeFig
-from saqc.lib.tools import periodicMask
+from saqc.lib.tools import periodicMask, toSequence
 
 if TYPE_CHECKING:
     from saqc.core.core import SaQC
@@ -224,7 +224,7 @@ class ToolsMixin:
         self._flags[mask, field] = UNFLAGGED
         return self
 
-    @register(mask=[], demask=[], squeeze=[])
+    @register(mask=[], demask=[], squeeze=[], multivariate=True)
     def plot(
         self: "SaQC",
         field: str,
@@ -317,13 +317,13 @@ class ToolsMixin:
 
         fig = makeFig(
             data=data,
-            field=field,
+            fields=toSequence(field),
             flags=flags,
             level=level,
             max_gap=max_gap,
             history=history,
             xscope=xscope,
-            phaseplot=phaseplot,
+            phaseplots=phaseplot,
             ax_kwargs=ax_kwargs,
         )
 
diff --git a/saqc/lib/plotting.py b/saqc/lib/plotting.py
index edd2df9c2..8236d73d5 100644
--- a/saqc/lib/plotting.py
+++ b/saqc/lib/plotting.py
@@ -7,7 +7,7 @@
 # -*- coding: utf-8 -*-
 
 import itertools
-from typing import Optional, Union
+from typing import Sequence
 
 import matplotlib as mpl
 import matplotlib.pyplot as plt
@@ -15,10 +15,11 @@ import numpy as np
 import pandas as pd
 from typing_extensions import Literal
 
-from saqc.core.flags import Flags
+from saqc.core.flags import Flags, History
 from saqc.lib.tools import toSequence
 from saqc.lib.types import DiosLikeT
 
+
 STATSDICT = {
     "values total": lambda x, y, z: len(x),
     "invalid total (=NaN)": lambda x, y, z: x.isna().sum(),
@@ -30,7 +31,7 @@ STATSDICT = {
 PLOT_KWARGS = {"alpha": 0.8, "linewidth": 1}
 FIG_KWARGS = {"figsize": (16, 9)}
 
-_seaborn_color_palette = [
+_SEABORN_COLOR_PALETTE = [
     (0.00784313725490196, 0.24313725490196078, 1.0),
     (1.0, 0.48627450980392156, 0.0),
     (0.10196078431372549, 0.788235294117647, 0.2196078431372549),
@@ -43,9 +44,12 @@ _seaborn_color_palette = [
     (0.0, 0.8431372549019608, 1.0),
 ]
 
+_PLOT_COLORS = itertools.cycle([str(val) for val in np.arange(0, 1.0, 0.2)])
+
+
 SCATTER_KWARGS = {
     "marker": ["s", "D", "^", "o", "v"],
-    "color": _seaborn_color_palette,
+    "color": _SEABORN_COLOR_PALETTE,
     "alpha": 0.7,
     "zorder": 10,
     "edgecolors": "black",
@@ -55,14 +59,14 @@ SCATTER_KWARGS = {
 
 def makeFig(
     data: DiosLikeT,
-    field: str,
+    fields: Sequence[str],
     flags: Flags,
     level: float,
-    max_gap: Optional[str] = None,
-    history: Union[Optional[Literal["valid", "complete"]], list] = "valid",
-    xscope: Optional[slice] = None,
-    phaseplot: Optional[str] = None,
-    ax_kwargs: Optional[dict] = None,
+    max_gap: str | None = None,
+    history: Literal["valid", "complete"] | Sequence[str] | None = "valid",
+    xscope: slice | None = None,
+    phaseplots: Sequence[str] | None = None,
+    ax_kwargs: dict | None = None,
 ):
     """
     Returns a figure object, containing data graph with flag marks for field.
@@ -115,155 +119,169 @@ def makeFig(
 
     if ax_kwargs is None:
         ax_kwargs = {}
-    # data retrieval
-    d = data[field]
-    # data slicing:
-    xscope = xscope or slice(xscope)
-    d = d[xscope]
-    flags_vals = flags[field][xscope]
-    flags_hist = flags.history[field].hist.loc[xscope]
-    flags_meta = flags.history[field].meta
 
-    # set fontsize:
-    default = plt.rcParams["font.size"]
-    plt.rcParams["font.size"] = ax_kwargs.pop("fontsize", None) or default
-
-    # set shapecycle start:
-    cyclestart = ax_kwargs.pop("cycleskip", 0)
+    if phaseplots is None:
+        phaseplots: list[None] = [None] * len(fields)
+        if len(phaseplots) != len(fields):
+            raise ValueError(
+                "expected identical number of 'field' and phaseplot values"
+            )
 
-    na_mask = d.isna()
-    d = d[~na_mask]
-    if phaseplot:
-        flags_vals = flags_vals.copy()
-        flags_hist = flags_hist.copy()
-        phase_index = data[phaseplot][xscope].values
-        phase_index_d = phase_index[~na_mask]
-        na_mask.index = phase_index
-        d.index = phase_index_d
-        flags_vals.index = phase_index
-        flags_hist.index = phase_index
-        plot_kwargs = {**PLOT_KWARGS, **{"marker": "o", "linewidth": 0}}
-        ax_kwargs = {**{"xlabel": phaseplot, "ylabel": d.name}, **ax_kwargs}
-    else:
-        plot_kwargs = PLOT_KWARGS
-
-    # insert nans between values mutually spaced > max_gap
-    if max_gap and not d.empty:
-        d = _insertBlockingNaNs(d, max_gap)
+    xscope = xscope or slice(xscope)
 
     # figure composition
     fig = mpl.pyplot.figure(constrained_layout=True, **FIG_KWARGS)
     grid = fig.add_gridspec()
     ax = fig.add_subplot(grid[0])
+    ax.set_title(", ".join(data[fields].columns))
 
-    _plotVarWithFlags(
-        ax,
-        d,
-        flags_vals,
-        flags_hist,
-        flags_meta,
-        history,
-        level,
-        na_mask,
-        plot_kwargs,
-        ax_kwargs,
-        SCATTER_KWARGS,
-        cyclestart,
+    # fontsize
+    default = plt.rcParams["font.size"]
+    plt.rcParams["font.size"] = ax_kwargs.pop("fontsize", None) or default
+
+    # shapecycle start
+    cyclestart = ax_kwargs.pop("cycleskip", 0)
+
+    markers = _getMarkers(
+        histories=[flags.history[f] for f in fields], cyclestart=cyclestart
     )
 
+    for field, phaseplot in zip(fields, phaseplots):
+
+        # data retrieval
+        d = pd.Series(data[field])
+
+        # data slicing:
+        d = pd.Series(d[xscope])
+
+        flags_vals = flags[field][xscope]
+        flags_hist = flags.history[field].hist.loc[xscope]
+        flags_meta = flags.history[field].meta
+
+        na_mask = d.isna()
+        d = d[~na_mask]
+        if phaseplot:
+            flags_vals = flags_vals.copy()
+            flags_hist = flags_hist.copy()
+            phase_index = data[phaseplot][xscope].to_numpy()
+            phase_index_d = phase_index[~na_mask]
+            na_mask.index = phase_index
+            d.index = phase_index_d
+            flags_vals.index = phase_index
+            flags_hist.index = phase_index
+            plot_kwargs = {**PLOT_KWARGS, **{"marker": "o", "linewidth": 0}}
+            ax_kwargs = {**{"xlabel": phaseplot, "ylabel": d.name}, **ax_kwargs}
+        else:
+            plot_kwargs = PLOT_KWARGS
+
+        # insert nans between values mutually spaced > max_gap
+        if max_gap and not d.empty:
+            d = _insertBlockingNaNs(d, max_gap)
+
+        _plotVarWithFlags(
+            ax=ax,
+            data=d,
+            markers=markers,
+            flags=flags_vals,
+            history=flags_hist,
+            meta=flags_meta,
+            level=level,
+            na_mask=na_mask,
+            plot_kwargs=plot_kwargs,
+            ax_kwargs=ax_kwargs,
+            scatter_kwargs=SCATTER_KWARGS,
+        )
+
     plt.rcParams["font.size"] = default
+
+    # remove duplicates from legend
+    handles, labels = ax.get_legend_handles_labels()
+    lgd = dict(zip(labels, handles))
+    plt.legend(lgd.values(), lgd.keys())
+
     return fig
 
 
+def _getMarkers(
+    histories: Sequence[History], cyclestart: int
+) -> dict[str, dict[str, str]]:
+
+    shapes = SCATTER_KWARGS.get("marker", "o")
+    shapes = itertools.cycle(toSequence(shapes))
+
+    colors = SCATTER_KWARGS.get(
+        "color", plt.rcParams["axes.prop_cycle"].by_key()["color"]
+    )
+    colors = itertools.cycle(toSequence(colors))
+    for _ in range(0, cyclestart):
+        next(colors)
+        next(shapes)
+
+    out = {}
+    for hist in histories:
+        for i in hist.columns:
+            # don't account for empty histories
+            if not len(hist.meta):
+                continue
+
+            # just to be sure
+            meta_field = "label" if "label" in hist.meta[i].keys() else "func"
+            if meta_field not in hist.meta[i]:
+                continue
+
+            key = hist.meta[i][meta_field]
+            if key not in out:
+                out[key] = {"color": next(colors), "marker": next(shapes)}
+
+    return out
+
+
 def _plotVarWithFlags(
     ax,
-    datser,
-    flags_vals,
-    flags_hist,
-    flags_meta,
-    history,
+    data: pd.Series,
+    markers: dict[str, dict[str, str]],
+    flags: Flags,
+    history: History,
+    meta: list[dict],
     level,
     na_mask,
     plot_kwargs,
     ax_kwargs,
     scatter_kwargs,
-    cyclestart,
 ):
-    scatter_kwargs = scatter_kwargs.copy()
-    ax.set_title(datser.name)
-    ax.plot(datser, color="black", label="data", **plot_kwargs)
+
+    ax.plot(data, color=next(_PLOT_COLORS), label=data.name, **plot_kwargs)
     ax.set(**ax_kwargs)
-    shape_cycle = scatter_kwargs.get("marker", "o")
-    shape_cycle = itertools.cycle(toSequence(shape_cycle))
-    color_cycle = scatter_kwargs.get(
-        "color", plt.rcParams["axes.prop_cycle"].by_key()["color"]
-    )
-    color_cycle = itertools.cycle(toSequence(color_cycle))
-    for k in range(0, cyclestart):
-        next(color_cycle)
-        next(shape_cycle)
-
-    if history:
-        for i in flags_hist.columns:
-            if isinstance(history, list):
-                meta_field = "label" if "label" in flags_meta[i].keys() else "func"
-                to_plot = (
-                    flags_meta[i][meta_field]
-                    if flags_meta[i][meta_field] in history
-                    else None
-                )
-                if not to_plot:
-                    continue
-                else:
-                    hist_key = "valid"
-            else:
-                hist_key = history
-            # catch empty but existing history case (flags_meta={})
-            if len(flags_meta[i]) == 0:
-                continue
-            label = (
-                flags_meta[i]["kwargs"].get("label", None)
-                or flags_meta[i]["func"].split(".")[-1]
-            )
-            scatter_kwargs.update({"label": label})
-            flags_i = flags_hist[i].astype(float)
-            if hist_key == "complete":
-                scatter_kwargs.update(
-                    {"color": next(color_cycle), "marker": next(shape_cycle)}
-                )
-                _plotFlags(ax, datser, flags_i, na_mask, level, scatter_kwargs)
-            if hist_key == "valid":
-                # only plot those flags, that do not get altered later on:
-                mask = flags_i.eq(flags_vals)
-                flags_i[~mask] = np.nan
-                # Skip plot, if the test did not have no effect on the all over flagging result. This avoids
-                # legend overflow
-                if ~(flags_i > level).any():
-                    continue
-
-                # Also skip plot, if all flagged values are np.nans (to catch flag missing and masked results mainly)
-                temp_i = datser.index.join(flags_i.index, how="inner")
-                if datser[temp_i][flags_i[temp_i].notna()].isna().all() or (
-                    "flagMissing" in flags_meta[i]["func"]
-                ):
-                    continue
-
-                scatter_kwargs.update(
-                    {"color": next(color_cycle), "marker": next(shape_cycle)}
-                )
-                _plotFlags(
-                    ax,
-                    datser,
-                    flags_i,
-                    na_mask,
-                    level,
-                    scatter_kwargs,
-                )
-
-        ax.legend()
-    else:
-        scatter_kwargs.update({"color": next(color_cycle), "marker": next(shape_cycle)})
-        _plotFlags(ax, datser, flags_vals, na_mask, level, scatter_kwargs)
+
+    for i in history.columns:
+        key = "label" if "label" in meta[i].keys() else "func"
+        label = meta[i][key]
+
+        # only plot those flags, that do not get altered later on:
+        hist_col = history[i].astype(float)
+        hist_col[~hist_col.eq(flags)] = np.nan
+        # Skip plot, if the test did not have no effect on the all over flagging result. This avoids
+        # legend overflow
+        if ~(hist_col > level).any():
+            continue
+
+        # Also skip plot, if all flagged values are np.nans (to catch flag missing and masked results mainly)
+        temp_i = data.index.join(hist_col.index, how="inner")
+        if data[temp_i][hist_col[temp_i].notna()].isna().all() or (
+            "flagMissing" in meta[i]["func"]
+        ):
+            continue
+
+        scatter_kwargs.update(label=label, **markers[label])
+
+        _plotFlags(
+            ax,
+            data,
+            hist_col,
+            na_mask,
+            level,
+            scatter_kwargs,
+        )
 
 
 def _plotFlags(ax, datser, flags, na_mask, level, scatter_kwargs):
-- 
GitLab