Skip to content

Commit

Permalink
Address pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Tixxx committed Oct 3, 2024
1 parent b90997c commit a55b2f4
Showing 1 changed file with 121 additions and 114 deletions.
235 changes: 121 additions & 114 deletions xla/service/gpu/transforms/windowed_einsum_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,17 +474,129 @@ HloComputation* MakeSumComputation(PrimitiveType type, HloModule* module) {
/*parameter_number=*/0, ShapeUtil::MakeShape(type, {}), "x"));
auto y = sum_b.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, ShapeUtil::MakeShape(type, {}), "y"));
if (type == PRED) {
sum_b.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(type, {}), HloOpcode::kOr, x, y));
} else {
sum_b.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(type, {}), HloOpcode::kAdd, x, y));
}
sum_b.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(type, {}), HloOpcode::kAdd, x, y));
HloComputation* reduction = module->AddEmbeddedComputation(sum_b.Build());
return reduction;
}

// Transform partial accumulations into a reduction on a contiguous buffer.
// Partial accumulations will impact the overlap between dots because the
// dot+add pattern will be fused into a single gemm later in gemm rewriter
// which adds data dependencies between gemms. Instead we write all
// intermediate results into a larger buffer and perform a one-shot reduction.
// The high-level transformation is:
//
// 'prev_res' is previously partially accumulated result.
//
// shape(x,y) prev_res shape(x,y) dot0
// \ /
// \ /
// shape(x,y) add0 shape(x,y) dot1
// \ /
// \ /
// shape(x,y) add1
// |
// shape(x,y) loop output
//
// transformed into:
// shape(x,y) prev_res shape(x,y) dot0 shape(x,y) dot1
// \ / /
// \ / /
// shape(n,x,y) concatenate on first axis, n is the number of partitions
// |
// shape(n,x,y) loop output
// |
// shape(x,y) reduction on first axis
//
// The final reduction is pulled outside of the loop to overlap with other
// collectives.
absl::Status MoveAccumulationOutsideLoop(
std::vector<HloInstruction*>& partial_accumulations,
HloComputation* while_body, HloInstruction* loop) {
// The input of the while loop will be modified and must have no other users.
if (!loop || loop->operand(0)->user_count() != 1) {
return absl::OkStatus();
}

std::vector<HloInstruction*> partials_to_concat;

// We reshape it to a N+1 dimensioned tensor with left-most dim being 1.
Shape shape = partial_accumulations[0]->shape();
shape = ShapeUtil::PrependMajorDimension(1, shape);

for (auto& inst : partial_accumulations) {
HloInstruction* reshaped_partial =
while_body->AddInstruction(HloInstruction::CreateReshape(shape, inst));
partials_to_concat.push_back(reshaped_partial);
}
Shape concat_shape = partial_accumulations[0]->shape();
concat_shape = ShapeUtil::PrependMajorDimension(partial_accumulations.size(),
concat_shape);

HloInstruction* concat = while_body->AddInstruction(
HloInstruction::CreateConcatenate(concat_shape, partials_to_concat, 0));

HloComputation* comp = loop->parent();
HloInstruction* windowed_lhs = loop->mutable_operand(0)->mutable_operand(0);
// Add a broadcasted zero of the same type as windowed_lhs. This holds all
// the partial accumulations and will be fed to a global reduction after
// this windowed einsum loop. We move the reduction outside of the loop so
// it can be fused or overlap with other instructions in the main
// computation.
Literal zero_literal =
LiteralUtil::Zero(windowed_lhs->shape().element_type());
HloInstruction* zero = comp->AddInstruction(
HloInstruction::CreateConstant(std::move(zero_literal)));
Shape zero_bcast_shape = ShapeUtil::ChangeElementType(
concat_shape, windowed_lhs->shape().element_type());
HloInstruction* zero_bcast = MakeBroadcastHlo(zero, {}, zero_bcast_shape);
loop->mutable_operand(0)->AppendOperand(zero_bcast);
ShapeUtil::AppendShapeToTuple(zero_bcast->shape(),
loop->mutable_operand(0)->mutable_shape());

// Update the parameter tuples of while's body and condition
// computations.
for (HloComputation* while_comp : {while_body, loop->while_condition()}) {
while_comp->ReplaceParameter(
0, HloInstruction::CreateParameter(
0, loop->mutable_operand(0)->shape(),
while_comp->parameter_instruction(0)->name()));
}
HloInstruction* root = while_body->root_instruction();
std::vector<HloInstruction*> original_operands(root->operands().begin(),
root->operands().end());
original_operands.push_back(concat);
HloInstruction* new_output_tuple = while_body->AddInstruction(
HloInstruction::CreateTuple(original_operands));
TF_RETURN_IF_ERROR(
while_body->ReplaceInstructionWithDifferentShape(root, new_output_tuple));

// Update the shape of the while loop instruction.
*loop->mutable_shape() = loop->operand(0)->shape();

// The final reduction
HloInstruction* concat_result_gte =
comp->AddInstruction(HloInstruction::CreateGetTupleElement(
loop, (loop->operand(0)->shape().tuple_shapes_size() - 1)));
HloInstruction* reduced_result =
comp->AddInstruction(HloInstruction::CreateReduce(
partial_accumulations[0]->shape(), concat_result_gte, {zero}, {0},
MakeSumComputation(shape.element_type(), loop->GetModule())));

// Replace the original output if present.
HloInstruction* original_output_gte;
auto it = absl::c_find_if(loop->users(), [&](HloInstruction* instr) {
// Index of the original output. It's fixed to be the third element in the
// tuple.
return instr->tuple_index() == 2;
});
if (it != loop->users().end()) {
original_output_gte = *it;
TF_RETURN_IF_ERROR(original_output_gte->ReplaceAllUsesWith(reduced_result));
}
return absl::OkStatus();
}
absl::Status PostProcessUnrolledLoop(HloInstruction* loop, int64_t stream_id) {
HloComputation* while_body = loop->while_body();
// This is to set force delay for the first collective permute so it can
Expand Down Expand Up @@ -519,117 +631,12 @@ absl::Status PostProcessUnrolledLoop(HloInstruction* loop, int64_t stream_id) {
partial_accumulations.push_back(partial_dot);
}
}

