Skip to content

Commit

Permalink
Added function create_memref_view_from_dlpack (#71)
Browse files Browse the repository at this point in the history
This function greatly simplifies Tripy's Array implementation. We want
to be able to handle memref creation from all types that implement the
`__dlpack__()` interface rather than the limited set we currently
support. This function should allow us to achieve this. The
corresponding Tripy changes are #72.
  • Loading branch information
markkraay authored Aug 16, 2024
1 parent 93cc518 commit 7705c34
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 3 deletions.
13 changes: 13 additions & 0 deletions mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#ifndef MLIR_EXECUTOR_C_RUNTIME_RUNTIME
#define MLIR_EXECUTOR_C_RUNTIME_RUNTIME

#include "dlpack/dlpack.h"
#include "mlir-c/Support.h"
#include "mlir-executor-c/Common/Common.h"
#include "mlir-executor-c/Support/Status.h"
Expand Down Expand Up @@ -382,6 +383,18 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionExecuteFunction(
const MTRT_RuntimeValue *inArgs, size_t numInArgs,
const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream);

//===----------------------------------------------------------------------===//
// DLPack
//===----------------------------------------------------------------------===//

/// Converts a DLDeviceType to MTRT_PointerType. This function will throw a runtime
/// error if the device type is invalid.
MLIR_CAPI_EXPORTED MTRT_Status mtrtGetPointerTypeFromDLDeviceType(DLDeviceType device, MTRT_PointerType* result);

/// Converts a DLDataType to MTRT_ScalarTypeCode. This function will throw a runtime
/// error if the data type is invalid.
MLIR_CAPI_EXPORTED MTRT_Status mtrtGetScalarTypeCodeFromDLDataType(DLDataType dtype, MTRT_ScalarTypeCode* result);

#ifdef __cplusplus
}
#endif
Expand Down
52 changes: 49 additions & 3 deletions mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,22 @@ static StatusOr<DLDeviceType> toDLPackDeviceType(PointerType address) {
return DLDeviceType::kDLCPU;
}

MTRT_Status mtrtGetPointerTypeFromDLDeviceType(DLDeviceType device, MTRT_PointerType* result) {
#define RETURN_OK(v) *result = v; return mtrtStatusGetOk();
switch (device) {
case DLDeviceType::kDLCUDA: RETURN_OK(MTRT_PointerType_device)
case DLDeviceType::kDLCPU: RETURN_OK(MTRT_PointerType_host)
case DLDeviceType::kDLCUDAHost: RETURN_OK(MTRT_PointerType_host)
case DLDeviceType::kDLCUDAManaged: RETURN_OK(MTRT_PointerType_unified)
default:
return wrap(getStatusWithMsg(
StatusCode::InvalidArgument, "DLDeviceType [",
// device,
"] conversion to MTRT_PointerType is not supported."));
}
#undef RETURN_OK
}

static StatusOr<DLDataTypeCode> toDLPackDataTypeCode(ScalarTypeCode type) {
switch (type) {
case ScalarTypeCode::i1:
Expand All @@ -347,13 +363,43 @@ static StatusOr<DLDataTypeCode> toDLPackDataTypeCode(ScalarTypeCode type) {
return DLDataTypeCode::kDLBfloat;
default:
return getStatusWithMsg(
StatusCode::InvalidArgument, "Scalar type code [",
EnumNameScalarTypeCode(type),
"] conversion to DLPackDataTypeCode is not supported.");
StatusCode::InvalidArgument, "Scalar type code conversion to DLPackDataTypeCode is not supported.");
}
return DLDataTypeCode::kDLFloat;
}

MTRT_Status mtrtGetScalarTypeCodeFromDLDataType(DLDataType dtype, MTRT_ScalarTypeCode* result) {
#define RETURN_OK(v) *result = v; return mtrtStatusGetOk();
switch (dtype.code) {
case kDLBool: RETURN_OK(MTRT_ScalarTypeCode_i1)
case kDLInt:
switch (dtype.bits) {
case 8: RETURN_OK(MTRT_ScalarTypeCode_i8)
case 16: RETURN_OK(MTRT_ScalarTypeCode_i16)
case 32: RETURN_OK(MTRT_ScalarTypeCode_i32)
case 64: RETURN_OK(MTRT_ScalarTypeCode_i64)
}
case kDLUInt:
switch (dtype.bits) {
case 8: RETURN_OK(MTRT_ScalarTypeCode_ui8);
}
case kDLFloat:
switch (dtype.bits) {
case 8: RETURN_OK(MTRT_ScalarTypeCode_f8e4m3fn)
case 16: RETURN_OK(MTRT_ScalarTypeCode_f16)
case 32: RETURN_OK(MTRT_ScalarTypeCode_f32)
case 64: RETURN_OK(MTRT_ScalarTypeCode_f64)
}
case kDLBfloat: RETURN_OK(MTRT_ScalarTypeCode_bf16)
case kDLComplex:
case kDLOpaqueHandle:
default:
return wrap(getStatusWithMsg(
StatusCode::InvalidArgument, "DLDataType conversion to MTRT_ScalarTypeCode is not supported."));
}
#undef RETURN_OK
}

