diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 350f181890e47..ec9dc5f0916c5 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -1476,9 +1476,12 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( // Sharding provided by XLA users, we need to keep them. strategy_group.following = nullptr; std::vector> new_strategies; - for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) { - const ShardingStrategy& strategy = strategy_group.GetStrategy(sid); - const auto& input_shardings = strategy_group.GetInputShardings(sid); + const auto& strategy_input_shardings = + strategy_group.GetStrategyInputShardings(); + for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) { + const InputShardings& input_shardings = strategy_input_shardings[iid]; + const ShardingStrategy& strategy = + strategy_group.GetStrategyForInputShardings(iid); if (strategy.output_sharding == existing_sharding) { VLOG(1) << "Keeping strategy: " << strategy.ToString(); new_strategies.push_back({strategy, input_shardings}); @@ -1574,9 +1577,12 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( // It is IMPORTANT that we do this only for instructions that do no follow // others, to keep the number of ILP variable small. std::vector> new_vector; - for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) { - const ShardingStrategy& strategy = strategy_group.GetStrategy(sid); - const auto& input_shardings = strategy_group.GetInputShardings(sid); + const auto& strategy_input_shardings = + strategy_group.GetStrategyInputShardings(); + for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) { + const InputShardings& input_shardings = strategy_input_shardings[iid]; + const ShardingStrategy& strategy = + strategy_group.GetStrategyForInputShardings(iid); if (strategy.output_sharding.IsReplicated() || ShardingIsConsistent(existing_sharding, strategy.output_sharding, strict) || @@ -3247,9 +3253,12 @@ absl::Status FilterStrategy(const HloInstruction* ins, const Shape& shape, } std::vector> new_strategies; - for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) { - const ShardingStrategy& strategy = strategy_group.GetStrategy(sid); - const auto& input_shardings = strategy_group.GetInputShardings(sid); + const auto& strategy_input_shardings = + strategy_group.GetStrategyInputShardings(); + for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) { + const InputShardings& input_shardings = strategy_input_shardings[iid]; + const ShardingStrategy& strategy = + strategy_group.GetStrategyForInputShardings(iid); const HloSharding& output_sharding = strategy.output_sharding; const std::vector tensor_dim_to_mesh_dim = cluster_env.GetTensorDimToMeshDimWrapper(shape, output_sharding); diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h index d63e5bee65007..4190993d33c15 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h @@ -147,7 +147,7 @@ inline const InputShardings& GetInputShardings( CHECK(!strategy_group->is_tuple); NodeIdx node_idx = strategy_group->node_idx; NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]); - return strategy_group->GetInputShardings(stra_idx); + return strategy_group->GetInputShardingsForStrategy(stra_idx); } // Get the final sharding strategy according to the ILP solution. @@ -181,7 +181,7 @@ inline const InputShardings& GetInputShardingsForTuple( } NodeIdx node_idx = strategy_group->node_idx; NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]); - return strategy_group->GetInputShardings(stra_idx); + return strategy_group->GetInputShardingsForStrategy(stra_idx); } } // namespace spmd diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index 237ffa4e3d4f4..f80958b099ff4 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -460,9 +460,12 @@ std::optional HandlerBase::GetShardingFromUser( void HandlerBase::SortStrategies() { std::vector> strategy_shardings; - for (size_t sid = 0; sid < strategy_group_->GetStrategies().size(); ++sid) { - const ShardingStrategy& strategy = strategy_group_->GetStrategy(sid); - const auto& input_shardings = strategy_group_->GetInputShardings(sid); + const auto strategy_input_shardings = + strategy_group_->GetStrategyInputShardings(); + for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) { + const InputShardings& input_shardings = strategy_input_shardings[iid]; + const ShardingStrategy& strategy = + strategy_group_->GetStrategyForInputShardings(iid); strategy_shardings.push_back({strategy, input_shardings}); } absl::c_stable_sort( diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index f84cd704cda23..91d40860c7325 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -66,7 +66,13 @@ namespace spmd { bool LeafVectorsAreConsistent(const std::vector& one, const std::vector& two) { - return one.size() == two.size(); + if (one.size() != two.size()) return false; + for (size_t sid = 0; sid < one.size(); ++sid) { + const bool invalid_strategy_one = (one[sid].compute_cost >= kInfinityCost); + const bool invalid_strategy_two = (two[sid].compute_cost >= kInfinityCost); + if (invalid_strategy_one != invalid_strategy_two) return false; + } + return true; } std::optional ConstructImprovedSharding( @@ -1026,6 +1032,11 @@ BuildStrategyAndCost( } CHECK(strategy_group != nullptr); RemoveDuplicatedStrategy(*strategy_group); + if (!option.allow_shardings_small_dims_across_many_devices) { + RemoveShardingsWhereSmallDimsShardedAcrossManyDevices( + ins->shape(), /* instruction_has_user_sharding */ ins->has_sharding(), + *strategy_group); + } if (ins->has_sharding() && ins->opcode() != HloOpcode::kOutfeed) { // Finds the sharding strategy that aligns with the given sharding spec // Do not merge nodes if this one instruction has annotations. @@ -1060,11 +1071,6 @@ BuildStrategyAndCost( } } } - if (!option.allow_shardings_small_dims_across_many_devices) { - RemoveShardingsWhereSmallDimsShardedAcrossManyDevices( - ins->shape(), /* instruction_has_user_sharding */ ins->has_sharding(), - *strategy_group); - } if (instruction_execution_counts.contains(ins)) { ScaleCostsWithExecutionCounts(instruction_execution_counts.at(ins), @@ -1085,10 +1091,12 @@ BuildStrategyAndCost( CHECK(!strategy_group->is_tuple); std::vector> new_strategies; int64_t idx = it - inst_indices.begin(); - const auto& strategies = strategy_group->GetStrategies(); - for (size_t sid = 0; sid < strategies.size(); ++sid) { - const ShardingStrategy& strategy = strategy_group->GetStrategy(sid); - const auto& input_shardings = strategy_group->GetInputShardings(sid); + const auto& strategy_input_shardings = + strategy_group->GetStrategyInputShardings(); + for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) { + const InputShardings& input_shardings = strategy_input_shardings[iid]; + const ShardingStrategy& strategy = + strategy_group->GetStrategyForInputShardings(iid); if (strategy.name == stra_names[idx]) { new_strategies.push_back({strategy, input_shardings}); } diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h index 8327f6b3e58ef..04c15b20e9aa1 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h @@ -16,8 +16,10 @@ limitations under the License. #ifndef XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_STRATEGY_H_ #define XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_STRATEGY_H_ +#include #include #include +#include #include #include #include @@ -213,7 +215,32 @@ struct StrategyGroup { } } else { for (const auto& strategy : strategies) { - absl::StrAppend(&str, indent, "Strategy ", strategy.ToStringLong()); + absl::StrAppend(&str, indent, "Strategy ", strategy.ToStringLong(), + "\n"); + } + } + if (!is_tuple) { + for (const auto& input_shardings : strategy_input_shardings) { + std::string input_sharding_str = "{"; + for (const auto& s : input_shardings) { + if (!s.has_value()) { + input_sharding_str += "[*],"; + } else if (s->IsReplicated()) { + input_sharding_str += "[R],"; + } else { + if (s->ReplicateOnLastTileDim()) { + input_sharding_str += + "[" + absl::StrJoin(s->tile_assignment().dimensions(), ", ") + + "]last_tile_dim_replicate,"; + } else { + input_sharding_str += + "[" + absl::StrJoin(s->tile_assignment().dimensions(), ", ") + + "],"; + } + } + } + input_sharding_str += "}\n"; + absl::StrAppend(&str, indent, "Input Sharding ", input_sharding_str); } } return str; @@ -253,21 +280,49 @@ struct StrategyGroup { void AddStrategy(const ShardingStrategy& strategy, const InputShardings& input_shardings = {}) { - strategies.push_back(strategy); + // Create a new strategy if needed (otherwise, reuse an existing one). + size_t strategy_idx = strategies.size(); + const size_t input_sharding_idx = strategy_input_shardings.size(); + const auto it = std::find(strategies.begin(), strategies.end(), strategy); + if (it == strategies.end()) { + strategies.push_back(strategy); + strategy_idx_to_input_sharding_idx.push_back(input_sharding_idx); + } else { + strategy_idx = std::distance(strategies.begin(), it); + } + input_sharding_idx_to_strategy_idx.push_back(strategy_idx); strategy_input_shardings.push_back(input_shardings); } void ClearStrategies() { strategies.clear(); strategy_input_shardings.clear(); + input_sharding_idx_to_strategy_idx.clear(); + strategy_idx_to_input_sharding_idx.clear(); } ShardingStrategy& GetStrategy(size_t strategy_idx) { return strategies[strategy_idx]; } - const InputShardings& GetInputShardings(size_t strategy_idx) const { - return strategy_input_shardings[strategy_idx]; + const ShardingStrategy& GetStrategyForInputShardings( + size_t input_sharding_idx) const { + const size_t strategy_idx = + input_sharding_idx_to_strategy_idx[input_sharding_idx]; + CHECK_LT(strategy_idx, strategies.size()); + return strategies[strategy_idx]; + } + + const InputShardings& GetInputShardings(size_t input_sharding_idx) const { + return strategy_input_shardings[input_sharding_idx]; + } + + const InputShardings& GetInputShardingsForStrategy( + size_t strategy_idx) const { + const size_t input_sharding_idx = + strategy_idx_to_input_sharding_idx[strategy_idx]; + CHECK_LT(input_sharding_idx, strategy_input_shardings.size()); + return strategy_input_shardings[input_sharding_idx]; } const std::vector& GetStrategies() const { @@ -297,6 +352,8 @@ struct StrategyGroup { // A vector of strategy choices for the non-tuple output. std::vector strategies; std::vector strategy_input_shardings; + std::vector input_sharding_idx_to_strategy_idx; + std::vector strategy_idx_to_input_sharding_idx; // Used when is_tuple == True. A vector of pointers, each pointer is one // StrategyGroup for one value in the output Tuple diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index 9de0cfd3fedf3..2b0c2aec59e6f 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -1145,7 +1145,8 @@ ENTRY %entry { ASSERT_NE(dot, nullptr); EXPECT_THAT(param0, op::Sharding("{devices=[4,1]0,1,2,3}")); EXPECT_THAT(param1, op::Sharding("{replicated}")); - EXPECT_THAT(dot, op::Sharding("{devices=[4,1]0,1,2,3}")); + EXPECT_THAT(dot, AnyOf(op::Sharding("{devices=[4,1]0,1,2,3}"), + op::Sharding("{devices=[2,2]<=[4]}"))); } TEST_F(AutoShardingTest, DotInsertReshardingReshapes) { diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index c5647d7828b98..5ee4d464ff116 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -840,14 +840,17 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) { deduped_replicated_strategies; absl::flat_hash_set added; size_t num_skipped_due_to_infinity_costs = 0; - for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) { - const ShardingStrategy& strategy = strategy_group.GetStrategy(sid); + const auto& strategy_input_shardings = + strategy_group.GetStrategyInputShardings(); + for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) { + const InputShardings& input_shardings = strategy_input_shardings[iid]; + const ShardingStrategy& strategy = + strategy_group.GetStrategyForInputShardings(iid); if (AllInfinityCosts(strategy.communication_resharding_costs)) { num_skipped_due_to_infinity_costs++; continue; } std::string key = strategy.output_sharding.ToString(); - const auto& input_shardings = strategy_group.GetInputShardings(sid); if (!input_shardings.empty()) { for (const auto& sharding : input_shardings) { key += "/" + (sharding.has_value() ? sharding->ToString() : "none"); @@ -863,8 +866,7 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) { deduped_replicated_strategies.push_back({strategy, input_shardings}); } } - CHECK_LT(num_skipped_due_to_infinity_costs, - strategy_group.GetStrategies().size()) + CHECK_LT(num_skipped_due_to_infinity_costs, strategy_input_shardings.size()) << "All strategies removed due to infinite resharding costs"; // Keeps replicated strategies as the last ones. if (!deduped_replicated_strategies.empty()) {