From bcfd98ccf7966a04c3fda7b4e01c3808c55a19dc Mon Sep 17 00:00:00 2001 From: hariharandev1 Date: Sun, 28 Jul 2024 23:31:59 -0700 Subject: [PATCH] Ignore file shuffling and indexing for native data loader. The sample building and native data loader case is needed only for DLIO created data loaders. For native data loaders which provide their own API;s they provide their own indexing and there this sampling can be ignored. --- dlio_benchmark/utils/config.py | 48 ++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/dlio_benchmark/utils/config.py b/dlio_benchmark/utils/config.py index 8e25b148..e0d1a945 100644 --- a/dlio_benchmark/utils/config.py +++ b/dlio_benchmark/utils/config.py @@ -138,6 +138,7 @@ class ConfigArguments: data_loader_class = None reader_class = None checkpoint_mechanism_class = None + native_data_loader = False def __init__(self): """ Virtually private constructor. """ @@ -300,6 +301,13 @@ def derive_configurations(self, file_list_train=None, file_list_eval=None): logging.info(f"Discovered custom data reader {class_name}") self.reader_class = obj break + self.native_data_loader = False + if self.data_loader == DataLoaderType.TENSORFLOW: + if self.format == FormatType.TFRECORD: + self.native_data_loader = True + elif self.data_loader == DataLoaderType.NATIVE_DALI: + if self.format in [FormatType.JPEG, FormatType.PNG, FormatType.NPY, FormatType.TFRECORD]: + self.native_data_loader = True @dlp.log def build_sample_map_iter(self, file_list, total_samples, epoch_number): @@ -361,26 +369,28 @@ def get_global_map_index(self, file_list, total_samples): @dlp.log def reconfigure(self, epoch_number, dataset_type): - if self.data_loader_sampler == DataLoaderSampler.ITERATIVE: - if self.file_shuffle is not Shuffle.OFF: - if self.seed_change_epoch: - np.random.seed(self.seed + epoch_number) + # the code assumes that file and sample shuffling is handled by the native data loader code. + if not self.native_data_loader: + if self.data_loader_sampler == DataLoaderSampler.ITERATIVE: + if self.file_shuffle is not Shuffle.OFF: + if self.seed_change_epoch: + np.random.seed(self.seed + epoch_number) + else: + np.random.seed(self.seed) + np.random.shuffle(self.file_list_train) if dataset_type is DatasetType.TRAIN else np.random.shuffle( + self.file_list_eval) + if self.data_loader_sampler == DataLoaderSampler.ITERATIVE: + if dataset_type is DatasetType.TRAIN: + global_file_map = self.build_sample_map_iter(self.file_list_train, self.total_samples_train, + epoch_number) else: - np.random.seed(self.seed) - np.random.shuffle(self.file_list_train) if dataset_type is DatasetType.TRAIN else np.random.shuffle( - self.file_list_eval) - if self.data_loader_sampler == DataLoaderSampler.ITERATIVE: - if dataset_type is DatasetType.TRAIN: - global_file_map = self.build_sample_map_iter(self.file_list_train, self.total_samples_train, - epoch_number) - else: - global_file_map = self.build_sample_map_iter(self.file_list_eval, self.total_samples_eval, epoch_number) - self.file_map = global_file_map[self.my_rank] - elif self.data_loader_sampler == DataLoaderSampler.INDEX: - if dataset_type is DatasetType.TRAIN: - self.global_index_map = self.get_global_map_index(self.file_list_train, self.total_samples_train) - else: - self.global_index_map = self.get_global_map_index(self.file_list_eval, self.total_samples_eval) + global_file_map = self.build_sample_map_iter(self.file_list_eval, self.total_samples_eval, epoch_number) + self.file_map = global_file_map[self.my_rank] + elif self.data_loader_sampler == DataLoaderSampler.INDEX: + if dataset_type is DatasetType.TRAIN: + self.global_index_map = self.get_global_map_index(self.file_list_train, self.total_samples_train) + else: + self.global_index_map = self.get_global_map_index(self.file_list_eval, self.total_samples_eval) def LoadConfig(args, config):