Skip to content

Commit

Permalink
fixed iterator to only store data for that rank.
Browse files Browse the repository at this point in the history
- we fix reconfiguration logic to build rank map for train and valid.
- Now we need reconfigure per epoch to shuffle the list.
  • Loading branch information
hariharan-devarajan committed Jul 29, 2024
1 parent 33ec34e commit 4e88768
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 91 deletions.
20 changes: 12 additions & 8 deletions dlio_benchmark/data_loader/dali_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,25 @@ def __init__(self, format_type, dataset_type, epoch, worker_index,
epoch_number=self.epoch)
assert(self.reader.is_index_based())
self.seed = seed
start_sample = self.worker_index * samples_per_worker
end_sample = (self.worker_index + 1) * samples_per_worker
if not hasattr(self, 'indices'):
self.indices = np.arange(self.total_num_samples, dtype=np.int64)
self.indices = list(range(start_sample, end_sample))
if self.shuffle != Shuffle.OFF:
if self.shuffle == Shuffle.SEED:
np.random.seed(self.seed)
np.random.shuffle(self.indices)
def __call__(self, sample_info):
logging.debug(
f"{utcnow()} Reading {sample_info.idx_in_epoch} out of {self.samples_per_worker} by worker {self.worker_index}")
sample_idx = sample_info.idx_in_epoch + self.samples_per_worker * self.worker_index
logging.info(
f"{utcnow()} Reading {sample_info.idx_in_epoch} out of {self.samples_per_worker} by worker {self.worker_index} with {self.indices} indices")
step = sample_info.iteration
if step >= self.total_num_steps or sample_idx >= self.total_num_samples:
if step >= self.total_num_steps or sample_info.idx_in_epoch >= self.samples_per_worker:
# Indicate end of the epoch
raise StopIteration()
sample_idx = self.indices[sample_info.idx_in_epoch]
with Profile(MODULE_DATA_LOADER, epoch=self.epoch, image_idx=sample_idx, step=step):
image = self.reader.read_index(self.indices[sample_idx], step)
return image, np.uint8([self.indices[sample_idx]])
image = self.reader.read_index(sample_idx, step)
return image, np.uint8([sample_idx])

