Skip to content
Snippets Groups Projects
test_functions.py 3.35 KiB
Newer Older
Bert Palm's avatar
Bert Palm committed
#! /usr/bin/env python
# -*- coding: utf-8 -*-

import pytest
import numpy as np
import pandas as pd
Bert Palm's avatar
Bert Palm committed

from saqc.core.evaluator import evalExpression
Bert Palm's avatar
Bert Palm committed
from saqc.funcs.functions import flagRange, flagSesonalRange, forceFlags, clearFlags

from test.common import initData, TESTFLAGGER, initMetaDict
Bert Palm's avatar
Bert Palm committed


David Schäfer's avatar
David Schäfer committed
@pytest.fixture
def data():
David Schäfer's avatar
David Schäfer committed
    return initData(
        cols=1,
        start_date="2016-01-01", end_date="2018-12-31",
David Schäfer's avatar
David Schäfer committed
        freq="1D")
David Schäfer's avatar
David Schäfer committed


@pytest.fixture
def field(data):
    return data.columns[0]

David Schäfer's avatar
David Schäfer committed

@pytest.mark.parametrize('flagger', TESTFLAGGER)
def test_flagAfter(data, field, flagger):

    flagger = flagger.initFlags(data)

    min = data.iloc[int(len(data)*.3), 0]
    max = data.iloc[int(len(data)*.6), 0]
    _, flagger_range = flagRange(data, field, flagger, min, max)
David Schäfer's avatar
David Schäfer committed
    flagged_range = flagger_range.isFlagged(
        field, loc=flagger_range.isFlagged(field))

    tests = [
        (f"flagWindowAfterFlag(window='3D', func=range(min={min}, max={max}))", "3D"),
        (f"flagNextAfterFlag(n=4, func=range(min={min}, max={max}))", 4),
    ]

    for expr, window in tests:
        _, flagger_range_repeated = evalExpression(expr, data, field, flagger)

David Schäfer's avatar
David Schäfer committed
        check = (flagged_range
                 .rolling(window=window)
                 .apply(
                     lambda df: (flagger_range_repeated
                                 .isFlagged(field, loc=df.index)
                                 .all()),
                     raw=False))
        assert check.all()
David Schäfer's avatar
David Schäfer committed
@pytest.mark.parametrize('flagger', TESTFLAGGER)
def test_range(data, field, flagger):
    min, max = 10, 90
    flagger = flagger.initFlags(data)
    data, flagger = flagRange(data, field, flagger, min=min, max=max)
    flagged = flagger.isFlagged(field)
David Schäfer's avatar
David Schäfer committed
    expected = (data[field] < min) | (data[field] >= max)
    assert np.all(flagged == expected)
Bert Palm's avatar
Bert Palm committed


# @pytest.mark.parametrize('flagger', TESTFLAGGER)
# def test_missing(data, field, flagger):
#     pass
@pytest.mark.parametrize('flagger', TESTFLAGGER)
def test_flagSesonalRange(data, field, flagger):
    # prepare
    data.loc[::2] = 0
    data.loc[1::2] = 50
    nyears = len(data.index.year.unique())

    tests = [
        ({"min": 1, "max": 100, "startmonth": 7, "startday": 1, "endmonth": 8, "endday": 31},
         31*2*nyears//2),
        ({"min": 1, "max": 100, "startmonth": 12, "startday": 16, "endmonth": 1, "endday": 15},
         31*nyears//2 + 1)
    ]

    for test, expected in tests:
        flagger = flagger.initFlags(data)
        data, flagger = flagSesonalRange(data, field, flagger, **test)
        flagged = flagger.isFlagged(field)
        assert flagged.sum() == expected
Bert Palm's avatar
Bert Palm committed


@pytest.mark.parametrize('flagger', TESTFLAGGER)
def test_clearFlags(data, field, flagger):
    flagger = flagger.initFlags(data)
    flags_orig = flagger.getFlags()
    flags_set = flagger.setFlags(field, flag=flagger.BAD).getFlags()
    _, flagger = clearFlags(data, field, flagger)
    flags_cleared = flagger.getFlags()
    assert np.all(flags_orig != flags_set)
    assert np.all(flags_orig == flags_cleared)
David Schäfer's avatar
David Schäfer committed
@pytest.mark.parametrize('flagger', TESTFLAGGER)
def test_forceFlags(data, flagger):
    flagger = flagger.initFlags(data)
David Schäfer's avatar
David Schäfer committed
    field, *_ = data.columns
    flags_orig = flagger.setFlags(field).getFlags(field)
    _, flagger = forceFlags(data, field, flagger, flag=flagger.GOOD)
    flags_forced = flagger.getFlags(field)
    assert np.all(flags_orig != flags_forced)