From 05542884e8f310b5eb52f9936c9c0a82f01e3275 Mon Sep 17 00:00:00 2001
From: Bert Palm <bert.palm@ufz.de>
Date: Thu, 6 Jun 2019 12:35:51 +0200
Subject: [PATCH] plot fixes and final plot

---
 core.py | 27 ++++++++++++++++++++++++---
 1 file changed, 24 insertions(+), 3 deletions(-)

diff --git a/core.py b/core.py
index dd1df1828..4268fb98d 100644
--- a/core.py
+++ b/core.py
@@ -4,6 +4,7 @@
 import numpy as np
 import pandas as pd
 import matplotlib as mpl
+from warnings import warn
 
 from config import Fields, Params
 from funcs import flagDispatch
@@ -46,6 +47,8 @@ def flagNext(flagger, flags, mask=True, flag_values=0, **kwargs) -> pd.Series:
 
 
 def runner(meta, flagger, data, flags=None, nodata=np.nan):
+    plotvars = []
+
     if flags is None:
         flags = pd.DataFrame(index=data.index)
 
@@ -118,7 +121,8 @@ def runner(meta, flagger, data, flags=None, nodata=np.nan):
                 fchunk = fchunk.astype({
                     c: flagger.flags for c in fchunk.columns if flagger.flag_fields[0] in c})
 
-            if Params.PLOT in flag_params:
+            if flag_params.get(Params.PLOT, False):
+                plotvars.append(varname)
                 new = flagger.getFlags(fchunk[varname])
                 mask = old != new
                 plot(dchunk, fchunk, mask, varname, flagger, title=flag_test)
@@ -127,6 +131,11 @@ def runner(meta, flagger, data, flags=None, nodata=np.nan):
             flags[start_date:end_date] = fchunk.squeeze()
 
         flagger.nextTest()
+
+    # plot all together
+    if plotvars:
+        plot(data, flags, True, set(plotvars), flagger)
+
     return data, flags
 
 
@@ -144,6 +153,20 @@ def plot(data, flags, flagmask, varname, flagger, interactive_backend=True, titl
     from pandas.plotting import register_matplotlib_converters
     register_matplotlib_converters()
 
+    if not isinstance(varname, (list, set)):
+        varname = set([varname])
+
+    tmp = []
+    for var in varname:
+        if var not in data.columns:
+            warn(f"Cannot plot column '{var}' that is not present in data.", UserWarning)
+        else:
+            tmp.append(var)
+    if tmp:
+        varname = tmp
+    else:
+        return
+
     def plot_vline(plt, points, color='blue'):
         # workaround for ax.vlines() as this work unexpected
         for point in points:
@@ -189,8 +212,6 @@ def plot(data, flags, flagmask, varname, flagger, interactive_backend=True, titl
             # ax.vlines(idx, *ylim, linestyles=':', color=colors[i])
             plot_vline(ax, idx, color=colors[i])
 
-    if not isinstance(varname, (list, set)):
-        varname = set([varname])
     plots = len(varname)
     if plots > 1:
         fig, axes = plt.subplots(plots, 1, sharex=True)
-- 
GitLab