Skip to content
Snippets Groups Projects
Commit d37a1b0d authored by David Schäfer's avatar David Schäfer
Browse files

possibly filter the argument `func_arguments` to `flagGeneric`

parent fdfc3cbb
No related branches found
No related tags found
1 merge request!13Evaluator rework - propagate flags through evaluation
......@@ -3,7 +3,7 @@
import ast
from functools import partial
from typing import Union, Dict, Any
from typing import Union, Dict, Any, Set
import numpy as np
import pandas as pd
......@@ -18,6 +18,7 @@ def _dslIsFlagged(flagger, data, flag=None):
def initLocalEnv(data: pd.DataFrame, field: str, flagger: BaseFlagger, nodata: float) -> Dict[str, Any]:
return {
"data": data,
"field": field,
......@@ -67,28 +68,43 @@ class DslTransformer(ast.NodeTransformer):
ast.Name,
)
def __init__(self, environment, variables):
def __init__(self, environment: Dict[str, Any], variables: Set[str]):
self.environment = environment
self.variables = variables
self.arguments = set()
self.invert = False
self.func_name = None
def transform(self, node):
# NOTE: should be done in __init__
self.arguments = set()
return self.visit(node)
def visit_Invert(self, node):
self.invert = True
return node
def visit_Call(self, node):
func_name = node.func.id
if func_name not in self.environment:
raise NameError(f"unspported function: '{func_name}'")
self.func_name = func_name
return ast.Call(func=node.func, args=[self.visit(arg) for arg in node.args], keywords=[],)
def visit_Name(self, node):
name = node.id
if name == "this":
name = self.environment["field"]
self.arguments.add(name)
# NOTE:
# we need a way to prevent some variables
# from ending up in `flagGeneric`, see the
# problem with np.all(~isflagged(x)) is True
if self.func_name == "isflagged" and self.invert:
self.invert = False
else:
self.arguments.add(name)
if name in self.variables:
value = ast.Constant(value=name)
......@@ -151,9 +167,14 @@ class ConfigTransformer(ast.NodeTransformer):
def visit_keyword(self, node):
key, value = node.arg, node.value
if self.func_name == Params.FLAG_GENERIC and key == Params.FUNC:
dsl_func = ast.keyword(
arg=key, value=self.dsl_transformer.transform(value))
# NOTE:
# Inject the additional `func_arguments` argument `flagGeneric`
# expects, to keep track of all the touched variables. We
# need this to propagate the flags from the independent variables
args = ast.keyword(
arg=Params.GENERIC_ARGS,
value=ast.List(
......
......@@ -14,7 +14,7 @@ from saqc.funcs.functions import (
flagIsolated,
)
from saqc.flagger.dmpflagger import DmpFlagger
from test.common import initData, TESTFLAGGER, initMetaDict
from test.common import initData, TESTFLAGGER
@pytest.fixture
......
......@@ -220,7 +220,6 @@ def test_invertIsFlagged(data, flagger):
flagger = flagger.initFlags(data)
var1, var2, *_ = data.columns
# flagger = flagger.setFlags(var1, iloc=slice(None, None, 2))
flagger = flagger.setFlags(var2, iloc=slice(None, None, 2))
_, flagger_result = evalExpression(
......@@ -229,7 +228,6 @@ def test_invertIsFlagged(data, flagger):
)
flags_result = flagger_result.isFlagged(var1)
flags = flagger.isFlagged(var2)
# import pdb; pdb.set_trace()
assert np.all(flags_result != flags)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment