Skip to content
Snippets Groups Projects
Commit cfcf8a99 authored by David Schäfer's avatar David Schäfer
Browse files

some cleanups

parent 4e728f61
No related branches found
No related tags found
No related merge requests found
...@@ -41,6 +41,10 @@ def _plot( ...@@ -41,6 +41,10 @@ def _plot(
# todo: try catch warn (once) return # todo: try catch warn (once) return
# only import if plotting is requested by the user # only import if plotting is requested by the user
import matplotlib as mpl import matplotlib as mpl
import matplotlib.pyplot as plt
from pandas.plotting import register_matplotlib_converters
# needed for datetime conversion
register_matplotlib_converters()
if not interactive_backend: if not interactive_backend:
# Import plot libs without interactivity, if not needed. This ensures that this can # Import plot libs without interactivity, if not needed. This ensures that this can
...@@ -49,13 +53,8 @@ def _plot( ...@@ -49,13 +53,8 @@ def _plot(
mpl.use("Agg") mpl.use("Agg")
else: else:
mpl.use("TkAgg") mpl.use("TkAgg")
from matplotlib import pyplot as plt
# needed for datetime conversion if np.isscalar(varname):
from pandas.plotting import register_matplotlib_converters
register_matplotlib_converters()
if not isinstance(varname, (list, set)):
varname = [varname] varname = [varname]
varname = set(varname) varname = set(varname)
...@@ -68,18 +67,18 @@ def _plot( ...@@ -68,18 +67,18 @@ def _plot(
logging.warning(f"Cannot plot column '{var}', because it is not present in data.") logging.warning(f"Cannot plot column '{var}', because it is not present in data.")
if not tmp: if not tmp:
return return
varname = tmp varnames = tmp
plots = len(varname) plots = len(varnames)
if plots > 1: if plots > 1:
fig, axes = plt.subplots(plots, 1, sharex=True) fig, axes = plt.subplots(plots, 1, sharex=True)
axes[0].set_title(title) axes[0].set_title(title)
for i, v in enumerate(varname): for i, v in enumerate(varnames):
_plotByQualtyFlag(data, v, flagger, flagmask, axes[i], plot_nans) _plotByQualityFlag(data, v, flagger, flagmask, axes[i], plot_nans)
else: else:
fig, ax = plt.subplots() fig, ax = plt.subplots()
plt.title(title) plt.title(title)
_plotByQualtyFlag(data, varname.pop(), flagger, flagmask, ax, plot_nans) _plotByQualityFlag(data, varnames.pop(), flagger, flagmask, ax, plot_nans)
# dummy plot for the label `missing` see plot_vline for more info # dummy plot for the label `missing` see plot_vline for more info
plt.plot([], [], ":", color="silver", label="missing data") plt.plot([], [], ":", color="silver", label="missing data")
...@@ -91,7 +90,7 @@ def _plot( ...@@ -91,7 +90,7 @@ def _plot(
plt.show() plt.show()
def _plotByQualtyFlag(data, varname, flagger, flagmask, ax, plot_nans): def _plotByQualityFlag(data, varname, flagger, flagmask, ax, plot_nans):
ax.set_ylabel(varname) ax.set_ylabel(varname)
x = data.index x = data.index
...@@ -106,7 +105,7 @@ def _plotByQualtyFlag(data, varname, flagger, flagmask, ax, plot_nans): ...@@ -106,7 +105,7 @@ def _plotByQualtyFlag(data, varname, flagger, flagmask, ax, plot_nans):
oldflags = flagged & ~flagmask oldflags = flagged & ~flagmask
ax.plot(x[oldflags], y[oldflags], ".", color="black", label="flagged by other test") ax.plot(x[oldflags], y[oldflags], ".", color="black", label="flagged by other test")
if plot_nans: if plot_nans:
_plot_nans(y[oldflags], 'black', ax) _plotNans(y[oldflags], 'black', ax)
# now we just want to show data that was flagged # now we just want to show data that was flagged
if flagmask is not True: if flagmask is not True:
...@@ -117,29 +116,26 @@ def _plotByQualtyFlag(data, varname, flagger, flagmask, ax, plot_nans): ...@@ -117,29 +116,26 @@ def _plotByQualtyFlag(data, varname, flagger, flagmask, ax, plot_nans):
if x.empty: if x.empty:
return return
suspicious = pd.Series(data=np.ones(len(y), dtype=bool), index=y.index)
# flag by categories
# plot UNFLAGGED (only nans are needed) # plot UNFLAGGED (only nans are needed)
flag, color = flagger.UNFLAGGED, _colors['unflagged'] flag, color = flagger.UNFLAGGED, _colors['unflagged']
flagged = flagger.isFlagged(varname, flag=flag, comparator='==') flagged = flagger.isFlagged(varname, flag=flag, comparator='==')
ax.plot(x[flagged], y[flagged], '.', color=color, label=f"flag: {flag}") ax.plot(x[flagged], y[flagged], '.', color=color, label=f"flag: {flag}")
if plot_nans: if plot_nans:
_plot_nans(y[flagged], color, ax) _plotNans(y[flagged], color, ax)
# plot GOOD # plot GOOD
flag, color = flagger.GOOD, _colors['good'] flag, color = flagger.GOOD, _colors['good']
flagged = flagger.isFlagged(varname, flag=flag, comparator='==') flagged = flagger.isFlagged(varname, flag=flag, comparator='==')
ax.plot(x[flagged], y[flagged], '.', color=color, label=f"flag: {flag}") ax.plot(x[flagged], y[flagged], '.', color=color, label=f"flag: {flag}")
if plot_nans: if plot_nans:
_plot_nans(y[flagged], color, ax) _plotNans(y[flagged], color, ax)
# plot BAD # plot BAD
flag, color = flagger.BAD, _colors['bad'] flag, color = flagger.BAD, _colors['bad']
flagged = flagger.isFlagged(varname, flag=flag, comparator='==') flagged = flagger.isFlagged(varname, flag=flag, comparator='==')
ax.plot(x[flagged], y[flagged], '.', color=color, label=f"flag: {flag}") ax.plot(x[flagged], y[flagged], '.', color=color, label=f"flag: {flag}")
if plot_nans: if plot_nans:
_plot_nans(y[flagged], color, ax) _plotNans(y[flagged], color, ax)
# plot SUSPICIOS # plot SUSPICIOS
color = _colors['suspicious'] color = _colors['suspicious']
...@@ -147,18 +143,17 @@ def _plotByQualtyFlag(data, varname, flagger, flagmask, ax, plot_nans): ...@@ -147,18 +143,17 @@ def _plotByQualtyFlag(data, varname, flagger, flagmask, ax, plot_nans):
flagged &= flagger.isFlagged(varname, flag=flagger.BAD, comparator='<') flagged &= flagger.isFlagged(varname, flag=flagger.BAD, comparator='<')
ax.plot(x[flagged], y[flagged], '.', color=color, label=f"{flagger.GOOD} < flag < {flagger.BAD}") ax.plot(x[flagged], y[flagged], '.', color=color, label=f"{flagger.GOOD} < flag < {flagger.BAD}")
if plot_nans: if plot_nans:
_plot_nans(y[flagged], color, ax) _plotNans(y[flagged], color, ax)
def _plot_nans(y, color, ax): def _plotNans(y, color, ax):
nans = y.isna() nans = y.isna()
_plotVline(ax, y[nans].index, color=color) _plotVline(ax, y[nans].index, color=color)
def _plotVline(plt, points, color="blue"): def _plotVline(ax, points, color="blue"):
# workaround for ax.vlines() as this work unexpected # workaround for ax.vlines() as this work unexpected
# normally this should work like so: # normally this should work like so:
# ax.vlines(idx, *ylim, linestyles=':', color='silver', label="missing") # ax.vlines(idx, *ylim, linestyles=':', color='silver', label="missing")
for point in points: for point in points:
plt.axvline(point, color=color, linestyle=":") ax.axvline(point, color=color, linestyle=":")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment