Skip to content

Commit

Permalink
Remove redundant test and add channel id test.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712616509
  • Loading branch information
toli-y authored and Google-ML-Automation committed Jan 6, 2025
1 parent 85fb36e commit 4046459
Showing 1 changed file with 40 additions and 58 deletions.
98 changes: 40 additions & 58 deletions xla/service/collective_permute_decomposer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "xla/service/collective_permute_decomposer.h"

#include <cstdint>
#include <memory>

#include <gmock/gmock.h>
Expand Down Expand Up @@ -42,25 +43,55 @@ using Pass = CollectivePermuteDecomposer;

class DecomposerTest : public HloHardwareIndependentTestBase {
protected:
void AssertNoTranform(absl::string_view hlo) {
TF_ASSERT_OK(RunAndCheckHloRewrite(hlo, Pass(0), false));
void AssertNoTranform(absl::string_view hlo, int64_t threshold = 0) {
TF_ASSERT_OK(RunAndCheckHloRewrite(hlo, Pass(threshold), false));
};
auto Transform(absl::string_view hlo) {
return RunAndCheckHloRewrite(hlo, Pass(0), true);
auto Transform(absl::string_view hlo, int64_t threshold = 0) {
return RunAndCheckHloRewrite(hlo, Pass(threshold), true);
};
void AssertTransform(absl::string_view hlo, int64_t threshold = 0) {
TF_ASSERT_OK(RunAndCheckHloRewrite(hlo, Pass(threshold), true));
}
};

TEST_F(DecomposerTest, WithCycleNotTransformed) {
AssertNoTranform(R"(HloModule test
ENTRY test_computation {
p = u32[] replica-id()
ROOT cp = u32[] collective-permute(p), channel_id=1,
data = u32[] parameter(0)
ROOT cp = u32[] collective-permute(data), channel_id=1,
source_target_pairs={{0,1}, {1,0}}
}
)");
})");
}

TEST_F(DecomposerTest, ThresholdNotTransformed) {
AssertNoTranform(R"(HloModule test
ENTRY test_computation {
p = u32[] replica-id()
ROOT cp = u32[] collective-permute(p),
source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}
})",
8);
}

TEST_F(DecomposerTest, Basic) {
AssertTransform(R"(HloModule test
ENTRY test_computation {
data = u32[] parameter(0)
ROOT cp = u32[] collective-permute(data), channel_id=1,
source_target_pairs={{0,1}, {1,2}}
})");
}

TEST_F(DecomposerTest, NoChannelId) {
AssertTransform(R"(HloModule test
ENTRY test_computation {
data = u32[] parameter(0)
ROOT cp = u32[] collective-permute(data),
source_target_pairs={{0,1}, {1,2}}
})");
}

TEST_F(DecomposerTest, TransformedExplicitChannelId) {
TEST_F(DecomposerTest, WithMetadata) {
absl::string_view hlo = R"(
HloModule test
ENTRY test_computation {
Expand Down Expand Up @@ -113,55 +144,6 @@ TEST_F(DecomposerTest, TransformedExplicitChannelId) {
EXPECT_THAT(root, op::GetTupleElement(recv_done, 0));
}

TEST_F(DecomposerTest, TransformedDefaultNoChannelId) {
absl::string_view hlo = R"(
HloModule test
ENTRY test_computation {
p = u32[] replica-id()
ROOT cp = u32[] collective-permute(p),
source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, Transform(hlo));

HloInstruction* after_all = FindInstruction(module.get(), "after-all");
HloInstruction* recv = FindInstruction(module.get(), "recv");
EXPECT_EQ(recv->operand(0), after_all);
EXPECT_FALSE(recv->channel_id().has_value());
EXPECT_THAT(
recv->ToString(),
HasSubstr(
"_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}"));
HloInstruction* recv_done = FindInstruction(module.get(), "recv-done");
EXPECT_EQ(recv_done->operand(0), recv);

HloInstruction* send = FindInstruction(module.get(), "send");
EXPECT_EQ(send->operand(1), after_all);
EXPECT_FALSE(send->channel_id().has_value());
EXPECT_THAT(
send->ToString(),
HasSubstr(
"_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}"));
HloInstruction* send_done = FindInstruction(module.get(), "send-done");
EXPECT_EQ(send_done->operand(0), send);

HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::GetTupleElement(recv_done, 0));
}

TEST_F(DecomposerTest, ThresholdNotTransformed) {
absl::string_view hlo = R"(HloModule test
ENTRY test_computation {
p = u32[] replica-id()
ROOT cp = u32[] collective-permute(p), channel_id=1,
source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}},
metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
})";
TF_ASSERT_OK(
RunAndCheckHloRewrite(hlo, Pass(/*threshold_in_bytes=*/8), false));
}

TEST_F(DecomposerTest, Pipeline1) {
absl::string_view hlo = R"(
HloModule module
Expand Down

0 comments on commit 4046459

Please sign in to comment.