diff --git a/osl_dynamics/data/base.py b/osl_dynamics/data/base.py
index 0ccd01fb..9f0881f2 100644
--- a/osl_dynamics/data/base.py
+++ b/osl_dynamics/data/base.py
@@ -1569,11 +1569,11 @@ def _check_rewrite():
                 v=len(str(self.n_sessions - 1)),
             )
             tfrecord_filenames.append(filepath)
-            if (
-                rewrite
-                or not os.path.exists(filepath.format(val=0))
-                or not os.path.exists(filepath.format(val=1))
-            ):
+
+            rewrite_ = rewrite or not os.path.exists(filepath.format(val=0))
+            if validation_split is not None:
+                rewrite_ = rewrite_ or not os.path.exists(filepath.format(val=1))
+            if rewrite_:
                 tfrecords_to_save.append((i, filepath))
 
         # Function for saving a single TFRecord