Skip to content
Snippets Groups Projects
Commit 16be7ea2 authored by David Schaefer's avatar David Schaefer
Browse files

fixed formatting

parent 54e15ce3
No related branches found
No related tags found
1 merge request!13Evaluator rework - propagate flags through evaluation
Pipeline #2427 passed with stage
in 6 minutes and 45 seconds
...@@ -88,7 +88,7 @@ def run( ...@@ -88,7 +88,7 @@ def run(
config = readConfig(config_file, data) config = readConfig(config_file, data)
# split config into the test and some 'meta' data # split config into the test and some 'meta' data
tests = config.filter(regex=Fields.TESTS + '*') tests = config.filter(regex=Fields.TESTS + "*")
meta = config[config.columns.difference(tests.columns)] meta = config[config.columns.difference(tests.columns)]
# prepapre the flags # prepapre the flags
......
...@@ -10,15 +10,9 @@ from saqc.core.evaluator.evaluator import ( ...@@ -10,15 +10,9 @@ from saqc.core.evaluator.evaluator import (
DslTransformer, DslTransformer,
ConfigChecker, ConfigChecker,
ConfigTransformer, ConfigTransformer,
evalCode evalCode,
) )
from saqc.core.evaluator.checker import ( from saqc.core.evaluator.checker import DslChecker, ConfigChecker
DslChecker,
ConfigChecker
)
from saqc.core.evaluator.transformer import ( from saqc.core.evaluator.transformer import DslTransformer, ConfigTransformer
DslTransformer,
ConfigTransformer
)
...@@ -6,6 +6,7 @@ import ast ...@@ -6,6 +6,7 @@ import ast
from saqc.funcs.register import FUNC_MAP from saqc.funcs.register import FUNC_MAP
from saqc.core.config import Params from saqc.core.config import Params
class DslChecker(ast.NodeVisitor): class DslChecker(ast.NodeVisitor):
SUPPORTED = ( SUPPORTED = (
...@@ -32,7 +33,7 @@ class DslChecker(ast.NodeVisitor): ...@@ -32,7 +33,7 @@ class DslChecker(ast.NodeVisitor):
ast.Invert, ast.Invert,
ast.Name, ast.Name,
ast.Load, ast.Load,
ast.Call ast.Call,
) )
def __init__(self, environment): def __init__(self, environment):
...@@ -46,9 +47,7 @@ class DslChecker(ast.NodeVisitor): ...@@ -46,9 +47,7 @@ class DslChecker(ast.NodeVisitor):
def visit_Name(self, node): def visit_Name(self, node):
name = node.id name = node.id
if (name != "this" and if name != "this" and name not in self.environment and name not in self.environment["variables"]:
name not in self.environment and
name not in self.environment["variables"]):
raise NameError(f"unknown variable: '{name}'") raise NameError(f"unknown variable: '{name}'")
self.generic_visit(node) self.generic_visit(node)
...@@ -113,4 +112,3 @@ class ConfigChecker(ast.NodeVisitor): ...@@ -113,4 +112,3 @@ class ConfigChecker(ast.NodeVisitor):
if not isinstance(node, self.SUPPORTED_NODES): if not isinstance(node, self.SUPPORTED_NODES):
raise TypeError(f"invalid node: '{node}'") raise TypeError(f"invalid node: '{node}'")
return super().generic_visit(node) return super().generic_visit(node)
...@@ -68,7 +68,7 @@ class DslChecker(ast.NodeVisitor): ...@@ -68,7 +68,7 @@ class DslChecker(ast.NodeVisitor):
ast.Invert, ast.Invert,
ast.Name, ast.Name,
ast.Load, ast.Load,
ast.Call ast.Call,
) )
def __init__(self, environment): def __init__(self, environment):
...@@ -82,9 +82,7 @@ class DslChecker(ast.NodeVisitor): ...@@ -82,9 +82,7 @@ class DslChecker(ast.NodeVisitor):
def visit_Name(self, node): def visit_Name(self, node):
name = node.id name = node.id
if (name != "this" and if name != "this" and name not in self.environment and name not in self.environment["variables"]:
name not in self.environment and
name not in self.environment["variables"]):
raise NameError(f"unknown variable: '{name}'") raise NameError(f"unknown variable: '{name}'")
self.generic_visit(node) self.generic_visit(node)
...@@ -152,7 +150,6 @@ class ConfigChecker(ast.NodeVisitor): ...@@ -152,7 +150,6 @@ class ConfigChecker(ast.NodeVisitor):
class DslTransformer(ast.NodeTransformer): class DslTransformer(ast.NodeTransformer):
def __init__(self, environment: Dict[str, Any]): def __init__(self, environment: Dict[str, Any]):
self.environment = environment self.environment = environment
self.arguments = set() self.arguments = set()
...@@ -165,11 +162,7 @@ class DslTransformer(ast.NodeTransformer): ...@@ -165,11 +162,7 @@ class DslTransformer(ast.NodeTransformer):
def visit_Call(self, node): def visit_Call(self, node):
self.func_name = node.func.id self.func_name = node.func.id
return ast.Call( return ast.Call(func=node.func, args=[self.visit(arg) for arg in node.args], keywords=[])
func=node.func,
args=[self.visit(arg) for arg in node.args],
keywords=[]
)
def visit_Name(self, node): def visit_Name(self, node):
name = node.id name = node.id
...@@ -197,8 +190,6 @@ class DslTransformer(ast.NodeTransformer): ...@@ -197,8 +190,6 @@ class DslTransformer(ast.NodeTransformer):
class ConfigTransformer(ast.NodeTransformer): class ConfigTransformer(ast.NodeTransformer):
def __init__(self, environment): def __init__(self, environment):
self.environment = environment self.environment = environment
self.func_name = None self.func_name = None
...@@ -230,9 +221,8 @@ class ConfigTransformer(ast.NodeTransformer): ...@@ -230,9 +221,8 @@ class ConfigTransformer(ast.NodeTransformer):
# need this to propagate the flags from the independent variables # need this to propagate the flags from the independent variables
args = ast.keyword( args = ast.keyword(
arg=Params.GENERIC_ARGS, arg=Params.GENERIC_ARGS,
value=ast.List( value=ast.List(elts=[ast.Str(s=v) for v in dsl_transformer.arguments], ctx=ast.Load()),
elts=[ast.Str(s=v) for v in dsl_transformer.arguments], )
ctx=ast.Load()))
return [dsl_func, args] return [dsl_func, args]
return self.generic_visit(node) return self.generic_visit(node)
......
...@@ -5,8 +5,8 @@ import ast ...@@ -5,8 +5,8 @@ import ast
from saqc.core.config import Params from saqc.core.config import Params
from typing import Dict, Any from typing import Dict, Any
class DslTransformer(ast.NodeTransformer):
class DslTransformer(ast.NodeTransformer):
def __init__(self, environment: Dict[str, Any]): def __init__(self, environment: Dict[str, Any]):
self.environment = environment self.environment = environment
self.arguments = set() self.arguments = set()
...@@ -19,11 +19,7 @@ class DslTransformer(ast.NodeTransformer): ...@@ -19,11 +19,7 @@ class DslTransformer(ast.NodeTransformer):
def visit_Call(self, node): def visit_Call(self, node):
self.func_name = node.func.id self.func_name = node.func.id
return ast.Call( return ast.Call(func=node.func, args=[self.visit(arg) for arg in node.args], keywords=[])
func=node.func,
args=[self.visit(arg) for arg in node.args],
keywords=[]
)
def visit_Name(self, node): def visit_Name(self, node):
name = node.id name = node.id
...@@ -51,8 +47,6 @@ class DslTransformer(ast.NodeTransformer): ...@@ -51,8 +47,6 @@ class DslTransformer(ast.NodeTransformer):
class ConfigTransformer(ast.NodeTransformer): class ConfigTransformer(ast.NodeTransformer):
def __init__(self, environment): def __init__(self, environment):
self.environment = environment self.environment = environment
self.func_name = None self.func_name = None
...@@ -84,9 +78,8 @@ class ConfigTransformer(ast.NodeTransformer): ...@@ -84,9 +78,8 @@ class ConfigTransformer(ast.NodeTransformer):
# need this to propagate the flags from the independent variables # need this to propagate the flags from the independent variables
args = ast.keyword( args = ast.keyword(
arg=Params.GENERIC_ARGS, arg=Params.GENERIC_ARGS,
value=ast.List( value=ast.List(elts=[ast.Str(s=v) for v in dsl_transformer.arguments], ctx=ast.Load()),
elts=[ast.Str(s=v) for v in dsl_transformer.arguments], )
ctx=ast.Load()))
return [dsl_func, args] return [dsl_func, args]
return self.generic_visit(node) return self.generic_visit(node)
...@@ -125,7 +125,7 @@ def checkConfig(config_df: pd.DataFrame, data: pd.DataFrame, flagger: BaseFlagge ...@@ -125,7 +125,7 @@ def checkConfig(config_df: pd.DataFrame, data: pd.DataFrame, flagger: BaseFlagge
if pd.isnull(config_row[F.VARNAME]) or not var_name: if pd.isnull(config_row[F.VARNAME]) or not var_name:
_raise(config_row, SyntaxError, f"non-optional column '{F.VARNAME}' is missing or empty") _raise(config_row, SyntaxError, f"non-optional column '{F.VARNAME}' is missing or empty")
test_fields = config_row.filter(regex=F.TESTS + '*').dropna() test_fields = config_row.filter(regex=F.TESTS + "*").dropna()
if test_fields.empty: if test_fields.empty:
_raise( _raise(
config_row, SyntaxError, f"at least one test needs to be given for variable", config_row, SyntaxError, f"at least one test needs to be given for variable",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment