diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/xla/hlo/experimental/auto_sharding/auto_sharding.cc index ea48e81f1a4601..bf1a454198a130 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -1468,9 +1468,12 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( // Sharding provided by XLA users, we need to keep them. strategy_group.following = nullptr; std::vector> new_strategies; - for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) { - const ShardingStrategy& strategy = strategy_group.GetStrategy(sid); - const auto& input_shardings = strategy_group.GetInputShardings(sid); + const auto& strategy_input_shardings = + strategy_group.GetStrategyInputShardings(); + for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) { + const InputShardings& input_shardings = strategy_input_shardings[iid]; + const ShardingStrategy& strategy = + strategy_group.GetStrategyForInputShardings(iid); if (strategy.output_sharding == existing_sharding) { VLOG(1) << "Keeping strategy: " << strategy.ToString(); new_strategies.push_back({strategy, input_shardings}); @@ -1566,9 +1569,12 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( // It is IMPORTANT that we do this only for instructions that do no follow // others, to keep the number of ILP variable small. std::vector> new_vector; - for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) { - const ShardingStrategy& strategy = strategy_group.GetStrategy(sid); - const auto& input_shardings = strategy_group.GetInputShardings(sid); + const auto& strategy_input_shardings = + strategy_group.GetStrategyInputShardings(); + for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) { + const InputShardings& input_shardings = strategy_input_shardings[iid]; + const ShardingStrategy& strategy = + strategy_group.GetStrategyForInputShardings(iid); if (strategy.output_sharding.IsReplicated() || ShardingIsConsistent(existing_sharding, strategy.output_sharding, strict) || @@ -3355,9 +3361,12 @@ absl::Status FilterStrategy(const HloInstruction* ins, const Shape& shape, } std::vector> new_strategies; - for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) { - const ShardingStrategy& strategy = strategy_group.GetStrategy(sid); - const auto& input_shardings = strategy_group.GetInputShardings(sid); + const auto& strategy_input_shardings = + strategy_group.GetStrategyInputShardings(); + for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) { + const InputShardings& input_shardings = strategy_input_shardings[iid]; + const ShardingStrategy& strategy = + strategy_group.GetStrategyForInputShardings(iid); const HloSharding& output_sharding = strategy.output_sharding; const std::vector tensor_dim_to_mesh_dim = cluster_env.GetTensorDimToMeshDimWrapper(shape, output_sharding); diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h index d63e5bee65007f..4190993d33c15c 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h @@ -147,7 +147,7 @@ inline const InputShardings& GetInputShardings( CHECK(!strategy_group->is_tuple); NodeIdx node_idx = strategy_group->node_idx; NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]); - return strategy_group->GetInputShardings(stra_idx); + return strategy_group->GetInputShardingsForStrategy(stra_idx); } // Get the final sharding strategy according to the ILP solution. @@ -181,7 +181,7 @@ inline const InputShardings& GetInputShardingsForTuple( } NodeIdx node_idx = strategy_group->node_idx; NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]); - return strategy_group->GetInputShardings(stra_idx); + return strategy_group->GetInputShardingsForStrategy(stra_idx); } } // namespace spmd diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index 85898c8b4e6dfd..0c2b7925d2025e 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -460,9 +460,12 @@ std::optional HandlerBase::GetShardingFromUser( void HandlerBase::SortStrategies() { std::vector> strategy_shardings; - for (size_t sid = 0; sid < strategy_group_->GetStrategies().size(); ++sid) { - const ShardingStrategy& strategy = strategy_group_->GetStrategy(sid); - const auto& input_shardings = strategy_group_->GetInputShardings(sid); + const auto strategy_input_shardings = + strategy_group_->GetStrategyInputShardings(); + for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) { + const InputShardings& input_shardings = strategy_input_shardings[iid]; + const ShardingStrategy& strategy = + strategy_group_->GetStrategyForInputShardings(iid); strategy_shardings.push_back({strategy, input_shardings}); } absl::c_stable_sort( diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 121917b3cc0b41..b40c1211df8c2b 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -1085,10 +1085,12 @@ BuildStrategyAndCost( CHECK(!strategy_group->is_tuple); std::vector> new_strategies; int64_t idx = it - inst_indices.begin(); - const auto& strategies = strategy_group->GetStrategies(); - for (size_t sid = 0; sid < strategies.size(); ++sid) { - const ShardingStrategy& strategy = strategy_group->GetStrategy(sid); - const auto& input_shardings = strategy_group->GetInputShardings(sid); + const auto& strategy_input_shardings = + strategy_group->GetStrategyInputShardings(); + for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) { + const InputShardings& input_shardings = strategy_input_shardings[iid]; + const ShardingStrategy& strategy = + strategy_group->GetStrategyForInputShardings(iid); if (strategy.name == stra_names[idx]) { new_strategies.push_back({strategy, input_shardings}); } diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h index 8327f6b3e58ef4..c78e5dc0a68a69 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h @@ -16,8 +16,10 @@ limitations under the License. #ifndef XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_STRATEGY_H_ #define XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_STRATEGY_H_ +#include #include #include +#include #include #include #include @@ -253,21 +255,49 @@ struct StrategyGroup { void AddStrategy(const ShardingStrategy& strategy, const InputShardings& input_shardings = {}) { - strategies.push_back(strategy); + // Create a new strategy if needed, otherwise reuse an existing one. + size_t strategy_idx = strategies.size(); + const size_t input_sharding_idx = strategy_input_shardings.size(); + const auto it = std::find(strategies.begin(), strategies.end(), strategy); + if (it == strategies.end()) { + strategies.push_back(strategy); + strategy_idx_to_input_sharding_idx.push_back(input_sharding_idx); + } else { + strategy_idx = std::distance(strategies.begin(), it); + } + input_sharding_idx_to_strategy_idx.push_back(strategy_idx); strategy_input_shardings.push_back(input_shardings); } void ClearStrategies() { strategies.clear(); strategy_input_shardings.clear(); + input_sharding_idx_to_strategy_idx.clear(); + strategy_idx_to_input_sharding_idx.clear(); } ShardingStrategy& GetStrategy(size_t strategy_idx) { return strategies[strategy_idx]; } - const InputShardings& GetInputShardings(size_t strategy_idx) const { - return strategy_input_shardings[strategy_idx]; + const ShardingStrategy& GetStrategyForInputShardings( + size_t input_sharding_idx) const { + const size_t strategy_idx = + input_sharding_idx_to_strategy_idx[input_sharding_idx]; + CHECK_LT(strategy_idx, strategies.size()); + return strategies[strategy_idx]; + } + + const InputShardings& GetInputShardings(size_t input_sharding_idx) const { + return strategy_input_shardings[input_sharding_idx]; + } + + const InputShardings& GetInputShardingsForStrategy( + size_t strategy_idx) const { + const size_t input_sharding_idx = + strategy_idx_to_input_sharding_idx[strategy_idx]; + CHECK_LT(input_sharding_idx, strategy_input_shardings.size()); + return strategy_input_shardings[input_sharding_idx]; } const std::vector& GetStrategies() const { @@ -297,6 +327,8 @@ struct StrategyGroup { // A vector of strategy choices for the non-tuple output. std::vector strategies; std::vector strategy_input_shardings; + std::vector input_sharding_idx_to_strategy_idx; + std::vector strategy_idx_to_input_sharding_idx; // Used when is_tuple == True. A vector of pointers, each pointer is one // StrategyGroup for one value in the output Tuple diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index adcc185f969ddd..b1176a85e732db 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -841,14 +841,17 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) { deduped_replicated_strategies; absl::flat_hash_set added; size_t num_skipped_due_to_infinity_costs = 0; - for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) { - const ShardingStrategy& strategy = strategy_group.GetStrategy(sid); + const auto& strategy_input_shardings = + strategy_group.GetStrategyInputShardings(); + for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) { + const InputShardings& input_shardings = strategy_input_shardings[iid]; + const ShardingStrategy& strategy = + strategy_group.GetStrategyForInputShardings(iid); if (AllInfinityCosts(strategy.communication_resharding_costs)) { num_skipped_due_to_infinity_costs++; continue; } std::string key = strategy.output_sharding.ToString(); - const auto& input_shardings = strategy_group.GetInputShardings(sid); if (!input_shardings.empty()) { for (const auto& sharding : input_shardings) { key += "/" + (sharding.has_value() ? sharding->ToString() : "none"); @@ -864,8 +867,7 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) { deduped_replicated_strategies.push_back({strategy, input_shardings}); } } - CHECK_LT(num_skipped_due_to_infinity_costs, - strategy_group.GetStrategies().size()) + CHECK_LT(num_skipped_due_to_infinity_costs, strategy_input_shardings.size()) << "All strategies removed due to infinite resharding costs"; // Keeps replicated strategies as the last ones. if (!deduped_replicated_strategies.empty()) {