Skip to content
Snippets Groups Projects
Commit b3da4544 authored by Bert Palm's avatar Bert Palm 🎇
Browse files

merge-request reviewed, one bug left to fix, tests show it

parent c7414bd7
No related branches found
No related tags found
No related merge requests found
......@@ -80,6 +80,8 @@ class BaseFlagger(ABC):
return flags.loc[mask, field]
def setFlags(self, field, loc=None, iloc=None, flag=None, force=False, **kwargs):
if field is None:
raise ValueError('field cannot be None')
flag = self.BAD if flag is None else flag
......@@ -96,7 +98,6 @@ class BaseFlagger(ABC):
def clearFlags(self, field, loc=None, iloc=None, **kwargs):
if field is None:
# NOTE: I don't see a need for this restriction
raise ValueError('field cannot be None')
return self.setFlags(field=field, loc=loc, iloc=iloc, flag=self.UNFLAGGED, force=True)
......
......@@ -56,13 +56,13 @@ class CategoricalBaseFlagger(BaseFlagger):
def setFlags(self, field, loc=None, iloc=None, flag=None, force=False, **kwargs):
return super().setFlags(
field=field, loc=loc, iloc=iloc,
flag=self._checkFlags(flag), force=force,
flag=self._checkFlag(flag), force=force,
**kwargs)
def isFlagged(self, field=None, loc=None, iloc=None, flag=None, comparator: str = ">", **kwargs):
return super().isFlagged(
field=field, loc=loc, iloc=iloc,
flag=self._checkFlags(flag), comparator=comparator,
flag=self._checkFlag(flag), comparator=comparator,
**kwargs)
def _assureDtype(self, flags):
......@@ -83,7 +83,7 @@ class CategoricalBaseFlagger(BaseFlagger):
return isinstance(f.dtype, pd.CategoricalDtype) and f.dtype == self.dtype
return f in self.dtype.categories
def _checkFlags(self, flag):
def _checkFlag(self, flag):
if flag is not None and not self._isCategorical(flag):
raise TypeError(
f"invalid flag '{flag}', possible choices are '{list(self.dtype.categories)}'")
......
......@@ -55,8 +55,10 @@ class DmpFlagger(CategoricalBaseFlagger):
return super()._assureDtype(flags.loc[mask, field])
def setFlags(self, field, loc=None, iloc=None, flag=None, force=False, comment='', cause='', **kwargs):
if field is None:
raise ValueError('field cannot be None')
flag = self.BAD if flag is None else self._checkFlags(flag)
flag = self.BAD if flag is None else self._checkFlag(flag)
comment = json.dumps({"comment": comment,
"commit": self.project_version,
......
......@@ -26,16 +26,16 @@ def get_dataset(rows, cols):
field = 'var0'
DATASETS = [
get_dataset(0, 1),
get_dataset(1, 1),
# get_dataset(0, 1),
# get_dataset(1, 1),
get_dataset(100, 1),
get_dataset(1000, 1),
get_dataset(0, 4),
get_dataset(1, 4),
# get_dataset(1000, 1),
# get_dataset(0, 4),
# get_dataset(1, 4),
get_dataset(100, 4),
get_dataset(1000, 4),
get_dataset(10000, 40),
get_dataset(20, 4),
# get_dataset(1000, 4),
# get_dataset(10000, 40),
# get_dataset(20, 4),
]
......@@ -95,6 +95,11 @@ def test_isFlagged(data, flagger):
# both the same
assert (flagged0[field] == flagged1).all()
flag = pd.Series(index=data.index, data=flagger.BAD).astype(flagger.dtype)
# fixme !!
flagger.isFlagged(flag=flag)
@pytest.mark.parametrize('data', DATASETS)
@pytest.mark.parametrize('flagger', TESTFLAGGER)
......@@ -121,6 +126,10 @@ def test_setFlags(data, flagger):
flagger_forced_good = flagger_bad.setFlags(field, flag=flagger.GOOD, force=True)
assert (flagger_forced_good.getFlags(field) == flagger.GOOD).all()
with pytest.raises(ValueError):
flagger.setFlags(field=None, flag=flagger.BAD)
@pytest.mark.parametrize('data', DATASETS)
@pytest.mark.parametrize('flagger', TESTFLAGGER)
......@@ -166,16 +175,16 @@ def test_dtype(data, flagger):
@pytest.mark.parametrize('data', DATASETS)
@pytest.mark.parametrize('flagger', TESTFLAGGER)
def test_returnCopy(data, flagger):
flagger.initFlags(data)
origin = flagger.getFlags()
origin = flagger.initFlags(data)
origin_data = origin.getFlags()
f = flagger.getFlags()
assert f is not origin
assert f is not origin_data
f = flagger.isFlagged()
assert f is not origin_data
f = flagger.setFlags(field)
assert f is not origin
f = flagger.setFlags(field).getFlags()
assert f is not origin
f = flagger.clearFlags(field).getFlags()
f = flagger.clearFlags(field)
assert f is not origin
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment