Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding python changes for numpy reader #9

Open
wants to merge 1 commit into
base: numpy_reader_cpp_changes
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rocAL_pybind/amd/rocal/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(self, batch_size=-1, num_threads=0, device_id=0, seed=1,
self._check_ops_decoder = [
"ImageDecoder", "ImageDecoderSlice", "ImageDecoderRandomCrop", "ImageDecoderRaw"]
self._check_ops_reader = ["labelReader", "TFRecordReaderClassification", "TFRecordReaderDetection",
"COCOReader", "Caffe2Reader", "Caffe2ReaderDetection", "CaffeReader", "CaffeReaderDetection"]
"COCOReader", "Caffe2Reader", "Caffe2ReaderDetection", "CaffeReader", "CaffeReaderDetection", "NumpyReader"]
self._batch_size = batch_size
self._num_threads = num_threads
self._device_id = device_id
Expand Down
51 changes: 51 additions & 0 deletions rocAL_pybind/amd/rocal/plugin/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,57 @@ def __init__(self,
multiplier=pipe._multiplier, offset=pipe._offset, display=display, device=device, device_id=device_id)


class ROCALNumpyIterator(object):
def __init__(self, pipeline, device="cpu", device_id=0):
self.loader = pipeline
self.device = device
self.device_id = device_id
self.output_list = None
self.batch_size = self.loader._batch_size
self.len = b.getRemainingImages(self.loader._handle)
self.last_batch_policy = self.loader._last_batch_policy
self.last_batch_size = None

def next(self):
return self.__next__()

def __next__(self):
if self.loader.rocal_run() != 0:
raise StopIteration
self.output_tensor_list = self.loader.get_output_tensors()

if self.output_list is None: # Checking if output_list is empty and initializing the buffers
self.output_list = []
for i in range(len(self.output_tensor_list)):
dimensions = self.output_tensor_list[i].dimensions()
dtype = self.output_tensor_list[i].dtype()
output = np.empty(dimensions, dtype=dtype)
# returned as numpy always - no ROCM CuPy support available
self.output_tensor_list[i].copy_data(output)
self.output_list.append(output)
else:
for i in range(len(self.output_tensor_list)):
self.output_tensor_list[i].copy_data(self.output_list[i])
# Check if last batch policy is partial and only return the valid images in last batch
if (self.last_batch_policy is (types.LAST_BATCH_PARTIAL)) and b.getRemainingImages(self.loader._handle) < self.batch_size:
if (self.last_batch_size is None):
self.last_batch_size = self.batch_size - \
b.getLastBatchPaddedSize(self.loader._handle)
return [inner_list[0:self.last_batch_size, :] for inner_list in self.output_list]
return self.output_list

def reset(self):
b.rocalResetLoaders(self.loader._handle)

def __iter__(self):
return self

def __len__(self):
return self.len // self.batch_size

def __del__(self):
b.rocalRelease(self.loader._handle)

def draw_patches(img, idx):
# image is expected as an array
"""!Writes images to disk as a PNG file.
Expand Down
62 changes: 62 additions & 0 deletions rocAL_pybind/amd/rocal/plugin/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,68 @@ def __len__(self):
def __del__(self):
b.rocalRelease(self.loader._handle)


class ROCALNumpyIterator(object):
def __init__(self, pipeline, device="cpu", device_id=0):
self.loader = pipeline
self.device = device
self.device_id = device_id
self.output_memory_type = self.loader._output_memory_type
self.output_list = None
self.batch_size = self.loader._batch_size
self.len = b.getRemainingImages(self.loader._handle) // self.batch_size
self.last_batch_policy = self.loader._last_batch_policy
self.last_batch_size = None

def next(self):
return self.__next__()

def __next__(self):
if self.loader.rocal_run() != 0:
raise StopIteration
self.output_tensor_list = self.loader.get_output_tensors()

if self.output_list is None:
# Output list used to store pipeline outputs - can support multiple augmentation outputs
self.output_list = []
for i in range(len(self.output_tensor_list)):
dimensions = self.output_tensor_list[i].dimensions()
torch_dtype = self.output_tensor_list[i].dtype()
if self.device == "cpu":
output = torch.empty(
dimensions, dtype=getattr(torch, torch_dtype))
else:
torch_gpu_device = torch.device('cuda', self.device_id)
output = torch.empty(dimensions, dtype=getattr(
torch, torch_dtype), device=torch_gpu_device)

self.output_tensor_list[i].copy_data(ctypes.c_void_p(
output.data_ptr()), self.output_memory_type)
self.output_list.append(output)
else:
for i in range(len(self.output_tensor_list)):
self.output_tensor_list[i].copy_data(ctypes.c_void_p(
self.output_list[i].data_ptr()), self.output_memory_type)
# Check if last batch policy is partial and only return the valid images in last batch
if (self.last_batch_policy is (types.LAST_BATCH_PARTIAL)) and b.getRemainingImages(self.loader._handle) < self.batch_size:
if (self.last_batch_size is None):
self.last_batch_size = self.batch_size - \
b.getLastBatchPaddedSize(self.loader._handle)
return [inner_list[0:self.last_batch_size, :] for inner_list in self.output_list]
return self.output_list

def reset(self):
b.rocalResetLoaders(self.loader._handle)

def __iter__(self):
return self

def __len__(self):
return self.len

def __del__(self):
b.rocalRelease(self.loader._handle)

def draw_patches(img, idx, bboxes):
"""!Writes images to disk as a PNG file.

