diff --git a/dlio_benchmark/data_loader/torch_data_loader.py b/dlio_benchmark/data_loader/torch_data_loader.py index c18cdbbe..7cbcb82a 100644 --- a/dlio_benchmark/data_loader/torch_data_loader.py +++ b/dlio_benchmark/data_loader/torch_data_loader.py @@ -92,7 +92,10 @@ def __init__(self, rank, size, num_samples, shuffle, epochs, seed): self.shuffle = shuffle self.epochs = epochs self.seed = seed - self.indices = list(range(self.num_samples)) + samples_per_proc = int(math.ceil(num_samples/size)) + start_sample = self.rank * samples_per_proc + end_sample = (self.rank + 1) * samples_per_proc + self.indices = list(range(start_sample, end_sample)) def __len__(self): @@ -103,11 +106,8 @@ def __iter__(self): if self.shuffle == Shuffle.SEED: np.random.seed(self.seed) np.random.shuffle(self.indices) - samples_per_gpu = self.num_samples // self.size - start = self.rank * samples_per_gpu - end = (self.rank + 1) * samples_per_gpu - for i in range(start, end): - yield self.indices[i % self.num_samples] + for sample in self.indices: + yield sample class TorchDataLoader(BaseDataLoader): diff --git a/dlio_benchmark/main.py b/dlio_benchmark/main.py index 2246edbf..8ca224a6 100644 --- a/dlio_benchmark/main.py +++ b/dlio_benchmark/main.py @@ -223,7 +223,6 @@ def _eval(self, epoch): """ Evaluation loop will read a separate dataset and has its own own computation time. """ - self.args.reconfigure(epoch, DatasetType.VALID) step = 1 total = math.floor(self.num_samples * self.num_files_eval / self.batch_size_eval / self.comm_size) loader = self.framework.get_loader(DatasetType.VALID) @@ -331,7 +330,7 @@ def run(self): self.next_checkpoint_epoch = self.checkpoint_after_epoch epoch = 1 # Initialize the dataset - self.args.reconfigure(epoch, DatasetType.TRAIN) + self.args.reconfigure(epoch) self.framework.init_loader(self.args.format, epoch=epoch, data_loader=self.args.data_loader) self.framework.get_loader(dataset_type=DatasetType.TRAIN).read() if self.do_eval: @@ -339,7 +338,6 @@ def run(self): for epoch in range(1, self.epochs + 1): self.next_checkpoint_step = self.steps_between_checkpoints self.stats.start_train(epoch) - self.args.reconfigure(epoch, DatasetType.TRAIN) steps = self._train(epoch) self.stats.end_train(epoch, steps) logging.debug(f"{utcnow()} Rank {self.my_rank} returned after {steps} steps.") @@ -347,11 +345,11 @@ def run(self): # Perform evaluation if enabled if self.do_eval and epoch >= next_eval_epoch: next_eval_epoch += self.epochs_between_evals - self.args.reconfigure(epoch, DatasetType.VALID) self.stats.start_eval(epoch) self._eval(epoch) self.stats.end_eval(epoch) self.framework.get_loader(DatasetType.VALID).finalize() + self.args.reconfigure(epoch + 1) # reconfigure once per epoch self.stats.end_run() diff --git a/dlio_benchmark/reader/reader_handler.py b/dlio_benchmark/reader/reader_handler.py index fd6c295d..b8f8dfc4 100644 --- a/dlio_benchmark/reader/reader_handler.py +++ b/dlio_benchmark/reader/reader_handler.py @@ -51,6 +51,12 @@ def __init__(self, dataset_type, thread_index): self.image_idx = 0 self._file_list = self._args.file_list_train if self.dataset_type is DatasetType.TRAIN else self._args.file_list_eval self.batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval + if dataset_type is DatasetType.TRAIN: + self.global_index_map = self._args.train_global_index_map + self.file_map = self._args.train_file_map + else: + self.file_map = self._args.val_file_map + self.global_index_map = self._args.val_global_index_map @dlp.log def preprocess(self, a=None): @@ -76,10 +82,10 @@ def next(self): batch = [] image_processed = 0 self.step = 1 - total_images = len(self._args.file_map[self.thread_index]) + total_images = len(self.file_map[self.thread_index]) logging.debug(f"{utcnow()} Reading {total_images} images thread {self.thread_index} rank {self._args.my_rank}") - for global_sample_idx, filename, sample_index in self._args.file_map[self.thread_index]: + for global_sample_idx, filename, sample_index in self.file_map[self.thread_index]: self.image_idx = global_sample_idx if filename not in self.open_file_map or self.open_file_map[filename] is None: self.open_file_map[filename] = self.open(filename) @@ -106,7 +112,8 @@ def next(self): def read_index(self, global_sample_idx, step): self.step = step self.image_idx = global_sample_idx - filename, sample_index = self._args.global_index_map[global_sample_idx] + logging.debug(f"{self.global_index_map}") + filename, sample_index = self.global_index_map[global_sample_idx] logging.debug(f"{utcnow()} read_index {filename}, {sample_index}") FormatReader.read_images += 1 if self._args.read_type is ReadType.ON_DEMAND or filename not in self.open_file_map or self.open_file_map[filename] is None: diff --git a/dlio_benchmark/utils/config.py b/dlio_benchmark/utils/config.py index 8e25b148..a7f477f0 100644 --- a/dlio_benchmark/utils/config.py +++ b/dlio_benchmark/utils/config.py @@ -300,87 +300,76 @@ def derive_configurations(self, file_list_train=None, file_list_eval=None): logging.info(f"Discovered custom data reader {class_name}") self.reader_class = obj break + self.train_file_map = {self.my_rank : {}} + self.val_file_map = {self.my_rank : {}} + self.train_global_index_map = {} + self.val_global_index_map = {} @dlp.log def build_sample_map_iter(self, file_list, total_samples, epoch_number): logging.debug(f"ranks {self.comm_size} threads {self.read_threads} tensors") - num_files = len(file_list) num_threads = 1 if self.read_threads > 0 and self.data_loader is not DataLoaderType.DALI: num_threads = self.read_threads - samples_per_proc = total_samples//self.comm_size + samples_per_proc = int(math.ceil(total_samples/self.comm_size)) self.samples_per_thread = samples_per_proc // num_threads - file_index = 0 - sample_index = 0 - sample_global_list = np.arange(total_samples) - if self.file_shuffle is not Shuffle.OFF: + start_sample_index = samples_per_proc * self.my_rank + end_sample_index = samples_per_proc * (self.my_rank + 1) + sample_list = np.arange(start_sample_index, end_sample_index) + if self.sample_shuffle is not Shuffle.OFF: if self.seed_change_epoch: np.random.seed(self.seed + epoch_number) else: np.random.seed(self.seed) - np.random.shuffle(sample_global_list) + np.random.shuffle(sample_list) + sample_index = 0 + num_files = len(file_list) + files_per_rank = (num_files // self.comm_size) % num_files + file_index = self.my_rank * files_per_rank process_thread_file_map = {} - abs_path = os.path.abspath(file_list[file_index]) - - for rank in range(self.comm_size): - if rank not in process_thread_file_map: - process_thread_file_map[rank] = {} + for thread_index in range(num_threads): + process_thread_file_map[thread_index] = [] + for sample in sample_list: for thread_index in range(num_threads): - if (thread_index < samples_per_proc%num_threads): - addition = 1 - else: - addition = 0 - if thread_index not in process_thread_file_map[rank]: - process_thread_file_map[rank][thread_index] = [] - selected_samples = 0 - while selected_samples < self.samples_per_thread+addition: - process_thread_file_map[rank][thread_index].append((sample_global_list[sample_index], - abs_path, - sample_global_list[ - sample_index] % self.num_samples_per_file)) - - sample_index += 1 - selected_samples += 1 - if sample_index >= self.num_samples_per_file: - sample_index = 0 - file_index += 1 - if file_index >= num_files: - break - abs_path = os.path.abspath(file_list[file_index]) + abs_path = os.path.abspath(file_list[file_index]) + process_thread_file_map[thread_index].append((sample, + abs_path, + sample_list[sample_index] % self.num_samples_per_file)) + sample_index += 1 + file_index = (sample_index // self.num_samples_per_file) % num_files return process_thread_file_map @dlp.log def get_global_map_index(self, file_list, total_samples): process_thread_file_map = {} - for global_sample_index in range(total_samples): + samples_per_proc = int(math.ceil(total_samples/self.comm_size)) + start_sample = self.my_rank * samples_per_proc + end_sample = (self.my_rank + 1) * samples_per_proc + for global_sample_index in range(start_sample, end_sample): file_index = global_sample_index//self.num_samples_per_file abs_path = os.path.abspath(file_list[file_index]) sample_index = global_sample_index % self.num_samples_per_file process_thread_file_map[global_sample_index] = (abs_path, sample_index) + logging.debug(f"{self.my_rank} {process_thread_file_map}") return process_thread_file_map @dlp.log - def reconfigure(self, epoch_number, dataset_type): + def reconfigure(self, epoch_number): if self.data_loader_sampler == DataLoaderSampler.ITERATIVE: if self.file_shuffle is not Shuffle.OFF: if self.seed_change_epoch: np.random.seed(self.seed + epoch_number) else: np.random.seed(self.seed) - np.random.shuffle(self.file_list_train) if dataset_type is DatasetType.TRAIN else np.random.shuffle( - self.file_list_eval) + np.random.shuffle(self.file_list_train) + np.random.shuffle(self.file_list_eval) if self.data_loader_sampler == DataLoaderSampler.ITERATIVE: - if dataset_type is DatasetType.TRAIN: - global_file_map = self.build_sample_map_iter(self.file_list_train, self.total_samples_train, + self.train_file_map = self.build_sample_map_iter(self.file_list_train, self.total_samples_train, epoch_number) - else: - global_file_map = self.build_sample_map_iter(self.file_list_eval, self.total_samples_eval, epoch_number) - self.file_map = global_file_map[self.my_rank] + self.val_file_map = self.build_sample_map_iter(self.file_list_eval, self.total_samples_eval, epoch_number) elif self.data_loader_sampler == DataLoaderSampler.INDEX: - if dataset_type is DatasetType.TRAIN: - self.global_index_map = self.get_global_map_index(self.file_list_train, self.total_samples_train) - else: - self.global_index_map = self.get_global_map_index(self.file_list_eval, self.total_samples_eval) + self.train_global_index_map = self.get_global_map_index(self.file_list_train, self.total_samples_train) + self.val_global_index_map = self.get_global_map_index(self.file_list_eval, self.total_samples_eval) def LoadConfig(args, config): diff --git a/dlio_benchmark/utils/utility.py b/dlio_benchmark/utils/utility.py index 8872f2ec..e4ebe1a7 100644 --- a/dlio_benchmark/utils/utility.py +++ b/dlio_benchmark/utils/utility.py @@ -66,8 +66,6 @@ def iter(self, a): # MPI cannot be initialized automatically, or read_thread spawn/forkserver # child processes will abort trying to open a non-existant PMI_fd file. import mpi4py -mpi4py.rc.initialize = False -from mpi4py import MPI p = psutil.Process() @@ -112,6 +110,7 @@ def classname(cls): return cls.__qualname__ def initialize(self): + from mpi4py import MPI if self.mpi_state == MPIState.UNINITIALIZED: # MPI may have already been initialized by dlio_benchmark_test.py if not MPI.Is_initialized(): @@ -181,6 +180,7 @@ def nnodes(self): else: return self.mpi_size//self.mpi_ppn def finalize(self): + from mpi4py import MPI if self.mpi_state == MPIState.MPI_INITIALIZED and MPI.Is_initialized(): MPI.Finalize()