Skip to content
Snippets Groups Projects
Commit 24d636b5 authored by David Schäfer's avatar David Schäfer
Browse files

better evaluation errors

parent 3929c9a6
No related branches found
No related tags found
No related merge requests found
......@@ -56,15 +56,19 @@ class DslTransformer(ast.NodeTransformer):
ast.Invert,
)
def __init__(self, func_map):
def __init__(self, func_map, variables):
self.func_map = func_map
self.variables = set(variables)
def _rename(self, node: ast.Name, target: str) -> ast.Subscript:
if node.id == "this":
name = node.id
if name == "this":
slice = ast.Index(value=ast.Name(id="field", ctx=ast.Load()))
else:
slice = ast.Index(value=ast.Constant(value=node.id))
if name not in self.variables:
raise NameError(f"unknown variable: '{name}'")
slice = ast.Index(value=ast.Constant(value=name))
return ast.Subscript(
value=ast.Name(id=target, ctx=ast.Load()),
......@@ -74,7 +78,7 @@ class DslTransformer(ast.NodeTransformer):
def visit_Call(self, node):
func_name = node.func.id
if func_name not in self.func_map:
raise TypeError(f"unspported function: {func_name}")
raise NameError(f"unspported function: {func_name}")
node = ast.Call(
func=node.func,
......@@ -180,7 +184,7 @@ def evalCode(code, data, flags, field, flagger, nodata):
def evalExpression(expr, data, flags, field, flagger, nodata):
tree = parseExpression(expr)
dsl_transformer = DslTransformer(initDslFuncMap(nodata))
dsl_transformer = DslTransformer(initDslFuncMap(nodata), data.columns)
transformed_tree = MetaTransformer(dsl_transformer).visit(tree)
code = compileTree(transformed_tree)
return evalCode(code, data, flags, field, flagger, nodata)
......@@ -16,17 +16,18 @@ from .register import register
@register("generic")
def flagGeneric(data, flags, field, flagger, func, **kwargs):
"""
NOTE:
The naming of the func parameter is pretty confusing
as it actually holds the result of a generic expression
"""
result = func.squeeze()
if np.isscalar(result):
# NOTE:
# - The naming of the func parameter is pretty confusing
# as it actually holds the result of a generic expression
# - if the result series carries a name, it was explicitly created
# from one single columns, so we need to preserve this columns
# properties
mask = func.squeeze() | flagger.isFlagged(flags[func.name or field])
if np.isscalar(mask):
raise TypeError(f"generic expression does not return an array")
if not np.issubdtype(result.dtype, np.bool_):
if not np.issubdtype(mask.dtype, np.bool_):
raise TypeError(f"generic expression does not return a boolean array")
flags = flagger.setFlags(flags, field, result, **kwargs)
flags = flagger.setFlags(flags, field, mask, **kwargs)
return data, flags
......
......@@ -18,7 +18,7 @@ from saqc.dsl.parser import (
def _evalExpression(expr, data, flags, field, flagger, nodata=np.nan):
dsl_transformer = DslTransformer(initDslFuncMap(nodata))
dsl_transformer = DslTransformer(initDslFuncMap(nodata), data.columns)
tree = ast.parse(expr, mode="eval")
transformed_tree = dsl_transformer.visit(tree)
code = compileTree(transformed_tree)
......@@ -35,38 +35,38 @@ def nodata():
return -99990
# @pytest.mark.parametrize("flagger", TESTFLAGGER)
# def test_flagPropagation(data, flagger):
# flags = flagger.setFlags(
# flagger.initFlags(data),
# 'var2', iloc=slice(None, None, 5))
# var1, var2, *_ = data.columns
# this = var1
# var2_flags = flagger.isFlagged(flags[var2])
# var2_data = data[var2].mask(var2_flags)
# data, flags = evalExpression(
# "generic(func=var2 < mean(var2))",
# data, flags,
# this,
# flagger, np.nan
# )
# expected = (var2_flags | (var2_data < var2_data.mean()))
# result = flagger.isFlagged(flags[this])
# assert (result == expected).all()
# @pytest.mark.parametrize("flagger", TESTFLAGGER)
# def test_missingIdentifier(data, flagger):
# flags = flagger.initFlags(data)
# tests = [
# "func(var2) < 5",
# "var3 != NODATA"
# ]
# for expr in tests:
# with pytest.raises(NameError):
# _evalExpression(expr, data, flags, data.columns[0], flagger)
@pytest.mark.parametrize("flagger", TESTFLAGGER)
def test_flagPropagation(data, flagger):
flags = flagger.setFlags(
flagger.initFlags(data),
'var2', iloc=slice(None, None, 5))
var1, var2, *_ = data.columns
this = var1
var2_flags = flagger.isFlagged(flags[var2])
var2_data = data[var2].mask(var2_flags)
data, flags = evalExpression(
"generic(func=var2 < mean(var2))",
data, flags,
this,
flagger, np.nan
)
expected = (var2_flags | (var2_data < var2_data.mean()))
result = flagger.isFlagged(flags[this])
assert (result == expected).all()
@pytest.mark.parametrize("flagger", TESTFLAGGER)
def test_missingIdentifier(data, flagger):
flags = flagger.initFlags(data)
tests = [
"func(var2) < 5",
"var3 != NODATA"
]
for expr in tests:
with pytest.raises(NameError):
_evalExpression(expr, data, flags, data.columns[0], flagger)
@pytest.mark.parametrize("flagger", TESTFLAGGER)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment