Skip to content

Commit

Permalink
PR #19363: Loop Counter Increment in Collective Pipeliner
Browse files Browse the repository at this point in the history
Imported from GitHub PR #19363

Sets the loop iteration counter increment in the backward transformation of the collective pipeliner pass to account for cases with non-zero initial value of the loop iteration counter. See #16953 and #18568.
Copybara import of the project:

--
06137aa by Philipp Hack <[email protected]>:

Modifies the loop counter increment set in the backward transformation of the collective pipeliner.

--
6da45bc by Philipp Hack <[email protected]>:

Modifies the loop counter increment set in the backward transformation of the collective pipeliner.

Merging this change closes #19363

FUTURE_COPYBARA_INTEGRATE_REVIEW=#19363 from philipphack:u_pipeliner_increment_xla 6da45bc
PiperOrigin-RevId: 698294264
  • Loading branch information
philipphack authored and Google-ML-Automation committed Nov 20, 2024
1 parent db0f426 commit c034a2e
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 1 deletion.
4 changes: 3 additions & 1 deletion xla/service/collective_pipeliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2825,7 +2825,9 @@ static absl::Status TransformLoopBackward(
[new_root_operands[*loop_analysis.GetLoopIterationIdx()]],
body_builder.AddInstruction(
HloInstruction::CreateConstant(*CreateLiteralOfShape(
loop_index_shape, next_loop_iteration.GetSignedValue())))));
loop_index_shape,
loop_analysis.GetLoopIncrement()->GetSignedValue())))));

HloInstruction* new_loop_root =
body_builder.AddInstruction(HloInstruction::CreateTuple(
MapNewOperands(new_root_operands, while_body_replacement_map,
Expand Down
59 changes: 59 additions & 0 deletions xla/service/collective_pipeliner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1348,6 +1348,65 @@ ENTRY entry {
EXPECT_EQ(add_instr_loop->opcode(), HloOpcode::kAdd);
}

TEST_F(CollectivePipelinerTest,
TransformIncrementIndexByOneStartFromOneBackwards) {
constexpr absl::string_view hlo_string = R"(
HloModule module
while_cond {
param = (s32[], bf16[5,8,128], bf16[5,1,2,128]) parameter(0)
loop_index = s32[] get-tuple-element(param), index=0
c4 = s32[] constant(4)
ROOT cmp = pred[] compare(loop_index, c4), direction=LT
}
while_body {
param = (s32[], bf16[5,8,128], bf16[5,1,2,128]) parameter(0)
loop_index = s32[] get-tuple-element(param), index=0
partial_output = bf16[5,8,128] get-tuple-element(param), index=1
slice_input = bf16[5,1,2,128] get-tuple-element(param), index=2
c0 = s32[] constant(0)
c1 = s32[] constant(1)
next_loop_index = s32[] add(loop_index, c1)
c3 = s32[] constant(3)
three_minus_loop_index = s32[] subtract(c3, loop_index)
dynamic_slice = bf16[1,1,2,128] dynamic-slice(slice_input, three_minus_loop_index, c0, c0, c0), dynamic_slice_sizes={1,1,2,128}
dynamic_slice_reshape = bf16[1,2,128] reshape(dynamic_slice)
add = bf16[1,2,128] add(dynamic_slice_reshape, dynamic_slice_reshape), control-predecessors={c3}
all_gather = bf16[1,8,128] all-gather(add), dimensions={1}, replica_groups={}
updated_partial_output = bf16[5,8,128] dynamic-update-slice(partial_output, all_gather, three_minus_loop_index, c0, c0)
ROOT tuple = (s32[], bf16[5,8,128], bf16[5,1,2,128]) tuple(next_loop_index, updated_partial_output, slice_input), control-predecessors={add}
}
ENTRY entry {
c1 = s32[] constant(1)
p0 = bf16[5,8,128] parameter(0)
p1 = bf16[5,1,2,128] parameter(1)
tuple = (s32[], bf16[5,8,128], bf16[5,1,2,128]) tuple(c1, p0, p1)
while = (s32[], bf16[5,8,128], bf16[5,1,2,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte = bf16[5,8,128] get-tuple-element(while), index=1
}
)";
auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value();
EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true, 0,
/*pipeline_use_tree=*/false,
/*process_different_sized_ops=*/false,
CollectivePipeliner::PipeliningDirection::kBackward,
IsAllGather)
.value());
XLA_VLOG_LINES(1, module->ToString());
const HloInstruction* while_instr =
FindInstruction(module.get(), HloOpcode::kWhile);
const HloComputation* comp = while_instr->while_body();
const HloInstruction* root_loop = comp->root_instruction();

const HloInstruction* shifted_loop_counter = root_loop->operand(4);
EXPECT_EQ(shifted_loop_counter->opcode(), HloOpcode::kAdd);
const HloInstruction* loop_increment = shifted_loop_counter->operand(1);
EXPECT_EQ(loop_increment->opcode(), HloOpcode::kConstant);
EXPECT_TRUE(loop_increment->literal().IsEqualAt({}, 1));
}

TEST_F(CollectivePipelinerTest,
TransformIncrementIndexByOneBackwardsWithTwoDependentClones) {
constexpr absl::string_view hlo_string = R"(
Expand Down
185 changes: 185 additions & 0 deletions xla/tests/collective_ops_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,191 @@ ENTRY entry {
absl::StrReplaceAll(kModuleReplicatedStr, replacements_), opts);
}

// E2E tests comparing the results with and without pipelining of collectives.
class CollectiveOpsTestE2EPipelinedNonPipelined : public CollectiveOpsTestE2E {
public:
void CollectiveOpsComparePipelinedNonPipelined(absl::string_view hlo_string) {
const int64_t kNumReplicas = 1;
const int64_t kNumPartitions = 2;
SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions);

HloModuleConfig config =
GetModuleConfigForTest(kNumReplicas, kNumPartitions);
auto opts = GetDebugOptionsForTest();
opts.set_xla_gpu_enable_pipelined_collectives(true);
config.set_debug_options(opts);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string, config));
auto fake_arguments = xla::MakeFakeArguments(module.get()).value();
std::vector<Literal*> fake_ptrs(fake_arguments.size());
for (int i = 0; i < fake_arguments.size(); ++i) {
fake_ptrs[i] = &fake_arguments[i];
}

DeviceAssignment assn(/*replica_count=*/kNumReplicas,
/*computation_count=*/kNumPartitions);
for (int64_t i = 0; i < kNumPartitions; ++i) {
assn(0, i) = i;
}

TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
HloTestBase::ExecuteReplicated(
std::move(module), fake_ptrs, kNumPartitions, &assn,
/*run_hlo_passes=*/true, /*use-threads=*/true));
ASSERT_EQ(results.size(), kNumPartitions);

