From 4bd39f3134f2283463e0c0689a7418f89a6a4047 Mon Sep 17 00:00:00 2001
From: luenensc <peter.luenenschloss@ufz.de>
Date: Thu, 28 Jul 2022 13:01:56 +0200
Subject: [PATCH] changes

---
 saqc/funcs/ml.py | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/saqc/funcs/ml.py b/saqc/funcs/ml.py
index cbbd7a282..75f36ef3c 100644
--- a/saqc/funcs/ml.py
+++ b/saqc/funcs/ml.py
@@ -676,7 +676,10 @@ def trainDeepModel(
     if hasattr(model, 'compile'):
         model.compile(**compilation_kwargs)
 
-    if fit_kwargs.pop('shuffle_all', False):
+    shuffle_val = fit_kwargs.pop('shuffle_all', False)
+    if shuffle_val:
+        if isinstance(shuffle_val, int):
+            np.random.seed(shuffle_val)
         shuffle_ix = np.arange(x_train.shape[0])
         np.random.shuffle(shuffle_ix)
         x_train = x_train[shuffle_ix,...]
@@ -685,8 +688,8 @@ def trainDeepModel(
 
     history = model.fit(x_train, y_train, **fit_kwargs)
 
-    if os.path.exists(os.path.join(path,'bestCheck.index')):
-        model.load_weights(os.path.join(path,'bestCheck'))
+    if os.path.exists(os.path.join(path, 'bestCheck.index')):
+        model.load_weights(os.path.join(path, 'bestCheck'))
 
     y_pred_test, y_pred_train = model.predict(x_test), model.predict(x_train)
 
-- 
GitLab