Skip to content

Commit

Permalink
Add a test.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 674515588
  • Loading branch information
amitsabne1 authored and Google-ML-Automation committed Sep 14, 2024
1 parent 29bd19c commit f8423a3
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions xla/service/algebraic_simplifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11819,10 +11819,34 @@ TEST_F(AlgebraicSimplifierTest, ReduceOfConstantBroadcastS32) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
auto clone = m->Clone();
HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
EXPECT_TRUE(simplifier.Run(m.get()).value());
std::cout << m->ToString() << std::endl;
int64_t reduce_count =
absl::c_count_if(m->entry_computation()->instructions(),
HloPredicateIsOp<HloOpcode::kReduce>);
// Expect no Reduce operation after simplification.
EXPECT_EQ(0, reduce_count);
}

TEST_F(AlgebraicSimplifierTest, TrivialReduce) {
const std::string hlo_string = R"(
HloModule test
add_s32 {
p0 = s32[] parameter(0)
p1 = s32[] parameter(1)
ROOT r = s32[] add(p0, p1)
}
ENTRY test.1 {
bcast = s32[1,7,7,1] parameter(0)
init = s32[] constant(0)
ROOT out = s32[1,7,7] reduce(bcast, init), dimensions={3}, to_apply=add_s32
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
AlgebraicSimplifierOptions options = default_options_;
options.set_is_layout_sensitive(false);
HloPassFix<AlgebraicSimplifier> simplifier(options);
EXPECT_TRUE(simplifier.Run(m.get()).value());
int64_t reduce_count =
absl::c_count_if(m->entry_computation()->instructions(),
HloPredicateIsOp<HloOpcode::kReduce>);
Expand All @@ -11846,7 +11870,6 @@ TEST_F(AlgebraicSimplifierTest, ReduceOfConstantBroadcastBF16) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
auto clone = m->Clone();
HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
EXPECT_TRUE(simplifier.Run(m.get()).value());
int64_t reduce_count =
Expand Down

0 comments on commit f8423a3

Please sign in to comment.