From 68aa6792e5ef165f7347198ef12df4e3d778f02d Mon Sep 17 00:00:00 2001
From: David Schaefer <david.schaefer@ufz.de>
Date: Wed, 11 Dec 2019 22:42:34 +0100
Subject: [PATCH] clearer speration between global and local compilation maps

---
 saqc/core/evaluator.py               | 38 +++++++++++++++-------------
 test/core/test_evaluator.py          |  4 +--
 test/funcs/test_generic_functions.py | 27 ++++++++++++--------
 3 files changed, 40 insertions(+), 29 deletions(-)

diff --git a/saqc/core/evaluator.py b/saqc/core/evaluator.py
index 93e72aa00..bdda36d5a 100644
--- a/saqc/core/evaluator.py
+++ b/saqc/core/evaluator.py
@@ -25,8 +25,8 @@ def _dslIsFlagged(data, field, flagger):
     return flagger.isFlagged(field)
 
 
-def initDslFuncMap(nodata):
-    return {
+def initGlobalMap():
+    out = {
         "abs": partial(_dslInner, np.abs),
         "max": partial(_dslInner, np.nanmax),
         "min": partial(_dslInner, np.nanmin),
@@ -35,6 +35,18 @@ def initDslFuncMap(nodata):
         "std": partial(_dslInner, np.nanstd),
         "len": partial(_dslInner, len),
         "isflagged": _dslIsFlagged,
+        "nan": np.nan,
+        **FUNC_MAP,
+    }
+    return out
+
+def initLocalMap(data, field, flagger, nodata):
+    return {
+        "data": data,
+        "field": field,
+        "this": field,
+        "flagger": flagger,
+        "NODATA": nodata,
         "ismissing": lambda data, field, flagger: ((data == nodata) | pd.isnull(data)),
     }
 
@@ -184,29 +196,21 @@ def compileTree(tree: ast.Expression):
     return compile(ast.fix_missing_locations(tree), "<ast>", mode="eval")
 
 
-def evalCode(code, data, field, flagger, nodata):
-    global_env = initDslFuncMap(nodata)
-    local_env = {
-        **FUNC_MAP,
-        "data": data,
-        "field": field,
-        "this": field,
-        "flagger": flagger,
-        "NODATA": nodata,
-    }
-
+def evalCode(code, global_env, local_env):
     return eval(code, global_env, local_env)
 
 
-def compileExpression(expr, data, flagger, nodata):
+def compileExpression(expr, data, flagger, env):
     varmap = set(data.columns.tolist() + flagger.getFlags().columns.tolist())
     tree = parseExpression(expr)
-    dsl_transformer = DslTransformer(initDslFuncMap(nodata), varmap)
+    dsl_transformer = DslTransformer(env, varmap)
     transformed_tree = MetaTransformer(dsl_transformer, flagger.signature).visit(tree)
     return compileTree(transformed_tree)
 
 
 def evalExpression(expr, data, field, flagger, nodata=np.nan):
 
-    code = compileExpression(expr, data, flagger, nodata)
-    return evalCode(code, data, field, flagger, nodata)
+    global_env = initGlobalMap()
+    local_env = initLocalMap(data, field, flagger, nodata)
+    code = compileExpression(expr, data, flagger, {**global_env, **local_env})
+    return evalCode(code, global_env, local_env)
diff --git a/test/core/test_evaluator.py b/test/core/test_evaluator.py
index 8089bf5f8..237b84f69 100644
--- a/test/core/test_evaluator.py
+++ b/test/core/test_evaluator.py
@@ -8,7 +8,7 @@ from saqc.funcs import register
 from saqc.core.evaluator import (
     compileTree,
     parseExpression,
-    initDslFuncMap,
+    initGlobalMap,
     DslTransformer,
     MetaTransformer,
 )
@@ -18,7 +18,7 @@ from test.common import TESTFLAGGER
 
 def compileExpression(expr, flagger, nodata=np.nan):
     tree = parseExpression(expr)
-    dsl_transformer = DslTransformer(initDslFuncMap(nodata), {})
+    dsl_transformer = DslTransformer(initGlobalMap(), {})
     transformed_tree = MetaTransformer(dsl_transformer, flagger.signature).visit(tree)
     code = compileTree(transformed_tree)
     return code
diff --git a/test/funcs/test_generic_functions.py b/test/funcs/test_generic_functions.py
index 934cff3bb..7a5096f64 100644
--- a/test/funcs/test_generic_functions.py
+++ b/test/funcs/test_generic_functions.py
@@ -9,7 +9,8 @@ from test.common import initData, TESTFLAGGER, TESTNODATA
 
 from saqc.core.evaluator import (
     DslTransformer,
-    initDslFuncMap,
+    initGlobalMap,
+    initLocalMap,
     parseExpression,
     evalExpression,
     compileTree,
@@ -17,12 +18,14 @@ from saqc.core.evaluator import (
 )
 
 
-def _evalExpression(expr, data, field, flagger, nodata=np.nan):
+def _evalDslExpression(expr, data, field, flagger, nodata=np.nan):
+    global_env = initGlobalMap()
+    local_env = initLocalMap(data, field, flagger, nodata)
     tree = parseExpression(expr)
-    dsl_transformer = DslTransformer(initDslFuncMap(nodata), data.columns)
+    dsl_transformer = DslTransformer({**global_env, **local_env}, data.columns)
     transformed_tree = dsl_transformer.visit(tree)
     code = compileTree(transformed_tree)
-    return evalCode(code, data, field, flagger, nodata)
+    return evalCode(code, global_env, local_env)
 
 
 @pytest.fixture
@@ -50,8 +53,12 @@ def test_flagPropagation(data, flagger):
 
 @pytest.mark.parametrize("flagger", TESTFLAGGER)
 def test_missingIdentifier(data, flagger):
+
     flagger = flagger.initFlags(data)
-    tests = ["generic(func=func(var2) < 5)", "generic(func=var3 != NODATA)"]
+    tests = [
+        "generic(func=fff(var2) < 5)",
+        "generic(func=var3 != NODATA)"
+    ]
     for expr in tests:
         with pytest.raises(NameError):
             evalExpression(expr, data, data.columns[0], flagger, np.nan)
@@ -90,7 +97,7 @@ def test_nonReduncingBuiltins(data, flagger):
     ]
 
     for expr, expected in tests:
-        result = _evalExpression(expr, data, this, flagger)
+        result = _evalDslExpression(expr, data, this, flagger)
         assert (result == expected).all()
 
 
@@ -112,7 +119,7 @@ def test_reduncingBuiltins(data, flagger, nodata):
     ]
 
     for expr, expected in tests:
