diff --git a/xla/service/gpu/transforms/BUILD b/xla/service/gpu/transforms/BUILD index 842dffa0028a64..0c8705b059d84b 100644 --- a/xla/service/gpu/transforms/BUILD +++ b/xla/service/gpu/transforms/BUILD @@ -378,6 +378,40 @@ xla_cc_test( ], ) +cc_library( + name = "collective_select_folder", + srcs = ["collective_select_folder.cc"], + hdrs = ["collective_select_folder.h"], + deps = [ + "//xla:comparison_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "collective_select_folder_test", + srcs = ["collective_select_folder_test.cc"], + deps = [ + ":collective_select_folder", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", + ], +) + cc_library( name = "collective_permute_valid_iteration_annotator", srcs = ["collective_permute_valid_iteration_annotator.cc"], diff --git a/xla/service/gpu/transforms/collective_select_folder.cc b/xla/service/gpu/transforms/collective_select_folder.cc new file mode 100644 index 00000000000000..a336170c83284f --- /dev/null +++ b/xla/service/gpu/transforms/collective_select_folder.cc @@ -0,0 +1,164 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/collective_select_folder.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/comparison_util.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +using SourceTargetPair = std::pair; +using SourceTargetPairs = std::vector; + +struct SelectPredInfo { + int64_t constant; + Comparison::Direction direction; + HloOpcode device_id_type; // kReplicaId or kPartitionId + HloInstruction* true_operand; + HloInstruction* false_operand; +}; + +// Returns handy references to %constant, %true_operand, %false_operand of the +// select(broadcast(compare(current_device_id, constant)), true_operand, +// false_operand) +// or +// select(compare(current_device_id, constant), true_operand, +// false_operand) +std::optional GetPredSelectInfo(HloInstruction* select) { + if (select->opcode() != HloOpcode::kSelect) { + return std::nullopt; + } + + const HloCompareInstruction* compare; + if (select->operand(0)->opcode() == HloOpcode::kCompare) { + compare = Cast(select->operand(0)); + } else if (select->operand(0)->opcode() == HloOpcode::kBroadcast && + select->operand(0)->operand(0)->opcode() == HloOpcode::kCompare) { + compare = Cast(select->operand(0)->operand(0)); + } else { + return std::nullopt; + } + + bool is_replica_or_partition_compare = + (compare->operand(0)->opcode() == HloOpcode::kReplicaId || + compare->operand(0)->opcode() == HloOpcode::kPartitionId) && + compare->operand(1)->opcode() == HloOpcode::kConstant; + + if (!is_replica_or_partition_compare) return std::nullopt; + + int64_t id_value = + compare->operand(1)->literal().GetFirstInteger().value_or(-1); + + return SelectPredInfo{id_value, compare->direction(), + compare->operand(0)->opcode(), + select->mutable_operand(1), select->mutable_operand(2)}; +} + +bool IsUniqueSource(int64_t device_id, const SourceTargetPairs& pairs) { + if (pairs.size() == 1 && pairs[0].first == device_id) return true; + return false; +} + +bool IsNotPresentInSource(int64_t device_id, const SourceTargetPairs& pairs) { + for (const auto& pair : pairs) { + if (pair.first == device_id) return false; + } + return true; +} + +inline absl::StatusOr update(HloInstruction* cp, HloInstruction* data) { + TF_RETURN_IF_ERROR(cp->ReplaceOperandWith(0, data)); + return true; +} + +// We have to maintain integrity of relationship between partition/replica +// and collective-permute's channel_id. +// That is we can only fold select when +// 1. cp has channel_id and condition is based on partition_id +// 2. cp has no channel_id and condition is based on replica_id +// see enum class CollectiveOpGroupMode for details. +bool IsShardingConsistent(HloCollectivePermuteInstruction* cp, + HloOpcode device_id_type) { + auto id = cp->channel_id(); + return (device_id_type == HloOpcode::kPartitionId && id.has_value()) || + (device_id_type == HloOpcode::kReplicaId && !id.has_value()); +} + +// Recognizer the pattern and update if applicable. +absl::StatusOr TryFoldSelect(HloInstruction* in) { + if (in->opcode() != HloOpcode::kCollectivePermute) return false; + auto select_info_opt = GetPredSelectInfo(in->mutable_operand(0)); + if (!select_info_opt.has_value()) return false; + auto select_info = select_info_opt.value(); + + HloCollectivePermuteInstruction* cp = + Cast(in); + if (!IsShardingConsistent(cp, select_info.device_id_type)) return false; + + int64_t device_id = select_info.constant; + SourceTargetPairs pairs = cp->source_target_pairs(); + + if (select_info.direction == Comparison::Direction::kEq) { + if (IsUniqueSource(device_id, pairs)) { + return update(cp, select_info.true_operand); + } else if (IsNotPresentInSource(device_id, pairs)) { + return update(cp, select_info.false_operand); + } + } + + if (select_info.direction == Comparison::Direction::kNe) { + if (IsNotPresentInSource(device_id, pairs)) { + return update(cp, select_info.true_operand); + } else if (IsUniqueSource(device_id, pairs)) { + return update(cp, select_info.false_operand); + } + } + return false; +} + +} // namespace + +absl::StatusOr CollectiveSelectFolder::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + TF_ASSIGN_OR_RETURN(bool local_changed, TryFoldSelect(instruction)); + changed |= local_changed; + } + } + return changed; +} + +} // namespace xla diff --git a/xla/service/gpu/transforms/collective_select_folder.h b/xla/service/gpu/transforms/collective_select_folder.h new file mode 100644 index 00000000000000..3e14ecbf054e1b --- /dev/null +++ b/xla/service/gpu/transforms/collective_select_folder.h @@ -0,0 +1,69 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_SELECT_FOLDER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_SELECT_FOLDER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// When collective-permute operates on a comparison to a device id +// and the senders match the condition's branch +// we can link collective-permute to the original data skipping the comparison. +// For example +// condition = broadcast(compare(replica_id, X), direction=EQ +// data_snd = select(condition, compare_true_data, compare_false_data) +// rcv = collective-permute(data_snd compare_true_data), pairs={{X,0}} +// can be transformed to +// rcv = collective-permute(compare_true_data), pairs={{X,0}} +// +// The pass is *only* handling compare direction={EQ,NE}. +// The pass handles Compare with and without preceding Broadcast. +// +// This pass is particularly useful in the pipeline parallelism generated module +// such as: +// fwd_data = ... +// bwd_data = +// is_first_device = ... +// is_last_device = ... +// data_snd = select(is_last_device, bwd_data, fwd_data) +// bwd_data_rcv = collective-permute(data_snd), pairs={{3,0}} +// fwd_data_rcv = collective-permute(data_snd), pairs={{0,1},{1,2},{2,3}} +// ROOT data_rcv = select(is_first_device, bwd_data_rcv, fwd_data_rcv) +// +// After the transformation, the module will become: +// fwd_data_snd = ... +// bwd_data_snd = ... +// is_first_device = ... +// bwd_data_rcv = collective-permute(bwd_data_snd), pairs={{3,0}} +// fwd_data_rcv = collective-permute(fwd_data_snd), pairs={{0,1},{1,2},{2,3}} +// ROOT data_rcv = select(is_first_device, bwd_data_rcv, fwd_data_rcv) +class CollectiveSelectFolder : public HloModulePass { + public: + absl::string_view name() const override { return "collective-select-folder"; } + + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_SELECT_FOLDER_H_ diff --git a/xla/service/gpu/transforms/collective_select_folder_test.cc b/xla/service/gpu/transforms/collective_select_folder_test.cc new file mode 100644 index 00000000000000..21aca72d8eeea0 --- /dev/null +++ b/xla/service/gpu/transforms/collective_select_folder_test.cc @@ -0,0 +1,316 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/collective_select_folder.h" + +#include +#include +#include +#include + +#include +#include +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +using ::testing::HasSubstr; + +HloPrintOptions Least() { + HloPrintOptions options; + options.set_print_operand_shape(false) + .set_include_layout_in_shapes(false) + .set_print_percent(false); + return options; +} + +class CollectiveSelectFolderTest : public HloTestBase { + public: + using FixedMapping = + std::initializer_list>; + + absl::StatusOr> RunTranform( + bool expect_changed, std::string_view hlo_template, FixedMapping params) { + std::string hlo_string = absl::StrReplaceAll(hlo_template, params); + SCOPED_TRACE("Input HLO: " + hlo_string); + VLOG(7) << "Input HLO: " << hlo_string; + + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSIGN_OR_RETURN(bool changed, + RunHloPass(CollectiveSelectFolder(), module.get())); + VLOG(7) << "Output HLO: " << module->ToString(Least()); + EXPECT_EQ(changed, expect_changed); + return module; + } + + absl::Status ExpectNoTranform(std::string_view hlo_template) { + return RunTranform(/*expect_changed=*/false, hlo_template, {}).status(); + } +}; + +void VerifyDirectDataFeedSPMD(HloModule* module, + std::string_view expected_fwd_operand, + std::string_view expected_bwd_operand) { + auto root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kSelect); + EXPECT_EQ(root->operand(1)->opcode(), HloOpcode::kCollectivePermute); + EXPECT_EQ(root->operand(2)->opcode(), HloOpcode::kCollectivePermute); + // first cp is sending backward per template + EXPECT_THAT(root->operand(1)->operand(0)->name(), + HasSubstr(expected_bwd_operand)) + << root->operand(1)->name() << " is expected to operate on " + << expected_bwd_operand; + // second cp is sending forward per template + EXPECT_THAT(root->operand(2)->operand(0)->name(), + HasSubstr(expected_fwd_operand)) + << root->operand(2)->name() << " is expected to operate on " + << expected_fwd_operand; +} + +// HLO segment as would be generated in SPMD pipeline containing two collective +// permutes forming a cycle. +const char* kSPMD2cp = R"( + HloModule test + ENTRY circular_exchange { + in_tpl = (f32[16], f32[16]) parameter(0) + fwd_data = f32[16]{0} get-tuple-element(in_tpl), index=0 + bwd_data = f32[16]{0} get-tuple-element(in_tpl), index=1 + + c_first_id = u32[] constant($first_id_constant) + c_last_id = u32[] constant($last_id_constant) + repl_id = u32[] partition-id() + + pred_first_id = pred[] compare(repl_id, c_first_id), direction=EQ + is_first = pred[] broadcast(pred_first_id), dimensions={} + + pred_last_id = pred[] compare(repl_id, c_last_id), direction=EQ + is_last = pred[] broadcast(pred_last_id), dimensions={} + + // select data to send (redundant!) + data_snd = f32[16] select(is_last, bwd_data, fwd_data) + + bwd_data_rcv = f32[16] collective-permute(data_snd), channel_id=1, source_target_pairs=$backward_pairs + fwd_data_rcv = f32[16] collective-permute(data_snd), channel_id=2, source_target_pairs=$forward_pairs + ROOT data_rcv = f32[16] select(is_first, bwd_data_rcv, fwd_data_rcv) + } +)"; + +TEST_F(CollectiveSelectFolderTest, SimpleForwardCycle) { + TF_ASSERT_OK_AND_ASSIGN( + auto module, RunTranform(/*expect_changed=*/true, kSPMD2cp, + {{"$first_id_constant", "0"}, + {"$last_id_constant", "3"}, + {"$forward_pairs", "{{0,1},{1,2},{2,3}}"}, + {"$backward_pairs", "{{3,0}}"}})); + + VerifyDirectDataFeedSPMD(module.get(), "fwd_data", "bwd_data"); +} + +TEST_F(CollectiveSelectFolderTest, SimpleBackwardCycle) { + TF_ASSERT_OK_AND_ASSIGN( + auto module, RunTranform(/*expect_changed=*/true, kSPMD2cp, + {{"$first_id_constant", "3"}, + {"$last_id_constant", "0"}, + {"$forward_pairs", "{{3,2},{2,1},{1,0}}"}, + {"$backward_pairs", "{{0,3}}"}})); + VerifyDirectDataFeedSPMD(module.get(), "fwd_data", "bwd_data"); +} + +TEST_F(CollectiveSelectFolderTest, CompareNEForwardCycle) { + TF_ASSERT_OK_AND_ASSIGN( + auto module, RunTranform(/*expect_changed=*/true, kSPMD2cp, + {{"$first_id_constant", "0"}, + {"$last_id_constant", "3"}, + {"$forward_pairs", "{{0,1},{1,2},{2,3}}"}, + {"$backward_pairs", "{{3,0}}"}, + {"direction=EQ", "direction=NE"}})); + // flip of the condition should result in flip of input data + VerifyDirectDataFeedSPMD(module.get(), "bwd_data", "fwd_data"); +} + +// Forceful case when select constant is not equal to the backward edge. +// In this case, backward collective-permute is expected to be linked +// to fwd_data while forward collective-permute is expected remain linked +// to the select. +TEST_F(CollectiveSelectFolderTest, LastDeviceIdMismatch) { + TF_ASSERT_OK_AND_ASSIGN( + auto module, RunTranform(/*expect_changed=*/true, kSPMD2cp, + {{"$first_id_constant", "0"}, + {"$last_id_constant", "2"}, // mismatch + {"$forward_pairs", "{{0,1},{1,2},{2,3}}"}, + {"$backward_pairs", "{{3,0}}"}})); + VerifyDirectDataFeedSPMD(module.get(), "data_snd", "fwd_data"); +} + +const char* kSelectBasecase = R"( + HloModule test + ENTRY computation1 { + compare_true_data = f32[16] parameter(0) + compare_false_data = f32[16] parameter(1) + device_id_constant = u32[] constant($device_id_constant) + repl_id = u32[] replica-id() + + prd = pred[] compare(repl_id, device_id_constant), direction=$direction + bcast = pred[] broadcast(prd), dimensions={} + selected_data = f32[16] select(bcast, compare_true_data, compare_false_data) + ROOT data_rcv = f32[16] collective-permute(selected_data), source_target_pairs=$pairs + } +)"; + +TEST_F(CollectiveSelectFolderTest, EqualTrueBranchTransform) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + RunTranform(/*expect_changed=*/true, kSelectBasecase, + {{"$device_id_constant", "3"}, + {"$direction", "EQ"}, + {"$pairs", "{{3,0}}"}})); + auto root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->operand(0)->name(), "compare_true_data"); +} + +TEST_F(CollectiveSelectFolderTest, EqualFalseBranchTransform) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + RunTranform(/*expect_changed=*/true, kSelectBasecase, + {{"$device_id_constant", "3"}, + {"$direction", "EQ"}, + {"$pairs", "{{0,1},{1,2}}"}})); + auto root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->operand(0)->name(), "compare_false_data"); +} + +TEST_F(CollectiveSelectFolderTest, NotEqualFalseBranchTransform) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + RunTranform(/*expect_changed=*/true, kSelectBasecase, + {{"$device_id_constant", "3"}, + {"$direction", "NE"}, + {"$pairs", "{{3,0}}"}})); + auto root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->operand(0)->name(), "compare_false_data"); +} + +TEST_F(CollectiveSelectFolderTest, NotEqualTrueTrueTransform) { + TF_ASSERT_OK_AND_ASSIGN( + auto module, RunTranform(/*expect_changed=*/true, kSelectBasecase, + {{"$device_id_constant", "3"}, + {"$direction", "NE"}, + {"$pairs", "{{0,1},{1,2},{4,5},{5,6}}"}})); + auto root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->operand(0)->name(), "compare_true_data"); +} + +TEST_F(CollectiveSelectFolderTest, MoreThanOnePair_NotTransformed) { + // cp contains more than one pair and + // therefore is not identical to equal(1) + TF_ASSERT_OK(RunTranform(/*expect_changed=*/false, kSelectBasecase, + {{"$device_id_constant", "1"}, + {"$direction", "EQ"}, + {"$pairs", "{{0,1},{1,2}}"}})); + + // cp falls under not_qual(1) but has more than one pair + TF_ASSERT_OK(RunTranform(/*expect_changed=*/false, kSelectBasecase, + {{"$device_id_constant", "1"}, + {"$direction", "NE"}, + {"$pairs", "{{0,1},{1,2}}"}})); +} + +const char* kSelectNoBroadcast = R"( + HloModule test + ENTRY computation1 { + compare_true_data = f32[16] parameter(0) + compare_false_data = f32[16] parameter(1) + device_id_constant = u32[] constant($device_id_constant) + repl_id = u32[] replica-id() + + prd = pred[] compare(repl_id, device_id_constant), direction=$direction + selected_data = f32[16] select(prd, compare_true_data, compare_false_data) + ROOT data_rcv = f32[16] collective-permute(selected_data), source_target_pairs=$pairs + } +)"; + +TEST_F(CollectiveSelectFolderTest, SelectNoBroadcastTransform) { + TF_ASSERT_OK_AND_ASSIGN( + auto module, RunTranform(/*expect_changed=*/true, kSelectNoBroadcast, + {{"$device_id_constant", "3"}, + {"$direction", "EQ"}, + {"$pairs", "{{3,0}}"}})); + auto root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->operand(0)->name(), "compare_true_data"); +} + +TEST_F(CollectiveSelectFolderTest, ReplicaIdChannelIdMismatch_NotTransformed) { + const absl::string_view hlo = R"( + HloModule test + ENTRY computation1 { + compare_true_data = f32[16] parameter(0) + compare_false_data = f32[16] parameter(1) + device_id_constant = u32[] constant(0) + repl_id = u32[] replica-id() + + prd = pred[] compare(repl_id, device_id_constant), direction=EQ + selected_data = f32[16] select(prd, compare_true_data, compare_false_data) + ROOT data_rcv = f32[16] collective-permute(selected_data), channel_id=1, source_target_pairs={{0,1}} + } + )"; + TF_ASSERT_OK(ExpectNoTranform(hlo)); +} + +TEST_F(CollectiveSelectFolderTest, PartIdChannelIdMismatch_NotTransformed) { + const absl::string_view hlo = R"( + HloModule test + ENTRY computation1 { + compare_true_data = f32[16] parameter(0) + compare_false_data = f32[16] parameter(1) + device_id_constant = u32[] constant(0) + repl_id = u32[] partition-id() + + prd = pred[] compare(repl_id, device_id_constant), direction=EQ + selected_data = f32[16] select(prd, compare_true_data, compare_false_data) + ROOT data_rcv = f32[16] collective-permute(selected_data), source_target_pairs={{0,1}} + } + )"; + TF_ASSERT_OK(ExpectNoTranform(hlo)); +} + +TEST_F(CollectiveSelectFolderTest, WrongNesting_NotTransformed) { + const absl::string_view hlo = R"( + HloModule test + ENTRY computation1 { + compare_true_data = f32[16] parameter(0) + compare_false_data = f32[16] parameter(1) + device_id_constant = u32[] constant(0) + repl_id = u32[] replica-id() + sum = u32[] add(device_id_constant, repl_id) // additional op + + prd = pred[] compare(sum, device_id_constant), direction=EQ + selected_data = f32[16] select(prd, compare_true_data, compare_false_data) + ROOT data_rcv = f32[16] collective-permute(selected_data), source_target_pairs={{0,1}} + } + )"; + TF_ASSERT_OK(ExpectNoTranform(hlo)); +} + +} // namespace +} // namespace xla