-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
partial merged the following PR: #81
- Loading branch information
Showing
10 changed files
with
297 additions
and
134 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from time import time | ||
import logging | ||
import math | ||
import numpy as np | ||
from nvidia.dali.pipeline import Pipeline | ||
import nvidia.dali.fn as fn | ||
import nvidia.dali.types as types | ||
import nvidia.dali as dali | ||
from nvidia.dali.plugin.pytorch import DALIGenericIterator | ||
|
||
from dlio_benchmark.common.constants import MODULE_DATA_LOADER | ||
from dlio_benchmark.common.enumerations import Shuffle, DataLoaderType, DatasetType | ||
from dlio_benchmark.data_loader.base_data_loader import BaseDataLoader | ||
from dlio_benchmark.reader.reader_factory import ReaderFactory | ||
from dlio_benchmark.utils.utility import utcnow, get_rank, timeit, Profile | ||
|
||
dlp = Profile(MODULE_DATA_LOADER) | ||
|
||
|
||
class NativeDaliDataLoader(BaseDataLoader): | ||
@dlp.log_init | ||
def __init__(self, format_type, dataset_type, epoch): | ||
super().__init__(format_type, dataset_type, epoch, DataLoaderType.NATIVE_DALI) | ||
self.pipelines = [] | ||
|
||
@dlp.log | ||
def read(self): | ||
num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval | ||
batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval | ||
parallel = True if self._args.read_threads > 0 else False | ||
self.pipelines = [] | ||
num_threads = 1 | ||
if self._args.read_threads > 0: | ||
num_threads = self._args.read_threads | ||
# None executes pipeline on CPU and the reader does the batching | ||
pipeline = Pipeline(batch_size=batch_size, num_threads=num_threads, device_id=None, py_num_workers=num_threads, | ||
exec_async=False, exec_pipelined=False) | ||
with pipeline: | ||
images = ReaderFactory.get_reader(type=self.format_type, | ||
dataset_type=self.dataset_type, | ||
thread_index=-1, | ||
epoch_number=self.epoch_number).read() | ||
pipeline.set_outputs(images) | ||
self.pipelines.append(pipeline) | ||
logging.info(f"{utcnow()} Creating {num_threads} pipelines by {self._args.my_rank} rank ") | ||
|
||
@dlp.log | ||
def next(self): | ||
super().next() | ||
num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval | ||
batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval | ||
for step in range(num_samples // batch_size): | ||
_dataset = DALIGenericIterator(self.pipelines, ['data']) | ||
for batch in _dataset: | ||
logging.info(f"{utcnow()} Creating {len(batch)} batches by {self._args.my_rank} rank ") | ||
yield batch | ||
|
||
@dlp.log | ||
def finalize(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
""" | ||
Copyright (c) 2022, UChicago Argonne, LLC | ||
All Rights Reserved | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import math | ||
import logging | ||
from time import time | ||
|
||
import nvidia.dali.fn as fn | ||
from dlio_benchmark.common.constants import MODULE_DATA_READER | ||
from dlio_benchmark.reader.dali_base_reader import DaliBaseReader | ||
from dlio_benchmark.reader.tf_base_reader import TFBaseReader | ||
from dlio_benchmark.utils.utility import utcnow, PerfTrace, Profile | ||
from dlio_benchmark.common.enumerations import DatasetType, Shuffle | ||
import nvidia.dali.tfrecord as tfrec | ||
|
||
dlp = Profile(MODULE_DATA_READER) | ||
|
||
|
||
class DaliImageReader(DaliBaseReader): | ||
@dlp.log_init | ||
def __init__(self, dataset_type): | ||
super().__init__(dataset_type) | ||
|
||
@dlp.log | ||
def _load(self): | ||
logging.debug( | ||
f"{utcnow()} Reading {len(self.file_list)} files rank {self._args.my_rank}") | ||
random_shuffle = False | ||
seed = -1 | ||
seed_change_epoch = False | ||
if self._args.sample_shuffle is not Shuffle.OFF: | ||
if self._args.sample_shuffle is not Shuffle.SEED: | ||
seed = self._args.seed | ||
random_shuffle = True | ||
seed_change_epoch = True | ||
initial_fill = 1024 | ||
if self._args.shuffle_size > 0: | ||
initial_fill = self._args.shuffle_size | ||
prefetch_size = 1 | ||
if self._args.prefetch_size > 0: | ||
prefetch_size = self._args.prefetch_size | ||
|
||
stick_to_shard = True | ||
if seed_change_epoch: | ||
stick_to_shard = False | ||
images, labels = fn.readers.file(files=files, num_shards=self._args.comm_size, | ||
prefetch_queue_depth=prefetch_size, | ||
initial_fill=initial_fill, random_shuffle=random_shuffle, | ||
shuffle_after_epoch=seed_change_epoch, | ||
stick_to_shard=stick_to_shard, pad_last_batch=True) | ||
dataset = fn.decoders.image(jpegs, device='cpu') | ||
return dataset | ||
|
||
@dlp.log | ||
def finalize(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
""" | ||
Copyright (c) 2022, UChicago Argonne, LLC | ||
All Rights Reserved | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import math | ||
import logging | ||
from time import time | ||
|
||
import nvidia.dali.fn as fn | ||
from dlio_benchmark.common.constants import MODULE_DATA_READER | ||
from dlio_benchmark.reader.dali_base_reader import DaliBaseReader | ||
from dlio_benchmark.reader.tf_base_reader import TFBaseReader | ||
from dlio_benchmark.utils.utility import utcnow, PerfTrace, Profile | ||
from dlio_benchmark.common.enumerations import DatasetType, Shuffle | ||
import nvidia.dali.tfrecord as tfrec | ||
|
||
dlp = Profile(MODULE_DATA_READER) | ||
|
||
|
||
class DaliNPZReader(DaliBaseReader): | ||
@dlp.log_init | ||
def __init__(self, dataset_type): | ||
super().__init__(dataset_type) | ||
|
||
@dlp.log | ||
def _load(self): | ||
logging.debug( | ||
f"{utcnow()} Reading {len(self.file_list)} files rank {self._args.my_rank}") | ||
random_shuffle = False | ||
seed = -1 | ||
seed_change_epoch = False | ||
if self._args.sample_shuffle is not Shuffle.OFF: | ||
if self._args.sample_shuffle is not Shuffle.SEED: | ||
seed = self._args.seed | ||
random_shuffle = True | ||
seed_change_epoch = True | ||
initial_fill = 1024 | ||
if self._args.shuffle_size > 0: | ||
initial_fill = self._args.shuffle_size | ||
prefetch_size = 1 | ||
if self._args.prefetch_size > 0: | ||
prefetch_size = self._args.prefetch_size | ||
|
||
stick_to_shard = True | ||
if seed_change_epoch: | ||
stick_to_shard = False | ||
|
||
dataset = fn.readers.numpy(device='cpu', files=self.file_list, num_shards=self._args.comm_size, | ||
prefetch_queue_depth=prefetch_size, initial_fill=initial_fill, | ||
random_shuffle=random_shuffle, seed=seed, shuffle_after_epoch=seed_change_epoch, | ||
stick_to_shard=stick_to_shard, pad_last_batch=True) | ||
return dataset | ||
|
||
@dlp.log | ||
def finalize(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
""" | ||
Copyright (c) 2022, UChicago Argonne, LLC | ||
All Rights Reserved | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import os.path | ||
|
||
import math | ||
import logging | ||
from time import time | ||
|
||
import nvidia | ||
import nvidia.dali.fn as fn | ||
from dlio_benchmark.common.constants import MODULE_DATA_READER | ||
from dlio_benchmark.reader.dali_base_reader import DaliBaseReader | ||
from dlio_benchmark.reader.tf_base_reader import TFBaseReader | ||
from dlio_benchmark.utils.utility import utcnow, PerfTrace, Profile | ||
from dlio_benchmark.common.enumerations import DatasetType, Shuffle | ||
import nvidia.dali.tfrecord as tfrec | ||
|
||
dlp = Profile(MODULE_DATA_READER) | ||
|
||
|
||
class DaliTFRecordReader(DaliBaseReader): | ||
@dlp.log_init | ||
def __init__(self, dataset_type): | ||
super().__init__(dataset_type) | ||
|
||
@dlp.log | ||
def _load(self): | ||
folder = "valid" | ||
if self.dataset_type == DatasetType.TRAIN: | ||
folder = "train" | ||
index_folder = f"{self._args.data_folder}/index/{folder}" | ||
index_files = [] | ||
for file in self.file_list: | ||
filename = os.path.basename(file) | ||
index_files.append(f"{index_folder}/{filename}.idx") | ||
logging.info( | ||
f"{utcnow()} Reading {len(self.file_list)} files rank {self._args.my_rank}") | ||
random_shuffle = False | ||
seed = -1 | ||
if self._args.sample_shuffle is not Shuffle.OFF: | ||
if self._args.sample_shuffle is not Shuffle.SEED: | ||
seed = self._args.seed | ||
random_shuffle = True | ||
initial_fill = 1024 | ||
if self._args.shuffle_size > 0: | ||
initial_fill = self._args.shuffle_size | ||
prefetch_size = 1 | ||
if self._args.prefetch_size > 0: | ||
prefetch_size = self._args.prefetch_size | ||
dataset = fn.readers.tfrecord(path=self.file_list, | ||
index_path=index_files, | ||
features={ | ||
'image': tfrec.FixedLenFeature((), tfrec.string, ""), | ||
'size': tfrec.FixedLenFeature([1], tfrec.int64, 0) | ||
}, num_shards=self._args.comm_size, | ||
prefetch_queue_depth=prefetch_size, | ||
initial_fill=initial_fill, | ||
random_shuffle=random_shuffle, seed=seed, | ||
stick_to_shard=True, pad_last_batch=True) | ||
return dataset["image"] | ||
|
||
@dlp.log | ||
def finalize(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.