Skip to content
Snippets Groups Projects
test_generic_api_functions.py 1.98 KiB
#! /usr/bin/env python
# -*- coding: utf-8 -*-
import ast

import pytest
import numpy as np
import pandas as pd

from dios import DictOfSeries

from test.common import TESTFLAGGER, TESTNODATA, initData, writeIO, flagAll
from saqc.core.visitor import ConfigFunctionParser
from saqc.core.config import Fields as F
from saqc.core.register import register
from saqc import SaQC, SimpleFlagger
from saqc.funcs.generic import _execGeneric
from saqc.funcs.tools import mask


register(masking='field')(flagAll)


@pytest.fixture
def data():
    return initData()


@pytest.mark.parametrize("flagger", TESTFLAGGER)
def test_addFieldFlagGeneric(data, flagger):
    saqc = SaQC(data=data, flagger=flagger)

    data, flags = saqc.generic.flag(
        "tmp1",
        func=lambda var1: pd.Series(False, index=data[var1.name].index)
    ).getResult()
    assert "tmp1" in flags.columns and "tmp1" not in data


@pytest.mark.parametrize("flagger", TESTFLAGGER)
def test_addFieldProcGeneric(data, flagger):
    saqc = SaQC(data=data, flagger=flagger)

    data, flagger = saqc.generic.process("tmp1", func=lambda: pd.Series([])).getResult(raw=True)
    assert "tmp1" in data.columns and data["tmp1"].empty

    data, flagger = saqc.generic.process("tmp2", func=lambda var1, var2: var1 + var2).getResult()
    assert "tmp2" in data.columns and (data["tmp2"] == data["var1"] + data["var2"]).all(axis=None)


@pytest.mark.parametrize("flagger", TESTFLAGGER)
def test_mask(data, flagger):

    saqc = SaQC(data=data, flagger=flagger)
    data_org = data.copy(deep=True)
    mean = data["var1"] / 2

    data, _ = saqc.generic.process("var1", lambda var1: mask(var1 < mean)).getResult()
    assert ((data["var1"].isna()) == (data_org["var1"] < 10) & data_org["var1"].isna()).all(axis=None)

    data, flags = saqc.generic.process("tmp", lambda var1: mask(var1 < mean)).getResult()
    assert ("tmp" in data.columns) and ("tmp" in flags.columns)
    assert ((data["tmp"].isna()) == (data_org["var1"] < 10) & data_org["var1"].isna()).all(axis=None)