From 7b61179efcae07c0b1dd0ca026445478bc27f3b0 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Wed, 18 Sep 2024 14:49:32 -0700 Subject: [PATCH] [XLA:SPMD] Remove LookaheadUserSharding in sharding propagation. When we infer the dot sharding from its operands, it is possible that both operands can improve the dot sharding. LookaheadUserSharding iterates the dot users and decides which dot operand sharding is preferred. This cl removes it for two reasons. 1. It is unnecessary. If we can predict the sharding from dot users, we can wait the sharding to be propagated from users. The propagted sharding from users can still help us make choice between dot operands. 2. The lookhead sharding may be wrong. LookaheadUserSharding is a heuristics. We cannot guarantee that the predicted sharding will hold in the dot users. Reverts b4ea9792c7ac549352065c43e0dd0a42e7b2ffeb PiperOrigin-RevId: 676140549 --- .../auto_sharding_dot_handler.cc | 8 +- xla/service/sharding_propagation.cc | 153 +++++++----------- xla/service/sharding_propagation.h | 9 +- xla/service/sharding_propagation_test.cc | 108 +++++++++++-- 4 files changed, 161 insertions(+), 117 deletions(-) diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index 85898c8b4e6df..237ffa4e3d4f4 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -444,13 +444,13 @@ std::optional HandlerBase::GetShardingFromUser( CHECK_OK(ins_clone->ReplaceOperandWith(1, rhs_clone.get())); if (ins_->opcode() == HloOpcode::kConvolution) { xla::InferConvolutionShardingFromOperands( - ins_clone.get(), call_graph_, 10, - /* may_combine_partial_sharding */ true, /* is_spmd */ true); + ins_clone.get(), /* aggressiveness */ 10, + /* may_combine_partial_sharding */ true); } else { xla::InferDotShardingFromOperands( - ins_clone.get(), call_graph_, + ins_clone.get(), dot_as_convolution_util::ParseDotGeneralFromDot(ins_clone.get()), - /* may_combine_partial_sharding/ */ true, /* is_spmd */ true); + /* aggressiveness */ 10, /* may_combine_partial_sharding */ true); } if (!ins_clone->has_sharding()) { return std::nullopt; diff --git a/xla/service/sharding_propagation.cc b/xla/service/sharding_propagation.cc index b9508c04021f7..316644bf87ea8 100644 --- a/xla/service/sharding_propagation.cc +++ b/xla/service/sharding_propagation.cc @@ -16,10 +16,10 @@ limitations under the License. #include "xla/service/sharding_propagation.h" #include +#include #include #include #include -#include #include #include #include @@ -36,10 +36,12 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/array.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_domain_metadata.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -47,6 +49,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_sharding_metadata.h" #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/protobuf_util.h" +#include "xla/service/call_graph.h" #include "xla/service/dot_as_convolution_util.h" #include "xla/service/host_memory_offload_annotations.h" #include "xla/service/spmd/shard_barrier_partitioner.h" @@ -416,55 +419,6 @@ bool SupportSpatialPartitioning( } } -// Helper to lookahead sharding of user of an instruction to be used as guidance -// for ambiguous cases. -std::optional LookaheadUserSharding(HloInstruction* instr, - bool is_spmd, - const CallGraph& call_graph) { - if (instr->user_count() != 1) { - return std::nullopt; - } - HloInstruction* current_user = instr->users()[0]; - std::optional sharding; - std::vector users_chain = {instr, current_user}; - // Collect single user instructions along the way. - while (!current_user->has_sharding()) { - // Only consider single user chains. - if (current_user->users().size() != 1) { - users_chain.clear(); - break; - } - current_user = current_user->users()[0]; - users_chain.push_back(current_user); - } - // Early exit for unsupported cases. - if (users_chain.empty()) { - return std::nullopt; - } - for (int i = users_chain.size() - 1; i >= 1; --i) { - HloInstruction* user = users_chain[i]; - HloInstruction* current = users_chain[i - 1]; - CHECK(user->has_sharding()); - sharding = ShardingPropagation::GetShardingFromUser( - *current, *user, INT64_MAX, is_spmd, call_graph, - /*sharding_helper=*/nullptr); - // We need to set the sharding to the instruction, because - // GetShardingFromUser() interface uses sharding from the instruction - // itself. It will be cleared out later. - if (sharding.has_value() && i != 1) { - current->set_sharding(*sharding); - continue; - } - break; - } - // Clear the sharding of the middle instructions we set the sharding of - // because they were unsharded. - for (int i = 1; i < users_chain.size() - 1; ++i) { - users_chain[i]->clear_sharding(); - } - return sharding; -} - // Infer output sharding on index parallel dimensions for gather from operand // and indices. bool InferGatherParallelShardingFromOperands( @@ -1071,9 +1025,9 @@ bool IsCSEPreventionSharding(const HloSharding& sharding) { } // namespace bool InferDotShardingFromOperands( - HloInstruction* instruction, const CallGraph& call_graph, + HloInstruction* instruction, const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, - bool may_combine_partial_sharding, bool is_spmd) { + int64_t aggressiveness, bool may_combine_partial_sharding) { auto from_operand = [&](int64_t operand_index) { auto operand = instruction->operand(operand_index); const HloSharding& operand_sharding = operand->sharding(); @@ -1128,55 +1082,66 @@ bool InferDotShardingFromOperands( from_operand(1), instruction, may_combine_partial_sharding, /*allow_aggressive_resharding=*/false); } - // If not improved sharding found then do not set any sharding. + + // Four cases based on if improved_operand_0 and improved_operand_1 are + // available. + // Case 0. Both operands have no improved sharding. if (!improved_operand_0.has_value() && !improved_operand_1.has_value()) { return false; } - // Sharding found from operand 0 but not operand 1. Set sharding from operand - // 0 + // Case 1. Sharding found from operand 0 but not operand 1. Set sharding from + // operand 0. if (improved_operand_0.has_value() && !improved_operand_1.has_value()) { instruction->set_sharding(*improved_operand_0); return true; } - // Sharding found from operand 1 but not operand 0. Set sharding from operand - // 1 + // Case 2. Sharding found from operand 1 but not operand 0. Set sharding from + // operand 1. if (!improved_operand_0.has_value() && improved_operand_1.has_value()) { instruction->set_sharding(*improved_operand_1); return true; } + // Case 3. Both operands have improved shardings. CHECK(improved_operand_0.has_value() && improved_operand_1.has_value()); - std::optional lookahead_sharding = - LookaheadUserSharding(instruction, is_spmd, call_graph); + + // If one of the improved shardings is a sub-tiling or equal to the other, use + // the better sharding with more tiles. + if (hlo_sharding_util::IsSubTilingOrEqualSharding( + instruction->shape(), *improved_operand_0, *improved_operand_1)) { + instruction->set_sharding(*improved_operand_0); + return true; + } + if (hlo_sharding_util::IsSubTilingOrEqualSharding( + instruction->shape(), *improved_operand_1, *improved_operand_0)) { + instruction->set_sharding(*improved_operand_1); + return true; + } + + // If the two improved shardings are mergeable, there is no conflict. + if (std::optional improved_sharding = + hlo_sharding_util::ReturnImprovedShardingImpl( + *improved_operand_0, &improved_operand_1.value(), + instruction->shape(), may_combine_partial_sharding, + /*allow_aggressive_resharding=*/false)) { + instruction->set_sharding(*improved_sharding); + return true; + } + + if (aggressiveness < 3) { + // We can improve the dot with different shardings. Pause the propagation + // and wait for the winner between the two operands. + return false; + } + + // The two improved sharding are different and we are at the highest + // aggressiveness. Prioritize the operand with larger size. std::array sharding_priority = {*improved_operand_0, *improved_operand_1}; - bool priority_defined_with_lookahead = false; - // Found sharding from lookahead. - if (lookahead_sharding.has_value()) { - const bool operand_0_is_lookahead_subtiling = - hlo_sharding_util::IsSubTilingOrEqualSharding( - instruction->shape(), *lookahead_sharding, *improved_operand_0); - const bool operand_1_is_lookahead_subtiling = - hlo_sharding_util::IsSubTilingOrEqualSharding( - instruction->shape(), *lookahead_sharding, *improved_operand_1); - // If the sharding from operand 0 is a subtiling of the user, but not the - // one from operand 1 prioritize that sharding. - if (operand_0_is_lookahead_subtiling && !operand_1_is_lookahead_subtiling) { - priority_defined_with_lookahead = true; - } - // If the sharding from operand 1 is a subtiling of the user, but not the - // one from operand 0 prioritize that sharding. - if (!operand_0_is_lookahead_subtiling && operand_1_is_lookahead_subtiling) { - instruction->set_sharding(*improved_operand_1); - std::swap(sharding_priority[0], sharding_priority[1]); - priority_defined_with_lookahead = true; - } - } - // If lookahead didn't define a priority then use size. - if (!priority_defined_with_lookahead && - ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) < - ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())) { + if (ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) < + ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())) { std::swap(sharding_priority[0], sharding_priority[1]); } + // Set primary sharding to the instruction and then try to improve it with // the secondary sharding. instruction->set_sharding(sharding_priority[0]); @@ -1187,10 +1152,8 @@ bool InferDotShardingFromOperands( // Convolution handling for InferShardingFromOperands(). bool InferConvolutionShardingFromOperands(HloInstruction* instruction, - const CallGraph& call_graph, int64_t aggressiveness, - bool may_combine_partial_sharding, - bool is_spmd) { + bool may_combine_partial_sharding) { auto get_partitions_for_dims = [&](const HloInstruction* inst, absl::Span< @@ -1225,8 +1188,8 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction, (lhs_conv_spatial_partitions == 1 && rhs_conv_spatial_partitions == 1 && instruction->batch_group_count() == 1 && instruction->feature_group_count() == 1)) { - return InferDotShardingFromOperands(instruction, call_graph, dot_dims, - may_combine_partial_sharding, is_spmd); + return InferDotShardingFromOperands(instruction, dot_dims, aggressiveness, + may_combine_partial_sharding); } const auto& dnums = instruction->convolution_dimension_numbers(); const HloInstruction* lhs = instruction->operand(0); @@ -2329,9 +2292,8 @@ bool ShardingPropagation::InferShardingFromOperands( 1); } case HloOpcode::kConvolution: - return InferConvolutionShardingFromOperands( - instruction, call_graph, aggressiveness, may_combine_partial_sharding, - is_spmd_); + return InferConvolutionShardingFromOperands(instruction, aggressiveness, + may_combine_partial_sharding); case HloOpcode::kTranspose: { const HloInstruction* input = instruction->operand(0); if (!hlo_sharding_util::IsSpatiallyPartitioned(input)) { @@ -2420,9 +2382,8 @@ bool ShardingPropagation::InferShardingFromOperands( case HloOpcode::kDot: { const auto& dnums = dot_as_convolution_util::ParseDotGeneralFromDot(instruction); - return InferDotShardingFromOperands(instruction, call_graph, dnums, - may_combine_partial_sharding, - is_spmd_); + return InferDotShardingFromOperands(instruction, dnums, aggressiveness, + may_combine_partial_sharding); } case HloOpcode::kParameter: { auto parent_it = computation_map.find(instruction->parent()); diff --git a/xla/service/sharding_propagation.h b/xla/service/sharding_propagation.h index 27cef82097743..2654a1fd7d335 100644 --- a/xla/service/sharding_propagation.h +++ b/xla/service/sharding_propagation.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_SHARDING_PROPAGATION_H_ #define XLA_SERVICE_SHARDING_PROPAGATION_H_ +#include #include #include #include @@ -35,17 +36,15 @@ namespace xla { // Infers the shardings for a dot HLO op from the shardings on its operands, // which are expected to have sharding annotations. bool InferDotShardingFromOperands( - HloInstruction* instruction, const CallGraph& call_graph, + HloInstruction* instruction, const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, - bool may_combine_partial_sharding, bool is_spmd); + int64_t aggressiveness, bool may_combine_partial_sharding); // Infers the shardings for a convolution HLO op from the shardings on its // operands, which are expected to have sharding annotations. bool InferConvolutionShardingFromOperands(HloInstruction* instruction, - const CallGraph& call_graph, int64_t aggressiveness, - bool may_combine_partial_sharding, - bool is_spmd); + bool may_combine_partial_sharding); // Remove Sharding custom-call instruction by folding the sharding attribute // to its operand. If the operand already has a different sharding, insert a diff --git a/xla/service/sharding_propagation_test.cc b/xla/service/sharding_propagation_test.cc index 5ca4b47d8ea15..565314d9150e3 100644 --- a/xla/service/sharding_propagation_test.cc +++ b/xla/service/sharding_propagation_test.cc @@ -3324,7 +3324,7 @@ ENTRY %conv { EXPECT_THAT(instruction, op::Sharding("{devices=[2,2,2]0,1,2,3,4,5,6,7}")); if (GetParam().propagate_metadata && !GetParam().clear_metadata) { EXPECT_THAT(instruction->sharding(), - ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")})); + ShardingMetadata({CreateMetadata("a"), CreateMetadata("b")})); } else { EXPECT_THAT(instruction->sharding(), ShardingMetadata({})); } @@ -3396,7 +3396,7 @@ ENTRY %conv { EXPECT_THAT(instruction, op::Sharding("{devices=[2,4]0,2,3,1,4,6,7,5}")); if (GetParam().propagate_metadata && !GetParam().clear_metadata) { EXPECT_THAT(instruction->sharding(), - ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")})); + ShardingMetadata({CreateMetadata("a"), CreateMetadata("b")})); } else { EXPECT_THAT(instruction->sharding(), ShardingMetadata({})); } @@ -11863,7 +11863,7 @@ ENTRY main.9 { op::Sharding("{{devices=[4]<=[4]}, {devices=[4]<=[4]}}")); } -TEST_F(ShardingPropagationTest, LookaheadUsersOfDot) { +TEST_F(ShardingPropagationTest, InferDotShardingFromOperands1) { const char* const hlo_string = R"( HloModule module @@ -11880,24 +11880,108 @@ ENTRY %entry { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); TF_ASSERT_OK_AND_ASSIGN( - bool changed, - ShardingPropagation( - /*is_spmd=*/true, /*propagate_metadata=*/true, - /*allow_spmd_sharding_propagation_to_output=*/{true}, - /*allow_spmd_sharding_propagation_to_parameters=*/{true}) - .Run(module.get())); + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); EXPECT_TRUE(changed); XLA_VLOG_LINES(1, module->ToString()); - // Check dangling sharding custom-call can be removed by DCE after - // propagation. auto* instruction = FindInstruction(module.get(), "dot.1"); - // Check sharding is correctly propagated. EXPECT_THAT(instruction, op::Sharding( "{devices=[4,4,1,4]<=[4,16]T(1,0) last_tile_dim_replicate}")); } +TEST_F(ShardingPropagationTest, InferDotShardingFromOperands2) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %entry { + p0 = bf16[16,32] parameter(0), sharding={devices=[16,1]<=[16]} + p1 = bf16[32,64] parameter(1), sharding={devices=[1,16]<=[16]} + dot = bf16[16,64] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT copy = bf16[16,64] copy(dot), sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + + XLA_VLOG_LINES(1, module->ToString()); + auto* instruction = FindInstruction(module.get(), "dot"); + EXPECT_THAT(instruction, op::Sharding("{devices=[1,16]<=[16]}")); +} + +TEST_F(ShardingPropagationTest, InferDotShardingFromOperands3) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %entry { + p0 = bf16[4,16,32] parameter(0), sharding={devices=[2,4,2]<=[16]} + p1 = bf16[4,32,64] parameter(1), sharding={devices=[2,8,1]<=[16]} + dot = bf16[4,16,64] dot(p0, p1), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} + ROOT copy = bf16[4,16,64] copy(dot) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + + XLA_VLOG_LINES(1, module->ToString()); + auto* instruction = FindInstruction(module.get(), "dot"); + EXPECT_THAT( + instruction, + op::Sharding("{devices=[2,4,1,2]<=[16] last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, InferDotShardingFromOperands4) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %entry { + p0 = bf16[4,16,32] parameter(0), sharding={devices=[2,1,8]<=[16]} + p1 = bf16[4,32,64] parameter(1), sharding={devices=[4,1,4]<=[16]} + dot = bf16[4,16,64] dot(p0, p1), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} + ROOT copy = bf16[4,16,64] copy(dot) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + + XLA_VLOG_LINES(1, module->ToString()); + auto* instruction = FindInstruction(module.get(), "dot"); + EXPECT_THAT(instruction, op::Sharding("{devices=[4,1,4]<=[16]}")); +} + +TEST_F(ShardingPropagationTest, InferDotShardingFromOperands5) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %entry { + p0 = bf16[16,16] parameter(0), sharding={devices=[4,4]<=[4,4]T(1,0)} + p1 = bf16[16,16] parameter(1), sharding={devices=[4,4]<=[4,4]T(1,0)} + dot.0 = bf16[16,16] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={1} + p2 = bf16[16,16] parameter(2), sharding={devices=[4,4]<=[16]} + p3 = bf16[16,16] parameter(3), sharding={devices=[4,4]<=[16]} + dot.1 = bf16[16,16] dot(p2, p3), lhs_contracting_dims={1}, rhs_contracting_dims={0} + add = bf16[16,16] add(dot.0, dot.1) + ROOT copy = bf16[16,16] copy(add) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + + XLA_VLOG_LINES(1, module->ToString()); + for (absl::string_view name : {"dot.0", "dot.1", "add"}) { + auto* instruction = FindInstruction(module.get(), name); + EXPECT_THAT(instruction, op::Sharding("{devices=[4,4]<=[16]}")); + } +} + TEST_F(ShardingPropagationTest, AsyncInstructionManualShardingArray) { const char* const hlo_string = R"( HloModule module