diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 49958936358d3..6938239fb1fdb 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -212,8 +212,8 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands( const CallGraph& call_graph, InputShardings& input_shardings) { ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; - if (input_shardings.empty() && ins->operand_count() > 0) { - input_shardings.resize(ins->operand_count()); + if (input_shardings.shardings.empty() && ins->operand_count() > 0) { + input_shardings.shardings.resize(ins->operand_count()); } for (int64_t k = 0; k < ins->operand_count(); ++k) { const HloInstruction* operand = ins->operand(k); @@ -224,14 +224,14 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands( if (operand_shape.IsToken() || operand_shape.rank() == 0) { communication_resharding_costs.push_back(zeros); memory_resharding_costs.push_back(zeros); - if (!input_shardings[k].has_value()) { - input_shardings[k] = HloSharding::Replicate(); + if (!input_shardings.shardings[k].has_value()) { + input_shardings.shardings[k] = HloSharding::Replicate(); } } else { std::optional cur_input_sharding; - CHECK_EQ(input_shardings.size(), ins->operand_count()); - if (input_shardings[k].has_value()) { - cur_input_sharding = input_shardings[k]; + CHECK_EQ(input_shardings.shardings.size(), ins->operand_count()); + if (input_shardings.shardings[k].has_value()) { + cur_input_sharding = input_shardings.shardings[k]; } else { cur_input_sharding = GetInputSharding( ins, k, output_sharding, call_graph, cluster_env.NumDevices()); @@ -250,8 +250,8 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands( } } CHECK(cur_input_sharding.has_value()); - if (!input_shardings[k].has_value()) { - input_shardings[k] = cur_input_sharding; + if (!input_shardings.shardings[k].has_value()) { + input_shardings.shardings[k] = cur_input_sharding; } if (ins->opcode() == HloOpcode::kGather && k == 0 && is_sharding_default_replicated) { @@ -259,7 +259,7 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands( << output_sharding.ToString(); communication_resharding_costs.push_back(zeros); memory_resharding_costs.push_back(zeros); - input_shardings[k] = std::nullopt; + input_shardings.shardings[k] = std::nullopt; } else { communication_resharding_costs.push_back( CommunicationReshardingCostVector( @@ -275,8 +275,7 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands( memory_resharding_costs); } -std::tuple>> +std::tuple GenerateReshardingCostsAndShardingsForAllOperands( const HloInstruction* ins, const HloSharding& output_sharding, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, @@ -286,7 +285,7 @@ GenerateReshardingCostsAndShardingsForAllOperands( GenerateReshardingCostsAndMissingShardingsForAllOperands( ins, output_sharding, strategy_map, cluster_env, call_graph, input_shardings_optional); - for (const auto& sharding_optional : input_shardings_optional) { + for (const auto& sharding_optional : input_shardings_optional.shardings) { CHECK(sharding_optional.has_value()); } @@ -333,7 +332,7 @@ void FollowArrayOrTokenStrategyGroup( double compute_cost = 0, communication_cost = 0; double memory_cost = ByteSizeOfShapeWithSharding(shape, *output_spec); size_t num_in_nodes = strategy_group.in_nodes.size(); - InputShardings input_shardings(num_in_nodes, *output_spec); + InputShardings input_shardings{name, {num_in_nodes, *output_spec}}; ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; for (size_t i = 0; i < strategy_group.in_nodes.size(); ++i) { @@ -345,7 +344,7 @@ void FollowArrayOrTokenStrategyGroup( } strategy_group.AddStrategy( - ShardingStrategy({name, *output_spec, compute_cost, communication_cost, + ShardingStrategy({*output_spec, compute_cost, communication_cost, memory_cost, communication_resharding_costs, memory_resharding_costs}), input_shardings); @@ -386,16 +385,16 @@ std::unique_ptr HandlePartialReduce( } // Get a list of input shardings, each corresponds to an operand. - InputShardings input_shardings; + std::string name = ToStringSimple(output_spec); + InputShardings input_shardings = {std::move(name)}; for (int64_t k = 0; k < output_size * 2; ++k) { if (k < output_size) { - input_shardings.push_back(input_spec); + input_shardings.shardings.push_back(input_spec); } else { - input_shardings.push_back(HloSharding::Replicate()); + input_shardings.shardings.push_back(HloSharding::Replicate()); } } - std::string name = ToStringSimple(output_spec); double compute_cost = 0, communication_cost = 0; double memory_cost = ByteSizeOfShapeWithSharding( ins->shape().tuple_shapes(i), output_spec); @@ -405,8 +404,8 @@ std::unique_ptr HandlePartialReduce( input_shardings); child_strategy_group->AddStrategy( - ShardingStrategy({std::move(name), std::move(output_spec), - compute_cost, communication_cost, memory_cost, + ShardingStrategy({std::move(output_spec), compute_cost, + communication_cost, memory_cost, std::move(resharding_costs.first), std::move(resharding_costs.second)}), std::move(input_shardings)); @@ -553,9 +552,9 @@ absl::StatusOr> FollowReduceStrategy( } } const ShardingStrategy strategy = ShardingStrategy( - {name, output_spec, compute_cost, communication_cost, memory_cost, + {output_spec, compute_cost, communication_cost, memory_cost, communication_resharding_costs, memory_resharding_costs}); - strategy_group->AddStrategy(strategy, {input_sharding}); + strategy_group->AddStrategy(strategy, {name, {input_sharding}}); } } else { LOG(FATAL) << "Unhandled kReduce shape: " << ins->shape().ToString(); @@ -574,8 +573,7 @@ std::vector FindReplicateStrategyIndices( return indices; } -std::tuple>> +std::tuple ReshardingCostsForTupleOperand(const HloInstruction* operand, const StrategyGroup& operand_strategy_vector) { // TODO(yuemmawang) Support instructions with more than one tuple operand. @@ -606,9 +604,10 @@ ReshardingCostsForTupleOperand(const HloInstruction* operand, communication_resharding_costs.back().at(i) = 0.0; } } - return {communication_resharding_costs, memory_resharding_costs, - std::vector>( - {HloSharding::Tuple(operand->shape(), tuple_element_shardings)})}; + return { + communication_resharding_costs, + memory_resharding_costs, + {{}, {HloSharding::Tuple(operand->shape(), tuple_element_shardings)}}}; } ReshardingCosts CreateZeroReshardingCostsForAllOperands( @@ -650,7 +649,7 @@ void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, HloSharding output_spec = HloSharding::Replicate(); ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; - InputShardings input_shardings; + InputShardings input_shardings = {"R"}; const int tuple_size = ins->operand(0)->shape().tuple_shapes_size(); const auto& operand_strategy_group = strategy_map.at(ins->operand(0)); @@ -677,7 +676,7 @@ void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, const StrategyGroup& child = *operand_children[i]; const Shape& tuple_shape = ins->operand(0)->shape().tuple_shapes(i); const HloSharding& input_sharding = get_input_sharding(i); - input_shardings.push_back(input_sharding); + input_shardings.shardings.push_back(input_sharding); communication_resharding_costs.push_back( CommunicationReshardingCostVector(child, tuple_shape, input_sharding, cluster_env)); @@ -685,7 +684,7 @@ void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, child, tuple_shape, input_sharding, cluster_env)); } const HloSharding& input_sharding = get_input_sharding(-1); - input_shardings.push_back(input_sharding); + input_shardings.shardings.push_back(input_sharding); } else { for (size_t i = 0; i < tuple_size; ++i) { const StrategyGroup& child = *operand_children[i]; @@ -698,20 +697,19 @@ void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, memory_resharding_costs.push_back({}); double memory_cost = ByteSizeOfShapeWithSharding(shape, output_spec); strategy_group.AddStrategy( - ShardingStrategy({"R", HloSharding::Replicate(), replicated_penalty, 0, + ShardingStrategy({HloSharding::Replicate(), replicated_penalty, 0, memory_cost, std::move(communication_resharding_costs), std::move(memory_resharding_costs)}), input_shardings); } -double ComputeCommunicationCost( - const HloInstruction* ins, - const std::vector>& operand_shardings, - const ClusterEnvironment& cluster_env) { +double ComputeCommunicationCost(const HloInstruction* ins, + const InputShardings& operand_shardings, + const ClusterEnvironment& cluster_env) { switch (ins->opcode()) { case HloOpcode::kGather: { - if (operand_shardings[0].has_value() && - !operand_shardings[0]->IsReplicated()) { + if (operand_shardings.shardings[0].has_value() && + !operand_shardings.shardings[0]->IsReplicated()) { auto mesh_shape = cluster_env.device_mesh_.dimensions(); auto mesh_dim = std::distance( mesh_shape.begin(), @@ -761,9 +759,10 @@ void AddReplicatedStrategy( CHECK(!operand->shape().IsTuple()); const auto& operand_strategy_group = strategy_map.at(operand).get(); const auto& operand_strategies = operand_strategy_group->GetStrategies(); + InputShardings input_shardings = {"R"}; + input_shardings.shardings.resize(ins->operand_count()); std::vector possible_input_shardings( - operand_strategies.size(), - std::vector>(ins->operand_count())); + operand_strategies.size(), input_shardings); std::vector possible_communication_resharding_costs( operand_strategies.size(), ReshardingCosts(ins->operand_count())); std::vector possible_memory_resharding_costs( @@ -778,7 +777,7 @@ void AddReplicatedStrategy( CHECK_EQ(possible_input_shardings.size(), operand_strategies.size()); for (size_t j = 0; j < possible_input_shardings.size(); ++j) { const auto& operand_sharding = operand_strategies[j].output_sharding; - possible_input_shardings[j][k] = operand_sharding; + possible_input_shardings[j].shardings[k] = operand_sharding; possible_communication_resharding_costs[j][k] = CommunicationReshardingCostVector(operand_strategy_group, operand_shape, operand_sharding, @@ -789,7 +788,7 @@ void AddReplicatedStrategy( } } else { for (size_t j = 0; j < possible_input_shardings.size(); ++j) { - possible_input_shardings[j][k] = replicated_strategy; + possible_input_shardings[j].shardings[k] = replicated_strategy; possible_communication_resharding_costs[j][k] = CommunicationReshardingCostVector( operand_strategy_group, operand_shape, replicated_strategy, @@ -806,7 +805,7 @@ void AddReplicatedStrategy( ins, possible_input_shardings[j], cluster_env); strategy_group.AddStrategy( ShardingStrategy( - {"R", replicated_strategy, replicated_penalty, communication_cost, + {replicated_strategy, replicated_penalty, communication_cost, memory_cost, std::move(possible_communication_resharding_costs[j]), std::move(possible_memory_resharding_costs[j])}), @@ -815,7 +814,7 @@ void AddReplicatedStrategy( } else { ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; - InputShardings input_shardings; + InputShardings input_shardings = {"R"}; if (ins->operand_count() > 0 && ins->operand(0)->shape().IsTuple()) { CHECK_EQ(ins->operand_count(), 1) @@ -843,12 +842,12 @@ void AddReplicatedStrategy( cluster_env)); memory_resharding_costs.push_back(MemoryReshardingCostVector( operand_strategy_group, operand_shape, output_spec, cluster_env)); - input_shardings.push_back(output_spec); + input_shardings.shardings.push_back(output_spec); } } } strategy_group.AddStrategy( - ShardingStrategy({"R", HloSharding::Replicate(), replicated_penalty, 0, + ShardingStrategy({HloSharding::Replicate(), replicated_penalty, 0, memory_cost, std::move(communication_resharding_costs), std::move(memory_resharding_costs)}), @@ -897,7 +896,7 @@ void EnumerateAll1DPartition( ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; - InputShardings input_shardings; + InputShardings input_shardings = {name}; if (ins->opcode() == HloOpcode::kConditional) { // TODO(pratikf): Compute input_shardings for kConditional ops communication_resharding_costs = @@ -915,7 +914,7 @@ void EnumerateAll1DPartition( *strategy_map.at(ins->operand(0))); } else if (ins->opcode() == HloOpcode::kRngBitGenerator && ins->operand(0)->shape().IsArray()) { - input_shardings.push_back(HloSharding::Replicate()); + input_shardings.shardings.push_back(HloSharding::Replicate()); std::tie(communication_resharding_costs, memory_resharding_costs) = GenerateReshardingCostsAndMissingShardingsForAllOperands( ins, output_spec, strategy_map, cluster_env, call_graph, @@ -939,7 +938,7 @@ void EnumerateAll1DPartition( ins->operand(0)->shape().rank() - 1, i, j, shape, cluster_env); } strategy_group.AddStrategy( - ShardingStrategy({name, output_spec, compute_cost, communication_cost, + ShardingStrategy({output_spec, compute_cost, communication_cost, memory_cost, std::move(communication_resharding_costs), std::move(memory_resharding_costs)}), @@ -1008,7 +1007,7 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, HloSharding output_spec = Tile(shape, tensor_dims, mesh_dims, device_mesh); double compute_cost = 0, communication_cost = 0; double memory_cost = ByteSizeOfShapeWithSharding(shape, output_spec); - InputShardings input_shardings; + InputShardings input_shardings = {name}; ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; if (ins->opcode() == HloOpcode::kConditional) { @@ -1051,7 +1050,7 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, } strategy_group.AddStrategy( - ShardingStrategy({name, output_spec, compute_cost, communication_cost, + ShardingStrategy({output_spec, compute_cost, communication_cost, memory_cost, std::move(communication_resharding_costs), std::move(memory_resharding_costs)}), input_shardings); @@ -1101,11 +1100,11 @@ void EnumerateAll1DPartitionReshape(const HloInstruction* ins, ReshardingCosts memory_resharding_costs{MemoryReshardingCostVector( operand_strategy_group, operand_shape, *input_spec, cluster_env)}; strategy_group.AddStrategy( - ShardingStrategy({name, output_spec, compute_cost, communication_cost, + ShardingStrategy({output_spec, compute_cost, communication_cost, memory_cost, std::move(communication_resharding_costs), std::move(memory_resharding_costs)}), - {*input_spec}); + {name, {*input_spec}}); } } } @@ -1370,7 +1369,7 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( std::string name = ToStringSimple(existing_sharding); ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; - InputShardings input_shardings; + InputShardings input_shardings = {name}; if (!strategy_group.in_nodes.empty()) { HloInstruction* ins = instructions.at(strategy_group.instruction_id); for (size_t i = 0; i < strategy_group.in_nodes.size(); i++) { @@ -1403,7 +1402,7 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( CHECK(input_sharding.has_value()); - input_shardings.push_back(*input_sharding); + input_shardings.shardings.push_back(*input_sharding); communication_resharding_costs.push_back( CommunicationReshardingCostVector( *operand_strategy_group, operand_shape, *input_sharding, @@ -1421,7 +1420,7 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( } strategy_group.ClearStrategies(); strategy_group.AddStrategy( - ShardingStrategy({name, existing_sharding, 0, 0, memory_cost, + ShardingStrategy({existing_sharding, 0, 0, memory_cost, communication_resharding_costs, memory_resharding_costs}), input_shardings); @@ -1652,7 +1651,7 @@ std::unique_ptr HandleManuallyShardedInstruction( strategy_groups); ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; - InputShardings input_shardings; + InputShardings input_shardings = {"MANUAL"}; if (ins->operand_count() > 0 && ins->operand(0)->shape().IsTuple()) { CHECK_EQ(ins->operand_count(), 1) @@ -1674,7 +1673,7 @@ std::unique_ptr HandleManuallyShardedInstruction( } } strategy_group->AddStrategy( - ShardingStrategy({"MANUAL", HloSharding::Replicate(), 0, 0, + ShardingStrategy({HloSharding::Replicate(), 0, 0, static_cast(ShapeUtil::ByteSizeOf(shape)), std::move(communication_resharding_costs), std::move(memory_resharding_costs)}), @@ -1725,14 +1724,13 @@ std::unique_ptr CreateReshapeStrategies( operand_strategy_group, operand->shape(), operand_strategy.output_sharding, cluster_env); strategy_group->AddStrategy( - ShardingStrategy({name, - *output_sharding, + ShardingStrategy({*output_sharding, compute_cost, communication_cost, memory_cost, {communication_resharding_costs}, {memory_resharding_costs}}), - {operand_strategy.output_sharding}); + {name, {operand_strategy.output_sharding}}); } if (strategy_group->GetStrategies().empty()) { @@ -2159,8 +2157,6 @@ absl::Status InsertReshardReshapes( // spmd partitioner generate correct code. if (inst->opcode() == HloOpcode::kDot || inst->opcode() == HloOpcode::kConvolution) { - const ShardingStrategy& stra = - GetShardingStrategy(inst, strategy_map, cost_graph, s_val); const HloInstruction* lhs = inst->operand(0); const HloInstruction* rhs = inst->operand(1); const HloSharding& lhs_sharding = lhs->sharding(); @@ -2193,7 +2189,9 @@ absl::Status InsertReshardReshapes( "Cannot generate tensor dim to mesh dim mapping"); } - if (absl::StrContains(stra.name, "allreduce") && + const InputShardings& input_shardings = + GetInputShardings(inst, strategy_map, cost_graph, s_val); + if (absl::StrContains(input_shardings.name, "allreduce") && std::any_of(lhs_con_dims.begin(), lhs_con_dims.end(), [&lhs_tensor_dim_to_mesh_dim](int64_t dim) { return lhs_tensor_dim_to_mesh_dim[dim] == -1; @@ -2205,19 +2203,20 @@ absl::Status InsertReshardReshapes( // Allow duplicated dot computation in this case to reduce // communication } else { - const InputShardings& input_shardings = - GetInputShardings(inst, strategy_map, cost_graph, s_val); - CHECK(input_shardings.size() == 2) + CHECK(input_shardings.shardings.size() == 2) << "Dot op requires both operands to have input shardings, " "but get instruction: " - << inst->ToString() << ", strategy : " << stra.ToString(); - if (input_shardings[0].has_value()) { + << inst->ToString() + << ", input shardings : " << input_shardings.ToString(); + if (input_shardings.shardings[0].has_value()) { TF_RETURN_IF_ERROR(FixMixedMeshShapeResharding( - inst, 0, *input_shardings[0], device_mesh, resharding_cache)); + inst, 0, *input_shardings.shardings[0], device_mesh, + resharding_cache)); } - if (input_shardings[1].has_value()) { + if (input_shardings.shardings[1].has_value()) { TF_RETURN_IF_ERROR(FixMixedMeshShapeResharding( - inst, 1, *input_shardings[1], device_mesh, resharding_cache)); + inst, 1, *input_shardings.shardings[1], device_mesh, + resharding_cache)); } } } @@ -2251,11 +2250,11 @@ absl::Status InsertReshardReshapes( const InputShardings& input_shardings = GetInputShardingsForTuple(inst, {static_cast(i)}, strategy_map, cost_graph, s_val); - if (input_shardings.size() > i && - input_shardings[i].has_value()) { - TF_RETURN_IF_ERROR( - FixMixedMeshShapeResharding(inst, i, *input_shardings[i], - device_mesh, resharding_cache)); + if (input_shardings.shardings.size() > i && + input_shardings.shardings[i].has_value()) { + TF_RETURN_IF_ERROR(FixMixedMeshShapeResharding( + inst, i, *input_shardings.shardings[i], device_mesh, + resharding_cache)); } } break; @@ -2265,10 +2264,11 @@ absl::Status InsertReshardReshapes( const InputShardings& input_shardings = GetInputShardingsForTuple(inst, {static_cast(i)}, strategy_map, cost_graph, s_val); - CHECK_EQ(input_shardings.size(), 1); - CHECK(input_shardings[0].has_value()); + CHECK_EQ(input_shardings.shardings.size(), 1); + CHECK(input_shardings.shardings[0].has_value()); TF_RETURN_IF_ERROR(FixMixedMeshShapeResharding( - inst, i, *input_shardings[0], device_mesh, resharding_cache)); + inst, i, *input_shardings.shardings[0], device_mesh, + resharding_cache)); } break; } @@ -2282,8 +2282,9 @@ absl::Status InsertReshardReshapes( const InputShardings& input_shardings = GetInputShardingsForTuple(inst, {static_cast(i)}, strategy_map, cost_graph, s_val); - if (!input_shardings.empty() && input_shardings[0].has_value()) { - dst_shardings[i] = *input_shardings[0]; + if (!input_shardings.shardings.empty() && + input_shardings.shardings[0].has_value()) { + dst_shardings[i] = *input_shardings.shardings[0]; } } TF_RETURN_IF_ERROR( @@ -2305,7 +2306,7 @@ absl::Status InsertReshardReshapes( } else { const InputShardings& input_shardings = GetInputShardings(inst, strategy_map, cost_graph, s_val); - if (input_shardings.empty()) { + if (input_shardings.shardings.empty()) { continue; } if (inst->opcode() == HloOpcode::kGetTupleElement) { @@ -2315,9 +2316,11 @@ absl::Status InsertReshardReshapes( } for (size_t i = 0; i < inst->operand_count(); ++i) { - if (input_shardings.size() > i && input_shardings[i].has_value()) { + if (input_shardings.shardings.size() > i && + input_shardings.shardings[i].has_value()) { TF_RETURN_IF_ERROR(FixMixedMeshShapeResharding( - inst, i, *input_shardings[i], device_mesh, resharding_cache)); + inst, i, *input_shardings.shardings[i], device_mesh, + resharding_cache)); } } } @@ -2881,7 +2884,9 @@ absl::Status GenerateReduceScatter( } const ShardingStrategy& strategy = GetShardingStrategy(inst, strategy_map, cost_graph, s_val); - if (!absl::StrContains(strategy.name, "allreduce")) { + const InputShardings& input_shardings = + GetInputShardings(inst, strategy_map, cost_graph, s_val); + if (!absl::StrContains(input_shardings.name, "allreduce")) { continue; } @@ -2985,14 +2990,14 @@ absl::Status GenerateReduceScatter( if (num_replicated_parameters >= 1 && need_all_gather.size() <= 1 && replicated_set.size() >= 5) { HloSharding output_spec = - GetReduceScatterOutput(inst, strategy, cluster_env); + GetReduceScatterOutput(inst, input_shardings, strategy, cluster_env); if (IsUndefined(output_spec)) { continue; } VLOG(10) << "SET: " << output_spec.ToString(); - if (absl::StartsWith(strategy.name, "RR = RS x SR")) { + if (absl::StartsWith(input_shardings.name, "RR = RS x SR")) { // If set the sharding for this dot instruction, the SPMD // partitioner will generate bad fallback code. replicated_set.erase(inst); @@ -3103,6 +3108,7 @@ absl::Status GenerateReduceScatter( // Return the output sharding of the reduce-scatter variant of a given strategy. HloSharding GetReduceScatterOutput(const HloInstruction* ins, + const InputShardings& input_shardings, const ShardingStrategy& strategy, const ClusterEnvironment& cluster_env) { const DeviceMesh& device_mesh = cluster_env.device_mesh_; @@ -3112,10 +3118,10 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, const DotDimensionNumbers& dot_dnums = ins->dot_dimension_numbers(); int64_t space_base_dim = dot_dnums.lhs_batch_dimensions_size(); - if (absl::StartsWith(strategy.name, "SR = SS x SR") || - absl::StartsWith(strategy.name, "RS = RS x SS")) { + if (absl::StartsWith(input_shardings.name, "SR = SS x SR") || + absl::StartsWith(input_shardings.name, "RS = RS x SS")) { int mesh_dim0, mesh_dim1; - std::tie(mesh_dim0, mesh_dim1) = ParseMeshDims(strategy.name); + std::tie(mesh_dim0, mesh_dim1) = ParseMeshDims(input_shardings.name); if (!IsDivisible(ins, device_mesh, {space_base_dim, space_base_dim + 1}, {mesh_dim0, mesh_dim1})) { @@ -3128,9 +3134,9 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, return Tile(ins->shape(), {space_base_dim, space_base_dim + 1}, {mesh_dim0, mesh_dim1}, device_mesh); } - if (absl::StartsWith(strategy.name, "SbR = SbSk x SbSk")) { + if (absl::StartsWith(input_shardings.name, "SbR = SbSk x SbSk")) { int mesh_dim0, mesh_dim1; - std::tie(mesh_dim0, mesh_dim1) = ParseMeshDims(strategy.name); + std::tie(mesh_dim0, mesh_dim1) = ParseMeshDims(input_shardings.name); if (!IsDivisible(ins, device_mesh, {0, space_base_dim}, {mesh_dim0, mesh_dim1})) { @@ -3143,8 +3149,8 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, return Tile(ins->shape(), {0, space_base_dim}, {mesh_dim0, mesh_dim1}, device_mesh); } - if (absl::StartsWith(strategy.name, "RR = RS x SR")) { - int mesh_dim = absl::StrContains(strategy.name, "{0}") ? 0 : 1; + if (absl::StartsWith(input_shardings.name, "RR = RS x SR")) { + int mesh_dim = absl::StrContains(input_shardings.name, "{0}") ? 0 : 1; if (!IsDivisible(ins, device_mesh, {space_base_dim}, {mesh_dim})) { return Undefined(); @@ -3152,7 +3158,7 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, return Tile(ins->shape(), {space_base_dim}, {mesh_dim}, device_mesh); } - if (absl::StartsWith(strategy.name, "R = Sk x Sk")) { + if (absl::StartsWith(input_shardings.name, "R = Sk x Sk")) { int mesh_dim = 0; if (!IsDivisible(ins, device_mesh_1d, {space_base_dim}, {mesh_dim})) { @@ -3167,10 +3173,10 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, int out_batch_dim = conv_dnums.output_batch_dimension(); int out_out_channel_dim = conv_dnums.output_feature_dimension(); - if (absl::StartsWith(strategy.name, "SR = SS x SR") || - absl::StartsWith(strategy.name, "RS = RS x SS")) { + if (absl::StartsWith(input_shardings.name, "SR = SS x SR") || + absl::StartsWith(input_shardings.name, "RS = RS x SS")) { int mesh_dim0, mesh_dim1; - std::tie(mesh_dim0, mesh_dim1) = ParseMeshDims(strategy.name); + std::tie(mesh_dim0, mesh_dim1) = ParseMeshDims(input_shardings.name); if (!IsDivisible(ins, device_mesh, {out_batch_dim, out_out_channel_dim}, {mesh_dim0, mesh_dim1})) { @@ -3180,7 +3186,7 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, return Tile(ins->shape(), {out_batch_dim, out_out_channel_dim}, {mesh_dim0, mesh_dim1}, device_mesh); } - if (absl::StartsWith(strategy.name, "R = Sk x Sk")) { + if (absl::StartsWith(input_shardings.name, "R = Sk x Sk")) { int mesh_dim = 0; if (!IsDivisible(ins, device_mesh_1d, {out_batch_dim}, {mesh_dim})) { @@ -3194,14 +3200,14 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, CHECK_EQ(ins->shape().rank(), 1); int mesh_dim; - if (absl::StrContains(strategy.name, "allreduce @ [0]")) { + if (absl::StrContains(input_shardings.name, "allreduce @ [0]")) { mesh_dim = 0; } else { mesh_dim = 1; } if (strategy.output_sharding.IsReplicated()) { - if (absl::StrContains(strategy.name, "1d")) { + if (absl::StrContains(input_shardings.name, "1d")) { if (!IsDivisible(ins, device_mesh_1d, {0}, {mesh_dim})) { return Undefined(); } diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.h b/xla/hlo/experimental/auto_sharding/auto_sharding.h index e37aecaa46898..2af64b64ef848 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -207,6 +207,7 @@ bool HasReduceScatterOpportunity(const HloInstruction* inst, const ConstInstructionSet& modified); HloSharding GetReduceScatterOutput(const HloInstruction* ins, + const InputShardings& input_shardings, const ShardingStrategy& strategy, const ClusterEnvironment& cluster_env); diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc index 67a67a2a27884..5d1016830c1c6 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc @@ -123,16 +123,26 @@ CostGraph::CostGraph(const StrategyGroups& strategy_groups, node_lens_[dst_idx]); absl::flat_hash_map src_strategy_name_to_idx_map; - for (NodeStrategyIdx i = 0; i < node_lens_[src_idx]; ++i) { + const auto& src_strategy_input_shardings = + src_strategy_group.GetStrategyInputShardings(); + for (size_t iid = 0; iid < src_strategy_input_shardings.size(); ++iid) { + const InputShardings& input_shardings = src_strategy_input_shardings[iid]; + NodeStrategyIdx i = + src_strategy_group.GetStrategyIdxForInputShardings(iid); const ShardingStrategy& strategy = src_strategy_group.GetStrategy(i); if (strategy.communication_cost > 0) { - src_strategy_name_to_idx_map[strategy.name] = i; + src_strategy_name_to_idx_map[input_shardings.name] = i; } } - for (NodeStrategyIdx i = 0; i < node_lens_[dst_idx]; ++i) { + const auto& dst_strategy_input_shardings = + dst_strategy_group.GetStrategyInputShardings(); + for (size_t iid = 0; iid < dst_strategy_input_shardings.size(); ++iid) { + const InputShardings& input_shardings = dst_strategy_input_shardings[iid]; + NodeStrategyIdx i = + dst_strategy_group.GetStrategyIdxForInputShardings(iid); const ShardingStrategy& dst_strategy = dst_strategy_group.GetStrategy(i); if (dst_strategy.communication_cost > 0) { - auto it = src_strategy_name_to_idx_map.find(dst_strategy.name); + auto it = src_strategy_name_to_idx_map.find(input_shardings.name); if (it != src_strategy_name_to_idx_map.end()) { const auto& src_strategy = src_strategy_group.GetStrategy(it->second); CHECK_LE(std::abs(src_strategy.communication_cost - diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index d5d008ee19442..9d8ba9e1ec36a 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -335,12 +335,12 @@ void HandlerBase::AppendNewStrategy(const std::string& name, } strategy_group_->AddStrategy( - ShardingStrategy({name, output_spec, compute_cost, communication_cost, + ShardingStrategy({output_spec, compute_cost, communication_cost, static_cast(ByteSizeOfShapeWithSharding( ins_->shape(), output_spec)), communication_resharding_costs, memory_resharding_costs}), - {input_specs.begin(), input_specs.end()}); + {name, {input_specs.begin(), input_specs.end()}}); } // Given lhs and rhs dim maps, infers a sharding for the output by relying @@ -467,7 +467,7 @@ void HandlerBase::SortStrategies() { [](const std::pair& s1, const std::pair& s2) { if (s1.first.memory_cost == s2.first.memory_cost) { - return s1.first.name < s2.first.name; + return s1.second.name < s2.second.name; } else { return s1.first.memory_cost < s2.first.memory_cost; } diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 0ebdd990e92cf..4b0218f807fe2 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -346,14 +346,14 @@ BuildStrategyAndCost( ByteSizeOfShapeWithSharding(ins->shape(), scatter_sharding); InputShardings input_shardings_optional( - {data_sharding, indices_sharding, update_sharding}); + {name, {data_sharding, indices_sharding, update_sharding}}); std::pair resharding_costs = GenerateReshardingCostsAndMissingShardingsForAllOperands( ins, scatter_sharding, strategy_map, cluster_env, call_graph, input_shardings_optional); strategy_group->AddStrategy( - ShardingStrategy({name, scatter_sharding, compute_cost, + ShardingStrategy({scatter_sharding, compute_cost, communication_cost, memory_cost, std::move(resharding_costs.first), std::move(resharding_costs.second)}), @@ -397,15 +397,14 @@ BuildStrategyAndCost( double memory_cost = ByteSizeOfShapeWithSharding(gather_shape, output_sharding); InputShardings input_shardings_optional( - {data_sharding, indices_sharding}); + {output_sharding.ToString(), {data_sharding, indices_sharding}}); std::pair resharding_costs = GenerateReshardingCostsAndMissingShardingsForAllOperands( ins, output_sharding, strategy_map, cluster_env, call_graph, input_shardings_optional); strategy_group->AddStrategy( - ShardingStrategy({std::string(output_sharding.ToString()), - output_sharding, compute_cost, + ShardingStrategy({output_sharding, compute_cost, communication_cost, memory_cost, std::move(resharding_costs.first), std::move(resharding_costs.second)}), @@ -565,14 +564,13 @@ BuildStrategyAndCost( MemoryReshardingCostVector(src_strategy_group, operand->shape(), input_spec, cluster_env); strategy_group->AddStrategy( - ShardingStrategy({name, - output_spec, + ShardingStrategy({output_spec, compute_cost, communication_cost, memory_cost, {communication_resharding_costs}, {memory_resharding_costs}}), - {input_spec}); + {name, {input_spec}}); } break; } @@ -685,9 +683,9 @@ BuildStrategyAndCost( if (k == follow_idx || ToString(ins->operand(k)->shape().dimensions()) == ToString(operand->shape().dimensions())) { - input_shardings.push_back(input_spec); + input_shardings.shardings.push_back(input_spec); } else { - input_shardings.push_back(std::nullopt); + input_shardings.shardings.push_back(std::nullopt); } } if (!output_spec.has_value()) { @@ -703,11 +701,10 @@ BuildStrategyAndCost( input_shardings); strategy_group->AddStrategy( - ShardingStrategy({name, *output_spec, compute_cost, - communication_cost, memory_cost, - std::move(resharding_costs.first), + ShardingStrategy({*output_spec, compute_cost, communication_cost, + memory_cost, std::move(resharding_costs.first), std::move(resharding_costs.second)}), - {input_spec}); + {name, {input_spec}}); } if (strategy_group->GetStrategies().empty()) { @@ -1094,7 +1091,7 @@ BuildStrategyAndCost( const InputShardings& input_shardings = strategy_input_shardings[iid]; const ShardingStrategy& strategy = strategy_group->GetStrategyForInputShardings(iid); - if (strategy.name == stra_names[idx]) { + if (input_shardings.name == stra_names[idx]) { new_strategies.push_back({strategy, input_shardings}); } } diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h index 04c15b20e9aa1..49212fe84ce65 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h @@ -77,12 +77,37 @@ using ReshardingCache = ConstInstructionMap>>; // Resharding costs for each operand using ReshardingCosts = std::vector>; -// Optional shardings for each operand -using InputShardings = std::vector>; + +// A named vector of optional shardings for each operand. +struct InputShardings { + std::string name; + std::vector> shardings; + + std::string ToString() const { + std::string str = absl::StrCat(name, " "); + for (const auto& s : shardings) { + if (!s.has_value()) { + absl::StrAppend(&str, "[*],"); + } else if (s->IsReplicated()) { + absl::StrAppend(&str, "[R],"); + } else { + if (s->ReplicateOnLastTileDim()) { + absl::StrAppend( + &str, "[", absl::StrJoin(s->tile_assignment().dimensions(), ", "), + "]last_tile_dim_replicate,"); + } else { + absl::StrAppend( + &str, "[", absl::StrJoin(s->tile_assignment().dimensions(), ", "), + "],"); + } + } + } + return str; + } +}; // One sharding strategy struct ShardingStrategy { - std::string name; HloSharding output_sharding; double compute_cost; double communication_cost; @@ -94,9 +119,7 @@ struct ShardingStrategy { ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; - std::string ToString() const { - return absl::StrCat(name, ", ", output_sharding.ToString()); - } + std::string ToString() const { return output_sharding.ToString(); } std::string ToStringLong() const { std::vector communication_resharding_vector_strings; @@ -119,7 +142,7 @@ struct ShardingStrategy { "{", absl::StrJoin(memory_resharding_vector_strings, ", "), "}"); return absl::StrCat( - name, ", ", output_sharding.ToString(), ", compute_cost=", compute_cost, + output_sharding.ToString(), ", compute_cost=", compute_cost, ", communication_cost=", communication_cost, ", memory_cost=", memory_cost, ", communication_resharding_costs=", communication_resharding_cost_str, @@ -127,7 +150,7 @@ struct ShardingStrategy { } bool operator==(const ShardingStrategy& other) const { - return name == other.name && output_sharding == other.output_sharding && + return output_sharding == other.output_sharding && compute_cost == other.compute_cost && communication_cost == other.communication_cost && memory_cost == other.memory_cost && @@ -221,25 +244,8 @@ struct StrategyGroup { } if (!is_tuple) { for (const auto& input_shardings : strategy_input_shardings) { - std::string input_sharding_str = "{"; - for (const auto& s : input_shardings) { - if (!s.has_value()) { - input_sharding_str += "[*],"; - } else if (s->IsReplicated()) { - input_sharding_str += "[R],"; - } else { - if (s->ReplicateOnLastTileDim()) { - input_sharding_str += - "[" + absl::StrJoin(s->tile_assignment().dimensions(), ", ") + - "]last_tile_dim_replicate,"; - } else { - input_sharding_str += - "[" + absl::StrJoin(s->tile_assignment().dimensions(), ", ") + - "],"; - } - } - } - input_sharding_str += "}\n"; + const std::string input_sharding_str = + absl::StrCat("{", input_shardings.ToString(), "}\n"); absl::StrAppend(&str, indent, "Input Sharding ", input_sharding_str); } } @@ -313,6 +319,10 @@ struct StrategyGroup { return strategies[strategy_idx]; } + size_t GetStrategyIdxForInputShardings(size_t input_sharding_idx) const { + return input_sharding_idx_to_strategy_idx[input_sharding_idx]; + } + const InputShardings& GetInputShardings(size_t input_sharding_idx) const { return strategy_input_shardings[input_sharding_idx]; } diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index 641f02ef2e918..5b3e43053ac2f 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -851,8 +851,8 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) { continue; } std::string key = strategy.output_sharding.ToString(); - if (!input_shardings.empty()) { - for (const auto& sharding : input_shardings) { + if (!input_shardings.shardings.empty()) { + for (const auto& sharding : input_shardings.shardings) { key += "/" + (sharding.has_value() ? sharding->ToString() : "none"); } }