diff --git a/saqc/core/modules/pattern.py b/saqc/core/modules/pattern.py index 1d64e8e461655c0af544d209ccd621fe1f05fde0..5f1ffd07aed9a7a8dc31fe619d993a64f5036a69 100644 --- a/saqc/core/modules/pattern.py +++ b/saqc/core/modules/pattern.py @@ -14,8 +14,8 @@ class Pattern(ModuleBase): self, field: str, ref_field: str, - widths: Sequence[int] = (1, 2, 4, 8), - waveform: str = "mexh", + max_distance: float = 0.0, + normalize=True, flag: float = BAD, **kwargs ) -> saqc.SaQC: @@ -25,8 +25,8 @@ class Pattern(ModuleBase): self, field: str, ref_field: str, - max_distance: float = 0.03, - normalize: bool = True, + widths: Sequence[int] = (1, 2, 4, 8), + waveform: str = "mexh", flag: float = BAD, **kwargs ) -> saqc.SaQC: diff --git a/saqc/funcs/pattern.py b/saqc/funcs/pattern.py index 01933a730a5adbf25ab7a061b38b17049e5f00ae..be633b9af98ff18c3e8c55ad40361325d671bce2 100644 --- a/saqc/funcs/pattern.py +++ b/saqc/funcs/pattern.py @@ -1,35 +1,35 @@ #! /usr/bin/env python # -*- coding: utf-8 -*- -from typing import Sequence, Union, Tuple, Optional import numpy as np +import pandas as pd import dtw import pywt from mlxtend.evaluate import permutation_test -from dios import DictOfSeries from saqc.constants import * -from saqc.core import register, Flags +from saqc.core.register import register from saqc.lib.tools import customRoller @register(masking="field", module="pattern") -def flagPatternByDTW( - data: DictOfSeries, - field: str, - flags: Flags, - ref_field: str, - widths: Sequence[int] = (1, 2, 4, 8), - waveform: str = "mexh", - flag: float = BAD, +def flagPatternByWavelet( + data, + field, + flags, + ref_field, + widths=(1, 2, 4, 8), + waveform="mexh", + flag=BAD, **kwargs -) -> Tuple[DictOfSeries, Flags]: +): """ Pattern recognition via wavelets. The steps are: 1. work on chunks returned by a moving window - 2. each chunk is compared to the given pattern, using the wavelet algorithm as presented in [1] + 2. each chunk is compared to the given pattern, using the wavelet algorithm as + presented in [1] 3. if the compared chunk is equal to the given pattern it gets flagged Parameters @@ -37,30 +37,31 @@ def flagPatternByDTW( data : dios.DictOfSeries A dictionary of pandas.Series, holding all the data. + field : str The fieldname of the data column, you want to correct. + flags : saqc.Flags - Container to store quality flags to data. + The flags belongiong to `data`. + ref_field: str The fieldname in `data' which holds the pattern. + widths: tuple of int - Widths for wavelet decomposition. [1] recommends a dyadic scale. Default: (1,2,4,8) + Widths for wavelet decomposition. [1] recommends a dyadic scale. + Default: (1,2,4,8) + waveform: str. Wavelet to be used for decomposition. Default: 'mexh'. See [2] for a list. - flag : float, default BAD - flag to set. - - kwargs Returns ------- data : dios.DictOfSeries A dictionary of pandas.Series, holding all the data. Data values may have changed relatively to the data input. - flags : saqc.Flags - The quality flags of data - Flags values may have changed relatively to the flags input. + flags : saqc.Flags + The flags belongiong to `data`. References ---------- @@ -72,15 +73,20 @@ def flagPatternByDTW( [2] https://pywavelets.readthedocs.io/en/latest/ref/cwt.html#continuous-wavelet-families """ + dat = data[field] ref = data[ref_field].to_numpy() cwtmat_ref, _ = pywt.cwt(ref, widths, waveform) wavepower_ref = np.power(cwtmat_ref, 2) len_width = len(widths) + sz = len(ref) + + assert len_width + assert sz def func(x, y): return x.sum() / y.sum() - def isPattern(chunk): + def pvalue(chunk): cwtmat_chunk, _ = pywt.cwt(chunk, widths, waveform) wavepower_chunk = np.power(cwtmat_chunk, 2) @@ -91,63 +97,142 @@ def flagPatternByDTW( pval = permutation_test( x, y, method="approximate", num_rounds=200, func=func, seed=0 ) - if min(pval, 1 - pval) > 0.01: - return True - return False + pval = min(pval, 1 - pval) + return pval # noqa # existence ensured by assert - dat = data[field] - sz = len(ref) - mask = customRoller(dat, window=sz, min_periods=sz).apply(isPattern, raw=True) + rolling = customRoller(dat, window=sz, min_periods=sz, forward=True) + pvals = rolling.apply(pvalue, raw=True) + markers = pvals > 0.01 # nans -> False + + # the markers are set on the left edge of the window. thus we must propagate + # `sz`-many True's to the right of every marker. + rolling = customRoller(markers, window=sz, min_periods=sz) + mask = rolling.sum().fillna(0).astype(bool) flags[mask, field] = flag return data, flags +def calculateDistanceByDTW( + data: pd.Series, reference: pd.Series, forward=True, normalize=True +): + """ + Calculate the DTW-distance of data to pattern in a rolling calculation. + + The data is compared to pattern in a rolling window. + The size of the rolling window is determined by the timespan defined + by the first and last timestamp of the reference data's datetime index. + + For details see the linked functions in the `See Also` section. + + Parameters + ---------- + data : pd.Series + Data series. Must have datetime-like index, and must be regularly sampled. + + reference : : pd.Series + Reference series. Must have datetime-like index, must not contain NaNs + and must not be empty. + + forward: bool, default True + If `True`, the distance value is set on the left edge of the data chunk. This + means, with a perfect match, `0.0` marks the beginning of the pattern in + the data. If `False`, `0.0` would mark the end of the pattern. + + normalize : bool, default True + If `False`, return unmodified distances. + If `True`, normalize distances by the number of observations in the reference. + This helps to make it easier to find a good cutoff threshold for further + processing. The distances then refer to the mean distance per datapoint, + expressed in the datas units. + + Returns + ------- + distance : pd.Series + + Notes + ----- + The data must be regularly sampled, otherwise a ValueError is raised. + NaNs in the data will be dropped before dtw distance calculation. + + See Also + -------- + flagPatternByDTW : flag data by DTW + """ + if reference.hasnans or reference.empty: + raise ValueError("reference must not have nan's and must not be empty.") + + # TODO: rm `+ pd.Timedelta('1ns')` as soon as #GL214 is fixed, + # add closed=both to customRoller instead + winsz = reference.index.max() - reference.index.min() + pd.Timedelta("1ns") + reference = reference.to_numpy() + + def isPattern(chunk): + return dtw.accelerated_dtw(chunk, reference, "euclidean")[0] + + # generate distances, excluding NaNs + rolling = customRoller(data.dropna(), window=winsz, forward=forward, expand=False) + distances: pd.Series = rolling.apply(isPattern, raw=True) + + if normalize: + distances /= len(reference) + + return distances.reindex(index=data.index) # reinsert NaNs + + @register(masking="field", module="pattern") -def flagPatternByWavelet( - data: DictOfSeries, - field: str, - flags: Flags, - ref_field: str, - max_distance: float = 0.03, - normalize: bool = True, - flag: float = BAD, - **kwargs -) -> Tuple[DictOfSeries, Flags]: +def flagPatternByDTW( + data, field, flags, ref_field, max_distance=0.0, normalize=True, flag=BAD, **kwargs +): """Pattern Recognition via Dynamic Time Warping. The steps are: - 1. work on chunks returned by a moving window - 2. each chunk is compared to the given pattern, using the dynamic time warping algorithm as presented in [1] - 3. if the compared chunk is equal to the given pattern it gets flagged + 1. work on a moving window + 2. for each data chunk extracted from each window, a distance to the given pattern + is calculated, by the dynamic time warping algorithm [1] + 3. if the distance is below the threshold, all the data in the window gets flagged Parameters ---------- - data : dios.DictOfSeries A dictionary of pandas.Series, holding all the data. + field : str - The fieldname of the data column, you want to correct. + The name of the data column + flags : saqc.Flags - Container to store quality flags to data. - ref_field: str - The fieldname in `data` which holds the pattern. - max_distance: float - Maximum dtw-distance between partition and pattern, so that partition is recognized as pattern. Default: 0.03 - normalize: boolean. - Normalizing dtw-distance (see [1]). Default: True - flag : float, default BAD - flag to set. + The flags belonging to `data`. + + ref_field : str + The name in `data` which holds the pattern. The pattern must not have NaNs, + have a datetime index and must not be empty. + + max_distance : float, default 0.0 + Maximum dtw-distance between chunk and pattern, if the distance is lower than + ``max_distance`` the data gets flagged. With default, ``0.0``, only exact + matches are flagged. + + normalize : bool, default True + If `False`, return unmodified distances. + If `True`, normalize distances by the number of observations of the reference. + This helps to make it easier to find a good cutoff threshold for further + processing. The distances then refer to the mean distance per datapoint, + expressed in the datas units. + Returns ------- data : dios.DictOfSeries A dictionary of pandas.Series, holding all the data. Data values may have changed relatively to the data input. + flags : saqc.Flags - The quality flags of data - Flags values may have changed relatively to the flags input. + The flags belonging to `data`. + Notes + ----- + The window size of the moving window is set to equal the temporal extension of the + reference datas datetime index. References ---------- @@ -156,20 +241,24 @@ def flagPatternByWavelet( [1] https://cran.r-project.org/web/packages/dtw/dtw.pdf """ ref = data[ref_field] - ref_var = ref.var() + dat = data[field] - def func(a, b): - return np.linalg.norm(a - b) + distances = calculateDistanceByDTW(dat, ref, forward=True, normalize=normalize) + # TODO: rm `+ pd.Timedelta('1ns')` as soon as #GL214 is fixed, + # add closed=both to customRoller instead + winsz = ref.index.max() - ref.index.min() + pd.Timedelta("1ns") - def isPattern(chunk): - dist, *_ = dtw.dtw(chunk, ref, func) - if normalize: - dist /= ref_var - return dist < max_distance + # prevent nan propagation + distances = distances.fillna(max_distance + 1) - dat = data[field] - sz = len(ref) - mask = customRoller(dat, window=sz, min_periods=sz).apply(isPattern, raw=True) + # find minima filter by threshold + fw = customRoller(distances, window=winsz, forward=True) + bw = customRoller(distances, window=winsz) + minima = (fw.min() == bw.min()) & (distances <= max_distance) + + # Propagate True's to size of pattern. + rolling = customRoller(minima, window=winsz) + mask = rolling.sum() > 0 flags[mask, field] = flag return data, flags diff --git a/tests/funcs/test_pattern_rec.py b/tests/funcs/test_pattern_rec.py index d8d67e324e3e8376adb276d40ecbd498697364ab..0af516990eb242580d1b9ad7c733a5772040f04c 100644 --- a/tests/funcs/test_pattern_rec.py +++ b/tests/funcs/test_pattern_rec.py @@ -21,35 +21,38 @@ def field(data): return data.columns[0] -@pytest.mark.skip(reason="faulty implementation - will get fixed by GL-MR191") +@pytest.mark.skip(reason="faulty implementation - wait for #GL216") def test_flagPattern_wavelet(): data = pd.Series(0, index=pd.date_range(start="2000", end="2001", freq="1d")) - data.iloc[2:4] = 7 - pattern = data.iloc[1:6] + data.iloc[10:18] = [0, 5, 6, 7, 6, 8, 5, 0] + pattern = data.iloc[10:18] data = dios.DictOfSeries(dict(data=data, pattern_data=pattern)) flags = initFlagsLike(data, name="data") - data, flags = flagPatternByDTW( + data, flags = flagPatternByWavelet( data, "data", flags, ref_field="pattern_data", flag=BAD ) - assert all(flags["data"][1:6]) - assert any(flags["data"][:1]) - assert any(flags["data"][7:]) + assert all(flags["data"].iloc[10:18] == BAD) + assert all(flags["data"].iloc[:9] == UNFLAGGED) + assert all(flags["data"].iloc[18:] == UNFLAGGED) -@pytest.mark.skip(reason="faulty implementation - will get fixed by GL-MR191") def test_flagPattern_dtw(): data = pd.Series(0, index=pd.date_range(start="2000", end="2001", freq="1d")) - data.iloc[2:4] = 7 - pattern = data.iloc[1:6] + data.iloc[10:18] = [0, 5, 6, 7, 6, 8, 5, 0] + pattern = data.iloc[10:18] data = dios.DictOfSeries(dict(data=data, pattern_data=pattern)) flags = initFlagsLike(data, name="data") - data, flags = flagPatternByWavelet( + data, flags = flagPatternByDTW( data, "data", flags, ref_field="pattern_data", flag=BAD ) - assert all(flags["data"][1:6]) - assert any(flags["data"][:1]) - assert any(flags["data"][7:]) + assert all(flags["data"].iloc[10:18] == BAD) + assert all(flags["data"].iloc[:9] == UNFLAGGED) + assert all(flags["data"].iloc[18:] == UNFLAGGED) + + # visualize: + # data['data'].plot() + # ((flags['data']>0) *5.).plot()