static void dlpackManagedTensorDeleter(DLManagedTensor *tensor) {
if (tensor) {
delete[] tensor->dl_tensor.shape;
Expand Down
76 changes: 76 additions & 0 deletions mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,76 @@ static std::unique_ptr<PyMemRefValue> createMemRef(
return std::make_unique<PyMemRefValue>(result);
}



static std::unique_ptr<PyMemRefValue>
createMemRefViewFromDLPack(PyRuntimeClient &client,
py::capsule capsule) {
DLManagedTensor *managedTensor =
static_cast<DLManagedTensor*>(PyCapsule_GetPointer(capsule.ptr(), "dltensor"));

if (managedTensor == nullptr) {
Py_DECREF(capsule);
return nullptr;
}

MTRT_MemRefValue result{nullptr};

// Extract the necessary information from the DLManagedTensor
void *data = managedTensor->dl_tensor.data;
int64_t *shape = managedTensor->dl_tensor.shape;

int64_t* strides = managedTensor->dl_tensor.strides;
std::vector<int64_t> stridesArray;
if (!strides) {
// Create a suffix product stride array in the event that the DLPack object's stride array is set to `null`
auto ndim = managedTensor->dl_tensor.ndim;
stridesArray.resize(ndim);
if (ndim > 0) {
stridesArray[ndim - 1] = 1;
for (int i = ndim - 2; i >= 0; i--) {
stridesArray[i] = shape[i+1] * stridesArray[i+1];
}
}
strides = stridesArray.data();
}

int64_t offset = managedTensor->dl_tensor.byte_offset;
int rank = managedTensor->dl_tensor.ndim;
DLDataType dtype = managedTensor->dl_tensor.dtype;
DLDeviceType device_type = managedTensor->dl_tensor.device.device_type;
int device_id = managedTensor->dl_tensor.device.device_id;

MTRT_ScalarTypeCode elementType;
MTRT_Status s;
s = mtrtGetScalarTypeCodeFromDLDataType(dtype, &elementType);
THROW_IF_MTRT_ERROR(s);

int64_t bytesPerElement = llvm::divideCeil(dtype.bits, 8);

MTRT_PointerType addressSpace;
s = mtrtGetPointerTypeFromDLDeviceType(device_type, &addressSpace);
THROW_IF_MTRT_ERROR(s);

MTRT_Device device{nullptr};
if (addressSpace == MTRT_PointerType_device) {
s = mtrtRuntimeClientGetDevice(client, device_id, &device);
THROW_IF_MTRT_ERROR(s);
}

if (data) {
s = mtrtMemRefCreateExternal(client, addressSpace, bytesPerElement * 8,
reinterpret_cast<uintptr_t>(data), offset, rank,
shape, strides, device, elementType, &result);
} else {
s = mtrtMemRefCreate(client, addressSpace, bytesPerElement * 8, rank,
shape, strides, device, mtrtStreamGetNull(), elementType, &result);
}

THROW_IF_MTRT_ERROR(s);
return std::make_unique<PyMemRefValue>(result);
}

static std::unique_ptr<PyMemRefValue> getMemRefFromHostBufferProtocol(
PyRuntimeClient &client, py::buffer array,
std::optional<std::vector<int64_t>> explicitShape,
Expand Down Expand Up @@ -716,6 +786,12 @@ PYBIND11_MODULE(_api, m) {
py::arg("shape"), py::arg("dtype"), py::arg("device") = py::none(),
py::arg("stream") = py::none(), py::keep_alive<0, 1>(),
"returns a new memref and allocates uninitialized backing storage")
.def(
"create_memref_view_from_dlpack",
[](PyRuntimeClient &self, py::capsule capsule) {
return createMemRefViewFromDLPack(self, capsule).release();
},
py::arg("dltensor") = py::none(), py::keep_alive<0, 1>(), py::keep_alive<0, 2>())
.def(
"create_device_memref_view",
[](PyRuntimeClient &self, uintptr_t ptr, std::vector<int64_t> shape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,68 @@ def memref_alloc():
# CHECK-NEXT: Memref released internally: True
# CHECK-NEXT: Number of External reference count: 1
# CHECK-NEXT: Numpy Array: [5. 4. 2.]


def create_memref_from_dlpack(arr, module):
print(f"Array: {arr}")
memref = client.create_memref_view_from_dlpack(arr.__dlpack__())
print(f"-- Memref shape: {memref.shape}")
print(f"-- Memref dtype: {memref.dtype}")
print(f"-- {module.__name__}.from_dlpack(): {module.from_dlpack(memref)}")

print(f"Test np.array -> client.create_memref_from_dlpack")
create_memref_from_dlpack(np.array([1, 2, 3, 4], dtype=np.int32), np)
create_memref_from_dlpack(np.ones((1, 2, 3), dtype=np.float32), np)
create_memref_from_dlpack(np.ones(0, dtype=np.int8), np)
print(f"Test cp.array -> client.create_memref_from_dlpack")
create_memref_from_dlpack(cp.array([1, 2, 3, 4], dtype=cp.int32), cp)
create_memref_from_dlpack(cp.ones((1, 2, 3), dtype=cp.float32), cp)
create_memref_from_dlpack(cp.ones(0, dtype=cp.int8), cp)


# CHECK-LABEL: Test np.array -> client.create_memref_from_dlpack
# CHECK-NEXT: Array: [1 2 3 4]
# CHECK-NEXT: -- Memref shape: [4]
# CHECK-NEXT: -- Memref dtype: ScalarTypeCode.i32
# CHECK-NEXT: -- numpy.from_dlpack(): [1 2 3 4]
# CHECK-NEXT: Array: {{\[\[\[1. 1. 1.\][[:space:]]*\[1. 1. 1.\]\]\]}}
# CHECK-NEXT: -- Memref shape: [1, 2, 3]
# CHECK-NEXT: -- Memref dtype: ScalarTypeCode.f32
# CHECK-NEXT: -- numpy.from_dlpack(): {{\[\[\[1. 1. 1.\][[:space:]]*\[1. 1. 1.\]\]\]}}
# CHECK-NEXT: Array: []
# CHECK-NEXT: -- Memref shape: [0]
# CHECK-NEXT: -- Memref dtype: ScalarTypeCode.i8
# CHECK-NEXT: -- numpy.from_dlpack(): []
# CHECK-LABEL: Test cp.array -> client.create_memref_from_dlpack
# CHECK-NEXT: Array: [1 2 3 4]
# CHECK-NEXT: -- Memref shape: [4]
# CHECK-NEXT: -- Memref dtype: ScalarTypeCode.i32
# CHECK-NEXT: -- cupy.from_dlpack(): [1 2 3 4]
# CHECK-NEXT: Array: {{\[\[\[1. 1. 1.\][[:space:]]*\[1. 1. 1.\]\]\]}}
# CHECK-NEXT: -- Memref shape: [1, 2, 3]
# CHECK-NEXT: -- Memref dtype: ScalarTypeCode.f32
# CHECK-NEXT: -- cupy.from_dlpack(): {{\[\[\[1. 1. 1.\][[:space:]]*\[1. 1. 1.\]\]\]}}
# CHECK-NEXT: Array: []
# CHECK-NEXT: -- Memref shape: [0]
# CHECK-NEXT: -- Memref dtype: ScalarTypeCode.i8
# CHECK-NEXT: -- cupy.from_dlpack(): []


def create_dangling_memref():
array = np.array([1, 2])
dlpack_capsule = array.__dlpack__()
memref = client.create_memref_view_from_dlpack(dlpack_capsule)
print("-- Inner scope: np.from_dlpack(): ", np.from_dlpack(memref))
return memref


print("Test memref maintains data's lifetime")
memref = create_dangling_memref()
# Declare a new array to overwrite the old memory
b = np.array([10, 10])
print("-- Outer scope: np.from_dlpack(): ", np.from_dlpack(memref))


# CHECK-LABEL: Test memref maintains data's lifetime
# CHECK-NEXT: -- Inner scope: np.from_dlpack(): [1 2]
# CHECK-NEXT: -- Outer scope: np.from_dlpack(): [1 2]

0 comments on commit 7705c34

Please sign in to comment.