Skip to content
Snippets Groups Projects
Commit 3a05303b authored by David Schäfer's avatar David Schäfer
Browse files

fixed clearFlags and its test

parent 63df0dc9
No related branches found
No related tags found
No related merge requests found
......@@ -97,9 +97,10 @@ class BaseFlagger:
def clearFlags(self, flags, field, loc=None, iloc=None, **kwargs):
check_isdf(flags, 'flags', allow_multiindex=False)
flags_loc, rows, col = self._getIndexer(flags, field, loc, iloc)
out = flags.copy()
flags_loc, rows, col = self._getIndexer(out, field, loc, iloc)
flags_loc[rows, col] = self.UNFLAGGED
return self._assureDtype(flags, field)
return self._assureDtype(out, field)
def _checkFlag(self, flag):
if isinstance(flag, pd.Series):
......
......@@ -40,7 +40,7 @@ class DmpFlagger(BaseFlagger):
colindex = pd.MultiIndex.from_product(
[data.columns, self.flags_fields],
names=[ColumnLevels.VARIABLES, ColumnLevels.FLAGS])
flags = pd.DataFrame(data=self.categories[0], columns=colindex, index=data.index)
flags = pd.DataFrame(data=self.UNFLAGGED, columns=colindex, index=data.index)
return self._assureDtype(flags)
def isFlagged(self, flags: PandasLike, flag=None, comparator=">") -> PandasLike:
......@@ -91,10 +91,10 @@ class DmpFlagger(BaseFlagger):
def clearFlags(self, flags, field, loc=None, iloc=None, **kwargs):
check_isdfmi(flags, 'flags')
flags = flags.copy()
indexer, rows, col = self._getIndexer(flags, field, loc, iloc)
indexer[rows, col] = self.UNFLAGGED, '', ''
return self._assureDtype(flags, field)
out = flags.copy()
indexer, rows, col = self._getIndexer(out, field, loc, iloc)
indexer[rows, col] = self.UNFLAGGED
return self._assureDtype(out, field)
def _assureDtype(self, flags, field=None):
if field is None:
......
......@@ -56,19 +56,13 @@ def test_flagSesonalRange(flagger):
assert len(flags[flagged]) == 31 * 4 / 2
@pytest.mark.parametrize('flagger', TESTFLAGGERS)
def test_clearFlags(flagger):
# prepare
field = 'testdata'
index = pd.date_range(start='2011-01-01', end='2011-01-10', freq='1d')
data = pd.DataFrame(data={field: np.linspace(0, index.size - 1, index.size)}, index=index)
@pytest.mark.parametrize('flagger', TESTFLAGGER)
def test_clearFlags(data, field, flagger):
orig = flagger.initFlags(data)
flags = orig.copy()
# test
flags = flagger.setFlags(flags, field)
assert (orig != flags).all
flags = flagger.setFlags(orig.copy(), field, flag=flagger.BAD)
_, cleared = clearFlags(data, flags, field, flagger)
assert (orig == cleared).all
assert np.all(orig != flags)
assert np.all(orig == cleared)
@pytest.mark.parametrize('flagger', TESTFLAGGER)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment