From 8e1e8a959c412c20ce4d094a93f77b89b7663e44 Mon Sep 17 00:00:00 2001 From: Patrick Toulme <135739773+ptoulme-aws@users.noreply.github.com> Date: Mon, 16 Sep 2024 18:54:40 -0700 Subject: [PATCH] PR #15403: Handle multiple users in all-gather dynamic-slice simplification. Add AllGatherDynamicSliceSimplifier pass Imported from GitHub PR https://github.com/openxla/xla/pull/15403 I have found in some models that have poor SPMD partitioning the below pattern. ``` all-gather.1 = all-gather(x) dot.1 = dot(all-gather.1, y) dynamic-slice.1 = dynamic-slice(all-gather.1) // can be cancelled ``` In this case, the all-gather has multiple users but the dynamic-slice can be cancelled. This is applicable to all-reduce and reduce-scatter also. My changes now support multiple users, but it also depends how this utility is used by internal TPU compiler and the GPU ReduceScatterCreator pass. My changes assume the cancellation is run like this -- 1. Find a dynamic-slice 2. Check if dynamic-slice can be cancelled 3. Delete dynamic-slice but do not delete the collective 4. The collective is deleted by the DCE pass if it has no users The above workflow then supports removing dynamic-slices even if the collective has multiple users. The above is what we are using in our internal Neuron workflow. Interested to hear thoughts on this. Copybara import of the project: -- f518bd6e3164aa10b60b4689f2aa2ee8d8faa7ae by ptoulme-aws : Handle multiple users in all-gather dynamic-slice simplification. Add AllGatherDynamicSliceSimplifier pass Merging this change closes #15403 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15403 from ptoulme-aws:multiple_user_collectives f518bd6e3164aa10b60b4689f2aa2ee8d8faa7ae PiperOrigin-RevId: 675370754 --- xla/service/collective_opt_utils.cc | 25 +- xla/service/collective_opt_utils.h | 10 +- xla/service/gpu/BUILD | 1 + xla/service/gpu/gpu_compiler.cc | 2 + xla/service/gpu/transforms/BUILD | 25 ++ .../all_gather_dynamic_slice_simplifier.cc | 83 +++++++ .../all_gather_dynamic_slice_simplifier.h | 48 ++++ ...ll_gather_dynamic_slice_simplifier_test.cc | 233 ++++++++++++++++++ 8 files changed, 417 insertions(+), 10 deletions(-) create mode 100644 xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.cc create mode 100644 xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h create mode 100644 xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier_test.cc diff --git a/xla/service/collective_opt_utils.cc b/xla/service/collective_opt_utils.cc index 13b7d553adb22..8da8a37155b87 100644 --- a/xla/service/collective_opt_utils.cc +++ b/xla/service/collective_opt_utils.cc @@ -322,12 +322,14 @@ bool AllGatherDynamicSliceCancellation( const HloAllGatherInstruction* ag, int64_t num_partitions, int64_t num_replicas, bool allow_multiple_split_dims, bool allow_intervening_reshape, int64_t min_rank, - HloPredicate match_partition_id, HloPredicate match_replica_id) { + HloPredicate match_partition_id, HloPredicate match_replica_id, + bool allow_intervening_bitcast, bool allow_multiple_users) { auto spec = MatchWithDynamicSlice( ag, num_partitions, num_replicas, allow_multiple_split_dims, allow_intervening_reshape, min_rank, match_partition_id, match_replica_id, ag->constrain_layout(), ag->use_global_device_ids(), - ag->channel_id() && ag->opcode() == HloOpcode::kAllGather); + ag->channel_id() && ag->opcode() == HloOpcode::kAllGather, + allow_intervening_bitcast, allow_multiple_users); if (spec.has_value()) { return true; } @@ -340,7 +342,7 @@ std::optional MatchWithDynamicSlice( bool allow_intervening_reshape, int64_t min_rank, HloPredicate match_partition_id, HloPredicate match_replica_id, bool is_constrain_layout, bool use_global_device_ids, bool is_cross_module, - bool allow_intervening_bitcast) { + bool allow_intervening_bitcast, bool allow_multiple_users) { if (!instruction->shape().IsArray() || is_constrain_layout || (is_cross_module && !instruction->GetModule()->config().use_spmd_partitioning())) { @@ -354,8 +356,8 @@ std::optional MatchWithDynamicSlice( << " excluding trivial dimensions " << instruction->ToString(); return std::nullopt; } - if (instruction->user_count() != 1) { - VLOG(2) << "All-gather user_count > 1 " << instruction->ToString(); + if (!allow_multiple_users && instruction->user_count() != 1) { + VLOG(2) << "All-gather user_count != 1 " << instruction->ToString(); return std::nullopt; } if (instruction->replica_groups().size() > 1) { @@ -371,8 +373,19 @@ std::optional MatchWithDynamicSlice( return std::nullopt; } } - + // Always assume first user to start. HloInstruction* user = instruction->users()[0]; + if (allow_multiple_users) { + // If we find a reshape or dynamic-slice use that. + for (auto* some_user : instruction->users()) { + if ((allow_intervening_reshape && + some_user->opcode() == HloOpcode::kReshape) || + some_user->opcode() == HloOpcode::kDynamicSlice) { + user = some_user; + break; + } + } + } HloInstruction* reshape = nullptr; if (allow_intervening_reshape && user->opcode() == HloOpcode::kReshape) { // Allow the intervening reshape if it reshapes just the non scattered diff --git a/xla/service/collective_opt_utils.h b/xla/service/collective_opt_utils.h index 983969e41e607..6131028d5f684 100644 --- a/xla/service/collective_opt_utils.h +++ b/xla/service/collective_opt_utils.h @@ -43,16 +43,17 @@ std::optional MatchReduceScatter( HloPredicate match_replica_id = HloPredicateIsOp, bool allow_intervening_bitcast = false); -// Check whether AG(ICI) and its single user DS(ICI) can be canceled out. +// Check whether AG(ICI) and its user DS(ICI) can be canceled out. bool AllGatherDynamicSliceCancellation( const HloAllGatherInstruction* ag, int64_t num_partitions, int64_t num_replicas, bool allow_multiple_split_dims = false, bool allow_intervening_reshape = false, int64_t min_rank = 1, HloPredicate match_partition_id = HloPredicateIsOp, - HloPredicate match_replica_id = HloPredicateIsOp); + HloPredicate match_replica_id = HloPredicateIsOp, + bool allow_intervening_bitcast = false, bool allow_multiple_users = false); // Check if a given instruction (AllReduce or AllGather) matches a DynamicSlice; -// the DynamicSlice has to be the only user of the given instruction. +// the DynamicSlice has to be the user of the given instruction. std::optional MatchWithDynamicSlice( const HloChannelInstruction* instruction, int64_t num_partitions, int64_t num_replicas, bool allow_multiple_split_dims = false, @@ -60,7 +61,8 @@ std::optional MatchWithDynamicSlice( HloPredicate match_partition_id = HloPredicateIsOp, HloPredicate match_replica_id = HloPredicateIsOp, bool is_constrain_layout = false, bool use_global_device_ids = false, - bool is_cross_module = false, bool allow_intervening_bitcast = false); + bool is_cross_module = false, bool allow_intervening_bitcast = false, + bool allow_multiple_users = false); } // namespace xla diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 151417710d85c..ad89a46c852e0 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -1416,6 +1416,7 @@ cc_library( "//xla/service/gpu/runtime:thunk", "//xla/service/gpu/transforms:algebraic_simplifier", "//xla/service/gpu/transforms:algorithm_checker", + "//xla/service/gpu/transforms:all_gather_dynamic_slice_simplifier", "//xla/service/gpu/transforms:all_gather_optimizer", "//xla/service/gpu/transforms:all_reduce_blueconnect", "//xla/service/gpu/transforms:all_reduce_splitter", diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 725556404b47b..02d2e0833158a 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -142,6 +142,7 @@ limitations under the License. #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/transforms/algebraic_simplifier.h" #include "xla/service/gpu/transforms/algorithm_checker.h" +#include "xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h" #include "xla/service/gpu/transforms/all_gather_optimizer.h" #include "xla/service/gpu/transforms/all_reduce_blueconnect.h" #include "xla/service/gpu/transforms/all_reduce_splitter.h" @@ -903,6 +904,7 @@ absl::Status RunCollectiveOptimizationPasses( HloPassPipeline collectives_pipeline("collective-optimizations"); collectives_pipeline.AddPass(); collectives_pipeline.AddPass(); + collectives_pipeline.AddPass(); collectives_pipeline.AddPass(); collectives_pipeline.AddPass( debug_options.xla_gpu_enable_reassociation_for_converted_ar()); diff --git a/xla/service/gpu/transforms/BUILD b/xla/service/gpu/transforms/BUILD index c481e29083b5d..842dffa0028a6 100644 --- a/xla/service/gpu/transforms/BUILD +++ b/xla/service/gpu/transforms/BUILD @@ -310,6 +310,31 @@ xla_cc_test( ], ) +cc_library( + name = "all_gather_dynamic_slice_simplifier", + srcs = ["all_gather_dynamic_slice_simplifier.cc"], + hdrs = ["all_gather_dynamic_slice_simplifier.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:collective_opt_utils", + "//xla/service:hlo_creation_utils", + "//xla/service:op_expander_pass", + ], +) + +xla_cc_test( + name = "all_gather_dynamic_slice_simplifier_test", + srcs = ["all_gather_dynamic_slice_simplifier_test.cc"], + deps = [ + ":all_gather_dynamic_slice_simplifier", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_matchers", + "//xla/tests:hlo_test_base", + "@tsl//tsl/platform:test_main", + ], +) + cc_library( name = "collective_permute_cycle_decomposer", srcs = ["collective_permute_cycle_decomposer.cc"], diff --git a/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.cc b/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.cc new file mode 100644 index 0000000000000..4035b80606cdf --- /dev/null +++ b/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.cc @@ -0,0 +1,83 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h" + +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/service/collective_opt_utils.h" + +namespace xla { +bool AllGatherDynamicSliceSimplifier::InstructionMatchesPattern( + HloInstruction* instruction) { + if (instruction->opcode() != HloOpcode::kDynamicSlice) { + return false; + } + + HloDynamicSliceInstruction* dynamic_slice = + Cast(instruction); + HloInstruction* operand = dynamic_slice->mutable_operand(0); + + // Check if the operand is a reshape or all-gather instruction + bool is_reshape = operand->opcode() == HloOpcode::kReshape; + bool is_all_gather = operand->opcode() == HloOpcode::kAllGather; + + if (!is_reshape && !is_all_gather) { + return false; + } + + if (is_reshape && operand->operand(0)->opcode() != HloOpcode::kAllGather) { + return false; + } + + const HloModuleConfig& config = instruction->GetModule()->config(); + HloAllGatherInstruction* all_gather = + is_reshape ? Cast(operand->mutable_operand(0)) + : Cast(operand); + + bool match = AllGatherDynamicSliceCancellation( + all_gather, config.num_partitions(), config.replica_count(), + /*allow_multiple_split_dims=*/true, + /*allow_intervening_reshape=*/true, /*min_rank=*/1, + HloPredicateIsOp, + HloPredicateIsOp, + /*allow_intervening_bitcast=*/false, + /*allow_multiple_users=*/true); + + return match; +} + +StatusOr AllGatherDynamicSliceSimplifier::ExpandInstruction( + HloInstruction* instruction) { + HloDynamicSliceInstruction* dynamic_slice = + Cast(instruction); + HloInstruction* operand = dynamic_slice->mutable_operand(0); + + if (operand->opcode() != HloOpcode::kReshape) { + // dynamic-slice(all-gather) case + return operand->mutable_operand(0); + } + + // dynamic-slice(reshape(all-gather)) case + HloReshapeInstruction* reshape = Cast(operand); + HloAllGatherInstruction* all_gather = + Cast(reshape->mutable_operand(0)); + HloInstruction* all_gather_input = all_gather->mutable_operand(0); + + auto* new_reshape = instruction->parent()->AddInstruction( + HloInstruction::CreateReshape(dynamic_slice->shape(), all_gather_input)); + return new_reshape; +} + +} // namespace xla diff --git a/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h b/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h new file mode 100644 index 0000000000000..f0fb673ad1f6f --- /dev/null +++ b/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h @@ -0,0 +1,48 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_DYNAMIC_SLICE_SIMPLIFIER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_DYNAMIC_SLICE_SIMPLIFIER_H_ + +#include "xla/service/op_expander_pass.h" + +namespace xla { + +// A pass that simplifies a dynamic-slice of an all-gather +// whose slice is the same as the original operand of the all-gather. +// As an example: +// +// ag = all-gather(x) replica_groups={{0,1,2,3,4,5,6,7}} +// offset = multiply(partition_id, slice_size) +// ds = dynamic-slice(ag, offset, 0, 0) +// +// Can be simplified to the all-gather operand. + +class AllGatherDynamicSliceSimplifier : public OpExpanderPass { + public: + absl::string_view name() const override { + return "all-gather-dynamic-slice-simplifier"; + } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla + +#endif // XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_DYNAMIC_SLICE_SIMPLIFIER_H_ diff --git a/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier_test.cc b/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier_test.cc new file mode 100644 index 0000000000000..c7f4391bc0092 --- /dev/null +++ b/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier_test.cc @@ -0,0 +1,233 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h" + +#include +#include +#include + +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_matchers.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { +namespace { + +using ::testing::Matcher; +namespace op = xla::testing::opcode_matchers; + +class AllGatherDynamicSliceSimplifierTest : public HloTestBase { + public: + absl::StatusOr> RunPass( + absl::string_view hlo_module, int64_t num_replicas, + int64_t num_partitions, bool expect_change) { + HloModuleConfig config = GetModuleConfigForTest( + /*replica_count=*/num_replicas, + /*num_partitions=*/num_partitions); + config.set_use_spmd_partitioning(num_partitions > 1); + TF_ASSIGN_OR_RETURN(auto module, + ParseAndReturnVerifiedModule(hlo_module, config)); + auto changed = AllGatherDynamicSliceSimplifier().Run(module.get()); + if (!changed.ok()) { + return changed.status(); + } + EXPECT_EQ(changed.value(), expect_change); + return std::move(module); + } +}; + +// Test cancellation of all-gather followed by dynamic-slice across all +// partitions. +TEST_F(AllGatherDynamicSliceSimplifierTest, AllPartitions) { + absl::string_view hlo_string = R"( + HloModule AllGather + + ENTRY %AllGather { + %param = f32[32,8,128]{2,1,0} parameter(0) + %ag = f32[256,8,128]{2,1,0} all-gather(%param), replica_groups={{0,1,2,3,4,5,6,7}}, + dimensions={0}, channel_id=1, use_global_device_ids=true + %pid = u32[] partition-id() + %pid_s32 = s32[] convert(%pid) + %slice_size = s32[] constant(32) + %offset = s32[] multiply(%pid_s32, %slice_size) + %zero = s32[] constant(0) + ROOT %ds = f32[32,8,128]{2,1,0} dynamic-slice(%ag, %offset, %zero, %zero), + dynamic_slice_sizes={32,8,128} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string, + /*num_replicas=*/1, + /*num_partitions=*/8, + /*expect_change=*/true)); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Parameter(0)); +} + +// Test cancellation of all-gather followed by dynamic-slice across all replicas +// with reshape. +TEST_F(AllGatherDynamicSliceSimplifierTest, AllReplicasWithReshape) { + absl::string_view hlo_string = R"( + HloModule AllGather + + ENTRY %AllGather { + %param = f32[32,8,128]{2,1,0} parameter(0) + %ag = f32[256,8,128]{2,1,0} all-gather(%param), replica_groups={{0,1,2,3,4,5,6,7}}, + dimensions={0}, channel_id=1, use_global_device_ids=true + %reshape = f32[256,8,64,2]{3,2,1,0} reshape(%ag) + %pid = u32[] partition-id() + %pid_s32 = s32[] convert(%pid) + %slice_size = s32[] constant(32) + %offset = s32[] multiply(%pid_s32, %slice_size) + %zero = s32[] constant(0) + ROOT %ds = f32[32,8,64,2]{3,2,1,0} dynamic-slice(%reshape, %offset, %zero, %zero, %zero), + dynamic_slice_sizes={32,8,64,2} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string, + /*num_replicas=*/1, + /*num_partitions=*/8, + /*expect_change=*/true)); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Reshape(op::Parameter(0))); +} + +// Test no cancellation when reshape is on the slice dimension. +TEST_F(AllGatherDynamicSliceSimplifierTest, + AllPartitionsWithReshapeOnSliceDim) { + absl::string_view hlo_string = R"( + HloModule AllGather + + ENTRY %AllGather { + %param = f32[32,8,128]{2,1,0} parameter(0) + %ag = f32[256,8,128]{2,1,0} all-gather(%param), replica_groups={{0,1,2,3,4,5,6,7}}, + dimensions={0}, channel_id=1, use_global_device_ids=true + %reshape = f32[2048,128]{1,0} reshape(%ag) + %pid = u32[] partition-id() + %pid_s32 = s32[] convert(%pid) + %slice_size = s32[] constant(256) + %offset = s32[] multiply(%pid_s32, %slice_size) + %zero = s32[] constant(0) + ROOT %ds = f32[256,128]{1,0} dynamic-slice(%reshape, %offset, %zero), + dynamic_slice_sizes={256,128} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string, + /*num_replicas=*/1, + /*num_partitions=*/8, + /*expect_change=*/false)); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::DynamicSlice( + op::Reshape(op::AllGather(op::Parameter(0))), + op::Multiply(op::Convert(op::PartitionId()), op::Constant()), + op::Constant())); +} + +// Test no cancellation when there is no all-gather. +TEST_F(AllGatherDynamicSliceSimplifierTest, NoAllGather) { + absl::string_view hlo_string = R"( + HloModule NoAllGather + + ENTRY %NoAllGather { + %param = f32[32,8,128]{2,1,0} parameter(0) + %pid = u32[] partition-id() + %pid_s32 = s32[] convert(%pid) + %slice_size = s32[] constant(32) + %offset = s32[] multiply(%pid_s32, %slice_size) + %zero = s32[] constant(0) + ROOT %ds = f32[32,8,128]{2,1,0} dynamic-slice(%param, %offset, %zero, %zero), + dynamic_slice_sizes={32,8,128} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string, + /*num_replicas=*/1, + /*num_partitions=*/1, + /*expect_change=*/false)); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::DynamicSlice( + op::Parameter(0), + op::Multiply(op::Convert(op::PartitionId()), op::Constant()), + op::Constant(), op::Constant())); +} + +// Test no cancellation when the all-gather dimension is incorrect. +TEST_F(AllGatherDynamicSliceSimplifierTest, IncorrectAllGatherDimension) { + absl::string_view hlo_string = R"( + HloModule IncorrectAllGatherDimension + + ENTRY %IncorrectAllGatherDimension { + %param = f32[32,8,128]{2,1,0} parameter(0) + %ag = f32[32,64,128]{2,1,0} all-gather(%param), replica_groups={}, + dimensions={1}, channel_id=1 + %pid = u32[] partition-id() + %pid_s32 = s32[] convert(%pid) + %slice_size = s32[] constant(8) + %offset = s32[] multiply(%pid_s32, %slice_size) + %zero = s32[] constant(0) + ROOT %ds = f32[32,8,128]{2,1,0} dynamic-slice(%ag, %zero, %offset, %zero), + dynamic_slice_sizes={32,8,128} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string, + /*num_replicas=*/8, + /*num_partitions=*/1, + /*expect_change=*/false)); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::DynamicSlice( + op::AllGather(op::Parameter(0)), op::Constant(), + op::Multiply(op::Convert(op::PartitionId()), op::Constant()), + op::Constant())); +} + +// Test cancellation of all-gather followed by dynamic-slice across all replicas +// with reshape and multiple users of the all-gather. +TEST_F(AllGatherDynamicSliceSimplifierTest, + AllReplicasWithReshapeMultipleUsers) { + absl::string_view hlo_string = R"( + HloModule AllGather + + ENTRY %AllGather { + %param = f32[32,8,128]{2,1,0} parameter(0) + %ag = f32[256,8,128]{2,1,0} all-gather(%param), replica_groups={{0,1,2,3,4,5,6,7}}, + dimensions={0}, channel_id=1, use_global_device_ids=true + %reshape = f32[256,8,64,2]{3,2,1,0} reshape(%ag) + %pid = u32[] partition-id() + %pid_s32 = s32[] convert(%pid) + %slice_size = s32[] constant(32) + %offset = s32[] multiply(%pid_s32, %slice_size) + %zero = s32[] constant(0) + %ds = f32[32,8,64,2]{3,2,1,0} dynamic-slice(%reshape, %offset, %zero, %zero, %zero), + dynamic_slice_sizes={32,8,64,2} + ROOT %tuple = (f32[32,8,64,2]{3,2,1,0}, f32[256,8,128]{2,1,0}) tuple(%ds, %ag) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string, + /*num_replicas=*/1, + /*num_partitions=*/8, + /*expect_change=*/true)); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Reshape(op::Parameter(0)), + op::AllGather(op::Parameter(0)))); +} +} // namespace +} // namespace gpu +} // namespace xla