-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding Native Dali Data Loader support for TFRecord, Images, and NPZ …
…files (#118) * fixed readthedoc build issue * partial merged the following PR: #81 * added back npz_reader * fixed bugs * fixed bugs * fixed image reader issue * fixed Profile, PerfTrace * removed unnecessary logs * fixed dali_image_reader * fixed dali_image_reader * added support for npy format * added support for npy format * changed enumerations * added removed dali base reader * fixed a bug * added native-dali-loader tests in github action * corrected github action formats * fixed read return * removed abstractmethod * fixed bugs * added dont_use_mmap * fixed indent * fixed csvreader * native_dali test with npy format instead of npz * fixed issue of enum * modify action so that dlio will always be installed * [skip ci] added documentation for dali * removed read; and define it as pipeline * added exceptions for unimplemented methods * added preprocessing * conditional cache for DLIO installation * fixed bugs * fixed bugs * fixed bugs * fixing again * tests again
- Loading branch information
Showing
20 changed files
with
515 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
""" | ||
Copyright (c) 2022, UChicago Argonne, LLC | ||
All Rights Reserved | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
from dlio_benchmark.common.enumerations import Compression | ||
from dlio_benchmark.data_generator.data_generator import DataGenerator | ||
|
||
import logging | ||
import numpy as np | ||
|
||
from dlio_benchmark.utils.utility import progress, utcnow | ||
from dlio_profiler.logger import fn_interceptor as Profile | ||
from shutil import copyfile | ||
from dlio_benchmark.common.constants import MODULE_DATA_GENERATOR | ||
|
||
dlp = Profile(MODULE_DATA_GENERATOR) | ||
|
||
""" | ||
Generator for creating data in NPZ format. | ||
""" | ||
class NPYGenerator(DataGenerator): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
@dlp.log | ||
def generate(self): | ||
""" | ||
Generator for creating data in NPY format of 3d dataset. | ||
""" | ||
super().generate() | ||
np.random.seed(10) | ||
record_labels = [0] * self.num_samples | ||
for i in dlp.iter(range(self.my_rank, int(self.total_files_to_generate), self.comm_size)): | ||
dim1, dim2 = self.get_dimension() | ||
records = np.random.randint(255, size=(dim1, dim2, self.num_samples), dtype=np.uint8) | ||
out_path_spec = self.storage.get_uri(self._file_list[i]) | ||
progress(i+1, self.total_files_to_generate, "Generating NPY Data") | ||
prev_out_spec = out_path_spec | ||
np.save(out_path_spec, records) | ||
np.random.seed() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from time import time | ||
import logging | ||
import math | ||
import numpy as np | ||
from nvidia.dali.pipeline import Pipeline | ||
import nvidia.dali.fn as fn | ||
import nvidia.dali.types as types | ||
import nvidia.dali as dali | ||
from nvidia.dali.plugin.pytorch import DALIGenericIterator | ||
|
||
from dlio_benchmark.common.constants import MODULE_DATA_LOADER | ||
from dlio_benchmark.common.enumerations import Shuffle, DataLoaderType, DatasetType | ||
from dlio_benchmark.data_loader.base_data_loader import BaseDataLoader | ||
from dlio_benchmark.reader.reader_factory import ReaderFactory | ||
from dlio_benchmark.utils.utility import utcnow, get_rank, timeit | ||
from dlio_profiler.logger import dlio_logger as PerfTrace, fn_interceptor as Profile | ||
|
||
dlp = Profile(MODULE_DATA_LOADER) | ||
|
||
|
||
class NativeDaliDataLoader(BaseDataLoader): | ||
@dlp.log_init | ||
def __init__(self, format_type, dataset_type, epoch): | ||
super().__init__(format_type, dataset_type, epoch, DataLoaderType.NATIVE_DALI) | ||
self.pipelines = [] | ||
|
||
@dlp.log | ||
def read(self): | ||
num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval | ||
batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval | ||
parallel = True if self._args.read_threads > 0 else False | ||
self.pipelines = [] | ||
num_threads = 1 | ||
if self._args.read_threads > 0: | ||
num_threads = self._args.read_threads | ||
# None executes pipeline on CPU and the reader does the batching | ||
pipeline = Pipeline(batch_size=batch_size, num_threads=num_threads, device_id=None, py_num_workers=num_threads, | ||
exec_async=False, exec_pipelined=False) | ||
with pipeline: | ||
images = ReaderFactory.get_reader(type=self.format_type, | ||
dataset_type=self.dataset_type, | ||
thread_index=-1, | ||
epoch_number=self.epoch_number).pipeline() | ||
pipeline.set_outputs(images) | ||
self.pipelines.append(pipeline) | ||
|
||
@dlp.log | ||
def next(self): | ||
super().next() | ||
num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval | ||
batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval | ||
for step in range(num_samples // batch_size): | ||
_dataset = DALIGenericIterator(self.pipelines, ['data']) | ||
for batch in _dataset: | ||
logging.debug(f"{utcnow()} Creating {len(batch)} batches by {self._args.my_rank} rank ") | ||
yield batch | ||
|
||
@dlp.log | ||
def finalize(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
""" | ||
Copyright (c) 2022, UChicago Argonne, LLC | ||
All Rights Reserved | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import math | ||
import logging | ||
from time import time, sleep | ||
import numpy as np | ||
|
||
import nvidia.dali.fn as fn | ||
from dlio_benchmark.common.constants import MODULE_DATA_READER | ||
from dlio_benchmark.dlio_benchmark.reader.reader_handler import FormatReader | ||
from dlio_benchmark.utils.utility import utcnow | ||
from dlio_benchmark.common.enumerations import DatasetType, Shuffle | ||
import nvidia.dali.tfrecord as tfrec | ||
from dlio_profiler.logger import dlio_logger as PerfTrace, fn_interceptor as Profile | ||
|
||
dlp = Profile(MODULE_DATA_READER) | ||
|
||
|
||
class DaliImageReader(FormatReader): | ||
@dlp.log_init | ||
def __init__(self, dataset_type, thread_index, epoch): | ||
super().__init__(dataset_type, thread_index) | ||
|
||
@dlp.log | ||
def open(self, filename): | ||
super().open(filename) | ||
|
||
def close(self): | ||
super().close() | ||
|
||
def get_sample(self, filename, sample_index): | ||
super().get_sample(filename, sample_index) | ||
raise Exception("get sample method is not implemented in dali readers") | ||
|
||
def next(self): | ||
super().next() | ||
raise Exception("next method is not implemented in dali readers") | ||
|
||
def read_index(self): | ||
super().read_index() | ||
raise Exception("read_index method is not implemented in dali readers") | ||
|
||
@dlp.log | ||
def pipeline(self): | ||
logging.debug( | ||
f"{utcnow()} Reading {len(self._file_list)} files rank {self._args.my_rank}") | ||
random_shuffle = False | ||
seed = -1 | ||
seed_change_epoch = False | ||
if self._args.sample_shuffle is not Shuffle.OFF: | ||
if self._args.sample_shuffle is not Shuffle.SEED: | ||
seed = self._args.seed | ||
random_shuffle = True | ||
seed_change_epoch = True | ||
initial_fill = 1024 | ||
if self._args.shuffle_size > 0: | ||
initial_fill = self._args.shuffle_size | ||
prefetch_size = 1 | ||
if self._args.prefetch_size > 0: | ||
prefetch_size = self._args.prefetch_size | ||
|
||
stick_to_shard = True | ||
if seed_change_epoch: | ||
stick_to_shard = False | ||
images, labels = fn.readers.file(files=self._file_list, num_shards=self._args.comm_size, | ||
prefetch_queue_depth=prefetch_size, | ||
initial_fill=initial_fill, random_shuffle=random_shuffle, | ||
shuffle_after_epoch=seed_change_epoch, | ||
stick_to_shard=stick_to_shard, pad_last_batch=True, | ||
dont_use_mmap=self._args.dont_use_mmap) | ||
images = fn.decoders.image(images, device='cpu') | ||
fn.python_function(dataset, function=self.preprocess, num_outputs=0) | ||
dataset = self._resize(images) | ||
return dataset | ||
|
||
@dlp.log | ||
def _resize(self, dataset): | ||
return fn.resize(dataset, size=[self._args.max_dimension, self._args.max_dimension]) | ||
|
||
@dlp.log | ||
def finalize(self): | ||
pass |
Oops, something went wrong.