diff --git a/dlio_benchmark/utils/config.py b/dlio_benchmark/utils/config.py index 8e25b148..e0d1a945 100644 --- a/dlio_benchmark/utils/config.py +++ b/dlio_benchmark/utils/config.py @@ -138,6 +138,7 @@ class ConfigArguments: data_loader_class = None reader_class = None checkpoint_mechanism_class = None + native_data_loader = False def __init__(self): """ Virtually private constructor. """ @@ -300,6 +301,13 @@ def derive_configurations(self, file_list_train=None, file_list_eval=None): logging.info(f"Discovered custom data reader {class_name}") self.reader_class = obj break + self.native_data_loader = False + if self.data_loader == DataLoaderType.TENSORFLOW: + if self.format == FormatType.TFRECORD: + self.native_data_loader = True + elif self.data_loader == DataLoaderType.NATIVE_DALI: + if self.format in [FormatType.JPEG, FormatType.PNG, FormatType.NPY, FormatType.TFRECORD]: + self.native_data_loader = True @dlp.log def build_sample_map_iter(self, file_list, total_samples, epoch_number): @@ -361,26 +369,28 @@ def get_global_map_index(self, file_list, total_samples): @dlp.log def reconfigure(self, epoch_number, dataset_type): - if self.data_loader_sampler == DataLoaderSampler.ITERATIVE: - if self.file_shuffle is not Shuffle.OFF: - if self.seed_change_epoch: - np.random.seed(self.seed + epoch_number) + # the code assumes that file and sample shuffling is handled by the native data loader code. + if not self.native_data_loader: + if self.data_loader_sampler == DataLoaderSampler.ITERATIVE: + if self.file_shuffle is not Shuffle.OFF: + if self.seed_change_epoch: + np.random.seed(self.seed + epoch_number) + else: + np.random.seed(self.seed) + np.random.shuffle(self.file_list_train) if dataset_type is DatasetType.TRAIN else np.random.shuffle( + self.file_list_eval) + if self.data_loader_sampler == DataLoaderSampler.ITERATIVE: + if dataset_type is DatasetType.TRAIN: + global_file_map = self.build_sample_map_iter(self.file_list_train, self.total_samples_train, + epoch_number) else: - np.random.seed(self.seed) - np.random.shuffle(self.file_list_train) if dataset_type is DatasetType.TRAIN else np.random.shuffle( - self.file_list_eval) - if self.data_loader_sampler == DataLoaderSampler.ITERATIVE: - if dataset_type is DatasetType.TRAIN: - global_file_map = self.build_sample_map_iter(self.file_list_train, self.total_samples_train, - epoch_number) - else: - global_file_map = self.build_sample_map_iter(self.file_list_eval, self.total_samples_eval, epoch_number) - self.file_map = global_file_map[self.my_rank] - elif self.data_loader_sampler == DataLoaderSampler.INDEX: - if dataset_type is DatasetType.TRAIN: - self.global_index_map = self.get_global_map_index(self.file_list_train, self.total_samples_train) - else: - self.global_index_map = self.get_global_map_index(self.file_list_eval, self.total_samples_eval) + global_file_map = self.build_sample_map_iter(self.file_list_eval, self.total_samples_eval, epoch_number) + self.file_map = global_file_map[self.my_rank] + elif self.data_loader_sampler == DataLoaderSampler.INDEX: + if dataset_type is DatasetType.TRAIN: + self.global_index_map = self.get_global_map_index(self.file_list_train, self.total_samples_train) + else: + self.global_index_map = self.get_global_map_index(self.file_list_eval, self.total_samples_eval) def LoadConfig(args, config):