From cfcf8a995cc17dbf9d6d0a2106977627b9623d4b Mon Sep 17 00:00:00 2001
From: David Schaefer <david.schaefer@ufz.de>
Date: Wed, 11 Dec 2019 21:00:13 +0100
Subject: [PATCH] some cleanups

---
 saqc/lib/plotting.py | 43 +++++++++++++++++++------------------------
 1 file changed, 19 insertions(+), 24 deletions(-)

diff --git a/saqc/lib/plotting.py b/saqc/lib/plotting.py
index 4d3f85bae..72ae15a72 100644
--- a/saqc/lib/plotting.py
+++ b/saqc/lib/plotting.py
@@ -41,6 +41,10 @@ def _plot(
     # todo: try catch warn (once) return
     # only import if plotting is requested by the user
     import matplotlib as mpl
+    import matplotlib.pyplot as plt
+    from pandas.plotting import register_matplotlib_converters
+    # needed for datetime conversion
+    register_matplotlib_converters()
 
     if not interactive_backend:
         # Import plot libs without interactivity, if not needed. This ensures that this can
@@ -49,13 +53,8 @@ def _plot(
         mpl.use("Agg")
     else:
         mpl.use("TkAgg")
-    from matplotlib import pyplot as plt
 
-    # needed for datetime conversion
-    from pandas.plotting import register_matplotlib_converters
-    register_matplotlib_converters()
-
-    if not isinstance(varname, (list, set)):
+    if np.isscalar(varname):
         varname = [varname]
     varname = set(varname)
 
@@ -68,18 +67,18 @@ def _plot(
             logging.warning(f"Cannot plot column '{var}', because it is not present in data.")
     if not tmp:
         return
-    varname = tmp
+    varnames = tmp
 
-    plots = len(varname)
+    plots = len(varnames)
     if plots > 1:
         fig, axes = plt.subplots(plots, 1, sharex=True)
         axes[0].set_title(title)
-        for i, v in enumerate(varname):
-            _plotByQualtyFlag(data, v, flagger, flagmask, axes[i], plot_nans)
+        for i, v in enumerate(varnames):
+            _plotByQualityFlag(data, v, flagger, flagmask, axes[i], plot_nans)
     else:
         fig, ax = plt.subplots()
         plt.title(title)
-        _plotByQualtyFlag(data, varname.pop(), flagger, flagmask, ax, plot_nans)
+        _plotByQualityFlag(data, varnames.pop(), flagger, flagmask, ax, plot_nans)
 
     # dummy plot for the label `missing` see plot_vline for more info
     plt.plot([], [], ":", color="silver", label="missing data")
@@ -91,7 +90,7 @@ def _plot(
         plt.show()
 
 
-def _plotByQualtyFlag(data, varname, flagger, flagmask, ax, plot_nans):
+def _plotByQualityFlag(data, varname, flagger, flagmask, ax, plot_nans):
     ax.set_ylabel(varname)
 
     x = data.index
@@ -106,7 +105,7 @@ def _plotByQualtyFlag(data, varname, flagger, flagmask, ax, plot_nans):
     oldflags = flagged & ~flagmask
     ax.plot(x[oldflags], y[oldflags], ".", color="black", label="flagged by other test")
     if plot_nans:
-        _plot_nans(y[oldflags], 'black', ax)
+        _plotNans(y[oldflags], 'black', ax)
 
     # now we just want to show data that was flagged
     if flagmask is not True:
@@ -117,29 +116,26 @@ def _plotByQualtyFlag(data, varname, flagger, flagmask, ax, plot_nans):
     if x.empty:
         return
 
-    suspicious = pd.Series(data=np.ones(len(y), dtype=bool), index=y.index)
-    # flag by categories
-
     # plot UNFLAGGED (only nans are needed)
     flag, color = flagger.UNFLAGGED, _colors['unflagged']
     flagged = flagger.isFlagged(varname, flag=flag, comparator='==')
     ax.plot(x[flagged], y[flagged], '.', color=color, label=f"flag: {flag}")
     if plot_nans:
-        _plot_nans(y[flagged], color, ax)
+        _plotNans(y[flagged], color, ax)
 
     # plot GOOD
     flag, color = flagger.GOOD, _colors['good']
     flagged = flagger.isFlagged(varname, flag=flag, comparator='==')
     ax.plot(x[flagged], y[flagged], '.', color=color, label=f"flag: {flag}")
     if plot_nans:
-        _plot_nans(y[flagged], color, ax)
+        _plotNans(y[flagged], color, ax)
 
     # plot BAD
     flag, color = flagger.BAD, _colors['bad']
     flagged = flagger.isFlagged(varname, flag=flag, comparator='==')
     ax.plot(x[flagged], y[flagged], '.', color=color, label=f"flag: {flag}")
     if plot_nans:
-        _plot_nans(y[flagged], color, ax)
+        _plotNans(y[flagged], color, ax)
 
     # plot SUSPICIOS
     color = _colors['suspicious']
@@ -147,18 +143,17 @@ def _plotByQualtyFlag(data, varname, flagger, flagmask, ax, plot_nans):
     flagged &= flagger.isFlagged(varname, flag=flagger.BAD, comparator='<')
     ax.plot(x[flagged], y[flagged], '.', color=color, label=f"{flagger.GOOD} < flag < {flagger.BAD}")
     if plot_nans:
-        _plot_nans(y[flagged], color, ax)
+        _plotNans(y[flagged], color, ax)
 
 
-def _plot_nans(y, color, ax):
+def _plotNans(y, color, ax):
     nans = y.isna()
     _plotVline(ax, y[nans].index, color=color)
 
 
-def _plotVline(plt, points, color="blue"):
+def _plotVline(ax, points, color="blue"):
     # workaround for ax.vlines() as this work unexpected
     # normally this should work like so:
     #   ax.vlines(idx, *ylim, linestyles=':', color='silver', label="missing")
     for point in points:
-        plt.axvline(point, color=color, linestyle=":")
-
+        ax.axvline(point, color=color, linestyle=":")
-- 
GitLab