Skip to content

Commit

Permalink
XLA: vendor the runtime mlir backend
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 9, 2024
1 parent abd21e3 commit 513492f
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
with:
path: "~/.cache/bazel"
key: bazel-${{ matrix.os }}
- run: find ~/.cache/bazel ~/.cache/bazelisk -iname "*.whl" -exec rm {} \;
- run: sudo find ~/.cache/bazel ~/.cache/bazelisk -iname "*.whl" -exec rm {} \;
- run: |
bazel build :enzyme_ad @llvm-project//llvm:FileCheck
bazel cquery "allpaths(//src/enzyme_ad/jax:enzyme_call,@xla//xla/stream_executor:executor_cache)" --notool_deps
Expand Down
93 changes: 89 additions & 4 deletions src/enzyme_ad/jax/compile_with_xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@
#include "xla/service/service.h"
#undef protected

// Needed to access CompileXlaRuntimeCpuExecutable/etc
#define private public
#include "xla/service/cpu/cpu_compiler.h"
#undef private

#include "xla/service/compiler.h"
#include "xla/service/cpu/cpu_executable.h"
#include "xla/service/hlo_module_util.h"
#include "xla/service/hlo_proto_util.h"
#include "xla/service/local_service_utils.h"

#include "absl/status/statusor.h"
Expand All @@ -27,6 +35,8 @@
#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h"
#include "xla/translate/mhlo_to_hlo/type_to_shape.h"

#include "xla/statusor.h"

#include "pybind11/pybind11.h"

#include "compile_with_xla.h"
Expand Down Expand Up @@ -161,6 +171,80 @@ run_pass_pipeline(const std::vector<std::string> &oldsym_vec,
return std::make_pair(entryfn.str(), ss.str());
}

absl::StatusOr<std::unique_ptr<xla::Executable>>
RunBackend(xla::cpu::CpuCompiler *self, std::unique_ptr<xla::HloModule> module,
[[maybe_unused]] xla::se::StreamExecutor *stream_exec,
const xla::Compiler::CompileOptions &options, bool xla_runtime) {

std::unique_ptr<xla::cpu::CpuExecutable> cpu_executable;
if (xla_runtime) {
TF_ASSIGN_OR_RETURN(cpu_executable,
self->CompileXlaRuntimeCpuExecutable(std::move(module),
options.registry));
} else {
TF_ASSIGN_OR_RETURN(cpu_executable,
self->CompileLegacyCpuExecutable(std::move(module)));
}

return std::unique_ptr<xla::Executable>(std::move(cpu_executable));
}

absl::StatusOr<std::unique_ptr<xla::Executable>>
BuildExecutable(xla::Service *self, const xla::HloModuleProto &module_proto,
std::unique_ptr<xla::HloModuleConfig> module_config,
xla::Backend *backend, xla::se::StreamExecutor *executor,
const xla::Compiler::CompileOptions &options,
bool run_backend_only, bool xla_runtime) {

TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloModule> module,
xla::CreateModuleFromProto(module_proto, *module_config,
run_backend_only));
xla::UpdateEntryComputationLayout(
module.get(), std::bind(&xla::Compiler::DefaultDeviceShapeRepresentation,
backend->compiler(), std::placeholders::_1));
// xla::DumpHloModuleIfEnabled(*module, xla::kBeforeOptimizationsDumpName);

std::unique_ptr<xla::HloProto> hlo_proto_before_opt;
if (!run_backend_only) {
// Save proto state before optimizations if we want a snapshot.
// When run_backend_only is enabled the post-optimization HLO will be the
// same as the pre-optimization HLO.
// if (xla::DumpingEnabledForHloModule(*module)) {
// hlo_proto_before_opt =
// std::make_unique<xla::HloProto>(MakeHloProto(*module));
// }
TF_ASSIGN_OR_RETURN(module, backend->compiler()->RunHloPasses(
std::move(module), executor, options));
}

/*
TF_ASSIGN_OR_RETURN(
std::unique_ptr<xla::Executable> executable,
backend->compiler()->RunBackend(std::move(module), executor, options));
*/
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Executable> executable,
RunBackend((xla::cpu::CpuCompiler *)backend->compiler(),
std::move(module), executor, options,
xla_runtime));

const xla::BufferAssignmentProto *buffer_assignment_proto_after_opt =
executable->buffer_assignment_proto();

// If dumping is enabled RunBackend(...) will emit a hlo_proto in the
// executable. This contains the buffer_assignment that is only available
// after RunBackend(). If hlo_proto_before_opt is not null, then we replace
// its buffer_assignment with the one from after_opt and then store it into
// the executable.
if (hlo_proto_before_opt != nullptr &&
buffer_assignment_proto_after_opt != nullptr) {
// CHECK(xla::DumpingEnabledForHloModule(executable->module()));
*hlo_proto_before_opt->mutable_buffer_assignment() =
std::move(*buffer_assignment_proto_after_opt);
executable->set_hlo_proto(std::move(hlo_proto_before_opt));
}
return std::move(executable);
}

// Compile an MHLO module given as a string to LLVM IR using XLA.
std::unique_ptr<xla::LocalExecutable>
compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output,
Expand Down Expand Up @@ -288,10 +372,11 @@ compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output,
build_options.device_allocator(), build_options.compile_thread_pool(),
build_options.layout_canonicalization_callback()};
opts.registry = &registry;
auto executable = local_client->local_service()->BuildExecutable(
xla_computation.proto(), std::move(module_config_or_error.value()),
local_client->mutable_backend(), executor.value(), opts,
build_options.run_backend_only());
auto executable =
BuildExecutable(local_client->local_service(), xla_computation.proto(),
std::move(module_config_or_error.value()),
local_client->mutable_backend(), executor.value(), opts,
build_options.run_backend_only(), xla_runtime);
if (!executable.ok()) {
throw pybind11::value_error(executable.status().ToString());
}
Expand Down

0 comments on commit 513492f

Please sign in to comment.