// Transform partial accumulations into a reduction on a contiguous buffer.
// Partial accumulations will impact the overlap between dots because the
// dot+add pattern will be fused into a single gemm later in gemm rewriter
// which adds data dependencies between gemms. Instead we write all
// intermediate results into a larger buffer and perform a one-shot reduction.
// The high-level transformation is:
// shape(x,y) cp shape(x,y) dot0
// \ /
// \ /
// shape(x,y) add0 shape(x,y) dot1
// \ /
// \ /
// shape(x,y) add1
// |
// shape(x,y) loop output
//
// transformed into:
// shape(x,y) cp shape(x,y) dot0 shape(x,y) dot1
// \ / /
// \ / /
// shape(n,x,y) concatenate on first axis, n is the number of partitions
// |
// shape(n,x,y) loop output
// |
// shape(x,y) reduction on first axis
//
// The final reduction is pulled outside of the loop to overlap with other
// collectives.
if (partial_accumulations.size() > 0 &&
while_body->name().find(
WindowedEinsumHandler::kWindowedEinsumAgLoopName) !=
std::string::npos) {
std::vector<HloInstruction*> partials_to_concat;

// We reshape it to a N+1 dimensioned tensor with left-most dim being 1.
Shape shape = partial_accumulations[0]->shape();
shape = ShapeUtil::PrependMajorDimension(1, shape);

for (auto& inst : partial_accumulations) {
CHECK(inst->user_count() == 1);
HloInstruction* reshaped_partial = while_body->AddInstruction(
HloInstruction::CreateReshape(shape, inst));
partials_to_concat.push_back(reshaped_partial);
}
Shape concat_shape = partial_accumulations[0]->shape();
concat_shape = ShapeUtil::PrependMajorDimension(
partial_accumulations.size(), concat_shape);

HloInstruction* concat = while_body->AddInstruction(
HloInstruction::CreateConcatenate(concat_shape, partials_to_concat, 0));

HloComputation* comp = loop->parent();
HloInstruction* windowed_lhs = loop->mutable_operand(0)->mutable_operand(0);
// Add a broadcasted zero of the same type as windowed_lhs. This holds all
// the partial accumulations and will be fed to a global reduction after
// this windowed einsum loop. We move the reduction outside of the loop so
// it can be fused or overlap with other instructions in the main
// computation.
Literal zero_literal =
LiteralUtil::Zero(windowed_lhs->shape().element_type());
HloInstruction* zero = comp->AddInstruction(
HloInstruction::CreateConstant(std::move(zero_literal)));
Shape zero_bcast_shape = ShapeUtil::ChangeElementType(
concat_shape, windowed_lhs->shape().element_type());
HloInstruction* zero_bcast = MakeBroadcastHlo(zero, {}, zero_bcast_shape);
loop->mutable_operand(0)->AppendOperand(zero_bcast);
ShapeUtil::AppendShapeToTuple(zero_bcast->shape(),
loop->mutable_operand(0)->mutable_shape());

// Update the parameter tuples of while's body and condition
// computations.
for (HloComputation* while_comp : {while_body, loop->while_condition()}) {
while_comp->ReplaceParameter(
0, HloInstruction::CreateParameter(
0, loop->mutable_operand(0)->shape(),
while_comp->parameter_instruction(0)->name()));
}
HloInstruction* root = while_body->root_instruction();
std::vector<HloInstruction*> original_operands(root->operands().begin(),
root->operands().end());
original_operands.push_back(concat);
HloInstruction* new_output_tuple = while_body->AddInstruction(
HloInstruction::CreateTuple(original_operands));
TF_RETURN_IF_ERROR(while_body->ReplaceInstructionWithDifferentShape(
root, new_output_tuple));

// Update the shape of the while loop instruction.
*loop->mutable_shape() = loop->operand(0)->shape();

// The final reduction
HloInstruction* concat_result_gte =
comp->AddInstruction(HloInstruction::CreateGetTupleElement(
loop, (loop->operand(0)->shape().tuple_shapes_size() - 1)));
HloInstruction* reduced_result =
comp->AddInstruction(HloInstruction::CreateReduce(
partial_accumulations[0]->shape(), concat_result_gte, {zero}, {0},
MakeSumComputation(shape.element_type(), loop->GetModule())));

// Replace the original output if present.
HloInstruction* original_output_gte;
auto it = absl::c_find_if(loop->users(), [&](HloInstruction* instr) {
// Index of the original output. It's fixed to be the third element in the
// tuple.
return instr->tuple_index() == 2;
});
if (it != loop->users().end()) {
original_output_gte = *it;
TF_RETURN_IF_ERROR(
original_output_gte->ReplaceAllUsesWith(reduced_result));
}
TF_RETURN_IF_ERROR(
MoveAccumulationOutsideLoop(partial_accumulations, while_body, loop));
}
return absl::OkStatus();
}
Expand Down

0 comments on commit a55b2f4

Please sign in to comment.