Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Augments StrategyGroup to model a one-to-many relationship between strategies and input shardings. #17289

Merged
merged 1 commit into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1476,9 +1476,12 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding(
// Sharding provided by XLA users, we need to keep them.
strategy_group.following = nullptr;
std::vector<std::pair<ShardingStrategy, InputShardings>> new_strategies;
for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) {
const ShardingStrategy& strategy = strategy_group.GetStrategy(sid);
const auto& input_shardings = strategy_group.GetInputShardings(sid);
const auto& strategy_input_shardings =
strategy_group.GetStrategyInputShardings();
for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = strategy_input_shardings[iid];
const ShardingStrategy& strategy =
strategy_group.GetStrategyForInputShardings(iid);
if (strategy.output_sharding == existing_sharding) {
VLOG(1) << "Keeping strategy: " << strategy.ToString();
new_strategies.push_back({strategy, input_shardings});
Expand Down Expand Up @@ -1574,9 +1577,12 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding(
// It is IMPORTANT that we do this only for instructions that do no follow
// others, to keep the number of ILP variable small.
std::vector<std::pair<ShardingStrategy, InputShardings>> new_vector;
for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) {
const ShardingStrategy& strategy = strategy_group.GetStrategy(sid);
const auto& input_shardings = strategy_group.GetInputShardings(sid);
const auto& strategy_input_shardings =
strategy_group.GetStrategyInputShardings();
for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = strategy_input_shardings[iid];
const ShardingStrategy& strategy =
strategy_group.GetStrategyForInputShardings(iid);
if (strategy.output_sharding.IsReplicated() ||
ShardingIsConsistent(existing_sharding, strategy.output_sharding,
strict) ||
Expand Down Expand Up @@ -3247,9 +3253,12 @@ absl::Status FilterStrategy(const HloInstruction* ins, const Shape& shape,
}

std::vector<std::pair<ShardingStrategy, InputShardings>> new_strategies;
for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) {
const ShardingStrategy& strategy = strategy_group.GetStrategy(sid);
const auto& input_shardings = strategy_group.GetInputShardings(sid);
const auto& strategy_input_shardings =
strategy_group.GetStrategyInputShardings();
for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = strategy_input_shardings[iid];
const ShardingStrategy& strategy =
strategy_group.GetStrategyForInputShardings(iid);
const HloSharding& output_sharding = strategy.output_sharding;
const std::vector<int64_t> tensor_dim_to_mesh_dim =
cluster_env.GetTensorDimToMeshDimWrapper(shape, output_sharding);
Expand Down
4 changes: 2 additions & 2 deletions xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ inline const InputShardings& GetInputShardings(
CHECK(!strategy_group->is_tuple);
NodeIdx node_idx = strategy_group->node_idx;
NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]);
return strategy_group->GetInputShardings(stra_idx);
return strategy_group->GetInputShardingsForStrategy(stra_idx);
}

// Get the final sharding strategy according to the ILP solution.
Expand Down Expand Up @@ -181,7 +181,7 @@ inline const InputShardings& GetInputShardingsForTuple(
}
NodeIdx node_idx = strategy_group->node_idx;
NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]);
return strategy_group->GetInputShardings(stra_idx);
return strategy_group->GetInputShardingsForStrategy(stra_idx);
}

} // namespace spmd
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,12 @@ std::optional<HloSharding> HandlerBase::GetShardingFromUser(

void HandlerBase::SortStrategies() {
std::vector<std::pair<ShardingStrategy, InputShardings>> strategy_shardings;
for (size_t sid = 0; sid < strategy_group_->GetStrategies().size(); ++sid) {
const ShardingStrategy& strategy = strategy_group_->GetStrategy(sid);
const auto& input_shardings = strategy_group_->GetInputShardings(sid);
const auto strategy_input_shardings =
strategy_group_->GetStrategyInputShardings();
for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = strategy_input_shardings[iid];
const ShardingStrategy& strategy =
strategy_group_->GetStrategyForInputShardings(iid);
strategy_shardings.push_back({strategy, input_shardings});
}
absl::c_stable_sort(
Expand Down
28 changes: 18 additions & 10 deletions xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ namespace spmd {

bool LeafVectorsAreConsistent(const std::vector<ShardingStrategy>& one,
const std::vector<ShardingStrategy>& two) {
return one.size() == two.size();
if (one.size() != two.size()) return false;
for (size_t sid = 0; sid < one.size(); ++sid) {
const bool invalid_strategy_one = (one[sid].compute_cost >= kInfinityCost);
const bool invalid_strategy_two = (two[sid].compute_cost >= kInfinityCost);
if (invalid_strategy_one != invalid_strategy_two) return false;
}
return true;
}

std::optional<HloSharding> ConstructImprovedSharding(
Expand Down Expand Up @@ -1026,6 +1032,11 @@ BuildStrategyAndCost(
}
CHECK(strategy_group != nullptr);
RemoveDuplicatedStrategy(*strategy_group);
if (!option.allow_shardings_small_dims_across_many_devices) {
RemoveShardingsWhereSmallDimsShardedAcrossManyDevices(
ins->shape(), /* instruction_has_user_sharding */ ins->has_sharding(),
*strategy_group);
}
if (ins->has_sharding() && ins->opcode() != HloOpcode::kOutfeed) {
// Finds the sharding strategy that aligns with the given sharding spec
// Do not merge nodes if this one instruction has annotations.
Expand Down Expand Up @@ -1060,11 +1071,6 @@ BuildStrategyAndCost(
}
}
}
if (!option.allow_shardings_small_dims_across_many_devices) {
RemoveShardingsWhereSmallDimsShardedAcrossManyDevices(
ins->shape(), /* instruction_has_user_sharding */ ins->has_sharding(),
*strategy_group);
}

if (instruction_execution_counts.contains(ins)) {
ScaleCostsWithExecutionCounts(instruction_execution_counts.at(ins),
Expand All @@ -1085,10 +1091,12 @@ BuildStrategyAndCost(
CHECK(!strategy_group->is_tuple);
std::vector<std::pair<ShardingStrategy, InputShardings>> new_strategies;
int64_t idx = it - inst_indices.begin();
const auto& strategies = strategy_group->GetStrategies();
for (size_t sid = 0; sid < strategies.size(); ++sid) {
const ShardingStrategy& strategy = strategy_group->GetStrategy(sid);
const auto& input_shardings = strategy_group->GetInputShardings(sid);
const auto& strategy_input_shardings =
strategy_group->GetStrategyInputShardings();
for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = strategy_input_shardings[iid];
const ShardingStrategy& strategy =
strategy_group->GetStrategyForInputShardings(iid);
if (strategy.name == stra_names[idx]) {
new_strategies.push_back({strategy, input_shardings});
}
Expand Down
65 changes: 61 additions & 4 deletions xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ limitations under the License.
#ifndef XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_STRATEGY_H_
#define XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_STRATEGY_H_

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <memory>
#include <optional>
#include <string>
Expand Down Expand Up @@ -213,7 +215,32 @@ struct StrategyGroup {
}
} else {
for (const auto& strategy : strategies) {
absl::StrAppend(&str, indent, "Strategy ", strategy.ToStringLong());
absl::StrAppend(&str, indent, "Strategy ", strategy.ToStringLong(),
"\n");
}
}
if (!is_tuple) {
for (const auto& input_shardings : strategy_input_shardings) {
std::string input_sharding_str = "{";
for (const auto& s : input_shardings) {
if (!s.has_value()) {
input_sharding_str += "[*],";
} else if (s->IsReplicated()) {
input_sharding_str += "[R],";
} else {
if (s->ReplicateOnLastTileDim()) {
input_sharding_str +=
"[" + absl::StrJoin(s->tile_assignment().dimensions(), ", ") +
"]last_tile_dim_replicate,";
} else {
input_sharding_str +=
"[" + absl::StrJoin(s->tile_assignment().dimensions(), ", ") +
"],";
}
}
}
input_sharding_str += "}\n";
absl::StrAppend(&str, indent, "Input Sharding ", input_sharding_str);
}
}
return str;
Expand Down Expand Up @@ -253,21 +280,49 @@ struct StrategyGroup {

void AddStrategy(const ShardingStrategy& strategy,
const InputShardings& input_shardings = {}) {
strategies.push_back(strategy);
// Create a new strategy if needed (otherwise, reuse an existing one).
size_t strategy_idx = strategies.size();
const size_t input_sharding_idx = strategy_input_shardings.size();
const auto it = std::find(strategies.begin(), strategies.end(), strategy);
if (it == strategies.end()) {
strategies.push_back(strategy);
strategy_idx_to_input_sharding_idx.push_back(input_sharding_idx);
} else {
strategy_idx = std::distance(strategies.begin(), it);
}
input_sharding_idx_to_strategy_idx.push_back(strategy_idx);
strategy_input_shardings.push_back(input_shardings);
}

void ClearStrategies() {
strategies.clear();
strategy_input_shardings.clear();
input_sharding_idx_to_strategy_idx.clear();
strategy_idx_to_input_sharding_idx.clear();
}

ShardingStrategy& GetStrategy(size_t strategy_idx) {
return strategies[strategy_idx];
}

const InputShardings& GetInputShardings(size_t strategy_idx) const {
return strategy_input_shardings[strategy_idx];
const ShardingStrategy& GetStrategyForInputShardings(
size_t input_sharding_idx) const {
const size_t strategy_idx =
input_sharding_idx_to_strategy_idx[input_sharding_idx];
CHECK_LT(strategy_idx, strategies.size());
return strategies[strategy_idx];
}

const InputShardings& GetInputShardings(size_t input_sharding_idx) const {
return strategy_input_shardings[input_sharding_idx];
}

const InputShardings& GetInputShardingsForStrategy(
size_t strategy_idx) const {
const size_t input_sharding_idx =
strategy_idx_to_input_sharding_idx[strategy_idx];
CHECK_LT(input_sharding_idx, strategy_input_shardings.size());
return strategy_input_shardings[input_sharding_idx];
}

const std::vector<ShardingStrategy>& GetStrategies() const {
Expand Down Expand Up @@ -297,6 +352,8 @@ struct StrategyGroup {
// A vector of strategy choices for the non-tuple output.
std::vector<ShardingStrategy> strategies;
std::vector<InputShardings> strategy_input_shardings;
std::vector<size_t> input_sharding_idx_to_strategy_idx;
std::vector<size_t> strategy_idx_to_input_sharding_idx;

// Used when is_tuple == True. A vector of pointers, each pointer is one
// StrategyGroup for one value in the output Tuple
Expand Down
3 changes: 2 additions & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,8 @@ ENTRY %entry {
ASSERT_NE(dot, nullptr);
EXPECT_THAT(param0, op::Sharding("{devices=[4,1]0,1,2,3}"));
EXPECT_THAT(param1, op::Sharding("{replicated}"));
EXPECT_THAT(dot, op::Sharding("{devices=[4,1]0,1,2,3}"));
EXPECT_THAT(dot, AnyOf(op::Sharding("{devices=[4,1]0,1,2,3}"),
op::Sharding("{devices=[2,2]<=[4]}")));
}

TEST_F(AutoShardingTest, DotInsertReshardingReshapes) {
Expand Down
12 changes: 7 additions & 5 deletions xla/hlo/experimental/auto_sharding/auto_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -840,14 +840,17 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) {
deduped_replicated_strategies;
absl::flat_hash_set<std::string> added;
size_t num_skipped_due_to_infinity_costs = 0;
for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) {
const ShardingStrategy& strategy = strategy_group.GetStrategy(sid);
const auto& strategy_input_shardings =
strategy_group.GetStrategyInputShardings();
for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = strategy_input_shardings[iid];
const ShardingStrategy& strategy =
strategy_group.GetStrategyForInputShardings(iid);
if (AllInfinityCosts(strategy.communication_resharding_costs)) {
num_skipped_due_to_infinity_costs++;
continue;
}
std::string key = strategy.output_sharding.ToString();
const auto& input_shardings = strategy_group.GetInputShardings(sid);
if (!input_shardings.empty()) {
for (const auto& sharding : input_shardings) {
key += "/" + (sharding.has_value() ? sharding->ToString() : "none");
Expand All @@ -863,8 +866,7 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) {
deduped_replicated_strategies.push_back({strategy, input_shardings});
}
}
CHECK_LT(num_skipped_due_to_infinity_costs,
strategy_group.GetStrategies().size())
CHECK_LT(num_skipped_due_to_infinity_costs, strategy_input_shardings.size())
<< "All strategies removed due to infinite resharding costs";
// Keeps replicated strategies as the last ones.
if (!deduped_replicated_strategies.empty()) {
Expand Down
Loading