Skip to content
Snippets Groups Projects
Commit d04d26f9 authored by Bert Palm's avatar Bert Palm 🎇
Browse files

added 'BT.empty'. Improved tests.

parent a2b62a86
No related branches found
No related tags found
1 merge request!218Flags
......@@ -93,7 +93,7 @@ class Backtrack:
Returns
-------
index: pd.Index
index : pd.Index
"""
return self.bt.index
......@@ -107,10 +107,26 @@ class Backtrack:
Returns
-------
columns: pd.Index
columns : pd.Index
"""
return self.bt.columns
@property
def empty(self) -> bool:
"""
Indicator whether Backtrack is empty.
True if Backtrack is entirely empty (no items).
Returns
-------
bool
If Backtrack is empty, return True, if not return False.
"""
# we take self.mask here, because it cannot have NaN's,
# but self.bt could have -> see pd.DataFrame.empty
return self.mask.empty
def _insert(self, s: pd.Series, nr: int, force=False) -> Backtrack:
"""
Insert data at an arbitrary position in the BT.
......@@ -172,7 +188,7 @@ class Backtrack:
if s.empty:
raise ValueError('Cannot append empty pd.Series')
if not self.bt.empty and not s.index.equals(self.index):
if not self.empty and not s.index.equals(self.index):
raise ValueError("Index must be equal to BT's index")
self._insert(value, nr=len(self))
......@@ -280,12 +296,15 @@ class Backtrack:
if any(mask.dtypes != bool):
raise ValueError("dtype of all columns in 'mask' must be bool")
if not mask.empty and not mask.iloc[:, -1].all():
raise ValueError("the values in the last column in mask must be 'True' everywhere.")
# check combination of bt and mask
if not obj.columns.equals(mask.columns):
raise ValueError("'BT' and 'mask' must have same columns")
raise ValueError("'bt' and 'mask' must have same columns")
if not obj.index.equals(mask.columns):
raise ValueError("'BT' and 'mask' must have same index")
if not obj.index.equals(mask.index):
raise ValueError("'bt' and 'mask' must have same index")
return obj, mask
......
......@@ -31,8 +31,8 @@ data = [
def check_invariants(bt):
"""
this can be called for **any** BT and
should never fail.
This can be called for **any** BT.
The assertions must hold in any case.
"""
# basics
assert isinstance(bt, Backtrack)
......@@ -48,11 +48,31 @@ def check_invariants(bt):
# advanced
assert bt.columns.equals(pd.Index(range(len(bt))))
assert isinstance(bt.max(), pd.Series)
assert bt.mask.empty or bt.mask.iloc[:, -1].all()
# False propagation
# for each row this must hold:
# either the row has one change (False->True)
# or the entire row is True
if not bt.empty:
idxmax = bt.mask.idxmax(axis=1)
for row, col in idxmax.items():
assert all(bt.mask.iloc[row, :col] == False)
assert all(bt.mask.iloc[row, col:] == True)
def is_equal(bt1: Backtrack, bt2: Backtrack):
"""
Check if two BT are (considered) equal, namely
have equal 'bt' and equal 'mask'.
"""
return bt1.bt.equals(bt2.bt) and bt1.mask.equals(bt2.mask)
@pytest.mark.parametrize('data', data + [None])
def test_init(data: np.array):
# init
df = pd.DataFrame(data, dtype=float)
bt = Backtrack(bt=df)
......@@ -65,23 +85,38 @@ def test_init(data: np.array):
assert bt.mask.all(axis=None)
# check fastpath
bt = Backtrack(bt=bt)
check_invariants(bt)
fast = Backtrack(bt=bt)
check_invariants(fast)
assert is_equal(bt, fast)
@pytest.mark.parametrize('data', data + [None])
def test_init_with_mask(data: np.array):
# init
df = pd.DataFrame(data, dtype=float)
bt = Backtrack(bt=df)
mask = pd.DataFrame(data, dtype=bool)
if not mask.empty:
mask.iloc[:, -1] = True
bt = Backtrack(bt=df, mask=mask)
check_invariants(bt)
if data is None:
return
# shape would fail
if data is not None:
assert len(bt.index) == data.shape[0]
assert len(bt.columns) == data.shape[1]
# check fastpath
fast = Backtrack(bt=bt)
check_invariants(fast)
assert is_equal(bt, fast)
def test_append():
pass
assert len(bt.index) == data.shape[0]
assert len(bt.columns) == data.shape[1]
assert bt.mask.all(axis=None)
def test_squeeze():
pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment