Skip to content

Commit

Permalink
Extract common builder code into hlo_creation_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
sergey-kozub committed Oct 2, 2024
1 parent a4cb02b commit 26a65b6
Show file tree
Hide file tree
Showing 15 changed files with 61 additions and 119 deletions.
7 changes: 7 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2572,6 +2572,7 @@ cc_library(
srcs = ["triangular_solve_expander.cc"],
hdrs = ["triangular_solve_expander.h"],
deps = [
":hlo_creation_utils",
":hlo_module_config",
":op_expander_pass",
"//xla:shape_util",
Expand Down Expand Up @@ -2614,6 +2615,7 @@ cc_library(
srcs = ["cholesky_expander.cc"],
hdrs = ["cholesky_expander.h"],
deps = [
":hlo_creation_utils",
":op_expander_pass",
"//xla:literal",
"//xla:shape_util",
Expand All @@ -2637,6 +2639,7 @@ cc_library(
srcs = ["qr_expander.cc"],
hdrs = ["qr_expander.h"],
deps = [
":hlo_creation_utils",
":op_expander_pass",
"//xla:literal",
"//xla:shape_util",
Expand Down Expand Up @@ -2692,6 +2695,7 @@ cc_library(
srcs = ["eigh_expander.cc"],
hdrs = ["eigh_expander.h"],
deps = [
":hlo_creation_utils",
":op_expander_pass",
"//xla:literal_util",
"//xla:shape_util",
Expand Down Expand Up @@ -3068,6 +3072,7 @@ cc_library(
srcs = ["bitcast_dtypes_expander.cc"],
hdrs = ["bitcast_dtypes_expander.h"],
deps = [
":hlo_creation_utils",
":hlo_module_config",
":op_expander_pass",
"//xla:literal_util",
Expand Down Expand Up @@ -7477,6 +7482,7 @@ cc_library(
srcs = ["rng_bit_generator_expander.cc"],
hdrs = ["rng_bit_generator_expander.h"],
deps = [
":hlo_creation_utils",
":op_expander_pass",
"//xla:shape_util",
"//xla:util",
Expand Down Expand Up @@ -7591,6 +7597,7 @@ cc_library(
srcs = ["topk_rewriter.cc"],
hdrs = ["topk_rewriter.h"],
deps = [
":hlo_creation_utils",
":pattern_matcher",
"//xla:shape_util",
"//xla:util",
Expand Down
11 changes: 3 additions & 8 deletions xla/service/bitcast_dtypes_expander.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/primitive_util.h"
#include "xla/service/hlo_creation_utils.h"
#include "xla/service/hlo_module_config.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
Expand Down Expand Up @@ -104,14 +105,8 @@ absl::StatusOr<HloInstruction*> BitcastDtypesExpander::ExpandInstruction(
BitcastConvertType(input, to_shape.element_type());

TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, b.Build());
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
xla_computation.GetProgramShape());
HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
xla_computation.proto(), config));
HloCloneContext context(module);
computation =
module->DeepCloneComputation(new_module->entry_computation(), &context);
TF_ASSIGN_OR_RETURN(
computation, XlaComputationToHloComputation(xla_computation, module));
}

return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
Expand Down
12 changes: 3 additions & 9 deletions xla/service/cholesky_expander.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "xla/client/xla_builder.h"
#include "xla/literal.h"
#include "xla/primitive_util.h"
#include "xla/service/hlo_creation_utils.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/util.h"
Expand Down Expand Up @@ -248,15 +249,8 @@ absl::StatusOr<HloInstruction*> CholeskyExpander::ExpandInstruction(
MaybeTransposeInMinorDims(l, !options.lower());

TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());

TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
xla_computation.GetProgramShape());
HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
xla_computation.proto(), config));
HloCloneContext context(module);
computation =
module->DeepCloneComputation(new_module->entry_computation(), &context);
TF_ASSIGN_OR_RETURN(
computation, XlaComputationToHloComputation(xla_computation, module));
}

return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
Expand Down
12 changes: 3 additions & 9 deletions xla/service/eigh_expander.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ limitations under the License.
#include "xla/client/xla_builder.h"
#include "xla/literal_util.h"
#include "xla/primitive_util.h"
#include "xla/service/hlo_creation_utils.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/util.h"
Expand Down Expand Up @@ -582,15 +583,8 @@ absl::StatusOr<HloInstruction*> EighExpander::ExpandInstruction(
}
XlaOp result = BuildEigh(a, lower, max_iter, tol, sort_eigenvalues);
TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(result));

TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
xla_computation.GetProgramShape());
HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
xla_computation.proto(), config));
HloCloneContext context(module);
computation =
module->DeepCloneComputation(new_module->entry_computation(), &context);
TF_ASSIGN_OR_RETURN(
computation, XlaComputationToHloComputation(xla_computation, module));
}

return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,7 @@ cc_library(
"//xla/client:xla_computation",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"//xla/service:hlo_creation_utils",
"//xla/service:hlo_module_config",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:cublas_cudnn",
Expand Down
31 changes: 7 additions & 24 deletions xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ limitations under the License.
#include "xla/service/gpu/cublas_cudnn.h"
#include "xla/service/gpu/cudnn_support_utils.h"
#include "xla/service/gpu/stream_executor_util.h"
#include "xla/service/hlo_creation_utils.h"
#include "xla/service/hlo_module_config.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
Expand Down Expand Up @@ -96,24 +97,6 @@ static std::vector<HloCustomCallInstruction*> GetRelevantConvs(
return convs;
}

// Converts an XlaBuilder into an HloComputation in the same module as
// `sibling_computation`.
//
// Yes, we serialize/deserialize as a proto. :)
static absl::StatusOr<HloComputation*> BuilderToHloComputation(
XlaBuilder& b, XlaOp root, HloComputation* sibling_computation) {
TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root));
TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comp.GetProgramShape());
HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(auto new_module,
HloModule::CreateFromProto(comp.proto(), config));

HloModule* dest_module = sibling_computation->parent();
HloCloneContext context(dest_module);
return dest_module->DeepCloneComputation(new_module->entry_computation(),
&context);
}

// Reshapes `instr` so that it has an extra dimension of size `vect_size` right
// after `dim`.
static XlaOp SplitAtDim(XlaOp instr, int64_t dim, int64_t vect_size) {
Expand Down Expand Up @@ -460,11 +443,11 @@ static absl::StatusOr<bool> TryRevectorizeConv(
new_conv_result, dnums->output_feature_dimension(), *output_vect_dim,
/*orig_vect_size=*/output_shape.dimensions(*output_vect_dim));

XlaOp root = Tuple(&b, {new_conv_result_unrevectorized, new_conv_scratch});
TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root));
TF_ASSIGN_OR_RETURN(
HloComputation * new_conv_comp,
BuilderToHloComputation(
b, Tuple(&b, {new_conv_result_unrevectorized, new_conv_scratch}),
conv->parent()));
XlaComputationToHloComputation(comp, conv->parent()->parent()));

// Set the name on the new conv. This is purely cosmetic, but we attempt to
// preserve e.g. "cudnn-conv.42" instead of "custom-call.42".
Expand Down Expand Up @@ -599,11 +582,11 @@ static absl::StatusOr<bool> TryVectorizeConv(
Collapse(new_conv_result, {dnums->output_feature_dimension(),
dnums->output_feature_dimension() + 1});

XlaOp root = Tuple(&b, {conv_result_collapsed, new_conv_scratch});
TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root));
TF_ASSIGN_OR_RETURN(
HloComputation * new_conv_comp,
BuilderToHloComputation(
b, Tuple(&b, {conv_result_collapsed, new_conv_scratch}),
conv->parent()));
XlaComputationToHloComputation(comp, conv->parent()->parent()));

// Create a tuple and replace the old conv with it!
VLOG(1) << "Vectorized conv to: " << new_conv_comp->ToString();
Expand Down
21 changes: 13 additions & 8 deletions xla/service/hlo_creation_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -597,12 +597,22 @@ HloInstruction* MaybeMakeTuple(absl::Span<HloInstruction* const> operands) {
HloInstruction::CreateTuple(operands));
}

absl::StatusOr<HloComputation*> XlaComputationToHloComputation(
XlaComputation& src_comp, HloModule* dest_module) {
TF_ASSIGN_OR_RETURN(ProgramShape program_shape, src_comp.GetProgramShape());
HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(auto new_module,
HloModule::CreateFromProto(src_comp.proto(), config));
HloCloneContext context(dest_module);
return dest_module->DeepCloneComputation(new_module->entry_computation(),
&context);
}

absl::StatusOr<HloInstruction*> MakeSortHlo(
const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
int64_t dimension_to_sort, bool is_stable, HloComputation::Builder* builder,
HloModule* module, const OpMetadata* metadata) {
CHECK(!operands.empty()) << "Sort Hlo requires at least one operand.";
HloComputation* compare_computation;
XlaBuilder b("Sort.Compare");
if (metadata != nullptr) {
b.SetOpMetadata(*metadata);
Expand All @@ -612,13 +622,8 @@ absl::StatusOr<HloInstruction*> MakeSortHlo(
operand_types[i] = operands[i]->shape().element_type();
}
XlaComputation comparator = CreateScalarLtComputation(operand_types, &b);
TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape());
HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(auto new_module,
HloModule::CreateFromProto(comparator.proto(), config));
HloCloneContext context(module);
compare_computation =
module->DeepCloneComputation(new_module->entry_computation(), &context);
TF_ASSIGN_OR_RETURN(HloComputation * compare_computation,
XlaComputationToHloComputation(comparator, module));
return builder->AddInstruction(HloInstruction::CreateSort(
sort_shape, dimension_to_sort, operands, compare_computation, is_stable));
}
Expand Down
6 changes: 6 additions & 0 deletions xla/service/hlo_creation_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include <vector>

