Skip to content

Commit

Permalink
fixed iterator to only store data for that rank.
Browse files Browse the repository at this point in the history
  • Loading branch information
hariharan-devarajan committed Jul 29, 2024
1 parent 33ec34e commit 37aee2b
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 41 deletions.
12 changes: 6 additions & 6 deletions dlio_benchmark/data_loader/torch_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ def __init__(self, rank, size, num_samples, shuffle, epochs, seed):
self.shuffle = shuffle
self.epochs = epochs
self.seed = seed
self.indices = list(range(self.num_samples))
samples_per_proc = int(math.ceil(num_samples/size))
start_sample = self.rank * samples_per_proc
end_sample = (self.rank + 1) * samples_per_proc
self.indices = list(range(start_sample, end_sample))


def __len__(self):
Expand All @@ -103,11 +106,8 @@ def __iter__(self):
if self.shuffle == Shuffle.SEED:
np.random.seed(self.seed)
np.random.shuffle(self.indices)
samples_per_gpu = self.num_samples // self.size
start = self.rank * samples_per_gpu
end = (self.rank + 1) * samples_per_gpu
for i in range(start, end):
yield self.indices[i % self.num_samples]
for sample in self.indices:
yield sample


class TorchDataLoader(BaseDataLoader):
Expand Down
59 changes: 24 additions & 35 deletions dlio_benchmark/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,55 +304,44 @@ def derive_configurations(self, file_list_train=None, file_list_eval=None):
@dlp.log
def build_sample_map_iter(self, file_list, total_samples, epoch_number):
logging.debug(f"ranks {self.comm_size} threads {self.read_threads} tensors")
num_files = len(file_list)
num_threads = 1
if self.read_threads > 0 and self.data_loader is not DataLoaderType.DALI:
num_threads = self.read_threads
samples_per_proc = total_samples//self.comm_size
samples_per_proc = int(math.ceil(total_samples/self.comm_size))
self.samples_per_thread = samples_per_proc // num_threads
file_index = 0
sample_index = 0
sample_global_list = np.arange(total_samples)
if self.file_shuffle is not Shuffle.OFF:
start_sample_index = samples_per_proc * self.my_rank
end_sample_index = samples_per_proc * (self.my_rank + 1)
sample_list = np.arange(start_sample_index, end_sample_index)
if self.sample_shuffle is not Shuffle.OFF:
if self.seed_change_epoch:
np.random.seed(self.seed + epoch_number)
else:
np.random.seed(self.seed)
np.random.shuffle(sample_global_list)
process_thread_file_map = {}
abs_path = os.path.abspath(file_list[file_index])

for rank in range(self.comm_size):
if rank not in process_thread_file_map:
process_thread_file_map[rank] = {}
np.random.shuffle(sample_list)
sample_index = 0
num_files = len(file_list)
files_per_rank = (num_files // self.comm_size) % num_files
file_index = self.my_rank * files_per_rank
process_thread_file_map = {self.my_rank:{}}
for thread_index in range(num_threads):
process_thread_file_map[self.my_rank][thread_index] = []
for sample in sample_list:
for thread_index in range(num_threads):
if (thread_index < samples_per_proc%num_threads):
addition = 1
else:
addition = 0
if thread_index not in process_thread_file_map[rank]:
process_thread_file_map[rank][thread_index] = []
selected_samples = 0
while selected_samples < self.samples_per_thread+addition:
process_thread_file_map[rank][thread_index].append((sample_global_list[sample_index],
abs_path,
sample_global_list[
sample_index] % self.num_samples_per_file))

sample_index += 1
selected_samples += 1
if sample_index >= self.num_samples_per_file:
sample_index = 0
file_index += 1
if file_index >= num_files:
break
abs_path = os.path.abspath(file_list[file_index])
abs_path = os.path.abspath(file_list[file_index])
process_thread_file_map[self.my_rank][thread_index].append((sample,
abs_path,
sample_list[sample_index] % self.num_samples_per_file))
sample_index += 1
file_index = (sample_index // self.num_samples_per_file) % num_files
return process_thread_file_map

@dlp.log
def get_global_map_index(self, file_list, total_samples):
process_thread_file_map = {}
for global_sample_index in range(total_samples):
samples_per_proc = int(math.ceil(total_samples/self.comm_size))
start_sample = self.my_rank * samples_per_proc
end_sample = (self.my_rank + 1) * samples_per_proc
for global_sample_index in range(start_sample, end_sample):
file_index = global_sample_index//self.num_samples_per_file
abs_path = os.path.abspath(file_list[file_index])
sample_index = global_sample_index % self.num_samples_per_file
Expand Down

0 comments on commit 37aee2b

Please sign in to comment.