diff --git a/xla/service/BUILD b/xla/service/BUILD index e72fff7038d06c..796215cb8a735e 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -2572,6 +2572,7 @@ cc_library( srcs = ["triangular_solve_expander.cc"], hdrs = ["triangular_solve_expander.h"], deps = [ + ":hlo_creation_utils", ":hlo_module_config", ":op_expander_pass", "//xla:shape_util", @@ -2614,6 +2615,7 @@ cc_library( srcs = ["cholesky_expander.cc"], hdrs = ["cholesky_expander.h"], deps = [ + ":hlo_creation_utils", ":op_expander_pass", "//xla:literal", "//xla:shape_util", @@ -2637,6 +2639,7 @@ cc_library( srcs = ["qr_expander.cc"], hdrs = ["qr_expander.h"], deps = [ + ":hlo_creation_utils", ":op_expander_pass", "//xla:literal", "//xla:shape_util", @@ -2692,6 +2695,7 @@ cc_library( srcs = ["eigh_expander.cc"], hdrs = ["eigh_expander.h"], deps = [ + ":hlo_creation_utils", ":op_expander_pass", "//xla:literal_util", "//xla:shape_util", @@ -3068,6 +3072,7 @@ cc_library( srcs = ["bitcast_dtypes_expander.cc"], hdrs = ["bitcast_dtypes_expander.h"], deps = [ + ":hlo_creation_utils", ":hlo_module_config", ":op_expander_pass", "//xla:literal_util", @@ -7477,6 +7482,7 @@ cc_library( srcs = ["rng_bit_generator_expander.cc"], hdrs = ["rng_bit_generator_expander.h"], deps = [ + ":hlo_creation_utils", ":op_expander_pass", "//xla:shape_util", "//xla:util", @@ -7591,6 +7597,7 @@ cc_library( srcs = ["topk_rewriter.cc"], hdrs = ["topk_rewriter.h"], deps = [ + ":hlo_creation_utils", ":pattern_matcher", "//xla:shape_util", "//xla:util", diff --git a/xla/service/bitcast_dtypes_expander.cc b/xla/service/bitcast_dtypes_expander.cc index 6a5d12c25c32fa..e12f69f56bc4e2 100644 --- a/xla/service/bitcast_dtypes_expander.cc +++ b/xla/service/bitcast_dtypes_expander.cc @@ -26,6 +26,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -104,14 +105,8 @@ absl::StatusOr BitcastDtypesExpander::ExpandInstruction( BitcastConvertType(input, to_shape.element_type()); TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, b.Build()); - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - xla_computation.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( - xla_computation.proto(), config)); - HloCloneContext context(module); - computation = - module->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN( + computation, XlaComputationToHloComputation(xla_computation, module)); } return instruction->parent()->AddInstruction(HloInstruction::CreateCall( diff --git a/xla/service/cholesky_expander.cc b/xla/service/cholesky_expander.cc index d70e0211103fff..d5a3053e7168a3 100644 --- a/xla/service/cholesky_expander.cc +++ b/xla/service/cholesky_expander.cc @@ -28,6 +28,7 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/literal.h" #include "xla/primitive_util.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" @@ -248,15 +249,8 @@ absl::StatusOr CholeskyExpander::ExpandInstruction( MaybeTransposeInMinorDims(l, !options.lower()); TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); - - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - xla_computation.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( - xla_computation.proto(), config)); - HloCloneContext context(module); - computation = - module->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN( + computation, XlaComputationToHloComputation(xla_computation, module)); } return instruction->parent()->AddInstruction(HloInstruction::CreateCall( diff --git a/xla/service/eigh_expander.cc b/xla/service/eigh_expander.cc index e95b268c1f3d8b..e34dbbf96d22c0 100644 --- a/xla/service/eigh_expander.cc +++ b/xla/service/eigh_expander.cc @@ -34,6 +34,7 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" @@ -582,15 +583,8 @@ absl::StatusOr EighExpander::ExpandInstruction( } XlaOp result = BuildEigh(a, lower, max_iter, tol, sort_eigenvalues); TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(result)); - - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - xla_computation.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( - xla_computation.proto(), config)); - HloCloneContext context(module); - computation = - module->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN( + computation, XlaComputationToHloComputation(xla_computation, module)); } return instruction->parent()->AddInstruction(HloInstruction::CreateCall( diff --git a/xla/service/gpu/transforms/BUILD b/xla/service/gpu/transforms/BUILD index ed276412453607..d81d7b0478ecba 100644 --- a/xla/service/gpu/transforms/BUILD +++ b/xla/service/gpu/transforms/BUILD @@ -1191,6 +1191,7 @@ cc_library( "//xla/client:xla_computation", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_creation_utils", "//xla/service:hlo_module_config", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", diff --git a/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc b/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc index 698b8fb73dd579..d35fcd105844d8 100644 --- a/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc +++ b/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc @@ -40,6 +40,7 @@ limitations under the License. #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/cudnn_support_utils.h" #include "xla/service/gpu/stream_executor_util.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -96,24 +97,6 @@ static std::vector GetRelevantConvs( return convs; } -// Converts an XlaBuilder into an HloComputation in the same module as -// `sibling_computation`. -// -// Yes, we serialize/deserialize as a proto. :) -static absl::StatusOr BuilderToHloComputation( - XlaBuilder& b, XlaOp root, HloComputation* sibling_computation) { - TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root)); - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comp.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, - HloModule::CreateFromProto(comp.proto(), config)); - - HloModule* dest_module = sibling_computation->parent(); - HloCloneContext context(dest_module); - return dest_module->DeepCloneComputation(new_module->entry_computation(), - &context); -} - // Reshapes `instr` so that it has an extra dimension of size `vect_size` right // after `dim`. static XlaOp SplitAtDim(XlaOp instr, int64_t dim, int64_t vect_size) { @@ -460,11 +443,11 @@ static absl::StatusOr TryRevectorizeConv( new_conv_result, dnums->output_feature_dimension(), *output_vect_dim, /*orig_vect_size=*/output_shape.dimensions(*output_vect_dim)); + XlaOp root = Tuple(&b, {new_conv_result_unrevectorized, new_conv_scratch}); + TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root)); TF_ASSIGN_OR_RETURN( HloComputation * new_conv_comp, - BuilderToHloComputation( - b, Tuple(&b, {new_conv_result_unrevectorized, new_conv_scratch}), - conv->parent())); + XlaComputationToHloComputation(comp, conv->parent()->parent())); // Set the name on the new conv. This is purely cosmetic, but we attempt to // preserve e.g. "cudnn-conv.42" instead of "custom-call.42". @@ -599,11 +582,11 @@ static absl::StatusOr TryVectorizeConv( Collapse(new_conv_result, {dnums->output_feature_dimension(), dnums->output_feature_dimension() + 1}); + XlaOp root = Tuple(&b, {conv_result_collapsed, new_conv_scratch}); + TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root)); TF_ASSIGN_OR_RETURN( HloComputation * new_conv_comp, - BuilderToHloComputation( - b, Tuple(&b, {conv_result_collapsed, new_conv_scratch}), - conv->parent())); + XlaComputationToHloComputation(comp, conv->parent()->parent())); // Create a tuple and replace the old conv with it! VLOG(1) << "Vectorized conv to: " << new_conv_comp->ToString(); diff --git a/xla/service/hlo_creation_utils.cc b/xla/service/hlo_creation_utils.cc index a94e23d21066e5..c9b5d5b2be361e 100644 --- a/xla/service/hlo_creation_utils.cc +++ b/xla/service/hlo_creation_utils.cc @@ -597,12 +597,22 @@ HloInstruction* MaybeMakeTuple(absl::Span operands) { HloInstruction::CreateTuple(operands)); } +absl::StatusOr XlaComputationToHloComputation( + XlaComputation& src_comp, HloModule* dest_module) { + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, src_comp.GetProgramShape()); + HloModuleConfig config(program_shape); + TF_ASSIGN_OR_RETURN(auto new_module, + HloModule::CreateFromProto(src_comp.proto(), config)); + HloCloneContext context(dest_module); + return dest_module->DeepCloneComputation(new_module->entry_computation(), + &context); +} + absl::StatusOr MakeSortHlo( const Shape& sort_shape, absl::Span operands, int64_t dimension_to_sort, bool is_stable, HloComputation::Builder* builder, HloModule* module, const OpMetadata* metadata) { CHECK(!operands.empty()) << "Sort Hlo requires at least one operand."; - HloComputation* compare_computation; XlaBuilder b("Sort.Compare"); if (metadata != nullptr) { b.SetOpMetadata(*metadata); @@ -612,13 +622,8 @@ absl::StatusOr MakeSortHlo( operand_types[i] = operands[i]->shape().element_type(); } XlaComputation comparator = CreateScalarLtComputation(operand_types, &b); - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, - HloModule::CreateFromProto(comparator.proto(), config)); - HloCloneContext context(module); - compare_computation = - module->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN(HloComputation * compare_computation, + XlaComputationToHloComputation(comparator, module)); return builder->AddInstruction(HloInstruction::CreateSort( sort_shape, dimension_to_sort, operands, compare_computation, is_stable)); } diff --git a/xla/service/hlo_creation_utils.h b/xla/service/hlo_creation_utils.h index 2db4a7045fc0e2..d9599663ea7fea 100644 --- a/xla/service/hlo_creation_utils.h +++ b/xla/service/hlo_creation_utils.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/literal_util.h" @@ -257,6 +258,11 @@ absl::StatusOr MakeSelectHlo( // instruction with all the operands. Crashes if `operands` is empty. HloInstruction* MaybeMakeTuple(absl::Span operands); +// Creates a HloComputation in the destination module from a builder's +// XlaComputation. +absl::StatusOr XlaComputationToHloComputation( + XlaComputation& src_comp, HloModule* dest_module); + // Creates a Sort HLO instruction and adds it to the computation containing the // operands. All operands must be in the same computation. Also creates a // default compare sub-computation which sorts the first operand into ascending diff --git a/xla/service/qr_expander.cc b/xla/service/qr_expander.cc index 4f79769d7c6bf8..7f32a0bc1628bd 100644 --- a/xla/service/qr_expander.cc +++ b/xla/service/qr_expander.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/literal.h" #include "xla/primitive_util.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" @@ -551,15 +552,8 @@ absl::StatusOr QrExpander::ExpandInstruction( } TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(result)); - - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - xla_computation.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( - xla_computation.proto(), config)); - HloCloneContext context(module); - computation = - module->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN( + computation, XlaComputationToHloComputation(xla_computation, module)); } return instruction->parent()->AddInstruction(HloInstruction::CreateCall( diff --git a/xla/service/rng_bit_generator_expander.cc b/xla/service/rng_bit_generator_expander.cc index 0d78762f47b964..88758e3c0b3667 100644 --- a/xla/service/rng_bit_generator_expander.cc +++ b/xla/service/rng_bit_generator_expander.cc @@ -23,6 +23,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" @@ -86,15 +87,8 @@ RngBitGeneratorExpander::GetGeneratorComputation(const Shape& data_shape, ConcatInDim(&builder, {Reshape(key_op, {1}), output.state}, 0); Tuple(&builder, {final_state, output.value}); TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); - - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - xla_computation.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( - xla_computation.proto(), config)); - HloCloneContext context(module); - HloComputation* new_computation = - module->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN(HloComputation * new_computation, + XlaComputationToHloComputation(xla_computation, module)); computation_cache_.emplace(cache_key, new_computation); return new_computation; } diff --git a/xla/service/rng_expander.cc b/xla/service/rng_expander.cc index cbc5a1d4549db9..294916f8fb68a9 100644 --- a/xla/service/rng_expander.cc +++ b/xla/service/rng_expander.cc @@ -111,16 +111,7 @@ absl::StatusOr GetComputationForRng(HloInstruction* rng) { } TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); - - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - xla_computation.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( - xla_computation.proto(), config)); - HloModule* module = rng->GetModule(); - HloCloneContext context(module); - return module->DeepCloneComputation(new_module->entry_computation(), - &context); + return XlaComputationToHloComputation(xla_computation, rng->GetModule()); } } // namespace diff --git a/xla/service/spmd/BUILD b/xla/service/spmd/BUILD index 6c2a9321819535..fccd701ab5277c 100644 --- a/xla/service/spmd/BUILD +++ b/xla/service/spmd/BUILD @@ -61,6 +61,7 @@ cc_library( "//xla/service:custom_call_sharding_helper", "//xla/service:dot_as_convolution_util", "//xla/service:flatten_call_graph", + "//xla/service:hlo_creation_utils", "//xla/service:hlo_cse", "//xla/service:hlo_dce", "//xla/service:hlo_lexer", diff --git a/xla/service/spmd/custom_call_handler.cc b/xla/service/spmd/custom_call_handler.cc index dab26f5985a0c5..018a6f7f444337 100644 --- a/xla/service/spmd/custom_call_handler.cc +++ b/xla/service/spmd/custom_call_handler.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/literal_util.h" #include "xla/service/custom_call_sharding_helper.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/service/hlo_lexer.h" #include "xla/service/hlo_module_config.h" #include "xla/service/host_memory_offload_annotations.h" @@ -207,13 +208,8 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCallTopK( XlaComputation comparator = CreateScalarComparisonComputation( "compare-value-and-index", {input->shape().element_type(), S32}, {Gt, Lt}, &b); - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, - HloModule::CreateFromProto(comparator.proto(), config)); - HloCloneContext context(module_); - auto compare_computation = - module_->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN(HloComputation * compare_computation, + XlaComputationToHloComputation(comparator, module_)); // Each partition needs to do TopK separately, thus the base shape for sort // becomes [ceil(batch_size / batch_dim_partition), k * shard_count]. const Shape sort_shape = ShapeUtil::MakeTupleShape( diff --git a/xla/service/topk_rewriter.cc b/xla/service/topk_rewriter.cc index bb65d436acedbd..d25076f2bc938e 100644 --- a/xla/service/topk_rewriter.cc +++ b/xla/service/topk_rewriter.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/primitive_util.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/service/pattern_matcher.h" #include "xla/shape_util.h" #include "xla/util.h" @@ -41,20 +42,6 @@ namespace xla { namespace m = match; -// TODO(cheshire): Avoid duplication w/ cudnn_vectorize_convolutions. -static absl::StatusOr BuilderToHloComputation( - XlaComputation& comp, HloComputation* sibling_computation) { - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comp.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, - HloModule::CreateFromProto(comp.proto(), config)); - - HloModule* dest_module = sibling_computation->parent(); - HloCloneContext context(dest_module); - return dest_module->DeepCloneComputation(new_module->entry_computation(), - &context); -} - static bool IsNanSafeGt(HloComputation* comp) { namespace m = match; auto match_bitcast_f32 = [](int64_t parameter_number) { @@ -500,9 +487,9 @@ class TopkDecomposerVisitor : public DfsHloRewriteVisitor { XlaComputation comparison = topk->largest() ? CreateScalarGtComputation(ptypes, &b) : CreateScalarLtComputation(ptypes, &b); - - TF_ASSIGN_OR_RETURN(HloComputation * comparator, - BuilderToHloComputation(comparison, topk->parent())); + TF_ASSIGN_OR_RETURN( + HloComputation * comparator, + XlaComputationToHloComputation(comparison, topk->parent()->parent())); return comparator; } diff --git a/xla/service/triangular_solve_expander.cc b/xla/service/triangular_solve_expander.cc index c61dc148c0ec33..3bc7ba36b60b81 100644 --- a/xla/service/triangular_solve_expander.cc +++ b/xla/service/triangular_solve_expander.cc @@ -34,6 +34,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -599,15 +600,8 @@ absl::StatusOr TriangularSolveExpander::ExpandInstruction( /*block_size=*/block_size_, /*precision=*/PrecisionConfig::HIGHEST); TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); - - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - xla_computation.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( - xla_computation.proto(), config)); - HloCloneContext context(module); - computation = - module->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN( + computation, XlaComputationToHloComputation(xla_computation, module)); } return instruction->parent()->AddInstruction(HloInstruction::CreateCall(