Skip to content
Snippets Groups Projects

Numba riddance

Merged Peter Lünenschloß requested to merge numbaRiddance into develop
3 unresolved threads
1 file
+ 13
39
Compare changes
  • Side-by-side
  • Inline
+ 13
39
@@ -176,17 +176,13 @@ def _fitPolynomial(
fitted = to_fit.rolling(
pd.Timedelta(window), closed="both", min_periods=min_periods, center=True
).apply(polyRollerIrregular, args=(centers, order))
else:
else: # if regular
if isinstance(window, str):
window = pd.Timedelta(window) // regular
if window % 2 == 0:
window = int(window - 1)
if min_periods is None:
min_periods = window
if len(to_fit) < 200000:
numba = False
else:
numba = True
val_range = np.arange(0, window)
center_index = window // 2
@@ -202,43 +198,21 @@ def _fitPolynomial(
miss_marker = np.floor(miss_marker - 1)
na_mask = to_fit.isna()
to_fit[na_mask] = miss_marker
if numba:
fitted = to_fit.rolling(window).apply(
polyRollerNumba,
args=(miss_marker, val_range, center_index, order),
raw=True,
engine="numba",
engine_kwargs={"no_python": True},
)
# due to a tiny bug - rolling with center=True doesnt work
# when using numba engine.
fitted = fitted.shift(-int(center_index))
else:
fitted = to_fit.rolling(window, center=True).apply(
polyRoller,
args=(miss_marker, val_range, center_index, order),
raw=True,
)
fitted = to_fit.rolling(window, center=True).apply(
polyRoller,
args=(miss_marker, val_range, center_index, order),
raw=True,
)
fitted[na_mask] = np.nan
else:
# we only fit fully populated intervals:
if numba:
fitted = to_fit.rolling(window).apply(
polyRollerNoMissingNumba,
args=(val_range, center_index, order),
engine="numba",
engine_kwargs={"no_python": True},
raw=True,
)
# due to a tiny bug - rolling with center=True doesnt work
# when using numba engine.
fitted = fitted.shift(-int(center_index))
else:
fitted = to_fit.rolling(window, center=True).apply(
polyRollerNoMissing,
args=(val_range, center_index, order),
raw=True,
)
fitted = to_fit.rolling(window, center=True).apply(
polyRollerNoMissing,
args=(val_range, center_index, order),
raw=True,
)
data[field] = fitted
worst = flags[field].rolling(window, center=True, min_periods=min_periods).max()
Loading