Skip to content

Commit

Permalink
Fix device list of loaded executable in PJRT plugin for multiple GPUs (
Browse files Browse the repository at this point in the history
…iree-org#19369)

It closes iree-org#19366, and blocks iree-org#19279.

After this PR, `ClientOptions::Compile` will first check the device
assignment in the compile options, and then return the corresponding
device list with the loaded executable.

To achieve this, we introduce protobuf via `FetchContent` in this PR,
which is scoped to the PJRT plugin. Compile options will be passed by
the PJRT client encoded in protobuf, and in this plugin we decode it
first and then retrieve some interesting fields.

ci-exactly: build_packages, test_pjrt

---------

Signed-off-by: PragmaTwice <[email protected]>
Co-authored-by: Scott Todd <[email protected]>
  • Loading branch information
PragmaTwice and ScottTodd authored Dec 6, 2024
1 parent 094ea05 commit d88d0a7
Show file tree
Hide file tree
Showing 9 changed files with 1,462 additions and 23 deletions.
4 changes: 4 additions & 0 deletions integrations/pjrt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ add_subdirectory("${IREE_ROOT_DIR}" "iree_core" EXCLUDE_FROM_ALL)
# toolchain level.
iree_setup_toolchain()

# Setup protoc and protobuf library
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake")
include(protobuf_cc_library)

add_subdirectory(src)
add_subdirectory(third_party/pjrt_c_api)

Expand Down
91 changes: 91 additions & 0 deletions integrations/pjrt/cmake/protobuf_cc_library.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

include(CMakeParseArguments)
include(FetchContent)

# disable some targets we don't use
set(protobuf_INSTALL OFF)
set(protobuf_BUILD_TESTS OFF)

# to prevent protobuf itself from using `find_package`
set(protobuf_FORCE_FETCH_DEPENDENCIES ON)

# pin the version of protobuf
set(protobuf_VERSION 29.1)

FetchContent_Declare(
protobuf
GIT_REPOSITORY https://github.com/protocolbuffers/protobuf
GIT_TAG v${protobuf_VERSION}
GIT_SHALLOW ON
)

FetchContent_MakeAvailable(protobuf)

# make protobuf_generate() function available
include(${protobuf_SOURCE_DIR}/cmake/protobuf-generate.cmake)

# iree_pjrt_protobuf_cc_library()
#
# CMake function to invoke the protoc compiler.
#
# Parameters:
# NAME: name of target
# SRCS: List of source files for the library
# PROTOC_ARGS: List of protoc arguments.
# PUBLIC: Add this so that this library will be exported under iree::
# Also in IDE, target will appear in IREE folder while non PUBLIC will be in IREE/internal.
# TESTONLY: When added, this target will only be built if user passes -DIREE_BUILD_TESTS=ON to CMake.
#
# iree_pjrt_protobuf_cc_library(
# NAME
# some_def
# SRC
# some_def.proto
# PUBLIC
# )
function(iree_pjrt_protobuf_cc_library)
cmake_parse_arguments(_RULE
"PUBLIC;TESTONLY"
"NAME"
"SRCS;PROTOC_ARGS"
${ARGN}
)

if(_RULE_TESTONLY AND NOT IREE_BUILD_TESTS)
return()
endif()

# Prefix the library with the package name, so we get: iree_package_name
iree_package_name(_PACKAGE_NAME)
set(_NAME "${_PACKAGE_NAME}_${_RULE_NAME}")

add_library(${_NAME} ${_RULE_SRCS})
protobuf_generate(
TARGET ${_NAME}
LANGUAGE cpp
PROTOC_OPTIONS ${_RULE_PROTOC_ARGS}
)
target_include_directories(${_NAME}
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}>
)
target_link_libraries(${_NAME}
PUBLIC
protobuf::libprotobuf
${IREE_DEFAULT_LINKOPTS}
)
iree_install_targets(
TARGETS ${_NAME}
)

# Alias the iree_package_name library to iree::package::name.
# This lets us more clearly map to Bazel and makes it possible to
# disambiguate the underscores in paths vs. the separators.
iree_package_ns(_PACKAGE_NS)
iree_add_alias_library(${_PACKAGE_NS}::${_RULE_NAME} ${_NAME})
endfunction()
1 change: 1 addition & 0 deletions integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ iree_cc_library(
iree::vm
iree::vm::bytecode::module
iree_pjrt_deps::headers
iree_pjrt_deps::protos
PUBLIC
)

Expand Down
48 changes: 29 additions & 19 deletions integrations/pjrt/src/iree_pjrt/common/api_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "iree_pjrt/common/api_impl.h"

