Skip to content
Snippets Groups Projects
Commit 31762a1a authored by Bert Palm's avatar Bert Palm 🎇
Browse files

prepared flagger refactoring

parent 85032515
No related branches found
No related tags found
3 merge requests!271Static expansion of regular expressions,!260Follow-Up Translations,!237Flagger Translations
......@@ -95,12 +95,12 @@ def main(config, data, flagger, outfile, nodata, log_level, fail):
data_result, flagger_result = saqc.readConfig(config).getResult(raw=True)
if outfile:
data_result = data_result.to_df()
flags = flagger_result.toFrame()
unflagged = (flags == UNFLAGGED) | flags.isna()
flags[unflagged] = GOOD
data_frame = data_result.to_df()
flags_frame = flagger_result.toFrame()
unflagged = (flags_frame == UNFLAGGED) | flags_frame.isna()
flags_frame[unflagged] = GOOD
fields = {"data": data_result, "flags": flags}
fields = {"data": data_frame, "flags": flags_frame}
out = (
pd.concat(fields.values(), axis=1, keys=fields.keys())
......
......@@ -249,14 +249,14 @@ def _maskData(data, flagger, columns, thresh) -> Tuple[dios.DictOfSeries, dios.D
return data, mask
def _isflagged(flags: Union[np.array, pd.Series], thresh: float) -> Union[np.array, pd.Series]:
def _isflagged(flagscol: Union[np.array, pd.Series], thresh: float) -> Union[np.array, pd.Series]:
"""
Return a mask of flags accordingly to `thresh`. Return type is same as flags.
"""
if thresh == UNFLAGGED:
return flags > UNFLAGGED
return flagscol > UNFLAGGED
return flags >= thresh
return flagscol >= thresh
def _prepareFlags(flagger: Flagger, masking) -> Flagger:
......
......@@ -124,7 +124,7 @@ def flagIsolated(
mask = data[field].isna()
flags = pd.Series(data=0, index=mask.index, dtype=bool)
bools = pd.Series(data=0, index=mask.index, dtype=bool)
for srs in groupConsecutives(mask):
if np.all(~srs):
start = srs.index[0]
......@@ -134,7 +134,7 @@ def flagIsolated(
if left.all():
right = mask[stop: stop + gap_window].iloc[1:]
if right.all():
flags[start:stop] = True
bools[start:stop] = True
flagger[mask, field] = flag
return data, flagger
......
......@@ -86,10 +86,10 @@ def test_dtypes(data, flags):
Test if the categorical dtype is preserved through the core functionality
"""
flagger = initFlagsLike(data)
flags = flagger.toDios()
flags_raw = flagger.toDios()
var1, var2 = data.columns[:2]
pdata, pflagger = SaQC(data, flags=flags).flagAll(var1).flagAll(var2).getResult(raw=True)
pdata, pflagger = SaQC(data, flags=flags_raw).flagAll(var1).flagAll(var2).getResult(raw=True)
for c in pflagger.columns:
assert pflagger[c].dtype == flagger[c].dtype
......
......@@ -23,8 +23,8 @@ def test_constants_flagBasic(data):
field, *_ = data.columns
flagger = initFlagsLike(data)
data, flagger_result = flagConstants(data, field, flagger, window="15Min", thresh=0.1, flag=BAD)
flags = flagger_result[field]
assert np.all(flags[expected] == BAD)
flagscol = flagger_result[field]
assert np.all(flagscol[expected] == BAD)
def test_constants_flagVarianceBased(data):
......
......@@ -54,20 +54,20 @@ def test_modelling_mask(dat):
common = dict(data=data, field=field, flagger=flagger, mode='periodic')
data_seasonal, flagger_seasonal = mask(**common, period_start="20:00", period_end="40:00", include_bounds=False)
flags = flagger_seasonal[field]
m = (20 <= flags.index.minute) & (flags.index.minute <= 40)
flagscol = flagger_seasonal[field]
m = (20 <= flagscol.index.minute) & (flagscol.index.minute <= 40)
assert all(flagger_seasonal[field][m] == UNFLAGGED)
assert all(data_seasonal[field][m].isna())
data_seasonal, flagger_seasonal = mask(**common, period_start="15:00:00", period_end="02:00:00")
flags = flagger_seasonal[field]
m = (15 <= flags.index.hour) & (flags.index.hour <= 2)
flagscol = flagger_seasonal[field]
m = (15 <= flagscol.index.hour) & (flagscol.index.hour <= 2)
assert all(flagger_seasonal[field][m] == UNFLAGGED)
assert all(data_seasonal[field][m].isna())
data_seasonal, flagger_seasonal = mask(**common, period_start="03T00:00:00", period_end="10T00:00:00")
flags = flagger_seasonal[field]
m = (3 <= flags.index.hour) & (flags.index.hour <= 10)
flagscol = flagger_seasonal[field]
m = (3 <= flagscol.index.hour) & (flagscol.index.hour <= 10)
assert all(flagger_seasonal[field][m] == UNFLAGGED)
assert all(data_seasonal[field][m].isna())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment