From 657d4b9f7cf5a5b1e9eaa6b269dbef9aff68e8e4 Mon Sep 17 00:00:00 2001 From: Huihuo Zheng Date: Fri, 15 Dec 2023 09:16:10 -0600 Subject: [PATCH] Adding Native Dali Data Loader support for TFRecord, Images, and NPZ files (#118) * fixed readthedoc build issue * partial merged the following PR: https://github.com/argonne-lcf/dlio_benchmark/pull/81 * added back npz_reader * fixed bugs * fixed bugs * fixed image reader issue * fixed Profile, PerfTrace * removed unnecessary logs * fixed dali_image_reader * fixed dali_image_reader * added support for npy format * added support for npy format * changed enumerations * added removed dali base reader * fixed a bug * added native-dali-loader tests in github action * corrected github action formats * fixed read return * removed abstractmethod * fixed bugs * added dont_use_mmap * fixed indent * fixed csvreader * native_dali test with npy format instead of npz * fixed issue of enum * modify action so that dlio will always be installed * [skip ci] added documentation for dali * removed read; and define it as pipeline * added exceptions for unimplemented methods * added preprocessing * conditional cache for DLIO installation * fixed bugs * fixed bugs * fixed bugs * fixing again * tests again --- .github/workflows/python-package-conda.yml | 29 ++++- README.md | 2 +- dlio_benchmark/common/enumerations.py | 6 +- .../data_generator/generator_factory.py | 3 + .../data_generator/npy_generator.py | 53 +++++++++ dlio_benchmark/data_generator/tf_generator.py | 15 ++- .../data_loader/data_loader_factory.py | 3 + .../data_loader/native_dali_data_loader.py | 60 ++++++++++ dlio_benchmark/reader/csv_reader.py | 2 +- dlio_benchmark/reader/dali_image_reader.py | 96 ++++++++++++++++ dlio_benchmark/reader/dali_npy_reader.py | 97 ++++++++++++++++ dlio_benchmark/reader/dali_tfrecord_reader.py | 108 ++++++++++++++++++ dlio_benchmark/reader/hdf5_reader.py | 9 +- .../reader/{png_reader.py => image_reader.py} | 4 +- .../reader/{jpeg_reader.py => npy_reader.py} | 10 +- dlio_benchmark/reader/reader_factory.py | 38 ++++-- dlio_benchmark/reader/reader_handler.py | 5 +- dlio_benchmark/reader/tf_reader.py | 9 +- dlio_benchmark/utils/config.py | 9 +- docs/source/config.rst | 4 +- 20 files changed, 515 insertions(+), 47 deletions(-) create mode 100644 dlio_benchmark/data_generator/npy_generator.py create mode 100644 dlio_benchmark/data_loader/native_dali_data_loader.py create mode 100644 dlio_benchmark/reader/dali_image_reader.py create mode 100644 dlio_benchmark/reader/dali_npy_reader.py create mode 100644 dlio_benchmark/reader/dali_tfrecord_reader.py rename dlio_benchmark/reader/{png_reader.py => image_reader.py} (96%) rename dlio_benchmark/reader/{jpeg_reader.py => npy_reader.py} (90%) diff --git a/.github/workflows/python-package-conda.yml b/.github/workflows/python-package-conda.yml index 169df873..40dd2091 100644 --- a/.github/workflows/python-package-conda.yml +++ b/.github/workflows/python-package-conda.yml @@ -47,8 +47,18 @@ jobs: - name: Install System Tools run: | sudo apt update - sudo apt-get install $CC $CXX libc6 + sudo apt-get install $CC $CXX libc6 git sudo apt-get install mpich libhwloc-dev + - name: Install DLIO code only + if: steps.cache-modules.outputs.cache-hit == 'true' + run: | + source ${VENV}/bin/activate + rm -rf *.egg* + rm -rf build + rm -rf dist + pip uninstall -y dlio_benchmark + python setup.py build + python setup.py install - name: Install DLIO if: steps.cache-modules.outputs.cache-hit != 'true' run: | @@ -57,8 +67,7 @@ jobs: pip install virtualenv python -m venv ${VENV} source ${VENV}/bin/activate - pip install .[test] - rm -rf dlio_benchmark + pip install .[test] - name: Install DLIO Profiler run: | echo "Profiler ${DLIO_PROFILER} gcc $CC" @@ -152,8 +161,18 @@ jobs: - name: test-tf-loader-npz run: | source ${VENV}/bin/activate - mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.data_reader.data_loader=tensorflow ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=1 ++workload.workflow.train=False ++workload.workflow.generate_data=True ++workload.dataset.num_files_train=8 ++workload.dataset.num_files_eval=8 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0 - mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.data_reader.data_loader=tensorflow ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=1 ++workload.workflow.train=True ++workload.workflow.generate_data=False ++workload.dataset.num_files_train=8 ++workload.dataset.num_files_eval=8 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0 + mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.data_reader.data_loader=tensorflow ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=2 ++workload.workflow.train=False ++workload.workflow.generate_data=True ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0 + mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.data_reader.data_loader=tensorflow ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=2 ++workload.workflow.train=True ++workload.workflow.generate_data=False ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0 + - name: test-torch-native-dali-loader-npy + run: | + source ${VENV}/bin/activate + mpirun -np 2 dlio_benchmark workload=unet3d ++workload.reader.data_loader=native_dali ++workload.dataset.format=npy ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=1 ++workload.workflow.train=False ++workload.workflow.generate_data=True ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0 + mpirun -np 2 dlio_benchmark workload=unet3d ++workload.reader.data_loader=native_dali ++workload.dataset.format=npy ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=1 ++workload.workflow.train=True ++workload.workflow.generate_data=False ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0 + - name: test-tf-native-dali-loader-npy + run: | + source ${VENV}/bin/activate + mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.dataset.format=npy ++workload.reader.data_loader=native_dali ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=1 ++workload.workflow.train=False ++workload.workflow.generate_data=True ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0 + mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.dataset.format=npy ++workload.reader.data_loader=native_dali ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=1 ++workload.workflow.train=True ++workload.workflow.generate_data=False ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0 - name: test_subset run: | source ${VENV}/bin/activate diff --git a/README.md b/README.md index 3ae53fbb..5990ff57 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ dlio_benchmark ++workload.workflow.generate_data=True git clone https://github.com/argonne-lcf/dlio_benchmark cd dlio_benchmark/ pip install .[dlio_profiler] - +``` ## Container ```bash diff --git a/dlio_benchmark/common/enumerations.py b/dlio_benchmark/common/enumerations.py index 64227772..9fab8f23 100644 --- a/dlio_benchmark/common/enumerations.py +++ b/dlio_benchmark/common/enumerations.py @@ -93,6 +93,7 @@ class FormatType(Enum): HDF5 = 'hdf5' CSV = 'csv' NPZ = 'npz' + NPY = 'npy' HDF5_OPT = 'hdf5_opt' JPEG = 'jpeg' PNG = 'png' @@ -100,7 +101,7 @@ class FormatType(Enum): def __str__(self): return self.value - @ staticmethod + @staticmethod def get_enum(value): if FormatType.TFRECORD.value == value: return FormatType.TFRECORD @@ -110,6 +111,8 @@ def get_enum(value): return FormatType.CSV elif FormatType.NPZ.value == value: return FormatType.NPZ + elif FormatType.NPY.value == value: + return FormatType.NPY elif FormatType.HDF5_OPT.value == value: return FormatType.HDF5_OPT elif FormatType.JPEG.value == value: @@ -124,6 +127,7 @@ class DataLoaderType(Enum): TENSORFLOW='tensorflow' PYTORCH='pytorch' DALI='dali' + NATIVE_DALI='native_dali' CUSTOM='custom' NONE='none' diff --git a/dlio_benchmark/data_generator/generator_factory.py b/dlio_benchmark/data_generator/generator_factory.py index 7c05d3a4..e61ead4c 100644 --- a/dlio_benchmark/data_generator/generator_factory.py +++ b/dlio_benchmark/data_generator/generator_factory.py @@ -38,6 +38,9 @@ def get_generator(type): elif type == FormatType.NPZ: from dlio_benchmark.data_generator.npz_generator import NPZGenerator return NPZGenerator() + elif type == FormatType.NPY: + from dlio_benchmark.data_generator.npy_generator import NPYGenerator + return NPYGenerator() elif type == FormatType.JPEG: from dlio_benchmark.data_generator.jpeg_generator import JPEGGenerator return JPEGGenerator() diff --git a/dlio_benchmark/data_generator/npy_generator.py b/dlio_benchmark/data_generator/npy_generator.py new file mode 100644 index 00000000..d60ec2f3 --- /dev/null +++ b/dlio_benchmark/data_generator/npy_generator.py @@ -0,0 +1,53 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from dlio_benchmark.common.enumerations import Compression +from dlio_benchmark.data_generator.data_generator import DataGenerator + +import logging +import numpy as np + +from dlio_benchmark.utils.utility import progress, utcnow +from dlio_profiler.logger import fn_interceptor as Profile +from shutil import copyfile +from dlio_benchmark.common.constants import MODULE_DATA_GENERATOR + +dlp = Profile(MODULE_DATA_GENERATOR) + +""" +Generator for creating data in NPZ format. +""" +class NPYGenerator(DataGenerator): + def __init__(self): + super().__init__() + + @dlp.log + def generate(self): + """ + Generator for creating data in NPY format of 3d dataset. + """ + super().generate() + np.random.seed(10) + record_labels = [0] * self.num_samples + for i in dlp.iter(range(self.my_rank, int(self.total_files_to_generate), self.comm_size)): + dim1, dim2 = self.get_dimension() + records = np.random.randint(255, size=(dim1, dim2, self.num_samples), dtype=np.uint8) + out_path_spec = self.storage.get_uri(self._file_list[i]) + progress(i+1, self.total_files_to_generate, "Generating NPY Data") + prev_out_spec = out_path_spec + np.save(out_path_spec, records) + np.random.seed() diff --git a/dlio_benchmark/data_generator/tf_generator.py b/dlio_benchmark/data_generator/tf_generator.py index f10a9621..b1a94c9c 100644 --- a/dlio_benchmark/data_generator/tf_generator.py +++ b/dlio_benchmark/data_generator/tf_generator.py @@ -14,12 +14,15 @@ See the License for the specific language governing permissions and limitations under the License. """ +import os +from subprocess import call from dlio_benchmark.data_generator.data_generator import DataGenerator import numpy as np import tensorflow as tf -from dlio_benchmark.utils.utility import progress, utcnow from dlio_profiler.logger import fn_interceptor as Profile + +from dlio_benchmark.utils.utility import progress, utcnow from shutil import copyfile from dlio_benchmark.common.constants import MODULE_DATA_GENERATOR @@ -64,4 +67,14 @@ def generate(self): serialized = example.SerializeToString() # Write the serialized data to the TFRecords file. writer.write(serialized) + tfrecord2idx_script = "tfrecord2idx" + folder = "train" + if "valid" in out_path_spec: + folder = "valid" + index_folder = f"{self._args.data_folder}/index/{folder}" + filename = os.path.basename(out_path_spec) + self.storage.create_node(index_folder, exist_ok=True) + tfrecord_idx = f"{index_folder}/{filename}.idx" + if not os.path.isfile(tfrecord_idx): + call([tfrecord2idx_script, out_path_spec, tfrecord_idx]) np.random.seed() diff --git a/dlio_benchmark/data_loader/data_loader_factory.py b/dlio_benchmark/data_loader/data_loader_factory.py index e8457450..13bf16b0 100644 --- a/dlio_benchmark/data_loader/data_loader_factory.py +++ b/dlio_benchmark/data_loader/data_loader_factory.py @@ -45,6 +45,9 @@ def get_loader(type, format_type, dataset_type, epoch): elif type == DataLoaderType.DALI: from dlio_benchmark.data_loader.dali_data_loader import DaliDataLoader return DaliDataLoader(format_type, dataset_type, epoch) + elif type == DataLoaderType.NATIVE_DALI: + from dlio_benchmark.data_loader.native_dali_data_loader import NativeDaliDataLoader + return NativeDaliDataLoader(format_type, dataset_type, epoch) else: print("Data Loader %s not supported or plugins not found" % type) raise Exception(str(ErrorCodes.EC1004)) diff --git a/dlio_benchmark/data_loader/native_dali_data_loader.py b/dlio_benchmark/data_loader/native_dali_data_loader.py new file mode 100644 index 00000000..900c8c6d --- /dev/null +++ b/dlio_benchmark/data_loader/native_dali_data_loader.py @@ -0,0 +1,60 @@ +from time import time +import logging +import math +import numpy as np +from nvidia.dali.pipeline import Pipeline +import nvidia.dali.fn as fn +import nvidia.dali.types as types +import nvidia.dali as dali +from nvidia.dali.plugin.pytorch import DALIGenericIterator + +from dlio_benchmark.common.constants import MODULE_DATA_LOADER +from dlio_benchmark.common.enumerations import Shuffle, DataLoaderType, DatasetType +from dlio_benchmark.data_loader.base_data_loader import BaseDataLoader +from dlio_benchmark.reader.reader_factory import ReaderFactory +from dlio_benchmark.utils.utility import utcnow, get_rank, timeit +from dlio_profiler.logger import dlio_logger as PerfTrace, fn_interceptor as Profile + +dlp = Profile(MODULE_DATA_LOADER) + + +class NativeDaliDataLoader(BaseDataLoader): + @dlp.log_init + def __init__(self, format_type, dataset_type, epoch): + super().__init__(format_type, dataset_type, epoch, DataLoaderType.NATIVE_DALI) + self.pipelines = [] + + @dlp.log + def read(self): + num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval + batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval + parallel = True if self._args.read_threads > 0 else False + self.pipelines = [] + num_threads = 1 + if self._args.read_threads > 0: + num_threads = self._args.read_threads + # None executes pipeline on CPU and the reader does the batching + pipeline = Pipeline(batch_size=batch_size, num_threads=num_threads, device_id=None, py_num_workers=num_threads, + exec_async=False, exec_pipelined=False) + with pipeline: + images = ReaderFactory.get_reader(type=self.format_type, + dataset_type=self.dataset_type, + thread_index=-1, + epoch_number=self.epoch_number).pipeline() + pipeline.set_outputs(images) + self.pipelines.append(pipeline) + + @dlp.log + def next(self): + super().next() + num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval + batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval + for step in range(num_samples // batch_size): + _dataset = DALIGenericIterator(self.pipelines, ['data']) + for batch in _dataset: + logging.debug(f"{utcnow()} Creating {len(batch)} batches by {self._args.my_rank} rank ") + yield batch + + @dlp.log + def finalize(self): + pass diff --git a/dlio_benchmark/reader/csv_reader.py b/dlio_benchmark/reader/csv_reader.py index be1e587b..0690fc69 100644 --- a/dlio_benchmark/reader/csv_reader.py +++ b/dlio_benchmark/reader/csv_reader.py @@ -57,4 +57,4 @@ def read_index(self, image_idx, step): @dlp.log def finalize(self): - return super().finalize() + return super().finalize() \ No newline at end of file diff --git a/dlio_benchmark/reader/dali_image_reader.py b/dlio_benchmark/reader/dali_image_reader.py new file mode 100644 index 00000000..032885cd --- /dev/null +++ b/dlio_benchmark/reader/dali_image_reader.py @@ -0,0 +1,96 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import math +import logging +from time import time, sleep +import numpy as np + +import nvidia.dali.fn as fn +from dlio_benchmark.common.constants import MODULE_DATA_READER +from dlio_benchmark.dlio_benchmark.reader.reader_handler import FormatReader +from dlio_benchmark.utils.utility import utcnow +from dlio_benchmark.common.enumerations import DatasetType, Shuffle +import nvidia.dali.tfrecord as tfrec +from dlio_profiler.logger import dlio_logger as PerfTrace, fn_interceptor as Profile + +dlp = Profile(MODULE_DATA_READER) + + +class DaliImageReader(FormatReader): + @dlp.log_init + def __init__(self, dataset_type, thread_index, epoch): + super().__init__(dataset_type, thread_index) + + @dlp.log + def open(self, filename): + super().open(filename) + + def close(self): + super().close() + + def get_sample(self, filename, sample_index): + super().get_sample(filename, sample_index) + raise Exception("get sample method is not implemented in dali readers") + + def next(self): + super().next() + raise Exception("next method is not implemented in dali readers") + + def read_index(self): + super().read_index() + raise Exception("read_index method is not implemented in dali readers") + + @dlp.log + def pipeline(self): + logging.debug( + f"{utcnow()} Reading {len(self._file_list)} files rank {self._args.my_rank}") + random_shuffle = False + seed = -1 + seed_change_epoch = False + if self._args.sample_shuffle is not Shuffle.OFF: + if self._args.sample_shuffle is not Shuffle.SEED: + seed = self._args.seed + random_shuffle = True + seed_change_epoch = True + initial_fill = 1024 + if self._args.shuffle_size > 0: + initial_fill = self._args.shuffle_size + prefetch_size = 1 + if self._args.prefetch_size > 0: + prefetch_size = self._args.prefetch_size + + stick_to_shard = True + if seed_change_epoch: + stick_to_shard = False + images, labels = fn.readers.file(files=self._file_list, num_shards=self._args.comm_size, + prefetch_queue_depth=prefetch_size, + initial_fill=initial_fill, random_shuffle=random_shuffle, + shuffle_after_epoch=seed_change_epoch, + stick_to_shard=stick_to_shard, pad_last_batch=True, + dont_use_mmap=self._args.dont_use_mmap) + images = fn.decoders.image(images, device='cpu') + fn.python_function(dataset, function=self.preprocess, num_outputs=0) + dataset = self._resize(images) + return dataset + + @dlp.log + def _resize(self, dataset): + return fn.resize(dataset, size=[self._args.max_dimension, self._args.max_dimension]) + + @dlp.log + def finalize(self): + pass \ No newline at end of file diff --git a/dlio_benchmark/reader/dali_npy_reader.py b/dlio_benchmark/reader/dali_npy_reader.py new file mode 100644 index 00000000..e68f4fb2 --- /dev/null +++ b/dlio_benchmark/reader/dali_npy_reader.py @@ -0,0 +1,97 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import math +import logging +from time import time, sleep +import numpy as np + +import nvidia.dali.fn as fn +from dlio_benchmark.common.constants import MODULE_DATA_READER +from dlio_benchmark.reader.reader_handler import FormatReader +from dlio_benchmark.utils.utility import utcnow +from dlio_benchmark.common.enumerations import DatasetType, Shuffle +import nvidia.dali.tfrecord as tfrec +from dlio_profiler.logger import dlio_logger as PerfTrace, fn_interceptor as Profile + +dlp = Profile(MODULE_DATA_READER) + + +class DaliNPYReader(FormatReader): + @dlp.log_init + def __init__(self, dataset_type, thread_index, epoch): + super().__init__(dataset_type, thread_index) + + @dlp.log + def open(self, filename): + super().open(filename) + + @dlp.log + def pipeline(self): + logging.debug( + f"{utcnow()} Reading {len(self._file_list)} files rank {self._args.my_rank}") + random_shuffle = False + seed = -1 + seed_change_epoch = False + if self._args.sample_shuffle is not Shuffle.OFF: + if self._args.sample_shuffle is not Shuffle.SEED: + seed = self._args.seed + random_shuffle = True + seed_change_epoch = True + initial_fill = 1024 + if self._args.shuffle_size > 0: + initial_fill = self._args.shuffle_size + prefetch_size = 1 + if self._args.prefetch_size > 0: + prefetch_size = self._args.prefetch_size + + stick_to_shard = True + if random_shuffle == True: + seed_change_epoch = False + if seed_change_epoch: + stick_to_shard = False + + dataset = fn.readers.numpy(device='cpu', files=self._file_list, num_shards=self._args.comm_size, + prefetch_queue_depth=prefetch_size, initial_fill=initial_fill, + random_shuffle=random_shuffle, seed=seed, shuffle_after_epoch=seed_change_epoch, + stick_to_shard=stick_to_shard, pad_last_batch=True, + dont_use_mmap=self._args.dont_use_mmap) + dataset = self._resize(dataset) + fn.python_function(dataset, function=self.preprocess, num_outputs=0) + return dataset + + def close(self): + super().close() + + def get_sample(self, filename, sample_index): + super().get_sample(filename, sample_index) + raise Exception("get sample method is not implemented in dali readers") + + def next(self): + super().next() + raise Exception("next method is not implemented in dali readers") + + def read_index(self): + super().read_index() + raise Exception("read_index method is not implemented in dali readers") + + @dlp.log + def _resize(self, dataset): + return fn.resize(dataset, size=[self._args.max_dimension, self._args.max_dimension]) + + @dlp.log + def finalize(self): + pass diff --git a/dlio_benchmark/reader/dali_tfrecord_reader.py b/dlio_benchmark/reader/dali_tfrecord_reader.py new file mode 100644 index 00000000..0a33f327 --- /dev/null +++ b/dlio_benchmark/reader/dali_tfrecord_reader.py @@ -0,0 +1,108 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import os.path + +import math +import logging +from time import time, sleep +import numpy as np + +import nvidia +import nvidia.dali.fn as fn +from dlio_benchmark.common.constants import MODULE_DATA_READER +from dlio_benchmark.dlio_benchmark.reader.reader_handler import FormatReader +from dlio_benchmark.utils.utility import utcnow +from dlio_benchmark.common.enumerations import DatasetType, Shuffle +import nvidia.dali.tfrecord as tfrec +from dlio_profiler.logger import dlio_logger as PerfTrace, fn_interceptor as Profile + +dlp = Profile(MODULE_DATA_READER) + + +class DaliTFRecordReader(FormatReader): + """ + Reader for NPZ files + """ + @dlp.log_init + def __init__(self, dataset_type, thread_index, epoch): + super().__init__(dataset_type, thread_index) + + @dlp.log + def open(self, filename): + super().open(filename) + + def close(self): + super().close() + + @dlp.log + def pipeline(self): + folder = "valid" + if self.dataset_type == DatasetType.TRAIN: + folder = "train" + index_folder = f"{self._args.data_folder}/index/{folder}" + index_files = [] + for file in self._file_list: + filename = os.path.basename(file) + index_files.append(f"{index_folder}/{filename}.idx") + logging.info( + f"{utcnow()} Reading {len(self.file_list)} files rank {self._args.my_rank}") + random_shuffle = False + seed = -1 + if self._args.sample_shuffle is not Shuffle.OFF: + if self._args.sample_shuffle is not Shuffle.SEED: + seed = self._args.seed + random_shuffle = True + initial_fill = 1024 + if self._args.shuffle_size > 0: + initial_fill = self._args.shuffle_size + prefetch_size = 1 + if self._args.prefetch_size > 0: + prefetch_size = self._args.prefetch_size + dataset = fn.readers.tfrecord(path=self._file_list, + index_path=index_files, + features={ + 'image': tfrec.FixedLenFeature((), tfrec.string, ""), + 'size': tfrec.FixedLenFeature([1], tfrec.int64, 0) + }, num_shards=self._args.comm_size, + prefetch_queue_depth=prefetch_size, + initial_fill=initial_fill, + random_shuffle=random_shuffle, seed=seed, + stick_to_shard=True, pad_last_batch=True, + dont_use_mmap=self._args.dont_use_mmap) + dataset = self._resize(dataset['image']) + fn.python_function(dataset, function=self.preprocess, num_outputs=0) + return dataset + + def get_sample(self, filename, sample_index): + super().get_sample(filename, sample_index) + raise Exception("get sample method is not implemented in dali readers") + + def next(self): + super().next() + raise Exception("next method is not implemented in dali readers") + + def read_index(self): + super().read_index() + raise Exception("read_index method is not implemented in dali readers") + + @dlp.log + def _resize(self, dataset): + return fn.resize(dataset, size=[self._args.max_dimension, self._args.max_dimension]) + + @dlp.log + def finalize(self): + pass diff --git a/dlio_benchmark/reader/hdf5_reader.py b/dlio_benchmark/reader/hdf5_reader.py index f95d9b94..77dd7301 100644 --- a/dlio_benchmark/reader/hdf5_reader.py +++ b/dlio_benchmark/reader/hdf5_reader.py @@ -24,13 +24,10 @@ dlp = Profile(MODULE_DATA_READER) -""" -Reader for HDF5 files for training file. -""" - - class HDF5Reader(FormatReader): - + """ + Reader for HDF5 files. + """ @dlp.log_init def __init__(self, dataset_type, thread_index, epoch): super().__init__(dataset_type, thread_index) diff --git a/dlio_benchmark/reader/png_reader.py b/dlio_benchmark/reader/image_reader.py similarity index 96% rename from dlio_benchmark/reader/png_reader.py rename to dlio_benchmark/reader/image_reader.py index 64183dd3..1fe63a05 100644 --- a/dlio_benchmark/reader/png_reader.py +++ b/dlio_benchmark/reader/image_reader.py @@ -26,9 +26,9 @@ dlp = Profile(MODULE_DATA_READER) -class PNGReader(FormatReader): +class ImageReader(FormatReader): """ - Reader for PNG files + Reader for PNG / JPEG files """ @dlp.log_init diff --git a/dlio_benchmark/reader/jpeg_reader.py b/dlio_benchmark/reader/npy_reader.py similarity index 90% rename from dlio_benchmark/reader/jpeg_reader.py rename to dlio_benchmark/reader/npy_reader.py index 664cde04..d3cc46f1 100644 --- a/dlio_benchmark/reader/jpeg_reader.py +++ b/dlio_benchmark/reader/npy_reader.py @@ -14,9 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. """ - import numpy as np -from PIL import Image from dlio_benchmark.common.constants import MODULE_DATA_READER from dlio_benchmark.reader.reader_handler import FormatReader @@ -25,9 +23,9 @@ dlp = Profile(MODULE_DATA_READER) -class JPEGReader(FormatReader): +class NPYReader(FormatReader): """ - Reader for JPEG files + Reader for NPY files """ @dlp.log_init @@ -37,7 +35,7 @@ def __init__(self, dataset_type, thread_index, epoch): @dlp.log def open(self, filename): super().open(filename) - return np.asarray(Image.open(filename)) + return np.load(filename) @dlp.log def close(self, filename): @@ -46,7 +44,7 @@ def close(self, filename): @dlp.log def get_sample(self, filename, sample_index): super().get_sample(filename, sample_index) - image = self.open_file_map[filename] + image = self.open_file_map[filename][..., sample_index] dlp.update(image_size=image.nbytes) def next(self): diff --git a/dlio_benchmark/reader/reader_factory.py b/dlio_benchmark/reader/reader_factory.py index 74fc353e..e84db142 100644 --- a/dlio_benchmark/reader/reader_factory.py +++ b/dlio_benchmark/reader/reader_factory.py @@ -43,18 +43,32 @@ def get_reader(type, dataset_type, thread_index, epoch_number): elif type == FormatType.CSV: from dlio_benchmark.reader.csv_reader import CSVReader return CSVReader(dataset_type, thread_index, epoch_number) - elif type == FormatType.JPEG: - from dlio_benchmark.reader.jpeg_reader import JPEGReader - return JPEGReader(dataset_type, thread_index, epoch_number) - elif type == FormatType.PNG: - from dlio_benchmark.reader.png_reader import PNGReader - return PNGReader(dataset_type, thread_index, epoch_number) + elif type == FormatType.JPEG or type == FormatType.PNG: + if _args.data_loader == DataLoaderType.NATIVE_DALI: + from dlio_benchmark.reader.dali_image_reader import DaliImageReader + return DaliImageReader(dataset_type, thread_index, epoch_number) + else: + from dlio_benchmark.reader.image_reader import ImageReader + return ImageReader(dataset_type, thread_index, epoch_number) + elif type == FormatType.NPY: + if _args.data_loader == DataLoaderType.NATIVE_DALI: + from dlio_benchmark.reader.dali_npy_reader import DaliNPYReader + return DaliNPYReader(dataset_type, thread_index, epoch_number) + else: + from dlio_benchmark.reader.npy_reader import NPYReader + return NPYReader(dataset_type, thread_index, epoch_number) elif type == FormatType.NPZ: - from dlio_benchmark.reader.npz_reader import NPZReader - return NPZReader(dataset_type, thread_index, epoch_number) + if _args.data_loader == DataLoaderType.NATIVE_DALI: + raise Exception("Loading data of %s format is not supported without framework data loader; please use npy format instead." %type) + else: + from dlio_benchmark.reader.npz_reader import NPZReader + return NPZReader(dataset_type, thread_index, epoch_number) elif type == FormatType.TFRECORD: - from dlio_benchmark.reader.tf_reader import TFReader - return TFReader(dataset_type, thread_index, epoch_number) + if _args.data_loader == DataLoaderType.NATIVE_DALI: + from dlio_benchmark.reader.dali_tfrecord_reader import DaliTFRecordReader + return DaliTFRecordReader(dataset_type, thread_index, epoch_number) + else: + from dlio_benchmark.reader.tf_reader import TFReader + return TFReader(dataset_type, thread_index, epoch_number) else: - print("Loading data of %s format is not supported without framework data loader" %type) - raise Exception(type) + raise Exception("Loading data of %s format is not supported without framework data loader" %type) \ No newline at end of file diff --git a/dlio_benchmark/reader/reader_handler.py b/dlio_benchmark/reader/reader_handler.py index 48675ab3..57884dd5 100644 --- a/dlio_benchmark/reader/reader_handler.py +++ b/dlio_benchmark/reader/reader_handler.py @@ -49,17 +49,18 @@ def __init__(self, dataset_type, thread_index): FormatReader.read_images = 0 self.step = 1 self.image_idx = 0 + self._file_list = self._args.file_list_train if self.dataset_type is DatasetType.TRAIN else self._args.file_list_eval self.batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval @dlp.log - def preprocess(self): + def preprocess(self, a=None): if self._args.preprocess_time != 0. or self._args.preprocess_time_stdev != 0.: t = np.random.normal(self._args.preprocess_time, self._args.preprocess_time_stdev) sleep(max(t, 0.0)) @abstractmethod def open(self, filename): - pass + return @abstractmethod def close(self, filename): diff --git a/dlio_benchmark/reader/tf_reader.py b/dlio_benchmark/reader/tf_reader.py index a889076b..d19ea96a 100644 --- a/dlio_benchmark/reader/tf_reader.py +++ b/dlio_benchmark/reader/tf_reader.py @@ -37,8 +37,7 @@ class TFReader(FormatReader): def __init__(self, dataset_type, thread_index, epoch): super().__init__(dataset_type, thread_index) self._dataset = None - self._file_list = self._args.file_list_train if self.dataset_type is DatasetType.TRAIN else self._args.file_list_eval - self.batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval + @dlp.log def open(self, filename): pass @@ -56,7 +55,7 @@ def resize_sample(self, filename, sample_index): pass @dlp.log - def parse_image(self, serialized): + def _parse_image(self, serialized): """ performs deserialization of the tfrecord. :param serialized: is the serialized version using protobuf @@ -86,7 +85,7 @@ def next(self): self._dataset = tf.data.TFRecordDataset(filenames=self._file_list, buffer_size=self._args.transfer_size) self._dataset = self._dataset.shard(num_shards=self._args.comm_size, index=self._args.my_rank) self._dataset = self._dataset.map( - lambda x: tf.py_function(func=self.parse_image, inp=[x], Tout=[tf.uint8]) + lambda x: tf.py_function(func=self._parse_image, inp=[x], Tout=[tf.uint8]) , num_parallel_calls=self._args.computation_threads) self._dataset = self._dataset.batch(self.batch_size, drop_remainder=True) total = math.ceil(len(self._file_list)/self._args.comm_size / self.batch_size * self._args.num_samples_per_file) @@ -104,4 +103,4 @@ def read_index(self, image_idx, step): @dlp.log def finalize(self): - return super().finalize() + return super().finalize() \ No newline at end of file diff --git a/dlio_benchmark/utils/config.py b/dlio_benchmark/utils/config.py index c6d319d7..eebf8525 100644 --- a/dlio_benchmark/utils/config.py +++ b/dlio_benchmark/utils/config.py @@ -78,6 +78,7 @@ class ConfigArguments: steps_between_checkpoints: int = -1 transfer_size: int = None read_threads: int = 1 + dont_use_mmap: bool = False computation_threads: int = 1 computation_time: float = 0. computation_time_stdev: float = 0. @@ -149,10 +150,8 @@ def validate(self): if (self.do_profiling == True) and (self.profiler == Profiler('darshan')): if ('LD_PRELOAD' not in os.environ or os.environ["LD_PRELOAD"].find("libdarshan") == -1): raise Exception("Please set darshan runtime library in LD_PRELOAD") - if self.format is FormatType.TFRECORD and self.framework is not FrameworkType.TENSORFLOW: - raise Exception(f"{self.framework} support for tfrecord is not implemented.") - if self.format is FormatType.TFRECORD and self.data_loader is not DataLoaderType.TENSORFLOW: - raise Exception(f"{self.data_loader} support for tfrecord is not implemented.") + if self.format is FormatType.TFRECORD and (self.data_loader is DataLoaderType.PYTORCH): + raise Exception(f"{self.framework} support for tfrecord is not implemented for {self.data_loader}.") if (self.framework == FrameworkType.TENSORFLOW and self.data_loader == DataLoaderType.PYTORCH) or ( self.framework == FrameworkType.PYTORCH and self.data_loader == DataLoaderType.TENSORFLOW): raise Exception("Imcompatible between framework and data_loader setup.") @@ -351,6 +350,8 @@ def LoadConfig(args, config): elif 'reader' in config: reader = config['reader'] if reader is not None: + if 'dont_use_mmap' in reader: + args.dont_use_mmap = reader['dont_use_mmap'] if 'reader_classname' in reader: args.reader_classname = reader['reader_classname'] if 'multiprocessing_context' in reader: diff --git a/docs/source/config.rst b/docs/source/config.rst index 920b6590..9d245ca2 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -189,7 +189,7 @@ reader - Description * - data_loader - tensorflow - - select the data loader to use [tensorflow|pytorch|dali]. + - select the data loader to use [tensorflow|pytorch|dali|native_dali]. * - batch_size - 1 - batch size for training @@ -227,6 +227,8 @@ reader For pytorch, ``prefetch_size`` is set to be 0, it will be changed to 2. In other words, the default value for ``prefetch_size`` in pytorch is 2. + For Dali data loader, we support two options, ``dali`` and ``native_dali```. ``dali`` uses our internal reader, such as ``jpeg_reader``, ``hdf5_reader``, etc, and ``dali.fn.external_source``; whereas ``native_dali`` directly uses Dali readers, such as ``dn.readers.numpy``, ``fn.readers.tfrecord``, and ``fn.readers.file``. + train ------------------