Skip to content

Commit

Permalink
PR #18287: Pass the compile options for deserialization via PJRT C API
Browse files Browse the repository at this point in the history
Imported from GitHub PR #18287

This enables JAX to supply different device assignment when deserializing single-device executables from compilation cache.

Fixes #18286.
Copybara import of the project:

--
35b505b by Jaroslav Sevcik <[email protected]>:

Pass the compile options for deserialization via PJRT C API

--
c983f61 by Jaroslav Sevcik <[email protected]>:

Add compile options comment, reorder fields

--
182b4a6 by Jaroslav Sevcik <[email protected]>:

Fix a little use-after-free

--
cee9982 by Jaroslav Sevcik <[email protected]>:

Rename field, improve comments

--
2b9fbd9 by Jaroslav Sevcik <[email protected]>:

Bump minor version, changelog update

Merging this change closes #18287

COPYBARA_INTEGRATE_REVIEW=#18287 from jaro-sevcik:deserialize-compile-options 2b9fbd9
PiperOrigin-RevId: 688990985
  • Loading branch information
jaro-sevcik authored and Google-ML-Automation committed Oct 23, 2024
1 parent 1b459fd commit 3a82522
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 2 deletions.
6 changes: 6 additions & 0 deletions xla/pjrt/c/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# PJRT C API changelog

## 0.56

* Added `overridden_serialized_compile_options` and
`overridden_serialized_compile_options_size` fields to
`PJRT_Executable_DeserializeAndLoad_Args`.

## 0.55
* Added types F8E4M3 and F8E3M4.

Expand Down
7 changes: 6 additions & 1 deletion xla/pjrt/c/pjrt_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next);
// Changes include:
// * Adding a new field to the PJRT_Api or argument structs
// * Renaming a method or argument (doesn't affect ABI)
#define PJRT_API_MINOR 55
#define PJRT_API_MINOR 56

// The plugin should set the major_version and minor_version of
// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in
Expand Down Expand Up @@ -1577,6 +1577,11 @@ struct PJRT_Executable_DeserializeAndLoad_Args {
const char* serialized_executable;
size_t serialized_executable_size;
PJRT_LoadedExecutable* loaded_executable; // out
// Serialized CompileOptionsProto or null (to use the options
// from the serialized executable).
// (https://github.com/openxla/xla/blob/main/xla/pjrt/compile_options.proto)
const char* overridden_serialized_compile_options;
size_t overridden_serialized_compile_options_size;
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_DeserializeAndLoad_Args,
loaded_executable);
Expand Down
13 changes: 12 additions & 1 deletion xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1570,9 +1570,20 @@ PJRT_Error* PJRT_Executable_DeserializeAndLoad(
absl::string_view serialized(args->serialized_executable,
args->serialized_executable_size);

std::optional<xla::CompileOptions> overriden_options;

if (args->overridden_serialized_compile_options &&
args->overridden_serialized_compile_options_size > 0) {
PJRT_ASSIGN_OR_RETURN(
overriden_options,
ParseCompileOptions(absl::string_view(
args->overridden_serialized_compile_options,
args->overridden_serialized_compile_options_size)));
}

PJRT_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtLoadedExecutable> executable,
args->client->client->DeserializeExecutable(
serialized, /*options=*/std::nullopt));
serialized, overriden_options));

args->loaded_executable =
new PJRT_LoadedExecutable(std::move(executable), args->client);
Expand Down
11 changes: 11 additions & 0 deletions xla/pjrt/pjrt_c_api_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,17 @@ PjRtCApiClient::DeserializeExecutable(absl::string_view serialized,
des_args.client = c_client_.get();
des_args.serialized_executable = serialized.data();
des_args.serialized_executable_size = serialized.length();
des_args.overridden_serialized_compile_options = nullptr;
des_args.overridden_serialized_compile_options_size = 0;

std::string options_str;
if (options) {
TF_ASSIGN_OR_RETURN(const CompileOptionsProto options_proto,
options->ToProto());
options_str = options_proto.SerializeAsString();
des_args.overridden_serialized_compile_options = options_str.c_str();
des_args.overridden_serialized_compile_options_size = options_str.size();
}

const PJRT_Api* api = pjrt_c_api();

Expand Down
37 changes: 37 additions & 0 deletions xla/pjrt/pjrt_c_api_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,5 +196,42 @@ TEST(PjRtClientTest, CompileUsesStableHloVersion) {
const_cast<PJRT_Api*>(c_api)->PJRT_Client_Compile = PJRT_Client_Compile_Orig;
}

TEST(PjRtClientTest, DeserializeExecutableWithDifferentDeviceAssignment) {
SetUpCpuPjRtApi();
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtClient> client,
GetCApiClient("cpu"));
ASSERT_GT(client->addressable_devices().size(), 1);

XlaBuilder builder("Identity");
Shape shape = ShapeUtil::MakeShape(S32, {2, 3});
auto input = Parameter(&builder, 0, shape, "input");
auto computation = builder.Build(input).value();

auto compile_options_for_device = [](int id) -> xla::CompileOptions {
xla::DeviceAssignment device_assignment(1, 1);
device_assignment(0, 0) = id;
xla::CompileOptions options;
options.executable_build_options.set_device_assignment(device_assignment);
return options;
};

// Compile the executable for device 0 and serialize it.
std::unique_ptr<PjRtLoadedExecutable> executable =
client->Compile(computation, compile_options_for_device(0)).value();
TF_ASSERT_OK_AND_ASSIGN(std::string serialized_executable,
executable->SerializeExecutable());

// Deserialize the executable for device 1.
TF_ASSERT_OK_AND_ASSIGN(
auto deserialized_executable,
client->DeserializeExecutable(serialized_executable,
compile_options_for_device(1)));

// Check that the executable's compile options were overridden
// with device id 1.
EXPECT_EQ(
deserialized_executable->addressable_devices()[0]->global_device_id(), 1);
}

} // namespace
} // namespace xla
3 changes: 3 additions & 0 deletions xla/pjrt/pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,9 @@ class PjRtClient {
// Pending completion of b/237720161, `options` is a mandatory argument in
// most implementations of this interface. They _are_ optional for
// implementations related to the PJRT C API.
//
// If `options` are provided, then they override the compile options
// from the serialized executable (`serialized`).
virtual absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
DeserializeExecutable(absl::string_view serialized,
std::optional<CompileOptions> options) {
Expand Down

0 comments on commit 3a82522

Please sign in to comment.