Skip to content

Commit

Permalink
Augments StrategyGroup to model a one-to-many relationship between st…
Browse files Browse the repository at this point in the history
…rategies and input shardings.

PiperOrigin-RevId: 675276024
  • Loading branch information
Google-ML-Automation committed Sep 17, 2024
1 parent 3406c60 commit fcf9caf
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 26 deletions.
27 changes: 18 additions & 9 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1468,9 +1468,12 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding(
// Sharding provided by XLA users, we need to keep them.
strategy_group.following = nullptr;
std::vector<std::pair<ShardingStrategy, InputShardings>> new_strategies;
for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) {
const ShardingStrategy& strategy = strategy_group.GetStrategy(sid);
const auto& input_shardings = strategy_group.GetInputShardings(sid);
const auto& strategy_input_shardings =
strategy_group.GetStrategyInputShardings();
for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = strategy_input_shardings[iid];
const ShardingStrategy& strategy =
strategy_group.GetStrategyForInputShardings(iid);
if (strategy.output_sharding == existing_sharding) {
VLOG(1) << "Keeping strategy: " << strategy.ToString();
new_strategies.push_back({strategy, input_shardings});
Expand Down Expand Up @@ -1566,9 +1569,12 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding(
// It is IMPORTANT that we do this only for instructions that do no follow
// others, to keep the number of ILP variable small.
std::vector<std::pair<ShardingStrategy, InputShardings>> new_vector;
for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) {
const ShardingStrategy& strategy = strategy_group.GetStrategy(sid);
const auto& input_shardings = strategy_group.GetInputShardings(sid);
const auto& strategy_input_shardings =
strategy_group.GetStrategyInputShardings();
for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = strategy_input_shardings[iid];
const ShardingStrategy& strategy =
strategy_group.GetStrategyForInputShardings(iid);
if (strategy.output_sharding.IsReplicated() ||
ShardingIsConsistent(existing_sharding, strategy.output_sharding,
strict) ||
Expand Down Expand Up @@ -3355,9 +3361,12 @@ absl::Status FilterStrategy(const HloInstruction* ins, const Shape& shape,
}

std::vector<std::pair<ShardingStrategy, InputShardings>> new_strategies;
for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) {
const ShardingStrategy& strategy = strategy_group.GetStrategy(sid);
const auto& input_shardings = strategy_group.GetInputShardings(sid);
const auto& strategy_input_shardings =
strategy_group.GetStrategyInputShardings();
for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = strategy_input_shardings[iid];
const ShardingStrategy& strategy =
strategy_group.GetStrategyForInputShardings(iid);
const HloSharding& output_sharding = strategy.output_sharding;
const std::vector<int64_t> tensor_dim_to_mesh_dim =
cluster_env.GetTensorDimToMeshDimWrapper(shape, output_sharding);
Expand Down
4 changes: 2 additions & 2 deletions xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ inline const InputShardings& GetInputShardings(
CHECK(!strategy_group->is_tuple);
NodeIdx node_idx = strategy_group->node_idx;
NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]);
return strategy_group->GetInputShardings(stra_idx);
return strategy_group->GetInputShardingsForStrategy(stra_idx);
}

// Get the final sharding strategy according to the ILP solution.
Expand Down Expand Up @@ -181,7 +181,7 @@ inline const InputShardings& GetInputShardingsForTuple(
}
NodeIdx node_idx = strategy_group->node_idx;
NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]);
return strategy_group->GetInputShardings(stra_idx);
return strategy_group->GetInputShardingsForStrategy(stra_idx);
}

} // namespace spmd
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,12 @@ std::optional<HloSharding> HandlerBase::GetShardingFromUser(

void HandlerBase::SortStrategies() {
std::vector<std::pair<ShardingStrategy, InputShardings>> strategy_shardings;
for (size_t sid = 0; sid < strategy_group_->GetStrategies().size(); ++sid) {
const ShardingStrategy& strategy = strategy_group_->GetStrategy(sid);
const auto& input_shardings = strategy_group_->GetInputShardings(sid);
const auto strategy_input_shardings =
strategy_group_->GetStrategyInputShardings();
for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = strategy_input_shardings[iid];
const ShardingStrategy& strategy =
strategy_group_->GetStrategyForInputShardings(iid);
strategy_shardings.push_back({strategy, input_shardings});
}
absl::c_stable_sort(
Expand Down
10 changes: 6 additions & 4 deletions xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1085,10 +1085,12 @@ BuildStrategyAndCost(
CHECK(!strategy_group->is_tuple);
std::vector<std::pair<ShardingStrategy, InputShardings>> new_strategies;
int64_t idx = it - inst_indices.begin();
const auto& strategies = strategy_group->GetStrategies();
for (size_t sid = 0; sid < strategies.size(); ++sid) {
const ShardingStrategy& strategy = strategy_group->GetStrategy(sid);
const auto& input_shardings = strategy_group->GetInputShardings(sid);
const auto& strategy_input_shardings =
strategy_group->GetStrategyInputShardings();
for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = strategy_input_shardings[iid];
const ShardingStrategy& strategy =
strategy_group->GetStrategyForInputShardings(iid);
if (strategy.name == stra_names[idx]) {
new_strategies.push_back({strategy, input_shardings});
}
Expand Down
38 changes: 35 additions & 3 deletions xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ limitations under the License.
#ifndef XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_STRATEGY_H_
#define XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_STRATEGY_H_

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <memory>
#include <optional>
#include <string>
Expand Down Expand Up @@ -253,21 +255,49 @@ struct StrategyGroup {

void AddStrategy(const ShardingStrategy& strategy,
const InputShardings& input_shardings = {}) {
strategies.push_back(strategy);
// Create a new strategy if needed, otherwise reuse an existing one.
size_t strategy_idx = strategies.size();
const size_t input_sharding_idx = strategy_input_shardings.size();
const auto it = std::find(strategies.begin(), strategies.end(), strategy);
if (it == strategies.end()) {
strategies.push_back(strategy);
strategy_idx_to_input_sharding_idx.push_back(input_sharding_idx);
} else {
strategy_idx = std::distance(strategies.begin(), it);
}
input_sharding_idx_to_strategy_idx.push_back(strategy_idx);
strategy_input_shardings.push_back(input_shardings);
}

void ClearStrategies() {
strategies.clear();
strategy_input_shardings.clear();
input_sharding_idx_to_strategy_idx.clear();
strategy_idx_to_input_sharding_idx.clear();
}

ShardingStrategy& GetStrategy(size_t strategy_idx) {
return strategies[strategy_idx];
}

const InputShardings& GetInputShardings(size_t strategy_idx) const {
return strategy_input_shardings[strategy_idx];
const ShardingStrategy& GetStrategyForInputShardings(
size_t input_sharding_idx) const {
const size_t strategy_idx =
input_sharding_idx_to_strategy_idx[input_sharding_idx];
CHECK_LT(strategy_idx, strategies.size());
return strategies[strategy_idx];
}

const InputShardings& GetInputShardings(size_t input_sharding_idx) const {
return strategy_input_shardings[input_sharding_idx];
}

const InputShardings& GetInputShardingsForStrategy(
size_t strategy_idx) const {
const size_t input_sharding_idx =
strategy_idx_to_input_sharding_idx[strategy_idx];
CHECK_LT(input_sharding_idx, strategy_input_shardings.size());
return strategy_input_shardings[input_sharding_idx];
}

const std::vector<ShardingStrategy>& GetStrategies() const {
Expand Down Expand Up @@ -297,6 +327,8 @@ struct StrategyGroup {
// A vector of strategy choices for the non-tuple output.
std::vector<ShardingStrategy> strategies;
std::vector<InputShardings> strategy_input_shardings;
std::vector<size_t> input_sharding_idx_to_strategy_idx;
std::vector<size_t> strategy_idx_to_input_sharding_idx;

// Used when is_tuple == True. A vector of pointers, each pointer is one
// StrategyGroup for one value in the output Tuple
Expand Down
12 changes: 7 additions & 5 deletions xla/hlo/experimental/auto_sharding/auto_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -841,14 +841,17 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) {
deduped_replicated_strategies;
absl::flat_hash_set<std::string> added;
size_t num_skipped_due_to_infinity_costs = 0;
for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) {
const ShardingStrategy& strategy = strategy_group.GetStrategy(sid);
const auto& strategy_input_shardings =
strategy_group.GetStrategyInputShardings();
for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = strategy_input_shardings[iid];
const ShardingStrategy& strategy =
strategy_group.GetStrategyForInputShardings(iid);
if (AllInfinityCosts(strategy.communication_resharding_costs)) {
num_skipped_due_to_infinity_costs++;
continue;
}
std::string key = strategy.output_sharding.ToString();
const auto& input_shardings = strategy_group.GetInputShardings(sid);
if (!input_shardings.empty()) {
for (const auto& sharding : input_shardings) {
key += "/" + (sharding.has_value() ? sharding->ToString() : "none");
Expand All @@ -864,8 +867,7 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) {
deduped_replicated_strategies.push_back({strategy, input_shardings});
}
}
CHECK_LT(num_skipped_due_to_infinity_costs,
strategy_group.GetStrategies().size())
CHECK_LT(num_skipped_due_to_infinity_costs, strategy_input_shardings.size())
<< "All strategies removed due to infinite resharding costs";
// Keeps replicated strategies as the last ones.
if (!deduped_replicated_strategies.empty()) {
Expand Down

0 comments on commit fcf9caf

Please sign in to comment.