From 4bd39f3134f2283463e0c0689a7418f89a6a4047 Mon Sep 17 00:00:00 2001 From: luenensc <peter.luenenschloss@ufz.de> Date: Thu, 28 Jul 2022 13:01:56 +0200 Subject: [PATCH] changes --- saqc/funcs/ml.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/saqc/funcs/ml.py b/saqc/funcs/ml.py index cbbd7a282..75f36ef3c 100644 --- a/saqc/funcs/ml.py +++ b/saqc/funcs/ml.py @@ -676,7 +676,10 @@ def trainDeepModel( if hasattr(model, 'compile'): model.compile(**compilation_kwargs) - if fit_kwargs.pop('shuffle_all', False): + shuffle_val = fit_kwargs.pop('shuffle_all', False) + if shuffle_val: + if isinstance(shuffle_val, int): + np.random.seed(shuffle_val) shuffle_ix = np.arange(x_train.shape[0]) np.random.shuffle(shuffle_ix) x_train = x_train[shuffle_ix,...] @@ -685,8 +688,8 @@ def trainDeepModel( history = model.fit(x_train, y_train, **fit_kwargs) - if os.path.exists(os.path.join(path,'bestCheck.index')): - model.load_weights(os.path.join(path,'bestCheck')) + if os.path.exists(os.path.join(path, 'bestCheck.index')): + model.load_weights(os.path.join(path, 'bestCheck')) y_pred_test, y_pred_train = model.predict(x_test), model.predict(x_train) -- GitLab