Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Associates names with individual input sharding combinations (rather than strategies). #17864

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 108 additions & 102 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions xla/hlo/experimental/auto_sharding/auto_sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ bool HasReduceScatterOpportunity(const HloInstruction* inst,
const ConstInstructionSet& modified);

HloSharding GetReduceScatterOutput(const HloInstruction* ins,
const InputShardings& input_shardings,
const ShardingStrategy& strategy,
const ClusterEnvironment& cluster_env);

Expand Down
18 changes: 14 additions & 4 deletions xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,26 @@ CostGraph::CostGraph(const StrategyGroups& strategy_groups,
node_lens_[dst_idx]);
absl::flat_hash_map<std::string, NodeStrategyIdx>
src_strategy_name_to_idx_map;
for (NodeStrategyIdx i = 0; i < node_lens_[src_idx]; ++i) {
const auto& src_strategy_input_shardings =
src_strategy_group.GetStrategyInputShardings();
for (size_t iid = 0; iid < src_strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = src_strategy_input_shardings[iid];
NodeStrategyIdx i =
src_strategy_group.GetStrategyIdxForInputShardings(iid);
const ShardingStrategy& strategy = src_strategy_group.GetStrategy(i);
if (strategy.communication_cost > 0) {
src_strategy_name_to_idx_map[strategy.name] = i;
src_strategy_name_to_idx_map[input_shardings.name] = i;
}
}
for (NodeStrategyIdx i = 0; i < node_lens_[dst_idx]; ++i) {
const auto& dst_strategy_input_shardings =
dst_strategy_group.GetStrategyInputShardings();
for (size_t iid = 0; iid < dst_strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = dst_strategy_input_shardings[iid];
NodeStrategyIdx i =
dst_strategy_group.GetStrategyIdxForInputShardings(iid);
const ShardingStrategy& dst_strategy = dst_strategy_group.GetStrategy(i);
if (dst_strategy.communication_cost > 0) {
auto it = src_strategy_name_to_idx_map.find(dst_strategy.name);
auto it = src_strategy_name_to_idx_map.find(input_shardings.name);
if (it != src_strategy_name_to_idx_map.end()) {
const auto& src_strategy = src_strategy_group.GetStrategy(it->second);
CHECK_LE(std::abs(src_strategy.communication_cost -
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,12 @@ void HandlerBase::AppendNewStrategy(const std::string& name,
}

strategy_group_->AddStrategy(
ShardingStrategy({name, output_spec, compute_cost, communication_cost,
ShardingStrategy({output_spec, compute_cost, communication_cost,
static_cast<double>(ByteSizeOfShapeWithSharding(
ins_->shape(), output_spec)),
communication_resharding_costs,
memory_resharding_costs}),
{input_specs.begin(), input_specs.end()});
{name, {input_specs.begin(), input_specs.end()}});
}

// Given lhs and rhs dim maps, infers a sharding for the output by relying
Expand Down Expand Up @@ -467,7 +467,7 @@ void HandlerBase::SortStrategies() {
[](const std::pair<ShardingStrategy, InputShardings>& s1,
const std::pair<ShardingStrategy, InputShardings>& s2) {
if (s1.first.memory_cost == s2.first.memory_cost) {
return s1.first.name < s2.first.name;
return s1.second.name < s2.second.name;
} else {
return s1.first.memory_cost < s2.first.memory_cost;
}
Expand Down
27 changes: 12 additions & 15 deletions xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -346,14 +346,14 @@ BuildStrategyAndCost(
ByteSizeOfShapeWithSharding(ins->shape(), scatter_sharding);

InputShardings input_shardings_optional(
{data_sharding, indices_sharding, update_sharding});
{name, {data_sharding, indices_sharding, update_sharding}});
std::pair<ReshardingCosts, ReshardingCosts> resharding_costs =
GenerateReshardingCostsAndMissingShardingsForAllOperands(
ins, scatter_sharding, strategy_map, cluster_env, call_graph,
input_shardings_optional);

strategy_group->AddStrategy(
ShardingStrategy({name, scatter_sharding, compute_cost,
ShardingStrategy({scatter_sharding, compute_cost,
communication_cost, memory_cost,
std::move(resharding_costs.first),
std::move(resharding_costs.second)}),
Expand Down Expand Up @@ -397,15 +397,14 @@ BuildStrategyAndCost(
double memory_cost =
ByteSizeOfShapeWithSharding(gather_shape, output_sharding);
InputShardings input_shardings_optional(
{data_sharding, indices_sharding});
{output_sharding.ToString(), {data_sharding, indices_sharding}});
std::pair<ReshardingCosts, ReshardingCosts> resharding_costs =
GenerateReshardingCostsAndMissingShardingsForAllOperands(
ins, output_sharding, strategy_map, cluster_env, call_graph,
input_shardings_optional);

strategy_group->AddStrategy(
ShardingStrategy({std::string(output_sharding.ToString()),
output_sharding, compute_cost,
ShardingStrategy({output_sharding, compute_cost,
communication_cost, memory_cost,
std::move(resharding_costs.first),
std::move(resharding_costs.second)}),
Expand Down Expand Up @@ -565,14 +564,13 @@ BuildStrategyAndCost(
MemoryReshardingCostVector(src_strategy_group, operand->shape(),
input_spec, cluster_env);
strategy_group->AddStrategy(
ShardingStrategy({name,
output_spec,
ShardingStrategy({output_spec,
compute_cost,
communication_cost,
memory_cost,
{communication_resharding_costs},
{memory_resharding_costs}}),
{input_spec});
{name, {input_spec}});
}
break;
}
Expand Down Expand Up @@ -685,9 +683,9 @@ BuildStrategyAndCost(
if (k == follow_idx ||
ToString(ins->operand(k)->shape().dimensions()) ==
ToString(operand->shape().dimensions())) {
input_shardings.push_back(input_spec);
input_shardings.shardings.push_back(input_spec);
} else {
input_shardings.push_back(std::nullopt);
input_shardings.shardings.push_back(std::nullopt);
}
}
if (!output_spec.has_value()) {
Expand All @@ -703,11 +701,10 @@ BuildStrategyAndCost(
input_shardings);

strategy_group->AddStrategy(
ShardingStrategy({name, *output_spec, compute_cost,
communication_cost, memory_cost,
std::move(resharding_costs.first),
ShardingStrategy({*output_spec, compute_cost, communication_cost,
memory_cost, std::move(resharding_costs.first),
std::move(resharding_costs.second)}),
{input_spec});
{name, {input_spec}});
}

if (strategy_group->GetStrategies().empty()) {
Expand Down Expand Up @@ -1094,7 +1091,7 @@ BuildStrategyAndCost(
const InputShardings& input_shardings = strategy_input_shardings[iid];
const ShardingStrategy& strategy =
strategy_group->GetStrategyForInputShardings(iid);
if (strategy.name == stra_names[idx]) {
if (input_shardings.name == stra_names[idx]) {
new_strategies.push_back({strategy, input_shardings});
}
}
Expand Down
64 changes: 37 additions & 27 deletions xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,37 @@ using ReshardingCache =
ConstInstructionMap<std::vector<std::pair<HloSharding, HloInstruction*>>>;
// Resharding costs for each operand
using ReshardingCosts = std::vector<std::vector<double>>;
// Optional shardings for each operand
using InputShardings = std::vector<std::optional<HloSharding>>;

// A named vector of optional shardings for each operand.
struct InputShardings {
std::string name;
std::vector<std::optional<HloSharding>> shardings;

std::string ToString() const {
std::string str = absl::StrCat(name, " ");
for (const auto& s : shardings) {
if (!s.has_value()) {
absl::StrAppend(&str, "[*],");
} else if (s->IsReplicated()) {
absl::StrAppend(&str, "[R],");
} else {
if (s->ReplicateOnLastTileDim()) {
absl::StrAppend(
&str, "[", absl::StrJoin(s->tile_assignment().dimensions(), ", "),
"]last_tile_dim_replicate,");
} else {
absl::StrAppend(
&str, "[", absl::StrJoin(s->tile_assignment().dimensions(), ", "),
"],");
}
}
}
return str;
}
};

// One sharding strategy
struct ShardingStrategy {
std::string name;
HloSharding output_sharding;
double compute_cost;
double communication_cost;
Expand All @@ -94,9 +119,7 @@ struct ShardingStrategy {
ReshardingCosts communication_resharding_costs;
ReshardingCosts memory_resharding_costs;

std::string ToString() const {
return absl::StrCat(name, ", ", output_sharding.ToString());
}
std::string ToString() const { return output_sharding.ToString(); }

std::string ToStringLong() const {
std::vector<std::string> communication_resharding_vector_strings;
Expand All @@ -119,15 +142,15 @@ struct ShardingStrategy {
"{", absl::StrJoin(memory_resharding_vector_strings, ", "), "}");

return absl::StrCat(
name, ", ", output_sharding.ToString(), ", compute_cost=", compute_cost,
output_sharding.ToString(), ", compute_cost=", compute_cost,
", communication_cost=", communication_cost,
", memory_cost=", memory_cost,
", communication_resharding_costs=", communication_resharding_cost_str,
", memory_resharding_costs=", memory_resharding_cost_str);
}

bool operator==(const ShardingStrategy& other) const {
return name == other.name && output_sharding == other.output_sharding &&
return output_sharding == other.output_sharding &&
compute_cost == other.compute_cost &&
communication_cost == other.communication_cost &&
memory_cost == other.memory_cost &&
Expand Down Expand Up @@ -221,25 +244,8 @@ struct StrategyGroup {
}
if (!is_tuple) {
for (const auto& input_shardings : strategy_input_shardings) {
std::string input_sharding_str = "{";
for (const auto& s : input_shardings) {
if (!s.has_value()) {
input_sharding_str += "[*],";
} else if (s->IsReplicated()) {
input_sharding_str += "[R],";
} else {
if (s->ReplicateOnLastTileDim()) {
input_sharding_str +=
"[" + absl::StrJoin(s->tile_assignment().dimensions(), ", ") +
"]last_tile_dim_replicate,";
} else {
input_sharding_str +=
"[" + absl::StrJoin(s->tile_assignment().dimensions(), ", ") +
"],";
}
}
}
input_sharding_str += "}\n";
const std::string input_sharding_str =
absl::StrCat("{", input_shardings.ToString(), "}\n");
absl::StrAppend(&str, indent, "Input Sharding ", input_sharding_str);
}
}
Expand Down Expand Up @@ -313,6 +319,10 @@ struct StrategyGroup {
return strategies[strategy_idx];
}

size_t GetStrategyIdxForInputShardings(size_t input_sharding_idx) const {
return input_sharding_idx_to_strategy_idx[input_sharding_idx];
}

const InputShardings& GetInputShardings(size_t input_sharding_idx) const {
return strategy_input_shardings[input_sharding_idx];
}
Expand Down
4 changes: 2 additions & 2 deletions xla/hlo/experimental/auto_sharding/auto_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -851,8 +851,8 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) {
continue;
}
std::string key = strategy.output_sharding.ToString();
if (!input_shardings.empty()) {
for (const auto& sharding : input_shardings) {
if (!input_shardings.shardings.empty()) {
for (const auto& sharding : input_shardings.shardings) {
key += "/" + (sharding.has_value() ? sharding->ToString() : "none");
}
}
Expand Down
Loading