From 6417ad43e4ed89e5c08d3ad08ed571b0c9c769d6 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Sun, 27 Oct 2024 23:06:03 -0500 Subject: [PATCH] [IREE-EP] Integrate iree async module in the IREE-EP Signed-Off-by: Gaurav Shukla --- .../core/providers/iree/iree_ep_runtime.cc | 83 ++++++++++++++++--- .../core/providers/iree/iree_ep_runtime.h | 3 + .../providers/iree/iree_execution_provider.cc | 2 + onnxruntime/test/perftest/ort_test_session.cc | 4 + 4 files changed, 79 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/iree/iree_ep_runtime.cc b/onnxruntime/core/providers/iree/iree_ep_runtime.cc index 086ef9962465a..7489e4d39cb54 100644 --- a/onnxruntime/core/providers/iree/iree_ep_runtime.cc +++ b/onnxruntime/core/providers/iree/iree_ep_runtime.cc @@ -4,6 +4,7 @@ #include "core/providers/iree/iree_ep_runtime.h" #include "core/session/onnxruntime_cxx_api.h" +#include namespace onnxruntime::iree_ep_rt { @@ -57,10 +58,18 @@ Session::~Session() { } iree_status_t Session::Initialize() { - return iree_runtime_session_create_with_device( + iree_status_t res_status = iree_runtime_session_create_with_device( instance->instance, &session_options, instance->device, iree_runtime_instance_host_allocator(instance->instance), &session); + iree_vm_module_t* custom_module = NULL; + iree_allocator_t host_allocator = iree_allocator_system(); + IREE_CHECK_OK(iree_custom_module_async_create( + iree_runtime_instance_vm_instance(instance->instance), instance->device, + host_allocator, &custom_module)); + IREE_CHECK_OK(iree_runtime_session_append_module(session, custom_module)); + iree_vm_module_release(custom_module); + return res_status; } iree_status_t Session::AppendBytecodeModule(fs::path vmfb_path, std::function dispose_callback) { @@ -147,6 +156,13 @@ iree_hal_element_type_t ConvertOrtElementType(ONNXTensorElementDataType et) { common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api, OrtKernelContext* ort_context_c) { // TODO: This is far from the most efficient way to make a call. Synchronous and copying. We can do // better but this gets points for simplicity and lets us bootstrap the tests. + iree_vm_list_t* inputs = NULL; + iree_allocator_t host_allocator = iree_allocator_system(); + IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), 1, + host_allocator, &inputs)); + iree_vm_list_t* outputs = NULL; + IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), 1, + host_allocator, &outputs)); Ort::KernelContext context(ort_context_c); SynchronousCall call(session); ORT_RETURN_IF_ERROR(HandleIREEStatus(call.InitializeByName(entrypoint_name))); @@ -161,8 +177,10 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api, // Process inputs. We could be smarter about this in a lot of ways, including carrying // more state from compilation so we are doing less munging here. - for (size_t i = 0; i < context.GetInputCount(); ++i) { - auto input_tensor = context.GetInput(i); + + std::cout<<"input count: "< result: // Invoke. - ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_runtime_call_invoke(&call.call, /*flags=*/0))); + // ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_runtime_call_invoke(&call.call, [>flags=<]0))); // Marshal the outputs. // TODO: Accessing the ORT output requires the shape and then we could get zero copy @@ -222,16 +272,19 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api, // convention, which allows passing in slabs of result buffers. Further, that would // run the host-side computation (which would compute output metadata) inline. // For static cases, we could also side-load the shape from the compile time. - std::vector shape; - for (size_t i = 0; i < context.GetOutputCount(); ++i) { + // std::vector shape; + std::cout<<"output count: "< namespace fs = std::filesystem; diff --git a/onnxruntime/core/providers/iree/iree_execution_provider.cc b/onnxruntime/core/providers/iree/iree_execution_provider.cc index d504561707e60..bad069632941a 100644 --- a/onnxruntime/core/providers/iree/iree_execution_provider.cc +++ b/onnxruntime/core/providers/iree/iree_execution_provider.cc @@ -118,6 +118,8 @@ common::Status IREEExecutionProvider::Compile(const std::vector