diff --git a/core.py b/core.py index dd1df1828acf22e10f78899041388c8c02be99b0..4268fb98de2a544e78dfaf99da318b445ca03313 100644 --- a/core.py +++ b/core.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd import matplotlib as mpl +from warnings import warn from config import Fields, Params from funcs import flagDispatch @@ -46,6 +47,8 @@ def flagNext(flagger, flags, mask=True, flag_values=0, **kwargs) -> pd.Series: def runner(meta, flagger, data, flags=None, nodata=np.nan): + plotvars = [] + if flags is None: flags = pd.DataFrame(index=data.index) @@ -118,7 +121,8 @@ def runner(meta, flagger, data, flags=None, nodata=np.nan): fchunk = fchunk.astype({ c: flagger.flags for c in fchunk.columns if flagger.flag_fields[0] in c}) - if Params.PLOT in flag_params: + if flag_params.get(Params.PLOT, False): + plotvars.append(varname) new = flagger.getFlags(fchunk[varname]) mask = old != new plot(dchunk, fchunk, mask, varname, flagger, title=flag_test) @@ -127,6 +131,11 @@ def runner(meta, flagger, data, flags=None, nodata=np.nan): flags[start_date:end_date] = fchunk.squeeze() flagger.nextTest() + + # plot all together + if plotvars: + plot(data, flags, True, set(plotvars), flagger) + return data, flags @@ -144,6 +153,20 @@ def plot(data, flags, flagmask, varname, flagger, interactive_backend=True, titl from pandas.plotting import register_matplotlib_converters register_matplotlib_converters() + if not isinstance(varname, (list, set)): + varname = set([varname]) + + tmp = [] + for var in varname: + if var not in data.columns: + warn(f"Cannot plot column '{var}' that is not present in data.", UserWarning) + else: + tmp.append(var) + if tmp: + varname = tmp + else: + return + def plot_vline(plt, points, color='blue'): # workaround for ax.vlines() as this work unexpected for point in points: @@ -189,8 +212,6 @@ def plot(data, flags, flagmask, varname, flagger, interactive_backend=True, titl # ax.vlines(idx, *ylim, linestyles=':', color=colors[i]) plot_vline(ax, idx, color=colors[i]) - if not isinstance(varname, (list, set)): - varname = set([varname]) plots = len(varname) if plots > 1: fig, axes = plt.subplots(plots, 1, sharex=True)