Skip to content

Commit

Permalink
fixed iterator to only store data for that rank.
Browse files Browse the repository at this point in the history
- we fix reconfiguration logic to build rank map for train and valid.
- Now we need reconfigure per epoch to shuffle the list.
  • Loading branch information
hariharan-devarajan committed Jul 29, 2024
1 parent 33ec34e commit 9cdd425
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 62 deletions.
12 changes: 6 additions & 6 deletions dlio_benchmark/data_loader/torch_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ def __init__(self, rank, size, num_samples, shuffle, epochs, seed):
self.shuffle = shuffle
self.epochs = epochs
self.seed = seed
self.indices = list(range(self.num_samples))
samples_per_proc = int(math.ceil(num_samples/size))
start_sample = self.rank * samples_per_proc
end_sample = (self.rank + 1) * samples_per_proc
self.indices = list(range(start_sample, end_sample))


def __len__(self):
Expand All @@ -103,11 +106,8 @@ def __iter__(self):
if self.shuffle == Shuffle.SEED:
np.random.seed(self.seed)
np.random.shuffle(self.indices)
samples_per_gpu = self.num_samples // self.size
start = self.rank * samples_per_gpu
end = (self.rank + 1) * samples_per_gpu
for i in range(start, end):
yield self.indices[i % self.num_samples]
for sample in self.indices:
yield sample


class TorchDataLoader(BaseDataLoader):
Expand Down
6 changes: 2 additions & 4 deletions dlio_benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def _eval(self, epoch):
"""
Evaluation loop will read a separate dataset and has its own own computation time.
"""
self.args.reconfigure(epoch, DatasetType.VALID)
step = 1
total = math.floor(self.num_samples * self.num_files_eval / self.batch_size_eval / self.comm_size)
loader = self.framework.get_loader(DatasetType.VALID)
Expand Down Expand Up @@ -331,27 +330,26 @@ def run(self):
self.next_checkpoint_epoch = self.checkpoint_after_epoch
epoch = 1
# Initialize the dataset
self.args.reconfigure(epoch, DatasetType.TRAIN)
self.args.reconfigure(epoch)
self.framework.init_loader(self.args.format, epoch=epoch, data_loader=self.args.data_loader)
self.framework.get_loader(dataset_type=DatasetType.TRAIN).read()
if self.do_eval:
self.framework.get_loader(dataset_type=DatasetType.VALID).read()
for epoch in range(1, self.epochs + 1):
self.next_checkpoint_step = self.steps_between_checkpoints
self.stats.start_train(epoch)
self.args.reconfigure(epoch, DatasetType.TRAIN)
steps = self._train(epoch)
self.stats.end_train(epoch, steps)
logging.debug(f"{utcnow()} Rank {self.my_rank} returned after {steps} steps.")
self.framework.get_loader(DatasetType.TRAIN).finalize()
# Perform evaluation if enabled
if self.do_eval and epoch >= next_eval_epoch:
next_eval_epoch += self.epochs_between_evals
self.args.reconfigure(epoch, DatasetType.VALID)
self.stats.start_eval(epoch)
self._eval(epoch)
self.stats.end_eval(epoch)
self.framework.get_loader(DatasetType.VALID).finalize()
self.args.reconfigure(epoch + 1) # reconfigure once per epoch

self.stats.end_run()

Expand Down
13 changes: 10 additions & 3 deletions dlio_benchmark/reader/reader_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def __init__(self, dataset_type, thread_index):
self.image_idx = 0
self._file_list = self._args.file_list_train if self.dataset_type is DatasetType.TRAIN else self._args.file_list_eval
self.batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval
if dataset_type is DatasetType.TRAIN:
self.global_index_map = self._args.train_global_index_map
self.file_map = self._args.train_file_map
else:
self.file_map = self._args.val_file_map
self.global_index_map = self._args.val_global_index_map

