diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index 85898c8b4e6dfd..237ffa4e3d4f40 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -444,13 +444,13 @@ std::optional HandlerBase::GetShardingFromUser( CHECK_OK(ins_clone->ReplaceOperandWith(1, rhs_clone.get())); if (ins_->opcode() == HloOpcode::kConvolution) { xla::InferConvolutionShardingFromOperands( - ins_clone.get(), call_graph_, 10, - /* may_combine_partial_sharding */ true, /* is_spmd */ true); + ins_clone.get(), /* aggressiveness */ 10, + /* may_combine_partial_sharding */ true); } else { xla::InferDotShardingFromOperands( - ins_clone.get(), call_graph_, + ins_clone.get(), dot_as_convolution_util::ParseDotGeneralFromDot(ins_clone.get()), - /* may_combine_partial_sharding/ */ true, /* is_spmd */ true); + /* aggressiveness */ 10, /* may_combine_partial_sharding */ true); } if (!ins_clone->has_sharding()) { return std::nullopt; diff --git a/xla/service/sharding_propagation.cc b/xla/service/sharding_propagation.cc index b9508c04021f7e..316644bf87ea8b 100644 --- a/xla/service/sharding_propagation.cc +++ b/xla/service/sharding_propagation.cc @@ -16,10 +16,10 @@ limitations under the License. #include "xla/service/sharding_propagation.h" #include +#include #include #include #include -#include #include #include #include @@ -36,10 +36,12 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/array.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_domain_metadata.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -47,6 +49,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_sharding_metadata.h" #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/protobuf_util.h" +#include "xla/service/call_graph.h" #include "xla/service/dot_as_convolution_util.h" #include "xla/service/host_memory_offload_annotations.h" #include "xla/service/spmd/shard_barrier_partitioner.h" @@ -416,55 +419,6 @@ bool SupportSpatialPartitioning( } } -// Helper to lookahead sharding of user of an instruction to be used as guidance -// for ambiguous cases. -std::optional LookaheadUserSharding(HloInstruction* instr, - bool is_spmd, - const CallGraph& call_graph) { - if (instr->user_count() != 1) { - return std::nullopt; - } - HloInstruction* current_user = instr->users()[0]; - std::optional sharding; - std::vector users_chain = {instr, current_user}; - // Collect single user instructions along the way. - while (!current_user->has_sharding()) { - // Only consider single user chains. - if (current_user->users().size() != 1) { - users_chain.clear(); - break; - } - current_user = current_user->users()[0]; - users_chain.push_back(current_user); - } - // Early exit for unsupported cases. - if (users_chain.empty()) { - return std::nullopt; - } - for (int i = users_chain.size() - 1; i >= 1; --i) { - HloInstruction* user = users_chain[i]; - HloInstruction* current = users_chain[i - 1]; - CHECK(user->has_sharding()); - sharding = ShardingPropagation::GetShardingFromUser( - *current, *user, INT64_MAX, is_spmd, call_graph, - /*sharding_helper=*/nullptr); - // We need to set the sharding to the instruction, because - // GetShardingFromUser() interface uses sharding from the instruction - // itself. It will be cleared out later. - if (sharding.has_value() && i != 1) { - current->set_sharding(*sharding); - continue; - } - break; - } - // Clear the sharding of the middle instructions we set the sharding of - // because they were unsharded. - for (int i = 1; i < users_chain.size() - 1; ++i) { - users_chain[i]->clear_sharding(); - } - return sharding; -} - // Infer output sharding on index parallel dimensions for gather from operand // and indices. bool InferGatherParallelShardingFromOperands( @@ -1071,9 +1025,9 @@ bool IsCSEPreventionSharding(const HloSharding& sharding) { } // namespace bool InferDotShardingFromOperands( - HloInstruction* instruction, const CallGraph& call_graph, + HloInstruction* instruction, const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, - bool may_combine_partial_sharding, bool is_spmd) { + int64_t aggressiveness, bool may_combine_partial_sharding) { auto from_operand = [&](int64_t operand_index) { auto operand = instruction->operand(operand_index); const HloSharding& operand_sharding = operand->sharding(); @@ -1128,55 +1082,66 @@ bool InferDotShardingFromOperands( from_operand(1), instruction, may_combine_partial_sharding, /*allow_aggressive_resharding=*/false); } - // If not improved sharding found then do not set any sharding. + + // Four cases based on if improved_operand_0 and improved_operand_1 are + // available. + // Case 0. Both operands have no improved sharding. if (!improved_operand_0.has_value() && !improved_operand_1.has_value()) { return false; } - // Sharding found from operand 0 but not operand 1. Set sharding from operand - // 0 + // Case 1. Sharding found from operand 0 but not operand 1. Set sharding from + // operand 0. if (improved_operand_0.has_value() && !improved_operand_1.has_value()) { instruction->set_sharding(*improved_operand_0); return true; } - // Sharding found from operand 1 but not operand 0. Set sharding from operand - // 1 + // Case 2. Sharding found from operand 1 but not operand 0. Set sharding from + // operand 1. if (!improved_operand_0.has_value() && improved_operand_1.has_value()) { instruction->set_sharding(*improved_operand_1); return true; } + // Case 3. Both operands have improved shardings. CHECK(improved_operand_0.has_value() && improved_operand_1.has_value()); - std::optional lookahead_sharding = - LookaheadUserSharding(instruction, is_spmd, call_graph); + + // If one of the improved shardings is a sub-tiling or equal to the other, use + // the better sharding with more tiles. + if (hlo_sharding_util::IsSubTilingOrEqualSharding( + instruction->shape(), *improved_operand_0, *improved_operand_1)) { + instruction->set_sharding(*improved_operand_0); + return true; + } + if (hlo_sharding_util::IsSubTilingOrEqualSharding( + instruction->shape(), *improved_operand_1, *improved_operand_0)) { + instruction->set_sharding(*improved_operand_1); + return true; + } + + // If the two improved shardings are mergeable, there is no conflict. + if (std::optional improved_sharding = + hlo_sharding_util::ReturnImprovedShardingImpl( + *improved_operand_0, &improved_operand_1.value(), + instruction->shape(), may_combine_partial_sharding, + /*allow_aggressive_resharding=*/false)) { + instruction->set_sharding(*improved_sharding); + return true; + } + + if (aggressiveness < 3) { + // We can improve the dot with different shardings. Pause the propagation + // and wait for the winner between the two operands. + return false; + } + + // The two improved sharding are different and we are at the highest + // aggressiveness. Prioritize the operand with larger size. std::array sharding_priority = {*improved_operand_0, *improved_operand_1}; - bool priority_defined_with_lookahead = false; - // Found sharding from lookahead. - if (lookahead_sharding.has_value()) { - const bool operand_0_is_lookahead_subtiling = - hlo_sharding_util::IsSubTilingOrEqualSharding( - instruction->shape(), *lookahead_sharding, *improved_operand_0); - const bool operand_1_is_lookahead_subtiling = - hlo_sharding_util::IsSubTilingOrEqualSharding( - instruction->shape(), *lookahead_sharding, *improved_operand_1); - // If the sharding from operand 0 is a subtiling of the user, but not the - // one from operand 1 prioritize that sharding. - if (operand_0_is_lookahead_subtiling && !operand_1_is_lookahead_subtiling) { - priority_defined_with_lookahead = true; - } - // If the sharding from operand 1 is a subtiling of the user, but not the - // one from operand 0 prioritize that sharding. - if (!operand_0_is_lookahead_subtiling && operand_1_is_lookahead_subtiling) { - instruction->set_sharding(*improved_operand_1); - std::swap(sharding_priority[0], sharding_priority[1]); - priority_defined_with_lookahead = true; - } - } - // If lookahead didn't define a priority then use size. - if (!priority_defined_with_lookahead && - ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) < - ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())) { + if (ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) < + ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())) { std::swap(sharding_priority[0], sharding_priority[1]); } + // Set primary sharding to the instruction and then try to improve it with // the secondary sharding. instruction->set_sharding(sharding_priority[0]); @@ -1187,10 +1152,8 @@ bool InferDotShardingFromOperands( // Convolution handling for InferShardingFromOperands(). bool InferConvolutionShardingFromOperands(HloInstruction* instruction, - const CallGraph& call_graph, int64_t aggressiveness, - bool may_combine_partial_sharding, - bool is_spmd) { + bool may_combine_partial_sharding) { auto get_partitions_for_dims = [&](const HloInstruction* inst, absl::Span< @@ -1225,8 +1188,8 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction, (lhs_conv_spatial_partitions == 1 && rhs_conv_spatial_partitions == 1 && instruction->batch_group_count() == 1 && instruction->feature_group_count() == 1)) { - return InferDotShardingFromOperands(instruction, call_graph, dot_dims, - may_combine_partial_sharding, is_spmd); + return InferDotShardingFromOperands(instruction, dot_dims, aggressiveness, + may_combine_partial_sharding); } const auto& dnums = instruction->convolution_dimension_numbers(); const HloInstruction* lhs = instruction->operand(0); @@ -2329,9 +2292,8 @@ bool ShardingPropagation::InferShardingFromOperands( 1); } case HloOpcode::kConvolution: - return InferConvolutionShardingFromOperands( - instruction, call_graph, aggressiveness, may_combine_partial_sharding, - is_spmd_); + return InferConvolutionShardingFromOperands(instruction, aggressiveness, + may_combine_partial_sharding); case HloOpcode::kTranspose: { const HloInstruction* input = instruction->operand(0); if (!hlo_sharding_util::IsSpatiallyPartitioned(input)) { @@ -2420,9 +2382,8 @@ bool ShardingPropagation::InferShardingFromOperands( case HloOpcode::kDot: { const auto& dnums = dot_as_convolution_util::ParseDotGeneralFromDot(instruction); - return InferDotShardingFromOperands(instruction, call_graph, dnums, - may_combine_partial_sharding, - is_spmd_); + return InferDotShardingFromOperands(instruction, dnums, aggressiveness, + may_combine_partial_sharding); } case HloOpcode::kParameter: { auto parent_it = computation_map.find(instruction->parent()); diff --git a/xla/service/sharding_propagation.h b/xla/service/sharding_propagation.h index 27cef820977436..2654a1fd7d335b 100644 --- a/xla/service/sharding_propagation.h +++ b/xla/service/sharding_propagation.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_SHARDING_PROPAGATION_H_ #define XLA_SERVICE_SHARDING_PROPAGATION_H_ +#include #include #include #include @@ -35,17 +36,15 @@ namespace xla { // Infers the shardings for a dot HLO op from the shardings on its operands, // which are expected to have sharding annotations. bool InferDotShardingFromOperands( - HloInstruction* instruction, const CallGraph& call_graph, + HloInstruction* instruction, const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, - bool may_combine_partial_sharding, bool is_spmd); + int64_t aggressiveness, bool may_combine_partial_sharding); // Infers the shardings for a convolution HLO op from the shardings on its // operands, which are expected to have sharding annotations. bool InferConvolutionShardingFromOperands(HloInstruction* instruction, - const CallGraph& call_graph, int64_t aggressiveness, - bool may_combine_partial_sharding, - bool is_spmd); + bool may_combine_partial_sharding); // Remove Sharding custom-call instruction by folding the sharding attribute // to its operand. If the operand already has a different sharding, insert a diff --git a/xla/service/sharding_propagation_test.cc b/xla/service/sharding_propagation_test.cc index 5ca4b47d8ea15c..565314d9150e33 100644 --- a/xla/service/sharding_propagation_test.cc +++ b/xla/service/sharding_propagation_test.cc @@ -3324,7 +3324,7 @@ ENTRY %conv { EXPECT_THAT(instruction, op::Sharding("{devices=[2,2,2]0,1,2,3,4,5,6,7}")); if (GetParam().propagate_metadata && !GetParam().clear_metadata) { EXPECT_THAT(instruction->sharding(), - ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")})); + ShardingMetadata({CreateMetadata("a"), CreateMetadata("b")})); } else { EXPECT_THAT(instruction->sharding(), ShardingMetadata({})); } @@ -3396,7 +3396,7 @@ ENTRY %conv { EXPECT_THAT(instruction, op::Sharding("{devices=[2,4]0,2,3,1,4,6,7,5}")); if (GetParam().propagate_metadata && !GetParam().clear_metadata) { EXPECT_THAT(instruction->sharding(), - ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")})); + ShardingMetadata({CreateMetadata("a"), CreateMetadata("b")})); } else { EXPECT_THAT(instruction->sharding(), ShardingMetadata({})); } @@ -11863,7 +11863,7 @@ ENTRY main.9 { op::Sharding("{{devices=[4]<=[4]}, {devices=[4]<=[4]}}")); } -TEST_F(ShardingPropagationTest, LookaheadUsersOfDot) { +TEST_F(ShardingPropagationTest, InferDotShardingFromOperands1) { const char* const hlo_string = R"( HloModule module @@ -11880,24 +11880,108 @@ ENTRY %entry { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); TF_ASSERT_OK_AND_ASSIGN( - bool changed, - ShardingPropagation( - /*is_spmd=*/true, /*propagate_metadata=*/true, - /*allow_spmd_sharding_propagation_to_output=*/{true}, - /*allow_spmd_sharding_propagation_to_parameters=*/{true}) - .Run(module.get())); + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); EXPECT_TRUE(changed); XLA_VLOG_LINES(1, module->ToString()); - // Check dangling sharding custom-call can be removed by DCE after - // propagation. auto* instruction = FindInstruction(module.get(), "dot.1"); - // Check sharding is correctly propagated. EXPECT_THAT(instruction, op::Sharding( "{devices=[4,4,1,4]<=[4,16]T(1,0) last_tile_dim_replicate}")); } +TEST_F(ShardingPropagationTest, InferDotShardingFromOperands2) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %entry { + p0 = bf16[16,32] parameter(0), sharding={devices=[16,1]<=[16]} + p1 = bf16[32,64] parameter(1), sharding={devices=[1,16]<=[16]} + dot = bf16[16,64] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT copy = bf16[16,64] copy(dot), sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + + XLA_VLOG_LINES(1, module->ToString()); + auto* instruction = FindInstruction(module.get(), "dot"); + EXPECT_THAT(instruction, op::Sharding("{devices=[1,16]<=[16]}")); +} + +TEST_F(ShardingPropagationTest, InferDotShardingFromOperands3) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %entry { + p0 = bf16[4,16,32] parameter(0), sharding={devices=[2,4,2]<=[16]} + p1 = bf16[4,32,64] parameter(1), sharding={devices=[2,8,1]<=[16]} + dot = bf16[4,16,64] dot(p0, p1), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} + ROOT copy = bf16[4,16,64] copy(dot) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + + XLA_VLOG_LINES(1, module->ToString()); + auto* instruction = FindInstruction(module.get(), "dot"); + EXPECT_THAT( + instruction, + op::Sharding("{devices=[2,4,1,2]<=[16] last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, InferDotShardingFromOperands4) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %entry { + p0 = bf16[4,16,32] parameter(0), sharding={devices=[2,1,8]<=[16]} + p1 = bf16[4,32,64] parameter(1), sharding={devices=[4,1,4]<=[16]} + dot = bf16[4,16,64] dot(p0, p1), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} + ROOT copy = bf16[4,16,64] copy(dot) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + + XLA_VLOG_LINES(1, module->ToString()); + auto* instruction = FindInstruction(module.get(), "dot"); + EXPECT_THAT(instruction, op::Sharding("{devices=[4,1,4]<=[16]}")); +} + +TEST_F(ShardingPropagationTest, InferDotShardingFromOperands5) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %entry { + p0 = bf16[16,16] parameter(0), sharding={devices=[4,4]<=[4,4]T(1,0)} + p1 = bf16[16,16] parameter(1), sharding={devices=[4,4]<=[4,4]T(1,0)} + dot.0 = bf16[16,16] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={1} + p2 = bf16[16,16] parameter(2), sharding={devices=[4,4]<=[16]} + p3 = bf16[16,16] parameter(3), sharding={devices=[4,4]<=[16]} + dot.1 = bf16[16,16] dot(p2, p3), lhs_contracting_dims={1}, rhs_contracting_dims={0} + add = bf16[16,16] add(dot.0, dot.1) + ROOT copy = bf16[16,16] copy(add) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + + XLA_VLOG_LINES(1, module->ToString()); + for (absl::string_view name : {"dot.0", "dot.1", "add"}) { + auto* instruction = FindInstruction(module.get(), name); + EXPECT_THAT(instruction, op::Sharding("{devices=[4,4]<=[16]}")); + } +} + TEST_F(ShardingPropagationTest, AsyncInstructionManualShardingArray) { const char* const hlo_string = R"( HloModule module