From e183f4279f384516ce9941267cfd65c1188c3fa3 Mon Sep 17 00:00:00 2001 From: SundarRajan28 Date: Thu, 20 Feb 2025 08:30:20 +0000 Subject: [PATCH] Adding python changes for numpy reader --- rocAL_pybind/amd/rocal/pipeline.py | 2 +- rocAL_pybind/amd/rocal/plugin/generic.py | 51 ++++++++++++++++++ rocAL_pybind/amd/rocal/plugin/pytorch.py | 62 +++++++++++++++++++++ rocAL_pybind/amd/rocal/readers.py | 15 ++++++ rocAL_pybind/amd/rocal/types.py | 4 ++ rocAL_pybind/rocal_pybind.cpp | 14 +++++ tests/python_api/numpy_reader.py | 68 ++++++++++++++++++++++++ 7 files changed, 215 insertions(+), 1 deletion(-) create mode 100644 tests/python_api/numpy_reader.py diff --git a/rocAL_pybind/amd/rocal/pipeline.py b/rocAL_pybind/amd/rocal/pipeline.py index 3723278e3..70419fa43 100644 --- a/rocAL_pybind/amd/rocal/pipeline.py +++ b/rocAL_pybind/amd/rocal/pipeline.py @@ -83,7 +83,7 @@ def __init__(self, batch_size=-1, num_threads=0, device_id=0, seed=1, self._check_ops_decoder = [ "ImageDecoder", "ImageDecoderSlice", "ImageDecoderRandomCrop", "ImageDecoderRaw"] self._check_ops_reader = ["labelReader", "TFRecordReaderClassification", "TFRecordReaderDetection", - "COCOReader", "Caffe2Reader", "Caffe2ReaderDetection", "CaffeReader", "CaffeReaderDetection"] + "COCOReader", "Caffe2Reader", "Caffe2ReaderDetection", "CaffeReader", "CaffeReaderDetection", "NumpyReader"] self._batch_size = batch_size self._num_threads = num_threads self._device_id = device_id diff --git a/rocAL_pybind/amd/rocal/plugin/generic.py b/rocAL_pybind/amd/rocal/plugin/generic.py index b883ae69d..3ec1f70f8 100644 --- a/rocAL_pybind/amd/rocal/plugin/generic.py +++ b/rocAL_pybind/amd/rocal/plugin/generic.py @@ -218,6 +218,57 @@ def __init__(self, multiplier=pipe._multiplier, offset=pipe._offset, display=display, device=device, device_id=device_id) +class ROCALNumpyIterator(object): + def __init__(self, pipeline, device="cpu", device_id=0): + self.loader = pipeline + self.device = device + self.device_id = device_id + self.output_list = None + self.batch_size = self.loader._batch_size + self.len = b.getRemainingImages(self.loader._handle) + self.last_batch_policy = self.loader._last_batch_policy + self.last_batch_size = None + + def next(self): + return self.__next__() + + def __next__(self): + if self.loader.rocal_run() != 0: + raise StopIteration + self.output_tensor_list = self.loader.get_output_tensors() + + if self.output_list is None: # Checking if output_list is empty and initializing the buffers + self.output_list = [] + for i in range(len(self.output_tensor_list)): + dimensions = self.output_tensor_list[i].dimensions() + dtype = self.output_tensor_list[i].dtype() + output = np.empty(dimensions, dtype=dtype) + # returned as numpy always - no ROCM CuPy support available + self.output_tensor_list[i].copy_data(output) + self.output_list.append(output) + else: + for i in range(len(self.output_tensor_list)): + self.output_tensor_list[i].copy_data(self.output_list[i]) + # Check if last batch policy is partial and only return the valid images in last batch + if (self.last_batch_policy is (types.LAST_BATCH_PARTIAL)) and b.getRemainingImages(self.loader._handle) < self.batch_size: + if (self.last_batch_size is None): + self.last_batch_size = self.batch_size - \ + b.getLastBatchPaddedSize(self.loader._handle) + return [inner_list[0:self.last_batch_size, :] for inner_list in self.output_list] + return self.output_list + + def reset(self): + b.rocalResetLoaders(self.loader._handle) + + def __iter__(self): + return self + + def __len__(self): + return self.len // self.batch_size + + def __del__(self): + b.rocalRelease(self.loader._handle) + def draw_patches(img, idx): # image is expected as an array """!Writes images to disk as a PNG file. diff --git a/rocAL_pybind/amd/rocal/plugin/pytorch.py b/rocAL_pybind/amd/rocal/plugin/pytorch.py index d78a322c4..e0d55c857 100644 --- a/rocAL_pybind/amd/rocal/plugin/pytorch.py +++ b/rocAL_pybind/amd/rocal/plugin/pytorch.py @@ -401,6 +401,68 @@ def __len__(self): def __del__(self): b.rocalRelease(self.loader._handle) + +class ROCALNumpyIterator(object): + def __init__(self, pipeline, device="cpu", device_id=0): + self.loader = pipeline + self.device = device + self.device_id = device_id + self.output_memory_type = self.loader._output_memory_type + self.output_list = None + self.batch_size = self.loader._batch_size + self.len = b.getRemainingImages(self.loader._handle) // self.batch_size + self.last_batch_policy = self.loader._last_batch_policy + self.last_batch_size = None + + def next(self): + return self.__next__() + + def __next__(self): + if self.loader.rocal_run() != 0: + raise StopIteration + self.output_tensor_list = self.loader.get_output_tensors() + + if self.output_list is None: + # Output list used to store pipeline outputs - can support multiple augmentation outputs + self.output_list = [] + for i in range(len(self.output_tensor_list)): + dimensions = self.output_tensor_list[i].dimensions() + torch_dtype = self.output_tensor_list[i].dtype() + if self.device == "cpu": + output = torch.empty( + dimensions, dtype=getattr(torch, torch_dtype)) + else: + torch_gpu_device = torch.device('cuda', self.device_id) + output = torch.empty(dimensions, dtype=getattr( + torch, torch_dtype), device=torch_gpu_device) + + self.output_tensor_list[i].copy_data(ctypes.c_void_p( + output.data_ptr()), self.output_memory_type) + self.output_list.append(output) + else: + for i in range(len(self.output_tensor_list)): + self.output_tensor_list[i].copy_data(ctypes.c_void_p( + self.output_list[i].data_ptr()), self.output_memory_type) + # Check if last batch policy is partial and only return the valid images in last batch + if (self.last_batch_policy is (types.LAST_BATCH_PARTIAL)) and b.getRemainingImages(self.loader._handle) < self.batch_size: + if (self.last_batch_size is None): + self.last_batch_size = self.batch_size - \ + b.getLastBatchPaddedSize(self.loader._handle) + return [inner_list[0:self.last_batch_size, :] for inner_list in self.output_list] + return self.output_list + + def reset(self): + b.rocalResetLoaders(self.loader._handle) + + def __iter__(self): + return self + + def __len__(self): + return self.len + + def __del__(self): + b.rocalRelease(self.loader._handle) + def draw_patches(img, idx, bboxes): """!Writes images to disk as a PNG file. diff --git a/rocAL_pybind/amd/rocal/readers.py b/rocAL_pybind/amd/rocal/readers.py index f91158487..531987ab6 100644 --- a/rocAL_pybind/amd/rocal/readers.py +++ b/rocAL_pybind/amd/rocal/readers.py @@ -382,3 +382,18 @@ def webdataset(path, index_paths="", ext = None, missing_components_behavior = t webdata_metadata = b.webDatasetReader( Pipeline._current_pipeline._handle, *(kwargs_pybind.values())) return webdata_metadata + + +def numpy(*inputs, file_root='', files=[], num_shards=1, output_layout=types.NONE, + random_shuffle=False, shard_id=0, stick_to_shard=True, shard_size=-1, + last_batch_policy=types.LAST_BATCH_FILL, pad_last_batch=True, seed=0): + + Pipeline._current_pipeline._reader = "NumpyReader" + Pipeline._current_pipeline._last_batch_policy = last_batch_policy + sharding_info = b.RocalShardingInfo(last_batch_policy, pad_last_batch, stick_to_shard, shard_size) + # Output + kwargs_pybind = {"source_path": file_root, "output_layout": output_layout, "files": files, "is_output": False, "shuffle": random_shuffle, + "loop": False, "shard_id": shard_id, "shard_count": num_shards, "seed": seed, "sharding_info": sharding_info} + numpy_reader_output = b.numpyReaderSourceShard( + Pipeline._current_pipeline._handle, *(kwargs_pybind.values())) + return (numpy_reader_output) diff --git a/rocAL_pybind/amd/rocal/types.py b/rocAL_pybind/amd/rocal/types.py index 060fc2fe0..ff0bdd7ca 100644 --- a/rocAL_pybind/amd/rocal/types.py +++ b/rocAL_pybind/amd/rocal/types.py @@ -38,6 +38,7 @@ from rocal_pybind.types import FLOAT from rocal_pybind.types import FLOAT16 from rocal_pybind.types import UINT8 +from rocal_pybind.types import INT16 # RocalOutputMemType from rocal_pybind.types import HOST_MEMORY @@ -58,6 +59,7 @@ from rocal_pybind.types import RGB_PLANAR # RocalTensorLayout +from rocal_pybind.types import NONE from rocal_pybind.types import NHWC from rocal_pybind.types import NCHW from rocal_pybind.types import NFHWC @@ -137,6 +139,7 @@ FLOAT: ("FLOAT", FLOAT), FLOAT16: ("FLOAT16", FLOAT16), UINT8: ("UINT8", UINT8), + INT16: ("INT16", INT16), HOST_MEMORY: ("HOST_MEMORY", HOST_MEMORY), DEVICE_MEMORY: ("DEVICE_MEMORY", DEVICE_MEMORY), PINNED_MEMORY: ("PINNED_MEMORY", PINNED_MEMORY), @@ -147,6 +150,7 @@ MAX_SIZE_ORIG: ("MAX_SIZE_ORIG", MAX_SIZE_ORIG), USER_GIVEN_SIZE_ORIG: ("USER_GIVEN_SIZE_ORIG", USER_GIVEN_SIZE_ORIG), + NONE: ("NONE", NONE), NHWC: ("NHWC", NHWC), NCHW: ("NCHW", NCHW), NFHWC: ("NFHWC", NFHWC), diff --git a/rocAL_pybind/rocal_pybind.cpp b/rocAL_pybind/rocal_pybind.cpp index f94562e88..9bf656a11 100644 --- a/rocAL_pybind/rocal_pybind.cpp +++ b/rocAL_pybind/rocal_pybind.cpp @@ -158,6 +158,7 @@ std::unordered_map rocalToPybindOutputDtype = { {3, "int8"}, {4, "uint32"}, {5, "int32"}, + {6, "int16"}, }; #if ENABLE_DLPACK @@ -202,6 +203,8 @@ std::unordered_map rocalToPybindOutputDtype = { case 16: if (dtype.code == kDLFloat) { return RocalTensorOutputType::ROCAL_FP16; + } else if (dtype.code == kDLInt) { + return RocalTensorOutputType::ROCAL_INT16; } else { throw std::runtime_error("Data type code for 16 bit type is not supported."); } @@ -264,6 +267,11 @@ std::unordered_map rocalToPybindOutputDtype = { case RocalTensorOutputType::ROCAL_FP32: out.bits = 32; out.code = kDLFloat; + break; + case RocalTensorOutputType::ROCAL_INT16: + out.bits = 16; + out.code = kDLInt; + break; case RocalTensorOutputType::ROCAL_FP16: out.bits = 16; out.code = kDLFloat; @@ -660,6 +668,7 @@ py::class_(m, "rocalListOfTensorList") .value("FLOAT", ROCAL_FP32) .value("FLOAT16", ROCAL_FP16) .value("UINT8", ROCAL_UINT8) + .value("INT16", ROCAL_INT16) .export_values(); py::enum_(types_m, "RocalOutputMemType", "Output memory types") .value("HOST_MEMORY", ROCAL_MEMCPY_HOST) @@ -695,6 +704,7 @@ py::class_(m, "rocalListOfTensorList") .value("RGB_PLANAR", ROCAL_COLOR_RGB_PLANAR) .export_values(); py::enum_(types_m, "RocalTensorLayout", "Tensor layout type") + .value("NONE", ROCAL_NONE) .value("NHWC", ROCAL_NHWC) .value("NCHW", ROCAL_NCHW) .value("NFHWC", ROCAL_NFHWC) @@ -1024,6 +1034,10 @@ py::class_(m, "rocalListOfTensorList") py::return_value_policy::reference); m.def("audioDecoder", &rocalAudioFileSource, "Reads file from the source given and decodes it", py::return_value_policy::reference); + m.def("numpyReaderSource", &rocalNumpyFileSource, "Reads data from numpy files", + py::return_value_policy::reference); + m.def("numpyReaderSourceShard", &rocalNumpyFileSourceSingleShard, "Reads data from numpy files according to the shard id and number of shards", + py::return_value_policy::reference); m.def("rocalResetLoaders", &rocalResetLoaders); m.def("videoMetaDataReader", &rocalCreateVideoLabelReader, py::return_value_policy::reference); // rocal_api_augmentation.h diff --git a/tests/python_api/numpy_reader.py b/tests/python_api/numpy_reader.py new file mode 100644 index 000000000..1216b65ff --- /dev/null +++ b/tests/python_api/numpy_reader.py @@ -0,0 +1,68 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import random + +from amd.rocal.pipeline import Pipeline +from amd.rocal.plugin.generic import ROCALNumpyIterator +import amd.rocal.fn as fn +import sys +import os + +def draw_patches(image, idx, layout="nhwc", dtype="uint8"): + # image is expected as a numpy array + import cv2 + if layout == "nchw": + image = image.transpose([1, 2, 0]) + if dtype == "fp16": + image = image.astype("uint8") + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + cv2.imwrite("output_folder/numpy_reader/" + str(idx) + ".png", image) + +def main(): + if len(sys.argv) < 3: + print('Please pass numpy_folder cpu/gpu batch_size') + exit(0) + data_path = sys.argv[1] + try: + path = "output_folder/numpy_reader/" + isExist = os.path.exists(path) + if not isExist: + os.makedirs(path) + except OSError as error: + print(error) + if (sys.argv[2] == "cpu"): + rocal_cpu = True + else: + rocal_cpu = False + batch_size = int(sys.argv[3]) + num_threads = 1 + device_id = 0 + local_rank = 0 + world_size = 1 + random_seed = random.SystemRandom().randint(0, 2**32 - 1) + + pipeline = Pipeline(batch_size=batch_size, num_threads=num_threads, device_id=device_id, seed=random_seed, rocal_cpu=rocal_cpu) + + with pipeline: + numpy_reader_output = fn.readers.numpy(file_root=data_path, shard_id=local_rank, num_shards=world_size) + pipeline.set_outputs(numpy_reader_output) + + pipeline.build() + + cnt = 0 + numpyIteratorPipeline = ROCALNumpyIterator(pipeline) + print(len(numpyIteratorPipeline)) + for epoch in range(1): + print("+++++++++++++++++++++++++++++EPOCH+++++++++++++++++++++++++++++++++++++",epoch) + for i , [batch] in enumerate(numpyIteratorPipeline): + print(batch.shape) + for img in batch: + draw_patches(img, cnt) + cnt += 1 + numpyIteratorPipeline.reset() + print("*********************************************************************") + + +if __name__ == '__main__': + main()