#include <iterator>
#include <optional>
#include <sstream>
#include <utility>
Expand Down Expand Up @@ -1002,7 +1003,7 @@ iree_status_t DeviceInstance::TransposeBroadcastDeviceBuffer(

// Compile program and check for errors:
LoadedExecutableInstance* executable;
auto* error = this->client().Compile(&program, &executable);
auto* error = this->client().Compile(&program, {}, &executable);
if (error) {
auto errinst = ErrorInstance::FromError(error);
auto ret = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
Expand Down Expand Up @@ -1351,22 +1352,14 @@ void ClientInstance::BindApi(PJRT_Api* api) {
LoadedExecutableInstance* executable;

// Read compilation options.
// TODO: Port CompileOptionsProto into the project or leave ommitted.
// xla::CompileOptionsProto options_proto;
// if (!options_proto.ParseFromArray(args->compile_options,
// args->compile_options_size)) {
// return MakeError(iree_make_status(IREE_STATUS_INTERNAL,
// "could not parse compilation
// options"));
// }
// auto options = xla::CompileOptions::FromProto(options_proto);
// if (!options.ok()) {
// return MakeError(
// iree_make_status(IREE_STATUS_INTERNAL,
// std::string(options.status().message()).c_str()));
// }

auto* error = client->Compile(args->program, /**options,*/ &executable);
xla::CompileOptionsProto options_proto;
if (!options_proto.ParseFromArray(args->compile_options,
args->compile_options_size)) {
return MakeError(iree_make_status(IREE_STATUS_INTERNAL,
"could not parse compilation options"));
}

auto* error = client->Compile(args->program, options_proto, &executable);
if (error) return error;
args->executable = *executable;
return nullptr;
Expand Down Expand Up @@ -1451,7 +1444,7 @@ iree_status_t ClientInstance::PopulateDevices() {
}

PJRT_Error* ClientInstance::Compile(const PJRT_Program* program,
/*xla::CompileOptions options,*/
xla::CompileOptionsProto options,
LoadedExecutableInstance** out_executable) {
std::unique_ptr<ArtifactDumper::Transaction> artifact_tx;
if (platform().artifact_dumper().enabled()) {
Expand Down Expand Up @@ -1570,11 +1563,28 @@ PJRT_Error* ClientInstance::Compile(const PJRT_Program* program,
output->GetDataSize()));
}

// calculate devices for this computation from device assignment
std::vector<DeviceInstance*> devices;

const auto& build_options = options.executable_build_options();
if (build_options.has_device_assignment()) {
const auto& device_assignment = build_options.device_assignment();
for (auto id :
device_assignment.computation_devices(0).replica_device_ids()) {
if (id < addressable_devices_.size())
devices.push_back(addressable_devices_[id]);
}
}

if (devices.empty()) {
devices = addressable_devices_;
}

auto executable = std::make_unique<LoadedExecutableInstance>(
*this,
new ExecutableImage(std::move(output),
std::string(program->code, program->code_size)),
addressable_devices_);
devices);
status = executable->LoadAll();
if (!iree_status_is_ok(status)) {
return MakeError(status);
Expand Down
7 changes: 4 additions & 3 deletions integrations/pjrt/src/iree_pjrt/common/api_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "iree_pjrt/common/layout_utils.h"
#include "iree_pjrt/common/platform.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/compile_options.pb.h"

namespace iree::pjrt {

Expand Down Expand Up @@ -451,9 +452,9 @@ class ClientInstance {

// Compiles.
// See TODOs in PJRT_Client_Compile.
PJRT_Error* Compile(
const PJRT_Program* program, /*xla::CompileOptions options, */
LoadedExecutableInstance** executable);
PJRT_Error* Compile(const PJRT_Program* program,
xla::CompileOptionsProto options,
LoadedExecutableInstance** executable);

// ---------------------------------------------------------------------------
// Subclass hooks.
Expand Down
9 changes: 9 additions & 0 deletions integrations/pjrt/third_party/pjrt_c_api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,12 @@ iree_cc_library(
"xla/pjrt/c/pjrt_c_api.h"
PUBLIC
)

iree_pjrt_protobuf_cc_library(
NAME
protos
SRCS
"xla/pjrt/compile_options.proto"
"xla/xla_data.proto"
PUBLIC
)
3 changes: 2 additions & 1 deletion integrations/pjrt/third_party/pjrt_c_api/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pjrt_c_api

This directory contains a fork of C headers needed to build a PJRT plugin.
This directory contains a fork of C headers and .proto files
needed to build a PJRT plugin.

It is intended to be sync'd with upstream for major/breaking changes and
releases.
Expand Down
Loading

0 comments on commit d88d0a7

Please sign in to comment.