Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed iterator to only store data for that rank. #216

Merged
merged 4 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions dlio_benchmark/data_loader/dali_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,25 @@ def __init__(self, format_type, dataset_type, epoch, worker_index,
epoch_number=self.epoch)
assert(self.reader.is_index_based())
self.seed = seed
start_sample = self.worker_index * samples_per_worker
end_sample = (self.worker_index + 1) * samples_per_worker
if not hasattr(self, 'indices'):
self.indices = np.arange(self.total_num_samples, dtype=np.int64)
self.indices = list(range(start_sample, end_sample))
if self.shuffle != Shuffle.OFF:
if self.shuffle == Shuffle.SEED:
np.random.seed(self.seed)
np.random.shuffle(self.indices)
def __call__(self, sample_info):
logging.debug(
f"{utcnow()} Reading {sample_info.idx_in_epoch} out of {self.samples_per_worker} by worker {self.worker_index}")
sample_idx = sample_info.idx_in_epoch + self.samples_per_worker * self.worker_index
f"{utcnow()} Reading {sample_info.idx_in_epoch} out of {self.samples_per_worker} by worker {self.worker_index} with {self.indices} indices")
step = sample_info.iteration
if step >= self.total_num_steps or sample_idx >= self.total_num_samples:
if step >= self.total_num_steps or sample_info.idx_in_epoch >= self.samples_per_worker:
# Indicate end of the epoch
raise StopIteration()
sample_idx = self.indices[sample_info.idx_in_epoch]
with Profile(MODULE_DATA_LOADER, epoch=self.epoch, image_idx=sample_idx, step=step):
image = self.reader.read_index(self.indices[sample_idx], step)
return image, np.uint8([self.indices[sample_idx]])
image = self.reader.read_index(sample_idx, step)
return image, np.uint8([sample_idx])

class DaliIteratorDataset(object):
def __init__(self, format_type, dataset_type, epoch, worker_index,
Expand All @@ -89,8 +91,10 @@ def __init__(self, format_type, dataset_type, epoch, worker_index,
epoch_number=self.epoch)
assert(self.reader.is_iterator_based())
self.seed = seed
start_sample = self.worker_index * samples_per_worker
end_sample = (self.worker_index + 1) * samples_per_worker
if not hasattr(self, 'indices'):
self.indices = np.arange(self.total_num_samples, dtype=np.int64)
self.indices = list(range(start_sample, end_sample))
if self.shuffle != Shuffle.OFF:
if self.shuffle == Shuffle.SEED:
np.random.seed(self.seed)
Expand Down
12 changes: 6 additions & 6 deletions dlio_benchmark/data_loader/torch_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ def __init__(self, rank, size, num_samples, shuffle, epochs, seed):
self.shuffle = shuffle
self.epochs = epochs
self.seed = seed
self.indices = list(range(self.num_samples))
samples_per_proc = int(math.ceil(num_samples/size))
start_sample = self.rank * samples_per_proc
end_sample = (self.rank + 1) * samples_per_proc
self.indices = list(range(start_sample, end_sample))


def __len__(self):
Expand All @@ -103,11 +106,8 @@ def __iter__(self):
if self.shuffle == Shuffle.SEED:
np.random.seed(self.seed)
np.random.shuffle(self.indices)
samples_per_gpu = self.num_samples // self.size
start = self.rank * samples_per_gpu
end = (self.rank + 1) * samples_per_gpu
for i in range(start, end):
yield self.indices[i % self.num_samples]
for sample in self.indices:
yield sample


class TorchDataLoader(BaseDataLoader):
Expand Down
6 changes: 2 additions & 4 deletions dlio_benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def _eval(self, epoch):
"""
Evaluation loop will read a separate dataset and has its own own computation time.
"""
self.args.reconfigure(epoch, DatasetType.VALID)
step = 1
total = math.floor(self.num_samples * self.num_files_eval / self.batch_size_eval / self.comm_size)
loader = self.framework.get_loader(DatasetType.VALID)
Expand Down Expand Up @@ -328,27 +327,26 @@ def run(self):
self.next_checkpoint_epoch = self.checkpoint_after_epoch
epoch = 1
# Initialize the dataset
self.args.reconfigure(epoch, DatasetType.TRAIN)
self.args.reconfigure(epoch)
self.framework.init_loader(self.args.format, epoch=epoch, data_loader=self.args.data_loader)
self.framework.get_loader(dataset_type=DatasetType.TRAIN).read()
if self.do_eval:
self.framework.get_loader(dataset_type=DatasetType.VALID).read()
for epoch in range(1, self.epochs + 1):
self.next_checkpoint_step = self.steps_between_checkpoints
self.stats.start_train(epoch)
self.args.reconfigure(epoch, DatasetType.TRAIN)
steps = self._train(epoch)
self.stats.end_train(epoch, steps)
logging.debug(f"{utcnow()} Rank {self.my_rank} returned after {steps} steps.")
self.framework.get_loader(DatasetType.TRAIN).finalize()
# Perform evaluation if enabled
if self.do_eval and epoch >= next_eval_epoch:
next_eval_epoch += self.epochs_between_evals
self.args.reconfigure(epoch, DatasetType.VALID)
self.stats.start_eval(epoch)
self._eval(epoch)
self.stats.end_eval(epoch)
self.framework.get_loader(DatasetType.VALID).finalize()
self.args.reconfigure(epoch + 1) # reconfigure once per epoch

self.stats.end_run()

Expand Down
18 changes: 9 additions & 9 deletions dlio_benchmark/reader/indexed_binary_mmap_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class IndexedBinaryMMapReader(FormatReader):
@dlp.log_init
def __init__(self, dataset_type, thread_index, epoch):
super().__init__(dataset_type, thread_index)
self.file_map = {}
self.file_map_ibr = {}
self.load_index()
self.buffer_map = {}

Expand All @@ -51,24 +51,24 @@ def read_longs(self, f, n):
return a

def load_index_file(self, global_sample_idx, filename, sample_index):
if filename not in self.file_map:
if filename not in self.file_map_ibr:
offset_file = self.index_file_path_off(filename)
sz_file = self.index_file_path_size(filename)
self.file_map[filename] = []
self.file_map_ibr[filename] = []
bin_buffer_mmap = np.memmap(offset_file, mode='r', order='C')
bin_buffer = memoryview(bin_buffer_mmap)
self.file_map[filename].append(np.frombuffer(bin_buffer, dtype=np.uint8))
self.file_map_ibr[filename].append(np.frombuffer(bin_buffer, dtype=np.uint8))
bin_buffer_mmap = np.memmap(sz_file, mode='r', order='C')
bin_buffer = memoryview(bin_buffer_mmap)
self.file_map[filename].append(np.frombuffer(bin_buffer, dtype=np.uint8))
self.file_map_ibr[filename].append(np.frombuffer(bin_buffer, dtype=np.uint8))

@dlp.log
def load_index(self):
if self._args.data_loader_sampler == DataLoaderSampler.ITERATIVE:
for global_sample_idx, filename, sample_index in self._args.file_map[self.thread_index]:
for global_sample_idx, filename, sample_index in self.file_map[self.thread_index]:
self.load_index_file(global_sample_idx, filename, sample_index)
elif self._args.data_loader_sampler == DataLoaderSampler.INDEX:
for global_sample_idx, (filename, sample_index) in self._args.global_index_map.items():
for global_sample_idx, (filename, sample_index) in self.global_index_map.items():
self.load_index_file(global_sample_idx, filename, sample_index)


Expand All @@ -91,8 +91,8 @@ def close(self, filename):
def get_sample(self, filename, sample_index):
super().get_sample(filename, sample_index)
buffer = self.buffer_map[filename]
offset = self.file_map[filename][0][sample_index]
size = self.file_map[filename][1][sample_index]
offset = self.file_map_ibr[filename][0][sample_index]
size = self.file_map_ibr[filename][1][sample_index]
logging.debug(f"reading sample from offset {offset} of size {size} from file {filename}")
image = buffer[offset:offset+size]
dlp.update(image_size=size)
Expand Down
18 changes: 9 additions & 9 deletions dlio_benchmark/reader/indexed_binary_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class IndexedBinaryReader(FormatReader):
@dlp.log_init
def __init__(self, dataset_type, thread_index, epoch):
super().__init__(dataset_type, thread_index)
self.file_map = {}
self.file_map_ibr = {}
self.load_index()

def index_file_path_off(self, prefix_path):
Expand All @@ -50,25 +50,25 @@ def read_longs(self, f, n):
return a

def load_index_file(self, global_sample_idx, filename, sample_index):
if filename not in self.file_map:
if filename not in self.file_map_ibr:
offset_file = self.index_file_path_off(filename)
sz_file = self.index_file_path_size(filename)
self.file_map[filename] = []
self.file_map_ibr[filename] = []
with open(offset_file, 'rb') as f:
offsets = self.read_longs(f, self._args.num_samples_per_file)
logging.debug(f"read offsets {offsets} from file {offset_file}")
self.file_map[filename].append(offsets)
self.file_map_ibr[filename].append(offsets)
with open(sz_file, 'rb') as f:
sizes = self.read_longs(f, self._args.num_samples_per_file)
logging.debug(f"read sizes {sizes} from file {sz_file}")
self.file_map[filename].append(sizes)
self.file_map_ibr[filename].append(sizes)
@dlp.log
def load_index(self):
if self._args.data_loader_sampler == DataLoaderSampler.ITERATIVE:
for global_sample_idx, filename, sample_index in self._args.file_map[self.thread_index]:
for global_sample_idx, filename, sample_index in self.file_map[self.thread_index]:
self.load_index_file(global_sample_idx, filename, sample_index)
elif self._args.data_loader_sampler == DataLoaderSampler.INDEX:
for global_sample_idx, (filename, sample_index) in self._args.global_index_map.items():
for global_sample_idx, (filename, sample_index) in self.global_index_map.items():
self.load_index_file(global_sample_idx, filename, sample_index)


Expand All @@ -88,8 +88,8 @@ def close(self, filename):
def get_sample(self, filename, sample_index):
super().get_sample(filename, sample_index)
file = self.open_file_map[filename]
offset = self.file_map[filename][0][sample_index]
size = self.file_map[filename][1][sample_index]
offset = self.file_map_ibr[filename][0][sample_index]
size = self.file_map_ibr[filename][1][sample_index]
logging.debug(f"reading sample from offset {offset} of size {size} from file {filename}")
file.seek(offset)
image = np.empty(size, dtype=np.uint8)
Expand Down
13 changes: 10 additions & 3 deletions dlio_benchmark/reader/reader_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def __init__(self, dataset_type, thread_index):
self.image_idx = 0
self._file_list = self._args.file_list_train if self.dataset_type is DatasetType.TRAIN else self._args.file_list_eval
self.batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval
if dataset_type is DatasetType.TRAIN:
self.global_index_map = self._args.train_global_index_map
self.file_map = self._args.train_file_map
else:
self.file_map = self._args.val_file_map
self.global_index_map = self._args.val_global_index_map

@dlp.log
def preprocess(self, a=None):
Expand All @@ -76,10 +82,10 @@ def next(self):
batch = []
image_processed = 0
self.step = 1
total_images = len(self._args.file_map[self.thread_index])
total_images = len(self.file_map[self.thread_index])
logging.debug(f"{utcnow()} Reading {total_images} images thread {self.thread_index} rank {self._args.my_rank}")

for global_sample_idx, filename, sample_index in self._args.file_map[self.thread_index]:
for global_sample_idx, filename, sample_index in self.file_map[self.thread_index]:
self.image_idx = global_sample_idx
if filename not in self.open_file_map or self.open_file_map[filename] is None:
self.open_file_map[filename] = self.open(filename)
Expand All @@ -106,7 +112,8 @@ def next(self):
def read_index(self, global_sample_idx, step):
self.step = step
self.image_idx = global_sample_idx
filename, sample_index = self._args.global_index_map[global_sample_idx]
logging.debug(f"{self.global_index_map}")
filename, sample_index = self.global_index_map[global_sample_idx]
logging.debug(f"{utcnow()} read_index {filename}, {sample_index}")
FormatReader.read_images += 1
if self._args.read_type is ReadType.ON_DEMAND or filename not in self.open_file_map or self.open_file_map[filename] is None:
Expand Down
2 changes: 1 addition & 1 deletion dlio_benchmark/reader/tf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _parse_image(self, serialized):

@dlp.log
def next(self):
logging.debug(f"{utcnow()} Reading {self._file_list} files thread {self.thread_index} rank {self._args.my_rank}")
logging.debug(f"{utcnow()} Reading {len(self._file_list)} files thread {self.thread_index} rank {self._args.my_rank}")
filenames = tf.data.Dataset.list_files(self._file_list, shuffle=False)
# sharding in the file list if we have enought files.
if (len(self._file_list) >= self._args.comm_size):
Expand Down
Loading
Loading