From f4a8c3670d8654d682d70415224e5d2fc21f393f Mon Sep 17 00:00:00 2001 From: xla authors Date: Mon, 16 Sep 2024 13:33:39 -0700 Subject: [PATCH] Adds accessor methods to StrategyGroup (so that clients can't directly manipulate the vectors containing sharding strategies and child groups). PiperOrigin-RevId: 675271442 --- .../auto_sharding/auto_sharding.cc | 805 ++++++++---------- .../auto_sharding/auto_sharding.h | 75 +- .../auto_sharding/auto_sharding_cost_graph.cc | 32 +- .../auto_sharding/auto_sharding_cost_graph.h | 8 +- .../auto_sharding_dot_handler.cc | 24 +- .../auto_sharding/auto_sharding_strategy.cc | 187 ++-- .../auto_sharding/auto_sharding_strategy.h | 81 +- .../auto_sharding/auto_sharding_util.cc | 105 +-- 8 files changed, 658 insertions(+), 659 deletions(-) diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 3a39630c4cad0..f486e7d94baba 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -100,16 +100,16 @@ constexpr double kSaltiplier = 0.0; // This value (0.0) disables salting. // Compute the resharding cost vector from multiple possible strategies to a // desired sharding spec. std::vector CommunicationReshardingCostVector( - const StrategyGroup* strategy_group, const Shape& operand_shape, + const StrategyGroup& strategy_group, const Shape& operand_shape, const HloSharding& required_sharding, const ClusterEnvironment& cluster_env) { - CHECK(!strategy_group->is_tuple) << "Only works with strategy vector."; + CHECK(!strategy_group.is_tuple) << "Only works with strategy vector."; std::vector ret; - ret.reserve(strategy_group->strategies.size()); + ret.reserve(strategy_group.GetStrategies().size()); auto required_sharding_for_resharding = required_sharding.IsTileMaximal() ? HloSharding::Replicate() : required_sharding; - for (const auto& x : strategy_group->strategies) { + for (const ShardingStrategy& x : strategy_group.GetStrategies()) { ret.push_back(cluster_env.ReshardingCost(operand_shape, x.output_sharding, required_sharding_for_resharding)); } @@ -151,18 +151,18 @@ double ComputeMemoryReshardingCost(const Shape& shape, } std::vector MemoryReshardingCostVector( - const StrategyGroup* strategy_group, const Shape& operand_shape, + const StrategyGroup& strategy_group, const Shape& operand_shape, const HloSharding& required_sharding, const ClusterEnvironment& cluster_env) { - CHECK(!strategy_group->is_tuple) << "Only works with strategy vector."; + CHECK(!strategy_group.is_tuple) << "Only works with strategy vector."; std::vector ret; - ret.reserve(strategy_group->strategies.size()); + ret.reserve(strategy_group.GetStrategies().size()); auto required_sharding_for_resharding = required_sharding.IsTileMaximal() ? HloSharding::Replicate() : required_sharding; CHECK_OK(required_sharding.Validate(operand_shape)) - << strategy_group->ToString(); - for (const auto& x : strategy_group->strategies) { + << strategy_group.ToString(); + for (const ShardingStrategy& x : strategy_group.GetStrategies()) { ret.push_back(ComputeMemoryReshardingCost(operand_shape, x.output_sharding, required_sharding_for_resharding, cluster_env.device_mesh_)); @@ -217,12 +217,14 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands( input_shardings.resize(ins->operand_count()); } for (int64_t k = 0; k < ins->operand_count(); ++k) { - auto operand = ins->operand(k); - if (operand->shape().IsToken() || operand->shape().rank() == 0) { - communication_resharding_costs.push_back(std::vector( - strategy_map.at(operand)->strategies.size(), 0.0)); - memory_resharding_costs.push_back(std::vector( - strategy_map.at(operand)->strategies.size(), 0.0)); + const HloInstruction* operand = ins->operand(k); + const Shape& operand_shape = operand->shape(); + const StrategyGroup& operand_strategy_group = *strategy_map.at(operand); + const auto& operand_strategies = operand_strategy_group.GetStrategies(); + const std::vector zeros(operand_strategies.size(), 0.0); + if (operand_shape.IsToken() || operand_shape.rank() == 0) { + communication_resharding_costs.push_back(zeros); + memory_resharding_costs.push_back(zeros); if (!input_shardings[k].has_value()) { input_shardings[k] = HloSharding::Replicate(); } @@ -252,25 +254,21 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands( if (!input_shardings[k].has_value()) { input_shardings[k] = cur_input_sharding; } - auto operand_strategies = strategy_map.at(operand).get(); - auto operand_shape = operand->shape(); if (ins->opcode() == HloOpcode::kGather && k == 0 && is_sharding_default_replicated) { VLOG(2) << "Zeroing out operand 0 resharding costs for gather sharding " << output_sharding.ToString(); - communication_resharding_costs.push_back( - std::vector(operand_strategies->strategies.size(), 0)); - memory_resharding_costs.push_back( - std::vector(operand_strategies->strategies.size(), 0)); + communication_resharding_costs.push_back(zeros); + memory_resharding_costs.push_back(zeros); input_shardings[k] = std::nullopt; } else { communication_resharding_costs.push_back( CommunicationReshardingCostVector( - operand_strategies, ins->operand(k)->shape(), - *cur_input_sharding, cluster_env)); - memory_resharding_costs.push_back(MemoryReshardingCostVector( - operand_strategies, ins->operand(k)->shape(), *cur_input_sharding, - cluster_env)); + operand_strategy_group, operand_shape, *cur_input_sharding, + cluster_env)); + memory_resharding_costs.push_back( + MemoryReshardingCostVector(operand_strategy_group, operand_shape, + *cur_input_sharding, cluster_env)); } } } @@ -319,18 +317,16 @@ void FollowArrayOrTokenStrategyGroup( strategy_group.following = &src_strategy_group; } - strategy_group.strategies.reserve(src_strategy_group.strategies.size()); + const auto& src_strategies = src_strategy_group.GetStrategies(); // Creates the sharding strategies and restores trimmed strategies, if any. - for (int64_t sid = 0; sid < src_strategy_group.strategies.size() + - pretrimmed_strategies.size(); - ++sid) { + for (int64_t sid = 0; + sid < src_strategies.size() + pretrimmed_strategies.size(); ++sid) { const HloSharding* output_spec; - if (sid < src_strategy_group.strategies.size()) { - output_spec = &src_strategy_group.strategies[sid].output_sharding; + if (sid < src_strategies.size()) { + output_spec = &src_strategies[sid].output_sharding; } else { output_spec = - &pretrimmed_strategies[sid - src_strategy_group.strategies.size()] - .output_sharding; + &pretrimmed_strategies[sid - src_strategies.size()].output_sharding; VLOG(1) << "Adding outspec from the trimmed strategy map: " << output_spec->ToString(); } @@ -344,13 +340,13 @@ void FollowArrayOrTokenStrategyGroup( ReshardingCosts memory_resharding_costs; for (size_t i = 0; i < strategy_group.in_nodes.size(); ++i) { communication_resharding_costs.push_back( - CommunicationReshardingCostVector(strategy_group.in_nodes[i], shape, + CommunicationReshardingCostVector(*strategy_group.in_nodes[i], shape, *output_spec, cluster_env)); memory_resharding_costs.push_back(MemoryReshardingCostVector( - strategy_group.in_nodes[i], shape, *output_spec, cluster_env)); + *strategy_group.in_nodes[i], shape, *output_spec, cluster_env)); } - strategy_group.strategies.push_back( + strategy_group.AddStrategy( ShardingStrategy({name, *output_spec, compute_cost, communication_cost, memory_cost, communication_resharding_costs, memory_resharding_costs, input_shardings})); @@ -375,9 +371,8 @@ std::unique_ptr HandlePartialReduce( CreateLeafStrategyGroupWithoutInNodes(instruction_id, strategy_groups); child_strategy_group->in_nodes.push_back(src_strategy_group); child_strategy_group->following = src_strategy_group; - for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); ++sid) { - const HloSharding& input_spec = - src_strategy_group->strategies[sid].output_sharding; + for (const auto& src_strategy : src_strategy_group->GetStrategies()) { + const HloSharding& input_spec = src_strategy.output_sharding; // There is no way for us to handle manual sharding. if (input_spec.IsManual() || input_spec.IsManualSubgroup()) { continue; @@ -410,42 +405,41 @@ std::unique_ptr HandlePartialReduce( ins, output_spec, strategy_map, cluster_env, call_graph, input_shardings); - child_strategy_group->strategies.push_back(ShardingStrategy( + child_strategy_group->AddStrategy(ShardingStrategy( {std::move(name), std::move(output_spec), compute_cost, communication_cost, memory_cost, std::move(resharding_costs.first), std::move(resharding_costs.second), std::move(input_shardings)})); } - strategy_group->childs.push_back(std::move(child_strategy_group)); + strategy_group->AddChild(std::move(child_strategy_group)); } return strategy_group; } std::unique_ptr MaybeFollowInsStrategyGroup( - const StrategyGroup* src_strategy_group, const Shape& shape, + const StrategyGroup& src_strategy_group, const Shape& shape, const size_t instruction_id, StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, const StableMap>& pretrimmed_strategy_map) { + const auto& children = src_strategy_group.GetChildren(); std::unique_ptr strategy_group; - if (src_strategy_group->is_tuple) { + if (src_strategy_group.is_tuple) { CHECK(shape.IsTuple()); - CHECK_EQ(shape.tuple_shapes_size(), src_strategy_group->childs.size()); + CHECK_EQ(shape.tuple_shapes_size(), children.size()); strategy_group = CreateTupleStrategyGroup(instruction_id); - strategy_group->childs.reserve(src_strategy_group->childs.size()); - for (size_t i = 0; i < src_strategy_group->childs.size(); ++i) { + for (size_t i = 0; i < children.size(); ++i) { auto child_strategies = MaybeFollowInsStrategyGroup( - src_strategy_group->childs[i].get(), shape.tuple_shapes(i), - instruction_id, strategy_groups, cluster_env, - pretrimmed_strategy_map); + *children[i], shape.tuple_shapes(i), instruction_id, strategy_groups, + cluster_env, pretrimmed_strategy_map); child_strategies->tuple_element_idx = i; - strategy_group->childs.push_back(std::move(child_strategies)); + strategy_group->AddChild(std::move(child_strategies)); } } else { strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, strategy_groups); - strategy_group->in_nodes.push_back(src_strategy_group); - FollowArrayOrTokenStrategyGroup(*src_strategy_group, shape, instruction_id, + strategy_group->in_nodes.push_back(&src_strategy_group); + FollowArrayOrTokenStrategyGroup(src_strategy_group, shape, instruction_id, cluster_env, pretrimmed_strategy_map, *strategy_group); } @@ -461,7 +455,6 @@ absl::StatusOr> FollowReduceStrategy( std::unique_ptr strategy_group; if (output_shape.IsTuple()) { strategy_group = CreateTupleStrategyGroup(instruction_id); - strategy_group->childs.reserve(ins->shape().tuple_shapes_size()); for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) { TF_ASSIGN_OR_RETURN( std::unique_ptr child_strategy, @@ -471,7 +464,7 @@ absl::StatusOr> FollowReduceStrategy( instruction_id, strategy_map, strategy_groups, cluster_env, allow_mixed_mesh_shape, crash_at_error)); child_strategy->tuple_element_idx = i; - strategy_group->childs.push_back(std::move(child_strategy)); + strategy_group->AddChild(std::move(child_strategy)); } } else if (output_shape.IsArray()) { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, @@ -479,7 +472,6 @@ absl::StatusOr> FollowReduceStrategy( const StrategyGroup* src_strategy_group = strategy_map.at(operand).get(); // Follows the strategy of the operand. strategy_group->following = src_strategy_group; - strategy_group->strategies.reserve(src_strategy_group->strategies.size()); // Map operand dims to inst dim // Example: f32[1,16]{1,0} reduce(f32[1,16,4096]{2,1,0} %param0, // f32[] %param1), dimensions={2} @@ -491,9 +483,8 @@ absl::StatusOr> FollowReduceStrategy( operand->shape().rank()) << "Invalid kReduce: output size + reduced dimensions size != op count"; - for (size_t sid = 0; sid < src_strategy_group->strategies.size(); ++sid) { - HloSharding input_sharding = - src_strategy_group->strategies[sid].output_sharding; + for (const auto& src_strategy : src_strategy_group->GetStrategies()) { + const HloSharding& input_sharding = src_strategy.output_sharding; const auto& tensor_dim_to_mesh = cluster_env.GetTensorDimToMeshDimWrapper( operand->shape(), input_sharding, /* consider_reverse_device_meshes */ true, @@ -521,8 +512,7 @@ absl::StatusOr> FollowReduceStrategy( std::unique_ptr new_reduce = HloInstruction::CreateReduce( output_shape, operand_clone.get(), unit_clone.get(), ins->dimensions(), ins->to_apply()); - operand_clone->set_sharding( - src_strategy_group->strategies[sid].output_sharding); + operand_clone->set_sharding(src_strategy.output_sharding); if (!new_reduce->ReplaceOperandWith(0, operand_clone.get()).ok()) { continue; } @@ -544,22 +534,21 @@ absl::StatusOr> FollowReduceStrategy( ReshardingCosts memory_resharding_costs; for (int64_t k = 0; k < ins->operand_count(); ++k) { const HloInstruction* cur_operand = ins->operand(k); + const auto& operand_strategy_group = *strategy_map.at(cur_operand); + const auto& operand_strategies = operand_strategy_group.GetStrategies(); if (ToString(cur_operand->shape().dimensions()) == ToString(operand->shape().dimensions())) { - const StrategyGroup* operand_strategies = - strategy_map.at(cur_operand).get(); communication_resharding_costs.push_back( - CommunicationReshardingCostVector(operand_strategies, + CommunicationReshardingCostVector(operand_strategy_group, cur_operand->shape(), input_sharding, cluster_env)); memory_resharding_costs.push_back(MemoryReshardingCostVector( - operand_strategies, cur_operand->shape(), input_sharding, + operand_strategy_group, cur_operand->shape(), input_sharding, cluster_env)); } else { - communication_resharding_costs.push_back(std::vector( - strategy_map.at(cur_operand)->strategies.size(), 0.0)); - memory_resharding_costs.push_back(std::vector( - strategy_map.at(cur_operand)->strategies.size(), 0.0)); + const std::vector zeros(operand_strategies.size(), 0); + communication_resharding_costs.push_back(zeros); + memory_resharding_costs.push_back(zeros); } } const ShardingStrategy strategy = @@ -571,7 +560,7 @@ absl::StatusOr> FollowReduceStrategy( communication_resharding_costs, memory_resharding_costs, {input_sharding}}); - strategy_group->strategies.push_back(strategy); + strategy_group->AddStrategy(strategy); } } else { LOG(FATAL) << "Unhandled kReduce shape: " << ins->shape().ToString(); @@ -593,7 +582,7 @@ std::vector FindReplicateStrategyIndices( std::tuple>> ReshardingCostsForTupleOperand(const HloInstruction* operand, - StrategyGroup* operand_strategy_vector) { + const StrategyGroup& operand_strategy_vector) { // TODO(yuemmawang) Support instructions with more than one tuple operand. // Creates resharding costs such that favors when operand strategies are // replicated. @@ -603,18 +592,20 @@ ReshardingCostsForTupleOperand(const HloInstruction* operand, for (size_t tuple_element_idx = 0; tuple_element_idx < operand->shape().tuple_shapes_size(); tuple_element_idx++) { - auto tuple_element_strategies = - operand_strategy_vector->childs.at(tuple_element_idx).get(); + const StrategyGroup& tuple_element_strategy_group = + *operand_strategy_vector.GetChildren()[tuple_element_idx]; + const auto& tuple_element_strategies = + tuple_element_strategy_group.GetStrategies(); std::vector indices = - FindReplicateStrategyIndices(tuple_element_strategies->strategies); + FindReplicateStrategyIndices(tuple_element_strategies); CHECK_GT(indices.size(), 0) << "There is no replicated strategy in instruction " << operand->ToString() << ".\nStrategies:\n" - << tuple_element_strategies->ToString(); + << tuple_element_strategy_group.ToString(); memory_resharding_costs.push_back( - std::vector(tuple_element_strategies->strategies.size(), 0)); - communication_resharding_costs.push_back(std::vector( - tuple_element_strategies->strategies.size(), kInfinityCost)); + std::vector(tuple_element_strategies.size(), 0)); + communication_resharding_costs.push_back( + std::vector(tuple_element_strategies.size(), kInfinityCost)); tuple_element_shardings.push_back(HloSharding::Replicate()); for (const size_t i : indices) { communication_resharding_costs.back().at(i) = 0.0; @@ -629,8 +620,8 @@ ReshardingCosts CreateZeroReshardingCostsForAllOperands( const HloInstruction* ins, const StrategyMap& strategy_map) { ReshardingCosts resharding_costs; for (size_t i = 0; i < ins->operand_count(); ++i) { - auto operand = ins->operand(i); - const auto& operand_strategies = strategy_map.at(operand); + const HloInstruction* operand = ins->operand(i); + const StrategyGroup& operand_strategy_group = *strategy_map.at(operand); if (operand->shape().IsTuple()) { if (ins->opcode() == HloOpcode::kConditional || ins->opcode() == HloOpcode::kOutfeed) { @@ -642,15 +633,15 @@ ReshardingCosts CreateZeroReshardingCostsForAllOperands( for (size_t tuple_element_idx = 0; tuple_element_idx < operand->shape().tuple_shapes_size(); tuple_element_idx++) { - auto tuple_element_strategies = - operand_strategies->childs.at(tuple_element_idx).get(); + const StrategyGroup& tuple_element_strategy_group = + *operand_strategy_group.GetChildren().at(tuple_element_idx); resharding_costs.push_back(std::vector( - tuple_element_strategies->strategies.size(), 0)); + tuple_element_strategy_group.GetStrategies().size(), 0)); } } } else { - resharding_costs.push_back( - std::vector(operand_strategies->strategies.size(), 0)); + const auto& strategies = operand_strategy_group.GetStrategies(); + resharding_costs.push_back(std::vector(strategies.size(), 0)); } } return resharding_costs; @@ -659,14 +650,16 @@ ReshardingCosts CreateZeroReshardingCostsForAllOperands( void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - std::unique_ptr& strategy_group, - const double replicated_penalty) { + const double replicated_penalty, + StrategyGroup& strategy_group) { HloSharding output_spec = HloSharding::Replicate(); ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; std::vector> input_shardings; const int tuple_size = ins->operand(0)->shape().tuple_shapes_size(); + const auto& operand_strategy_group = strategy_map.at(ins->operand(0)); + const auto& operand_children = operand_strategy_group->GetChildren(); if (ins->has_sharding()) { std::vector operand_shapes(ins->operand_count()); for (int i = 0; i < ins->operand_count(); ++i) { @@ -686,34 +679,30 @@ void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, }; for (size_t i = 0; i < tuple_size; ++i) { - auto input_sharding = get_input_sharding(i); + const StrategyGroup& child = *operand_children[i]; + const Shape& tuple_shape = ins->operand(0)->shape().tuple_shapes(i); + const HloSharding& input_sharding = get_input_sharding(i); input_shardings.push_back(input_sharding); communication_resharding_costs.push_back( - CommunicationReshardingCostVector( - strategy_map.at(ins->operand(0))->childs[i].get(), - ins->operand(0)->shape().tuple_shapes(i), input_sharding, - cluster_env)); + CommunicationReshardingCostVector(child, tuple_shape, input_sharding, + cluster_env)); memory_resharding_costs.push_back(MemoryReshardingCostVector( - strategy_map.at(ins->operand(0))->childs[i].get(), - ins->operand(0)->shape().tuple_shapes(i), input_sharding, - cluster_env)); + child, tuple_shape, input_sharding, cluster_env)); } - auto input_sharding = get_input_sharding(-1); + const HloSharding& input_sharding = get_input_sharding(-1); input_shardings.push_back(input_sharding); } else { for (size_t i = 0; i < tuple_size; ++i) { - communication_resharding_costs.push_back(std::vector( - strategy_map.at(ins->operand(0))->childs[i].get()->strategies.size(), - 0)); - memory_resharding_costs.push_back(std::vector( - strategy_map.at(ins->operand(0))->childs[i].get()->strategies.size(), - 0)); + const StrategyGroup& child = *operand_children[i]; + const std::vector zeros(child.GetStrategies().size(), 0); + communication_resharding_costs.push_back(zeros); + memory_resharding_costs.push_back(zeros); } } communication_resharding_costs.push_back({}); memory_resharding_costs.push_back({}); double memory_cost = ByteSizeOfShapeWithSharding(shape, output_spec); - strategy_group->strategies.push_back( + strategy_group.AddStrategy( ShardingStrategy({"R", HloSharding::Replicate(), replicated_penalty, 0, memory_cost, std::move(communication_resharding_costs), std::move(memory_resharding_costs), input_shardings})); @@ -761,9 +750,9 @@ double ComputeCommunicationCost( void AddReplicatedStrategy( const HloInstruction* ins, const Shape& shape, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - std::unique_ptr& strategy_group, const double replicated_penalty, - absl::flat_hash_set operands_to_consider_all_strategies_for) { + absl::flat_hash_set operands_to_consider_all_strategies_for, + StrategyGroup& strategy_group) { HloSharding replicated_strategy = HloSharding::Replicate(); HloSharding output_spec = replicated_strategy; double memory_cost = ByteSizeOfShapeWithSharding(shape, output_spec); @@ -774,47 +763,45 @@ void AddReplicatedStrategy( *operands_to_consider_all_strategies_for.begin(); auto operand = ins->operand(operand_to_consider_all_strategies_for); CHECK(!operand->shape().IsTuple()); - auto operand_strategies_to_consider = strategy_map.at(operand).get(); + const auto& operand_strategy_group = strategy_map.at(operand).get(); + const auto& operand_strategies = operand_strategy_group->GetStrategies(); std::vector>> possible_input_shardings( - operand_strategies_to_consider->strategies.size(), + operand_strategies.size(), std::vector>(ins->operand_count())); std::vector possible_communication_resharding_costs( - operand_strategies_to_consider->strategies.size(), - ReshardingCosts(ins->operand_count())); + operand_strategies.size(), ReshardingCosts(ins->operand_count())); std::vector possible_memory_resharding_costs( - operand_strategies_to_consider->strategies.size(), - ReshardingCosts(ins->operand_count())); + operand_strategies.size(), ReshardingCosts(ins->operand_count())); for (int64_t k = 0; k < ins->operand_count(); ++k) { - CHECK(!ins->operand(k)->shape().IsTuple()); + const HloInstruction* operand = ins->operand(k); + const Shape& operand_shape = operand->shape(); + CHECK(!operand_shape.IsTuple()); + const StrategyGroup& operand_strategy_group = *strategy_map.at(operand); if (k == operand_to_consider_all_strategies_for) { - CHECK_EQ(possible_input_shardings.size(), - operand_strategies_to_consider->strategies.size()); + CHECK_EQ(possible_input_shardings.size(), operand_strategies.size()); for (size_t j = 0; j < possible_input_shardings.size(); ++j) { - possible_input_shardings[j][k] = - operand_strategies_to_consider->strategies[j].output_sharding; + const auto& operand_sharding = operand_strategies[j].output_sharding; + possible_input_shardings[j][k] = operand_sharding; possible_communication_resharding_costs[j][k] = - CommunicationReshardingCostVector( - strategy_map.at(ins->operand(k)).get(), - ins->operand(k)->shape(), - operand_strategies_to_consider->strategies[j].output_sharding, - cluster_env); - possible_memory_resharding_costs[j][k] = MemoryReshardingCostVector( - strategy_map.at(ins->operand(k)).get(), ins->operand(k)->shape(), - operand_strategies_to_consider->strategies[j].output_sharding, - cluster_env); + CommunicationReshardingCostVector(operand_strategy_group, + operand_shape, operand_sharding, + cluster_env); + possible_memory_resharding_costs[j][k] = + MemoryReshardingCostVector(operand_strategy_group, operand_shape, + operand_sharding, cluster_env); } } else { for (size_t j = 0; j < possible_input_shardings.size(); ++j) { possible_input_shardings[j][k] = replicated_strategy; possible_communication_resharding_costs[j][k] = CommunicationReshardingCostVector( - strategy_map.at(ins->operand(k)).get(), - ins->operand(k)->shape(), replicated_strategy, cluster_env); - possible_memory_resharding_costs[j][k] = MemoryReshardingCostVector( - strategy_map.at(ins->operand(k)).get(), ins->operand(k)->shape(), - replicated_strategy, cluster_env); + operand_strategy_group, operand_shape, replicated_strategy, + cluster_env); + possible_memory_resharding_costs[j][k] = + MemoryReshardingCostVector(operand_strategy_group, operand_shape, + replicated_strategy, cluster_env); } } } @@ -822,7 +809,7 @@ void AddReplicatedStrategy( for (size_t j = 0; j < possible_input_shardings.size(); ++j) { double communication_cost = ComputeCommunicationCost( ins, possible_input_shardings[j], cluster_env); - strategy_group->strategies.push_back(ShardingStrategy( + strategy_group.AddStrategy(ShardingStrategy( {"R", replicated_strategy, replicated_penalty, communication_cost, memory_cost, std::move(possible_communication_resharding_costs[j]), std::move(possible_memory_resharding_costs[j]), @@ -840,29 +827,30 @@ void AddReplicatedStrategy( "b/233412625."; std::tie(communication_resharding_costs, memory_resharding_costs, input_shardings) = - ReshardingCostsForTupleOperand( - ins->operand(0), strategy_map.at(ins->operand(0)).get()); + ReshardingCostsForTupleOperand(ins->operand(0), + *strategy_map.at(ins->operand(0))); } else { for (int64_t k = 0; k < ins->operand_count(); ++k) { - auto operand = ins->operand(k); + const HloInstruction* operand = ins->operand(k); + const Shape& operand_shape = operand->shape(); + const StrategyGroup& operand_strategy_group = *strategy_map.at(operand); + const auto& operand_strategies = operand_strategy_group.GetStrategies(); if (ins->opcode() == HloOpcode::kConditional) { - communication_resharding_costs.push_back(std::vector( - strategy_map.at(operand)->strategies.size(), 0)); - memory_resharding_costs.push_back(std::vector( - strategy_map.at(operand)->strategies.size(), 0)); + std::vector zeros(operand_strategies.size(), 0); + communication_resharding_costs.push_back(zeros); + memory_resharding_costs.push_back(zeros); } else { communication_resharding_costs.push_back( - CommunicationReshardingCostVector(strategy_map.at(operand).get(), - ins->operand(k)->shape(), - output_spec, cluster_env)); + CommunicationReshardingCostVector(operand_strategy_group, + operand_shape, output_spec, + cluster_env)); memory_resharding_costs.push_back(MemoryReshardingCostVector( - strategy_map.at(operand).get(), ins->operand(k)->shape(), - output_spec, cluster_env)); + operand_strategy_group, operand_shape, output_spec, cluster_env)); input_shardings.push_back(output_spec); } } } - strategy_group->strategies.push_back(ShardingStrategy( + strategy_group.AddStrategy(ShardingStrategy( {"R", HloSharding::Replicate(), replicated_penalty, 0, memory_cost, std::move(communication_resharding_costs), std::move(memory_resharding_costs), input_shardings})); @@ -887,10 +875,10 @@ void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - std::unique_ptr& strategy_group, const bool only_allow_divisible, const std::string& suffix, - const CallGraph& call_graph) { + const CallGraph& call_graph, + StrategyGroup& strategy_group) { for (int64_t i = 0; i < shape.rank(); ++i) { for (int64_t j = 0; j < device_mesh.num_dimensions(); ++j) { if (device_mesh.dim(j) == 1 || shape.dimensions(i) < device_mesh.dim(j) || @@ -920,8 +908,8 @@ void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, "operand."; std::tie(communication_resharding_costs, memory_resharding_costs, input_shardings) = - ReshardingCostsForTupleOperand( - ins->operand(0), strategy_map.at(ins->operand(0)).get()); + ReshardingCostsForTupleOperand(ins->operand(0), + *strategy_map.at(ins->operand(0))); } else if (ins->opcode() == HloOpcode::kRngBitGenerator && ins->operand(0)->shape().IsArray()) { input_shardings.push_back(HloSharding::Replicate()); @@ -947,7 +935,7 @@ void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, communication_cost = ComputeSortCommunicationCost( ins->operand(0)->shape().rank() - 1, i, j, shape, cluster_env); } - strategy_group->strategies.push_back(ShardingStrategy( + strategy_group.AddStrategy(ShardingStrategy( {name, output_spec, compute_cost, communication_cost, memory_cost, std::move(communication_resharding_costs), std::move(memory_resharding_costs), input_shardings})); @@ -959,25 +947,25 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - std::unique_ptr& strategy_group, const CallGraph& call_graph, - absl::Span tensor_dims); + absl::Span tensor_dims, + StrategyGroup& strategy_group); void EnumerateAllPartition(const HloInstruction* ins, const Shape& shape, const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - std::unique_ptr& strategy_group, const InstructionBatchDimMap& batch_dim_map, const bool only_allow_divisible, const CallGraph& call_graph, const int64_t partition_dimensions, - const std::vector& tensor_dims) { + const std::vector& tensor_dims, + StrategyGroup& strategy_group) { const auto tensor_dims_size = tensor_dims.size(); if (tensor_dims_size == partition_dimensions) { BuildStrategyAndCostForOp(ins, shape, device_mesh, cluster_env, - strategy_map, strategy_group, call_graph, - tensor_dims); + strategy_map, call_graph, tensor_dims, + strategy_group); return; } auto iter = batch_dim_map.find(GetBatchDimMapKey(ins)); @@ -1001,8 +989,9 @@ void EnumerateAllPartition(const HloInstruction* ins, const Shape& shape, std::vector next_tensor_dims = tensor_dims; next_tensor_dims.push_back(i); EnumerateAllPartition(ins, shape, device_mesh, cluster_env, strategy_map, - strategy_group, batch_dim_map, only_allow_divisible, - call_graph, partition_dimensions, next_tensor_dims); + batch_dim_map, only_allow_divisible, call_graph, + partition_dimensions, next_tensor_dims, + strategy_group); } } @@ -1010,9 +999,9 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - std::unique_ptr& strategy_group, const CallGraph& call_graph, - absl::Span tensor_dims) { + absl::Span tensor_dims, + StrategyGroup& strategy_group) { std::vector mesh_dims(tensor_dims.size()); std::iota(mesh_dims.begin(), mesh_dims.end(), 0); const std::string name = @@ -1038,7 +1027,7 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, std::tie(communication_resharding_costs, memory_resharding_costs, input_shardings) = ReshardingCostsForTupleOperand(ins->operand(0), - strategy_map.at(ins->operand(0)).get()); + *strategy_map.at(ins->operand(0))); } else { std::tie(communication_resharding_costs, memory_resharding_costs, input_shardings) = @@ -1063,18 +1052,22 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, } } - strategy_group->strategies.push_back( + strategy_group.AddStrategy( ShardingStrategy({name, output_spec, compute_cost, communication_cost, memory_cost, std::move(communication_resharding_costs), std::move(memory_resharding_costs), input_shardings})); } -void EnumerateAll1DPartitionReshape( - const HloInstruction* ins, const DeviceMesh& device_mesh, - const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - std::unique_ptr& strategy_group, bool only_allow_divisible, - const std::string& suffix) { +void EnumerateAll1DPartitionReshape(const HloInstruction* ins, + const DeviceMesh& device_mesh, + const ClusterEnvironment& cluster_env, + const StrategyMap& strategy_map, + bool only_allow_divisible, + const std::string& suffix, + StrategyGroup& strategy_group) { const HloInstruction* operand = ins->operand(0); + const Shape& operand_shape = operand->shape(); + const StrategyGroup& operand_strategy_group = *strategy_map.at(operand); for (int64_t i = 0; i < ins->shape().rank(); ++i) { for (int64_t j = 0; j < device_mesh.num_dimensions(); ++j) { @@ -1086,7 +1079,7 @@ void EnumerateAll1DPartitionReshape( HloSharding output_spec = Tile(ins->shape(), {i}, {j}, device_mesh); std::optional input_spec = - hlo_sharding_util::ReshapeSharding(ins->shape(), operand->shape(), + hlo_sharding_util::ReshapeSharding(ins->shape(), operand_shape, output_spec); if (!input_spec.has_value()) { // invalid reshape continue; @@ -1104,13 +1097,11 @@ void EnumerateAll1DPartitionReshape( ByteSizeOfShapeWithSharding(ins->shape(), output_spec); ReshardingCosts communication_resharding_costs{ - CommunicationReshardingCostVector(strategy_map.at(operand).get(), - operand->shape(), *input_spec, - cluster_env)}; + CommunicationReshardingCostVector( + operand_strategy_group, operand_shape, *input_spec, cluster_env)}; ReshardingCosts memory_resharding_costs{MemoryReshardingCostVector( - strategy_map.at(operand).get(), operand->shape(), *input_spec, - cluster_env)}; - strategy_group->strategies.push_back( + operand_strategy_group, operand_shape, *input_spec, cluster_env)}; + strategy_group.AddStrategy( ShardingStrategy({name, output_spec, compute_cost, @@ -1123,26 +1114,24 @@ void EnumerateAll1DPartitionReshape( } } -void BuildStrategyAndCostForReshape( - const HloInstruction* ins, const DeviceMesh& device_mesh, - const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - std::unique_ptr& strategy_group, - absl::Span tensor_dims); +void BuildStrategyAndCostForReshape(const HloInstruction* ins, + const DeviceMesh& device_mesh, + const ClusterEnvironment& cluster_env, + const StrategyMap& strategy_map, + absl::Span tensor_dims, + StrategyGroup& strategy_group); // Enumerate all partitions for reshape. Batch dim is always partitioned. -void EnumeratePartitionReshape(const HloInstruction* ins, - const DeviceMesh& device_mesh, - const ClusterEnvironment& cluster_env, - const StrategyMap& strategy_map, - const InstructionBatchDimMap& batch_dim_map, - std::unique_ptr& strategy_group, - const bool only_allow_divisible, - const int64_t partition_dimensions, - const std::vector& tensor_dims = {}) { +void EnumeratePartitionReshape( + const HloInstruction* ins, const DeviceMesh& device_mesh, + const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, + const InstructionBatchDimMap& batch_dim_map, + const bool only_allow_divisible, const int64_t partition_dimensions, + const std::vector& tensor_dims, StrategyGroup& strategy_group) { const auto tensor_dims_size = tensor_dims.size(); if (tensor_dims_size == partition_dimensions) { BuildStrategyAndCostForReshape(ins, device_mesh, cluster_env, strategy_map, - strategy_group, tensor_dims); + tensor_dims, strategy_group); return; } auto iter = batch_dim_map.find(GetBatchDimMapKey(ins)); @@ -1169,24 +1158,27 @@ void EnumeratePartitionReshape(const HloInstruction* ins, std::vector next_tensor_dims = tensor_dims; next_tensor_dims.push_back(i); EnumeratePartitionReshape(ins, device_mesh, cluster_env, strategy_map, - batch_dim_map, strategy_group, - only_allow_divisible, partition_dimensions, - next_tensor_dims); + batch_dim_map, only_allow_divisible, + partition_dimensions, next_tensor_dims, + strategy_group); } } -void BuildStrategyAndCostForReshape( - const HloInstruction* ins, const DeviceMesh& device_mesh, - const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - std::unique_ptr& strategy_group, - absl::Span tensor_dims) { +void BuildStrategyAndCostForReshape(const HloInstruction* ins, + const DeviceMesh& device_mesh, + const ClusterEnvironment& cluster_env, + const StrategyMap& strategy_map, + absl::Span tensor_dims, + StrategyGroup& strategy_group) { const HloInstruction* operand = ins->operand(0); + const Shape& operand_shape = operand->shape(); + const StrategyGroup& operand_strategy_group = *strategy_map.at(operand); std::vector mesh_dims(tensor_dims.size()); std::iota(mesh_dims.begin(), mesh_dims.end(), 0); - HloSharding output_spec = + const HloSharding output_spec = Tile(ins->shape(), tensor_dims, mesh_dims, device_mesh); std::optional input_spec = hlo_sharding_util::ReshapeSharding( - ins->shape(), operand->shape(), output_spec); + ins->shape(), operand_shape, output_spec); if (!input_spec.has_value()) { // invalid reshape return; } @@ -1195,16 +1187,13 @@ void BuildStrategyAndCostForReshape( absl::StrJoin(mesh_dims, ",")); double compute_cost = 0, communication_cost = 0; double memory_cost = ByteSizeOfShapeWithSharding(ins->shape(), output_spec); - ; ReshardingCosts communication_resharding_costs{ - CommunicationReshardingCostVector(strategy_map.at(operand).get(), - operand->shape(), *input_spec, - cluster_env)}; - ReshardingCosts memory_resharding_costs{ - MemoryReshardingCostVector(strategy_map.at(operand).get(), - operand->shape(), *input_spec, cluster_env)}; - strategy_group->strategies.push_back( + CommunicationReshardingCostVector(operand_strategy_group, operand_shape, + *input_spec, cluster_env)}; + ReshardingCosts memory_resharding_costs{MemoryReshardingCostVector( + operand_strategy_group, operand_shape, *input_spec, cluster_env)}; + strategy_group.AddStrategy( ShardingStrategy({name, output_spec, compute_cost, @@ -1224,10 +1213,9 @@ int64_t MaxNumTiles(const StrategyMap& strategy_map, strategy_group = strategy_group->following; } int64_t max_num_tiles = -1; - for (size_t i = 0; i < strategy_group->strategies.size(); ++i) { + for (const ShardingStrategy& strategy : strategy_group->GetStrategies()) { max_num_tiles = - std::max(max_num_tiles, - strategy_group->strategies[i].output_sharding.NumTiles()); + std::max(max_num_tiles, strategy.output_sharding.NumTiles()); } return max_num_tiles; } @@ -1333,53 +1321,55 @@ void DisableIncompatibleMixedMeshShapeAndForceBatchDim( } void FillAllStrategiesForArray( - std::unique_ptr& strategy_group, const HloInstruction* ins, - const Shape& shape, const ClusterEnvironment& cluster_env, - const StrategyMap& strategy_map, const AutoShardingOption& option, - const double replicated_penalty, + const HloInstruction* ins, const Shape& shape, + const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, + const AutoShardingOption& option, const double replicated_penalty, const InstructionBatchDimMap& batch_dim_map, const CallGraph& call_graph, const bool only_allow_divisible, const bool create_replicated_strategies, - const bool create_partially_replicated_strategies) { + const bool create_partially_replicated_strategies, + StrategyGroup& strategy_group) { if (create_partially_replicated_strategies || cluster_env.IsDeviceMesh1D()) { EnumerateAll1DPartition(ins, shape, cluster_env.device_mesh_, cluster_env, - strategy_map, strategy_group, only_allow_divisible, - "", call_graph); + strategy_map, only_allow_divisible, "", call_graph, + strategy_group); } // Split 2 dims if (cluster_env.IsDeviceMesh2D()) { EnumerateAllPartition(ins, shape, cluster_env.device_mesh_, cluster_env, - strategy_map, strategy_group, batch_dim_map, - only_allow_divisible, call_graph, /*partitions*/ 2); + strategy_map, batch_dim_map, only_allow_divisible, + call_graph, /*partitions*/ 2, /*tensor_dims*/ {}, + strategy_group); } // Split 3 dims if (cluster_env.IsDeviceMesh3D()) { EnumerateAllPartition(ins, shape, cluster_env.device_mesh_, cluster_env, - strategy_map, strategy_group, batch_dim_map, - only_allow_divisible, call_graph, /*partitions*/ 3); + strategy_map, batch_dim_map, only_allow_divisible, + call_graph, /*partitions*/ 3, /*tensor_dims*/ {}, + strategy_group); } if (option.allow_mixed_mesh_shape && cluster_env.IsDeviceMesh2D()) { // Set penalty for 1d partial tiled layout - for (size_t i = 0; i < strategy_group->strategies.size(); ++i) { - strategy_group->strategies[i].compute_cost += replicated_penalty * 0.8; + for (size_t i = 0; i < strategy_group.GetStrategies().size(); ++i) { + strategy_group.GetStrategy(i).compute_cost += replicated_penalty * 0.8; } // Split 1 dim, but for 1d mesh EnumerateAll1DPartition(ins, shape, cluster_env.device_mesh_1d_, - cluster_env, strategy_map, strategy_group, - only_allow_divisible, " 1d", call_graph); + cluster_env, strategy_map, only_allow_divisible, + " 1d", call_graph, strategy_group); } - if (create_replicated_strategies || strategy_group->strategies.empty()) { - AddReplicatedStrategy(ins, shape, cluster_env, strategy_map, strategy_group, - replicated_penalty); + if (create_replicated_strategies || strategy_group.GetStrategies().empty()) { + AddReplicatedStrategy(ins, shape, cluster_env, strategy_map, + replicated_penalty, {}, strategy_group); } // If force_batch_dim_to_mesh_dim is set, filter out invalid strategies // and only keep the data parallel strategies. if (option.force_batch_dim_to_mesh_dim >= 0 && batch_dim_map.contains(GetBatchDimMapKey(ins))) { - CHECK_OK(FilterStrategy(ins, shape, strategy_group, cluster_env, - batch_dim_map, option)); + CHECK_OK(FilterStrategy(ins, shape, cluster_env, batch_dim_map, option, + strategy_group)); } } @@ -1394,7 +1384,6 @@ absl::StatusOr> CreateAllStrategiesGroup( std::unique_ptr strategy_group; if (shape.IsTuple()) { strategy_group = CreateTupleStrategyGroup(instruction_id); - strategy_group->childs.reserve(shape.tuple_shapes_size()); for (size_t i = 0; i < shape.tuple_shapes_size(); ++i) { auto child_strategies = CreateAllStrategiesGroup(ins, shape.tuple_shapes(i), instruction_id, @@ -1405,21 +1394,22 @@ absl::StatusOr> CreateAllStrategiesGroup( create_partially_replicated_strategies) .value(); child_strategies->tuple_element_idx = i; - strategy_group->childs.push_back(std::move(child_strategies)); + strategy_group->AddChild(std::move(child_strategies)); } } else if (shape.IsArray()) { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, strategy_groups); FillAllStrategiesForArray( - strategy_group, ins, shape, cluster_env, strategy_map, option, - replicated_penalty, batch_dim_map, call_graph, only_allow_divisible, - create_replicated_strategies, create_partially_replicated_strategies); + ins, shape, cluster_env, strategy_map, option, replicated_penalty, + batch_dim_map, call_graph, only_allow_divisible, + create_replicated_strategies, create_partially_replicated_strategies, + *strategy_group); } else if (shape.IsToken()) { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, strategy_groups); - AddReplicatedStrategy(ins, shape, cluster_env, strategy_map, strategy_group, - replicated_penalty); + AddReplicatedStrategy(ins, shape, cluster_env, strategy_map, + replicated_penalty, {}, *strategy_group); } else { LOG(FATAL) << "Unsupported instruction shape: " << shape.DebugString(); } @@ -1460,18 +1450,19 @@ bool ShardingIsConsistent(const HloSharding& partial_sharding, // HloSharding. // These two are distinguished by spmd::ShardingIsComplete(). void TrimOrGenerateStrategiesBasedOnExistingSharding( - const Shape& output_shape, StrategyGroup* strategy_group, - const StrategyMap& strategy_map, + const Shape& output_shape, const StrategyMap& strategy_map, const std::vector& instructions, const HloSharding& existing_sharding, const ClusterEnvironment& cluster_env, StableMap>& pretrimmed_strategy_map, - const CallGraph& call_graph, const bool strict) { - if (strategy_group->is_tuple) { - for (size_t i = 0; i < strategy_group->childs.size(); ++i) { + const CallGraph& call_graph, const bool strict, + StrategyGroup& strategy_group) { + if (strategy_group.is_tuple) { + for (size_t i = 0; i < strategy_group.GetChildren().size(); ++i) { TrimOrGenerateStrategiesBasedOnExistingSharding( - output_shape.tuple_shapes(i), strategy_group->childs.at(i).get(), - strategy_map, instructions, existing_sharding.tuple_elements().at(i), - cluster_env, pretrimmed_strategy_map, call_graph, strict); + output_shape.tuple_shapes(i), strategy_map, instructions, + existing_sharding.tuple_elements().at(i), cluster_env, + pretrimmed_strategy_map, call_graph, strict, + strategy_group.GetChild(i)); } } else { if (existing_sharding.IsUnknown()) { @@ -1480,34 +1471,34 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( if (spmd::ShardingIsComplete(existing_sharding, cluster_env.device_mesh_.num_elements())) { // Sharding provided by XLA users, we need to keep them. - strategy_group->following = nullptr; + strategy_group.following = nullptr; std::vector new_strategies; - for (size_t i = 0; i < strategy_group->strategies.size(); i++) { - if (strategy_group->strategies[i].output_sharding == - existing_sharding) { - VLOG(1) << "Keeping strategy index: " << i; - ShardingStrategy found_strategy = strategy_group->strategies[i]; - new_strategies.push_back(found_strategy); + for (const ShardingStrategy& strategy : strategy_group.GetStrategies()) { + if (strategy.output_sharding == existing_sharding) { + VLOG(1) << "Keeping strategy: " << strategy.ToString(); + new_strategies.push_back(strategy); } } if (!new_strategies.empty()) { // Stores other strategies in the map, removes them in the vector and // only keeps the one we found. - pretrimmed_strategy_map[strategy_group->node_idx] = - strategy_group->strategies; - strategy_group->strategies.clear(); - strategy_group->strategies = new_strategies; + pretrimmed_strategy_map[strategy_group.node_idx] = + strategy_group.GetStrategies(); + strategy_group.ClearStrategies(); + for (const ShardingStrategy& strategy : new_strategies) { + strategy_group.AddStrategy(strategy); + } } else { VLOG(1) << "Generate a new strategy based on user sharding."; std::string name = ToStringSimple(existing_sharding); ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; std::vector> input_shardings; - if (!strategy_group->in_nodes.empty()) { - HloInstruction* ins = instructions.at(strategy_group->instruction_id); - for (size_t i = 0; i < strategy_group->in_nodes.size(); i++) { + if (!strategy_group.in_nodes.empty()) { + HloInstruction* ins = instructions.at(strategy_group.instruction_id); + for (size_t i = 0; i < strategy_group.in_nodes.size(); i++) { HloInstruction* operand = - instructions.at(strategy_group->in_nodes.at(i)->instruction_id); + instructions.at(strategy_group.in_nodes.at(i)->instruction_id); std::optional input_sharding = ShardingPropagation::GetShardingFromUser( *operand, *ins, 10, true, call_graph, @@ -1521,7 +1512,7 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( operand->shape(), {ins->tuple_index()}); } operand_strategy_group = - operand_strategy_group->childs[ins->tuple_index()].get(); + &operand_strategy_group->GetChild(ins->tuple_index()); operand_shape = operand->shape().tuple_shapes(ins->tuple_index()); } @@ -1538,21 +1529,21 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( input_shardings.push_back(*input_sharding); communication_resharding_costs.push_back( CommunicationReshardingCostVector( - operand_strategy_group, operand_shape, *input_sharding, + *operand_strategy_group, operand_shape, *input_sharding, cluster_env)); memory_resharding_costs.push_back(MemoryReshardingCostVector( - operand_strategy_group, operand_shape, *input_sharding, + *operand_strategy_group, operand_shape, *input_sharding, cluster_env)); } } double memory_cost = ByteSizeOfShapeWithSharding(output_shape, existing_sharding); - if (!strategy_group->strategies.empty()) { - pretrimmed_strategy_map[strategy_group->node_idx] = - strategy_group->strategies; + if (!strategy_group.GetStrategies().empty()) { + pretrimmed_strategy_map[strategy_group.node_idx] = + strategy_group.GetStrategies(); } - strategy_group->strategies.clear(); - strategy_group->strategies.push_back( + strategy_group.ClearStrategies(); + strategy_group.AddStrategy( ShardingStrategy({name, existing_sharding, 0, 0, memory_cost, communication_resharding_costs, memory_resharding_costs, input_shardings})); @@ -1561,23 +1552,23 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( // that option is kInfinityCost, set the cost to zero. This is okay // because there is only one option anyway, and having the costs set to // kInfinityCost is problematic for the solver. - if (strategy_group->strategies.size() == 1) { + if (strategy_group.GetStrategies().size() == 1) { for (auto& operand_communication_resharding_costs : - strategy_group->strategies[0].communication_resharding_costs) { + strategy_group.GetStrategy(0).communication_resharding_costs) { if (operand_communication_resharding_costs.size() == 1 && operand_communication_resharding_costs[0] >= kInfinityCost) { operand_communication_resharding_costs[0] = 0; } } } - } else if (!strategy_group->following) { + } else if (!strategy_group.following) { // If existing sharding is a partial sharding from previous iteration, // find the strategies that are 1D&&complete or align with user // sharding. // It is IMPORTANT that we do this only for instructions that do no follow // others, to keep the number of ILP variable small. std::vector new_vector; - for (const auto& strategy : strategy_group->strategies) { + for (const ShardingStrategy& strategy : strategy_group.GetStrategies()) { if (strategy.output_sharding.IsReplicated() || ShardingIsConsistent(existing_sharding, strategy.output_sharding, strict) || @@ -1587,37 +1578,40 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( spmd::ShardingIsComplete( strategy.output_sharding, cluster_env.original_device_mesh_.num_elements()))) { - new_vector.push_back(std::move(strategy)); + new_vector.push_back(strategy); } } // If no sharding strategy left, just keep the original set, because we do // not have to strictly keep those shardings and the only purpose is to // reduce problem size for the last iteration. if (!new_vector.empty() && - new_vector.size() != strategy_group->strategies.size()) { - strategy_group->following = nullptr; - strategy_group->strategies = std::move(new_vector); + new_vector.size() != strategy_group.GetStrategies().size()) { + strategy_group.following = nullptr; + strategy_group.ClearStrategies(); + for (const ShardingStrategy& strategy : new_vector) { + strategy_group.AddStrategy(strategy); + } } } } } -void CheckMemoryCosts(StrategyGroup* strategy_group, const Shape& shape) { - if (strategy_group->is_tuple) { - for (size_t i = 0; i < strategy_group->childs.size(); i++) { - CheckMemoryCosts(strategy_group->childs[i].get(), +void CheckMemoryCosts(const StrategyGroup& strategy_group, const Shape& shape) { + if (strategy_group.is_tuple) { + for (size_t i = 0; i < strategy_group.GetChildren().size(); i++) { + CheckMemoryCosts(*strategy_group.GetChildren()[i], shape.tuple_shapes().at(i)); } } else { double full_mem = 0.0; - for (const auto& strategy : strategy_group->strategies) { + for (const ShardingStrategy& strategy : strategy_group.GetStrategies()) { if (strategy.output_sharding.IsReplicated()) { full_mem = strategy.memory_cost; size_t size = ByteSizeOfShape(shape); CHECK_EQ(strategy.memory_cost, size); } } - for (const auto& strategy : strategy_group->strategies) { + for (const ShardingStrategy& strategy : strategy_group.GetStrategies()) { if (!strategy.output_sharding.IsReplicated() && full_mem > 0.0) { CHECK_GE(strategy.memory_cost * strategy.output_sharding.NumTiles(), full_mem); @@ -1627,17 +1621,19 @@ void CheckMemoryCosts(StrategyGroup* strategy_group, const Shape& shape) { } void RemoveShardingsWhereSmallDimsShardedAcrossManyDevices( - const Shape& shape, StrategyGroup* strategy_group, - const bool instruction_has_user_sharding) { - if (strategy_group->is_tuple) { - for (size_t i = 0; i < strategy_group->childs.size(); i++) { + const Shape& shape, const bool instruction_has_user_sharding, + StrategyGroup& strategy_group) { + if (strategy_group.is_tuple) { + const auto& children = strategy_group.GetChildren(); + for (size_t i = 0; i < children.size(); i++) { RemoveShardingsWhereSmallDimsShardedAcrossManyDevices( - shape.tuple_shapes().at(i), strategy_group->childs[i].get(), - instruction_has_user_sharding); + shape.tuple_shapes().at(i), instruction_has_user_sharding, + *children[i]); } return; } - if (instruction_has_user_sharding && strategy_group->strategies.size() == 1) { + if (instruction_has_user_sharding && + strategy_group.GetStrategies().size() == 1) { // If an instruction has a specified user sharding, and there is only a // single strategy, removing that strategy would mean we won't have any // strategy for that instruction. Further, given that the user has @@ -1646,9 +1642,8 @@ void RemoveShardingsWhereSmallDimsShardedAcrossManyDevices( return; } std::vector invalid_strategy_indices; - for (int strategy_idx = 0; strategy_idx < strategy_group->strategies.size(); - strategy_idx++) { - const ShardingStrategy& strategy = strategy_group->strategies[strategy_idx]; + for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) { + const ShardingStrategy& strategy = strategy_group.GetStrategy(sid); if (strategy.output_sharding.IsReplicated()) { continue; } @@ -1656,76 +1651,29 @@ void RemoveShardingsWhereSmallDimsShardedAcrossManyDevices( for (int64_t i = 0; i < shape.rank(); ++i) { if (tile_assignment.dim(i) > 1 && tile_assignment.dim(i) > shape.dimensions(i)) { - invalid_strategy_indices.push_back(strategy_idx); + invalid_strategy_indices.push_back(sid); break; } } } - if (invalid_strategy_indices.size() < strategy_group->strategies.size()) { - for (int strategy_idx : invalid_strategy_indices) { - VLOG(1) << "Removing invalid strategy: " - << strategy_group->strategies[strategy_idx].ToString(); - strategy_group->strategies[strategy_idx].compute_cost = kInfinityCost; - } - } -} - -void CheckReshardingCostsShape(StrategyGroup* strategy_group) { - if (strategy_group->is_tuple) { - for (size_t i = 0; i < strategy_group->childs.size(); i++) { - CheckReshardingCostsShape(strategy_group->childs[i].get()); - } - } else { - for (const auto& strategy : strategy_group->strategies) { - if (strategy_group->in_nodes.size() == 1 && - strategy_group->in_nodes.at(0)->is_tuple) { - // This is when current instruction's only operand is tuple, and the - // first dimension of communication_resharding_costs should equal its - // number of tuple elements. - CHECK_EQ(strategy.communication_resharding_costs.size(), - strategy_group->in_nodes.at(0)->childs.size()) - << "Instruction ID: " << strategy_group->instruction_id << "\n" - << strategy_group->ToString(); - } else { - // The rest of the time, the first dimension of - // communication_resharding_costs should equal its number of operands - // (in_nodes). - CHECK_EQ(strategy.communication_resharding_costs.size(), - strategy_group->in_nodes.size()) - << "Instruction ID: " << strategy_group->instruction_id << "\n" - << strategy_group->ToString(); - } - for (size_t i = 0; i < strategy.communication_resharding_costs.size(); - i++) { - size_t to_compare; - if (strategy_group->in_nodes.size() == 1 && - strategy_group->in_nodes.at(0)->is_tuple) { - to_compare = - strategy_group->in_nodes.at(0)->childs.at(i)->strategies.size(); - } else if (strategy_group->is_tuple) { - to_compare = strategy_group->in_nodes.at(i)->childs.size(); - } else { - to_compare = strategy_group->in_nodes.at(i)->strategies.size(); - } - CHECK_EQ(strategy.communication_resharding_costs[i].size(), to_compare) - << "\nIndex of communication_resharding_costs: " << i - << "\nInstruction ID: " << strategy_group->instruction_id - << "\nCurrent strategies:\n" - << strategy_group->ToString(); - } + if (invalid_strategy_indices.size() < strategy_group.GetStrategies().size()) { + for (size_t sid : invalid_strategy_indices) { + ShardingStrategy& strategy = strategy_group.GetStrategy(sid); + VLOG(1) << "Removing invalid strategy: " << strategy.ToString(); + strategy.compute_cost = kInfinityCost; } } } -void ScaleCostsWithExecutionCounts(StrategyGroup* strategy_group, - const int64_t execution_count) { - if (strategy_group->is_tuple) { - for (size_t i = 0; i < strategy_group->childs.size(); ++i) { - ScaleCostsWithExecutionCounts(strategy_group->childs[i].get(), - execution_count); +void ScaleCostsWithExecutionCounts(const int64_t execution_count, + StrategyGroup& strategy_group) { + if (strategy_group.is_tuple) { + for (const auto& child : strategy_group.GetChildren()) { + ScaleCostsWithExecutionCounts(execution_count, *child); } } else { - for (auto& strategy : strategy_group->strategies) { + for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) { + ShardingStrategy& strategy = strategy_group.GetStrategy(sid); strategy.compute_cost *= execution_count; strategy.communication_cost *= execution_count; for (auto i = 0; i < strategy.communication_resharding_costs.size(); @@ -1808,14 +1756,13 @@ std::unique_ptr HandleManuallyShardedInstruction( std::unique_ptr strategy_group; if (shape.IsTuple()) { strategy_group = CreateTupleStrategyGroup(instruction_id); - strategy_group->childs.reserve(shape.tuple_shapes_size()); for (size_t i = 0; i < shape.tuple_shapes_size(); ++i) { std::unique_ptr child_strategies = HandleManuallyShardedInstruction(ins, shape.tuple_shapes(i), instruction_id, strategy_groups, strategy_map); child_strategies->tuple_element_idx = i; - strategy_group->childs.push_back(std::move(child_strategies)); + strategy_group->AddChild(std::move(child_strategies)); } } else if (shape.IsToken() || shape.IsArray()) { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, @@ -1831,18 +1778,19 @@ std::unique_ptr HandleManuallyShardedInstruction( "b/233412625."; std::tie(communication_resharding_costs, memory_resharding_costs, input_shardings) = - ReshardingCostsForTupleOperand( - ins->operand(0), strategy_map.at(ins->operand(0)).get()); + ReshardingCostsForTupleOperand(ins->operand(0), + *strategy_map.at(ins->operand(0))); } else { for (int64_t k = 0; k < ins->operand_count(); ++k) { const HloInstruction* operand = ins->operand(k); - communication_resharding_costs.push_back(std::vector( - strategy_map.at(operand)->strategies.size(), 0)); - memory_resharding_costs.push_back(std::vector( - strategy_map.at(operand)->strategies.size(), 0)); + const StrategyGroup& operand_strategy_group = *strategy_map.at(operand); + const auto& strategies = operand_strategy_group.GetStrategies(); + const std::vector zeros(strategies.size(), 0); + communication_resharding_costs.push_back(zeros); + memory_resharding_costs.push_back(zeros); } } - strategy_group->strategies.push_back(ShardingStrategy( + strategy_group->AddStrategy(ShardingStrategy( {"MANUAL", HloSharding::Replicate(), 0, 0, static_cast(ShapeUtil::ByteSizeOf(shape)), std::move(communication_resharding_costs), @@ -1870,15 +1818,14 @@ std::unique_ptr CreateReshapeStrategies( const HloInstruction* operand = ins->operand(0); // Create follow strategies - const StrategyGroup* src_strategy_group = strategy_map.at(operand).get(); - CHECK(!src_strategy_group->is_tuple); - strategy_group->following = src_strategy_group; + const StrategyGroup& src_strategy_group = *strategy_map.at(operand); + CHECK(!src_strategy_group.is_tuple); + strategy_group->following = &src_strategy_group; - for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); ++sid) { + for (const auto& src_strategy : src_strategy_group.GetStrategies()) { std::optional output_spec = - hlo_sharding_util::ReshapeSharding( - operand->shape(), ins->shape(), - src_strategy_group->strategies[sid].output_sharding); + hlo_sharding_util::ReshapeSharding(operand->shape(), ins->shape(), + src_strategy.output_sharding); if (!output_spec.has_value()) { continue; @@ -1898,30 +1845,30 @@ std::unique_ptr CreateReshapeStrategies( std::vector communication_resharding_costs = CommunicationReshardingCostVector( src_strategy_group, operand->shape(), - src_strategy_group->strategies[sid].output_sharding, cluster_env); - std::vector memory_resharding_costs = MemoryReshardingCostVector( - src_strategy_group, operand->shape(), - src_strategy_group->strategies[sid].output_sharding, cluster_env); - strategy_group->strategies.push_back(ShardingStrategy( - {name, - *output_spec, - compute_cost, - communication_cost, - memory_cost, - {communication_resharding_costs}, - {memory_resharding_costs}, - {src_strategy_group->strategies[sid].output_sharding}})); - } - } - - if (strategy_group->strategies.empty()) { + src_strategy.output_sharding, cluster_env); + std::vector memory_resharding_costs = + MemoryReshardingCostVector(src_strategy_group, operand->shape(), + src_strategy.output_sharding, cluster_env); + strategy_group->AddStrategy( + ShardingStrategy({name, + *output_spec, + compute_cost, + communication_cost, + memory_cost, + {communication_resharding_costs}, + {memory_resharding_costs}, + {src_strategy.output_sharding}})); + } + } + + if (strategy_group->GetStrategies().empty()) { // Fail to create follow strategies, enumerate all possible cases VLOG(2) << "Enumerating all strategies for reshape"; FillAllStrategiesForArray( - strategy_group, ins, ins->shape(), cluster_env, strategy_map, option, + ins, ins->shape(), cluster_env, strategy_map, option, replicated_penalty, batch_dim_map, call_graph, only_allow_divisible, /* create_replicated_strategies */ true, - /* create_partially_replicated_strategies */ true); + /* create_partially_replicated_strategies */ true, *strategy_group); } return strategy_group; @@ -2018,8 +1965,8 @@ AutoShardingSolverResult CallSolver( tuple_elements.at(*strategy_group->tuple_element_idx); } } - for (NodeStrategyIdx j = 0; j < strategy_group->strategies.size(); ++j) { - const ShardingStrategy& strategy = strategy_group->strategies[j]; + for (auto j = 0; j < strategy_group->GetStrategies().size(); ++j) { + const ShardingStrategy& strategy = strategy_group->GetStrategies()[j]; const HloSharding& sharding = strategy.output_sharding; ci.add_costs(strategy.compute_cost); di.add_costs(strategy.communication_cost + @@ -2049,14 +1996,13 @@ AutoShardingSolverResult CallSolver( for (const auto& pair : alias_set) { const StrategyGroup* src_strategy_group = strategy_groups[pair.first]; const StrategyGroup* dst_strategy_group = strategy_groups[pair.second]; - Matrix raw_cost(src_strategy_group->strategies.size(), - dst_strategy_group->strategies.size()); - for (NodeStrategyIdx i = 0; i < src_strategy_group->strategies.size(); - ++i) { - for (NodeStrategyIdx j = 0; j < dst_strategy_group->strategies.size(); - ++j) { - if (src_strategy_group->strategies[i].output_sharding == - dst_strategy_group->strategies[j].output_sharding) { + const auto& src_strategies = src_strategy_group->GetStrategies(); + const auto& dst_strategies = dst_strategy_group->GetStrategies(); + Matrix raw_cost(src_strategies.size(), dst_strategies.size()); + for (NodeStrategyIdx i = 0; i < src_strategies.size(); ++i) { + for (NodeStrategyIdx j = 0; j < dst_strategies.size(); ++j) { + if (src_strategies[i].output_sharding == + dst_strategies[j].output_sharding) { raw_cost(i, j) = 0.0; } else { raw_cost(i, j) = 1.0; @@ -2271,20 +2217,18 @@ void SetHloSharding( extract_tuple_shardings = [&](const StrategyGroup* strategy_group) { if (strategy_group->is_tuple) { - for (const auto& child_strategies : strategy_group->childs) { + for (const auto& child_strategies : strategy_group->GetChildren()) { extract_tuple_shardings(child_strategies.get()); } } else { NodeIdx node_idx = strategy_group->node_idx; NodeStrategyIdx stra_idx = s_val[node_idx]; + const auto& strategy = strategy_group->GetStrategies()[stra_idx]; // Do not set completed sharding before the last iteration - if (strategy_group->strategies[stra_idx] - .output_sharding.IsReplicated() && - !last_iteration) { + if (strategy.output_sharding.IsReplicated() && !last_iteration) { set_tuple_sharding = false; } - output_flattened_shardings.push_back( - strategy_group->strategies[stra_idx].output_sharding); + output_flattened_shardings.push_back(strategy.output_sharding); } }; extract_tuple_shardings(strategy_group); @@ -2639,21 +2583,17 @@ std::string PrintAutoShardingSolution(const HloInstructionSequence& sequence, // Print the chosen strategy for (NodeIdx node_idx = 0; node_idx < N; ++node_idx) { + const StrategyGroup& strategy_group = *strategy_groups[node_idx]; absl::StrAppend( &str, node_idx, " ", - ToAdaptiveString( - instructions[strategy_groups[node_idx]->instruction_id]), - " "); + ToAdaptiveString(instructions[strategy_group.instruction_id]), " "); NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]); - if (cost_graph.follow_idx_[node_idx] < 0) { - absl::StrAppend( - &str, strategy_groups[node_idx]->strategies[stra_idx].ToString(), - "\n"); - } else { - absl::StrAppend( - &str, strategy_groups[node_idx]->strategies[stra_idx].ToString(), - " follow ", cost_graph.follow_idx_[node_idx], "\n"); + const ShardingStrategy& strategy = strategy_group.GetStrategies()[stra_idx]; + absl::StrAppend(&str, strategy.ToString()); + if (cost_graph.follow_idx_[node_idx] >= 0) { + absl::StrAppend(&str, " follow ", cost_graph.follow_idx_[node_idx]); } + absl::StrAppend(&str, "\n"); } return str; @@ -2668,25 +2608,26 @@ std::string PrintSolutionMemoryUsage(const LivenessSet& liveness_set, std::vector> time_memory_usage; // Function that gets the memory usage of a StrategyGroup belongs to one // tensor. - std::function calculate_memory_usage; - calculate_memory_usage = [&](const StrategyGroup* strategy_group) { - if (strategy_group->is_tuple) { + std::function calculate_memory_usage; + calculate_memory_usage = [&](const StrategyGroup& strategy_group) { + if (strategy_group.is_tuple) { double m = 0.0; - for (const auto& child : strategy_group->childs) { - m += calculate_memory_usage(child.get()); + for (const auto& child : strategy_group.GetChildren()) { + m += calculate_memory_usage(*child); } return m; } - NodeIdx ins_idx = strategy_group->node_idx; + NodeIdx ins_idx = strategy_group.node_idx; NodeStrategyIdx stra_idx = cost_graph.RemapIndex(ins_idx, s_val[ins_idx]); - const ShardingStrategy& strategy = strategy_group->strategies[stra_idx]; + const auto& strategies = strategy_group.GetStrategies(); + const ShardingStrategy& strategy = strategies[stra_idx]; return strategy.memory_cost; }; for (LivenessIdx time_idx = 0; time_idx < liveness_set.size(); ++time_idx) { double mem = 0.0; for (const auto& val : liveness_set.at(time_idx)) { const HloInstruction* ins = val->instruction(); - auto tmp = calculate_memory_usage(strategy_map.at(ins).get()); + auto tmp = calculate_memory_usage(*strategy_map.at(ins)); mem += tmp; if (VLOG_IS_ON(6) && tmp / (1024 * 1024) > 1) { @@ -2724,7 +2665,7 @@ std::string PrintSolutionMemoryUsage(const LivenessSet& liveness_set, for (LivenessIdx time_idx = 0; time_idx < k; time_idx++) { for (const auto& val : liveness_set[time_memory_usage.at(time_idx).first]) { const HloInstruction* ins = val->instruction(); - auto mem = calculate_memory_usage(strategy_map.at(ins).get()); + auto mem = calculate_memory_usage(*strategy_map.at(ins)); if (mem > 100 * 1024 * 1024) { instruction_mem.push_back( {absl::StrCat(ins->name(), val->index().ToString()), mem}); @@ -3402,10 +3343,10 @@ void AnnotateShardingWithSimpleHeuristic( // Filter strategies according to the option.force_batch_dim_to_mesh_dim. // This can be used to forcibly generate data-parallel strategies. absl::Status FilterStrategy(const HloInstruction* ins, const Shape& shape, - std::unique_ptr& strategy_group, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, - const AutoShardingOption& option) { + const AutoShardingOption& option, + StrategyGroup& strategy_group) { int mesh_dim = option.force_batch_dim_to_mesh_dim; int batch_dim = batch_map.at(GetBatchDimMapKey(ins)); const DeviceMesh& device_mesh = cluster_env.device_mesh_; @@ -3417,27 +3358,31 @@ absl::Status FilterStrategy(const HloInstruction* ins, const Shape& shape, } std::vector new_strategies; - for (auto& stra : strategy_group->strategies) { - std::vector tensor_dim_to_mesh_dim = - cluster_env.GetTensorDimToMeshDimWrapper(shape, stra.output_sharding); + for (const ShardingStrategy& strategy : strategy_group.GetStrategies()) { + const HloSharding& output_sharding = strategy.output_sharding; + const std::vector tensor_dim_to_mesh_dim = + cluster_env.GetTensorDimToMeshDimWrapper(shape, output_sharding); if (device_mesh.dim(mesh_dim) > 1) { // If the mesh dim is not one, the output tensor must be // tiled along the mesh dim. if (tensor_dim_to_mesh_dim[batch_dim] == mesh_dim) { - new_strategies.push_back(std::move(stra)); + new_strategies.push_back(strategy); } } else { // If the mesh dim is one, the output tensor must be replicated // on the mesh dim. if (tensor_dim_to_mesh_dim[batch_dim] == -1) { - new_strategies.push_back(std::move(stra)); + new_strategies.push_back(strategy); } } } CHECK(!new_strategies.empty()) << ins->ToString() << " does not have any valid strategies"; - strategy_group->strategies = std::move(new_strategies); + strategy_group.ClearStrategies(); + for (const ShardingStrategy& strategy : new_strategies) { + strategy_group.AddStrategy(strategy); + } return absl::OkStatus(); } diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.h b/xla/hlo/experimental/auto_sharding/auto_sharding.h index 6b05b35d73a7f..c9791841e5eba 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -29,6 +29,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" @@ -131,12 +132,12 @@ HloSharding Tile(const Shape& shape, absl::Span tensor_dims, const DeviceMesh& device_mesh); std::vector CommunicationReshardingCostVector( - const StrategyGroup* strategy_group, const Shape& shape, + const StrategyGroup& strategy_group, const Shape& shape, const HloSharding& required_sharding, const ClusterEnvironment& cluster_env); std::vector MemoryReshardingCostVector( - const StrategyGroup* strategy_group, const Shape& operand_shape, + const StrategyGroup& strategy_group, const Shape& operand_shape, const HloSharding& required_sharding, const ClusterEnvironment& cluster_env); @@ -146,17 +147,13 @@ std::unique_ptr CreateLeafStrategyGroup( size_t instruction_id, const HloInstruction* ins, const StrategyMap& strategy_map, StrategyGroups& strategy_groups); -void SetInNodesWithInstruction(std::unique_ptr& strategy_group, - const HloInstruction* ins, - const StrategyMap& strategy_map); - -void RemoveDuplicatedStrategy(std::unique_ptr& strategy_group); +void RemoveDuplicatedStrategy(StrategyGroup& strategy_group); absl::Status FilterStrategy(const HloInstruction* ins, const Shape& shape, - std::unique_ptr& strategy_group, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, - const AutoShardingOption& option); + const AutoShardingOption& option, + StrategyGroup& strategy_group); absl::Status HandleDot(std::unique_ptr& strategy_group, StrategyGroups& strategy_groups, @@ -242,10 +239,11 @@ void PopulateTemporalValues(const CostGraph& cost_graph, void AddReplicatedStrategy( const HloInstruction* ins, const Shape& shape, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - std::unique_ptr& strategy_group, double replicated_penalty, - absl::flat_hash_set operands_to_consider_all_strategies_for = {}); + double replicated_penalty, + absl::flat_hash_set operands_to_consider_all_strategies_for, + StrategyGroup& strategy_group); -void CheckMemoryCosts(StrategyGroup* strategy_group, const Shape& shape); +void CheckMemoryCosts(const StrategyGroup& strategy_group, const Shape& shape); // Choose an operand to follow. We choose to follow the operand with the highest // priority. @@ -254,13 +252,12 @@ std::pair ChooseOperandToFollow( const AliasMap& alias_map, int64_t max_depth, const HloInstruction* ins); void FillAllStrategiesForArray( - std::unique_ptr& strategy_group, const HloInstruction* ins, - const Shape& shape, const ClusterEnvironment& cluster_env, - const StrategyMap& strategy_map, const AutoShardingOption& option, - double replicated_penalty, const InstructionBatchDimMap& batch_dim_map, - const CallGraph& call_graph, bool only_allow_divisible, - bool create_replicated_strategies, - bool create_partially_replicated_strategies); + const HloInstruction* ins, const Shape& shape, + const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, + const AutoShardingOption& option, double replicated_penalty, + const InstructionBatchDimMap& batch_dim_map, const CallGraph& call_graph, + bool only_allow_divisible, bool create_replicated_strategies, + bool create_partially_replicated_strategies, StrategyGroup& strategy_group); absl::StatusOr> CreateAllStrategiesGroup( const HloInstruction* ins, const Shape& shape, size_t instruction_id, @@ -313,22 +310,19 @@ void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - std::unique_ptr& strategy_group, bool only_allow_divisible, const std::string& suffix, - const CallGraph& call_graph); + const CallGraph& call_graph, + StrategyGroup& strategy_group); // Enumerate all partitions recursively. -void EnumerateAllPartition(const HloInstruction* ins, const Shape& shape, - const DeviceMesh& device_mesh, - const ClusterEnvironment& cluster_env, - const StrategyMap& strategy_map, - std::unique_ptr& strategy_group, - const InstructionBatchDimMap& batch_dim_map, - bool only_allow_divisible, - const CallGraph& call_graph, - int64_t partition_dimensions, - const std::vector& tensor_dims = {}); +void EnumerateAllPartition( + const HloInstruction* ins, const Shape& shape, + const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, + const StrategyMap& strategy_map, + const InstructionBatchDimMap& batch_dim_map, bool only_allow_divisible, + const CallGraph& call_graph, int64_t partition_dimensions, + const std::vector& tensor_dims, StrategyGroup& strategy_group); absl::StatusOr> FollowReduceStrategy( const HloInstruction* ins, const Shape& output_shape, @@ -340,8 +334,8 @@ absl::StatusOr> FollowReduceStrategy( void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - std::unique_ptr& strategy_group, - double replicated_penalty); + double replicated_penalty, + StrategyGroup& strategy_group); std::pair GenerateReshardingCostsAndMissingShardingsForAllOperands( @@ -351,28 +345,27 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands( std::vector>& input_shardings); std::unique_ptr MaybeFollowInsStrategyGroup( - const StrategyGroup* src_strategy_group, const Shape& shape, + const StrategyGroup& src_strategy_group, const Shape& shape, size_t instruction_id, StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, const StableMap>& pretrimmed_strategy_map); void RemoveShardingsWhereSmallDimsShardedAcrossManyDevices( - const Shape& shape, StrategyGroup* strategy_group, - bool instruction_has_user_sharding); + const Shape& shape, bool instruction_has_user_sharding, + StrategyGroup& strategy_group); -void ScaleCostsWithExecutionCounts(StrategyGroup* strategy_group, - int64_t execution_count); +void ScaleCostsWithExecutionCounts(int64_t execution_count, + StrategyGroup& strategy_group); // Existing shardings refer to the HloSharding field in the given // HloInstruction. void TrimOrGenerateStrategiesBasedOnExistingSharding( - const Shape& output_shape, StrategyGroup* strategy_group, - const StrategyMap& strategy_map, + const Shape& output_shape, const StrategyMap& strategy_map, const std::vector& instructions, const HloSharding& existing_sharding, const ClusterEnvironment& cluster_env, StableMap>& pretrimmed_strategy_map, - const CallGraph& call_graph, bool strict); + const CallGraph& call_graph, bool strict, StrategyGroup& strategy_group); // Build possible sharding strategies and their costs for all instructions. absl::StatusOr> diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc index a613bb0b582c7..67a67a2a27884 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc @@ -61,9 +61,9 @@ CostGraph::CostGraph(const StrategyGroups& strategy_groups, // Build the cost graph. for (StrategyGroup* strategy_group : strategy_groups) { - node_lens_.push_back(strategy_group->strategies.size()); + node_lens_.push_back(strategy_group->GetStrategies().size()); extra_node_costs_.push_back( - std::vector(strategy_group->strategies.size(), 0.0)); + std::vector(strategy_group->GetStrategies().size(), 0.0)); const auto& in_nodes = strategy_group->in_nodes; for (size_t i = 0; i < in_nodes.size(); ++i) { @@ -74,8 +74,8 @@ CostGraph::CostGraph(const StrategyGroups& strategy_groups, CreateEdgeCost(src_idx, dst_idx, i, strategy_group); AddEdgeCost(src_idx, dst_idx, edge_cost); } else if (in_nodes[i]->is_tuple && in_nodes.size() > 1) { - for (size_t l = 0; l < in_nodes[i]->childs.size(); ++l) { - NodeIdx src_idx = in_nodes[i]->childs[l]->node_idx; + for (const auto& child : in_nodes[i]->GetChildren()) { + NodeIdx src_idx = child->node_idx; NodeIdx dst_idx = strategy_group->node_idx; EdgeReshardingCostMatrix edge_cost = CreateEdgeCost(src_idx, dst_idx, i, strategy_group, true); @@ -86,8 +86,8 @@ CostGraph::CostGraph(const StrategyGroups& strategy_groups, << "Do not support instructions with more than one tuple " "operand. If this CHECK fails, we will need to fix " "b/233412625."; - for (size_t l = 0; l < in_nodes[i]->childs.size(); ++l) { - NodeIdx src_idx = in_nodes[i]->childs[l]->node_idx; + for (size_t l = 0; l < in_nodes[i]->GetChildren().size(); ++l) { + NodeIdx src_idx = in_nodes[i]->GetChildren()[l]->node_idx; NodeIdx dst_idx = strategy_group->node_idx; // TODO(b/233412625) Support more general case, e.g., multiple tuple // operands. If there is only one operand and it's a tuple, the @@ -101,8 +101,8 @@ CostGraph::CostGraph(const StrategyGroups& strategy_groups, } if (strategy_group->following) { - CHECK_EQ(strategy_group->strategies.size(), - strategy_group->following->strategies.size()) + CHECK_EQ(strategy_group->GetStrategies().size(), + strategy_group->following->GetStrategies().size()) << "Different strategy counts for instruction ID " << strategy_group->instruction_id << " and following instruction ID " << strategy_group->following->instruction_id; @@ -116,26 +116,25 @@ CostGraph::CostGraph(const StrategyGroups& strategy_groups, for (const auto& pair : associative_dot_pairs) { NodeIdx src_idx = pair.first->node_idx; NodeIdx dst_idx = pair.second->node_idx; + StrategyGroup& src_strategy_group = *strategy_groups[src_idx]; + StrategyGroup& dst_strategy_group = *strategy_groups[dst_idx]; EdgeReshardingCostMatrix edge_cost(node_lens_[src_idx], node_lens_[dst_idx]); absl::flat_hash_map src_strategy_name_to_idx_map; for (NodeStrategyIdx i = 0; i < node_lens_[src_idx]; ++i) { - const ShardingStrategy& strategy = - strategy_groups[src_idx]->strategies[i]; + const ShardingStrategy& strategy = src_strategy_group.GetStrategy(i); if (strategy.communication_cost > 0) { src_strategy_name_to_idx_map[strategy.name] = i; } } for (NodeStrategyIdx i = 0; i < node_lens_[dst_idx]; ++i) { - const ShardingStrategy& dst_strategy = - strategy_groups[dst_idx]->strategies[i]; + const ShardingStrategy& dst_strategy = dst_strategy_group.GetStrategy(i); if (dst_strategy.communication_cost > 0) { auto it = src_strategy_name_to_idx_map.find(dst_strategy.name); if (it != src_strategy_name_to_idx_map.end()) { - const ShardingStrategy& src_strategy = - strategy_groups[src_idx]->strategies[it->second]; + const auto& src_strategy = src_strategy_group.GetStrategy(it->second); CHECK_LE(std::abs(src_strategy.communication_cost - dst_strategy.communication_cost), 1e-6); @@ -154,8 +153,9 @@ EdgeReshardingCostMatrix CostGraph::CreateEdgeCost( CHECK_LT(src_idx, node_lens_.size()); CHECK_LT(dst_idx, node_lens_.size()); EdgeReshardingCostMatrix edge_cost(node_lens_[src_idx], node_lens_[dst_idx]); - for (NodeStrategyIdx k = 0; k < strategy_group->strategies.size(); ++k) { - const ShardingStrategy& strategy = strategy_group->strategies[k]; + const auto& strategies = strategy_group->GetStrategies(); + for (NodeStrategyIdx k = 0; k < strategies.size(); ++k) { + const ShardingStrategy& strategy = strategies[k]; size_t start_idx = 0; CHECK_LT(in_node_idx, strategy.memory_resharding_costs.size()) << strategy_group->node_idx; diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h index 3d6bac1b13919..802d780707b67 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h @@ -136,7 +136,7 @@ inline const ShardingStrategy& GetShardingStrategy( CHECK(!strategy_group->is_tuple); NodeIdx node_idx = strategy_group->node_idx; NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]); - return strategy_group->strategies[stra_idx]; + return strategy_group->GetStrategies()[stra_idx]; } // Get the final sharding strategy according to the ILP solution. @@ -147,13 +147,13 @@ inline const ShardingStrategy& GetShardingStrategyForTuple( const StrategyGroup* strategy_group = strategy_map.at(inst).get(); CHECK(strategy_group->is_tuple); for (auto index_element : index) { - CHECK_LT(index_element, strategy_group->childs.size()); - const auto& strategies = strategy_group->childs[index_element]; + CHECK_LT(index_element, strategy_group->GetChildren().size()); + const auto& strategies = strategy_group->GetChildren()[index_element]; strategy_group = strategies.get(); } NodeIdx node_idx = strategy_group->node_idx; NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]); - return strategy_group->strategies[stra_idx]; + return strategy_group->GetStrategies()[stra_idx]; } } // namespace spmd diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index 9224da821db47..15d0a03b0592f 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -34,7 +34,6 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/array.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" @@ -51,6 +50,7 @@ limitations under the License. #include "xla/service/dot_as_convolution_util.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/sharding_propagation.h" +#include "xla/shape.h" #include "tsl/platform/errors.h" namespace xla { @@ -331,15 +331,15 @@ void HandlerBase::AppendNewStrategy(const std::string& name, for (int i = 0; i < ins_->operand_count(); ++i) { const HloInstruction* operand = ins_->operand(i); + const Shape& operand_shape = operand->shape(); + const StrategyGroup& operand_strategy_group = *strategy_map_.at(operand); communication_resharding_costs.push_back(CommunicationReshardingCostVector( - strategy_map_.at(operand).get(), operand->shape(), input_specs[i], - cluster_env_)); + operand_strategy_group, operand_shape, input_specs[i], cluster_env_)); memory_resharding_costs.push_back(MemoryReshardingCostVector( - strategy_map_.at(operand).get(), operand->shape(), input_specs[i], - cluster_env_)); + operand_strategy_group, operand_shape, input_specs[i], cluster_env_)); } - strategy_group_->strategies.push_back(ShardingStrategy({ + strategy_group_->AddStrategy(ShardingStrategy({ name, output_spec, compute_cost, @@ -462,15 +462,19 @@ std::optional HandlerBase::GetShardingFromUser( } void HandlerBase::SortStrategies() { + auto strategies = strategy_group_->GetStrategies(); absl::c_stable_sort( - strategy_group_->strategies, - [](const ShardingStrategy& s1, const ShardingStrategy& s2) { + strategies, [](const ShardingStrategy& s1, const ShardingStrategy& s2) { if (s1.memory_cost == s2.memory_cost) { return s1.name < s2.name; } else { return s1.memory_cost < s2.memory_cost; } }); + strategy_group_->ClearStrategies(); + for (const ShardingStrategy& strategy : strategies) { + strategy_group_->AddStrategy(strategy); + } } /************** DotHandler function definitions **************/ @@ -962,8 +966,8 @@ absl::Status ConvHandler::RegisterStrategies() { // and only keep the data parallel strategies. if (option_.force_batch_dim_to_mesh_dim >= 0 && batch_map_.contains(GetBatchDimMapKey(ins_))) { - TF_RETURN_IF_ERROR(FilterStrategy(ins_, ins_->shape(), strategy_group_, - cluster_env_, batch_map_, option_)); + TF_RETURN_IF_ERROR(FilterStrategy(ins_, ins_->shape(), cluster_env_, + batch_map_, option_, *strategy_group_)); } SortStrategies(); diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 0cc623562cc96..3f1a9727cf210 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -279,7 +279,6 @@ BuildStrategyAndCost( VLOG(5) << "Following while input " << while_input_tuple->name(); strategy_group = CreateTupleStrategyGroup(instruction_id); - strategy_group->childs.reserve(ins->shape().tuple_shapes_size()); // We use this following relationship to ensure that the input tuple // of the while loop, and the parameter of the body of that while // loop. Therefore, this followinf relationship is necessary for @@ -288,11 +287,11 @@ BuildStrategyAndCost( for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) { std::unique_ptr child_strategies = MaybeFollowInsStrategyGroup( - while_input_tuple_strategy_group->childs[i].get(), + *while_input_tuple_strategy_group->GetChildren()[i], ins->shape().tuple_shapes().at(i), instruction_id, strategy_groups, cluster_env, pretrimmed_strategy_map); child_strategies->tuple_element_idx = i; - strategy_group->childs.push_back(std::move(child_strategies)); + strategy_group->AddChild(std::move(child_strategies)); } } else { strategy_group = @@ -321,8 +320,8 @@ BuildStrategyAndCost( case HloOpcode::kConstant: { strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, strategy_groups); - AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, 0); + AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, 0, + {}, *strategy_group); break; } case HloOpcode::kScatter: { @@ -344,7 +343,7 @@ BuildStrategyAndCost( ins, scatter_sharding, strategy_map, cluster_env, call_graph, input_shardings_optional); - strategy_group->strategies.push_back(ShardingStrategy( + strategy_group->AddStrategy(ShardingStrategy( {name, scatter_sharding, compute_cost, communication_cost, memory_cost, std::move(resharding_costs.first), std::move(resharding_costs.second), input_shardings_optional})); @@ -356,9 +355,9 @@ BuildStrategyAndCost( const HloInstruction* scatter_update = scatter->scatter_updates()[0]; ForEachInCartesianProduct( - {strategy_map.at(scatter_data)->strategies, - strategy_map.at(scatter_indices)->strategies, - strategy_map.at(scatter_update)->strategies}, + {strategy_map.at(scatter_data)->GetStrategies(), + strategy_map.at(scatter_indices)->GetStrategies(), + strategy_map.at(scatter_update)->GetStrategies()}, [&](const std::vector& operand_shardings) { GenerateScatterShardingFromOperands( scatter, operand_shardings[0].output_sharding, @@ -396,7 +395,7 @@ BuildStrategyAndCost( ins, output_sharding, strategy_map, cluster_env, call_graph, input_shardings_optional); - strategy_group->strategies.push_back(ShardingStrategy( + strategy_group->AddStrategy(ShardingStrategy( {std::string(output_sharding.ToString()), output_sharding, compute_cost, communication_cost, memory_cost, std::move(resharding_costs.first), @@ -404,7 +403,7 @@ BuildStrategyAndCost( }; for (const ShardingStrategy& indices_strategy : - indices_strategy_group->strategies) { + indices_strategy_group->GetStrategies()) { const HloSharding& indices_spec = indices_strategy.output_sharding; const HloSharding& indices_to_combine_spec = hlo_sharding_util:: GatherOutputShardingFromIndexIndexPassthroughDimensions( @@ -420,7 +419,7 @@ BuildStrategyAndCost( } for (const ShardingStrategy& data_strategy : - data_strategy_group->strategies) { + data_strategy_group->GetStrategies()) { const HloSharding& data_spec = data_strategy.output_sharding; auto gather_parallel_dims = hlo_sharding_util::GetGatherParallelBatchDims(*ins, call_graph); @@ -499,9 +498,9 @@ BuildStrategyAndCost( } } } - AddReplicatedStrategy( - ins, ins->shape(), cluster_env, strategy_map, strategy_group, 0, - /* operands_to_consider_all_strategies_for */ {0}); + AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, 0, + /* operands_to_consider_all_strategies_for */ {0}, + *strategy_group); break; } case HloOpcode::kBroadcast: { @@ -530,15 +529,13 @@ BuildStrategyAndCost( const HloInstruction* operand = ins->operand(0); // Create follow strategies - const StrategyGroup* src_strategy_group = - strategy_map.at(operand).get(); - CHECK(!src_strategy_group->is_tuple); - strategy_group->following = src_strategy_group; + const StrategyGroup& src_strategy_group = *strategy_map.at(operand); + CHECK(!src_strategy_group.is_tuple); + strategy_group->following = &src_strategy_group; - for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); - ++sid) { + for (const auto& strategy : src_strategy_group.GetStrategies()) { HloSharding output_spec = Undefined(); - auto input_spec = src_strategy_group->strategies[sid].output_sharding; + const HloSharding& input_spec = strategy.output_sharding; if (opcode == HloOpcode::kTranspose) { output_spec = hlo_sharding_util::TransposeSharding( input_spec, ins->dimensions()); @@ -558,7 +555,7 @@ BuildStrategyAndCost( std::vector memory_resharding_costs = MemoryReshardingCostVector(src_strategy_group, operand->shape(), input_spec, cluster_env); - strategy_group->strategies.push_back( + strategy_group->AddStrategy( ShardingStrategy({name, output_spec, compute_cost, @@ -609,11 +606,9 @@ BuildStrategyAndCost( CHECK(!src_strategy_group->is_tuple); strategy_group->following = src_strategy_group; - for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); - ++sid) { + for (const auto& strategy : src_strategy_group->GetStrategies()) { std::optional output_spec; - HloSharding input_spec = - src_strategy_group->strategies[sid].output_sharding; + const HloSharding& input_spec = strategy.output_sharding; double compute_cost = 0, communication_cost = 0; // Find output shardings. @@ -698,7 +693,7 @@ BuildStrategyAndCost( ins, *output_spec, strategy_map, cluster_env, call_graph, input_shardings); - strategy_group->strategies.push_back( + strategy_group->AddStrategy( ShardingStrategy({name, *output_spec, compute_cost, @@ -709,18 +704,18 @@ BuildStrategyAndCost( {input_spec}})); } - if (strategy_group->strategies.empty()) { + if (strategy_group->GetStrategies().empty()) { strategy_group->following = nullptr; - AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, 0); + AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, 0, + {}, *strategy_group); } break; } case HloOpcode::kOptimizationBarrier: { - auto operand_strategies = strategy_map.at(ins->operand(0)).get(); + const auto& operand_strategy_group = *strategy_map.at(ins->operand(0)); strategy_group = MaybeFollowInsStrategyGroup( - operand_strategies, ins->shape(), instruction_id, strategy_groups, - cluster_env, pretrimmed_strategy_map); + operand_strategy_group, ins->shape(), instruction_id, + strategy_groups, cluster_env, pretrimmed_strategy_map); break; } case HloOpcode::kBitcast: { @@ -814,9 +809,10 @@ BuildStrategyAndCost( if (option.allow_recompute_heavy_op) { AddReplicatedStrategy( - ins, ins->shape(), cluster_env, strategy_map, strategy_group, + ins, ins->shape(), cluster_env, strategy_map, GetDotConvReplicationPenalty(ins, instruction_id, /* window */ 10, - sequence, hlo_cost_analysis)); + sequence, hlo_cost_analysis), + {}, *strategy_group); } break; } @@ -827,17 +823,18 @@ BuildStrategyAndCost( batch_dim_map, option, call_graph)); if (option.allow_recompute_heavy_op) { AddReplicatedStrategy( - ins, ins->shape(), cluster_env, strategy_map, strategy_group, + ins, ins->shape(), cluster_env, strategy_map, GetDotConvReplicationPenalty(ins, instruction_id, /* window */ 10, - sequence, hlo_cost_analysis)); + sequence, hlo_cost_analysis), + {}, *strategy_group); } break; } case HloOpcode::kRngGetAndUpdateState: { strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, strategy_groups); - AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, 0); + AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, 0, + {}, *strategy_group); break; } case HloOpcode::kIota: { @@ -853,16 +850,14 @@ BuildStrategyAndCost( } case HloOpcode::kTuple: { strategy_group = CreateTupleStrategyGroup(instruction_id); - strategy_group->childs.reserve(ins->operand_count()); for (size_t i = 0; i < ins->operand_count(); ++i) { const HloInstruction* operand = ins->operand(i); - const StrategyGroup* src_strategy_group = - strategy_map.at(operand).get(); + const StrategyGroup& src_strategy_group = *strategy_map.at(operand); auto child_strategies = MaybeFollowInsStrategyGroup( src_strategy_group, operand->shape(), instruction_id, strategy_groups, cluster_env, pretrimmed_strategy_map); child_strategies->tuple_element_idx = i; - strategy_group->childs.push_back(std::move(child_strategies)); + strategy_group->AddChild(std::move(child_strategies)); } if (ins->users().size() == 1 && @@ -878,13 +873,12 @@ BuildStrategyAndCost( } case HloOpcode::kGetTupleElement: { const HloInstruction* operand = ins->operand(0); - const StrategyGroup* src_strategy_group = - strategy_map.at(operand).get(); - CHECK(src_strategy_group->is_tuple); + const StrategyGroup& src_strategy_group = *strategy_map.at(operand); + CHECK(src_strategy_group.is_tuple); + const auto& src_children = src_strategy_group.GetChildren(); strategy_group = MaybeFollowInsStrategyGroup( - src_strategy_group->childs[ins->tuple_index()].get(), ins->shape(), - instruction_id, strategy_groups, cluster_env, - pretrimmed_strategy_map); + *src_children[ins->tuple_index()], ins->shape(), instruction_id, + strategy_groups, cluster_env, pretrimmed_strategy_map); break; } case HloOpcode::kCustomCall: { @@ -895,8 +889,6 @@ BuildStrategyAndCost( if (only_replicated) { if (ins->shape().IsTuple()) { strategy_group = CreateTupleStrategyGroup(instruction_id); - strategy_group->childs.reserve( - ins->shape().tuple_shapes_size()); for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) { std::unique_ptr child_strategies = @@ -904,16 +896,16 @@ BuildStrategyAndCost( strategy_map, strategy_groups); AddReplicatedStrategy(ins, ins->shape().tuple_shapes(i), cluster_env, strategy_map, - child_strategies, replicated_penalty); - strategy_group->childs.push_back( - std::move(child_strategies)); + replicated_penalty, {}, + *child_strategies); + strategy_group->AddChild(std::move(child_strategies)); } } else { strategy_group = CreateLeafStrategyGroup( instruction_id, ins, strategy_map, strategy_groups); AddReplicatedStrategy(ins, ins->shape(), cluster_env, - strategy_map, strategy_group, - replicated_penalty); + strategy_map, replicated_penalty, {}, + *strategy_group); } return; } @@ -951,8 +943,7 @@ BuildStrategyAndCost( // Follows operand 0's strategies if this custom-call op is // shardable and has the same input and output sizes. const HloInstruction* operand = ins->operand(0); - const StrategyGroup* src_strategy_group = - strategy_map.at(operand).get(); + const StrategyGroup& src_strategy_group = *strategy_map.at(operand); strategy_group = MaybeFollowInsStrategyGroup( src_strategy_group, ins->shape(), instruction_id, strategy_groups, cluster_env, pretrimmed_strategy_map); @@ -967,16 +958,15 @@ BuildStrategyAndCost( } case HloOpcode::kWhile: { strategy_group = CreateTupleStrategyGroup(instruction_id); - strategy_group->childs.reserve(ins->shape().tuple_shapes_size()); - const StrategyGroup* src_strategy_group = - strategy_map.at(ins->operand(0)).get(); + const auto& src_strategy_group = *strategy_map.at(ins->operand(0)); + const auto& src_children = src_strategy_group.GetChildren(); for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) { auto child_strategies = MaybeFollowInsStrategyGroup( - src_strategy_group->childs[i].get(), - ins->shape().tuple_shapes().at(i), instruction_id, - strategy_groups, cluster_env, pretrimmed_strategy_map); + *src_children[i], ins->shape().tuple_shapes().at(i), + instruction_id, strategy_groups, cluster_env, + pretrimmed_strategy_map); child_strategies->tuple_element_idx = i; - strategy_group->childs.push_back(std::move(child_strategies)); + strategy_group->AddChild(std::move(child_strategies)); } break; @@ -998,54 +988,55 @@ BuildStrategyAndCost( strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, strategy_groups); GenerateOutfeedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, replicated_penalty); + replicated_penalty, *strategy_group); break; } case HloOpcode::kRecv: case HloOpcode::kRecvDone: case HloOpcode::kSend: { strategy_group = CreateTupleStrategyGroup(instruction_id); - strategy_group->childs.reserve(ins->shape().tuple_shapes_size()); for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) { std::unique_ptr child_strategies = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, strategy_groups); AddReplicatedStrategy(ins, ins->shape().tuple_shapes(i), cluster_env, - strategy_map, child_strategies, 0); + strategy_map, 0, {}, *child_strategies); child_strategies->tuple_element_idx = i; - strategy_group->childs.push_back(std::move(child_strategies)); + strategy_group->AddChild(std::move(child_strategies)); } break; } case HloOpcode::kSendDone: { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, strategy_groups); - AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, 0); + AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, 0, + {}, *strategy_group); break; } case HloOpcode::kAfterAll: { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, strategy_groups); AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategy_group, replicated_penalty); + replicated_penalty, {}, *strategy_group); break; } default: LOG(FATAL) << "Unhandled instruction: " + ins->ToString(); } - RemoveDuplicatedStrategy(strategy_group); + CHECK(strategy_group != nullptr); + RemoveDuplicatedStrategy(*strategy_group); if (ins->has_sharding() && ins->opcode() != HloOpcode::kOutfeed) { // Finds the sharding strategy that aligns with the given sharding spec // Do not merge nodes if this one instruction has annotations. TrimOrGenerateStrategiesBasedOnExistingSharding( - ins->shape(), strategy_group.get(), strategy_map, instructions, - ins->sharding(), cluster_env, pretrimmed_strategy_map, call_graph, - option.nd_sharding_iteratively_strict_search_space); + ins->shape(), strategy_map, instructions, ins->sharding(), + cluster_env, pretrimmed_strategy_map, call_graph, + option.nd_sharding_iteratively_strict_search_space, *strategy_group); } if (!strategy_group->is_tuple && strategy_group->following) { - if (!LeafVectorsAreConsistent(strategy_group->strategies, - strategy_group->following->strategies)) { + if (!LeafVectorsAreConsistent( + strategy_group->GetStrategies(), + strategy_group->following->GetStrategies())) { // It confuses the solver if two instructions have different number of // sharding strategies but share the same ILP variable. The solver would // run much longer and/or return infeasible solutions. So if two @@ -1056,27 +1047,27 @@ BuildStrategyAndCost( strategy_group->following = nullptr; } } else if (strategy_group->is_tuple) { - for (size_t i = 0; i < strategy_group->childs.size(); i++) { - if (strategy_group->childs.at(i)->following && - !LeafVectorsAreConsistent( - strategy_group->childs.at(i)->strategies, - strategy_group->childs.at(i)->following->strategies)) { + for (size_t i = 0; i < strategy_group->GetChildren().size(); i++) { + auto& child = strategy_group->GetChildren().at(i); + if (child->following && + !LeafVectorsAreConsistent(child->GetStrategies(), + child->following->GetStrategies())) { CHECK(!is_follow_necessary_for_correctness) << "Reverting a following decision that is necessary for " "correctness. Please report this as a bug."; - strategy_group->childs.at(i)->following = nullptr; + child->following = nullptr; } } } if (!option.allow_shardings_small_dims_across_many_devices) { RemoveShardingsWhereSmallDimsShardedAcrossManyDevices( - ins->shape(), strategy_group.get(), - /* instruction_has_user_sharding */ ins->has_sharding()); + ins->shape(), /* instruction_has_user_sharding */ ins->has_sharding(), + *strategy_group); } if (instruction_execution_counts.contains(ins)) { - ScaleCostsWithExecutionCounts(strategy_group.get(), - instruction_execution_counts.at(ins)); + ScaleCostsWithExecutionCounts(instruction_execution_counts.at(ins), + *strategy_group); } else { VLOG(5) << "No execution count available for " << ins->name(); } @@ -1093,12 +1084,15 @@ BuildStrategyAndCost( CHECK(!strategy_group->is_tuple); std::vector new_strategies; int64_t idx = it - inst_indices.begin(); - for (const auto& stra : strategy_group->strategies) { - if (stra.name == stra_names[idx]) { - new_strategies.push_back(stra); + for (const auto& strategy : strategy_group->GetStrategies()) { + if (strategy.name == stra_names[idx]) { + new_strategies.push_back(strategy); } } - strategy_group->strategies = std::move(new_strategies); + strategy_group->ClearStrategies(); + for (const ShardingStrategy& strategy : new_strategies) { + strategy_group->AddStrategy(strategy); + } } } @@ -1109,10 +1103,11 @@ BuildStrategyAndCost( // the mesh shape we're trying does not match with the mesh shape used in // user specified shardings. So we disable the check in that situation. if (!trying_multiple_mesh_shapes) { - CHECK(strategy_group->is_tuple || !strategy_group->strategies.empty()) + CHECK(strategy_group->is_tuple || + !strategy_group->GetStrategies().empty()) << ins->ToString() << " does not have any valid strategies."; } else if (!(strategy_group->is_tuple || - !strategy_group->strategies.empty())) { + !strategy_group->GetStrategies().empty())) { return absl::Status( absl::StatusCode::kFailedPrecondition, "Could not generate any shardings for an instruction due " @@ -1121,7 +1116,7 @@ BuildStrategyAndCost( // Checks the shape of resharding_costs is valid. It will check fail if the // shape is not as expected. // CheckReshardingCostsShape(strategies.get()); - CheckMemoryCosts(strategy_group.get(), ins->shape()); + CheckMemoryCosts(*strategy_group, ins->shape()); strategy_map[ins] = std::move(strategy_group); } // end of for loop diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h index 9777b109e22b9..3a17e6f3d46b9 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h @@ -183,18 +183,33 @@ struct StrategyGroup { std::vector in_nodes; // The followed strategy. Used for merging nodes. const StrategyGroup* following = nullptr; - // Used when is_tuple == False. Leaf strategy vector. - // A vector of strategy choices for the non-tuple output. - std::vector strategies; - // Used when is_tuple == True. A vector of pointers, each pointer is one - // StrategyGroup for one value in the output Tuple - std::vector> childs; // The index of this instruction in the HLO operand (or tuple shape) list. std::optional tuple_element_idx; - std::string ToString(size_t indention = 0) const { + StrategyGroup() = default; + + StrategyGroup(bool is_tuple, NodeIdx node_idx, size_t instruction_id) + : is_tuple(is_tuple), + node_idx(node_idx), + instruction_id(instruction_id) {} + + StrategyGroup(bool is_tuple, NodeIdx node_idx, size_t instruction_id, + const std::vector& in_nodes, + const StrategyGroup* following, + const std::vector& strategies) + : is_tuple(is_tuple), + node_idx(node_idx), + instruction_id(instruction_id), + in_nodes(in_nodes), + following(following) { + for (const ShardingStrategy& strategy : strategies) { + AddStrategy(strategy); + } + } + + std::string ToString(size_t indentation = 0) const { std::string str; - const std::string indent(indention, ' '); + const std::string indent(indentation, ' '); absl::StrAppend(&str, indent, "node_idx: ", node_idx, "\n"); absl::StrAppend(&str, indent, "instruction id: ", instruction_id, "\n"); absl::StrAppend(&str, indent, "is_tuple: ", is_tuple, "\n"); @@ -214,9 +229,9 @@ struct StrategyGroup { " instruction_id=", i->instruction_id, "\n"); } if (is_tuple) { - for (size_t i = 0; i < childs.size(); ++i) { + for (size_t i = 0; i < children.size(); ++i) { absl::StrAppend(&str, indent, "Tuple element #", i, ":\n"); - absl::StrAppend(&str, childs[i]->ToString(indention + 2)); + absl::StrAppend(&str, children[i]->ToString(indentation + 2)); } } else { for (const auto& strategy : strategies) { @@ -229,8 +244,8 @@ struct StrategyGroup { const StrategyGroup* GetSubStrategyGroup(const ShapeIndex& index) const { const StrategyGroup* result = this; for (auto index_element : index) { - CHECK_LE(index_element, result->childs.size()); - result = result->childs.at(index_element).get(); + CHECK_LE(index_element, result->children.size()); + result = result->children.at(index_element).get(); } return result; } @@ -238,7 +253,7 @@ struct StrategyGroup { void ForEachLeafStrategyGroup( absl::FunctionRef fn) const { if (is_tuple) { - for (const std::unique_ptr& child : childs) { + for (const std::unique_ptr& child : children) { fn(*child); } } else { @@ -248,13 +263,51 @@ struct StrategyGroup { void ForEachLeafStrategyGroup(absl::FunctionRef fn) { if (is_tuple) { - for (std::unique_ptr& child : childs) { + for (std::unique_ptr& child : children) { fn(*child); } } else { fn(*this); } } + + //////// Accessor methods for strategies //////// + + void AddStrategy(const ShardingStrategy& strategy) { + strategies.push_back(strategy); + } + + void ClearStrategies() { strategies.clear(); } + + ShardingStrategy& GetStrategy(size_t strategy_idx) { + return strategies[strategy_idx]; + } + + const std::vector& GetStrategies() const { + return strategies; + } + + //////// Accessor methods for children //////// + + void AddChild(std::unique_ptr child) { + children.push_back(std::move(child)); + } + + void ClearChildren() { children.clear(); } + + StrategyGroup& GetChild(size_t child_idx) { return *children[child_idx]; } + + const std::vector>& GetChildren() const { + return children; + } + + private: + // Used when is_tuple == False. Leaf strategy vector. + // A vector of strategy choices for the non-tuple output. + std::vector strategies; + // Used when is_tuple == True. A vector of pointers, each pointer is one + // StrategyGroup for one value in the output Tuple + std::vector> children; }; // Type aliases. diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index 7dc7a456957bc..d12de663a115d 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -826,30 +826,28 @@ bool AllInfinityCosts( // that were not intended to be replicated when being generating, but ending up // being replicated, which could happen when, for example, generating 2D // sharding for a 1D mesh shape. -void RemoveDuplicatedStrategy(std::unique_ptr& strategy_group) { - if (strategy_group->is_tuple) { - for (auto& child : strategy_group->childs) { - RemoveDuplicatedStrategy(child); +void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) { + if (strategy_group.is_tuple) { + for (auto& child : strategy_group.GetChildren()) { + RemoveDuplicatedStrategy(*child); } return; } - if (strategy_group->following || strategy_group->strategies.empty()) { + if (strategy_group.following || strategy_group.GetStrategies().empty()) { return; } std::vector new_vector; std::vector deduped_replicated_strategies; absl::flat_hash_set added; size_t num_skipped_due_to_infinity_costs = 0; - for (size_t i = 0; i < strategy_group->strategies.size(); ++i) { - if (AllInfinityCosts( - strategy_group->strategies[i].communication_resharding_costs)) { + for (const ShardingStrategy& strategy : strategy_group.GetStrategies()) { + if (AllInfinityCosts(strategy.communication_resharding_costs)) { num_skipped_due_to_infinity_costs++; continue; } - std::string key = strategy_group->strategies[i].output_sharding.ToString(); - if (!strategy_group->strategies[i].input_shardings.empty()) { - for (const auto& sharding : - strategy_group->strategies[i].input_shardings) { + std::string key = strategy.output_sharding.ToString(); + if (!strategy.input_shardings.empty()) { + for (const auto& sharding : strategy.input_shardings) { key += "/" + (sharding.has_value() ? sharding->ToString() : "none"); } } @@ -857,14 +855,14 @@ void RemoveDuplicatedStrategy(std::unique_ptr& strategy_group) { continue; } added.insert(key); - if (!strategy_group->strategies[i].output_sharding.IsReplicated()) { - new_vector.push_back(std::move(strategy_group->strategies[i])); + if (!strategy.output_sharding.IsReplicated()) { + new_vector.push_back(strategy); } else { - deduped_replicated_strategies.push_back( - std::move(strategy_group->strategies[i])); + deduped_replicated_strategies.push_back(strategy); } } - CHECK_LT(num_skipped_due_to_infinity_costs, strategy_group->strategies.size()) + CHECK_LT(num_skipped_due_to_infinity_costs, + strategy_group.GetStrategies().size()) << "All strategies removed due to infinite resharding costs"; // Keeps replicated strategies as the last ones. if (!deduped_replicated_strategies.empty()) { @@ -872,7 +870,10 @@ void RemoveDuplicatedStrategy(std::unique_ptr& strategy_group) { new_vector.push_back(std::move(deduped_replicated_strategies[i])); } } - strategy_group->strategies = std::move(new_vector); + strategy_group.ClearStrategies(); + for (const ShardingStrategy& strategy : new_vector) { + strategy_group.AddStrategy(strategy); + } } bool IsDivisible(const HloInstruction* ins, const DeviceMesh& device_mesh, @@ -1770,13 +1771,13 @@ AliasSet BuildAliasSet(const HloModule* module, traverse_tuple_alias; traverse_tuple_alias = [&](const StrategyGroup* src_strategy_group, const StrategyGroup* dst_strategy_group) { + const auto& src_children = src_strategy_group->GetChildren(); + const auto& dst_children = dst_strategy_group->GetChildren(); if (src_strategy_group->is_tuple) { CHECK(dst_strategy_group->is_tuple); - CHECK_EQ(src_strategy_group->childs.size(), - dst_strategy_group->childs.size()); - for (size_t i = 0; i < src_strategy_group->childs.size(); ++i) { - traverse_tuple_alias(src_strategy_group->childs[i].get(), - dst_strategy_group->childs[i].get()); + CHECK_EQ(src_children.size(), dst_children.size()); + for (size_t i = 0; i < src_children.size(); ++i) { + traverse_tuple_alias(src_children[i].get(), dst_children[i].get()); } } else { alias_set.insert( @@ -1794,17 +1795,19 @@ AliasSet BuildAliasSet(const HloModule* module, HloInstruction* param_ins = parameter_instructions[alias.parameter_number]; if (alias.parameter_index.empty()) { - traverse_tuple_alias( - strategy_map.at(param_ins).get(), - strategy_map.at(output_tuple)->childs[output_index.front()].get()); + traverse_tuple_alias(strategy_map.at(param_ins).get(), + strategy_map.at(output_tuple) + ->GetChildren()[output_index.front()] + .get()); } else { // parameter_instructions[alias.parameter_number] is a tuple. // alias.parameter_index.size() == 1 per the CHECK_LT statement. - traverse_tuple_alias( - strategy_map.at(param_ins) - ->childs[alias.parameter_index.front()] - .get(), - strategy_map.at(output_tuple)->childs[output_index.front()].get()); + traverse_tuple_alias(strategy_map.at(param_ins) + ->GetChildren()[alias.parameter_index.front()] + .get(), + strategy_map.at(output_tuple) + ->GetChildren()[output_index.front()] + .get()); } }); @@ -1851,13 +1854,15 @@ absl::Status CheckAliasSetCompatibility(const AliasSet& alias_set, size_t compatible_cnt = 0; bool replicated = false; - for (size_t i = 0; i < src_strategy_group->strategies.size(); ++i) { - for (size_t j = 0; j < dst_strategy_group->strategies.size(); ++j) { - if (src_strategy_group->strategies[i].output_sharding == - dst_strategy_group->strategies[j].output_sharding) { + for (size_t i = 0; i < src_strategy_group->GetStrategies().size(); ++i) { + const HloSharding& src_sharding = + src_strategy_group->GetStrategies()[i].output_sharding; + for (size_t j = 0; j < dst_strategy_group->GetStrategies().size(); ++j) { + const HloSharding& dst_sharding = + dst_strategy_group->GetStrategies()[j].output_sharding; + if (src_sharding == dst_sharding) { compatible_cnt += 1; - if (src_strategy_group->strategies[i] - .output_sharding.IsReplicated()) { + if (src_sharding.IsReplicated()) { replicated = true; } } @@ -1865,8 +1870,8 @@ absl::Status CheckAliasSetCompatibility(const AliasSet& alias_set, } if (compatible_cnt == 1 && - (replicated && (src_strategy_group->strategies.size() > 1 || - dst_strategy_group->strategies.size() > 1))) { + (replicated && (src_strategy_group->GetStrategies().size() > 1 || + dst_strategy_group->GetStrategies().size() > 1))) { LOG(WARNING) << "Alias pair has only replicated strategy in common. This " "will result in choosing replicated strategy for these " @@ -1909,13 +1914,16 @@ absl::StatusOr ComputeAliasCompatibility( const StrategyGroup* dst_strategy_group, const std::vector& instructions) { AliasCompatibility alias_compatibility; - for (size_t i = 0; i < src_strategy_group->strategies.size(); ++i) { - for (size_t j = 0; j < dst_strategy_group->strategies.size(); ++j) { - if (src_strategy_group->strategies[i].output_sharding == - dst_strategy_group->strategies[j].output_sharding) { + for (size_t i = 0; i < src_strategy_group->GetStrategies().size(); ++i) { + const HloSharding& src_sharding = + src_strategy_group->GetStrategies()[i].output_sharding; + for (size_t j = 0; j < dst_strategy_group->GetStrategies().size(); ++j) { + const HloSharding& dst_sharding = + dst_strategy_group->GetStrategies()[j].output_sharding; + if (src_sharding == dst_sharding) { alias_compatibility.src_compatible.push_back(i); alias_compatibility.dst_compatible.push_back(j); - if (src_strategy_group->strategies[i].output_sharding.IsReplicated()) { + if (src_sharding.IsReplicated()) { alias_compatibility.replicated = true; } } @@ -1923,9 +1931,10 @@ absl::StatusOr ComputeAliasCompatibility( } int compatible_cnt = alias_compatibility.src_compatible.size(); - if (compatible_cnt == 1 && (alias_compatibility.replicated && - (src_strategy_group->strategies.size() > 1 || - dst_strategy_group->strategies.size() > 1))) { + if (compatible_cnt == 1 && + (alias_compatibility.replicated && + (src_strategy_group->GetStrategies().size() > 1 || + dst_strategy_group->GetStrategies().size() > 1))) { LOG(WARNING) << "Alias pair has only replicated strategy in common. This " "will result in choosing replicated strategy for these " "tensors and may result in large memory consumption: " @@ -2024,7 +2033,7 @@ absl::Status RemoveFollowersIfMismatchedStrategies( if (auto it = followee_root_valid_strategies.find(idx); it != followee_root_valid_strategies.end()) { for (auto strategy_idx : it->second) { - if (strategy_group.strategies[strategy_idx].compute_cost != + if (strategy_group.GetStrategies()[strategy_idx].compute_cost != kInfinityCost) { return; }