-        result = _evalExpression(expr, data, this, flagger, nodata)
+        result = _evalDslExpression(expr, data, this, flagger, nodata)
         assert result == expected
 
 
@@ -135,7 +142,7 @@ def test_ismissing(data, flagger, nodata):
     ]
 
     for expr, checkFunc in tests:
-        idx = _evalExpression(expr, data, var1, flagger, nodata)
+        idx = _evalDslExpression(expr, data, var1, flagger, nodata)
         assert checkFunc(data.loc[idx, var1])
 
 
@@ -173,7 +180,7 @@ def test_isflagged(data, flagger):
     flagger = flagger.setFlags(var1, iloc=slice(None, None, 2))
     flagger = flagger.setFlags(var2, iloc=slice(None, None, 2))
 
-    idx = _evalExpression(f"isflagged({var1})", data, var2, flagger)
+    idx = _evalDslExpression(f"isflagged({var1})", data, var2, flagger)
 
     flagged = flagger.isFlagged(var1)
     assert (flagged == idx).all
@@ -188,7 +195,7 @@ def test_isflaggedArgument(data, flagger):
         var1, iloc=slice(None, None, 2), flag=flagger.BAD
     )
 
-    idx = _evalExpression(f"isflagged({var1}, {flagger.BAD})", data, var2, flagger)
+    idx = _evalDslExpression(f"isflagged({var1}, {flagger.BAD})", data, var2, flagger)
 
     flagged = flagger.isFlagged(var1, flag=flagger.BAD, comparator=">=")
     assert (flagged == idx).all()
-- 
GitLab