diff --git a/xla/service/collective_opt_utils.cc b/xla/service/collective_opt_utils.cc index 13b7d553adb22..8da8a37155b87 100644 --- a/xla/service/collective_opt_utils.cc +++ b/xla/service/collective_opt_utils.cc @@ -322,12 +322,14 @@ bool AllGatherDynamicSliceCancellation( const HloAllGatherInstruction* ag, int64_t num_partitions, int64_t num_replicas, bool allow_multiple_split_dims, bool allow_intervening_reshape, int64_t min_rank, - HloPredicate match_partition_id, HloPredicate match_replica_id) { + HloPredicate match_partition_id, HloPredicate match_replica_id, + bool allow_intervening_bitcast, bool allow_multiple_users) { auto spec = MatchWithDynamicSlice( ag, num_partitions, num_replicas, allow_multiple_split_dims, allow_intervening_reshape, min_rank, match_partition_id, match_replica_id, ag->constrain_layout(), ag->use_global_device_ids(), - ag->channel_id() && ag->opcode() == HloOpcode::kAllGather); + ag->channel_id() && ag->opcode() == HloOpcode::kAllGather, + allow_intervening_bitcast, allow_multiple_users); if (spec.has_value()) { return true; } @@ -340,7 +342,7 @@ std::optional MatchWithDynamicSlice( bool allow_intervening_reshape, int64_t min_rank, HloPredicate match_partition_id, HloPredicate match_replica_id, bool is_constrain_layout, bool use_global_device_ids, bool is_cross_module, - bool allow_intervening_bitcast) { + bool allow_intervening_bitcast, bool allow_multiple_users) { if (!instruction->shape().IsArray() || is_constrain_layout || (is_cross_module && !instruction->GetModule()->config().use_spmd_partitioning())) { @@ -354,8 +356,8 @@ std::optional MatchWithDynamicSlice( << " excluding trivial dimensions " << instruction->ToString(); return std::nullopt; } - if (instruction->user_count() != 1) { - VLOG(2) << "All-gather user_count > 1 " << instruction->ToString(); + if (!allow_multiple_users && instruction->user_count() != 1) { + VLOG(2) << "All-gather user_count != 1 " << instruction->ToString(); return std::nullopt; } if (instruction->replica_groups().size() > 1) { @@ -371,8 +373,19 @@ std::optional MatchWithDynamicSlice( return std::nullopt; } } - + // Always assume first user to start. HloInstruction* user = instruction->users()[0]; + if (allow_multiple_users) { + // If we find a reshape or dynamic-slice use that. + for (auto* some_user : instruction->users()) { + if ((allow_intervening_reshape && + some_user->opcode() == HloOpcode::kReshape) || + some_user->opcode() == HloOpcode::kDynamicSlice) { + user = some_user; + break; + } + } + } HloInstruction* reshape = nullptr; if (allow_intervening_reshape && user->opcode() == HloOpcode::kReshape) { // Allow the intervening reshape if it reshapes just the non scattered diff --git a/xla/service/collective_opt_utils.h b/xla/service/collective_opt_utils.h index 983969e41e607..6131028d5f684 100644 --- a/xla/service/collective_opt_utils.h +++ b/xla/service/collective_opt_utils.h @@ -43,16 +43,17 @@ std::optional MatchReduceScatter( HloPredicate match_replica_id = HloPredicateIsOp, bool allow_intervening_bitcast = false); -// Check whether AG(ICI) and its single user DS(ICI) can be canceled out. +// Check whether AG(ICI) and its user DS(ICI) can be canceled out. bool AllGatherDynamicSliceCancellation( const HloAllGatherInstruction* ag, int64_t num_partitions, int64_t num_replicas, bool allow_multiple_split_dims = false, bool allow_intervening_reshape = false, int64_t min_rank = 1, HloPredicate match_partition_id = HloPredicateIsOp, - HloPredicate match_replica_id = HloPredicateIsOp); + HloPredicate match_replica_id = HloPredicateIsOp, + bool allow_intervening_bitcast = false, bool allow_multiple_users = false); // Check if a given instruction (AllReduce or AllGather) matches a DynamicSlice; -// the DynamicSlice has to be the only user of the given instruction. +// the DynamicSlice has to be the user of the given instruction. std::optional MatchWithDynamicSlice( const HloChannelInstruction* instruction, int64_t num_partitions, int64_t num_replicas, bool allow_multiple_split_dims = false, @@ -60,7 +61,8 @@ std::optional MatchWithDynamicSlice( HloPredicate match_partition_id = HloPredicateIsOp, HloPredicate match_replica_id = HloPredicateIsOp, bool is_constrain_layout = false, bool use_global_device_ids = false, - bool is_cross_module = false, bool allow_intervening_bitcast = false); + bool is_cross_module = false, bool allow_intervening_bitcast = false, + bool allow_multiple_users = false); } // namespace xla diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 151417710d85c..ad89a46c852e0 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -1416,6 +1416,7 @@ cc_library( "//xla/service/gpu/runtime:thunk", "//xla/service/gpu/transforms:algebraic_simplifier", "//xla/service/gpu/transforms:algorithm_checker", + "//xla/service/gpu/transforms:all_gather_dynamic_slice_simplifier", "//xla/service/gpu/transforms:all_gather_optimizer", "//xla/service/gpu/transforms:all_reduce_blueconnect", "//xla/service/gpu/transforms:all_reduce_splitter", diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 725556404b47b..02d2e0833158a 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -142,6 +142,7 @@ limitations under the License. #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/transforms/algebraic_simplifier.h" #include "xla/service/gpu/transforms/algorithm_checker.h" +#include "xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h" #include "xla/service/gpu/transforms/all_gather_optimizer.h" #include "xla/service/gpu/transforms/all_reduce_blueconnect.h" #include "xla/service/gpu/transforms/all_reduce_splitter.h" @@ -903,6 +904,7 @@ absl::Status RunCollectiveOptimizationPasses( HloPassPipeline collectives_pipeline("collective-optimizations"); collectives_pipeline.AddPass(); collectives_pipeline.AddPass(); + collectives_pipeline.AddPass(); collectives_pipeline.AddPass(); collectives_pipeline.AddPass( debug_options.xla_gpu_enable_reassociation_for_converted_ar()); diff --git a/xla/service/gpu/transforms/BUILD b/xla/service/gpu/transforms/BUILD index c481e29083b5d..842dffa0028a6 100644 --- a/xla/service/gpu/transforms/BUILD +++ b/xla/service/gpu/transforms/BUILD @@ -310,6 +310,31 @@ xla_cc_test( ], ) +cc_library( + name = "all_gather_dynamic_slice_simplifier", + srcs = ["all_gather_dynamic_slice_simplifier.cc"], + hdrs = ["all_gather_dynamic_slice_simplifier.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:collective_opt_utils", + "//xla/service:hlo_creation_utils", + "//xla/service:op_expander_pass", + ], +) + +xla_cc_test( + name = "all_gather_dynamic_slice_simplifier_test", + srcs = ["all_gather_dynamic_slice_simplifier_test.cc"], + deps = [ + ":all_gather_dynamic_slice_simplifier", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_matchers", + "//xla/tests:hlo_test_base", + "@tsl//tsl/platform:test_main", + ], +) + cc_library( name = "collective_permute_cycle_decomposer", srcs = ["collective_permute_cycle_decomposer.cc"], diff --git a/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.cc b/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.cc new file mode 100644 index 0000000000000..4035b80606cdf --- /dev/null +++ b/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.cc @@ -0,0 +1,83 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h" + +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/service/collective_opt_utils.h" + +namespace xla { +bool AllGatherDynamicSliceSimplifier::InstructionMatchesPattern( + HloInstruction* instruction) { + if (instruction->opcode() != HloOpcode::kDynamicSlice) { + return false; + } + + HloDynamicSliceInstruction* dynamic_slice = + Cast(instruction); + HloInstruction* operand = dynamic_slice->mutable_operand(0); + + // Check if the operand is a reshape or all-gather instruction + bool is_reshape = operand->opcode() == HloOpcode::kReshape; + bool is_all_gather = operand->opcode() == HloOpcode::kAllGather; + + if (!is_reshape && !is_all_gather) { + return false; + } + + if (is_reshape && operand->operand(0)->opcode() != HloOpcode::kAllGather) { + return false; + } + + const HloModuleConfig& config = instruction->GetModule()->config(); + HloAllGatherInstruction* all_gather = + is_reshape ? Cast(operand->mutable_operand(0)) + : Cast(operand); + + bool match = AllGatherDynamicSliceCancellation( + all_gather, config.num_partitions(), config.replica_count(), + /*allow_multiple_split_dims=*/true, + /*allow_intervening_reshape=*/true, /*min_rank=*/1, + HloPredicateIsOp, + HloPredicateIsOp, + /*allow_intervening_bitcast=*/false, + /*allow_multiple_users=*/true); + + return match; +} + +StatusOr AllGatherDynamicSliceSimplifier::ExpandInstruction( + HloInstruction* instruction) { + HloDynamicSliceInstruction* dynamic_slice = + Cast(instruction); + HloInstruction* operand = dynamic_slice->mutable_operand(0); + + if (operand->opcode() != HloOpcode::kReshape) { + // dynamic-slice(all-gather) case + return operand->mutable_operand(0); + } + + // dynamic-slice(reshape(all-gather)) case + HloReshapeInstruction* reshape = Cast(operand); + HloAllGatherInstruction* all_gather = + Cast(reshape->mutable_operand(0)); + HloInstruction* all_gather_input = all_gather->mutable_operand(0); + + auto* new_reshape = instruction->parent()->AddInstruction( + HloInstruction::CreateReshape(dynamic_slice->shape(), all_gather_input)); + return new_reshape; +} + +} // namespace xla diff --git a/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h b/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h new file mode 100644 index 0000000000000..f0fb673ad1f6f --- /dev/null +++ b/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h @@ -0,0 +1,48 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_DYNAMIC_SLICE_SIMPLIFIER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_DYNAMIC_SLICE_SIMPLIFIER_H_ + +#include "xla/service/op_expander_pass.h" + +namespace xla { + +// A pass that simplifies a dynamic-slice of an all-gather +// whose slice is the same as the original operand of the all-gather. +// As an example: +// +// ag = all-gather(x) replica_groups={{0,1,2,3,4,5,6,7}} +// offset = multiply(partition_id, slice_size) +// ds = dynamic-slice(ag, offset, 0, 0) +// +// Can be simplified to the all-gather operand. + +class AllGatherDynamicSliceSimplifier : public OpExpanderPass { + public: + absl::string_view name() const override { + return "all-gather-dynamic-slice-simplifier"; + } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla + +#endif // XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_DYNAMIC_SLICE_SIMPLIFIER_H_ diff --git a/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier_test.cc b/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier_test.cc new file mode 100644 index 0000000000000..c7f4391bc0092 --- /dev/null +++ b/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier_test.cc @@ -0,0 +1,233 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h" + +#include +#include +#include + +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_matchers.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { +namespace { + +using ::testing::Matcher; +namespace op = xla::testing::opcode_matchers; + +class AllGatherDynamicSliceSimplifierTest : public HloTestBase { + public: + absl::StatusOr> RunPass( + absl::string_view hlo_module, int64_t num_replicas, + int64_t num_partitions, bool expect_change) { + HloModuleConfig config = GetModuleConfigForTest( + /*replica_count=*/num_replicas, + /*num_partitions=*/num_partitions); + config.set_use_spmd_partitioning(num_partitions > 1); + TF_ASSIGN_OR_RETURN(auto module, + ParseAndReturnVerifiedModule(hlo_module, config)); + auto changed = AllGatherDynamicSliceSimplifier().Run(module.get()); + if (!changed.ok()) { + return changed.status(); + } + EXPECT_EQ(changed.value(), expect_change); + return std::move(module); + } +}; + +// Test cancellation of all-gather followed by dynamic-slice across all +// partitions. +TEST_F(AllGatherDynamicSliceSimplifierTest, AllPartitions) { + absl::string_view hlo_string = R"( + HloModule AllGather + + ENTRY %AllGather { + %param = f32[32,8,128]{2,1,0} parameter(0) + %ag = f32[256,8,128]{2,1,0} all-gather(%param), replica_groups={{0,1,2,3,4,5,6,7}}, + dimensions={0}, channel_id=1, use_global_device_ids=true + %pid = u32[] partition-id() + %pid_s32 = s32[] convert(%pid) + %slice_size = s32[] constant(32) + %offset = s32[] multiply(%pid_s32, %slice_size) + %zero = s32[] constant(0) + ROOT %ds = f32[32,8,128]{2,1,0} dynamic-slice(%ag, %offset, %zero, %zero), + dynamic_slice_sizes={32,8,128} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string, + /*num_replicas=*/1, + /*num_partitions=*/8, + /*expect_change=*/true)); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Parameter(0)); +} + +// Test cancellation of all-gather followed by dynamic-slice across all replicas +// with reshape. +TEST_F(AllGatherDynamicSliceSimplifierTest, AllReplicasWithReshape) { + absl::string_view hlo_string = R"( + HloModule AllGather + + ENTRY %AllGather { + %param = f32[32,8,128]{2,1,0} parameter(0) + %ag = f32[256,8,128]{2,1,0} all-gather(%param), replica_groups={{0,1,2,3,4,5,6,7}}, + dimensions={0}, channel_id=1, use_global_device_ids=true + %reshape = f32[256,8,64,2]{3,2,1,0} reshape(%ag) + %pid = u32[] partition-id() + %pid_s32 = s32[] convert(%pid) + %slice_size = s32[] constant(32) + %offset = s32[] multiply(%pid_s32, %slice_size) + %zero = s32[] constant(0) + ROOT %ds = f32[32,8,64,2]{3,2,1,0} dynamic-slice(%reshape, %offset, %zero, %zero, %zero), + dynamic_slice_sizes={32,8,64,2} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string, + /*num_replicas=*/1, + /*num_partitions=*/8, + /*expect_change=*/true)); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Reshape(op::Parameter(0))); +} + +// Test no cancellation when reshape is on the slice dimension. +TEST_F(AllGatherDynamicSliceSimplifierTest, + AllPartitionsWithReshapeOnSliceDim) { + absl::string_view hlo_string = R"( + HloModule AllGather + + ENTRY %AllGather { + %param = f32[32,8,128]{2,1,0} parameter(0) + %ag = f32[256,8,128]{2,1,0} all-gather(%param), replica_groups={{0,1,2,3,4,5,6,7}}, + dimensions={0}, channel_id=1, use_global_device_ids=true + %reshape = f32[2048,128]{1,0} reshape(%ag) + %pid = u32[] partition-id() + %pid_s32 = s32[] convert(%pid) + %slice_size = s32[] constant(256) + %offset = s32[] multiply(%pid_s32, %slice_size) + %zero = s32[] constant(0) + ROOT %ds = f32[256,128]{1,0} dynamic-slice(%reshape, %offset, %zero), + dynamic_slice_sizes={256,128} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string, + /*num_replicas=*/1, + /*num_partitions=*/8, + /*expect_change=*/false)); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::DynamicSlice( + op::Reshape(op::AllGather(op::Parameter(0))), + op::Multiply(op::Convert(op::PartitionId()), op::Constant()), + op::Constant())); +} + +// Test no cancellation when there is no all-gather. +TEST_F(AllGatherDynamicSliceSimplifierTest, NoAllGather) { + absl::string_view hlo_string = R"( + HloModule NoAllGather + + ENTRY %NoAllGather { + %param = f32[32,8,128]{2,1,0} parameter(0) + %pid = u32[] partition-id() + %pid_s32 = s32[] convert(%pid) + %slice_size = s32[] constant(32) + %offset = s32[] multiply(%pid_s32, %slice_size) + %zero = s32[] constant(0) + ROOT %ds = f32[32,8,128]{2,1,0} dynamic-slice(%param, %offset, %zero, %zero), + dynamic_slice_sizes={32,8,128} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string, + /*num_replicas=*/1, + /*num_partitions=*/1, + /*expect_change=*/false)); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::DynamicSlice( + op::Parameter(0), + op::Multiply(op::Convert(op::PartitionId()), op::Constant()), + op::Constant(), op::Constant())); +} + +// Test no cancellation when the all-gather dimension is incorrect. +TEST_F(AllGatherDynamicSliceSimplifierTest, IncorrectAllGatherDimension) { + absl::string_view hlo_string = R"( + HloModule IncorrectAllGatherDimension + + ENTRY %IncorrectAllGatherDimension { + %param = f32[32,8,128]{2,1,0} parameter(0) + %ag = f32[32,64,128]{2,1,0} all-gather(%param), replica_groups={}, + dimensions={1}, channel_id=1 + %pid = u32[] partition-id() + %pid_s32 = s32[] convert(%pid) + %slice_size = s32[] constant(8) + %offset = s32[] multiply(%pid_s32, %slice_size) + %zero = s32[] constant(0) + ROOT %ds = f32[32,8,128]{2,1,0} dynamic-slice(%ag, %zero, %offset, %zero), + dynamic_slice_sizes={32,8,128} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string, + /*num_replicas=*/8, + /*num_partitions=*/1, + /*expect_change=*/false)); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::DynamicSlice( + op::AllGather(op::Parameter(0)), op::Constant(), + op::Multiply(op::Convert(op::PartitionId()), op::Constant()), + op::Constant())); +} + +// Test cancellation of all-gather followed by dynamic-slice across all replicas +// with reshape and multiple users of the all-gather. +TEST_F(AllGatherDynamicSliceSimplifierTest, + AllReplicasWithReshapeMultipleUsers) { + absl::string_view hlo_string = R"( + HloModule AllGather + + ENTRY %AllGather { + %param = f32[32,8,128]{2,1,0} parameter(0) + %ag = f32[256,8,128]{2,1,0} all-gather(%param), replica_groups={{0,1,2,3,4,5,6,7}}, + dimensions={0}, channel_id=1, use_global_device_ids=true + %reshape = f32[256,8,64,2]{3,2,1,0} reshape(%ag) + %pid = u32[] partition-id() + %pid_s32 = s32[] convert(%pid) + %slice_size = s32[] constant(32) + %offset = s32[] multiply(%pid_s32, %slice_size) + %zero = s32[] constant(0) + %ds = f32[32,8,64,2]{3,2,1,0} dynamic-slice(%reshape, %offset, %zero, %zero, %zero), + dynamic_slice_sizes={32,8,64,2} + ROOT %tuple = (f32[32,8,64,2]{3,2,1,0}, f32[256,8,128]{2,1,0}) tuple(%ds, %ag) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string, + /*num_replicas=*/1, + /*num_partitions=*/8, + /*expect_change=*/true)); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Reshape(op::Parameter(0)), + op::AllGather(op::Parameter(0)))); +} +} // namespace +} // namespace gpu +} // namespace xla