Skip to content
Snippets Groups Projects
Commit 54e15ce3 authored by David Schäfer's avatar David Schäfer
Browse files

moved the evaluator into separate sub module

parent f4359147
No related branches found
No related tags found
1 merge request!13Evaluator rework - propagate flags through evaluation
Pipeline #2424 passed with stage
in 6 minutes and 43 seconds
#! /usr/bin/env python
# -*- coding: utf-8 -*-
from saqc.core.evaluator.evaluator import (
compileExpression,
evalExpression,
compileTree,
parseExpression,
initLocalEnv,
DslTransformer,
ConfigChecker,
ConfigTransformer,
evalCode
)
from saqc.core.evaluator.checker import (
DslChecker,
ConfigChecker
)
from saqc.core.evaluator.transformer import (
DslTransformer,
ConfigTransformer
)
#! /usr/bin/env python
# -*- coding: utf-8 -*-
import ast
from saqc.funcs.register import FUNC_MAP
from saqc.core.config import Params
class DslChecker(ast.NodeVisitor):
SUPPORTED = (
ast.Expression,
ast.UnaryOp,
ast.BinOp,
ast.BitOr,
ast.BitAnd,
ast.Num,
ast.Compare,
ast.Add,
ast.Sub,
ast.Mult,
ast.Div,
ast.Pow,
ast.Mod,
ast.USub,
ast.Eq,
ast.NotEq,
ast.Gt,
ast.Lt,
ast.GtE,
ast.LtE,
ast.Invert,
ast.Name,
ast.Load,
ast.Call
)
def __init__(self, environment):
self.environment = environment
def visit_Call(self, node):
func_name = node.func.id
if func_name not in self.environment:
raise NameError(f"unspported function: '{func_name}'")
self.generic_visit(node)
def visit_Name(self, node):
name = node.id
if (name != "this" and
name not in self.environment and
name not in self.environment["variables"]):
raise NameError(f"unknown variable: '{name}'")
self.generic_visit(node)
def generic_visit(self, node):
if not isinstance(node, self.SUPPORTED):
raise TypeError(f"invalid expression: '{node}'")
return super().generic_visit(node)
class ConfigChecker(ast.NodeVisitor):
SUPPORTED_NODES = (
ast.Call,
ast.Num,
ast.Str,
ast.keyword,
ast.NameConstant,
ast.UnaryOp,
ast.Name,
ast.Load,
ast.Expression,
ast.Subscript,
ast.Index,
ast.USub,
)
SUPPORTED_ARGUMENTS = (ast.Str, ast.Num, ast.NameConstant, ast.Call, ast.UnaryOp, ast.USub, ast.Name)
def __init__(self, environment, pass_parameter):
self.pass_parameter = pass_parameter
self.environment = environment
self.func_name = None
def visit_Call(self, node):
func_name = node.func.id
if func_name not in FUNC_MAP:
raise NameError(f"unknown test function: '{func_name}'")
if node.args:
raise TypeError("only keyword arguments are supported")
self.func_name = func_name
return self.generic_visit(node)
def visit_keyword(self, node):
key, value = node.arg, node.value
if self.func_name == Params.FLAG_GENERIC and key == Params.FUNC:
DslChecker(self.environment).visit(value)
return
if key not in FUNC_MAP[self.func_name].signature + self.pass_parameter:
raise TypeError(f"unknown function parameter '{node.arg}'")
if not isinstance(value, self.SUPPORTED_ARGUMENTS):
raise TypeError(f"invalid argument type '{type(value)}'")
if isinstance(value, ast.Name) and value.id not in self.environment:
raise NameError(f"unknown variable: {value.id}")
return self.generic_visit(node)
def generic_visit(self, node):
if not isinstance(node, self.SUPPORTED_NODES):
raise TypeError(f"invalid node: '{node}'")
return super().generic_visit(node)
......@@ -3,7 +3,7 @@
import ast
from functools import partial
from typing import Union, Dict, Any, Set
from typing import Any, Dict
import numpy as np
import pandas as pd
......
#! /usr/bin/env python
# -*- coding: utf-8 -*-
import ast
from saqc.core.config import Params
from typing import Dict, Any
class DslTransformer(ast.NodeTransformer):
def __init__(self, environment: Dict[str, Any]):
self.environment = environment
self.arguments = set()
self.invert = False
self.func_name = None
def visit_Invert(self, node):
self.invert = True
return node
def visit_Call(self, node):
self.func_name = node.func.id
return ast.Call(
func=node.func,
args=[self.visit(arg) for arg in node.args],
keywords=[]
)
def visit_Name(self, node):
name = node.id
if name == "this":
name = self.environment["field"]
# NOTE:
# we need a way to prevent some variables
# from ending up in `flagGeneric`, see the
# problem with np.all(~isflagged(x)) is True
if self.func_name == "isflagged" and self.invert:
self.invert = False
else:
self.arguments.add(name)
if name in self.environment["variables"]:
value = ast.Constant(value=name)
node = ast.Subscript(
value=ast.Name(id="data", ctx=ast.Load()), slice=ast.Index(value=value), ctx=ast.Load(),
)
elif name in self.environment:
node = ast.Constant(value=name)
return node
class ConfigTransformer(ast.NodeTransformer):
def __init__(self, environment):
self.environment = environment
self.func_name = None
def visit_Call(self, node):
func_name = node.func.id
self.func_name = func_name
new_args = [
ast.Name(id="data", ctx=ast.Load()),
ast.Name(id="field", ctx=ast.Load()),
ast.Name(id="flagger", ctx=ast.Load()),
]
node = ast.Call(func=node.func, args=new_args + node.args, keywords=node.keywords)
return self.generic_visit(node)
def visit_keyword(self, node):
key, value = node.arg, node.value
if self.func_name == Params.FLAG_GENERIC and key == Params.FUNC:
dsl_transformer = DslTransformer(self.environment)
value = dsl_transformer.visit(value)
dsl_func = ast.keyword(arg=key, value=value)
# NOTE:
# Inject the additional `func_arguments` argument `flagGeneric`
# expects, to keep track of all the touched variables. We
# need this to propagate the flags from the independent variables
args = ast.keyword(
arg=Params.GENERIC_ARGS,
value=ast.List(
elts=[ast.Str(s=v) for v in dsl_transformer.arguments],
ctx=ast.Load()))
return [dsl_func, args]
return self.generic_visit(node)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment