Skip to content
Snippets Groups Projects
Commit d04d26f9 authored by Bert Palm's avatar Bert Palm 🎇
Browse files

added 'BT.empty'. Improved tests.

parent a2b62a86
No related branches found
No related tags found
1 merge request!218Flags
...@@ -93,7 +93,7 @@ class Backtrack: ...@@ -93,7 +93,7 @@ class Backtrack:
Returns Returns
------- -------
index: pd.Index index : pd.Index
""" """
return self.bt.index return self.bt.index
...@@ -107,10 +107,26 @@ class Backtrack: ...@@ -107,10 +107,26 @@ class Backtrack:
Returns Returns
------- -------
columns: pd.Index columns : pd.Index
""" """
return self.bt.columns return self.bt.columns
@property
def empty(self) -> bool:
"""
Indicator whether Backtrack is empty.
True if Backtrack is entirely empty (no items).
Returns
-------
bool
If Backtrack is empty, return True, if not return False.
"""
# we take self.mask here, because it cannot have NaN's,
# but self.bt could have -> see pd.DataFrame.empty
return self.mask.empty
def _insert(self, s: pd.Series, nr: int, force=False) -> Backtrack: def _insert(self, s: pd.Series, nr: int, force=False) -> Backtrack:
""" """
Insert data at an arbitrary position in the BT. Insert data at an arbitrary position in the BT.
...@@ -172,7 +188,7 @@ class Backtrack: ...@@ -172,7 +188,7 @@ class Backtrack:
if s.empty: if s.empty:
raise ValueError('Cannot append empty pd.Series') raise ValueError('Cannot append empty pd.Series')
if not self.bt.empty and not s.index.equals(self.index): if not self.empty and not s.index.equals(self.index):
raise ValueError("Index must be equal to BT's index") raise ValueError("Index must be equal to BT's index")
self._insert(value, nr=len(self)) self._insert(value, nr=len(self))
...@@ -280,12 +296,15 @@ class Backtrack: ...@@ -280,12 +296,15 @@ class Backtrack:
if any(mask.dtypes != bool): if any(mask.dtypes != bool):
raise ValueError("dtype of all columns in 'mask' must be bool") raise ValueError("dtype of all columns in 'mask' must be bool")
if not mask.empty and not mask.iloc[:, -1].all():
raise ValueError("the values in the last column in mask must be 'True' everywhere.")
# check combination of bt and mask # check combination of bt and mask
if not obj.columns.equals(mask.columns): if not obj.columns.equals(mask.columns):
raise ValueError("'BT' and 'mask' must have same columns") raise ValueError("'bt' and 'mask' must have same columns")
if not obj.index.equals(mask.columns): if not obj.index.equals(mask.index):
raise ValueError("'BT' and 'mask' must have same index") raise ValueError("'bt' and 'mask' must have same index")
return obj, mask return obj, mask
......
...@@ -31,8 +31,8 @@ data = [ ...@@ -31,8 +31,8 @@ data = [
def check_invariants(bt): def check_invariants(bt):
""" """
this can be called for **any** BT and This can be called for **any** BT.
should never fail. The assertions must hold in any case.
""" """
# basics # basics
assert isinstance(bt, Backtrack) assert isinstance(bt, Backtrack)
...@@ -48,11 +48,31 @@ def check_invariants(bt): ...@@ -48,11 +48,31 @@ def check_invariants(bt):
# advanced # advanced
assert bt.columns.equals(pd.Index(range(len(bt)))) assert bt.columns.equals(pd.Index(range(len(bt))))
assert isinstance(bt.max(), pd.Series) assert isinstance(bt.max(), pd.Series)
assert bt.mask.empty or bt.mask.iloc[:, -1].all()
# False propagation
# for each row this must hold:
# either the row has one change (False->True)
# or the entire row is True
if not bt.empty:
idxmax = bt.mask.idxmax(axis=1)
for row, col in idxmax.items():
assert all(bt.mask.iloc[row, :col] == False)
assert all(bt.mask.iloc[row, col:] == True)
def is_equal(bt1: Backtrack, bt2: Backtrack):
"""
Check if two BT are (considered) equal, namely
have equal 'bt' and equal 'mask'.
"""
return bt1.bt.equals(bt2.bt) and bt1.mask.equals(bt2.mask)
@pytest.mark.parametrize('data', data + [None]) @pytest.mark.parametrize('data', data + [None])
def test_init(data: np.array): def test_init(data: np.array):
# init
df = pd.DataFrame(data, dtype=float) df = pd.DataFrame(data, dtype=float)
bt = Backtrack(bt=df) bt = Backtrack(bt=df)
...@@ -65,23 +85,38 @@ def test_init(data: np.array): ...@@ -65,23 +85,38 @@ def test_init(data: np.array):
assert bt.mask.all(axis=None) assert bt.mask.all(axis=None)
# check fastpath # check fastpath
bt = Backtrack(bt=bt) fast = Backtrack(bt=bt)
check_invariants(bt) check_invariants(fast)
assert is_equal(bt, fast)
@pytest.mark.parametrize('data', data + [None]) @pytest.mark.parametrize('data', data + [None])
def test_init_with_mask(data: np.array): def test_init_with_mask(data: np.array):
# init
df = pd.DataFrame(data, dtype=float) df = pd.DataFrame(data, dtype=float)
mask = pd.DataFrame(data, dtype=bool)
bt = Backtrack(bt=df) if not mask.empty:
mask.iloc[:, -1] = True
bt = Backtrack(bt=df, mask=mask)
check_invariants(bt) check_invariants(bt)
if data is None: # shape would fail
return if data is not None:
assert len(bt.index) == data.shape[0]
assert len(bt.columns) == data.shape[1]
# check fastpath
fast = Backtrack(bt=bt)
check_invariants(fast)
assert is_equal(bt, fast)
def test_append():
pass
assert len(bt.index) == data.shape[0]
assert len(bt.columns) == data.shape[1]
assert bt.mask.all(axis=None)
def test_squeeze():
pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment