diff --git a/saqc/funcs/ml.py b/saqc/funcs/ml.py
index cbbd7a2822e7dae6f085633ce02cde368f4e28d9..75f36ef3ce711a57943c01c024a36e9cd57cc0ca 100644
--- a/saqc/funcs/ml.py
+++ b/saqc/funcs/ml.py
@@ -676,7 +676,10 @@ def trainDeepModel(
     if hasattr(model, 'compile'):
         model.compile(**compilation_kwargs)
 
-    if fit_kwargs.pop('shuffle_all', False):
+    shuffle_val = fit_kwargs.pop('shuffle_all', False)
+    if shuffle_val:
+        if isinstance(shuffle_val, int):
+            np.random.seed(shuffle_val)
         shuffle_ix = np.arange(x_train.shape[0])
         np.random.shuffle(shuffle_ix)
         x_train = x_train[shuffle_ix,...]
@@ -685,8 +688,8 @@ def trainDeepModel(
 
     history = model.fit(x_train, y_train, **fit_kwargs)
 
-    if os.path.exists(os.path.join(path,'bestCheck.index')):
-        model.load_weights(os.path.join(path,'bestCheck'))
+    if os.path.exists(os.path.join(path, 'bestCheck.index')):
+        model.load_weights(os.path.join(path, 'bestCheck'))
 
     y_pred_test, y_pred_train = model.predict(x_test), model.predict(x_train)