Skip to content
Snippets Groups Projects
Commit bc70cf14 authored by David Schäfer's avatar David Schäfer
Browse files

bugfix: plotting failed for flagger with _flags attribute of type

        pd.DataFrame
parent ca5b36e6
No related branches found
No related tags found
No related merge requests found
......@@ -8,11 +8,13 @@ import pandas as pd
from saqc.core import run
from saqc.flagger import CategoricalFlagger
from saqc.flagger.dmpflagger import DmpFlagger, FlagFields
FLAGGERS = {
"numeric": CategoricalFlagger([-1, 0, 1]),
"category": CategoricalFlagger(["NIL", "OK", "BAD"]),
"dmp": DmpFlagger()
}
......@@ -44,9 +46,20 @@ def main(config, data, flagger, outfile, nodata, fail):
if outfile:
flags = flagger_result.getFlags()
flags_out = flags.where((flags.isnull() | flagger_result.isFlagged()), flagger_result.GOOD)
cols_out = sum([[c, c + "_flags"] for c in flags_out], [])
data_out = data_result.join(flags_out, rsuffix="_flags")
data_out[cols_out].to_csv(outfile, header=True, index=True)
if isinstance(flagger_result, DmpFlagger):
flags = flagger_result._flags
flags.loc[flags_out.index, (slice(None), FlagFields.FLAG)] = flags_out.values
flags_out = flags
if not isinstance(flags_out.columns, pd.MultiIndex):
flags_out.columns = pd.MultiIndex.from_product([flags.columns, ["flag"]])
data_result.columns = pd.MultiIndex.from_product([data_result.columns, ["data"]])
# flags_out.columns = flags_out.columns.map("_".join)
data_out = data_result.join(flags_out)
data_out.sort_index(axis="columns").to_csv(outfile, header=True, index=True, na_rep="")
if __name__ == "__main__":
......
......@@ -49,12 +49,14 @@ class DmpFlagger(CategoricalFlagger):
"""
if data is not None:
flags = pd.DataFrame(data=self.UNFLAGGED, columns=self._getColumnIndex(data.columns), index=data.index,)
flags = pd.DataFrame(data="", columns=self._getColumnIndex(data.columns), index=data.index,)
flags.loc[:, self._getColumnIndex(data.columns, [FlagFields.FLAG])] = self.UNFLAGGED
elif flags is not None:
if not isinstance(flags.columns, pd.MultiIndex):
flags = flags.T.set_index(keys=self._getColumnIndex(flags.columns, [FlagFields.FLAG])).T.reindex(
columns=self._getColumnIndex(flags.columns)
)
cols = flags.columns
flags = flags.copy()
flags.columns = self._getColumnIndex(cols, [FlagFields.FLAG])
flags = flags.reindex(columns=self._getColumnIndex(cols), fill_value="")
else:
raise TypeError("either 'data' or 'flags' are required")
......
......@@ -65,6 +65,8 @@ def plotHook(
return
mask = flags_old != flags_new
if isinstance(mask, pd.DataFrame):
mask = mask.any(axis=1)
__plotvars.append(varname)
_plot(data, flagger_new, mask, varname, title=flag_test, plot_nans=plot_nans)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment