Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pjrt] Add PjRtClient::CompileAndLoad variant #23534

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions xla/pjrt/cpu/cpu_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -680,8 +680,8 @@ static absl::StatusOr<std::unique_ptr<xla::Executable>> JitCompile(
compile_options);
}

absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> TfrtCpuClient::Compile(
mlir::ModuleOp module, CompileOptions options) {
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
TfrtCpuClient::CompileAndLoad(mlir::ModuleOp module, CompileOptions options) {
XlaComputation xla_computation;
const ExecutableBuildOptions& exec_build_options =
options.executable_build_options;
Expand Down Expand Up @@ -728,8 +728,9 @@ absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> TfrtCpuClient::Compile(
layout_callback, options);
}

absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> TfrtCpuClient::Compile(
const XlaComputation& computation, CompileOptions options) {
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
TfrtCpuClient::CompileAndLoad(const XlaComputation& computation,
CompileOptions options) {
std::vector<const Shape*> argument_layout_pointers;
const ExecutableBuildOptions& build_options =
options.executable_build_options;
Expand Down
4 changes: 2 additions & 2 deletions xla/pjrt/cpu/cpu_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ class TfrtCpuClient final : public PjRtClient {
absl::StatusOr<std::unique_ptr<HloCostAnalysis>> GetHloCostAnalysis()
const override;

absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> CompileAndLoad(
const XlaComputation& computation, CompileOptions options) override;
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> CompileAndLoad(
mlir::ModuleOp module, CompileOptions options) override;

// For TfrtCpuClient, `options` is mandatory.
Expand Down
7 changes: 4 additions & 3 deletions xla/pjrt/interpreter/interpreter_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,8 @@ absl::StatusOr<Layout> InterpreterClient::GetDefaultLayout(
}

absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
InterpreterClient::Compile(const XlaComputation& computation,
CompileOptions options) {
InterpreterClient::CompileAndLoad(const XlaComputation& computation,
CompileOptions options) {
std::vector<const Shape*> argument_layout_pointers;
const ExecutableBuildOptions& build_options =
options.executable_build_options;
Expand All @@ -356,7 +356,8 @@ InterpreterClient::Compile(const XlaComputation& computation,
}

absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
InterpreterClient::Compile(mlir::ModuleOp module, CompileOptions options) {
InterpreterClient::CompileAndLoad(mlir::ModuleOp module,
CompileOptions options) {
XlaComputation xla_computation;
const ExecutableBuildOptions& exec_build_options =
options.executable_build_options;
Expand Down
4 changes: 2 additions & 2 deletions xla/pjrt/interpreter/interpreter_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -453,10 +453,10 @@ class InterpreterClient final : public PjRtClient {
return std::make_unique<HloCostAnalysis>(ShapeSizeBytes);
}

absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> CompileAndLoad(
const XlaComputation& computation, CompileOptions options) override;

absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> CompileAndLoad(
mlir::ModuleOp module, CompileOptions options) override;

absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
Expand Down
9 changes: 5 additions & 4 deletions xla/pjrt/pjrt_c_api_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,16 +384,17 @@ InitializeArgsAndCompile(PjRtCApiClient* api_client, const PJRT_Api* c_api,
return ret;
}

absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> PjRtCApiClient::Compile(
const XlaComputation& computation, CompileOptions options) {
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
PjRtCApiClient::CompileAndLoad(const XlaComputation& computation,
CompileOptions options) {
std::string module_str = computation.proto().SerializeAsString();
std::string format(pjrt::kHloFormat);
return InitializeArgsAndCompile(this, c_api_, c_client_.get(), options,
module_str, format);
}

absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> PjRtCApiClient::Compile(
mlir::ModuleOp module, CompileOptions options) {
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
PjRtCApiClient::CompileAndLoad(mlir::ModuleOp module, CompileOptions options) {
if (!pjrt_c_api()) llvm::report_fatal_error("pjrt_c_api is null");

auto attributes = plugin_attributes()->attributes;
Expand Down
4 changes: 2 additions & 2 deletions xla/pjrt/pjrt_c_api_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,10 @@ class PjRtCApiClient : public PjRtClient {
absl::StatusOr<Layout> GetDefaultLayout(
PrimitiveType element_type, absl::Span<const int64_t> dims) override;

absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> CompileAndLoad(
const XlaComputation& computation, CompileOptions options) override;

absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> CompileAndLoad(
mlir::ModuleOp module, CompileOptions options) override;

// `PjRtCApiClient::DeserializeExecutable()` ignores `CompileOptions` arg
Expand Down
10 changes: 9 additions & 1 deletion xla/pjrt/pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -596,12 +596,20 @@ class PjRtClient {
// Compile `computation` with given `options`.
virtual absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
const XlaComputation& computation, CompileOptions options) {
return Unimplemented("Compile with options is not supported.");
return CompileAndLoad(computation, options);
}
virtual absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> CompileAndLoad(
const XlaComputation& computation, CompileOptions options) {
return Unimplemented("Compile with computation is not supported.");
}

// Variant of `Compile` that accepts an MLIR module.
virtual absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
mlir::ModuleOp module, CompileOptions options) {
return CompileAndLoad(module, options);
}
virtual absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> CompileAndLoad(
mlir::ModuleOp module, CompileOptions options) {
return Unimplemented("Compile with MLIR Module is not supported.");
}

Expand Down
8 changes: 4 additions & 4 deletions xla/pjrt/pjrt_stream_executor_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3525,8 +3525,8 @@ PjRtStreamExecutorClient::CompileInternal(
}

absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
PjRtStreamExecutorClient::Compile(mlir::ModuleOp module,
CompileOptions options) {
PjRtStreamExecutorClient::CompileAndLoad(mlir::ModuleOp module,
CompileOptions options) {
XlaComputation xla_computation;
const ExecutableBuildOptions& exec_build_options =
options.executable_build_options;
Expand Down Expand Up @@ -3586,8 +3586,8 @@ PjRtStreamExecutorClient::Compile(mlir::ModuleOp module,
}

absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
PjRtStreamExecutorClient::Compile(const XlaComputation& computation,
CompileOptions options) {
PjRtStreamExecutorClient::CompileAndLoad(const XlaComputation& computation,
CompileOptions options) {
std::vector<const Shape*> argument_layout_pointers;
const ExecutableBuildOptions& build_options =
options.executable_build_options;
Expand Down
4 changes: 2 additions & 2 deletions xla/pjrt/pjrt_stream_executor_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,9 @@ class PjRtStreamExecutorClient : public PjRtClient {
absl::StatusOr<Layout> GetDefaultLayout(
PrimitiveType element_type, absl::Span<const int64_t> dims) override;

absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> CompileAndLoad(
const XlaComputation& computation, CompileOptions options) override;
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> CompileAndLoad(
mlir::ModuleOp mlir_module, CompileOptions options) override;

virtual absl::StatusOr<std::string> SerializeExecutable(
Expand Down
8 changes: 4 additions & 4 deletions xla/pjrt/tf_pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,13 @@ class TfPjRtClient : public PjRtClient {
const override {
return wrapped_->GetHloCostAnalysis();
}
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> CompileAndLoad(
const XlaComputation& computation, CompileOptions options) override {
return WrapExecutable(wrapped_->Compile(computation, options));
return WrapExecutable(wrapped_->CompileAndLoad(computation, options));
}
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> CompileAndLoad(
mlir::ModuleOp module, CompileOptions options) override {
return WrapExecutable(wrapped_->Compile(std::move(module), options));
return WrapExecutable(wrapped_->CompileAndLoad(std::move(module), options));
}

absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> DeserializeExecutable(
Expand Down
Loading