Skip to content

Commit

Permalink
[XLA:CollectivePipeliner] Avoid redundant broadcasts in the formattin…
Browse files Browse the repository at this point in the history
…g ops of sunk collectives.

Before this CL, the same broadcast could be added multiple times
- to the formatting ops of a single sunk collective, and
- to the modified HLO computation if the same broadcast appears in the formatting ops of different sunk collectives.

PiperOrigin-RevId: 676593920
  • Loading branch information
seherellis authored and Google-ML-Automation committed Sep 19, 2024
1 parent 8a15ad0 commit 7f938f1
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 5 deletions.
15 changes: 10 additions & 5 deletions xla/service/collective_pipeliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,8 @@ bool CollectSimpleDependencies(HloInstruction* i,
for (HloInstruction* op : i->mutable_operands()) {
absl::InlinedVector<HloInstruction*, 4> to_add;
if (op->opcode() == HloOpcode::kBroadcast) {
to_add.push_back(op);
if (deps_set.insert(op).second) {
to_add.push_back(op);
op = op->mutable_operand(0);
if (op->opcode() == HloOpcode::kConstant) {
if (deps_set.insert(op).second) {
Expand Down Expand Up @@ -318,6 +318,7 @@ CheckStoreIntoSliceIsCompatible(HloInstruction* instr,
absl::flat_hash_set<HloInstruction*> added_instructions;
HloInstruction* folded_instr = instr;
std::vector<HloInstruction*> formatting_ops;
absl::flat_hash_set<HloInstruction*> formatting_set;
// Returns if this is an acceptable user of a pipelined instruction.
// Generic elementwise ops can have multiple operands that require the inputs
// of being saved across the loop. So protect them through
Expand Down Expand Up @@ -411,11 +412,12 @@ CheckStoreIntoSliceIsCompatible(HloInstruction* instr,
auto& data = stack.back();
HloInstruction* instr = data.first;
if (data.second == 0 && instr != folded_instr) {
if (!CollectSimpleDependencies(instr, formatting_ops,
added_instructions)) {
if (!CollectSimpleDependencies(instr, formatting_ops, formatting_set)) {
return empty_pair;
}
formatting_ops.push_back(instr);
if (formatting_set.insert(instr).second) {
formatting_ops.push_back(instr);
}
}
if (data.second == instr->user_count()) {
stack.pop_back();
Expand Down Expand Up @@ -2330,9 +2332,9 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis,
// Create the new tuple with the original while tuple size.
std::vector<HloInstruction*> new_output_tuple;
new_output_tuple.resize(operands_indices_count, nullptr);
InstructionMap pipelined_map;
// Reproduce computation to the output after the loop on the full shape.
for (auto& to_move : loop_analysis.GetMoveInfos()) {
InstructionMap pipelined_map;
for (int64_t i = 0; i < to_move.collectives_to_move.size(); ++i) {
HloInstruction* collective = to_move.collectives_to_move[i];
int64_t gte_index = collective_to_new_tuple_index[collective];
Expand Down Expand Up @@ -2419,6 +2421,9 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis,
// an effect on the instruction itself (like say broadcast, slices ...
// etc).
for (HloInstruction* formatting_op : to_move.formatting_ops) {
if (pipelined_map.contains(formatting_op)) {
continue;
}
if (!to_add_batch_set.contains(formatting_op) &&
formatting_op->opcode() != HloOpcode::kBroadcast) {
HloInstruction* cloned_not_to_batch = loop_computation->AddInstruction(
Expand Down
83 changes: 83 additions & 0 deletions xla/service/collective_pipeliner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3766,5 +3766,88 @@ ENTRY entry {
while_instr->while_body()->root_instruction()->operand(8)));
}

TEST_F(CollectivePipelinerTest, NoRedundantBroadcastsInFormattingOps) {
constexpr absl::string_view hlo_string = R"(
HloModule module
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
add.1 {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
while_cond {
param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.396 = bf16[3,8,128] get-tuple-element(param), index=2
get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=3
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394)
constant.2560 = s32[] constant(-1)
add.231 = s32[] add(subtract.139, constant.2560)
constant.2561 = s32[] constant(0)
compare.747 = pred[] compare(add.231, constant.2561), direction=LT
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1
convert = bf16[] convert(add.232)
broadcast = bf16[1,8,128] broadcast(convert)
add.1 = bf16[1,8,128] add(ar.1, broadcast)
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, add.1, select.1348, constant.2561, constant.2561)
ar.2 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add.1, channel_id=2
add.2 = bf16[1,8,128] add(ar.2, broadcast)
dynamic-update-slice.36 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.396, add.2, select.1348, constant.2561, constant.2561)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, dynamic-update-slice.36, get-tuple-element.35)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[3,8,128] parameter(0)
tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";
auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value();
EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true,
/*level_to_operate_on=*/0,
/*pipeline_use_tree=*/true,
/*process_different_sized_ops=*/true,
CollectivePipeliner::kForwardSink)
.value());
XLA_VLOG_LINES(1, module->ToString());
// There should be only one broadcast instruction using a get-tuple-element
// from the while instruction.
EXPECT_EQ(absl::c_count_if(module->entry_computation()->instructions(),
[](const HloInstruction* instr) {
return instr->opcode() ==
HloOpcode::kBroadcast &&
instr->operand(0)->opcode() ==
HloOpcode::kGetTupleElement &&
instr->operand(0)->operand(0)->opcode() ==
HloOpcode::kWhile;
}),
1);
}

} // namespace
} // namespace xla

0 comments on commit 7f938f1

Please sign in to comment.