Skip to content
Snippets Groups Projects
Commit 39d932e4 authored by David Schäfer's avatar David Schäfer
Browse files

some cleanups

parent 08b5b1a6
No related branches found
No related tags found
No related merge requests found
......@@ -159,7 +159,7 @@ class MetaTransformer(ast.NodeTransformer):
def parseExpression(expr: str) -> ast.Expression:
tree = ast.parse(expr, mode="eval")
if not isinstance(tree.body, ast.Call):
if not isinstance(tree.body, (ast.Call, ast.Compare)):
raise TypeError('function call needed')
return tree
......
#! /usr/bin/env python
# -*- coding: utf-8 -*-
import ast
import pytest
import numpy as np
import pandas as pd
......@@ -12,6 +10,7 @@ from ..common import initData, TESTFLAGGER
from saqc.dsl.parser import (
DslTransformer,
initDslFuncMap,
parseExpression,
evalExpression,
compileTree,
evalCode)
......@@ -19,10 +18,10 @@ from saqc.dsl.parser import (
def _evalExpression(expr, data, flags, field, flagger, nodata=np.nan):
dsl_transformer = DslTransformer(initDslFuncMap(nodata), data.columns)
tree = ast.parse(expr, mode="eval")
tree = parseExpression(expr)
transformed_tree = dsl_transformer.visit(tree)
code = compileTree(transformed_tree)
return evalCode(code, data, flags, "var1", flagger, nodata)
return evalCode(code, data, flags, field, flagger, nodata)
@pytest.fixture
......@@ -61,12 +60,12 @@ def test_flagPropagation(data, flagger):
def test_missingIdentifier(data, flagger):
flags = flagger.initFlags(data)
tests = [
"func(var2) < 5",
"var3 != NODATA"
"generic(func=func(var2) < 5)",
"generic(func=var3 != NODATA)"
]
for expr in tests:
with pytest.raises(NameError):
_evalExpression(expr, data, flags, data.columns[0], flagger)
evalExpression(expr, data, flags, data.columns[0], flagger, np.nan)
@pytest.mark.parametrize("flagger", TESTFLAGGER)
......@@ -82,11 +81,6 @@ def test_comparisons(data, flagger):
(f"this <= {var2}", data[this] <= data[var2])
]
# check directly
for expr, expected in tests:
result = _evalExpression(expr, data, flags, this, flagger, np.nan)
assert (result == expected).all()
# check within the usually enclosing scope
for expr, mask in tests:
_, result_flags = evalExpression(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment