Skip to content

Commit

Permalink
Remove shuffling from tfreader as it is already correctly done during…
Browse files Browse the repository at this point in the history
… reconfigure.
  • Loading branch information
hariharan-devarajan committed Jul 30, 2024
1 parent 82c0796 commit 1510dfd
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion dlio_benchmark/data_loader/tf_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def read(self):
self._dataset = ReaderFactory.get_reader(type=self.format_type,
dataset_type=self.dataset_type,
thread_index=-1,
epoch_number=0).next()
epoch_number=self.epoch_number).next()

@dlp.log
def next(self):
Expand Down
7 changes: 4 additions & 3 deletions dlio_benchmark/reader/tf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,13 @@ def _parse_image(self, serialized):

@dlp.log
def next(self):
logging.debug(f"{utcnow()} Reading {len(self._file_list)} files thread {self.thread_index} rank {self._args.my_rank}")
filenames = tf.data.Dataset.list_files(self._file_list, shuffle=True)
logging.debug(f"{utcnow()} Reading {self._file_list} files thread {self.thread_index} rank {self._args.my_rank}")
filenames = tf.data.Dataset.list_files(self._file_list, shuffle=False)
# sharding in the file list if we have enought files.
if (len(self._file_list) >= self._args.comm_size):
filenames = filenames.shard(num_shards=self._args.comm_size, index=self._args.my_rank)

logging.debug(f"{utcnow()} shard {filenames} files index {self._args.my_rank} number {self._args.comm_size}")

self._dataset = tf.data.TFRecordDataset(filenames=filenames, buffer_size=self._args.transfer_size,
num_parallel_reads=self._args.read_threads)

Expand Down

0 comments on commit 1510dfd

Please sign in to comment.