Skip to content

Commit

Permalink
Ignore file shuffling and indexing for native data loader.
Browse files Browse the repository at this point in the history
The sample building and native data loader case is needed only for DLIO created data loaders. For native data loaders which provide their own API;s they provide their own indexing and there this sampling can be ignored.
  • Loading branch information
hariharan-devarajan committed Jul 29, 2024
1 parent 33ec34e commit bcfd98c
Showing 1 changed file with 29 additions and 19 deletions.
48 changes: 29 additions & 19 deletions dlio_benchmark/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class ConfigArguments:
data_loader_class = None
reader_class = None
checkpoint_mechanism_class = None
native_data_loader = False

def __init__(self):
""" Virtually private constructor. """
Expand Down Expand Up @@ -300,6 +301,13 @@ def derive_configurations(self, file_list_train=None, file_list_eval=None):
logging.info(f"Discovered custom data reader {class_name}")
self.reader_class = obj
break
self.native_data_loader = False
if self.data_loader == DataLoaderType.TENSORFLOW:
if self.format == FormatType.TFRECORD:
self.native_data_loader = True
elif self.data_loader == DataLoaderType.NATIVE_DALI:
if self.format in [FormatType.JPEG, FormatType.PNG, FormatType.NPY, FormatType.TFRECORD]:
self.native_data_loader = True

@dlp.log
def build_sample_map_iter(self, file_list, total_samples, epoch_number):
Expand Down Expand Up @@ -361,26 +369,28 @@ def get_global_map_index(self, file_list, total_samples):

@dlp.log
def reconfigure(self, epoch_number, dataset_type):
if self.data_loader_sampler == DataLoaderSampler.ITERATIVE:
if self.file_shuffle is not Shuffle.OFF:
if self.seed_change_epoch:
np.random.seed(self.seed + epoch_number)
# the code assumes that file and sample shuffling is handled by the native data loader code.
if not self.native_data_loader:
if self.data_loader_sampler == DataLoaderSampler.ITERATIVE:
if self.file_shuffle is not Shuffle.OFF:
if self.seed_change_epoch:
np.random.seed(self.seed + epoch_number)
else:
np.random.seed(self.seed)
np.random.shuffle(self.file_list_train) if dataset_type is DatasetType.TRAIN else np.random.shuffle(
self.file_list_eval)
if self.data_loader_sampler == DataLoaderSampler.ITERATIVE:
if dataset_type is DatasetType.TRAIN:
global_file_map = self.build_sample_map_iter(self.file_list_train, self.total_samples_train,
epoch_number)
else:
np.random.seed(self.seed)
np.random.shuffle(self.file_list_train) if dataset_type is DatasetType.TRAIN else np.random.shuffle(
self.file_list_eval)
if self.data_loader_sampler == DataLoaderSampler.ITERATIVE:
if dataset_type is DatasetType.TRAIN:
global_file_map = self.build_sample_map_iter(self.file_list_train, self.total_samples_train,
epoch_number)
else:
global_file_map = self.build_sample_map_iter(self.file_list_eval, self.total_samples_eval, epoch_number)
self.file_map = global_file_map[self.my_rank]
elif self.data_loader_sampler == DataLoaderSampler.INDEX:
if dataset_type is DatasetType.TRAIN:
self.global_index_map = self.get_global_map_index(self.file_list_train, self.total_samples_train)
else:
self.global_index_map = self.get_global_map_index(self.file_list_eval, self.total_samples_eval)
global_file_map = self.build_sample_map_iter(self.file_list_eval, self.total_samples_eval, epoch_number)
self.file_map = global_file_map[self.my_rank]
elif self.data_loader_sampler == DataLoaderSampler.INDEX:
if dataset_type is DatasetType.TRAIN:
self.global_index_map = self.get_global_map_index(self.file_list_train, self.total_samples_train)
else:
self.global_index_map = self.get_global_map_index(self.file_list_eval, self.total_samples_eval)


def LoadConfig(args, config):
Expand Down

0 comments on commit bcfd98c

Please sign in to comment.