Skip to content

Commit

Permalink
[IREE-EP] Integrate iree async module in the IREE-EP
Browse files Browse the repository at this point in the history
Signed-Off-by: Gaurav Shukla <[email protected]>
  • Loading branch information
Shukla-Gaurav committed Nov 11, 2024
1 parent 7b2046f commit 6417ad4
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 13 deletions.
83 changes: 70 additions & 13 deletions onnxruntime/core/providers/iree/iree_ep_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "core/providers/iree/iree_ep_runtime.h"

#include "core/session/onnxruntime_cxx_api.h"
#include <iostream>

Check warning on line 7 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: iree_ep_runtime.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:7: Found C++ system header after other header. Should be: iree_ep_runtime.h, c system, c++ system, other. [build/include_order] [4]

namespace onnxruntime::iree_ep_rt {

Expand Down Expand Up @@ -57,10 +58,18 @@ Session::~Session() {
}

iree_status_t Session::Initialize() {
return iree_runtime_session_create_with_device(
iree_status_t res_status = iree_runtime_session_create_with_device(
instance->instance, &session_options, instance->device,
iree_runtime_instance_host_allocator(instance->instance),
&session);
iree_vm_module_t* custom_module = NULL;
iree_allocator_t host_allocator = iree_allocator_system();
IREE_CHECK_OK(iree_custom_module_async_create(
iree_runtime_instance_vm_instance(instance->instance), instance->device,
host_allocator, &custom_module));
IREE_CHECK_OK(iree_runtime_session_append_module(session, custom_module));
iree_vm_module_release(custom_module);
return res_status;
}

iree_status_t Session::AppendBytecodeModule(fs::path vmfb_path, std::function<void()> dispose_callback) {
Expand Down Expand Up @@ -147,6 +156,13 @@ iree_hal_element_type_t ConvertOrtElementType(ONNXTensorElementDataType et) {
common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api, OrtKernelContext* ort_context_c) {
// TODO: This is far from the most efficient way to make a call. Synchronous and copying. We can do
// better but this gets points for simplicity and lets us bootstrap the tests.
iree_vm_list_t* inputs = NULL;
iree_allocator_t host_allocator = iree_allocator_system();
IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), 1,
host_allocator, &inputs));
iree_vm_list_t* outputs = NULL;
IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), 1,
host_allocator, &outputs));
Ort::KernelContext context(ort_context_c);
SynchronousCall call(session);
ORT_RETURN_IF_ERROR(HandleIREEStatus(call.InitializeByName(entrypoint_name)));
Expand All @@ -161,8 +177,10 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,

// Process inputs. We could be smarter about this in a lot of ways, including carrying
// more state from compilation so we are doing less munging here.
for (size_t i = 0; i < context.GetInputCount(); ++i) {
auto input_tensor = context.GetInput(i);

Check warning on line 180 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:180: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
std::cout<<"input count: "<<context.GetInputCount()<<"\n";

Check warning on line 181 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing spaces around << [whitespace/operators] [3] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:181: Missing spaces around << [whitespace/operators] [3]
// for (size_t i = 0; i < context.GetInputCount(); ++i) {
auto input_tensor = context.GetInput(0);
ORT_ENFORCE(input_tensor.IsTensor());

// The device type is rather... sparse... CPU, GPU and FPGA. Not sure how that
Expand Down Expand Up @@ -207,13 +225,45 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
// Buffer view + storage are returned and owned by the caller:
&arg.bv)));

iree_vm_ref_t input_view_ref = iree_hal_buffer_view_move_ref(arg.bv);
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &input_view_ref));

iree_hal_semaphore_t* semaphore = NULL;
IREE_CHECK_OK(iree_hal_semaphore_create(
device, 0ull, IREE_HAL_SEMAPHORE_FLAG_NONE, &semaphore));
iree_hal_fence_t* fence_t1 = NULL;
IREE_CHECK_OK(
iree_hal_fence_create_at(semaphore, 1ull, host_allocator, &fence_t1));
iree_hal_fence_t* fence_t2 = NULL;
IREE_CHECK_OK(
iree_hal_fence_create_at(semaphore, 2ull, host_allocator, &fence_t2));
iree_hal_semaphore_release(semaphore);
std::cout<<"\n semaphore released";

Check warning on line 241 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing spaces around << [whitespace/operators] [3] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:241: Missing spaces around << [whitespace/operators] [3]
iree_vm_ref_t fence_t1_ref = iree_hal_fence_retain_ref(fence_t1);
std::cout<<"\n semaphore released1";

Check warning on line 243 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing spaces around << [whitespace/operators] [3] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:243: Missing spaces around << [whitespace/operators] [3]
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &fence_t1_ref));
std::cout<<"\n semaphore released2";

Check warning on line 245 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing spaces around << [whitespace/operators] [3] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:245: Missing spaces around << [whitespace/operators] [3]
iree_vm_ref_t fence_t2_ref = iree_hal_fence_retain_ref(fence_t2);
std::cout<<"\n semaphore released3";