class DaliIteratorDataset(object):
def __init__(self, format_type, dataset_type, epoch, worker_index,
Expand All @@ -89,8 +91,10 @@ def __init__(self, format_type, dataset_type, epoch, worker_index,
epoch_number=self.epoch)
assert(self.reader.is_iterator_based())
self.seed = seed
start_sample = self.worker_index * samples_per_worker
end_sample = (self.worker_index + 1) * samples_per_worker
if not hasattr(self, 'indices'):
self.indices = np.arange(self.total_num_samples, dtype=np.int64)
self.indices = list(range(start_sample, end_sample))
if self.shuffle != Shuffle.OFF:
if self.shuffle == Shuffle.SEED:
np.random.seed(self.seed)
Expand Down
12 changes: 6 additions & 6 deletions dlio_benchmark/data_loader/torch_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ def __init__(self, rank, size, num_samples, shuffle, epochs, seed):
self.shuffle = shuffle
self.epochs = epochs
self.seed = seed
self.indices = list(range(self.num_samples))
samples_per_proc = int(math.ceil(num_samples/size))
start_sample = self.rank * samples_per_proc
end_sample = (self.rank + 1) * samples_per_proc
self.indices = list(range(start_sample, end_sample))


def __len__(self):
Expand All @@ -103,11 +106,8 @@ def __iter__(self):
if self.shuffle == Shuffle.SEED:
np.random.seed(self.seed)
np.random.shuffle(self.indices)
samples_per_gpu = self.num_samples // self.size
start = self.rank * samples_per_gpu
end = (self.rank + 1) * samples_per_gpu
for i in range(start, end):
yield self.indices[i % self.num_samples]
for sample in self.indices:
yield sample


class TorchDataLoader(BaseDataLoader):
Expand Down
6 changes: 2 additions & 4 deletions dlio_benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def _eval(self, epoch):
"""
Evaluation loop will read a separate dataset and has its own own computation time.
"""
self.args.reconfigure(epoch, DatasetType.VALID)
step = 1
total = math.floor(self.num_samples * self.num_files_eval / self.batch_size_eval / self.comm_size)
loader = self.framework.get_loader(DatasetType.VALID)
Expand Down Expand Up @@ -331,27 +330,26 @@ def run(self):
self.next_checkpoint_epoch = self.checkpoint_after_epoch
epoch = 1
# Initialize the dataset
self.args.reconfigure(epoch, DatasetType.TRAIN)
self.args.reconfigure(epoch)
self.framework.init_loader(self.args.format, epoch=epoch, data_loader=self.args.data_loader)
self.framework.get_loader(dataset_type=DatasetType.TRAIN).read()
if self.do_eval:
self.framework.get_loader(dataset_type=DatasetType.VALID).read()
for epoch in range(1, self.epochs + 1):
self.next_checkpoint_step = self.steps_between_checkpoints
self.stats.start_train(epoch)
self.args.reconfigure(epoch, DatasetType.TRAIN)
steps = self._train(epoch)
self.stats.end_train(epoch, steps)
logging.debug(f"{utcnow()} Rank {self.my_rank} returned after {steps} steps.")
self.framework.get_loader(DatasetType.TRAIN).finalize()
# Perform evaluation if enabled
if self.do_eval and epoch >= next_eval_epoch:
next_eval_epoch += self.epochs_between_evals
self.args.reconfigure(epoch, DatasetType.VALID)
self.stats.start_eval(epoch)
self._eval(epoch)
self.stats.end_eval(epoch)
self.framework.get_loader(DatasetType.VALID).finalize()
self.args.reconfigure(epoch + 1) # reconfigure once per epoch

self.stats.end_run()

Expand Down
18 changes: 9 additions & 9 deletions dlio_benchmark/reader/indexed_binary_mmap_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class IndexedBinaryMMapReader(FormatReader):
@dlp.log_init
def __init__(self, dataset_type, thread_index, epoch):
super().__init__(dataset_type, thread_index)
self.file_map = {}
self.file_map_ibr = {}
self.load_index()
self.buffer_map = {}

Expand All @@ -51,24 +51,24 @@ def read_longs(self, f, n):
return a

def load_index_file(self, global_sample_idx, filename, sample_index):
if filename not in self.file_map:
if filename not in self.file_map_ibr:
offset_file = self.index_file_path_off(filename)
sz_file = self.index_file_path_size(filename)
self.file_map[filename] = []
self.file_map_ibr[filename] = []
bin_buffer_mmap = np.memmap(offset_file, mode='r', order='C')
bin_buffer = memoryview(bin_buffer_mmap)
self.file_map[filename].append(np.frombuffer(bin_buffer, dtype=np.uint8))
self.file_map_ibr[filename].append(np.frombuffer(bin_buffer, dtype=np.uint8))
bin_buffer_mmap = np.memmap(sz_file, mode='r', order='C')
bin_buffer = memoryview(bin_buffer_mmap)
self.file_map[filename].append(np.frombuffer(bin_buffer, dtype=np.uint8))
self.file_map_ibr[filename].append(np.frombuffer(bin_buffer, dtype=np.uint8))

@dlp.log
def load_index(self):
if self._args.data_loader_sampler == DataLoaderSampler.ITERATIVE:
for global_sample_idx, filename, sample_index in self._args.file_map[self.thread_index]:
for global_sample_idx, filename, sample_index in self.file_map[self.thread_index]:
self.load_index_file(global_sample_idx, filename, sample_index)
elif self._args.data_loader_sampler == DataLoaderSampler.INDEX:
for global_sample_idx, (filename, sample_index) in self._args.global_index_map.items():
for global_sample_idx, (filename, sample_index) in self.global_index_map.items():
self.load_index_file(global_sample_idx, filename, sample_index)


Expand All @@ -91,8 +91,8 @@ def close(self, filename):
def get_sample(self, filename, sample_index):
super().get_sample(filename, sample_index)
buffer = self.buffer_map[filename]
offset = self.file_map[filename][0][sample_index]
size = self.file_map[filename][1][sample_index]
offset = self.file_map_ibr[filename][0][sample_index]
size = self.file_map_ibr[filename][1][sample_index]
logging.debug(f"reading sample from offset {offset} of size {size} from file {filename}")
image = buffer[offset:offset+size]
dlp.update(image_size=size)
Expand Down
18 changes: 9 additions & 9 deletions dlio_benchmark/reader/indexed_binary_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class IndexedBinaryReader(FormatReader):
@dlp.log_init
def __init__(self, dataset_type, thread_index, epoch):
super().__init__(dataset_type, thread_index)
self.file_map = {}
self.file_map_ibr = {}
self.load_index()

def index_file_path_off(self, prefix_path):
Expand All @@ -50,25 +50,25 @@ def read_longs(self, f, n):
return a

def load_index_file(self, global_sample_idx, filename, sample_index):
if filename not in self.file_map:
if filename not in self.file_map_ibr:
offset_file = self.index_file_path_off(filename)
sz_file = self.index_file_path_size(filename)
self.file_map[filename] = []
self.file_map_ibr[filename] = []
with open(offset_file, 'rb') as f:
offsets = self.read_longs(f, self._args.num_samples_per_file)
logging.debug(f"read offsets {offsets} from file {offset_file}")
self.file_map[filename].append(offsets)
self.file_map_ibr[filename].append(offsets)
with open(sz_file, 'rb') as f:
sizes = self.read_longs(f, self._args.num_samples_per_file)
logging.debug(f"read sizes {sizes} from file {sz_file}")
self.file_map[filename].append(sizes)
self.file_map_ibr[filename].append(sizes)
@dlp.log
def load_index(self):
if self._args.data_loader_sampler == DataLoaderSampler.ITERATIVE:
for global_sample_idx, filename, sample_index in self._args.file_map[self.thread_index]:
for global_sample_idx, filename, sample_index in self.file_map[self.thread_index]:
self.load_index_file(global_sample_idx, filename, sample_index)
elif self._args.data_loader_sampler == DataLoaderSampler.INDEX:
for global_sample_idx, (filename, sample_index) in self._args.global_index_map.items():
for global_sample_idx, (filename, sample_index) in self.global_index_map.items():
self.load_index_file(global_sample_idx, filename, sample_index)


Expand All @@ -88,8 +88,8 @@ def close(self, filename):
def get_sample(self, filename, sample_index):
super().get_sample(filename, sample_index)
file = self.open_file_map[filename]
offset = self.file_map[filename][0][sample_index]
size = self.file_map[filename][1][sample_index]
offset = self.file_map_ibr[filename][0][sample_index]
size = self.file_map_ibr[filename][1][sample_index]
logging.debug(f"reading sample from offset {offset} of size {size} from file {filename}")
file.seek(offset)
image = np.empty(size, dtype=np.uint8)
Expand Down
13 changes: 10 additions & 3 deletions dlio_benchmark/reader/reader_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def __init__(self, dataset_type, thread_index):
self.image_idx = 0
self._file_list = self._args.file_list_train if self.dataset_type is DatasetType.TRAIN else self._args.file_list_eval
self.batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval
if dataset_type is DatasetType.TRAIN:
self.global_index_map = self._args.train_global_index_map
self.file_map = self._args.train_file_map
else:
self.file_map = self._args.val_file_map
self.global_index_map = self._args.val_global_index_map

@dlp.log
def preprocess(self, a=None):
Expand All @@ -76,10 +82,10 @@ def next(self):
batch = []
image_processed = 0
self.step = 1
total_images = len(self._args.file_map[self.thread_index])
total_images = len(self.file_map[self.thread_index])
logging.debug(f"{utcnow()} Reading {total_images} images thread {self.thread_index} rank {self._args.my_rank}")

for global_sample_idx, filename, sample_index in self._args.file_map[self.thread_index]:
for global_sample_idx, filename, sample_index in self.file_map[self.thread_index]:
self.image_idx = global_sample_idx
if filename not in self.open_file_map or self.open_file_map[filename] is None:
self.open_file_map[filename] = self.open(filename)
Expand All @@ -106,7 +112,8 @@ def next(self):
def read_index(self, global_sample_idx, step):
self.step = step
self.image_idx = global_sample_idx
filename, sample_index = self._args.global_index_map[global_sample_idx]
logging.debug(f"{self.global_index_map}")
filename, sample_index = self.global_index_map[global_sample_idx]
logging.debug(f"{utcnow()} read_index {filename}, {sample_index}")
FormatReader.read_images += 1
if self._args.read_type is ReadType.ON_DEMAND or filename not in self.open_file_map or self.open_file_map[filename] is None:
Expand Down
Loading

0 comments on commit 4e88768

Please sign in to comment.