diff --git a/saqc/funcs/pattern.py b/saqc/funcs/pattern.py index 52437b89bcfd87e77b677161f455faa2d8472c50..a33cdceaef496a8d4bcb1087da297b6167403419 100644 --- a/saqc/funcs/pattern.py +++ b/saqc/funcs/pattern.py @@ -1,17 +1,11 @@ #! /usr/bin/env python # -*- coding: utf-8 -*- -from saqc.core.modules import base from typing import Sequence, Union, Tuple, Optional -from typing_extensions import Literal - import numpy as np - import dtw import pywt - from mlxtend.evaluate import permutation_test - from dios.dios import DictOfSeries from saqc.core.register import register @@ -100,7 +94,7 @@ def flagPatternByDTW( sz = len(ref) mask = customRoller(dat, window=sz, min_periods=sz).apply(isPattern, raw=True) - flagger = flagger.setFlags(field, loc=mask, **kwargs) + flagger[mask, field] = kwargs['flag'] return data, flagger @@ -108,7 +102,7 @@ def flagPatternByDTW( def flagPatternByWavelet( data: DictOfSeries, field: str, - flagger: base, + flagger: Flagger, ref_field: str, max_distance: float=0.03, normalize: bool=True, @@ -172,5 +166,5 @@ def flagPatternByWavelet( sz = len(ref) mask = customRoller(dat, window=sz, min_periods=sz).apply(isPattern, raw=True) - flagger = flagger.setFlags(field, loc=mask, **kwargs) + flagger[mask, field] = kwargs['flag'] return data, flagger