From 7705c34194a78ec87ff5ee123471de701bc0aae0 Mon Sep 17 00:00:00 2001 From: Mark Kraay <63085173+markkraay@users.noreply.github.com> Date: Fri, 16 Aug 2024 13:59:00 -0700 Subject: [PATCH] Added function `create_memref_view_from_dlpack` (#71) This function greatly simplifies Tripy's Array implementation. We want to be able to handle memref creation from all types that implement the `__dlpack__()` interface rather than the limited set we currently support. This function should allow us to achieve this. The corresponding Tripy changes are #72. --- .../include/mlir-executor-c/Runtime/Runtime.h | 13 ++++ .../executor/lib/CAPI/Runtime/Runtime.cpp | 52 ++++++++++++- .../python/bindings/Runtime/RuntimePyBind.cpp | 76 +++++++++++++++++++ .../test_create_memref.py | 65 ++++++++++++++++ 4 files changed, 203 insertions(+), 3 deletions(-) diff --git a/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h b/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h index 014da6ed4..11ae93519 100644 --- a/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h +++ b/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h @@ -25,6 +25,7 @@ #ifndef MLIR_EXECUTOR_C_RUNTIME_RUNTIME #define MLIR_EXECUTOR_C_RUNTIME_RUNTIME +#include "dlpack/dlpack.h" #include "mlir-c/Support.h" #include "mlir-executor-c/Common/Common.h" #include "mlir-executor-c/Support/Status.h" @@ -382,6 +383,18 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionExecuteFunction( const MTRT_RuntimeValue *inArgs, size_t numInArgs, const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream); +//===----------------------------------------------------------------------===// +// DLPack +//===----------------------------------------------------------------------===// + +/// Converts a DLDeviceType to MTRT_PointerType. This function will throw a runtime +/// error if the device type is invalid. +MLIR_CAPI_EXPORTED MTRT_Status mtrtGetPointerTypeFromDLDeviceType(DLDeviceType device, MTRT_PointerType* result); + +/// Converts a DLDataType to MTRT_ScalarTypeCode. This function will throw a runtime +/// error if the data type is invalid. +MLIR_CAPI_EXPORTED MTRT_Status mtrtGetScalarTypeCodeFromDLDataType(DLDataType dtype, MTRT_ScalarTypeCode* result); + #ifdef __cplusplus } #endif diff --git a/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp b/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp index 83d94f09d..2b0fbc35c 100644 --- a/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp +++ b/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp @@ -326,6 +326,22 @@ static StatusOr toDLPackDeviceType(PointerType address) { return DLDeviceType::kDLCPU; } +MTRT_Status mtrtGetPointerTypeFromDLDeviceType(DLDeviceType device, MTRT_PointerType* result) { + #define RETURN_OK(v) *result = v; return mtrtStatusGetOk(); + switch (device) { + case DLDeviceType::kDLCUDA: RETURN_OK(MTRT_PointerType_device) + case DLDeviceType::kDLCPU: RETURN_OK(MTRT_PointerType_host) + case DLDeviceType::kDLCUDAHost: RETURN_OK(MTRT_PointerType_host) + case DLDeviceType::kDLCUDAManaged: RETURN_OK(MTRT_PointerType_unified) + default: + return wrap(getStatusWithMsg( + StatusCode::InvalidArgument, "DLDeviceType [", + // device, + "] conversion to MTRT_PointerType is not supported.")); + } + #undef RETURN_OK +} + static StatusOr toDLPackDataTypeCode(ScalarTypeCode type) { switch (type) { case ScalarTypeCode::i1: @@ -347,13 +363,43 @@ static StatusOr toDLPackDataTypeCode(ScalarTypeCode type) { return DLDataTypeCode::kDLBfloat; default: return getStatusWithMsg( - StatusCode::InvalidArgument, "Scalar type code [", - EnumNameScalarTypeCode(type), - "] conversion to DLPackDataTypeCode is not supported."); + StatusCode::InvalidArgument, "Scalar type code conversion to DLPackDataTypeCode is not supported."); } return DLDataTypeCode::kDLFloat; } +MTRT_Status mtrtGetScalarTypeCodeFromDLDataType(DLDataType dtype, MTRT_ScalarTypeCode* result) { + #define RETURN_OK(v) *result = v; return mtrtStatusGetOk(); + switch (dtype.code) { + case kDLBool: RETURN_OK(MTRT_ScalarTypeCode_i1) + case kDLInt: + switch (dtype.bits) { + case 8: RETURN_OK(MTRT_ScalarTypeCode_i8) + case 16: RETURN_OK(MTRT_ScalarTypeCode_i16) + case 32: RETURN_OK(MTRT_ScalarTypeCode_i32) + case 64: RETURN_OK(MTRT_ScalarTypeCode_i64) + } + case kDLUInt: + switch (dtype.bits) { + case 8: RETURN_OK(MTRT_ScalarTypeCode_ui8); + } + case kDLFloat: + switch (dtype.bits) { + case 8: RETURN_OK(MTRT_ScalarTypeCode_f8e4m3fn) + case 16: RETURN_OK(MTRT_ScalarTypeCode_f16) + case 32: RETURN_OK(MTRT_ScalarTypeCode_f32) + case 64: RETURN_OK(MTRT_ScalarTypeCode_f64) + } + case kDLBfloat: RETURN_OK(MTRT_ScalarTypeCode_bf16) + case kDLComplex: + case kDLOpaqueHandle: + default: + return wrap(getStatusWithMsg( + StatusCode::InvalidArgument, "DLDataType conversion to MTRT_ScalarTypeCode is not supported.")); + } + #undef RETURN_OK +} + static void dlpackManagedTensorDeleter(DLManagedTensor *tensor) { if (tensor) { delete[] tensor->dl_tensor.shape; diff --git a/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp b/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp index ce793d05b..fef5ad868 100644 --- a/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp +++ b/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp @@ -312,6 +312,76 @@ static std::unique_ptr createMemRef( return std::make_unique(result); } + + +static std::unique_ptr +createMemRefViewFromDLPack(PyRuntimeClient &client, + py::capsule capsule) { + DLManagedTensor *managedTensor = + static_cast(PyCapsule_GetPointer(capsule.ptr(), "dltensor")); + + if (managedTensor == nullptr) { + Py_DECREF(capsule); + return nullptr; + } + + MTRT_MemRefValue result{nullptr}; + + // Extract the necessary information from the DLManagedTensor + void *data = managedTensor->dl_tensor.data; + int64_t *shape = managedTensor->dl_tensor.shape; + + int64_t* strides = managedTensor->dl_tensor.strides; + std::vector stridesArray; + if (!strides) { + // Create a suffix product stride array in the event that the DLPack object's stride array is set to `null` + auto ndim = managedTensor->dl_tensor.ndim; + stridesArray.resize(ndim); + if (ndim > 0) { + stridesArray[ndim - 1] = 1; + for (int i = ndim - 2; i >= 0; i--) { + stridesArray[i] = shape[i+1] * stridesArray[i+1]; + } + } + strides = stridesArray.data(); + } + + int64_t offset = managedTensor->dl_tensor.byte_offset; + int rank = managedTensor->dl_tensor.ndim; + DLDataType dtype = managedTensor->dl_tensor.dtype; + DLDeviceType device_type = managedTensor->dl_tensor.device.device_type; + int device_id = managedTensor->dl_tensor.device.device_id; + + MTRT_ScalarTypeCode elementType; + MTRT_Status s; + s = mtrtGetScalarTypeCodeFromDLDataType(dtype, &elementType); + THROW_IF_MTRT_ERROR(s); + + int64_t bytesPerElement = llvm::divideCeil(dtype.bits, 8); + + MTRT_PointerType addressSpace; + s = mtrtGetPointerTypeFromDLDeviceType(device_type, &addressSpace); + THROW_IF_MTRT_ERROR(s); + + MTRT_Device device{nullptr}; + if (addressSpace == MTRT_PointerType_device) { + s = mtrtRuntimeClientGetDevice(client, device_id, &device); + THROW_IF_MTRT_ERROR(s); + } + + if (data) { + s = mtrtMemRefCreateExternal(client, addressSpace, bytesPerElement * 8, + reinterpret_cast(data), offset, rank, + shape, strides, device, elementType, &result); + } else { + s = mtrtMemRefCreate(client, addressSpace, bytesPerElement * 8, rank, + shape, strides, device, mtrtStreamGetNull(), elementType, &result); + } + + THROW_IF_MTRT_ERROR(s); + return std::make_unique(result); +} + static std::unique_ptr getMemRefFromHostBufferProtocol( PyRuntimeClient &client, py::buffer array, std::optional> explicitShape, @@ -716,6 +786,12 @@ PYBIND11_MODULE(_api, m) { py::arg("shape"), py::arg("dtype"), py::arg("device") = py::none(), py::arg("stream") = py::none(), py::keep_alive<0, 1>(), "returns a new memref and allocates uninitialized backing storage") + .def( + "create_memref_view_from_dlpack", + [](PyRuntimeClient &self, py::capsule capsule) { + return createMemRefViewFromDLPack(self, capsule).release(); + }, + py::arg("dltensor") = py::none(), py::keep_alive<0, 1>(), py::keep_alive<0, 2>()) .def( "create_device_memref_view", [](PyRuntimeClient &self, uintptr_t ptr, std::vector shape, diff --git a/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_create_memref.py b/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_create_memref.py index f43e58787..ad4f44f83 100644 --- a/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_create_memref.py +++ b/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_create_memref.py @@ -448,3 +448,68 @@ def memref_alloc(): # CHECK-NEXT: Memref released internally: True # CHECK-NEXT: Number of External reference count: 1 # CHECK-NEXT: Numpy Array: [5. 4. 2.] + + +def create_memref_from_dlpack(arr, module): + print(f"Array: {arr}") + memref = client.create_memref_view_from_dlpack(arr.__dlpack__()) + print(f"-- Memref shape: {memref.shape}") + print(f"-- Memref dtype: {memref.dtype}") + print(f"-- {module.__name__}.from_dlpack(): {module.from_dlpack(memref)}") + +print(f"Test np.array -> client.create_memref_from_dlpack") +create_memref_from_dlpack(np.array([1, 2, 3, 4], dtype=np.int32), np) +create_memref_from_dlpack(np.ones((1, 2, 3), dtype=np.float32), np) +create_memref_from_dlpack(np.ones(0, dtype=np.int8), np) +print(f"Test cp.array -> client.create_memref_from_dlpack") +create_memref_from_dlpack(cp.array([1, 2, 3, 4], dtype=cp.int32), cp) +create_memref_from_dlpack(cp.ones((1, 2, 3), dtype=cp.float32), cp) +create_memref_from_dlpack(cp.ones(0, dtype=cp.int8), cp) + + +# CHECK-LABEL: Test np.array -> client.create_memref_from_dlpack +# CHECK-NEXT: Array: [1 2 3 4] +# CHECK-NEXT: -- Memref shape: [4] +# CHECK-NEXT: -- Memref dtype: ScalarTypeCode.i32 +# CHECK-NEXT: -- numpy.from_dlpack(): [1 2 3 4] +# CHECK-NEXT: Array: {{\[\[\[1. 1. 1.\][[:space:]]*\[1. 1. 1.\]\]\]}} +# CHECK-NEXT: -- Memref shape: [1, 2, 3] +# CHECK-NEXT: -- Memref dtype: ScalarTypeCode.f32 +# CHECK-NEXT: -- numpy.from_dlpack(): {{\[\[\[1. 1. 1.\][[:space:]]*\[1. 1. 1.\]\]\]}} +# CHECK-NEXT: Array: [] +# CHECK-NEXT: -- Memref shape: [0] +# CHECK-NEXT: -- Memref dtype: ScalarTypeCode.i8 +# CHECK-NEXT: -- numpy.from_dlpack(): [] +# CHECK-LABEL: Test cp.array -> client.create_memref_from_dlpack +# CHECK-NEXT: Array: [1 2 3 4] +# CHECK-NEXT: -- Memref shape: [4] +# CHECK-NEXT: -- Memref dtype: ScalarTypeCode.i32 +# CHECK-NEXT: -- cupy.from_dlpack(): [1 2 3 4] +# CHECK-NEXT: Array: {{\[\[\[1. 1. 1.\][[:space:]]*\[1. 1. 1.\]\]\]}} +# CHECK-NEXT: -- Memref shape: [1, 2, 3] +# CHECK-NEXT: -- Memref dtype: ScalarTypeCode.f32 +# CHECK-NEXT: -- cupy.from_dlpack(): {{\[\[\[1. 1. 1.\][[:space:]]*\[1. 1. 1.\]\]\]}} +# CHECK-NEXT: Array: [] +# CHECK-NEXT: -- Memref shape: [0] +# CHECK-NEXT: -- Memref dtype: ScalarTypeCode.i8 +# CHECK-NEXT: -- cupy.from_dlpack(): [] + + +def create_dangling_memref(): + array = np.array([1, 2]) + dlpack_capsule = array.__dlpack__() + memref = client.create_memref_view_from_dlpack(dlpack_capsule) + print("-- Inner scope: np.from_dlpack(): ", np.from_dlpack(memref)) + return memref + + +print("Test memref maintains data's lifetime") +memref = create_dangling_memref() +# Declare a new array to overwrite the old memory +b = np.array([10, 10]) +print("-- Outer scope: np.from_dlpack(): ", np.from_dlpack(memref)) + + +# CHECK-LABEL: Test memref maintains data's lifetime +# CHECK-NEXT: -- Inner scope: np.from_dlpack(): [1 2] +# CHECK-NEXT: -- Outer scope: np.from_dlpack(): [1 2] \ No newline at end of file