@dlp.log
def preprocess(self, a=None):
Expand All @@ -76,10 +82,10 @@ def next(self):
batch = []
image_processed = 0
self.step = 1
total_images = len(self._args.file_map[self.thread_index])
total_images = len(self.file_map[self.thread_index])
logging.debug(f"{utcnow()} Reading {total_images} images thread {self.thread_index} rank {self._args.my_rank}")

for global_sample_idx, filename, sample_index in self._args.file_map[self.thread_index]:
for global_sample_idx, filename, sample_index in self.file_map[self.thread_index]:
self.image_idx = global_sample_idx
if filename not in self.open_file_map or self.open_file_map[filename] is None:
self.open_file_map[filename] = self.open(filename)
Expand All @@ -106,7 +112,8 @@ def next(self):
def read_index(self, global_sample_idx, step):
self.step = step
self.image_idx = global_sample_idx
filename, sample_index = self._args.global_index_map[global_sample_idx]
logging.debug(f"{self.global_index_map}")
filename, sample_index = self.global_index_map[global_sample_idx]
logging.debug(f"{utcnow()} read_index {filename}, {sample_index}")
FormatReader.read_images += 1
if self._args.read_type is ReadType.ON_DEMAND or filename not in self.open_file_map or self.open_file_map[filename] is None:
Expand Down
83 changes: 36 additions & 47 deletions dlio_benchmark/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,87 +300,76 @@ def derive_configurations(self, file_list_train=None, file_list_eval=None):
logging.info(f"Discovered custom data reader {class_name}")
self.reader_class = obj
break
self.train_file_map = {self.my_rank : {}}
self.val_file_map = {self.my_rank : {}}
self.train_global_index_map = {}
self.val_global_index_map = {}

@dlp.log
def build_sample_map_iter(self, file_list, total_samples, epoch_number):
logging.debug(f"ranks {self.comm_size} threads {self.read_threads} tensors")
num_files = len(file_list)
num_threads = 1
if self.read_threads > 0 and self.data_loader is not DataLoaderType.DALI:
num_threads = self.read_threads
samples_per_proc = total_samples//self.comm_size
samples_per_proc = int(math.ceil(total_samples/self.comm_size))
self.samples_per_thread = samples_per_proc // num_threads
file_index = 0
sample_index = 0
sample_global_list = np.arange(total_samples)
if self.file_shuffle is not Shuffle.OFF:
start_sample_index = samples_per_proc * self.my_rank
end_sample_index = samples_per_proc * (self.my_rank + 1)
sample_list = np.arange(start_sample_index, end_sample_index)
if self.sample_shuffle is not Shuffle.OFF:
if self.seed_change_epoch:
np.random.seed(self.seed + epoch_number)
else:
np.random.seed(self.seed)
np.random.shuffle(sample_global_list)
process_thread_file_map = {}
abs_path = os.path.abspath(file_list[file_index])

