Skip to content

Commit

Permalink
Associates names with individual input sharding combinations (rather …
Browse files Browse the repository at this point in the history
…than strategies).

PiperOrigin-RevId: 681254097
  • Loading branch information
Google-ML-Automation committed Oct 2, 2024
1 parent 93be085 commit 8b7f1e0
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 153 deletions.
210 changes: 108 additions & 102 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions xla/hlo/experimental/auto_sharding/auto_sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ bool HasReduceScatterOpportunity(const HloInstruction* inst,
const ConstInstructionSet& modified);

HloSharding GetReduceScatterOutput(const HloInstruction* ins,
const InputShardings& input_shardings,
const ShardingStrategy& strategy,
const ClusterEnvironment& cluster_env);

Expand Down
18 changes: 14 additions & 4 deletions xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,26 @@ CostGraph::CostGraph(const StrategyGroups& strategy_groups,
node_lens_[dst_idx]);
absl::flat_hash_map<std::string, NodeStrategyIdx>
src_strategy_name_to_idx_map;
for (NodeStrategyIdx i = 0; i < node_lens_[src_idx]; ++i) {
const auto& src_strategy_input_shardings =
src_strategy_group.GetStrategyInputShardings();
for (size_t iid = 0; iid < src_strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = src_strategy_input_shardings[iid];
NodeStrategyIdx i =
src_strategy_group.GetStrategyIdxForInputShardings(iid);
const ShardingStrategy& strategy = src_strategy_group.GetStrategy(i);
if (strategy.communication_cost > 0) {
src_strategy_name_to_idx_map[strategy.name] = i;
src_strategy_name_to_idx_map[input_shardings.name] = i;
}
}
for (NodeStrategyIdx i = 0; i < node_lens_[dst_idx]; ++i) {
const auto& dst_strategy_input_shardings =
dst_strategy_group.GetStrategyInputShardings();
for (size_t iid = 0; iid < dst_strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = dst_strategy_input_shardings[iid];
NodeStrategyIdx i =
dst_strategy_group.GetStrategyIdxForInputShardings(iid);
const ShardingStrategy& dst_strategy = dst_strategy_group.GetStrategy(i);
if (dst_strategy.communication_cost > 0) {
auto it = src_strategy_name_to_idx_map.find(dst_strategy.name);
auto it = src_strategy_name_to_idx_map.find(input_shardings.name);
if (it != src_strategy_name_to_idx_map.end()) {
const auto& src_strategy = src_strategy_group.GetStrategy(it->second);
CHECK_LE(std::abs(src_strategy.communication_cost -
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,12 @@ void HandlerBase::AppendNewStrategy(const std::string& name,
}

strategy_group_->AddStrategy(
ShardingStrategy({name, output_spec, compute_cost, communication_cost,
ShardingStrategy({output_spec, compute_cost, communication_cost,
static_cast<double>(ByteSizeOfShapeWithSharding(
ins_->shape(), output_spec)),
communication_resharding_costs,
memory_resharding_costs}),
{input_specs.begin(), input_specs.end()});
{name, {input_specs.begin(), input_specs.end()}});
}

// Given lhs and rhs dim maps, infers a sharding for the output by relying
Expand Down Expand Up @@ -467,7 +467,7 @@ void HandlerBase::SortStrategies() {
[](const std::pair<ShardingStrategy, InputShardings>& s1,
const std::pair<ShardingStrategy, InputShardings>& s2) {
if (s1.first.memory_cost == s2.first.memory_cost) {
return s1.first.name < s2.first.name;
return s1.second.name < s2.second.name;
} else {
return s1.first.memory_cost < s2.first.memory_cost;
}
Expand Down
27 changes: 12 additions & 15 deletions xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,14 +350,14 @@ BuildStrategyAndCost(
ByteSizeOfShapeWithSharding(ins->shape(), scatter_sharding);

InputShardings input_shardings_optional(
{data_sharding, indices_sharding, update_sharding});
{name, {data_sharding, indices_sharding, update_sharding}});
std::pair<ReshardingCosts, ReshardingCosts> resharding_costs =
GenerateReshardingCostsAndMissingShardingsForAllOperands(
ins, scatter_sharding, strategy_map, cluster_env, call_graph,
input_shardings_optional);

strategy_group->AddStrategy(
ShardingStrategy({name, scatter_sharding, compute_cost,
ShardingStrategy({scatter_sharding, compute_cost,
communication_cost, memory_cost,
std::move(resharding_costs.first),
std::move(resharding_costs.second)}),
Expand Down Expand Up @@ -401,15 +401,14 @@ BuildStrategyAndCost(
double memory_cost =
ByteSizeOfShapeWithSharding(gather_shape, output_sharding);
InputShardings input_shardings_optional(
{data_sharding, indices_sharding});
{output_sharding.ToString(), {data_sharding, indices_sharding}});
std::pair<ReshardingCosts, ReshardingCosts> resharding_costs =
GenerateReshardingCostsAndMissingShardingsForAllOperands(
ins, output_sharding, strategy_map, cluster_env, call_graph,
input_shardings_optional);

strategy_group->AddStrategy(
ShardingStrategy({std::string(output_sharding.ToString()),
output_sharding, compute_cost,
ShardingStrategy({output_sharding, compute_cost,
communication_cost, memory_cost,
std::move(resharding_costs.first),
std::move(resharding_costs.second)}),
Expand Down Expand Up @@ -569,14 +568,13 @@ BuildStrategyAndCost(
MemoryReshardingCostVector(src_strategy_group, operand->shape(),
input_spec, cluster_env);
strategy_group->AddStrategy(
ShardingStrategy({name,
output_spec,
ShardingStrategy({output_spec,
compute_cost,
communication_cost,
memory_cost,
{communication_resharding_costs},
{memory_resharding_costs}}),
{input_spec});
{name, {input_spec}});
}
break;
}
Expand Down Expand Up @@ -689,9 +687,9 @@ BuildStrategyAndCost(
if (k == follow_idx ||
ToString(ins->operand(k)->shape().dimensions()) ==
ToString(operand->shape().dimensions())) {
input_shardings.push_back(input_spec);
input_shardings.shardings.push_back(input_spec);
} else {
input_shardings.push_back(std::nullopt);
input_shardings.shardings.push_back(std::nullopt);
}
}
if (!output_spec.has_value()) {
Expand All @@ -707,11 +705,10 @@ BuildStrategyAndCost(
input_shardings);

strategy_group->AddStrategy(
ShardingStrategy({name, *output_spec, compute_cost,
communication_cost, memory_cost,
std::move(resharding_costs.first),
ShardingStrategy({*output_spec, compute_cost, communication_cost,
memory_cost, std::move(resharding_costs.first),
std::move(resharding_costs.second)}),
{input_spec});
{name, {input_spec}});
}

if (strategy_group->GetStrategies().empty()) {
Expand Down Expand Up @@ -1098,7 +1095,7 @@ BuildStrategyAndCost(
const InputShardings& input_shardings = strategy_input_shardings[iid];
const ShardingStrategy& strategy =
strategy_group->GetStrategyForInputShardings(iid);
if (strategy.name == stra_names[idx]) {
if (input_shardings.name == stra_names[idx]) {
new_strategies.push_back({strategy, input_shardings});
}
}
Expand Down
64 changes: 37 additions & 27 deletions xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,37 @@ using ReshardingCache =
ConstInstructionMap<std::vector<std::pair<HloSharding, HloInstruction*>>>;
// Resharding costs for each operand
using ReshardingCosts = std::vector<std::vector<double>>;
// Optional shardings for each operand
using InputShardings = std::vector<std::optional<HloSharding>>;

// A named vector of optional shardings for each operand.
struct InputShardings {
std::string name;
std::vector<std::optional<HloSharding>> shardings;

std::string ToString() const {
std::string str = absl::StrCat(name, " ");
for (const auto& s : shardings) {
if (!s.has_value()) {
absl::StrAppend(&str, "[*],");
} else if (s->IsReplicated()) {
absl::StrAppend(&str, "[R],");
} else {
if (s->ReplicateOnLastTileDim()) {
absl::StrAppend(
&str, "[", absl::StrJoin(s->tile_assignment().dimensions(), ", "),
"]last_tile_dim_replicate,");
} else {
absl::StrAppend(
&str, "[", absl::StrJoin(s->tile_assignment().dimensions(), ", "),
"],");
}
}
}
return str;
}
};

// One sharding strategy
struct ShardingStrategy {
std::string name;
HloSharding output_sharding;
double compute_cost;
double communication_cost;
Expand All @@ -94,9 +119,7 @@ struct ShardingStrategy {
ReshardingCosts communication_resharding_costs;
ReshardingCosts memory_resharding_costs;

std::string ToString() const {
return absl::StrCat(name, ", ", output_sharding.ToString());
}
std::string ToString() const { return output_sharding.ToString(); }

std::string ToStringLong() const {
std::vector<std::string> communication_resharding_vector_strings;
Expand All @@ -119,15 +142,15 @@ struct ShardingStrategy {
"{", absl::StrJoin(memory_resharding_vector_strings, ", "), "}");

return absl::StrCat(
name, ", ", output_sharding.ToString(), ", compute_cost=", compute_cost,
output_sharding.ToString(), ", compute_cost=", compute_cost,
", communication_cost=", communication_cost,
", memory_cost=", memory_cost,
", communication_resharding_costs=", communication_resharding_cost_str,
", memory_resharding_costs=", memory_resharding_cost_str);
}

bool operator==(const ShardingStrategy& other) const {
return name == other.name && output_sharding == other.output_sharding &&
return output_sharding == other.output_sharding &&
compute_cost == other.compute_cost &&
communication_cost == other.communication_cost &&
memory_cost == other.memory_cost &&
Expand Down Expand Up @@ -221,25 +244,8 @@ struct StrategyGroup {
}
if (!is_tuple) {
for (const auto& input_shardings : strategy_input_shardings) {
std::string input_sharding_str = "{";
for (const auto& s : input_shardings) {
if (!s.has_value()) {
input_sharding_str += "[*],";
} else if (s->IsReplicated()) {
input_sharding_str += "[R],";
} else {
if (s->ReplicateOnLastTileDim()) {
input_sharding_str +=
"[" + absl::StrJoin(s->tile_assignment().dimensions(), ", ") +
"]last_tile_dim_replicate,";
} else {
input_sharding_str +=
"[" + absl::StrJoin(s->tile_assignment().dimensions(), ", ") +
"],";
}
}
}
input_sharding_str += "}\n";
const std::string input_sharding_str =
absl::StrCat("{", input_shardings.ToString(), "}\n");
absl::StrAppend(&str, indent, "Input Sharding ", input_sharding_str);
}
}
Expand Down Expand Up @@ -313,6 +319,10 @@ struct StrategyGroup {
return strategies[strategy_idx];
}

size_t GetStrategyIdxForInputShardings(size_t input_sharding_idx) const {
return input_sharding_idx_to_strategy_idx[input_sharding_idx];
}

const InputShardings& GetInputShardings(size_t input_sharding_idx) const {
return strategy_input_shardings[input_sharding_idx];
}
Expand Down
4 changes: 2 additions & 2 deletions xla/hlo/experimental/auto_sharding/auto_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -851,8 +851,8 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) {
continue;
}
std::string key = strategy.output_sharding.ToString();
if (!input_shardings.empty()) {
for (const auto& sharding : input_shardings) {
if (!input_shardings.shardings.empty()) {
for (const auto& sharding : input_shardings.shardings) {
key += "/" + (sharding.has_value() ? sharding->ToString() : "none");
}
}
Expand Down

0 comments on commit 8b7f1e0

Please sign in to comment.