Skip to content

Commit

Permalink
[XLA:ALGEBRAIC_SIMPLIFIER] Turn constant all-gather into broadcast
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698114685
  • Loading branch information
blakehechtman authored and Google-ML-Automation committed Nov 19, 2024
1 parent 1edbb92 commit 3f032e1
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 0 deletions.
12 changes: 12 additions & 0 deletions xla/hlo/transforms/simplifiers/algebraic_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1221,6 +1221,18 @@ absl::StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare(
return false;
}

absl::Status AlgebraicSimplifierVisitor::HandleAllGather(
HloInstruction* all_gather) {
if (all_gather->shape().IsArray() &&
Match(all_gather->mutable_operand(0),
m::Broadcast(m::ConstantScalar()))) {
return ReplaceWithNewInstruction(
all_gather,
all_gather->mutable_operand(0)->CloneWithNewShape(all_gather->shape()));
}
return absl::OkStatus();
}

absl::Status AlgebraicSimplifierVisitor::HandleAllToAll(
HloInstruction* all_to_all) {
if (all_to_all->shape().IsArray() &&
Expand Down
2 changes: 2 additions & 0 deletions xla/hlo/transforms/simplifiers/algebraic_simplifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,8 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {

absl::Status HandleAdd(HloInstruction* add) override;

absl::Status HandleAllGather(HloInstruction* all_gather) override;

absl::Status HandleAllToAll(HloInstruction* all_to_all) override;

absl::Status HandleAnd(HloInstruction* logical_and) override;
Expand Down
14 changes: 14 additions & 0 deletions xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12586,6 +12586,20 @@ TEST_F(AlgebraicSimplifierTest, BitcastBroadcastDifferentLayout) {
EXPECT_FALSE(simplifier.Run(module.get()).value());
}

TEST_F(AlgebraicSimplifierTest, AllGatherOfBroadcast) {
const char* kModuleStr = R"(
HloModule m
test {
z = f32[] constant(0)
b = f32[4,4] broadcast(z), dimensions={}
ROOT ag = f32[16,4] all-gather(b), dimensions={0}, replica_groups={{0, 1, 2, 3}}
})";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Broadcast(m::Constant())));
}

TEST_F(AlgebraicSimplifierTest, TrivialMin) {
const char* kModuleStr = R"(
HloModule m
Expand Down

0 comments on commit 3f032e1

Please sign in to comment.