for rank in range(self.comm_size):
if rank not in process_thread_file_map:
process_thread_file_map[rank] = {}
np.random.shuffle(sample_list)
sample_index = 0
num_files = len(file_list)
files_per_rank = (num_files // self.comm_size) % num_files
file_index = self.my_rank * files_per_rank
process_thread_file_map = {self.my_rank:{}}
for thread_index in range(num_threads):
process_thread_file_map[self.my_rank][thread_index] = []
for sample in sample_list:
for thread_index in range(num_threads):
if (thread_index < samples_per_proc%num_threads):
addition = 1
else:
addition = 0
if thread_index not in process_thread_file_map[rank]:
process_thread_file_map[rank][thread_index] = []
selected_samples = 0
while selected_samples < self.samples_per_thread+addition:
process_thread_file_map[rank][thread_index].append((sample_global_list[sample_index],
abs_path,
sample_global_list[
sample_index] % self.num_samples_per_file))

sample_index += 1
selected_samples += 1
if sample_index >= self.num_samples_per_file:
sample_index = 0
file_index += 1
if file_index >= num_files:
break
abs_path = os.path.abspath(file_list[file_index])
abs_path = os.path.abspath(file_list[file_index])
process_thread_file_map[self.my_rank][thread_index].append((sample,
abs_path,
sample_list[sample_index] % self.num_samples_per_file))
sample_index += 1
file_index = (sample_index // self.num_samples_per_file) % num_files
return process_thread_file_map

@dlp.log
def get_global_map_index(self, file_list, total_samples):
process_thread_file_map = {}
for global_sample_index in range(total_samples):
samples_per_proc = int(math.ceil(total_samples/self.comm_size))
start_sample = self.my_rank * samples_per_proc
end_sample = (self.my_rank + 1) * samples_per_proc
for global_sample_index in range(start_sample, end_sample):
file_index = global_sample_index//self.num_samples_per_file
abs_path = os.path.abspath(file_list[file_index])
sample_index = global_sample_index % self.num_samples_per_file
process_thread_file_map[global_sample_index] = (abs_path, sample_index)
logging.debug(f"{self.my_rank} {process_thread_file_map}")
return process_thread_file_map

@dlp.log
def reconfigure(self, epoch_number, dataset_type):
def reconfigure(self, epoch_number):
if self.data_loader_sampler == DataLoaderSampler.ITERATIVE:
if self.file_shuffle is not Shuffle.OFF:
if self.seed_change_epoch:
np.random.seed(self.seed + epoch_number)
else:
np.random.seed(self.seed)
np.random.shuffle(self.file_list_train) if dataset_type is DatasetType.TRAIN else np.random.shuffle(
self.file_list_eval)
np.random.shuffle(self.file_list_train)
np.random.shuffle(self.file_list_eval)
if self.data_loader_sampler == DataLoaderSampler.ITERATIVE:
if dataset_type is DatasetType.TRAIN:
global_file_map = self.build_sample_map_iter(self.file_list_train, self.total_samples_train,
self.train_file_map = self.build_sample_map_iter(self.file_list_train, self.total_samples_train,
epoch_number)
else:
global_file_map = self.build_sample_map_iter(self.file_list_eval, self.total_samples_eval, epoch_number)
self.file_map = global_file_map[self.my_rank]
self.val_file_map = self.build_sample_map_iter(self.file_list_eval, self.total_samples_eval, epoch_number)
elif self.data_loader_sampler == DataLoaderSampler.INDEX:
if dataset_type is DatasetType.TRAIN:
self.global_index_map = self.get_global_map_index(self.file_list_train, self.total_samples_train)
else:
self.global_index_map = self.get_global_map_index(self.file_list_eval, self.total_samples_eval)
self.train_global_index_map = self.get_global_map_index(self.file_list_train, self.total_samples_train)
self.val_global_index_map = self.get_global_map_index(self.file_list_eval, self.total_samples_eval)


def LoadConfig(args, config):
Expand Down
4 changes: 2 additions & 2 deletions dlio_benchmark/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ def iter(self, a):
# MPI cannot be initialized automatically, or read_thread spawn/forkserver
# child processes will abort trying to open a non-existant PMI_fd file.
import mpi4py
mpi4py.rc.initialize = False
from mpi4py import MPI

p = psutil.Process()

Expand Down Expand Up @@ -112,6 +110,7 @@ def classname(cls):
return cls.__qualname__

def initialize(self):
from mpi4py import MPI
if self.mpi_state == MPIState.UNINITIALIZED:
# MPI may have already been initialized by dlio_benchmark_test.py
if not MPI.Is_initialized():
Expand Down Expand Up @@ -181,6 +180,7 @@ def nnodes(self):
else:
return self.mpi_size//self.mpi_ppn
def finalize(self):
from mpi4py import MPI
if self.mpi_state == MPIState.MPI_INITIALIZED and MPI.Is_initialized():
MPI.Finalize()

Expand Down

0 comments on commit 9cdd425

Please sign in to comment.