From 40464599e51a8dbc7e368537e8f1785172808063 Mon Sep 17 00:00:00 2001 From: Toli Yevtushenko Date: Mon, 6 Jan 2025 12:15:09 -0800 Subject: [PATCH] Remove redundant test and add channel id test. PiperOrigin-RevId: 712616509 --- .../collective_permute_decomposer_test.cc | 98 ++++++++----------- 1 file changed, 40 insertions(+), 58 deletions(-) diff --git a/xla/service/collective_permute_decomposer_test.cc b/xla/service/collective_permute_decomposer_test.cc index 85e13e8085411..974d95bf45c82 100644 --- a/xla/service/collective_permute_decomposer_test.cc +++ b/xla/service/collective_permute_decomposer_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/collective_permute_decomposer.h" +#include #include #include @@ -42,25 +43,55 @@ using Pass = CollectivePermuteDecomposer; class DecomposerTest : public HloHardwareIndependentTestBase { protected: - void AssertNoTranform(absl::string_view hlo) { - TF_ASSERT_OK(RunAndCheckHloRewrite(hlo, Pass(0), false)); + void AssertNoTranform(absl::string_view hlo, int64_t threshold = 0) { + TF_ASSERT_OK(RunAndCheckHloRewrite(hlo, Pass(threshold), false)); }; - auto Transform(absl::string_view hlo) { - return RunAndCheckHloRewrite(hlo, Pass(0), true); + auto Transform(absl::string_view hlo, int64_t threshold = 0) { + return RunAndCheckHloRewrite(hlo, Pass(threshold), true); }; + void AssertTransform(absl::string_view hlo, int64_t threshold = 0) { + TF_ASSERT_OK(RunAndCheckHloRewrite(hlo, Pass(threshold), true)); + } }; TEST_F(DecomposerTest, WithCycleNotTransformed) { AssertNoTranform(R"(HloModule test ENTRY test_computation { - p = u32[] replica-id() - ROOT cp = u32[] collective-permute(p), channel_id=1, + data = u32[] parameter(0) + ROOT cp = u32[] collective-permute(data), channel_id=1, source_target_pairs={{0,1}, {1,0}} - } - )"); + })"); +} + +TEST_F(DecomposerTest, ThresholdNotTransformed) { + AssertNoTranform(R"(HloModule test + ENTRY test_computation { + p = u32[] replica-id() + ROOT cp = u32[] collective-permute(p), + source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}} + })", + 8); +} + +TEST_F(DecomposerTest, Basic) { + AssertTransform(R"(HloModule test + ENTRY test_computation { + data = u32[] parameter(0) + ROOT cp = u32[] collective-permute(data), channel_id=1, + source_target_pairs={{0,1}, {1,2}} + })"); +} + +TEST_F(DecomposerTest, NoChannelId) { + AssertTransform(R"(HloModule test + ENTRY test_computation { + data = u32[] parameter(0) + ROOT cp = u32[] collective-permute(data), + source_target_pairs={{0,1}, {1,2}} + })"); } -TEST_F(DecomposerTest, TransformedExplicitChannelId) { +TEST_F(DecomposerTest, WithMetadata) { absl::string_view hlo = R"( HloModule test ENTRY test_computation { @@ -113,55 +144,6 @@ TEST_F(DecomposerTest, TransformedExplicitChannelId) { EXPECT_THAT(root, op::GetTupleElement(recv_done, 0)); } -TEST_F(DecomposerTest, TransformedDefaultNoChannelId) { - absl::string_view hlo = R"( - HloModule test - ENTRY test_computation { - p = u32[] replica-id() - ROOT cp = u32[] collective-permute(p), - source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}} - } - )"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, Transform(hlo)); - - HloInstruction* after_all = FindInstruction(module.get(), "after-all"); - HloInstruction* recv = FindInstruction(module.get(), "recv"); - EXPECT_EQ(recv->operand(0), after_all); - EXPECT_FALSE(recv->channel_id().has_value()); - EXPECT_THAT( - recv->ToString(), - HasSubstr( - "_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}")); - HloInstruction* recv_done = FindInstruction(module.get(), "recv-done"); - EXPECT_EQ(recv_done->operand(0), recv); - - HloInstruction* send = FindInstruction(module.get(), "send"); - EXPECT_EQ(send->operand(1), after_all); - EXPECT_FALSE(send->channel_id().has_value()); - EXPECT_THAT( - send->ToString(), - HasSubstr( - "_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}")); - HloInstruction* send_done = FindInstruction(module.get(), "send-done"); - EXPECT_EQ(send_done->operand(0), send); - - HloInstruction* root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::GetTupleElement(recv_done, 0)); -} - -TEST_F(DecomposerTest, ThresholdNotTransformed) { - absl::string_view hlo = R"(HloModule test - ENTRY test_computation { - p = u32[] replica-id() - ROOT cp = u32[] collective-permute(p), channel_id=1, - source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}, - metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35} - })"; - TF_ASSERT_OK( - RunAndCheckHloRewrite(hlo, Pass(/*threshold_in_bytes=*/8), false)); -} - TEST_F(DecomposerTest, Pipeline1) { absl::string_view hlo = R"( HloModule module