Skip to content

Commit

Permalink
[IREE][EP] Add support for rocm backend
Browse files Browse the repository at this point in the history
This commit adds support for rocm backend in iree-ep.

Signed-Off-by: Gaurav Shukla<[email protected]>
  • Loading branch information
Shukla-Gaurav committed Sep 24, 2024
1 parent 1d4576f commit 1c81f11
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 36 deletions.
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/iree/compiler/jit_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include "mlir-c/BuiltinAttributes.h"

#include <cstring>
#include <filesystem>

namespace onnxruntime::iree_ep_jit {

Expand Down Expand Up @@ -208,12 +207,15 @@ common::Status CompilerInvocation::ImportSubgraph(const onnxruntime::GraphViewer
return common::Status::OK();
}

common::Status CompilerInvocation::CompileAndOutputVMFB(iree_compiler_output_t* output) {
common::Status CompilerInvocation::CompileAndOutputVMFB(iree_compiler_output_t* output, fs::path vmfb_path) {
// Main compilation.
if (!ireeCompilerInvocationPipeline(inv, IREE_COMPILER_PIPELINE_STD)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "IREE compilation error.", ConsumeDiagnostics());
}

// Attach the compiled output to a file.
ireeCompilerOutputOpenFile(vmfb_path.c_str(), &output);

// Output.
if (auto* err = ireeCompilerInvocationOutputVMBytecode(inv, output)) {
return ErrorToStatus(err, "Failure emitting VM bytecode: ");
Expand Down
12 changes: 8 additions & 4 deletions onnxruntime/core/providers/iree/compiler/jit_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
#include "iree/compiler/embedding_api.h"
#include "iree/compiler/mlir_interop.h"

#include <filesystem>

Check warning on line 15 in onnxruntime/core/providers/iree/compiler/jit_compiler.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: jit_compiler.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/providers/iree/compiler/jit_compiler.h:15: Found C++ system header after other header. Should be: jit_compiler.h, c system, c++ system, other. [build/include_order] [4]
#include <string>
#include <string_view>

namespace fs = std::filesystem;

namespace onnxruntime::iree_ep_jit {

common::Status ErrorToStatus(iree_compiler_error_t* err, std::string message_prefix);
Expand Down Expand Up @@ -44,13 +47,14 @@ struct CompilerOutput {

// Releases ownership of the output, returning a callback that can be used to
// destroy it at a later date.
std::function<void()> Release() {
iree_compiler_output_t* local_output = output;
std::function<void()> Release(fs::path vmfb_path) {
iree_compiler_output_t* local_output = this->output;
this->output = nullptr;
return [local_output]() {
return [local_output, &vmfb_path]() {
if (local_output) {
ireeCompilerOutputDestroy(local_output);
}
fs::remove(vmfb_path);
};
}

Expand Down Expand Up @@ -84,7 +88,7 @@ struct CompilerInvocation {
common::Status ImportSubgraph(const onnxruntime::GraphViewer& graph_view, const std::string& func_name);

// Compile and output a VMFB.
common::Status CompileAndOutputVMFB(iree_compiler_output_t* output);
common::Status CompileAndOutputVMFB(iree_compiler_output_t* output, fs::path vmfb_path);

// If there are any diagnostics, clears them and returns a loggable string.
std::string ConsumeDiagnostics();
Expand Down
38 changes: 20 additions & 18 deletions onnxruntime/core/providers/iree/iree_ep_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,7 @@ common::Status HandleFailingIREEStatus(iree_status_t iree_status) {
return common::Status::OK();
}

std::string buffer;
iree_host_size_t actual_len;
buffer.resize(1024);
if (!iree_status_format(iree_status, buffer.size(), buffer.data(),
&actual_len)) {
buffer.resize(actual_len);
if (!iree_status_format(iree_status, buffer.size(), buffer.data(),
&actual_len)) {
actual_len = 0;
}
}
buffer.resize(actual_len);
std::string buffer = iree::Status::ToString(iree_status);

return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "IREE Runtime Error: ", std::move(buffer));
}
Expand All @@ -43,13 +32,13 @@ Instance::~Instance() {
}
}

iree_status_t Instance::Initialize() {
iree_status_t Instance::Initialize(std::string device_str) {
IREE_RETURN_IF_ERROR(iree_runtime_instance_create(
&options, iree_allocator_system(), &instance));

// TODO: Need real device selection.
IREE_RETURN_IF_ERROR(iree_runtime_instance_try_create_default_device(
instance, iree_make_cstring_view("local-task"), &device));
instance, iree_make_cstring_view(device_str.c_str()), &device));

return iree_ok_status();
}
Expand All @@ -74,11 +63,14 @@ iree_status_t Session::Initialize() {
&session);
}

iree_status_t Session::AppendBytecodeModule(void* contents, uint64_t size, std::function<void()> dispose_callback) {
iree_status_t Session::AppendBytecodeModule(fs::path vmfb_path, std::function<void()> dispose_callback) {
dispose_callbacks.push_back(std::move(dispose_callback));
return iree_runtime_session_append_bytecode_module_from_memory(
session, iree_make_const_byte_span(contents, size),
iree_allocator_null());
// TODO(Shukla-Gaurav): load from memory instead of file.
// return iree_runtime_session_append_bytecode_module_from_memory(
// session, iree_make_const_byte_span(contents, size),
// iree_allocator_null());
return iree_runtime_session_append_bytecode_module_from_file(
session, file_loc.c_str());
}

namespace {
Expand Down Expand Up @@ -245,6 +237,16 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
iree_hal_buffer_t* ret_buffer = iree_hal_buffer_view_buffer(ret.bv);
// TODO: Synchronous mapping read, like everything in this function, is not a
// great idea. It isn't supported on all device types and will need a scrub.
iree_string_view_t device_val = iree_hal_device_id(device);
auto device_str = std::string(device_val.data, device_val.size);

Check warning on line 241 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:241: Add #include <string> for string [build/include_what_you_use] [4]
if (device_str == "hip") {
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_device_transfer_d2h(
iree_runtime_session_device(session),
ret_buffer, 0, output_tensor.GetTensorMutableRawData(),
iree_hal_buffer_view_byte_length(ret.bv), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
iree_infinite_timeout())));
return common::Status::OK();
}
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_map_read(ret_buffer, /*source_offset=*/0,
output_tensor.GetTensorMutableRawData(),
iree_hal_buffer_view_byte_length(ret.bv))));
Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/core/providers/iree/iree_ep_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
#include "core/session/onnxruntime_c_api.h"
#include "iree/runtime/api.h"

#include <filesystem>

Check warning on line 10 in onnxruntime/core/providers/iree/iree_ep_runtime.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: iree_ep_runtime.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.h:10: Found C++ system header after other header. Should be: iree_ep_runtime.h, c system, c++ system, other. [build/include_order] [4]

namespace fs = std::filesystem;

namespace onnxruntime::iree_ep_rt {

// Handles a failing IREE status.
Expand All @@ -27,7 +31,7 @@ struct Instance {

// Initializes the instance.
// TODO: We should probably pass the options in here and use it to set up.
iree_status_t Initialize();
iree_status_t Initialize(std::string device_str);

Check warning on line 34 in onnxruntime/core/providers/iree/iree_ep_runtime.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.h:34: Add #include <string> for string [build/include_what_you_use] [4]

// Instance globals.
iree_runtime_instance_options_t options;
Expand All @@ -48,7 +52,7 @@ struct Session {
// Append a user-compiled bytecode module buffer to the session, along with a dispose callback.
// The dispose callback will be invoked when Session is destroyed regardless of success/failure
// of this call.
iree_status_t AppendBytecodeModule(void* contents, uint64_t size, std::function<void()> dispose_callback);
iree_status_t AppendBytecodeModule(fs::path vmfb_path, std::function<void()> dispose_callback);

// Calls the entrypoint. This returns an ORT Status and normalizes any IREE statuses to that
// because that can arise from ORT interactions.
Expand Down
45 changes: 35 additions & 10 deletions onnxruntime/core/providers/iree/iree_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ IREEExecutionProvider::~IREEExecutionProvider() {
}

common::Status IREEExecutionProvider::Initialize() {
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_instance_->Initialize()));
if (info_.find("device") == info_.end())
info_["device"] = "local-task";
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_instance_->Initialize(info_["device"])));
return common::Status::OK();
}

Expand Down Expand Up @@ -98,15 +100,25 @@ common::Status IREEExecutionProvider::Compile(const std::vector<FusedNodeAndGrap
// TODO: The target needs to be synchronized with the runtime based on EP options.
// TODO: We should just be adding the target to the module instead of specifying via
// flags.
std::string device_flag = "--iree-hal-target-backends=";
std::string device_flag = "--iree-hal-target-device=";
if (info_.find("hal_target_device") == info_.end()) {
// In case device info is absent, set `llvm-cpu` as default hal-target-backend.
// In case device info is absent, set `llvm-cpu` as default hal-target-device.
device_flag.append("llvm-cpu");
} else {
device_flag.append(info_["hal_target_device"]);
}
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compile: setting device flag as " << device_flag;
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compile: setting flag " << device_flag;
ORT_RETURN_IF_ERROR(compiler.SetFlag(device_flag.c_str()));

// Set all the compile-time flags.
// TODO(Shukla-Gaurav): Use ireeCompilerSessionSetFlags API to set all the flags at once.
// TODO(Shukla-Gaurav): support more than one extra flags by parsing the input string.
if (info_.find("compile_time_flags") != info_.end()) {
std::string extra_flag = info_["compile_time_flags"];
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compile: setting flag " << extra_flag;
ORT_RETURN_IF_ERROR(compiler.SetFlag(extra_flag.c_str()));
}

ORT_RETURN_IF_ERROR(compiler.Initialize());
std::string module_name = "ort";
iree_ep_jit::CompilerInvocation inv(compiler, module_name.c_str());
Expand All @@ -133,20 +145,33 @@ common::Status IREEExecutionProvider::Compile(const std::vector<FusedNodeAndGrap
if (auto* err = ireeCompilerOutputOpenMembuffer(&vmfb_output.output)) {
return iree_ep_jit::ErrorToStatus(err, "Failure opening compiler output buffer: ");
}
ORT_RETURN_IF_ERROR(inv.CompileAndOutputVMFB(vmfb_output.output));

// This will save the compiled module to temporary directory.
fs::path save_to = fs::temp_directory_path();
if (info_.find("save_to") != info_.end() && fs::is_directory(info_["save_to"])
save_to = fs::path(info_["save_to"]);

fs::path file_name("compiled_model.vmfb");
fs::path vmfb_path = save_to / file_name;


ORT_RETURN_IF_ERROR(inv.CompileAndOutputVMFB(vmfb_output.output, vmfb_path));
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compiled vmfb saved at this location " << vmfb_path;

// Map raw memory.
void* vmfb_contents;
uint64_t vmfb_size;
ORT_RETURN_IF_ERROR(vmfb_output.MapMemory(&vmfb_contents, &vmfb_size));
// void* vmfb_contents = nullptr;
// uint64_t vmfb_size = 0;
// TODO(Shukla-Gaurav): Map memory instead of storing the compiled module as a file
// ORT_RETURN_IF_ERROR(vmfb_output.MapMemory(&vmfb_contents, &vmfb_size));

// Create a new runtime session.
auto rt_session = std::make_shared<iree_ep_rt::Session>(rt_instance_);
// In case device info is absent, set `local-task` as default device.
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->Initialize()));

// Load the compiled module, releasing our ownership of the CompilerOutput.
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->AppendBytecodeModule(
vmfb_contents, vmfb_size, vmfb_output.Release())));
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->AppendBytecodeModule(vmfb_path,
vmfb_output.Release(vmfb_path))));

for (auto& entrypoint_name : entrypoint_names) {
node_compute_funcs.push_back(CreateNodeComputeFunc(entrypoint_name, rt_session));
Expand Down

0 comments on commit 1c81f11

Please sign in to comment.