From 68aa6792e5ef165f7347198ef12df4e3d778f02d Mon Sep 17 00:00:00 2001 From: David Schaefer <david.schaefer@ufz.de> Date: Wed, 11 Dec 2019 22:42:34 +0100 Subject: [PATCH] clearer speration between global and local compilation maps --- saqc/core/evaluator.py | 38 +++++++++++++++------------- test/core/test_evaluator.py | 4 +-- test/funcs/test_generic_functions.py | 27 ++++++++++++-------- 3 files changed, 40 insertions(+), 29 deletions(-) diff --git a/saqc/core/evaluator.py b/saqc/core/evaluator.py index 93e72aa00..bdda36d5a 100644 --- a/saqc/core/evaluator.py +++ b/saqc/core/evaluator.py @@ -25,8 +25,8 @@ def _dslIsFlagged(data, field, flagger): return flagger.isFlagged(field) -def initDslFuncMap(nodata): - return { +def initGlobalMap(): + out = { "abs": partial(_dslInner, np.abs), "max": partial(_dslInner, np.nanmax), "min": partial(_dslInner, np.nanmin), @@ -35,6 +35,18 @@ def initDslFuncMap(nodata): "std": partial(_dslInner, np.nanstd), "len": partial(_dslInner, len), "isflagged": _dslIsFlagged, + "nan": np.nan, + **FUNC_MAP, + } + return out + +def initLocalMap(data, field, flagger, nodata): + return { + "data": data, + "field": field, + "this": field, + "flagger": flagger, + "NODATA": nodata, "ismissing": lambda data, field, flagger: ((data == nodata) | pd.isnull(data)), } @@ -184,29 +196,21 @@ def compileTree(tree: ast.Expression): return compile(ast.fix_missing_locations(tree), "<ast>", mode="eval") -def evalCode(code, data, field, flagger, nodata): - global_env = initDslFuncMap(nodata) - local_env = { - **FUNC_MAP, - "data": data, - "field": field, - "this": field, - "flagger": flagger, - "NODATA": nodata, - } - +def evalCode(code, global_env, local_env): return eval(code, global_env, local_env) -def compileExpression(expr, data, flagger, nodata): +def compileExpression(expr, data, flagger, env): varmap = set(data.columns.tolist() + flagger.getFlags().columns.tolist()) tree = parseExpression(expr) - dsl_transformer = DslTransformer(initDslFuncMap(nodata), varmap) + dsl_transformer = DslTransformer(env, varmap) transformed_tree = MetaTransformer(dsl_transformer, flagger.signature).visit(tree) return compileTree(transformed_tree) def evalExpression(expr, data, field, flagger, nodata=np.nan): - code = compileExpression(expr, data, flagger, nodata) - return evalCode(code, data, field, flagger, nodata) + global_env = initGlobalMap() + local_env = initLocalMap(data, field, flagger, nodata) + code = compileExpression(expr, data, flagger, {**global_env, **local_env}) + return evalCode(code, global_env, local_env) diff --git a/test/core/test_evaluator.py b/test/core/test_evaluator.py index 8089bf5f8..237b84f69 100644 --- a/test/core/test_evaluator.py +++ b/test/core/test_evaluator.py @@ -8,7 +8,7 @@ from saqc.funcs import register from saqc.core.evaluator import ( compileTree, parseExpression, - initDslFuncMap, + initGlobalMap, DslTransformer, MetaTransformer, ) @@ -18,7 +18,7 @@ from test.common import TESTFLAGGER def compileExpression(expr, flagger, nodata=np.nan): tree = parseExpression(expr) - dsl_transformer = DslTransformer(initDslFuncMap(nodata), {}) + dsl_transformer = DslTransformer(initGlobalMap(), {}) transformed_tree = MetaTransformer(dsl_transformer, flagger.signature).visit(tree) code = compileTree(transformed_tree) return code diff --git a/test/funcs/test_generic_functions.py b/test/funcs/test_generic_functions.py index 934cff3bb..7a5096f64 100644 --- a/test/funcs/test_generic_functions.py +++ b/test/funcs/test_generic_functions.py @@ -9,7 +9,8 @@ from test.common import initData, TESTFLAGGER, TESTNODATA from saqc.core.evaluator import ( DslTransformer, - initDslFuncMap, + initGlobalMap, + initLocalMap, parseExpression, evalExpression, compileTree, @@ -17,12 +18,14 @@ from saqc.core.evaluator import ( ) -def _evalExpression(expr, data, field, flagger, nodata=np.nan): +def _evalDslExpression(expr, data, field, flagger, nodata=np.nan): + global_env = initGlobalMap() + local_env = initLocalMap(data, field, flagger, nodata) tree = parseExpression(expr) - dsl_transformer = DslTransformer(initDslFuncMap(nodata), data.columns) + dsl_transformer = DslTransformer({**global_env, **local_env}, data.columns) transformed_tree = dsl_transformer.visit(tree) code = compileTree(transformed_tree) - return evalCode(code, data, field, flagger, nodata) + return evalCode(code, global_env, local_env) @pytest.fixture @@ -50,8 +53,12 @@ def test_flagPropagation(data, flagger): @pytest.mark.parametrize("flagger", TESTFLAGGER) def test_missingIdentifier(data, flagger): + flagger = flagger.initFlags(data) - tests = ["generic(func=func(var2) < 5)", "generic(func=var3 != NODATA)"] + tests = [ + "generic(func=fff(var2) < 5)", + "generic(func=var3 != NODATA)" + ] for expr in tests: with pytest.raises(NameError): evalExpression(expr, data, data.columns[0], flagger, np.nan) @@ -90,7 +97,7 @@ def test_nonReduncingBuiltins(data, flagger): ] for expr, expected in tests: - result = _evalExpression(expr, data, this, flagger) + result = _evalDslExpression(expr, data, this, flagger) assert (result == expected).all() @@ -112,7 +119,7 @@ def test_reduncingBuiltins(data, flagger, nodata): ] for expr, expected in tests: - result = _evalExpression(expr, data, this, flagger, nodata) + result = _evalDslExpression(expr, data, this, flagger, nodata) assert result == expected @@ -135,7 +142,7 @@ def test_ismissing(data, flagger, nodata): ] for expr, checkFunc in tests: - idx = _evalExpression(expr, data, var1, flagger, nodata) + idx = _evalDslExpression(expr, data, var1, flagger, nodata) assert checkFunc(data.loc[idx, var1]) @@ -173,7 +180,7 @@ def test_isflagged(data, flagger): flagger = flagger.setFlags(var1, iloc=slice(None, None, 2)) flagger = flagger.setFlags(var2, iloc=slice(None, None, 2)) - idx = _evalExpression(f"isflagged({var1})", data, var2, flagger) + idx = _evalDslExpression(f"isflagged({var1})", data, var2, flagger) flagged = flagger.isFlagged(var1) assert (flagged == idx).all @@ -188,7 +195,7 @@ def test_isflaggedArgument(data, flagger): var1, iloc=slice(None, None, 2), flag=flagger.BAD ) - idx = _evalExpression(f"isflagged({var1}, {flagger.BAD})", data, var2, flagger) + idx = _evalDslExpression(f"isflagged({var1}, {flagger.BAD})", data, var2, flagger) flagged = flagger.isFlagged(var1, flag=flagger.BAD, comparator=">=") assert (flagged == idx).all() -- GitLab