diff --git a/osl_dynamics/data/base.py b/osl_dynamics/data/base.py index 0ccd01fb..9f0881f2 100644 --- a/osl_dynamics/data/base.py +++ b/osl_dynamics/data/base.py @@ -1569,11 +1569,11 @@ def _check_rewrite(): v=len(str(self.n_sessions - 1)), ) tfrecord_filenames.append(filepath) - if ( - rewrite - or not os.path.exists(filepath.format(val=0)) - or not os.path.exists(filepath.format(val=1)) - ): + + rewrite_ = rewrite or not os.path.exists(filepath.format(val=0)) + if validation_split is not None: + rewrite_ = rewrite_ or not os.path.exists(filepath.format(val=1)) + if rewrite_: tfrecords_to_save.append((i, filepath)) # Function for saving a single TFRecord