HloModuleConfig ref_config =
GetModuleConfigForTest(kNumReplicas, kNumPartitions);
auto ref_opts = GetDebugOptionsForTest();
ref_opts.set_xla_gpu_enable_pipelined_collectives(false);
ref_config.set_debug_options(ref_opts);
TF_ASSERT_OK_AND_ASSIGN(
auto ref_module, ParseAndReturnVerifiedModule(hlo_string, ref_config));
auto fake_ref_arguments = xla::MakeFakeArguments(ref_module.get()).value();
std::vector<Literal*> ref_fake_ptrs(fake_ref_arguments.size());
for (int i = 0; i < fake_ref_arguments.size(); ++i) {
ref_fake_ptrs[i] = &fake_ref_arguments[i];
}

TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> ref_results,
HloTestBase::ExecuteReplicated(
std::move(ref_module), ref_fake_ptrs, kNumPartitions, &assn,
/*run_hlo_passes=*/true, /*use-threads=*/true));
ASSERT_EQ(ref_results.size(), kNumPartitions);
ErrorSpec error_spec{1e-5, 1e-5};
// Expect same results with and without pipelining of collectives.
for (int i = 0; i < kNumPartitions; ++i) {
EXPECT_TRUE(
LiteralTestUtil::Near(ref_results[i], results[i], error_spec));
}
}
};

TEST_F(CollectiveOpsTestE2EPipelinedNonPipelined, CollectivePipelinerForward) {
constexpr absl::string_view hlo_string = R"(
HloModule module, entry_computation_layout={(bf16[5,8,16])->bf16[5,8,16]}, allow_spmd_sharding_propagation_to_parameters={false,false}, num_partitions=2
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
while_cond {
param = (s32[], bf16[5,8,16], bf16[5,8,16]) parameter(0)
loop_index = s32[] get-tuple-element(param), index=0
c5 = s32[] constant(5)
ROOT cmp = pred[] compare(loop_index, c5), direction=LT
}
while_body {
param = (s32[], bf16[5,8,16], bf16[5,8,16]) parameter(0)
loop_index = s32[] get-tuple-element(param), index=0
partial_output = bf16[5,8,16] get-tuple-element(param), index=1
slice_input = bf16[5,8,16] get-tuple-element(param), index=2
c0 = s32[] constant(0)
c1 = s32[] constant(1)
next_loop_index = s32[] add(loop_index, c1)
dynamic_slice = bf16[1,8,16] dynamic-slice(slice_input, loop_index, c0, c0), dynamic_slice_sizes={1,8,16}
all_reduce = bf16[1,8,16] all-reduce(dynamic_slice), replica_groups={}, to_apply=add, channel_id=1
updated_partial_output = bf16[5,8,16] dynamic-update-slice(partial_output, all_reduce, loop_index, c0, c0)
ROOT tuple = (s32[], bf16[5,8,16], bf16[5,8,16]) tuple(next_loop_index, updated_partial_output, slice_input)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[5,8,16] parameter(0)
tuple = (s32[], bf16[5,8,16], bf16[5,8,16]) tuple(c0, p0, p0)
while = (s32[], bf16[5,8,16], bf16[5,8,16]) while(tuple), condition=while_cond, body=while_body
ROOT gte = bf16[5,8,16] get-tuple-element(while), index=1
}
)";

CollectiveOpsComparePipelinedNonPipelined(hlo_string);
}

TEST_F(CollectiveOpsTestE2EPipelinedNonPipelined, CollectivePipelinerBackward) {
constexpr absl::string_view hlo_string = R"(
HloModule module, entry_computation_layout={(bf16[5,4,16], bf16[5,1,2,16])->bf16[5,4,16]}, allow_spmd_sharding_propagation_to_parameters={false,false}, num_partitions=2
while_cond {
param = (s32[], bf16[5,4,16], bf16[5,1,2,16]) parameter(0)
loop_index = s32[] get-tuple-element(param), index=0
c5 = s32[] constant(5)
ROOT cmp = pred[] compare(loop_index, c5), direction=LT
}
while_body {
param = (s32[], bf16[5,4,16], bf16[5,1,2,16]) parameter(0)
loop_index = s32[] get-tuple-element(param), index=0
partial_output = bf16[5,4,16] get-tuple-element(param), index=1
slice_input = bf16[5,1,2,16] get-tuple-element(param), index=2
c0 = s32[] constant(0)
c1 = s32[] constant(1)
next_loop_index = s32[] add(loop_index, c1)
dynamic_slice = bf16[1,1,2,16] dynamic-slice(slice_input, loop_index, c0, c0, c0), dynamic_slice_sizes={1,1,2,16}
dynamic_slice_reshape = bf16[1,2,16] reshape(dynamic_slice)
all_gather = bf16[1,4,16] all-gather(dynamic_slice_reshape), dimensions={1}, replica_groups={}
updated_partial_output = bf16[5,4,16] dynamic-update-slice(partial_output, all_gather, loop_index, c0, c0)
ROOT tuple = (s32[], bf16[5,4,16], bf16[5,1,2,16]) tuple(next_loop_index, updated_partial_output, slice_input)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[5,4,16] parameter(0)
p1 = bf16[5,1,2,16] parameter(1)
tuple = (s32[], bf16[5,4,16], bf16[5,1,2,16]) tuple(c0, p0, p1)
while = (s32[], bf16[5,4,16], bf16[5,1,2,16]) while(tuple), condition=while_cond, body=while_body
ROOT gte = bf16[5,4,16] get-tuple-element(while), index=1
}
)";

CollectiveOpsComparePipelinedNonPipelined(hlo_string);
}

TEST_F(CollectiveOpsTestE2EPipelinedNonPipelined,
CollectivePipelinerBackwardStartFromOne) {
constexpr absl::string_view hlo_string = R"(
HloModule module, entry_computation_layout={(bf16[5,4,16], bf16[5,1,2,16])->bf16[5,4,16]}, allow_spmd_sharding_propagation_to_parameters={false,false}, num_partitions=2
while_cond {
param = (s32[], bf16[5,4,16], bf16[5,1,2,16]) parameter(0)
loop_index = s32[] get-tuple-element(param), index=0
c6 = s32[] constant(6)
ROOT cmp = pred[] compare(loop_index, c6), direction=LT
}
while_body {
param = (s32[], bf16[5,4,16], bf16[5,1,2,16]) parameter(0)
loop_index = s32[] get-tuple-element(param), index=0
partial_output = bf16[5,4,16] get-tuple-element(param), index=1
slice_input = bf16[5,1,2,16] get-tuple-element(param), index=2
c0 = s32[] constant(0)
c1 = s32[] constant(1)
next_loop_index = s32[] add(loop_index, c1)
loop_index_minus_one = s32[] subtract(loop_index, c1)
dynamic_slice = bf16[1,1,2,16] dynamic-slice(slice_input, loop_index_minus_one, c0, c0, c0), dynamic_slice_sizes={1,1,2,16}
dynamic_slice_reshape = bf16[1,2,16] reshape(dynamic_slice)
all_gather = bf16[1,4,16] all-gather(dynamic_slice_reshape), dimensions={1}, replica_groups={}
updated_partial_output = bf16[5,4,16] dynamic-update-slice(partial_output, all_gather, loop_index_minus_one, c0, c0)
ROOT tuple = (s32[], bf16[5,4,16], bf16[5,1,2,16]) tuple(next_loop_index, updated_partial_output, slice_input)
}
ENTRY entry {
c1 = s32[] constant(1)
p0 = bf16[5,4,16] parameter(0)
p1 = bf16[5,1,2,16] parameter(1)
tuple = (s32[], bf16[5,4,16], bf16[5,1,2,16]) tuple(c1, p0, p1)
while = (s32[], bf16[5,4,16], bf16[5,1,2,16]) while(tuple), condition=while_cond, body=while_body
ROOT gte = bf16[5,4,16] get-tuple-element(while), index=1
}
)";

CollectiveOpsComparePipelinedNonPipelined(hlo_string);
}

TEST_F(CollectiveOpsTestE2E, AllToAllQuantizeCollectiveQuantizer) {
absl::string_view kModuleReplicatedStr = R"(
HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={()->bf16[2]}, num_partitions=2
Expand Down

0 comments on commit c034a2e

Please sign in to comment.