Skip to content
Snippets Groups Projects
Commit ad3dd7bb authored by David Schäfer's avatar David Schäfer
Browse files

allow ast.USub in test function arguments

parent efeb074c
No related branches found
No related tags found
No related merge requests found
......@@ -2,9 +2,6 @@
# -*- coding: utf-8 -*-
# maybe we should use dataclasses here
class Fields:
VARNAME = "varname"
START = "start_date"
......@@ -16,7 +13,5 @@ class Fields:
class Params:
GENERIC = "generic"
FUNC = "func"
FLAGPERIOD = "flag_period"
FLAGVALUES = "flag_values"
FLAG = "flag"
......@@ -115,18 +115,16 @@ class DslTransformer(ast.NodeTransformer):
class MetaTransformer(ast.NodeTransformer):
SUPPORTED = (
ast.Call,
ast.Num,
ast.Str,
ast.keyword,
ast.NameConstant,
ast.UnaryOp,
ast.Name,
ast.Load,
ast.Expression,
ast.Subscript,
ast.Index,
SUPPORTED_NODES = (
ast.Call, ast.Num, ast.Str, ast.keyword,
ast.NameConstant, ast.UnaryOp, ast.Name,
ast.Load, ast.Expression, ast.Subscript,
ast.Index, ast.USub
)
SUPPORTED_ARGUMENTS = (
ast.Str, ast.Num, ast.NameConstant, ast.Call,
ast.UnaryOp, ast.USub
)
def __init__(self, dsl_transformer, pass_parameter):
......@@ -156,22 +154,22 @@ class MetaTransformer(ast.NodeTransformer):
def visit_keyword(self, node):
key, value = node.arg, node.value
if self.func_name == "generic" and key == Params.FUNC:
if self.func_name == Params.GENERIC and key == Params.FUNC:
node = ast.keyword(arg=key, value=self.dsl_transformer.visit(value))
return node
if key not in FUNC_MAP[self.func_name].signature + self.pass_parameter:
raise TypeError(f"unknown function parameter '{node.arg}'")
if not isinstance(value, (ast.Str, ast.Num, ast.Call)):
if not isinstance(value, self.SUPPORTED_ARGUMENTS):
raise TypeError(
f"only concrete values and function calls are valid function arguments"
f"invalid argument type '{type(value)}'"
)
return self.generic_visit(node)
def generic_visit(self, node):
if not isinstance(node, self.SUPPORTED):
if not isinstance(node, self.SUPPORTED_NODES):
raise TypeError(f"invalid expression: '{node}'")
return super().generic_visit(node)
......
......@@ -4,6 +4,7 @@
import pytest
import numpy as np
from saqc.funcs import register
from saqc.core.evaluator import (
compileTree,
parseExpression,
......@@ -17,12 +18,16 @@ from test.common import TESTFLAGGER
def compileExpression(expr, flagger, nodata=np.nan):
tree = parseExpression(expr)
dsl_transformer = DslTransformer(initDslFuncMap(flagger, nodata, "target"))
transformed_tree = MetaTransformer(dsl_transformer).visit(tree)
dsl_transformer = DslTransformer(initDslFuncMap(nodata), {})
transformed_tree = MetaTransformer(dsl_transformer, flagger.signature).visit(tree)
code = compileTree(transformed_tree)
return code
def _dummyFunc(data, field, flagger, kwarg, **kwargs):
pass
@pytest.mark.parametrize("flagger", TESTFLAGGER)
def test_syntaxError(flagger):
exprs = [
......@@ -38,13 +43,43 @@ def test_syntaxError(flagger):
@pytest.mark.parametrize("flagger", TESTFLAGGER)
def test_typeError(flagger):
register("func")(_dummyFunc)
register("otherFunc")(_dummyFunc)
exprs = [
"range",
"nodata(x=[1, 2, 3])",
"nodata(func=ismissing(this))",
"range(deleteEverything())",
# "func",
"func(kwarg=[1, 2, 3])",
"func(x=5)",
"func(otherFunc())",
"func(kwarg=otherFunc(this))",
"func(kwarg=otherFunc(kwarg=this))",
]
for expr in exprs:
with pytest.raises(TypeError):
compileExpression(expr, flagger)
@pytest.mark.parametrize("flagger", TESTFLAGGER)
def test_supportedArguments(flagger):
register("func")(_dummyFunc)
register("otherFunc")(_dummyFunc)
exprs = [
"func(kwarg='str')",
"func(kwarg=5)",
"func(kwarg=5.5)",
"func(kwarg=-5)",
"func(kwarg=True)",
"func(kwarg=otherFunc())",
]
for expr in exprs:
compileExpression(expr, flagger)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment