Skip to content

Commit

Permalink
PR #17754: [NFC] Extract common builder code into hlo_creation_utils
Browse files Browse the repository at this point in the history
Imported from GitHub PR #17754

Added `XlaComputationToHloComputation` helper function to `hlo_creation_utils` and replace the uses.

Copybara import of the project:

--
f9cbfd2 by Sergey Kozub <[email protected]>:

Extract common builder code into hlo_creation_utils

Merging this change closes #17754

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17754 from openxla:skozub/xla_builder f9cbfd2
PiperOrigin-RevId: 680912058
  • Loading branch information
sergey-kozub authored and Google-ML-Automation committed Oct 7, 2024
1 parent a53b7e7 commit bd4ab19
Show file tree
Hide file tree
Showing 20 changed files with 226 additions and 135 deletions.
7 changes: 7 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2575,6 +2575,7 @@ cc_library(
srcs = ["triangular_solve_expander.cc"],
hdrs = ["triangular_solve_expander.h"],
deps = [
":hlo_creation_utils",
":hlo_module_config",
":op_expander_pass",
"//xla:shape_util",
Expand Down Expand Up @@ -2617,6 +2618,7 @@ cc_library(
srcs = ["cholesky_expander.cc"],
hdrs = ["cholesky_expander.h"],
deps = [
":hlo_creation_utils",
":op_expander_pass",
"//xla:literal",
"//xla:shape_util",
Expand All @@ -2640,6 +2642,7 @@ cc_library(
srcs = ["qr_expander.cc"],
hdrs = ["qr_expander.h"],
deps = [
":hlo_creation_utils",
":op_expander_pass",
"//xla:literal",
"//xla:shape_util",
Expand Down Expand Up @@ -2695,6 +2698,7 @@ cc_library(
srcs = ["eigh_expander.cc"],
hdrs = ["eigh_expander.h"],
deps = [
":hlo_creation_utils",
":op_expander_pass",
"//xla:literal_util",
"//xla:shape_util",
Expand Down Expand Up @@ -3071,6 +3075,7 @@ cc_library(
srcs = ["bitcast_dtypes_expander.cc"],
hdrs = ["bitcast_dtypes_expander.h"],
deps = [
":hlo_creation_utils",
":hlo_module_config",
":op_expander_pass",
"//xla:literal_util",
Expand Down Expand Up @@ -7483,6 +7488,7 @@ cc_library(
srcs = ["rng_bit_generator_expander.cc"],
hdrs = ["rng_bit_generator_expander.h"],
deps = [
":hlo_creation_utils",
":op_expander_pass",
"//xla:shape_util",
"//xla:util",
Expand Down Expand Up @@ -7597,6 +7603,7 @@ cc_library(
srcs = ["topk_rewriter.cc"],
hdrs = ["topk_rewriter.h"],
deps = [
":hlo_creation_utils",
":pattern_matcher",
"//xla:shape_util",
"//xla:util",
Expand Down
11 changes: 3 additions & 8 deletions xla/service/bitcast_dtypes_expander.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/primitive_util.h"
#include "xla/service/hlo_creation_utils.h"
#include "xla/service/hlo_module_config.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
Expand Down Expand Up @@ -104,14 +105,8 @@ absl::StatusOr<HloInstruction*> BitcastDtypesExpander::ExpandInstruction(
BitcastConvertType(input, to_shape.element_type());

TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, b.Build());
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
xla_computation.GetProgramShape());
HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
xla_computation.proto(), config));
HloCloneContext context(module);
computation =
module->DeepCloneComputation(new_module->entry_computation(), &context);
TF_ASSIGN_OR_RETURN(
computation, XlaComputationToHloComputation(xla_computation, module));
}

return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
Expand Down
12 changes: 3 additions & 9 deletions xla/service/cholesky_expander.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "xla/hlo/builder/xla_builder.h"
#include "xla/literal.h"
#include "xla/primitive_util.h"
#include "xla/service/hlo_creation_utils.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/util.h"
Expand Down Expand Up @@ -248,15 +249,8 @@ absl::StatusOr<HloInstruction*> CholeskyExpander::ExpandInstruction(
MaybeTransposeInMinorDims(l, !options.lower());

TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());

TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
xla_computation.GetProgramShape());
HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
xla_computation.proto(), config));
HloCloneContext context(module);
computation =
module->DeepCloneComputation(new_module->entry_computation(), &context);
TF_ASSIGN_OR_RETURN(
computation, XlaComputationToHloComputation(xla_computation, module));
}

return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
Expand Down
12 changes: 3 additions & 9 deletions xla/service/eigh_expander.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ limitations under the License.
#include "xla/hlo/builder/xla_builder.h"
#include "xla/literal_util.h"
#include "xla/primitive_util.h"
#include "xla/service/hlo_creation_utils.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/util.h"
Expand Down Expand Up @@ -582,15 +583,8 @@ absl::StatusOr<HloInstruction*> EighExpander::ExpandInstruction(
}
XlaOp result = BuildEigh(a, lower, max_iter, tol, sort_eigenvalues);
TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(result));

TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
xla_computation.GetProgramShape());
HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
xla_computation.proto(), config));
HloCloneContext context(module);
computation =
module->DeepCloneComputation(new_module->entry_computation(), &context);
TF_ASSIGN_OR_RETURN(
computation, XlaComputationToHloComputation(xla_computation, module));
}

return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,7 @@ cc_library(
"//xla/client:xla_computation",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"//xla/service:hlo_creation_utils",
"//xla/service:hlo_module_config",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:cublas_cudnn",
Expand Down
31 changes: 7 additions & 24 deletions xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ limitations under the License.
#include "xla/service/gpu/cublas_cudnn.h"
#include "xla/service/gpu/cudnn_support_utils.h"
#include "xla/service/gpu/stream_executor_util.h"
#include "xla/service/hlo_creation_utils.h"
#include "xla/service/hlo_module_config.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
Expand Down Expand Up @@ -96,24 +97,6 @@ static std::vector<HloCustomCallInstruction*> GetRelevantConvs(
return convs;
}

// Converts an XlaBuilder into an HloComputation in the same module as
// `sibling_computation`.
//
// Yes, we serialize/deserialize as a proto. :)
static absl::StatusOr<HloComputation*> BuilderToHloComputation(
XlaBuilder& b, XlaOp root, HloComputation* sibling_computation) {
TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root));
TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comp.GetProgramShape());
HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(auto new_module,
HloModule::CreateFromProto(comp.proto(), config));

HloModule* dest_module = sibling_computation->parent();
HloCloneContext context(dest_module);
return dest_module->DeepCloneComputation(new_module->entry_computation(),
&context);
}

// Reshapes `instr` so that it has an extra dimension of size `vect_size` right
// after `dim`.
static XlaOp SplitAtDim(XlaOp instr, int64_t dim, int64_t vect_size) {
Expand Down Expand Up @@ -460,11 +443,11 @@ static absl::StatusOr<bool> TryRevectorizeConv(
new_conv_result, dnums->output_feature_dimension(), *output_vect_dim,
/*orig_vect_size=*/output_shape.dimensions(*output_vect_dim));

XlaOp root = Tuple(&b, {new_conv_result_unrevectorized, new_conv_scratch});
TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root));
TF_ASSIGN_OR_RETURN(
HloComputation * new_conv_comp,
BuilderToHloComputation(
b, Tuple(&b, {new_conv_result_unrevectorized, new_conv_scratch}),
conv->parent()));
XlaComputationToHloComputation(comp, conv->parent()->parent()));

// Set the name on the new conv. This is purely cosmetic, but we attempt to
// preserve e.g. "cudnn-conv.42" instead of "custom-call.42".
Expand Down Expand Up @@ -599,11 +582,11 @@ static absl::StatusOr<bool> TryVectorizeConv(
Collapse(new_conv_result, {dnums->output_feature_dimension(),
dnums->output_feature_dimension() + 1});

XlaOp root = Tuple(&b, {conv_result_collapsed, new_conv_scratch});
TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root));
TF_ASSIGN_OR_RETURN(
HloComputation * new_conv_comp,
BuilderToHloComputation(
b, Tuple(&b, {conv_result_collapsed, new_conv_scratch}),
conv->parent()));
XlaComputationToHloComputation(comp, conv->parent()->parent()));