#include "absl/types/span.h"
#include "xla/client/xla_computation.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/literal_util.h"
Expand Down Expand Up @@ -257,6 +258,11 @@ absl::StatusOr<HloInstruction*> MakeSelectHlo(
// instruction with all the operands. Crashes if `operands` is empty.
HloInstruction* MaybeMakeTuple(absl::Span<HloInstruction* const> operands);

// Creates a HloComputation in the destination module from a builder's
// XlaComputation.
absl::StatusOr<HloComputation*> XlaComputationToHloComputation(
XlaComputation& src_comp, HloModule* dest_module);

// Creates a Sort HLO instruction and adds it to the computation containing the
// operands. All operands must be in the same computation. Also creates a
// default compare sub-computation which sorts the first operand into ascending
Expand Down
12 changes: 3 additions & 9 deletions xla/service/qr_expander.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "xla/client/xla_builder.h"
#include "xla/literal.h"
#include "xla/primitive_util.h"
#include "xla/service/hlo_creation_utils.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/util.h"
Expand Down Expand Up @@ -551,15 +552,8 @@ absl::StatusOr<HloInstruction*> QrExpander::ExpandInstruction(
}

TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(result));

TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
xla_computation.GetProgramShape());
HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
xla_computation.proto(), config));
HloCloneContext context(module);
computation =
module->DeepCloneComputation(new_module->entry_computation(), &context);
TF_ASSIGN_OR_RETURN(
computation, XlaComputationToHloComputation(xla_computation, module));
}

return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
Expand Down
12 changes: 3 additions & 9 deletions xla/service/rng_bit_generator_expander.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/hlo_creation_utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/util.h"
Expand Down Expand Up @@ -86,15 +87,8 @@ RngBitGeneratorExpander::GetGeneratorComputation(const Shape& data_shape,
ConcatInDim(&builder, {Reshape(key_op, {1}), output.state}, 0);
Tuple(&builder, {final_state, output.value});
TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());

TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
xla_computation.GetProgramShape());
HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
xla_computation.proto(), config));
HloCloneContext context(module);
HloComputation* new_computation =
module->DeepCloneComputation(new_module->entry_computation(), &context);
TF_ASSIGN_OR_RETURN(HloComputation * new_computation,
XlaComputationToHloComputation(xla_computation, module));
computation_cache_.emplace(cache_key, new_computation);
return new_computation;
}
Expand Down
11 changes: 1 addition & 10 deletions xla/service/rng_expander.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,7 @@ absl::StatusOr<HloComputation*> GetComputationForRng(HloInstruction* rng) {
}

TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());

TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
xla_computation.GetProgramShape());
HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
xla_computation.proto(), config));
HloModule* module = rng->GetModule();
HloCloneContext context(module);
return module->DeepCloneComputation(new_module->entry_computation(),
&context);
return XlaComputationToHloComputation(xla_computation, rng->GetModule());
}

} // namespace
Expand Down
1 change: 1 addition & 0 deletions xla/service/spmd/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ cc_library(
"//xla/service:custom_call_sharding_helper",
"//xla/service:dot_as_convolution_util",
"//xla/service:flatten_call_graph",
"//xla/service:hlo_creation_utils",
"//xla/service:hlo_cse",
"//xla/service:hlo_dce",
"//xla/service:hlo_lexer",
Expand Down
10 changes: 3 additions & 7 deletions xla/service/spmd/custom_call_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ limitations under the License.
#include "xla/hlo/utils/hlo_sharding_util.h"
#include "xla/literal_util.h"
#include "xla/service/custom_call_sharding_helper.h"
#include "xla/service/hlo_creation_utils.h"
#include "xla/service/hlo_lexer.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/host_memory_offload_annotations.h"
Expand Down Expand Up @@ -207,13 +208,8 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCallTopK(
XlaComputation comparator = CreateScalarComparisonComputation(
"compare-value-and-index", {input->shape().element_type(), S32}, {Gt, Lt},
&b);
TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape());
HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(auto new_module,
HloModule::CreateFromProto(comparator.proto(), config));
HloCloneContext context(module_);
auto compare_computation =
module_->DeepCloneComputation(new_module->entry_computation(), &context);
TF_ASSIGN_OR_RETURN(HloComputation * compare_computation,
XlaComputationToHloComputation(comparator, module_));
// Each partition needs to do TopK separately, thus the base shape for sort
// becomes [ceil(batch_size / batch_dim_partition), k * shard_count].
const Shape sort_shape = ShapeUtil::MakeTupleShape(
Expand Down
Loading

0 comments on commit 26a65b6

Please sign in to comment.