Skip to content
Snippets Groups Projects
Commit b181aab4 authored by Bert Palm's avatar Bert Palm 🎇
Browse files

Merge branch 'patternFix' into 'develop'

fix for patternrec

See merge request !191
parents 0835c4b2 9675ffd8
No related branches found
No related tags found
1 merge request!191fix for patternrec
Pipeline #30711 passed with stage
in 2 minutes and 8 seconds
......@@ -14,8 +14,8 @@ class Pattern(ModuleBase):
self,
field: str,
ref_field: str,
widths: Sequence[int] = (1, 2, 4, 8),
waveform: str = "mexh",
max_distance: float = 0.0,
normalize=True,
flag: float = BAD,
**kwargs
) -> saqc.SaQC:
......@@ -25,8 +25,8 @@ class Pattern(ModuleBase):
self,
field: str,
ref_field: str,
max_distance: float = 0.03,
normalize: bool = True,
widths: Sequence[int] = (1, 2, 4, 8),
waveform: str = "mexh",
flag: float = BAD,
**kwargs
) -> saqc.SaQC:
......
#! /usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Sequence, Union, Tuple, Optional
import numpy as np
import pandas as pd
import dtw
import pywt
from mlxtend.evaluate import permutation_test
from dios import DictOfSeries
from saqc.constants import *
from saqc.core import register, Flags
from saqc.core.register import register
from saqc.lib.tools import customRoller
@register(masking="field", module="pattern")
def flagPatternByDTW(
data: DictOfSeries,
field: str,
flags: Flags,
ref_field: str,
widths: Sequence[int] = (1, 2, 4, 8),
waveform: str = "mexh",
flag: float = BAD,
def flagPatternByWavelet(
data,
field,
flags,
ref_field,
widths=(1, 2, 4, 8),
waveform="mexh",
flag=BAD,
**kwargs
) -> Tuple[DictOfSeries, Flags]:
):
"""
Pattern recognition via wavelets.
The steps are:
1. work on chunks returned by a moving window
2. each chunk is compared to the given pattern, using the wavelet algorithm as presented in [1]
2. each chunk is compared to the given pattern, using the wavelet algorithm as
presented in [1]
3. if the compared chunk is equal to the given pattern it gets flagged
Parameters
......@@ -37,30 +37,31 @@ def flagPatternByDTW(
data : dios.DictOfSeries
A dictionary of pandas.Series, holding all the data.
field : str
The fieldname of the data column, you want to correct.
flags : saqc.Flags
Container to store quality flags to data.
The flags belongiong to `data`.
ref_field: str
The fieldname in `data' which holds the pattern.
widths: tuple of int
Widths for wavelet decomposition. [1] recommends a dyadic scale. Default: (1,2,4,8)
Widths for wavelet decomposition. [1] recommends a dyadic scale.
Default: (1,2,4,8)
waveform: str.
Wavelet to be used for decomposition. Default: 'mexh'. See [2] for a list.
flag : float, default BAD
flag to set.
kwargs
Returns
-------
data : dios.DictOfSeries
A dictionary of pandas.Series, holding all the data.
Data values may have changed relatively to the data input.
flags : saqc.Flags
The quality flags of data
Flags values may have changed relatively to the flags input.
flags : saqc.Flags
The flags belongiong to `data`.
References
----------
......@@ -72,15 +73,20 @@ def flagPatternByDTW(
[2] https://pywavelets.readthedocs.io/en/latest/ref/cwt.html#continuous-wavelet-families
"""
dat = data[field]
ref = data[ref_field].to_numpy()
cwtmat_ref, _ = pywt.cwt(ref, widths, waveform)
wavepower_ref = np.power(cwtmat_ref, 2)
len_width = len(widths)
sz = len(ref)
assert len_width
assert sz
def func(x, y):
return x.sum() / y.sum()
def isPattern(chunk):
def pvalue(chunk):
cwtmat_chunk, _ = pywt.cwt(chunk, widths, waveform)
wavepower_chunk = np.power(cwtmat_chunk, 2)
......@@ -91,63 +97,142 @@ def flagPatternByDTW(
pval = permutation_test(
x, y, method="approximate", num_rounds=200, func=func, seed=0
)
if min(pval, 1 - pval) > 0.01:
return True
return False
pval = min(pval, 1 - pval)
return pval # noqa # existence ensured by assert
dat = data[field]
sz = len(ref)
mask = customRoller(dat, window=sz, min_periods=sz).apply(isPattern, raw=True)
rolling = customRoller(dat, window=sz, min_periods=sz, forward=True)
pvals = rolling.apply(pvalue, raw=True)
markers = pvals > 0.01 # nans -> False
# the markers are set on the left edge of the window. thus we must propagate
# `sz`-many True's to the right of every marker.
rolling = customRoller(markers, window=sz, min_periods=sz)
mask = rolling.sum().fillna(0).astype(bool)
flags[mask, field] = flag
return data, flags
def calculateDistanceByDTW(
data: pd.Series, reference: pd.Series, forward=True, normalize=True
):
"""
Calculate the DTW-distance of data to pattern in a rolling calculation.
The data is compared to pattern in a rolling window.
The size of the rolling window is determined by the timespan defined
by the first and last timestamp of the reference data's datetime index.
For details see the linked functions in the `See Also` section.
Parameters
----------
data : pd.Series
Data series. Must have datetime-like index, and must be regularly sampled.
reference : : pd.Series
Reference series. Must have datetime-like index, must not contain NaNs
and must not be empty.
forward: bool, default True
If `True`, the distance value is set on the left edge of the data chunk. This
means, with a perfect match, `0.0` marks the beginning of the pattern in
the data. If `False`, `0.0` would mark the end of the pattern.
normalize : bool, default True
If `False`, return unmodified distances.
If `True`, normalize distances by the number of observations in the reference.
This helps to make it easier to find a good cutoff threshold for further
processing. The distances then refer to the mean distance per datapoint,
expressed in the datas units.
Returns
-------
distance : pd.Series
Notes
-----
The data must be regularly sampled, otherwise a ValueError is raised.
NaNs in the data will be dropped before dtw distance calculation.
See Also
--------
flagPatternByDTW : flag data by DTW
"""
if reference.hasnans or reference.empty:
raise ValueError("reference must not have nan's and must not be empty.")
# TODO: rm `+ pd.Timedelta('1ns')` as soon as #GL214 is fixed,
# add closed=both to customRoller instead
winsz = reference.index.max() - reference.index.min() + pd.Timedelta("1ns")
reference = reference.to_numpy()
def isPattern(chunk):
return dtw.accelerated_dtw(chunk, reference, "euclidean")[0]
# generate distances, excluding NaNs
rolling = customRoller(data.dropna(), window=winsz, forward=forward, expand=False)
distances: pd.Series = rolling.apply(isPattern, raw=True)
if normalize:
distances /= len(reference)
return distances.reindex(index=data.index) # reinsert NaNs
@register(masking="field", module="pattern")
def flagPatternByWavelet(
data: DictOfSeries,
field: str,
flags: Flags,
ref_field: str,
max_distance: float = 0.03,
normalize: bool = True,
flag: float = BAD,
**kwargs
) -> Tuple[DictOfSeries, Flags]:
def flagPatternByDTW(
data, field, flags, ref_field, max_distance=0.0, normalize=True, flag=BAD, **kwargs
):
"""Pattern Recognition via Dynamic Time Warping.
The steps are:
1. work on chunks returned by a moving window
2. each chunk is compared to the given pattern, using the dynamic time warping algorithm as presented in [1]
3. if the compared chunk is equal to the given pattern it gets flagged
1. work on a moving window
2. for each data chunk extracted from each window, a distance to the given pattern
is calculated, by the dynamic time warping algorithm [1]
3. if the distance is below the threshold, all the data in the window gets flagged
Parameters
----------
data : dios.DictOfSeries
A dictionary of pandas.Series, holding all the data.
field : str
The fieldname of the data column, you want to correct.
The name of the data column
flags : saqc.Flags
Container to store quality flags to data.
ref_field: str
The fieldname in `data` which holds the pattern.
max_distance: float
Maximum dtw-distance between partition and pattern, so that partition is recognized as pattern. Default: 0.03
normalize: boolean.
Normalizing dtw-distance (see [1]). Default: True
flag : float, default BAD
flag to set.
The flags belonging to `data`.
ref_field : str
The name in `data` which holds the pattern. The pattern must not have NaNs,
have a datetime index and must not be empty.
max_distance : float, default 0.0
Maximum dtw-distance between chunk and pattern, if the distance is lower than
``max_distance`` the data gets flagged. With default, ``0.0``, only exact
matches are flagged.
normalize : bool, default True
If `False`, return unmodified distances.
If `True`, normalize distances by the number of observations of the reference.
This helps to make it easier to find a good cutoff threshold for further
processing. The distances then refer to the mean distance per datapoint,
expressed in the datas units.
Returns
-------
data : dios.DictOfSeries
A dictionary of pandas.Series, holding all the data.
Data values may have changed relatively to the data input.
flags : saqc.Flags
The quality flags of data
Flags values may have changed relatively to the flags input.
The flags belonging to `data`.
Notes
-----
The window size of the moving window is set to equal the temporal extension of the
reference datas datetime index.
References
----------
......@@ -156,20 +241,24 @@ def flagPatternByWavelet(
[1] https://cran.r-project.org/web/packages/dtw/dtw.pdf
"""
ref = data[ref_field]
ref_var = ref.var()
dat = data[field]
def func(a, b):
return np.linalg.norm(a - b)
distances = calculateDistanceByDTW(dat, ref, forward=True, normalize=normalize)
# TODO: rm `+ pd.Timedelta('1ns')` as soon as #GL214 is fixed,
# add closed=both to customRoller instead
winsz = ref.index.max() - ref.index.min() + pd.Timedelta("1ns")
def isPattern(chunk):
dist, *_ = dtw.dtw(chunk, ref, func)
if normalize:
dist /= ref_var
return dist < max_distance
# prevent nan propagation
distances = distances.fillna(max_distance + 1)
dat = data[field]
sz = len(ref)
mask = customRoller(dat, window=sz, min_periods=sz).apply(isPattern, raw=True)
# find minima filter by threshold
fw = customRoller(distances, window=winsz, forward=True)
bw = customRoller(distances, window=winsz)
minima = (fw.min() == bw.min()) & (distances <= max_distance)
# Propagate True's to size of pattern.
rolling = customRoller(minima, window=winsz)
mask = rolling.sum() > 0
flags[mask, field] = flag
return data, flags
......@@ -21,35 +21,38 @@ def field(data):
return data.columns[0]
@pytest.mark.skip(reason="faulty implementation - will get fixed by GL-MR191")
@pytest.mark.skip(reason="faulty implementation - wait for #GL216")
def test_flagPattern_wavelet():
data = pd.Series(0, index=pd.date_range(start="2000", end="2001", freq="1d"))
data.iloc[2:4] = 7
pattern = data.iloc[1:6]
data.iloc[10:18] = [0, 5, 6, 7, 6, 8, 5, 0]
pattern = data.iloc[10:18]
data = dios.DictOfSeries(dict(data=data, pattern_data=pattern))
flags = initFlagsLike(data, name="data")
data, flags = flagPatternByDTW(
data, flags = flagPatternByWavelet(
data, "data", flags, ref_field="pattern_data", flag=BAD
)
assert all(flags["data"][1:6])
assert any(flags["data"][:1])
assert any(flags["data"][7:])
assert all(flags["data"].iloc[10:18] == BAD)
assert all(flags["data"].iloc[:9] == UNFLAGGED)
assert all(flags["data"].iloc[18:] == UNFLAGGED)
@pytest.mark.skip(reason="faulty implementation - will get fixed by GL-MR191")
def test_flagPattern_dtw():
data = pd.Series(0, index=pd.date_range(start="2000", end="2001", freq="1d"))
data.iloc[2:4] = 7
pattern = data.iloc[1:6]
data.iloc[10:18] = [0, 5, 6, 7, 6, 8, 5, 0]
pattern = data.iloc[10:18]
data = dios.DictOfSeries(dict(data=data, pattern_data=pattern))
flags = initFlagsLike(data, name="data")
data, flags = flagPatternByWavelet(
data, flags = flagPatternByDTW(
data, "data", flags, ref_field="pattern_data", flag=BAD
)
assert all(flags["data"][1:6])
assert any(flags["data"][:1])
assert any(flags["data"][7:])
assert all(flags["data"].iloc[10:18] == BAD)
assert all(flags["data"].iloc[:9] == UNFLAGGED)
assert all(flags["data"].iloc[18:] == UNFLAGGED)
# visualize:
# data['data'].plot()
# ((flags['data']>0) *5.).plot()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment