Skip to content

Commit

Permalink
Decouples strategies from their associated input shardings.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675096103
  • Loading branch information
Google-ML-Automation committed Sep 17, 2024
1 parent 8e1e8a9 commit ba9a668
Show file tree
Hide file tree
Showing 7 changed files with 238 additions and 208 deletions.
247 changes: 123 additions & 124 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions xla/hlo/experimental/auto_sharding/auto_sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,7 @@ std::pair<ReshardingCosts, ReshardingCosts>
GenerateReshardingCostsAndMissingShardingsForAllOperands(
const HloInstruction* ins, const HloSharding& output_sharding,
const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env,
const CallGraph& call_graph,
std::vector<std::optional<HloSharding>>& input_shardings);
const CallGraph& call_graph, InputShardings& input_shardings);

std::unique_ptr<StrategyGroup> MaybeFollowInsStrategyGroup(
const StrategyGroup& src_strategy_group, const Shape& shape,
Expand Down
30 changes: 29 additions & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,20 @@ inline const ShardingStrategy& GetShardingStrategy(
return strategy_group->GetStrategies()[stra_idx];
}

// Get the input shardings according to the ILP solution.
inline const InputShardings& GetInputShardings(
const HloInstruction* inst, const StrategyMap& strategy_map,
const CostGraph& cost_graph, absl::Span<const NodeStrategyIdx> s_val) {
const StrategyGroup* strategy_group = strategy_map.at(inst).get();
CHECK(!strategy_group->is_tuple);
NodeIdx node_idx = strategy_group->node_idx;
NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]);
return strategy_group->GetInputShardings(stra_idx);
}

// Get the final sharding strategy according to the ILP solution.
inline const ShardingStrategy& GetShardingStrategyForTuple(
const HloInstruction* inst, ShapeIndex index,
const HloInstruction* inst, const ShapeIndex& index,
const StrategyMap& strategy_map, const CostGraph& cost_graph,
absl::Span<const NodeStrategyIdx> s_val) {
const StrategyGroup* strategy_group = strategy_map.at(inst).get();
Expand All @@ -156,6 +167,23 @@ inline const ShardingStrategy& GetShardingStrategyForTuple(
return strategy_group->GetStrategies()[stra_idx];
}

// Get the input shardings according to the ILP solution.
inline const InputShardings& GetInputShardingsForTuple(
const HloInstruction* inst, const ShapeIndex& index,
const StrategyMap& strategy_map, const CostGraph& cost_graph,
absl::Span<const NodeStrategyIdx> s_val) {
const StrategyGroup* strategy_group = strategy_map.at(inst).get();
CHECK(strategy_group->is_tuple);
for (auto index_element : index) {
CHECK_LT(index_element, strategy_group->GetChildren().size());
const auto& strategies = strategy_group->GetChildren()[index_element];
strategy_group = strategies.get();
}
NodeIdx node_idx = strategy_group->node_idx;
NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]);
return strategy_group->GetInputShardings(stra_idx);
}

} // namespace spmd
} // namespace xla

Expand Down
40 changes: 22 additions & 18 deletions xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include <optional>
#include <string>
#include <tuple>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
Expand Down Expand Up @@ -339,17 +340,13 @@ void HandlerBase::AppendNewStrategy(const std::string& name,
operand_strategy_group, operand_shape, input_specs[i], cluster_env_));
}

strategy_group_->AddStrategy(ShardingStrategy({
name,
output_spec,
compute_cost,
communication_cost,
static_cast<double>(
ByteSizeOfShapeWithSharding(ins_->shape(), output_spec)),
communication_resharding_costs,
memory_resharding_costs,
{input_specs.begin(), input_specs.end()},
}));
strategy_group_->AddStrategy(
ShardingStrategy({name, output_spec, compute_cost, communication_cost,
static_cast<double>(ByteSizeOfShapeWithSharding(
ins_->shape(), output_spec)),
communication_resharding_costs,
memory_resharding_costs}),
{input_specs.begin(), input_specs.end()});
}

// Given lhs and rhs dim maps, infers a sharding for the output by relying
Expand Down Expand Up @@ -462,18 +459,25 @@ std::optional<HloSharding> HandlerBase::GetShardingFromUser(
}

void HandlerBase::SortStrategies() {
auto strategies = strategy_group_->GetStrategies();
std::vector<std::pair<ShardingStrategy, InputShardings>> strategy_shardings;
for (size_t sid = 0; sid < strategy_group_->GetStrategies().size(); ++sid) {
const ShardingStrategy& strategy = strategy_group_->GetStrategy(sid);
const auto& input_shardings = strategy_group_->GetInputShardings(sid);
strategy_shardings.push_back({strategy, input_shardings});
}
absl::c_stable_sort(
strategies, [](const ShardingStrategy& s1, const ShardingStrategy& s2) {
if (s1.memory_cost == s2.memory_cost) {
return s1.name < s2.name;
strategy_shardings,
[](const std::pair<ShardingStrategy, InputShardings>& s1,
const std::pair<ShardingStrategy, InputShardings>& s2) {
if (s1.first.memory_cost == s2.first.memory_cost) {
return s1.first.name < s2.first.name;
} else {
return s1.memory_cost < s2.memory_cost;
return s1.first.memory_cost < s2.first.memory_cost;
}
});
strategy_group_->ClearStrategies();
for (const ShardingStrategy& strategy : strategies) {
strategy_group_->AddStrategy(strategy);
for (const auto& [strategy, input_shardings] : strategy_shardings) {
strategy_group_->AddStrategy(strategy, input_shardings);
}
}

Expand Down
56 changes: 30 additions & 26 deletions xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -336,17 +336,19 @@ BuildStrategyAndCost(
double memory_cost =
ByteSizeOfShapeWithSharding(ins->shape(), scatter_sharding);

std::vector<std::optional<HloSharding>> input_shardings_optional(
InputShardings input_shardings_optional(
{data_sharding, indices_sharding, update_sharding});
std::pair<ReshardingCosts, ReshardingCosts> resharding_costs =
GenerateReshardingCostsAndMissingShardingsForAllOperands(
ins, scatter_sharding, strategy_map, cluster_env, call_graph,
input_shardings_optional);

strategy_group->AddStrategy(ShardingStrategy(
{name, scatter_sharding, compute_cost, communication_cost,
memory_cost, std::move(resharding_costs.first),
std::move(resharding_costs.second), input_shardings_optional}));
strategy_group->AddStrategy(
ShardingStrategy({name, scatter_sharding, compute_cost,
communication_cost, memory_cost,
std::move(resharding_costs.first),
std::move(resharding_costs.second)}),
input_shardings_optional);
};

const HloScatterInstruction* scatter = Cast<HloScatterInstruction>(ins);
Expand Down Expand Up @@ -388,18 +390,20 @@ BuildStrategyAndCost(
double compute_cost = 0, communication_cost = 0;
double memory_cost =
ByteSizeOfShapeWithSharding(gather_shape, output_sharding);
std::vector<std::optional<HloSharding>> input_shardings_optional(
InputShardings input_shardings_optional(
{data_sharding, indices_sharding});
std::pair<ReshardingCosts, ReshardingCosts> resharding_costs =
GenerateReshardingCostsAndMissingShardingsForAllOperands(
ins, output_sharding, strategy_map, cluster_env, call_graph,
input_shardings_optional);

strategy_group->AddStrategy(ShardingStrategy(
{std::string(output_sharding.ToString()), output_sharding,
compute_cost, communication_cost, memory_cost,
std::move(resharding_costs.first),
std::move(resharding_costs.second), input_shardings_optional}));
strategy_group->AddStrategy(
ShardingStrategy({std::string(output_sharding.ToString()),
output_sharding, compute_cost,
communication_cost, memory_cost,
std::move(resharding_costs.first),
std::move(resharding_costs.second)}),
input_shardings_optional);
};

for (const ShardingStrategy& indices_strategy :
Expand Down Expand Up @@ -562,8 +566,8 @@ BuildStrategyAndCost(
communication_cost,
memory_cost,
{communication_resharding_costs},
{memory_resharding_costs},
{input_spec}}));
{memory_resharding_costs}}),
{input_spec});
}
break;
}
Expand Down Expand Up @@ -671,7 +675,7 @@ BuildStrategyAndCost(
}

// Get a list of input shardings, each corresponds to an operand.
std::vector<std::optional<HloSharding>> input_shardings;
InputShardings input_shardings;
for (int64_t k = 0; k < ins->operand_count(); ++k) {
if (k == follow_idx ||
ToString(ins->operand(k)->shape().dimensions()) ==
Expand All @@ -694,14 +698,11 @@ BuildStrategyAndCost(
input_shardings);

strategy_group->AddStrategy(
ShardingStrategy({name,
*output_spec,
compute_cost,
communication_cost,
memory_cost,
ShardingStrategy({name, *output_spec, compute_cost,
communication_cost, memory_cost,
std::move(resharding_costs.first),
std::move(resharding_costs.second),
{input_spec}}));
std::move(resharding_costs.second)}),
{input_spec});
}

if (strategy_group->GetStrategies().empty()) {
Expand Down Expand Up @@ -1082,16 +1083,19 @@ BuildStrategyAndCost(
auto it = absl::c_find(inst_indices, strategy_group->node_idx);
if (it != inst_indices.end()) {
CHECK(!strategy_group->is_tuple);
std::vector<ShardingStrategy> new_strategies;
std::vector<std::pair<ShardingStrategy, InputShardings>> new_strategies;
int64_t idx = it - inst_indices.begin();
for (const auto& strategy : strategy_group->GetStrategies()) {
const auto& strategies = strategy_group->GetStrategies();
for (size_t sid = 0; sid < strategies.size(); ++sid) {
const ShardingStrategy& strategy = strategy_group->GetStrategy(sid);
const auto& input_shardings = strategy_group->GetInputShardings(sid);
if (strategy.name == stra_names[idx]) {
new_strategies.push_back(strategy);
new_strategies.push_back({strategy, input_shardings});
}
}
strategy_group->ClearStrategies();
for (const ShardingStrategy& strategy : new_strategies) {
strategy_group->AddStrategy(strategy);
for (const auto& [strategy, input_shardings] : new_strategies) {
strategy_group->AddStrategy(strategy, input_shardings);
}
}
}
Expand Down
49 changes: 21 additions & 28 deletions xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ using ReshardingCache =
ConstInstructionMap<std::vector<std::pair<HloSharding, HloInstruction*>>>;
// Resharding costs for each operand
using ReshardingCosts = std::vector<std::vector<double>>;
// Optional shardings for each operand
using InputShardings = std::vector<std::optional<HloSharding>>;

// One sharding strategy
struct ShardingStrategy {
Expand All @@ -89,9 +91,6 @@ struct ShardingStrategy {
// cost from i-th tuple element's j-th strategy.
ReshardingCosts communication_resharding_costs;
ReshardingCosts memory_resharding_costs;
// Optional: the required shardings of operands.
// This is used to guide the SPMD partitioner.
std::vector<std::optional<HloSharding>> input_shardings;

std::string ToString() const {
return absl::StrCat(name, ", ", output_sharding.ToString());
Expand All @@ -117,32 +116,12 @@ struct ShardingStrategy {
std::string memory_resharding_cost_str = absl::StrCat(
"{", absl::StrJoin(memory_resharding_vector_strings, ", "), "}");

std::string input_sharding_str = "{";
for (const auto& s : input_shardings) {
if (!s.has_value()) {
input_sharding_str += "[*],";
} else if (s->IsReplicated()) {
input_sharding_str += "[R],";
} else {
if (s->ReplicateOnLastTileDim()) {
input_sharding_str +=
"[" + absl::StrJoin(s->tile_assignment().dimensions(), ", ") +
"]last_tile_dim_replicate,";
} else {
input_sharding_str +=
"[" + absl::StrJoin(s->tile_assignment().dimensions(), ", ") +
"],";
}
}
}
input_sharding_str += "}\n";
return absl::StrCat(
name, ", ", output_sharding.ToString(), ", compute_cost=", compute_cost,
", communication_cost=", communication_cost,
", memory_cost=", memory_cost,
", communication_resharding_costs=", communication_resharding_cost_str,
", memory_resharding_costs=", memory_resharding_cost_str,
", input_shardings=", input_sharding_str);
", memory_resharding_costs=", memory_resharding_cost_str);
}

bool operator==(const ShardingStrategy& other) const {
Expand All @@ -152,8 +131,7 @@ struct ShardingStrategy {
memory_cost == other.memory_cost &&
communication_resharding_costs ==
other.communication_resharding_costs &&
memory_resharding_costs == other.memory_resharding_costs &&
input_shardings == other.input_shardings;
memory_resharding_costs == other.memory_resharding_costs;
}
};

Expand Down Expand Up @@ -273,20 +251,33 @@ struct StrategyGroup {

//////// Accessor methods for strategies ////////

void AddStrategy(const ShardingStrategy& strategy) {
void AddStrategy(const ShardingStrategy& strategy,
const InputShardings& input_shardings = {}) {
strategies.push_back(strategy);
strategy_input_shardings.push_back(input_shardings);
}

void ClearStrategies() { strategies.clear(); }
void ClearStrategies() {
strategies.clear();
strategy_input_shardings.clear();
}

ShardingStrategy& GetStrategy(size_t strategy_idx) {
return strategies[strategy_idx];
}

const InputShardings& GetInputShardings(size_t strategy_idx) const {
return strategy_input_shardings[strategy_idx];
}

const std::vector<ShardingStrategy>& GetStrategies() const {
return strategies;
}

const std::vector<InputShardings>& GetStrategyInputShardings() const {
return strategy_input_shardings;
}

//////// Accessor methods for children ////////

void AddChild(std::unique_ptr<StrategyGroup> child) {
Expand All @@ -305,6 +296,8 @@ struct StrategyGroup {
// Used when is_tuple == False. Leaf strategy vector.
// A vector of strategy choices for the non-tuple output.
std::vector<ShardingStrategy> strategies;
std::vector<InputShardings> strategy_input_shardings;

// Used when is_tuple == True. A vector of pointers, each pointer is one
// StrategyGroup for one value in the output Tuple
std::vector<std::unique_ptr<StrategyGroup>> children;
Expand Down
21 changes: 12 additions & 9 deletions xla/hlo/experimental/auto_sharding/auto_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -836,18 +836,21 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) {
if (strategy_group.following || strategy_group.GetStrategies().empty()) {
return;
}
std::vector<ShardingStrategy> new_vector;
std::vector<ShardingStrategy> deduped_replicated_strategies;
std::vector<std::pair<ShardingStrategy, InputShardings>> new_vector;
std::vector<std::pair<ShardingStrategy, InputShardings>>
deduped_replicated_strategies;
absl::flat_hash_set<std::string> added;
size_t num_skipped_due_to_infinity_costs = 0;
for (const ShardingStrategy& strategy : strategy_group.GetStrategies()) {
for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) {
const ShardingStrategy& strategy = strategy_group.GetStrategy(sid);
if (AllInfinityCosts(strategy.communication_resharding_costs)) {
num_skipped_due_to_infinity_costs++;
continue;
}
std::string key = strategy.output_sharding.ToString();
if (!strategy.input_shardings.empty()) {
for (const auto& sharding : strategy.input_shardings) {
const auto& input_shardings = strategy_group.GetInputShardings(sid);
if (!input_shardings.empty()) {
for (const auto& sharding : input_shardings) {
key += "/" + (sharding.has_value() ? sharding->ToString() : "none");
}
}
Expand All @@ -856,9 +859,9 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) {
}
added.insert(key);
if (!strategy.output_sharding.IsReplicated()) {
new_vector.push_back(strategy);
new_vector.push_back({strategy, input_shardings});
} else {
deduped_replicated_strategies.push_back(strategy);
deduped_replicated_strategies.push_back({strategy, input_shardings});
}
}
CHECK_LT(num_skipped_due_to_infinity_costs,
Expand All @@ -871,8 +874,8 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) {
}
}
strategy_group.ClearStrategies();
for (const ShardingStrategy& strategy : new_vector) {
strategy_group.AddStrategy(strategy);
for (const auto& [strategy, input_shardings] : new_vector) {
strategy_group.AddStrategy(strategy, input_shardings);
}
}

Expand Down

0 comments on commit ba9a668

Please sign in to comment.