Skip to content

Commit

Permalink
PR #19272: Revert "PR #15291: [NVIDIA GPU] Add Bitcast to collective …
Browse files Browse the repository at this point in the history
…pipeliner a…

Imported from GitHub PR #19272

This reverts commit 6c65d7a.

Accepting Bitcast in collective pipeliner was a temporary solution for some workload relying on post-layout collective pipeliner. Recently we saw cases where including Bitcast can break the pattern matcher. Revert this PR since Bitcast will not show up in pre-layout collective pipeliner, which is the default behavior moving forward.
Copybara import of the project:

--
ad05557 by Terry Sun <[email protected]>:

Revert "PR #15291: [NVIDIA GPU] Add Bitcast to collective pipeliner acceptable users"

This reverts commit 6c65d7a.

Merging this change closes #19272

COPYBARA_INTEGRATE_REVIEW=#19272 from terryysun:terryysun/revert_bitcast_in_cp ad05557
PiperOrigin-RevId: 696218794
  • Loading branch information
terryysun authored and Google-ML-Automation committed Nov 13, 2024
1 parent 02a7447 commit 5ce5c41
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 63 deletions.
2 changes: 1 addition & 1 deletion xla/service/collective_pipeliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ CheckStoreIntoSliceIsCompatible(HloInstruction* instr,
HloOpcode::kPad, HloOpcode::kCollectivePermute,
HloOpcode::kConvert, HloOpcode::kReshape,
HloOpcode::kAllReduce, HloOpcode::kTranspose,
HloOpcode::kBroadcast, HloOpcode::kBitcast>(i) ||
HloOpcode::kBroadcast>(i) ||
(multi_uses_pipelining && i->IsElementwise()) ||
i->IsCustomCall(CollectivePipeliner::kInsertedByPreviousStep) ||
i->IsCustomCall(CollectivePipeliner::kSunkByPreviousStep);
Expand Down
60 changes: 0 additions & 60 deletions xla/service/collective_pipeliner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,66 +184,6 @@ ENTRY entry {
EXPECT_EQ(get_tuple_index->tuple_index(), 3);
}

// A case where Bitcast will become the user of a pipelined instruction and
// check if the DUS is pushed to the next iteration successfully. Absense of
// Bitcast in acceptable users will break this test.
TEST_F(CollectivePipelinerTest, BitcastAsUser) {
constexpr absl::string_view hlo_string = R"(
HloModule module
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
while_cond {
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
current-loop-index = s32[] get-tuple-element(param), index=0
output-buffer = bf16[3,8,128] get-tuple-element(param), index=1
input-buffer = bf16[3,8,128] get-tuple-element(param), index=2
constant.1 = s32[] constant(1)
next-loop-index = s32[] add(current-loop-index, constant.1)
constant.0 = s32[] constant(0)
sliced-input-buffer = bf16[1,8,128] dynamic-slice(input-buffer, current-loop-index, constant.0, constant.0), dynamic_slice_sizes={1,8,128}
all-reduce = bf16[1,8,128] all-reduce(sliced-input-buffer), replica_groups={}, to_apply=add, channel_id=1
bitcast.0 = u16[3,8,128] bitcast(all-reduce)
bitcast.1 = bf16[3,8,128] bitcast(bitcast.0)
dynamic-update-slice = bf16[3,8,128] dynamic-update-slice(output-buffer, bitcast.1, current-loop-index, constant.0, constant.0)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(next-loop-index, dynamic-update-slice, input-buffer)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[3,8,128] parameter(0)
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";
auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value();
EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true).value());
XLA_VLOG_LINES(1, module->ToString());
const HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::DynamicUpdateSlice(_, op::Bitcast(), _, _, _));
const HloInstruction* cast_back = root->operand(1);
EXPECT_EQ(cast_back->opcode(), HloOpcode::kBitcast);
const HloInstruction* cast_to = cast_back->operand(0);
EXPECT_EQ(cast_to->opcode(), HloOpcode::kBitcast);
const HloInstruction* ar = cast_to->operand(0);
// check if all-reduce is pipelined
EXPECT_EQ(ar->opcode(), HloOpcode::kAllReduce);
}

TEST_F(CollectivePipelinerTest, TransformIncrementIndexByOneCollectivePermute) {
constexpr absl::string_view hlo_string = R"(
HloModule module
Expand Down
3 changes: 1 addition & 2 deletions xla/service/gpu/gpu_collective_combiner_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,8 @@ TEST_F(CollectiveCombinerUtilsTest,
current-loop-index, constant.0, constant.0), dynamic_slice_sizes={1,8,128}
all-reduce = bf16[1,8,128] all-reduce(sliced-input-buffer),
replica_groups={}, to_apply=add, channel_id=1
bitcast.0 = bf16[3,8,128] bitcast(all-reduce)
dynamic-update-slice = bf16[3,8,128] dynamic-update-slice(output-buffer,
bitcast.0, current-loop-index, constant.0, constant.0)
all-reduce, current-loop-index, constant.0, constant.0)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(next-loop-index,
dynamic-update-slice, input-buffer)
}
Expand Down

0 comments on commit 5ce5c41

Please sign in to comment.