Skip to content
Snippets Groups Projects
Commit cfcf8a99 authored by David Schäfer's avatar David Schäfer
Browse files

some cleanups

parent 4e728f61
No related branches found
No related tags found
No related merge requests found
......@@ -41,6 +41,10 @@ def _plot(
# todo: try catch warn (once) return
# only import if plotting is requested by the user
import matplotlib as mpl
import matplotlib.pyplot as plt
from pandas.plotting import register_matplotlib_converters
# needed for datetime conversion
register_matplotlib_converters()
if not interactive_backend:
# Import plot libs without interactivity, if not needed. This ensures that this can
......@@ -49,13 +53,8 @@ def _plot(
mpl.use("Agg")
else:
mpl.use("TkAgg")
from matplotlib import pyplot as plt
# needed for datetime conversion
from pandas.plotting import register_matplotlib_converters
register_matplotlib_converters()
if not isinstance(varname, (list, set)):
if np.isscalar(varname):
varname = [varname]
varname = set(varname)
......@@ -68,18 +67,18 @@ def _plot(
logging.warning(f"Cannot plot column '{var}', because it is not present in data.")
if not tmp:
return
varname = tmp
varnames = tmp
plots = len(varname)
plots = len(varnames)
if plots > 1:
fig, axes = plt.subplots(plots, 1, sharex=True)
axes[0].set_title(title)
for i, v in enumerate(varname):
_plotByQualtyFlag(data, v, flagger, flagmask, axes[i], plot_nans)
for i, v in enumerate(varnames):
_plotByQualityFlag(data, v, flagger, flagmask, axes[i], plot_nans)
else:
fig, ax = plt.subplots()
plt.title(title)
_plotByQualtyFlag(data, varname.pop(), flagger, flagmask, ax, plot_nans)
_plotByQualityFlag(data, varnames.pop(), flagger, flagmask, ax, plot_nans)
# dummy plot for the label `missing` see plot_vline for more info
plt.plot([], [], ":", color="silver", label="missing data")
......@@ -91,7 +90,7 @@ def _plot(
plt.show()
def _plotByQualtyFlag(data, varname, flagger, flagmask, ax, plot_nans):
def _plotByQualityFlag(data, varname, flagger, flagmask, ax, plot_nans):
ax.set_ylabel(varname)
x = data.index
......@@ -106,7 +105,7 @@ def _plotByQualtyFlag(data, varname, flagger, flagmask, ax, plot_nans):
oldflags = flagged & ~flagmask
ax.plot(x[oldflags], y[oldflags], ".", color="black", label="flagged by other test")
if plot_nans:
_plot_nans(y[oldflags], 'black', ax)
_plotNans(y[oldflags], 'black', ax)
# now we just want to show data that was flagged
if flagmask is not True:
......@@ -117,29 +116,26 @@ def _plotByQualtyFlag(data, varname, flagger, flagmask, ax, plot_nans):
if x.empty:
return
suspicious = pd.Series(data=np.ones(len(y), dtype=bool), index=y.index)
# flag by categories
# plot UNFLAGGED (only nans are needed)
flag, color = flagger.UNFLAGGED, _colors['unflagged']
flagged = flagger.isFlagged(varname, flag=flag, comparator='==')
ax.plot(x[flagged], y[flagged], '.', color=color, label=f"flag: {flag}")
if plot_nans:
_plot_nans(y[flagged], color, ax)
_plotNans(y[flagged], color, ax)
# plot GOOD
flag, color = flagger.GOOD, _colors['good']
flagged = flagger.isFlagged(varname, flag=flag, comparator='==')
ax.plot(x[flagged], y[flagged], '.', color=color, label=f"flag: {flag}")
if plot_nans:
_plot_nans(y[flagged], color, ax)
_plotNans(y[flagged], color, ax)
# plot BAD
flag, color = flagger.BAD, _colors['bad']
flagged = flagger.isFlagged(varname, flag=flag, comparator='==')
ax.plot(x[flagged], y[flagged], '.', color=color, label=f"flag: {flag}")
if plot_nans:
_plot_nans(y[flagged], color, ax)
_plotNans(y[flagged], color, ax)
# plot SUSPICIOS
color = _colors['suspicious']
......@@ -147,18 +143,17 @@ def _plotByQualtyFlag(data, varname, flagger, flagmask, ax, plot_nans):
flagged &= flagger.isFlagged(varname, flag=flagger.BAD, comparator='<')
ax.plot(x[flagged], y[flagged], '.', color=color, label=f"{flagger.GOOD} < flag < {flagger.BAD}")
if plot_nans:
_plot_nans(y[flagged], color, ax)
_plotNans(y[flagged], color, ax)
def _plot_nans(y, color, ax):
def _plotNans(y, color, ax):
nans = y.isna()
_plotVline(ax, y[nans].index, color=color)
def _plotVline(plt, points, color="blue"):
def _plotVline(ax, points, color="blue"):
# workaround for ax.vlines() as this work unexpected
# normally this should work like so:
# ax.vlines(idx, *ylim, linestyles=':', color='silver', label="missing")
for point in points:
plt.axvline(point, color=color, linestyle=":")
ax.axvline(point, color=color, linestyle=":")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment