Skip to content

Commit

Permalink
Add a rule to algebraic simplifier to replace A - A with 0.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675732243
  • Loading branch information
fhoushmand authored and Google-ML-Automation committed Sep 17, 2024
1 parent 891d972 commit 9f64d25
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
4 changes: 4 additions & 0 deletions xla/service/algebraic_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2072,6 +2072,10 @@ absl::Status AlgebraicSimplifierVisitor::HandleConstant(
absl::Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
HloInstruction *lhs, *rhs;
CHECK(Match(sub, m::Subtract(m::Op(&lhs), m::Op(&rhs))));
// A - A => 0
if (lhs == rhs) {
return ReplaceInstruction(sub, MakeScalarLike(sub, 0));
}
// A - 0 => A
VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString();
if (IsAll(rhs, 0) && ReplaceInstructionIfCompatible(sub, lhs)) {
Expand Down
30 changes: 30 additions & 0 deletions xla/service/algebraic_simplifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,36 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeBroadcastedConstants) {
m::ConstantScalar(2.0))))));
}

TEST_F(AlgebraicSimplifierTest, ReplaceSubtractOfEqualOperandsWithZero) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = f32[] parameter(0)
ROOT sub = f32[] subtract(p0, p0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::ConstantScalar(0.0)));
}

TEST_F(AlgebraicSimplifierTest,
ReplaceSubtractOfEqualOperandsWithBroadcastZero) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = f32[512,20] parameter(0)
ROOT sub = f32[512,20] subtract(p0, p0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
std::cout << m->ToString() << std::endl;
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Broadcast()));
}

TEST_F(AlgebraicSimplifierTest, SubAddReassociateMergeConstants) {
const char* kModuleStr = R"(
HloModule m
Expand Down

0 comments on commit 9f64d25

Please sign in to comment.