// Create a tuple and replace the old conv with it!
VLOG(1) << "Vectorized conv to: " << new_conv_comp->ToString();
Expand Down
145 changes: 145 additions & 0 deletions xla/service/gpu/transforms/windowed_einsum_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,136 @@ bool ShouldAddToChain(const HloInstruction* inst) {
return false;
}
}

HloComputation* MakeSumComputation(PrimitiveType type, HloModule* module) {
HloComputation::Builder sum_b("add");
auto x = sum_b.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, ShapeUtil::MakeShape(type, {}), "x"));
auto y = sum_b.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, ShapeUtil::MakeShape(type, {}), "y"));
sum_b.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(type, {}), HloOpcode::kAdd, x, y));
HloComputation* reduction = module->AddEmbeddedComputation(sum_b.Build());
return reduction;
}

// Transform partial accumulations into a reduction on a contiguous buffer.
// Partial accumulations will impact the overlap between dots because the
// dot+add pattern will be fused into a single gemm later in gemm rewriter
// which adds data dependencies between gemms. Instead we write all
// intermediate results into a larger buffer and perform a one-shot reduction.
// The high-level transformation is:
//
// 'prev_res' is previously partially accumulated result.
//
// shape(x,y) prev_res shape(x,y) dot0
// \ /
// \ /
// shape(x,y) add0 shape(x,y) dot1
// \ /
// \ /
// shape(x,y) add1
// |
// shape(x,y) loop output
//
// transformed into:
// shape(x,y) prev_res shape(x,y) dot0 shape(x,y) dot1
// \ / /
// \ / /
// shape(n,x,y) concatenate on first axis, n is the number of partitions
// |
// shape(n,x,y) loop output
// |
// shape(x,y) reduction on first axis
//
// The final reduction is pulled outside of the loop to overlap with other
// collectives.
absl::Status MoveAccumulationOutsideLoop(
std::vector<HloInstruction*>& partial_accumulations,
HloComputation* while_body, HloInstruction* loop) {
// The input of the while loop will be modified and must have no other users.
if (!loop || loop->operand(0)->user_count() != 1) {
return absl::OkStatus();
}

std::vector<HloInstruction*> partials_to_concat;

// We reshape it to a N+1 dimensioned tensor with left-most dim being 1.
Shape shape = partial_accumulations[0]->shape();
shape = ShapeUtil::PrependMajorDimension(1, shape);

for (auto& inst : partial_accumulations) {
HloInstruction* reshaped_partial =
while_body->AddInstruction(HloInstruction::CreateReshape(shape, inst));
partials_to_concat.push_back(reshaped_partial);
}
Shape concat_shape = partial_accumulations[0]->shape();
concat_shape = ShapeUtil::PrependMajorDimension(partial_accumulations.size(),
concat_shape);

HloInstruction* concat = while_body->AddInstruction(
HloInstruction::CreateConcatenate(concat_shape, partials_to_concat, 0));

HloComputation* comp = loop->parent();
HloInstruction* windowed_lhs = loop->mutable_operand(0)->mutable_operand(0);
// Add a broadcasted zero of the same type as windowed_lhs. This holds all
// the partial accumulations and will be fed to a global reduction after
// this windowed einsum loop. We move the reduction outside of the loop so
// it can be fused or overlap with other instructions in the main
// computation.
Literal zero_literal =
LiteralUtil::Zero(windowed_lhs->shape().element_type());
HloInstruction* zero = comp->AddInstruction(
HloInstruction::CreateConstant(std::move(zero_literal)));
Shape zero_bcast_shape = ShapeUtil::ChangeElementType(
concat_shape, windowed_lhs->shape().element_type());
HloInstruction* zero_bcast = MakeBroadcastHlo(zero, {}, zero_bcast_shape);
loop->mutable_operand(0)->AppendOperand(zero_bcast);
ShapeUtil::AppendShapeToTuple(zero_bcast->shape(),
loop->mutable_operand(0)->mutable_shape());

// Update the parameter tuples of while's body and condition
// computations.
for (HloComputation* while_comp : {while_body, loop->while_condition()}) {
while_comp->ReplaceParameter(
0, HloInstruction::CreateParameter(
0, loop->mutable_operand(0)->shape(),
while_comp->parameter_instruction(0)->name()));
}
HloInstruction* root = while_body->root_instruction();
std::vector<HloInstruction*> original_operands(root->operands().begin(),
root->operands().end());
original_operands.push_back(concat);
HloInstruction* new_output_tuple = while_body->AddInstruction(
HloInstruction::CreateTuple(original_operands));
TF_RETURN_IF_ERROR(
while_body->ReplaceInstructionWithDifferentShape(root, new_output_tuple));

// Update the shape of the while loop instruction.
*loop->mutable_shape() = loop->operand(0)->shape();

// The final reduction
HloInstruction* concat_result_gte =
comp->AddInstruction(HloInstruction::CreateGetTupleElement(
loop, (loop->operand(0)->shape().tuple_shapes_size() - 1)));
HloInstruction* reduced_result =
comp->AddInstruction(HloInstruction::CreateReduce(
partial_accumulations[0]->shape(), concat_result_gte, zero, {0},
MakeSumComputation(shape.element_type(), loop->GetModule())));

// Replace the original output if present.
HloInstruction* original_output_gte;
auto it = absl::c_find_if(loop->users(), [&](HloInstruction* instr) {
// Index of the original output. It's fixed to be the third element in the
// tuple.
return instr->tuple_index() == 2;
});
if (it != loop->users().end()) {
original_output_gte = *it;
TF_RETURN_IF_ERROR(original_output_gte->ReplaceAllUsesWith(reduced_result));
}
return absl::OkStatus();
}
absl::Status PostProcessUnrolledLoop(HloInstruction* loop, int64_t stream_id) {
HloComputation* while_body = loop->while_body();
// This is to set force delay for the first collective permute so it can
Expand All @@ -509,6 +639,7 @@ absl::Status PostProcessUnrolledLoop(HloInstruction* loop, int64_t stream_id) {
WindowedEinsumHandler::kWindowedEinsumRsLoopName) == 0
? 2
: 0;
std::vector<HloInstruction*> partial_accumulations;
for (HloInstruction* inst : while_body->MakeInstructionPostOrder()) {
HloInstruction* matched_cp;
if (Match(inst,
Expand All @@ -524,6 +655,20 @@ absl::Status PostProcessUnrolledLoop(HloInstruction* loop, int64_t stream_id) {
TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(inst, stream_id));
++stream_id;
}
// If dot's result is accumulated, this means we found a loop with
// contracting dim sharded.
HloInstruction* partial_dot;
if (Match(inst, m::AddAnyOrder(m::Op(),
m::Dot(&partial_dot, m::Op(), m::Op())))) {
partial_accumulations.push_back(partial_dot);
}
}
if (partial_accumulations.size() > 0 &&
while_body->name().find(
WindowedEinsumHandler::kWindowedEinsumAgLoopName) !=
std::string::npos) {
TF_RETURN_IF_ERROR(
MoveAccumulationOutsideLoop(partial_accumulations, while_body, loop));
}
return absl::OkStatus();
}
Expand Down
Loading

0 comments on commit bd4ab19

Please sign in to comment.