Check warning on line 247 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing spaces around << [whitespace/operators] [3] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:247: Missing spaces around << [whitespace/operators] [3]
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &fence_t2_ref));
std::cout<<"\n semaphore released4";

Check warning on line 249 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing spaces around << [whitespace/operators] [3] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:249: Missing spaces around << [whitespace/operators] [3]
IREE_CHECK_OK(iree_hal_fence_signal(fence_t1));
std::cout<<"\n T=1 reached";

Check warning on line 251 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing spaces around << [whitespace/operators] [3] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:251: Missing spaces around << [whitespace/operators] [3]
// Add it to the call.
iree_status_t status = iree_runtime_call_inputs_push_back_buffer_view(&call.call, arg.bv);
ORT_RETURN_IF_ERROR(HandleIREEStatus(status));
}
iree_string_view_t entry_point = iree_make_cstring_view(entrypoint_name);
IREE_CHECK_OK(
iree_runtime_session_call_by_name(session, entry_point, inputs, outputs));
// We could go do other things now while the async work progresses. Here we
// just immediately wait.
IREE_CHECK_OK(iree_hal_fence_wait(fence_t2, iree_infinite_timeout()));
std::cout<<"\n T=2 reached";

Check warning on line 259 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing spaces around << [whitespace/operators] [3] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:259: Missing spaces around << [whitespace/operators] [3]
// iree_status_t status = iree_runtime_call_inputs_push_back_buffer_view(&call.call, arg.bv);
// ORT_RETURN_IF_ERROR(HandleIREEStatus(status));
// }
// Read back the tensor<?xi32> result:

// Invoke.
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_runtime_call_invoke(&call.call, /*flags=*/0)));
// ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_runtime_call_invoke(&call.call, [>flags=<]0)));

// Marshal the outputs.
// TODO: Accessing the ORT output requires the shape and then we could get zero copy
Expand All @@ -222,16 +272,19 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
// convention, which allows passing in slabs of result buffers. Further, that would
// run the host-side computation (which would compute output metadata) inline.
// For static cases, we could also side-load the shape from the compile time.
std::vector<int64_t> shape;
for (size_t i = 0; i < context.GetOutputCount(); ++i) {
// std::vector<int64_t> shape;
std::cout<<"output count: "<<context.GetOutputCount()<<"\n";
// for (size_t i = 0; i < context.GetOutputCount(); ++i) {
HalBufferView ret;
ORT_RETURN_IF_ERROR(HandleIREEStatus(
iree_runtime_call_outputs_pop_front_buffer_view(&call.call, &ret.bv)));
ret.bv = iree_vm_list_get_buffer_view_assign(outputs, 0);
// ORT_RETURN_IF_ERROR(HandleIREEStatus(
// iree_runtime_call_outputs_pop_front_buffer_view(&call.call, &ret.bv)));
size_t ret_rank = iree_hal_buffer_view_shape_rank(ret.bv);
const iree_hal_dim_t* ret_dims = iree_hal_buffer_view_shape_dims(ret.bv);
shape.clear();
shape.resize(ret_rank);
std::copy(ret_dims, ret_dims + ret_rank, shape.begin());
auto output_tensor = context.GetOutput(i, shape.data(), shape.size());
auto output_tensor = context.GetOutput(0, shape.data(), shape.size());
ORT_ENFORCE(output_tensor.IsTensor());

iree_hal_buffer_t* ret_buffer = iree_hal_buffer_view_buffer(ret.bv);
Expand All @@ -250,8 +303,12 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_map_read(ret_buffer, /*source_offset=*/0,
output_tensor.GetTensorMutableRawData(),
iree_hal_buffer_view_byte_length(ret.bv))));
}
// }

iree_vm_list_release(inputs);
iree_vm_list_release(outputs);
iree_hal_fence_release(fence_t1);
iree_hal_fence_release(fence_t2);
return common::Status::OK();
}

Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/iree/iree_ep_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

#include "core/common/common.h"
#include "core/session/onnxruntime_c_api.h"
#include "iree/modules/hal/types.h"
#include "iree/runtime/api.h"

#include "module.h"

#include <filesystem>

namespace fs = std::filesystem;
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/iree/iree_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ common::Status IREEExecutionProvider::Compile(const std::vector<FusedNodeAndGrap
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compile: setting flag " << extra_flag;
ORT_RETURN_IF_ERROR(compiler.SetFlag(extra_flag.c_str()));
}
std::string extra_flag_2 = "--iree-execution-model=async-external";
ORT_RETURN_IF_ERROR(compiler.SetFlag(extra_flag_2.c_str()));

ORT_RETURN_IF_ERROR(compiler.Initialize());
std::string module_name = "ort";
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/test/perftest/ort_test_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
#include "core/providers/dml/dml_session_options_config_keys.h"
#endif

#ifdef USE_IREE
#include "core/providers/iree/iree_provider_factory.h"
#endif

#ifdef _WIN32
#define strdup _strdup
#endif
Expand Down

0 comments on commit 6417ad4

Please sign in to comment.