diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/xla/hlo/experimental/auto_sharding/auto_sharding.cc index f486e7d94baba..ea48e81f1a460 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -48,7 +48,6 @@ limitations under the License. #include "absl/time/clock.h" #include "absl/time/time.h" #include "absl/types/span.h" -#include "xla/array.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_memory.h" @@ -89,6 +88,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" namespace xla { namespace spmd { @@ -209,8 +209,7 @@ std::pair GenerateReshardingCostsAndMissingShardingsForAllOperands( const HloInstruction* ins, const HloSharding& output_sharding, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, - const CallGraph& call_graph, - std::vector>& input_shardings) { + const CallGraph& call_graph, InputShardings& input_shardings) { ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; if (input_shardings.empty() && ins->operand_count() > 0) { @@ -282,7 +281,7 @@ GenerateReshardingCostsAndShardingsForAllOperands( const HloInstruction* ins, const HloSharding& output_sharding, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, const CallGraph& call_graph) { - std::vector> input_shardings_optional; + InputShardings input_shardings_optional; std::pair resharding_costs = GenerateReshardingCostsAndMissingShardingsForAllOperands( ins, output_sharding, strategy_map, cluster_env, call_graph, @@ -334,8 +333,7 @@ void FollowArrayOrTokenStrategyGroup( double compute_cost = 0, communication_cost = 0; double memory_cost = ByteSizeOfShapeWithSharding(shape, *output_spec); size_t num_in_nodes = strategy_group.in_nodes.size(); - std::vector> input_shardings(num_in_nodes, - *output_spec); + InputShardings input_shardings(num_in_nodes, *output_spec); ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; for (size_t i = 0; i < strategy_group.in_nodes.size(); ++i) { @@ -349,7 +347,8 @@ void FollowArrayOrTokenStrategyGroup( strategy_group.AddStrategy( ShardingStrategy({name, *output_spec, compute_cost, communication_cost, memory_cost, communication_resharding_costs, - memory_resharding_costs, input_shardings})); + memory_resharding_costs}), + input_shardings); } } @@ -387,7 +386,7 @@ std::unique_ptr HandlePartialReduce( } // Get a list of input shardings, each corresponds to an operand. - std::vector> input_shardings; + InputShardings input_shardings; for (int64_t k = 0; k < output_size * 2; ++k) { if (k < output_size) { input_shardings.push_back(input_spec); @@ -405,10 +404,12 @@ std::unique_ptr HandlePartialReduce( ins, output_spec, strategy_map, cluster_env, call_graph, input_shardings); - child_strategy_group->AddStrategy(ShardingStrategy( - {std::move(name), std::move(output_spec), compute_cost, - communication_cost, memory_cost, std::move(resharding_costs.first), - std::move(resharding_costs.second), std::move(input_shardings)})); + child_strategy_group->AddStrategy( + ShardingStrategy({std::move(name), std::move(output_spec), + compute_cost, communication_cost, memory_cost, + std::move(resharding_costs.first), + std::move(resharding_costs.second)}), + std::move(input_shardings)); } strategy_group->AddChild(std::move(child_strategy_group)); @@ -551,16 +552,10 @@ absl::StatusOr> FollowReduceStrategy( memory_resharding_costs.push_back(zeros); } } - const ShardingStrategy strategy = - ShardingStrategy({name, - output_spec, - compute_cost, - communication_cost, - memory_cost, - communication_resharding_costs, - memory_resharding_costs, - {input_sharding}}); - strategy_group->AddStrategy(strategy); + const ShardingStrategy strategy = ShardingStrategy( + {name, output_spec, compute_cost, communication_cost, memory_cost, + communication_resharding_costs, memory_resharding_costs}); + strategy_group->AddStrategy(strategy, {input_sharding}); } } else { LOG(FATAL) << "Unhandled kReduce shape: " << ins->shape().ToString(); @@ -655,7 +650,7 @@ void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, HloSharding output_spec = HloSharding::Replicate(); ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; - std::vector> input_shardings; + InputShardings input_shardings; const int tuple_size = ins->operand(0)->shape().tuple_shapes_size(); const auto& operand_strategy_group = strategy_map.at(ins->operand(0)); @@ -705,7 +700,8 @@ void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, strategy_group.AddStrategy( ShardingStrategy({"R", HloSharding::Replicate(), replicated_penalty, 0, memory_cost, std::move(communication_resharding_costs), - std::move(memory_resharding_costs), input_shardings})); + std::move(memory_resharding_costs)}), + input_shardings); } double ComputeCommunicationCost( @@ -765,10 +761,9 @@ void AddReplicatedStrategy( CHECK(!operand->shape().IsTuple()); const auto& operand_strategy_group = strategy_map.at(operand).get(); const auto& operand_strategies = operand_strategy_group->GetStrategies(); - std::vector>> - possible_input_shardings( - operand_strategies.size(), - std::vector>(ins->operand_count())); + std::vector possible_input_shardings( + operand_strategies.size(), + std::vector>(ins->operand_count())); std::vector possible_communication_resharding_costs( operand_strategies.size(), ReshardingCosts(ins->operand_count())); std::vector possible_memory_resharding_costs( @@ -809,16 +804,18 @@ void AddReplicatedStrategy( for (size_t j = 0; j < possible_input_shardings.size(); ++j) { double communication_cost = ComputeCommunicationCost( ins, possible_input_shardings[j], cluster_env); - strategy_group.AddStrategy(ShardingStrategy( - {"R", replicated_strategy, replicated_penalty, communication_cost, - memory_cost, std::move(possible_communication_resharding_costs[j]), - std::move(possible_memory_resharding_costs[j]), - std::move(possible_input_shardings[j])})); + strategy_group.AddStrategy( + ShardingStrategy( + {"R", replicated_strategy, replicated_penalty, communication_cost, + memory_cost, + std::move(possible_communication_resharding_costs[j]), + std::move(possible_memory_resharding_costs[j])}), + std::move(possible_input_shardings[j])); } } else { ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; - std::vector> input_shardings; + InputShardings input_shardings; if (ins->operand_count() > 0 && ins->operand(0)->shape().IsTuple()) { CHECK_EQ(ins->operand_count(), 1) @@ -850,10 +847,12 @@ void AddReplicatedStrategy( } } } - strategy_group.AddStrategy(ShardingStrategy( - {"R", HloSharding::Replicate(), replicated_penalty, 0, memory_cost, - std::move(communication_resharding_costs), - std::move(memory_resharding_costs), input_shardings})); + strategy_group.AddStrategy( + ShardingStrategy({"R", HloSharding::Replicate(), replicated_penalty, 0, + memory_cost, + std::move(communication_resharding_costs), + std::move(memory_resharding_costs)}), + input_shardings); } } @@ -894,7 +893,7 @@ void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; - std::vector> input_shardings; + InputShardings input_shardings; if (ins->opcode() == HloOpcode::kConditional) { // TODO(pratikf): Compute input_shardings for kConditional ops communication_resharding_costs = @@ -935,10 +934,12 @@ void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, communication_cost = ComputeSortCommunicationCost( ins->operand(0)->shape().rank() - 1, i, j, shape, cluster_env); } - strategy_group.AddStrategy(ShardingStrategy( - {name, output_spec, compute_cost, communication_cost, memory_cost, - std::move(communication_resharding_costs), - std::move(memory_resharding_costs), input_shardings})); + strategy_group.AddStrategy( + ShardingStrategy({name, output_spec, compute_cost, communication_cost, + memory_cost, + std::move(communication_resharding_costs), + std::move(memory_resharding_costs)}), + input_shardings); } } } @@ -1010,7 +1011,7 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, HloSharding output_spec = Tile(shape, tensor_dims, mesh_dims, device_mesh); double compute_cost = 0, communication_cost = 0; double memory_cost = ByteSizeOfShapeWithSharding(shape, output_spec); - std::vector> input_shardings; + InputShardings input_shardings; ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; if (ins->opcode() == HloOpcode::kConditional) { @@ -1055,7 +1056,8 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, strategy_group.AddStrategy( ShardingStrategy({name, output_spec, compute_cost, communication_cost, memory_cost, std::move(communication_resharding_costs), - std::move(memory_resharding_costs), input_shardings})); + std::move(memory_resharding_costs)}), + input_shardings); } void EnumerateAll1DPartitionReshape(const HloInstruction* ins, @@ -1102,14 +1104,11 @@ void EnumerateAll1DPartitionReshape(const HloInstruction* ins, ReshardingCosts memory_resharding_costs{MemoryReshardingCostVector( operand_strategy_group, operand_shape, *input_spec, cluster_env)}; strategy_group.AddStrategy( - ShardingStrategy({name, - output_spec, - compute_cost, - communication_cost, + ShardingStrategy({name, output_spec, compute_cost, communication_cost, memory_cost, std::move(communication_resharding_costs), - std::move(memory_resharding_costs), - {*input_spec}})); + std::move(memory_resharding_costs)}), + {*input_spec}); } } } @@ -1194,14 +1193,10 @@ void BuildStrategyAndCostForReshape(const HloInstruction* ins, ReshardingCosts memory_resharding_costs{MemoryReshardingCostVector( operand_strategy_group, operand_shape, *input_spec, cluster_env)}; strategy_group.AddStrategy( - ShardingStrategy({name, - output_spec, - compute_cost, - communication_cost, - memory_cost, - std::move(communication_resharding_costs), - std::move(memory_resharding_costs), - {*input_spec}})); + ShardingStrategy({name, output_spec, compute_cost, communication_cost, + memory_cost, std::move(communication_resharding_costs), + std::move(memory_resharding_costs)}), + {*input_spec}); } // Return the maximum number of tiles among all strategies of an instruction. @@ -1472,11 +1467,13 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( cluster_env.device_mesh_.num_elements())) { // Sharding provided by XLA users, we need to keep them. strategy_group.following = nullptr; - std::vector new_strategies; - for (const ShardingStrategy& strategy : strategy_group.GetStrategies()) { + std::vector> new_strategies; + for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) { + const ShardingStrategy& strategy = strategy_group.GetStrategy(sid); + const auto& input_shardings = strategy_group.GetInputShardings(sid); if (strategy.output_sharding == existing_sharding) { VLOG(1) << "Keeping strategy: " << strategy.ToString(); - new_strategies.push_back(strategy); + new_strategies.push_back({strategy, input_shardings}); } } if (!new_strategies.empty()) { @@ -1485,15 +1482,15 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( pretrimmed_strategy_map[strategy_group.node_idx] = strategy_group.GetStrategies(); strategy_group.ClearStrategies(); - for (const ShardingStrategy& strategy : new_strategies) { - strategy_group.AddStrategy(strategy); + for (const auto& [strategy, input_shardings] : new_strategies) { + strategy_group.AddStrategy(strategy, input_shardings); } } else { VLOG(1) << "Generate a new strategy based on user sharding."; std::string name = ToStringSimple(existing_sharding); ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; - std::vector> input_shardings; + InputShardings input_shardings; if (!strategy_group.in_nodes.empty()) { HloInstruction* ins = instructions.at(strategy_group.instruction_id); for (size_t i = 0; i < strategy_group.in_nodes.size(); i++) { @@ -1546,7 +1543,8 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( strategy_group.AddStrategy( ShardingStrategy({name, existing_sharding, 0, 0, memory_cost, communication_resharding_costs, - memory_resharding_costs, input_shardings})); + memory_resharding_costs}), + input_shardings); } // If there is only one option for resharding, and the cost computed for // that option is kInfinityCost, set the cost to zero. This is okay @@ -1567,8 +1565,10 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( // sharding. // It is IMPORTANT that we do this only for instructions that do no follow // others, to keep the number of ILP variable small. - std::vector new_vector; - for (const ShardingStrategy& strategy : strategy_group.GetStrategies()) { + std::vector> new_vector; + for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) { + const ShardingStrategy& strategy = strategy_group.GetStrategy(sid); + const auto& input_shardings = strategy_group.GetInputShardings(sid); if (strategy.output_sharding.IsReplicated() || ShardingIsConsistent(existing_sharding, strategy.output_sharding, strict) || @@ -1578,7 +1578,7 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( spmd::ShardingIsComplete( strategy.output_sharding, cluster_env.original_device_mesh_.num_elements()))) { - new_vector.push_back(strategy); + new_vector.push_back({strategy, input_shardings}); } } // If no sharding strategy left, just keep the original set, because we do @@ -1588,8 +1588,8 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( new_vector.size() != strategy_group.GetStrategies().size()) { strategy_group.following = nullptr; strategy_group.ClearStrategies(); - for (const ShardingStrategy& strategy : new_vector) { - strategy_group.AddStrategy(strategy); + for (const auto& [strategy, input_shardings] : new_vector) { + strategy_group.AddStrategy(strategy, input_shardings); } } } @@ -1769,7 +1769,7 @@ std::unique_ptr HandleManuallyShardedInstruction( strategy_groups); ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; - std::vector> input_shardings; + InputShardings input_shardings; if (ins->operand_count() > 0 && ins->operand(0)->shape().IsTuple()) { CHECK_EQ(ins->operand_count(), 1) @@ -1790,11 +1790,12 @@ std::unique_ptr HandleManuallyShardedInstruction( memory_resharding_costs.push_back(zeros); } } - strategy_group->AddStrategy(ShardingStrategy( - {"MANUAL", HloSharding::Replicate(), 0, 0, - static_cast(ShapeUtil::ByteSizeOf(shape)), - std::move(communication_resharding_costs), - std::move(memory_resharding_costs), std::move(input_shardings)})); + strategy_group->AddStrategy( + ShardingStrategy({"MANUAL", HloSharding::Replicate(), 0, 0, + static_cast(ShapeUtil::ByteSizeOf(shape)), + std::move(communication_resharding_costs), + std::move(memory_resharding_costs)}), + std::move(input_shardings)); } else { LOG(FATAL) << "Unsupported instruction shape: " << shape.DebugString(); } @@ -1856,8 +1857,8 @@ std::unique_ptr CreateReshapeStrategies( communication_cost, memory_cost, {communication_resharding_costs}, - {memory_resharding_costs}, - {src_strategy.output_sharding}})); + {memory_resharding_costs}}), + {src_strategy.output_sharding}); } } @@ -2332,19 +2333,19 @@ absl::Status InsertReshardReshapes( // Allow duplicated dot computation in this case to reduce // communication } else { - CHECK(stra.input_shardings.size() == 2) + const InputShardings& input_shardings = + GetInputShardings(inst, strategy_map, cost_graph, s_val); + CHECK(input_shardings.size() == 2) << "Dot op requires both operands to have input shardings, " "but get instruction: " << inst->ToString() << ", strategy : " << stra.ToString(); - if (stra.input_shardings[0].has_value()) { - TF_RETURN_IF_ERROR( - FixMixedMeshShapeResharding(inst, 0, *stra.input_shardings[0], - device_mesh, resharding_cache)); + if (input_shardings[0].has_value()) { + TF_RETURN_IF_ERROR(FixMixedMeshShapeResharding( + inst, 0, *input_shardings[0], device_mesh, resharding_cache)); } - if (stra.input_shardings[1].has_value()) { - TF_RETURN_IF_ERROR( - FixMixedMeshShapeResharding(inst, 1, *stra.input_shardings[1], - device_mesh, resharding_cache)); + if (input_shardings[1].has_value()) { + TF_RETURN_IF_ERROR(FixMixedMeshShapeResharding( + inst, 1, *input_shardings[1], device_mesh, resharding_cache)); } } } @@ -2375,28 +2376,27 @@ absl::Status InsertReshardReshapes( case HloOpcode::kRngBitGenerator: case HloOpcode::kSort: { for (size_t i = 0; i < inst->shape().tuple_shapes_size(); ++i) { - const ShardingStrategy& stra = - GetShardingStrategyForTuple(inst, {static_cast(i)}, - strategy_map, cost_graph, s_val); - if (stra.input_shardings.size() > i && - stra.input_shardings[i].has_value()) { - TF_RETURN_IF_ERROR(FixMixedMeshShapeResharding( - inst, i, *stra.input_shardings[i], device_mesh, - resharding_cache)); + const InputShardings& input_shardings = + GetInputShardingsForTuple(inst, {static_cast(i)}, + strategy_map, cost_graph, s_val); + if (input_shardings.size() > i && + input_shardings[i].has_value()) { + TF_RETURN_IF_ERROR( + FixMixedMeshShapeResharding(inst, i, *input_shardings[i], + device_mesh, resharding_cache)); } } break; } case HloOpcode::kTuple: { for (size_t i = 0; i < inst->shape().tuple_shapes_size(); ++i) { - const ShardingStrategy& stra = - GetShardingStrategyForTuple(inst, {static_cast(i)}, - strategy_map, cost_graph, s_val); - CHECK_EQ(stra.input_shardings.size(), 1); - CHECK(stra.input_shardings[0].has_value()); - TF_RETURN_IF_ERROR( - FixMixedMeshShapeResharding(inst, i, *stra.input_shardings[0], - device_mesh, resharding_cache)); + const InputShardings& input_shardings = + GetInputShardingsForTuple(inst, {static_cast(i)}, + strategy_map, cost_graph, s_val); + CHECK_EQ(input_shardings.size(), 1); + CHECK(input_shardings[0].has_value()); + TF_RETURN_IF_ERROR(FixMixedMeshShapeResharding( + inst, i, *input_shardings[0], device_mesh, resharding_cache)); } break; } @@ -2407,12 +2407,11 @@ absl::Status InsertReshardReshapes( CHECK(!inst->shape().tuple_shapes(i).IsTuple()) << "We currently do not support ops with nested tuples as " "output. See b/332951306."; - const ShardingStrategy& stra = - GetShardingStrategyForTuple(inst, {static_cast(i)}, - strategy_map, cost_graph, s_val); - if (!stra.input_shardings.empty() && - stra.input_shardings[0].has_value()) { - dst_shardings[i] = *stra.input_shardings[0]; + const InputShardings& input_shardings = + GetInputShardingsForTuple(inst, {static_cast(i)}, + strategy_map, cost_graph, s_val); + if (!input_shardings.empty() && input_shardings[0].has_value()) { + dst_shardings[i] = *input_shardings[0]; } } TF_RETURN_IF_ERROR( @@ -2432,9 +2431,9 @@ absl::Status InsertReshardReshapes( LOG(FATAL) << "Unhandled instruction: " + inst->ToString(); } } else { - const ShardingStrategy& stra = - GetShardingStrategy(inst, strategy_map, cost_graph, s_val); - if (stra.input_shardings.empty()) { + const InputShardings& input_shardings = + GetInputShardings(inst, strategy_map, cost_graph, s_val); + if (input_shardings.empty()) { continue; } if (inst->opcode() == HloOpcode::kGetTupleElement) { @@ -2444,11 +2443,9 @@ absl::Status InsertReshardReshapes( } for (size_t i = 0; i < inst->operand_count(); ++i) { - if (stra.input_shardings.size() > i && - stra.input_shardings[i].has_value()) { - TF_RETURN_IF_ERROR( - FixMixedMeshShapeResharding(inst, i, *stra.input_shardings[i], - device_mesh, resharding_cache)); + if (input_shardings.size() > i && input_shardings[i].has_value()) { + TF_RETURN_IF_ERROR(FixMixedMeshShapeResharding( + inst, i, *input_shardings[i], device_mesh, resharding_cache)); } } } @@ -3357,8 +3354,10 @@ absl::Status FilterStrategy(const HloInstruction* ins, const Shape& shape, "not divisible by the number of devices"); } - std::vector new_strategies; - for (const ShardingStrategy& strategy : strategy_group.GetStrategies()) { + std::vector> new_strategies; + for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) { + const ShardingStrategy& strategy = strategy_group.GetStrategy(sid); + const auto& input_shardings = strategy_group.GetInputShardings(sid); const HloSharding& output_sharding = strategy.output_sharding; const std::vector tensor_dim_to_mesh_dim = cluster_env.GetTensorDimToMeshDimWrapper(shape, output_sharding); @@ -3367,21 +3366,21 @@ absl::Status FilterStrategy(const HloInstruction* ins, const Shape& shape, // If the mesh dim is not one, the output tensor must be // tiled along the mesh dim. if (tensor_dim_to_mesh_dim[batch_dim] == mesh_dim) { - new_strategies.push_back(strategy); + new_strategies.push_back({strategy, input_shardings}); } } else { // If the mesh dim is one, the output tensor must be replicated // on the mesh dim. if (tensor_dim_to_mesh_dim[batch_dim] == -1) { - new_strategies.push_back(strategy); + new_strategies.push_back({strategy, input_shardings}); } } } CHECK(!new_strategies.empty()) << ins->ToString() << " does not have any valid strategies"; strategy_group.ClearStrategies(); - for (const ShardingStrategy& strategy : new_strategies) { - strategy_group.AddStrategy(strategy); + for (const auto& [strategy, input_shardings] : new_strategies) { + strategy_group.AddStrategy(strategy, input_shardings); } return absl::OkStatus(); diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.h b/xla/hlo/experimental/auto_sharding/auto_sharding.h index c9791841e5eba..71a85b83136ea 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -341,8 +341,7 @@ std::pair GenerateReshardingCostsAndMissingShardingsForAllOperands( const HloInstruction* ins, const HloSharding& output_sharding, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, - const CallGraph& call_graph, - std::vector>& input_shardings); + const CallGraph& call_graph, InputShardings& input_shardings); std::unique_ptr MaybeFollowInsStrategyGroup( const StrategyGroup& src_strategy_group, const Shape& shape, diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h index 802d780707b67..d63e5bee65007 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h @@ -139,9 +139,20 @@ inline const ShardingStrategy& GetShardingStrategy( return strategy_group->GetStrategies()[stra_idx]; } +// Get the input shardings according to the ILP solution. +inline const InputShardings& GetInputShardings( + const HloInstruction* inst, const StrategyMap& strategy_map, + const CostGraph& cost_graph, absl::Span s_val) { + const StrategyGroup* strategy_group = strategy_map.at(inst).get(); + CHECK(!strategy_group->is_tuple); + NodeIdx node_idx = strategy_group->node_idx; + NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]); + return strategy_group->GetInputShardings(stra_idx); +} + // Get the final sharding strategy according to the ILP solution. inline const ShardingStrategy& GetShardingStrategyForTuple( - const HloInstruction* inst, ShapeIndex index, + const HloInstruction* inst, const ShapeIndex& index, const StrategyMap& strategy_map, const CostGraph& cost_graph, absl::Span s_val) { const StrategyGroup* strategy_group = strategy_map.at(inst).get(); @@ -156,6 +167,23 @@ inline const ShardingStrategy& GetShardingStrategyForTuple( return strategy_group->GetStrategies()[stra_idx]; } +// Get the input shardings according to the ILP solution. +inline const InputShardings& GetInputShardingsForTuple( + const HloInstruction* inst, const ShapeIndex& index, + const StrategyMap& strategy_map, const CostGraph& cost_graph, + absl::Span s_val) { + const StrategyGroup* strategy_group = strategy_map.at(inst).get(); + CHECK(strategy_group->is_tuple); + for (auto index_element : index) { + CHECK_LT(index_element, strategy_group->GetChildren().size()); + const auto& strategies = strategy_group->GetChildren()[index_element]; + strategy_group = strategies.get(); + } + NodeIdx node_idx = strategy_group->node_idx; + NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]); + return strategy_group->GetInputShardings(stra_idx); +} + } // namespace spmd } // namespace xla diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index 15d0a03b0592f..85898c8b4e6df 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" @@ -339,17 +340,13 @@ void HandlerBase::AppendNewStrategy(const std::string& name, operand_strategy_group, operand_shape, input_specs[i], cluster_env_)); } - strategy_group_->AddStrategy(ShardingStrategy({ - name, - output_spec, - compute_cost, - communication_cost, - static_cast( - ByteSizeOfShapeWithSharding(ins_->shape(), output_spec)), - communication_resharding_costs, - memory_resharding_costs, - {input_specs.begin(), input_specs.end()}, - })); + strategy_group_->AddStrategy( + ShardingStrategy({name, output_spec, compute_cost, communication_cost, + static_cast(ByteSizeOfShapeWithSharding( + ins_->shape(), output_spec)), + communication_resharding_costs, + memory_resharding_costs}), + {input_specs.begin(), input_specs.end()}); } // Given lhs and rhs dim maps, infers a sharding for the output by relying @@ -462,18 +459,25 @@ std::optional HandlerBase::GetShardingFromUser( } void HandlerBase::SortStrategies() { - auto strategies = strategy_group_->GetStrategies(); + std::vector> strategy_shardings; + for (size_t sid = 0; sid < strategy_group_->GetStrategies().size(); ++sid) { + const ShardingStrategy& strategy = strategy_group_->GetStrategy(sid); + const auto& input_shardings = strategy_group_->GetInputShardings(sid); + strategy_shardings.push_back({strategy, input_shardings}); + } absl::c_stable_sort( - strategies, [](const ShardingStrategy& s1, const ShardingStrategy& s2) { - if (s1.memory_cost == s2.memory_cost) { - return s1.name < s2.name; + strategy_shardings, + [](const std::pair& s1, + const std::pair& s2) { + if (s1.first.memory_cost == s2.first.memory_cost) { + return s1.first.name < s2.first.name; } else { - return s1.memory_cost < s2.memory_cost; + return s1.first.memory_cost < s2.first.memory_cost; } }); strategy_group_->ClearStrategies(); - for (const ShardingStrategy& strategy : strategies) { - strategy_group_->AddStrategy(strategy); + for (const auto& [strategy, input_shardings] : strategy_shardings) { + strategy_group_->AddStrategy(strategy, input_shardings); } } diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 3f1a9727cf210..121917b3cc0b4 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -336,17 +336,19 @@ BuildStrategyAndCost( double memory_cost = ByteSizeOfShapeWithSharding(ins->shape(), scatter_sharding); - std::vector> input_shardings_optional( + InputShardings input_shardings_optional( {data_sharding, indices_sharding, update_sharding}); std::pair resharding_costs = GenerateReshardingCostsAndMissingShardingsForAllOperands( ins, scatter_sharding, strategy_map, cluster_env, call_graph, input_shardings_optional); - strategy_group->AddStrategy(ShardingStrategy( - {name, scatter_sharding, compute_cost, communication_cost, - memory_cost, std::move(resharding_costs.first), - std::move(resharding_costs.second), input_shardings_optional})); + strategy_group->AddStrategy( + ShardingStrategy({name, scatter_sharding, compute_cost, + communication_cost, memory_cost, + std::move(resharding_costs.first), + std::move(resharding_costs.second)}), + input_shardings_optional); }; const HloScatterInstruction* scatter = Cast(ins); @@ -388,18 +390,20 @@ BuildStrategyAndCost( double compute_cost = 0, communication_cost = 0; double memory_cost = ByteSizeOfShapeWithSharding(gather_shape, output_sharding); - std::vector> input_shardings_optional( + InputShardings input_shardings_optional( {data_sharding, indices_sharding}); std::pair resharding_costs = GenerateReshardingCostsAndMissingShardingsForAllOperands( ins, output_sharding, strategy_map, cluster_env, call_graph, input_shardings_optional); - strategy_group->AddStrategy(ShardingStrategy( - {std::string(output_sharding.ToString()), output_sharding, - compute_cost, communication_cost, memory_cost, - std::move(resharding_costs.first), - std::move(resharding_costs.second), input_shardings_optional})); + strategy_group->AddStrategy( + ShardingStrategy({std::string(output_sharding.ToString()), + output_sharding, compute_cost, + communication_cost, memory_cost, + std::move(resharding_costs.first), + std::move(resharding_costs.second)}), + input_shardings_optional); }; for (const ShardingStrategy& indices_strategy : @@ -562,8 +566,8 @@ BuildStrategyAndCost( communication_cost, memory_cost, {communication_resharding_costs}, - {memory_resharding_costs}, - {input_spec}})); + {memory_resharding_costs}}), + {input_spec}); } break; } @@ -671,7 +675,7 @@ BuildStrategyAndCost( } // Get a list of input shardings, each corresponds to an operand. - std::vector> input_shardings; + InputShardings input_shardings; for (int64_t k = 0; k < ins->operand_count(); ++k) { if (k == follow_idx || ToString(ins->operand(k)->shape().dimensions()) == @@ -694,14 +698,11 @@ BuildStrategyAndCost( input_shardings); strategy_group->AddStrategy( - ShardingStrategy({name, - *output_spec, - compute_cost, - communication_cost, - memory_cost, + ShardingStrategy({name, *output_spec, compute_cost, + communication_cost, memory_cost, std::move(resharding_costs.first), - std::move(resharding_costs.second), - {input_spec}})); + std::move(resharding_costs.second)}), + {input_spec}); } if (strategy_group->GetStrategies().empty()) { @@ -1082,16 +1083,19 @@ BuildStrategyAndCost( auto it = absl::c_find(inst_indices, strategy_group->node_idx); if (it != inst_indices.end()) { CHECK(!strategy_group->is_tuple); - std::vector new_strategies; + std::vector> new_strategies; int64_t idx = it - inst_indices.begin(); - for (const auto& strategy : strategy_group->GetStrategies()) { + const auto& strategies = strategy_group->GetStrategies(); + for (size_t sid = 0; sid < strategies.size(); ++sid) { + const ShardingStrategy& strategy = strategy_group->GetStrategy(sid); + const auto& input_shardings = strategy_group->GetInputShardings(sid); if (strategy.name == stra_names[idx]) { - new_strategies.push_back(strategy); + new_strategies.push_back({strategy, input_shardings}); } } strategy_group->ClearStrategies(); - for (const ShardingStrategy& strategy : new_strategies) { - strategy_group->AddStrategy(strategy); + for (const auto& [strategy, input_shardings] : new_strategies) { + strategy_group->AddStrategy(strategy, input_shardings); } } } diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h index 3a17e6f3d46b9..8327f6b3e58ef 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h @@ -75,6 +75,8 @@ using ReshardingCache = ConstInstructionMap>>; // Resharding costs for each operand using ReshardingCosts = std::vector>; +// Optional shardings for each operand +using InputShardings = std::vector>; // One sharding strategy struct ShardingStrategy { @@ -89,9 +91,6 @@ struct ShardingStrategy { // cost from i-th tuple element's j-th strategy. ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; - // Optional: the required shardings of operands. - // This is used to guide the SPMD partitioner. - std::vector> input_shardings; std::string ToString() const { return absl::StrCat(name, ", ", output_sharding.ToString()); @@ -117,32 +116,12 @@ struct ShardingStrategy { std::string memory_resharding_cost_str = absl::StrCat( "{", absl::StrJoin(memory_resharding_vector_strings, ", "), "}"); - std::string input_sharding_str = "{"; - for (const auto& s : input_shardings) { - if (!s.has_value()) { - input_sharding_str += "[*],"; - } else if (s->IsReplicated()) { - input_sharding_str += "[R],"; - } else { - if (s->ReplicateOnLastTileDim()) { - input_sharding_str += - "[" + absl::StrJoin(s->tile_assignment().dimensions(), ", ") + - "]last_tile_dim_replicate,"; - } else { - input_sharding_str += - "[" + absl::StrJoin(s->tile_assignment().dimensions(), ", ") + - "],"; - } - } - } - input_sharding_str += "}\n"; return absl::StrCat( name, ", ", output_sharding.ToString(), ", compute_cost=", compute_cost, ", communication_cost=", communication_cost, ", memory_cost=", memory_cost, ", communication_resharding_costs=", communication_resharding_cost_str, - ", memory_resharding_costs=", memory_resharding_cost_str, - ", input_shardings=", input_sharding_str); + ", memory_resharding_costs=", memory_resharding_cost_str); } bool operator==(const ShardingStrategy& other) const { @@ -152,8 +131,7 @@ struct ShardingStrategy { memory_cost == other.memory_cost && communication_resharding_costs == other.communication_resharding_costs && - memory_resharding_costs == other.memory_resharding_costs && - input_shardings == other.input_shardings; + memory_resharding_costs == other.memory_resharding_costs; } }; @@ -273,20 +251,33 @@ struct StrategyGroup { //////// Accessor methods for strategies //////// - void AddStrategy(const ShardingStrategy& strategy) { + void AddStrategy(const ShardingStrategy& strategy, + const InputShardings& input_shardings = {}) { strategies.push_back(strategy); + strategy_input_shardings.push_back(input_shardings); } - void ClearStrategies() { strategies.clear(); } + void ClearStrategies() { + strategies.clear(); + strategy_input_shardings.clear(); + } ShardingStrategy& GetStrategy(size_t strategy_idx) { return strategies[strategy_idx]; } + const InputShardings& GetInputShardings(size_t strategy_idx) const { + return strategy_input_shardings[strategy_idx]; + } + const std::vector& GetStrategies() const { return strategies; } + const std::vector& GetStrategyInputShardings() const { + return strategy_input_shardings; + } + //////// Accessor methods for children //////// void AddChild(std::unique_ptr child) { @@ -305,6 +296,8 @@ struct StrategyGroup { // Used when is_tuple == False. Leaf strategy vector. // A vector of strategy choices for the non-tuple output. std::vector strategies; + std::vector strategy_input_shardings; + // Used when is_tuple == True. A vector of pointers, each pointer is one // StrategyGroup for one value in the output Tuple std::vector> children; diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index d12de663a115d..adcc185f969dd 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -836,18 +836,21 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) { if (strategy_group.following || strategy_group.GetStrategies().empty()) { return; } - std::vector new_vector; - std::vector deduped_replicated_strategies; + std::vector> new_vector; + std::vector> + deduped_replicated_strategies; absl::flat_hash_set added; size_t num_skipped_due_to_infinity_costs = 0; - for (const ShardingStrategy& strategy : strategy_group.GetStrategies()) { + for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) { + const ShardingStrategy& strategy = strategy_group.GetStrategy(sid); if (AllInfinityCosts(strategy.communication_resharding_costs)) { num_skipped_due_to_infinity_costs++; continue; } std::string key = strategy.output_sharding.ToString(); - if (!strategy.input_shardings.empty()) { - for (const auto& sharding : strategy.input_shardings) { + const auto& input_shardings = strategy_group.GetInputShardings(sid); + if (!input_shardings.empty()) { + for (const auto& sharding : input_shardings) { key += "/" + (sharding.has_value() ? sharding->ToString() : "none"); } } @@ -856,9 +859,9 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) { } added.insert(key); if (!strategy.output_sharding.IsReplicated()) { - new_vector.push_back(strategy); + new_vector.push_back({strategy, input_shardings}); } else { - deduped_replicated_strategies.push_back(strategy); + deduped_replicated_strategies.push_back({strategy, input_shardings}); } } CHECK_LT(num_skipped_due_to_infinity_costs, @@ -871,8 +874,8 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) { } } strategy_group.ClearStrategies(); - for (const ShardingStrategy& strategy : new_vector) { - strategy_group.AddStrategy(strategy); + for (const auto& [strategy, input_shardings] : new_vector) { + strategy_group.AddStrategy(strategy, input_shardings); } }