Skip to content

Commit

Permalink
Augments StrategyGroup to model a one-to-many relationship between st…
Browse files Browse the repository at this point in the history
…rategies and input shardings.

PiperOrigin-RevId: 676370963
  • Loading branch information
Google-ML-Automation committed Sep 19, 2024
1 parent 45dea71 commit b9fcb24
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 34 deletions.
27 changes: 18 additions & 9 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1476,9 +1476,12 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding(
// Sharding provided by XLA users, we need to keep them.
strategy_group.following = nullptr;
std::vector<std::pair<ShardingStrategy, InputShardings>> new_strategies;
for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) {
const ShardingStrategy& strategy = strategy_group.GetStrategy(sid);
const auto& input_shardings = strategy_group.GetInputShardings(sid);
const auto& strategy_input_shardings =
strategy_group.GetStrategyInputShardings();
for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = strategy_input_shardings[iid];
const ShardingStrategy& strategy =
strategy_group.GetStrategyForInputShardings(iid);
if (strategy.output_sharding == existing_sharding) {
VLOG(1) << "Keeping strategy: " << strategy.ToString();
new_strategies.push_back({strategy, input_shardings});
Expand Down Expand Up @@ -1574,9 +1577,12 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding(
// It is IMPORTANT that we do this only for instructions that do no follow
// others, to keep the number of ILP variable small.
std::vector<std::pair<ShardingStrategy, InputShardings>> new_vector;
for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) {
const ShardingStrategy& strategy = strategy_group.GetStrategy(sid);
const auto& input_shardings = strategy_group.GetInputShardings(sid);
const auto& strategy_input_shardings =
strategy_group.GetStrategyInputShardings();
for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = strategy_input_shardings[iid];
const ShardingStrategy& strategy =
strategy_group.GetStrategyForInputShardings(iid);
if (strategy.output_sharding.IsReplicated() ||
ShardingIsConsistent(existing_sharding, strategy.output_sharding,
strict) ||
Expand Down Expand Up @@ -3247,9 +3253,12 @@ absl::Status FilterStrategy(const HloInstruction* ins, const Shape& shape,
}

std::vector<std::pair<ShardingStrategy, InputShardings>> new_strategies;
for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) {
const ShardingStrategy& strategy = strategy_group.GetStrategy(sid);
const auto& input_shardings = strategy_group.GetInputShardings(sid);
const auto& strategy_input_shardings =
strategy_group.GetStrategyInputShardings();
for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = strategy_input_shardings[iid];
const ShardingStrategy& strategy =
strategy_group.GetStrategyForInputShardings(iid);
const HloSharding& output_sharding = strategy.output_sharding;
const std::vector<int64_t> tensor_dim_to_mesh_dim =
cluster_env.GetTensorDimToMeshDimWrapper(shape, output_sharding);
Expand Down
4 changes: 2 additions & 2 deletions xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ inline const InputShardings& GetInputShardings(
CHECK(!strategy_group->is_tuple);
NodeIdx node_idx = strategy_group->node_idx;
NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]);
return strategy_group->GetInputShardings(stra_idx);
return strategy_group->GetInputShardingsForStrategy(stra_idx);
}

// Get the final sharding strategy according to the ILP solution.
Expand Down Expand Up @@ -181,7 +181,7 @@ inline const InputShardings& GetInputShardingsForTuple(
}
NodeIdx node_idx = strategy_group->node_idx;
NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]);
return strategy_group->GetInputShardings(stra_idx);
return strategy_group->GetInputShardingsForStrategy(stra_idx);
}

} // namespace spmd
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,12 @@ std::optional<HloSharding> HandlerBase::GetShardingFromUser(

void HandlerBase::SortStrategies() {
std::vector<std::pair<ShardingStrategy, InputShardings>> strategy_shardings;
for (size_t sid = 0; sid < strategy_group_->GetStrategies().size(); ++sid) {
const ShardingStrategy& strategy = strategy_group_->GetStrategy(sid);
const auto& input_shardings = strategy_group_->GetInputShardings(sid);
const auto strategy_input_shardings =
strategy_group_->GetStrategyInputShardings();
for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = strategy_input_shardings[iid];
const ShardingStrategy& strategy =
strategy_group_->GetStrategyForInputShardings(iid);
strategy_shardings.push_back({strategy, input_shardings});
}
absl::c_stable_sort(
Expand Down
28 changes: 18 additions & 10 deletions xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ namespace spmd {

bool LeafVectorsAreConsistent(const std::vector<ShardingStrategy>& one,
const std::vector<ShardingStrategy>& two) {
return one.size() == two.size();
if (one.size() != two.size()) return false;
for (size_t sid = 0; sid < one.size(); ++sid) {
const bool invalid_strategy_one = (one[sid].compute_cost >= kInfinityCost);
const bool invalid_strategy_two = (two[sid].compute_cost >= kInfinityCost);
if (invalid_strategy_one != invalid_strategy_two) return false;
}
return true;
}

std::optional<HloSharding> ConstructImprovedSharding(
Expand Down Expand Up @@ -1026,6 +1032,11 @@ BuildStrategyAndCost(
}
CHECK(strategy_group != nullptr);
RemoveDuplicatedStrategy(*strategy_group);
if (!option.allow_shardings_small_dims_across_many_devices) {
RemoveShardingsWhereSmallDimsShardedAcrossManyDevices(
ins->shape(), /* instruction_has_user_sharding */ ins->has_sharding(),
*strategy_group);
}
if (ins->has_sharding() && ins->opcode() != HloOpcode::kOutfeed) {
// Finds the sharding strategy that aligns with the given sharding spec
// Do not merge nodes if this one instruction has annotations.
Expand Down Expand Up @@ -1060,11 +1071,6 @@ BuildStrategyAndCost(
}
}
}
if (!option.allow_shardings_small_dims_across_many_devices) {
RemoveShardingsWhereSmallDimsShardedAcrossManyDevices(
ins->shape(), /* instruction_has_user_sharding */ ins->has_sharding(),
*strategy_group);
}

if (instruction_execution_counts.contains(ins)) {
ScaleCostsWithExecutionCounts(instruction_execution_counts.at(ins),
Expand All @@ -1085,10 +1091,12 @@ BuildStrategyAndCost(
CHECK(!strategy_group->is_tuple);
std::vector<std::pair<ShardingStrategy, InputShardings>> new_strategies;
int64_t idx = it - inst_indices.begin();
const auto& strategies = strategy_group->GetStrategies();
for (size_t sid = 0; sid < strategies.size(); ++sid) {
const ShardingStrategy& strategy = strategy_group->GetStrategy(sid);
const auto& input_shardings = strategy_group->GetInputShardings(sid);
const auto& strategy_input_shardings =
strategy_group->GetStrategyInputShardings();
for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = strategy_input_shardings[iid];
const ShardingStrategy& strategy =
strategy_group->GetStrategyForInputShardings(iid);
if (strategy.name == stra_names[idx]) {
new_strategies.push_back({strategy, input_shardings});
}
Expand Down
65 changes: 61 additions & 4 deletions xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ limitations under the License.
#ifndef XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_STRATEGY_H_
#define XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_STRATEGY_H_

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <memory>
#include <optional>
#include <string>
Expand Down Expand Up @@ -213,7 +215,32 @@ struct StrategyGroup {
}
} else {
for (const auto& strategy : strategies) {
absl::StrAppend(&str, indent, "Strategy ", strategy.ToStringLong());
absl::StrAppend(&str, indent, "Strategy ", strategy.ToStringLong(),
"\n");
}
}
if (!is_tuple) {
for (const auto& input_shardings : strategy_input_shardings) {
std::string input_sharding_str = "{";
for (const auto& s : input_shardings) {
if (!s.has_value()) {
input_sharding_str += "[*],";
} else if (s->IsReplicated()) {
input_sharding_str += "[R],";
} else {
if (s->ReplicateOnLastTileDim()) {
input_sharding_str +=
"[" + absl::StrJoin(s->tile_assignment().dimensions(), ", ") +
"]last_tile_dim_replicate,";
} else {
input_sharding_str +=
"[" + absl::StrJoin(s->tile_assignment().dimensions(), ", ") +
"],";
}
}
}
input_sharding_str += "}\n";
absl::StrAppend(&str, indent, "Input Sharding ", input_sharding_str);
}
}
return str;
Expand Down Expand Up @@ -253,21 +280,49 @@ struct StrategyGroup {

void AddStrategy(const ShardingStrategy& strategy,
const InputShardings& input_shardings = {}) {
strategies.push_back(strategy);
// Create a new strategy if needed (otherwise, reuse an existing one).
size_t strategy_idx = strategies.size();
const size_t input_sharding_idx = strategy_input_shardings.size();
const auto it = std::find(strategies.begin(), strategies.end(), strategy);
if (it == strategies.end()) {
strategies.push_back(strategy);
strategy_idx_to_input_sharding_idx.push_back(input_sharding_idx);
} else {
strategy_idx = std::distance(strategies.begin(), it);
}
input_sharding_idx_to_strategy_idx.push_back(strategy_idx);
strategy_input_shardings.push_back(input_shardings);
}

void ClearStrategies() {
strategies.clear();
strategy_input_shardings.clear();
input_sharding_idx_to_strategy_idx.clear();
strategy_idx_to_input_sharding_idx.clear();
}

ShardingStrategy& GetStrategy(size_t strategy_idx) {
return strategies[strategy_idx];
}

const InputShardings& GetInputShardings(size_t strategy_idx) const {
return strategy_input_shardings[strategy_idx];
const ShardingStrategy& GetStrategyForInputShardings(
size_t input_sharding_idx) const {
const size_t strategy_idx =
input_sharding_idx_to_strategy_idx[input_sharding_idx];
CHECK_LT(strategy_idx, strategies.size());
return strategies[strategy_idx];
}

const InputShardings& GetInputShardings(size_t input_sharding_idx) const {
return strategy_input_shardings[input_sharding_idx];
}

const InputShardings& GetInputShardingsForStrategy(
size_t strategy_idx) const {
const size_t input_sharding_idx =
strategy_idx_to_input_sharding_idx[strategy_idx];
CHECK_LT(input_sharding_idx, strategy_input_shardings.size());
return strategy_input_shardings[input_sharding_idx];
}

const std::vector<ShardingStrategy>& GetStrategies() const {
Expand Down Expand Up @@ -297,6 +352,8 @@ struct StrategyGroup {
// A vector of strategy choices for the non-tuple output.
std::vector<ShardingStrategy> strategies;
std::vector<InputShardings> strategy_input_shardings;
std::vector<size_t> input_sharding_idx_to_strategy_idx;
std::vector<size_t> strategy_idx_to_input_sharding_idx;

// Used when is_tuple == True. A vector of pointers, each pointer is one
// StrategyGroup for one value in the output Tuple
Expand Down
3 changes: 2 additions & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,8 @@ ENTRY %entry {
ASSERT_NE(dot, nullptr);
EXPECT_THAT(param0, op::Sharding("{devices=[4,1]0,1,2,3}"));
EXPECT_THAT(param1, op::Sharding("{replicated}"));
EXPECT_THAT(dot, op::Sharding("{devices=[4,1]0,1,2,3}"));
EXPECT_THAT(dot, AnyOf(op::Sharding("{devices=[4,1]0,1,2,3}"),
op::Sharding("{devices=[2,2]<=[4]}")));
}

TEST_F(AutoShardingTest, DotInsertReshardingReshapes) {
Expand Down
12 changes: 7 additions & 5 deletions xla/hlo/experimental/auto_sharding/auto_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -840,14 +840,17 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) {
deduped_replicated_strategies;
absl::flat_hash_set<std::string> added;
size_t num_skipped_due_to_infinity_costs = 0;
for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) {
const ShardingStrategy& strategy = strategy_group.GetStrategy(sid);
const auto& strategy_input_shardings =
strategy_group.GetStrategyInputShardings();
for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = strategy_input_shardings[iid];
const ShardingStrategy& strategy =
strategy_group.GetStrategyForInputShardings(iid);
if (AllInfinityCosts(strategy.communication_resharding_costs)) {
num_skipped_due_to_infinity_costs++;
continue;
}
std::string key = strategy.output_sharding.ToString();
const auto& input_shardings = strategy_group.GetInputShardings(sid);
if (!input_shardings.empty()) {
for (const auto& sharding : input_shardings) {
key += "/" + (sharding.has_value() ? sharding->ToString() : "none");
Expand All @@ -863,8 +866,7 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) {
deduped_replicated_strategies.push_back({strategy, input_shardings});
}
}
CHECK_LT(num_skipped_due_to_infinity_costs,
strategy_group.GetStrategies().size())
CHECK_LT(num_skipped_due_to_infinity_costs, strategy_input_shardings.size())
<< "All strategies removed due to infinite resharding costs";
// Keeps replicated strategies as the last ones.
if (!deduped_replicated_strategies.empty()) {
Expand Down

0 comments on commit b9fcb24

Please sign in to comment.