diff --git a/torch_xla/csrc/runtime/disc/BUILD b/torch_xla/csrc/runtime/disc/BUILD index 999aa85ea64..96f0eb8d279 100755 --- a/torch_xla/csrc/runtime/disc/BUILD +++ b/torch_xla/csrc/runtime/disc/BUILD @@ -1,3 +1,4 @@ +load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library") load( "@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", @@ -37,7 +38,12 @@ ptxla_cc_library( "-DGOOGLE_CUDA", ] ) - +cc_proto_library( + name = "disc_compiler_result_proto", + srcs = [ + "compile_result.proto", + ], +) ptxla_cc_library( name = "disc_utils", srcs = ["disc_utils.cc"], @@ -58,6 +64,7 @@ ptxla_cc_library( deps = [ ":disc_ral", ":disc_utils", + ":disc_compiler_result_proto", "//torch_xla/csrc/runtime:tf_logging", "//torch_xla/csrc/runtime:sys_util", "//torch_xla/csrc/runtime:env_vars", @@ -79,3 +86,4 @@ ptxla_cc_test( "@tsl//tsl/platform:test_main", ] ) + diff --git a/torch_xla/csrc/runtime/disc/compile_result.proto b/torch_xla/csrc/runtime/disc/compile_result.proto new file mode 100644 index 00000000000..acb9d18b87c --- /dev/null +++ b/torch_xla/csrc/runtime/disc/compile_result.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package torch_xla.runtime.disc; + +option cc_enable_arenas = true; + +message DataSpec { + string device = 1; + int32 dtype = 2; +} +message DISCCompileResult { + bytes ral_library = 1; + bytes ral_meta_pb = 2; + repeated DataSpec input_specs = 3; + repeated DataSpec output_specs = 4; + repeated string devices = 5; +} \ No newline at end of file diff --git a/torch_xla/csrc/runtime/disc/disc_compile.cc b/torch_xla/csrc/runtime/disc/disc_compile.cc index 053535f5e2e..2037efa0f55 100644 --- a/torch_xla/csrc/runtime/disc/disc_compile.cc +++ b/torch_xla/csrc/runtime/disc/disc_compile.cc @@ -4,6 +4,7 @@ #include +#include "torch_xla/csrc/runtime/disc/compile_result.pb.h" #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/tf_logging.h" @@ -98,7 +99,6 @@ DISCComplationResult Compile(mlir::ModuleOp &module, res.ral_mate_pb = ReadFileBytes(absl::StrCat(output_fname, ".pbtxt")); res.inputs = inputs; res.outputs = outputs; - return res; } diff --git a/torch_xla/csrc/runtime/disc/disc_ral.h b/torch_xla/csrc/runtime/disc/disc_ral.h index f47431689c5..b850a3c6ef6 100644 --- a/torch_xla/csrc/runtime/disc/disc_ral.h +++ b/torch_xla/csrc/runtime/disc/disc_ral.h @@ -33,6 +33,7 @@ class RalContext { ~RalContext(); std::vector Execute(const std::vector& inputs); + DISCComplationResult GetDiscResult() { return disc_result_; } private: void BindingInputs(const std::vector& inputs, diff --git a/torch_xla/csrc/runtime/disc_computation_client.cc b/torch_xla/csrc/runtime/disc_computation_client.cc index 6465551dbde..44271bd1227 100644 --- a/torch_xla/csrc/runtime/disc_computation_client.cc +++ b/torch_xla/csrc/runtime/disc_computation_client.cc @@ -16,10 +16,12 @@ #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/disc/compile_result.pb.h" #include "torch_xla/csrc/runtime/disc/disc_compile.h" #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/sys_util.h" +#include "xla/client/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/float_normalization.h" #include "xla/service/gpu/gpu_float_support.h" @@ -136,8 +138,6 @@ std::vector DISCComputationClient::TransferToDevice( auto dtype = at::TensorOptions(TorchTypeFromXlaType(tensor->shape().element_type())); auto ret = at::empty(sizes, dtype).contiguous(); - // tensor->populate_fn(tensor, ret.data_ptr(), - // ret.element_size() * ret.numel()); std::memcpy(ret.data_ptr(), tensor->data(), ret.element_size() * ret.numel()); @@ -406,5 +406,72 @@ int DISCComputationClient::GetProcessIndex() const { return local_rank_; } int DISCComputationClient::GetNumProcesses() const { return world_size_; } +std::string DISCComputationClient::SerializeComputation( + const ComputationPtr computation) { + auto client = dynamic_cast(computation.get()); + auto hlo_proto = client->computation().proto(); + auto result = client->executable->GetDiscResult(); + torch_xla::runtime::disc::DISCCompileResult result_pb; + result_pb.set_ral_library(result.ral_lib); + result_pb.set_ral_meta_pb(result.ral_mate_pb); + for (const auto& input : result.inputs) { + auto data_meta = result_pb.add_input_specs(); + data_meta->set_device(input.device); + data_meta->set_dtype(static_cast(input.scalar_type)); + } + for (const auto& output : result.outputs) { + auto data_meta = result_pb.add_output_specs(); + data_meta->set_device(output.device); + data_meta->set_dtype(static_cast(output.scalar_type)); + } + for (auto device : computation->devices()) { + result_pb.add_devices(device); + } + return absl::StrCat(hlo_proto.SerializeAsString(), + ":::", result_pb.SerializeAsString()); +} +ComputationClient::ComputationPtr DISCComputationClient::DeserializeComputation( + const std::string& serialized) { + std::vector parts = absl::StrSplit(serialized, ":::"); + if (parts.size() != 2) { + XLA_ERROR() << "Invalid serialized computation, should have 2 parts with " + "separator ':::', got " + << parts.size(); + } + if (parts[1].size() > std::numeric_limits::max()) { + XLA_ERROR() << "Serialized DISCCompileResult proto too large (>2GB)\n"; + } + xla::HloModuleProto hlo_proto; + disc::DISCCompileResult result_proto; + hlo_proto.ParseFromString(parts[0]); + result_proto.ParseFromString(parts[1]); + + disc::DISCComplationResult compile_result; + compile_result.ral_lib = result_proto.ral_library(); + compile_result.ral_mate_pb = result_proto.ral_meta_pb(); + for (const auto& input : result_proto.input_specs()) { + disc::DataMeta data_meta; + data_meta.device = input.device(); + data_meta.scalar_type = static_cast(input.dtype()); + compile_result.inputs.push_back(data_meta); + } + for (const auto& output : result_proto.output_specs()) { + disc::DataMeta data_meta; + data_meta.device = output.device(); + data_meta.scalar_type = static_cast(output.dtype()); + compile_result.outputs.push_back(data_meta); + } + std::vector devices; + for (const auto& device : result_proto.devices()) { + devices.push_back(device); + } + + auto ral_context = std::make_unique(compile_result); + auto computation = std::make_shared( + std::move(xla::XlaComputation(hlo_proto)), devices, + std::move(ral_context)); + return computation; +} + } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/disc_computation_client.h b/torch_xla/csrc/runtime/disc_computation_client.h index 0701d7b3591..f95a184a41b 100644 --- a/torch_xla/csrc/runtime/disc_computation_client.h +++ b/torch_xla/csrc/runtime/disc_computation_client.h @@ -11,6 +11,8 @@ namespace runtime { class DISCComputationClient : public ComputationClient { public: + const std::string DefaultDevicePrefix = "CUDA:"; + DISCComputationClient(); ~DISCComputationClient(); @@ -55,15 +57,10 @@ class DISCComputationClient : public ComputationClient { XLA_ERROR() << __FUNCTION__ << " not implemented"; } - std::string SerializeComputation(const ComputationPtr computation) override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - } - - ComputationPtr DeserializeComputation( - const std::string& serialized) override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - } + std::string SerializeComputation(const ComputationPtr computation) override; + ComputationClient::ComputationPtr DeserializeComputation( + const std::string& serialized) override; torch::lazy::hash_t HashCompilationEnv() override { // TODO(wangang.wa): Improve this function. return torch::lazy::hash_t();