Expand Down
15 changes: 15 additions & 0 deletions rocAL_pybind/amd/rocal/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,3 +382,18 @@ def webdataset(path, index_paths="", ext = None, missing_components_behavior = t
webdata_metadata = b.webDatasetReader(
Pipeline._current_pipeline._handle, *(kwargs_pybind.values()))
return webdata_metadata


def numpy(*inputs, file_root='', files=[], num_shards=1, output_layout=types.NONE,
random_shuffle=False, shard_id=0, stick_to_shard=True, shard_size=-1,
last_batch_policy=types.LAST_BATCH_FILL, pad_last_batch=True, seed=0):

Pipeline._current_pipeline._reader = "NumpyReader"
Pipeline._current_pipeline._last_batch_policy = last_batch_policy
sharding_info = b.RocalShardingInfo(last_batch_policy, pad_last_batch, stick_to_shard, shard_size)
# Output
kwargs_pybind = {"source_path": file_root, "output_layout": output_layout, "files": files, "is_output": False, "shuffle": random_shuffle,
"loop": False, "shard_id": shard_id, "shard_count": num_shards, "seed": seed, "sharding_info": sharding_info}
numpy_reader_output = b.numpyReaderSourceShard(
Pipeline._current_pipeline._handle, *(kwargs_pybind.values()))
return (numpy_reader_output)
4 changes: 4 additions & 0 deletions rocAL_pybind/amd/rocal/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from rocal_pybind.types import FLOAT
from rocal_pybind.types import FLOAT16
from rocal_pybind.types import UINT8
from rocal_pybind.types import INT16

# RocalOutputMemType
from rocal_pybind.types import HOST_MEMORY
Expand All @@ -58,6 +59,7 @@
from rocal_pybind.types import RGB_PLANAR

# RocalTensorLayout
from rocal_pybind.types import NONE
from rocal_pybind.types import NHWC
from rocal_pybind.types import NCHW
from rocal_pybind.types import NFHWC
Expand Down Expand Up @@ -137,6 +139,7 @@
FLOAT: ("FLOAT", FLOAT),
FLOAT16: ("FLOAT16", FLOAT16),
UINT8: ("UINT8", UINT8),
INT16: ("INT16", INT16),
HOST_MEMORY: ("HOST_MEMORY", HOST_MEMORY),
DEVICE_MEMORY: ("DEVICE_MEMORY", DEVICE_MEMORY),
PINNED_MEMORY: ("PINNED_MEMORY", PINNED_MEMORY),
Expand All @@ -147,6 +150,7 @@
MAX_SIZE_ORIG: ("MAX_SIZE_ORIG", MAX_SIZE_ORIG),
USER_GIVEN_SIZE_ORIG: ("USER_GIVEN_SIZE_ORIG", USER_GIVEN_SIZE_ORIG),

NONE: ("NONE", NONE),
NHWC: ("NHWC", NHWC),
NCHW: ("NCHW", NCHW),
NFHWC: ("NFHWC", NFHWC),
Expand Down
14 changes: 14 additions & 0 deletions rocAL_pybind/rocal_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ std::unordered_map<int, std::string> rocalToPybindOutputDtype = {
{3, "int8"},
{4, "uint32"},
{5, "int32"},
{6, "int16"},
};

#if ENABLE_DLPACK
Expand Down Expand Up @@ -202,6 +203,8 @@ std::unordered_map<int, std::string> rocalToPybindOutputDtype = {
case 16:
if (dtype.code == kDLFloat) {
return RocalTensorOutputType::ROCAL_FP16;
} else if (dtype.code == kDLInt) {
return RocalTensorOutputType::ROCAL_INT16;
} else {
throw std::runtime_error("Data type code for 16 bit type is not supported.");
}
Expand Down Expand Up @@ -264,6 +267,11 @@ std::unordered_map<int, std::string> rocalToPybindOutputDtype = {
case RocalTensorOutputType::ROCAL_FP32:
out.bits = 32;
out.code = kDLFloat;
break;
case RocalTensorOutputType::ROCAL_INT16:
out.bits = 16;
out.code = kDLInt;
break;
case RocalTensorOutputType::ROCAL_FP16:
out.bits = 16;
out.code = kDLFloat;
Expand Down Expand Up @@ -660,6 +668,7 @@ py::class_<rocalListOfTensorList>(m, "rocalListOfTensorList")
.value("FLOAT", ROCAL_FP32)
.value("FLOAT16", ROCAL_FP16)
.value("UINT8", ROCAL_UINT8)
.value("INT16", ROCAL_INT16)
.export_values();
py::enum_<RocalOutputMemType>(types_m, "RocalOutputMemType", "Output memory types")
.value("HOST_MEMORY", ROCAL_MEMCPY_HOST)
Expand Down Expand Up @@ -695,6 +704,7 @@ py::class_<rocalListOfTensorList>(m, "rocalListOfTensorList")
.value("RGB_PLANAR", ROCAL_COLOR_RGB_PLANAR)
.export_values();
py::enum_<RocalTensorLayout>(types_m, "RocalTensorLayout", "Tensor layout type")
.value("NONE", ROCAL_NONE)
.value("NHWC", ROCAL_NHWC)
.value("NCHW", ROCAL_NCHW)
.value("NFHWC", ROCAL_NFHWC)
Expand Down Expand Up @@ -1024,6 +1034,10 @@ py::class_<rocalListOfTensorList>(m, "rocalListOfTensorList")
py::return_value_policy::reference);
m.def("audioDecoder", &rocalAudioFileSource, "Reads file from the source given and decodes it",
py::return_value_policy::reference);
m.def("numpyReaderSource", &rocalNumpyFileSource, "Reads data from numpy files",
py::return_value_policy::reference);
m.def("numpyReaderSourceShard", &rocalNumpyFileSourceSingleShard, "Reads data from numpy files according to the shard id and number of shards",
py::return_value_policy::reference);
m.def("rocalResetLoaders", &rocalResetLoaders);
m.def("videoMetaDataReader", &rocalCreateVideoLabelReader, py::return_value_policy::reference);
// rocal_api_augmentation.h
Expand Down
68 changes: 68 additions & 0 deletions tests/python_api/numpy_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random

from amd.rocal.pipeline import Pipeline
from amd.rocal.plugin.generic import ROCALNumpyIterator
import amd.rocal.fn as fn
import sys
import os

def draw_patches(image, idx, layout="nhwc", dtype="uint8"):
# image is expected as a numpy array
import cv2
if layout == "nchw":
image = image.transpose([1, 2, 0])
if dtype == "fp16":
image = image.astype("uint8")
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
cv2.imwrite("output_folder/numpy_reader/" + str(idx) + ".png", image)

def main():
if len(sys.argv) < 3:
print('Please pass numpy_folder cpu/gpu batch_size')
exit(0)
data_path = sys.argv[1]
try:
path = "output_folder/numpy_reader/"
isExist = os.path.exists(path)
if not isExist:
os.makedirs(path)
except OSError as error:
print(error)
if (sys.argv[2] == "cpu"):
rocal_cpu = True
else:
rocal_cpu = False
batch_size = int(sys.argv[3])
num_threads = 1
device_id = 0
local_rank = 0
world_size = 1
random_seed = random.SystemRandom().randint(0, 2**32 - 1)

pipeline = Pipeline(batch_size=batch_size, num_threads=num_threads, device_id=device_id, seed=random_seed, rocal_cpu=rocal_cpu)

with pipeline:
numpy_reader_output = fn.readers.numpy(file_root=data_path, shard_id=local_rank, num_shards=world_size)
pipeline.set_outputs(numpy_reader_output)

pipeline.build()

cnt = 0
numpyIteratorPipeline = ROCALNumpyIterator(pipeline)
print(len(numpyIteratorPipeline))
for epoch in range(1):
print("+++++++++++++++++++++++++++++EPOCH+++++++++++++++++++++++++++++++++++++",epoch)
for i , [batch] in enumerate(numpyIteratorPipeline):
print(batch.shape)
for img in batch:
draw_patches(img, cnt)
cnt += 1
numpyIteratorPipeline.reset()
print("*********************************************************************")


if __name__ == '__main__':
main()