Skip to content

Commit

Permalink
Propagate shardings to the root instruction of while condition.
Browse files Browse the repository at this point in the history
### Sharding Propagation
Although the root instruction of while condition is in the shape `pred[]`. It can have the following meaningful shardings.
1. {replicated}
2. {manual}
3. subgroup sharding, e.g., {devices=[2,2]<=[4] last_tile_dims={manual, replicated}}

Thus, we need to propagate the sharding to the root such that the partitioner can correctly handle the while condition.

### SPMD Partitioner
The condition root must be replicated so that all partitions follow the same control flow. It can also have some tile dims to be manual. Thus, we need to replicate all data dims and keep the manual dims.

PiperOrigin-RevId: 658455111
  • Loading branch information
ZixuanJiang authored and copybara-github committed Aug 1, 2024
1 parent 0609842 commit e500f12
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 15 deletions.
7 changes: 4 additions & 3 deletions xla/service/sharding_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ bool SupportSpatialPartitioning(
computation_map.find(instruction->parent()) == computation_map.end() &&
!(is_entry_root && allow_spmd_sharding_propagation_to_output)) {
// We don't support sharding the root instruction of a computation yet,
// unless the computation is a while body.
// unless the computation is in computation_map.
return false;
}

Expand Down Expand Up @@ -2954,8 +2954,8 @@ absl::StatusOr<bool> ShardingPropagation::Run(
}
}

// Populate computation_map in order to associate while bodies to their
// while instructions.
// Populate computation_map in order to associate while bodies and conditions
// to their while instructions.
for (auto computation : module->computations(execution_threads)) {
for (auto instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kWhile ||
Expand All @@ -2982,6 +2982,7 @@ absl::StatusOr<bool> ShardingPropagation::Run(
}
if (instruction->opcode() == HloOpcode::kWhile) {
computation_map[instruction->while_body()] = instruction;
computation_map[instruction->while_condition()] = instruction;
} else {
for (HloComputation* c : instruction->called_computations()) {
computation_map[c] = instruction;
Expand Down
54 changes: 54 additions & 0 deletions xla/service/sharding_propagation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2757,6 +2757,60 @@ ENTRY %entry {
}
}

TEST_F(ShardingPropagationTest, PropagateShardingInWhileCondition) {
const char* const hlo_string = R"(
HloModule module
%cond {
%vars.cond = (u32[], f32[]) parameter(0)
%count.cond = u32[] get-tuple-element(%vars.cond), index=0
%limit = u32[] constant(10)
ROOT %lt = pred[] compare(%count.cond, %limit), direction=LT
}
%body {
%vars = (u32[], f32[]) parameter(0)
%count = u32[] get-tuple-element(%vars), index=0
%acc = f32[] get-tuple-element(%vars), index=1
%one = u32[] constant(1)
%count.1 = u32[] add(u32[] %count, u32[] %one)
%acc.1 = f32[] add(f32[] %acc, f32[] %acc)
ROOT %tuple = (u32[], f32[]) tuple(%count.1, %acc.1)
}
ENTRY %entry {
%p0 = f32[] parameter(0), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
%zero = u32[] constant(0), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
%init = (u32[], f32[]) tuple(%zero, %p0)
ROOT %while = (u32[], f32[]) while(%init), body=%body, condition=%cond
})";

TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
ShardingPropagation(/*is_spmd=*/false, /*propagate_metadata=*/false,
/*allow_spmd_sharding_propagation_to_output=*/{true})
.Run(module.get()));
EXPECT_TRUE(changed);
HloSharding single_sharding =
ParseSharding("{devices=[2,2]<=[4] last_tile_dims={manual, replicated}}")
.value();
HloSharding tuple_sharding = HloSharding::SingleTuple(
module->entry_computation()->root_instruction()->shape(),
single_sharding);

for (const HloComputation* computation : module->computations()) {
for (const HloInstruction* instruction : computation->instructions()) {
EXPECT_TRUE(instruction->has_sharding());
EXPECT_EQ(instruction->sharding(), instruction->shape().IsTuple()
? tuple_sharding
: single_sharding);
}
}
}

TEST_P(ParameterizedMetadataTest, WhileGetShardingFromRecvInBody) {
const char* const hlo_string = R"(
HloModule module
Expand Down
25 changes: 13 additions & 12 deletions xla/service/spmd/spmd_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4041,20 +4041,21 @@ absl::Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) {
const HloSharding& sharding = hlo->sharding();

// Shardings for the body parameter, body root, and cond parameter must be
// the same, and the condition root must be replicated so that all partitions
// follow the same control flow.
// the same.
hlo->while_condition()->parameter_instruction(0)->set_sharding(sharding);
hlo->while_body()->parameter_instruction(0)->set_sharding(sharding);
const HloSharding& cond_root_sharding =
hlo->while_condition()->root_instruction()->sharding();
TF_RETURN_IF_ERROR(partitioner_
->PartitionComputation(hlo->while_condition(),
cond_root_sharding.IsManual()
? cond_root_sharding
: HloSharding::Replicate(),
next_channel_id_, logger_,
call_graph_)
.status());

// The condition root must be replicated so that all partitions follow the
// same control flow.
HloInstruction* cond_root = hlo->while_condition()->root_instruction();
const HloSharding cond_root_sharding =
hlo_sharding_util::ReplicateAllDataDims(cond_root->sharding());
cond_root->set_sharding(cond_root_sharding);
TF_RETURN_IF_ERROR(
partitioner_
->PartitionComputation(hlo->while_condition(), cond_root_sharding,
next_channel_id_, logger_, call_graph_)
.status());
TF_RETURN_IF_ERROR(partitioner_
->PartitionComputation(hlo->while_body(), sharding,
next_channel_id_, logger_,
Expand Down
30 changes: 30 additions & 0 deletions xla/service/spmd/spmd_partitioner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4474,6 +4474,36 @@ ENTRY entry {
EXPECT_THAT(root, AllOf(op::While(zero), op::Shape("s32[]")));
}

TEST_P(SpmdPartitioningTest, WhilePartialManual) {
absl::string_view hlo_string = R"(
HloModule module

LoopCond {
x = s32[] parameter(0), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
const = s32[] constant(5), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
ROOT lt = pred[] compare(x, const), direction=LT, sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
}

Inc {
x = s32[] parameter(0), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
const = s32[] constant(1), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
ROOT add = s32[] add(x, const), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
}

ENTRY entry {
zero = s32[] parameter(0), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
ROOT while = s32[] while(zero), body=Inc, condition=LoopCond, sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}}
})";

TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();

auto zero = AllOf(op::Parameter(0), op::Shape("s32[]"));
const auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::While(zero), op::Shape("s32[]")));
}

TEST_P(SpmdPartitioningTest, TestWhileFrontendAttributes) {
absl::string_view hlo_string = R"(
HloModule module
Expand Down

0 comments on commit e500f12

Please sign in to comment.