Skip to content

Commit

Permalink
Adds accessor methods to StrategyGroup (so that clients can't directl…
Browse files Browse the repository at this point in the history
…y manipulate the vectors containing sharding strategies and child groups).

PiperOrigin-RevId: 675271442
  • Loading branch information
Google-ML-Automation committed Sep 16, 2024
1 parent 6683d9a commit f4a8c36
Show file tree
Hide file tree
Showing 8 changed files with 658 additions and 659 deletions.
805 changes: 375 additions & 430 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc

Large diffs are not rendered by default.

75 changes: 34 additions & 41 deletions xla/hlo/experimental/auto_sharding/auto_sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h"
Expand Down Expand Up @@ -131,12 +132,12 @@ HloSharding Tile(const Shape& shape, absl::Span<const int64_t> tensor_dims,
const DeviceMesh& device_mesh);

std::vector<double> CommunicationReshardingCostVector(
const StrategyGroup* strategy_group, const Shape& shape,
const StrategyGroup& strategy_group, const Shape& shape,
const HloSharding& required_sharding,
const ClusterEnvironment& cluster_env);

std::vector<double> MemoryReshardingCostVector(
const StrategyGroup* strategy_group, const Shape& operand_shape,
const StrategyGroup& strategy_group, const Shape& operand_shape,
const HloSharding& required_sharding,
const ClusterEnvironment& cluster_env);

Expand All @@ -146,17 +147,13 @@ std::unique_ptr<StrategyGroup> CreateLeafStrategyGroup(
size_t instruction_id, const HloInstruction* ins,
const StrategyMap& strategy_map, StrategyGroups& strategy_groups);

void SetInNodesWithInstruction(std::unique_ptr<StrategyGroup>& strategy_group,
const HloInstruction* ins,
const StrategyMap& strategy_map);

void RemoveDuplicatedStrategy(std::unique_ptr<StrategyGroup>& strategy_group);
void RemoveDuplicatedStrategy(StrategyGroup& strategy_group);

absl::Status FilterStrategy(const HloInstruction* ins, const Shape& shape,
std::unique_ptr<StrategyGroup>& strategy_group,
const ClusterEnvironment& cluster_env,
const InstructionBatchDimMap& batch_map,
const AutoShardingOption& option);
const AutoShardingOption& option,
StrategyGroup& strategy_group);

absl::Status HandleDot(std::unique_ptr<StrategyGroup>& strategy_group,
StrategyGroups& strategy_groups,
Expand Down Expand Up @@ -242,10 +239,11 @@ void PopulateTemporalValues(const CostGraph& cost_graph,
void AddReplicatedStrategy(
const HloInstruction* ins, const Shape& shape,
const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map,
std::unique_ptr<StrategyGroup>& strategy_group, double replicated_penalty,
absl::flat_hash_set<int64_t> operands_to_consider_all_strategies_for = {});
double replicated_penalty,
absl::flat_hash_set<int64_t> operands_to_consider_all_strategies_for,
StrategyGroup& strategy_group);

void CheckMemoryCosts(StrategyGroup* strategy_group, const Shape& shape);
void CheckMemoryCosts(const StrategyGroup& strategy_group, const Shape& shape);

// Choose an operand to follow. We choose to follow the operand with the highest
// priority.
Expand All @@ -254,13 +252,12 @@ std::pair<int64_t, bool> ChooseOperandToFollow(
const AliasMap& alias_map, int64_t max_depth, const HloInstruction* ins);

void FillAllStrategiesForArray(
std::unique_ptr<StrategyGroup>& strategy_group, const HloInstruction* ins,
const Shape& shape, const ClusterEnvironment& cluster_env,
const StrategyMap& strategy_map, const AutoShardingOption& option,
double replicated_penalty, const InstructionBatchDimMap& batch_dim_map,
const CallGraph& call_graph, bool only_allow_divisible,
bool create_replicated_strategies,
bool create_partially_replicated_strategies);
const HloInstruction* ins, const Shape& shape,
const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map,
const AutoShardingOption& option, double replicated_penalty,
const InstructionBatchDimMap& batch_dim_map, const CallGraph& call_graph,
bool only_allow_divisible, bool create_replicated_strategies,
bool create_partially_replicated_strategies, StrategyGroup& strategy_group);

absl::StatusOr<std::unique_ptr<StrategyGroup>> CreateAllStrategiesGroup(
const HloInstruction* ins, const Shape& shape, size_t instruction_id,
Expand Down Expand Up @@ -313,22 +310,19 @@ void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape,
const DeviceMesh& device_mesh,
const ClusterEnvironment& cluster_env,
const StrategyMap& strategy_map,
std::unique_ptr<StrategyGroup>& strategy_group,
bool only_allow_divisible,
const std::string& suffix,
const CallGraph& call_graph);
const CallGraph& call_graph,
StrategyGroup& strategy_group);

// Enumerate all partitions recursively.
void EnumerateAllPartition(const HloInstruction* ins, const Shape& shape,
const DeviceMesh& device_mesh,
const ClusterEnvironment& cluster_env,
const StrategyMap& strategy_map,
std::unique_ptr<StrategyGroup>& strategy_group,
const InstructionBatchDimMap& batch_dim_map,
bool only_allow_divisible,
const CallGraph& call_graph,
int64_t partition_dimensions,
const std::vector<int64_t>& tensor_dims = {});
void EnumerateAllPartition(
const HloInstruction* ins, const Shape& shape,
const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env,
const StrategyMap& strategy_map,
const InstructionBatchDimMap& batch_dim_map, bool only_allow_divisible,
const CallGraph& call_graph, int64_t partition_dimensions,
const std::vector<int64_t>& tensor_dims, StrategyGroup& strategy_group);

absl::StatusOr<std::unique_ptr<StrategyGroup>> FollowReduceStrategy(
const HloInstruction* ins, const Shape& output_shape,
Expand All @@ -340,8 +334,8 @@ absl::StatusOr<std::unique_ptr<StrategyGroup>> FollowReduceStrategy(
void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape,
const ClusterEnvironment& cluster_env,
const StrategyMap& strategy_map,
std::unique_ptr<StrategyGroup>& strategy_group,
double replicated_penalty);
double replicated_penalty,
StrategyGroup& strategy_group);

std::pair<ReshardingCosts, ReshardingCosts>
GenerateReshardingCostsAndMissingShardingsForAllOperands(
Expand All @@ -351,28 +345,27 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands(
std::vector<std::optional<HloSharding>>& input_shardings);

std::unique_ptr<StrategyGroup> MaybeFollowInsStrategyGroup(
const StrategyGroup* src_strategy_group, const Shape& shape,
const StrategyGroup& src_strategy_group, const Shape& shape,
size_t instruction_id, StrategyGroups& strategy_groups,
const ClusterEnvironment& cluster_env,
const StableMap<NodeIdx, std::vector<ShardingStrategy>>&
pretrimmed_strategy_map);

void RemoveShardingsWhereSmallDimsShardedAcrossManyDevices(
const Shape& shape, StrategyGroup* strategy_group,
bool instruction_has_user_sharding);
const Shape& shape, bool instruction_has_user_sharding,
StrategyGroup& strategy_group);

void ScaleCostsWithExecutionCounts(StrategyGroup* strategy_group,
int64_t execution_count);
void ScaleCostsWithExecutionCounts(int64_t execution_count,
StrategyGroup& strategy_group);

// Existing shardings refer to the HloSharding field in the given
// HloInstruction.
void TrimOrGenerateStrategiesBasedOnExistingSharding(
const Shape& output_shape, StrategyGroup* strategy_group,
const StrategyMap& strategy_map,
const Shape& output_shape, const StrategyMap& strategy_map,
const std::vector<HloInstruction*>& instructions,
const HloSharding& existing_sharding, const ClusterEnvironment& cluster_env,
StableMap<int64_t, std::vector<ShardingStrategy>>& pretrimmed_strategy_map,
const CallGraph& call_graph, bool strict);
const CallGraph& call_graph, bool strict, StrategyGroup& strategy_group);

// Build possible sharding strategies and their costs for all instructions.
absl::StatusOr<std::tuple<StrategyMap, StrategyGroups, AssociativeDotPairs>>
Expand Down
32 changes: 16 additions & 16 deletions xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ CostGraph::CostGraph(const StrategyGroups& strategy_groups,

// Build the cost graph.
for (StrategyGroup* strategy_group : strategy_groups) {
node_lens_.push_back(strategy_group->strategies.size());
node_lens_.push_back(strategy_group->GetStrategies().size());
extra_node_costs_.push_back(
std::vector<double>(strategy_group->strategies.size(), 0.0));
std::vector<double>(strategy_group->GetStrategies().size(), 0.0));

const auto& in_nodes = strategy_group->in_nodes;
for (size_t i = 0; i < in_nodes.size(); ++i) {
Expand All @@ -74,8 +74,8 @@ CostGraph::CostGraph(const StrategyGroups& strategy_groups,
CreateEdgeCost(src_idx, dst_idx, i, strategy_group);
AddEdgeCost(src_idx, dst_idx, edge_cost);
} else if (in_nodes[i]->is_tuple && in_nodes.size() > 1) {
for (size_t l = 0; l < in_nodes[i]->childs.size(); ++l) {
NodeIdx src_idx = in_nodes[i]->childs[l]->node_idx;
for (const auto& child : in_nodes[i]->GetChildren()) {
NodeIdx src_idx = child->node_idx;
NodeIdx dst_idx = strategy_group->node_idx;
EdgeReshardingCostMatrix edge_cost =
CreateEdgeCost(src_idx, dst_idx, i, strategy_group, true);
Expand All @@ -86,8 +86,8 @@ CostGraph::CostGraph(const StrategyGroups& strategy_groups,
<< "Do not support instructions with more than one tuple "
"operand. If this CHECK fails, we will need to fix "
"b/233412625.";
for (size_t l = 0; l < in_nodes[i]->childs.size(); ++l) {
NodeIdx src_idx = in_nodes[i]->childs[l]->node_idx;
for (size_t l = 0; l < in_nodes[i]->GetChildren().size(); ++l) {
NodeIdx src_idx = in_nodes[i]->GetChildren()[l]->node_idx;
NodeIdx dst_idx = strategy_group->node_idx;
// TODO(b/233412625) Support more general case, e.g., multiple tuple
// operands. If there is only one operand and it's a tuple, the
Expand All @@ -101,8 +101,8 @@ CostGraph::CostGraph(const StrategyGroups& strategy_groups,
}

if (strategy_group->following) {
CHECK_EQ(strategy_group->strategies.size(),
strategy_group->following->strategies.size())
CHECK_EQ(strategy_group->GetStrategies().size(),
strategy_group->following->GetStrategies().size())
<< "Different strategy counts for instruction ID "
<< strategy_group->instruction_id << " and following instruction ID "
<< strategy_group->following->instruction_id;
Expand All @@ -116,26 +116,25 @@ CostGraph::CostGraph(const StrategyGroups& strategy_groups,
for (const auto& pair : associative_dot_pairs) {
NodeIdx src_idx = pair.first->node_idx;
NodeIdx dst_idx = pair.second->node_idx;
StrategyGroup& src_strategy_group = *strategy_groups[src_idx];
StrategyGroup& dst_strategy_group = *strategy_groups[dst_idx];

EdgeReshardingCostMatrix edge_cost(node_lens_[src_idx],
node_lens_[dst_idx]);
absl::flat_hash_map<std::string, NodeStrategyIdx>
src_strategy_name_to_idx_map;
for (NodeStrategyIdx i = 0; i < node_lens_[src_idx]; ++i) {
const ShardingStrategy& strategy =
strategy_groups[src_idx]->strategies[i];
const ShardingStrategy& strategy = src_strategy_group.GetStrategy(i);
if (strategy.communication_cost > 0) {
src_strategy_name_to_idx_map[strategy.name] = i;
}
}
for (NodeStrategyIdx i = 0; i < node_lens_[dst_idx]; ++i) {
const ShardingStrategy& dst_strategy =
strategy_groups[dst_idx]->strategies[i];
const ShardingStrategy& dst_strategy = dst_strategy_group.GetStrategy(i);
if (dst_strategy.communication_cost > 0) {
auto it = src_strategy_name_to_idx_map.find(dst_strategy.name);
if (it != src_strategy_name_to_idx_map.end()) {
const ShardingStrategy& src_strategy =
strategy_groups[src_idx]->strategies[it->second];
const auto& src_strategy = src_strategy_group.GetStrategy(it->second);
CHECK_LE(std::abs(src_strategy.communication_cost -
dst_strategy.communication_cost),
1e-6);
Expand All @@ -154,8 +153,9 @@ EdgeReshardingCostMatrix CostGraph::CreateEdgeCost(
CHECK_LT(src_idx, node_lens_.size());
CHECK_LT(dst_idx, node_lens_.size());
EdgeReshardingCostMatrix edge_cost(node_lens_[src_idx], node_lens_[dst_idx]);
for (NodeStrategyIdx k = 0; k < strategy_group->strategies.size(); ++k) {
const ShardingStrategy& strategy = strategy_group->strategies[k];
const auto& strategies = strategy_group->GetStrategies();
for (NodeStrategyIdx k = 0; k < strategies.size(); ++k) {
const ShardingStrategy& strategy = strategies[k];
size_t start_idx = 0;
CHECK_LT(in_node_idx, strategy.memory_resharding_costs.size())
<< strategy_group->node_idx;
Expand Down
8 changes: 4 additions & 4 deletions xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ inline const ShardingStrategy& GetShardingStrategy(
CHECK(!strategy_group->is_tuple);
NodeIdx node_idx = strategy_group->node_idx;
NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]);
return strategy_group->strategies[stra_idx];
return strategy_group->GetStrategies()[stra_idx];
}

// Get the final sharding strategy according to the ILP solution.
Expand All @@ -147,13 +147,13 @@ inline const ShardingStrategy& GetShardingStrategyForTuple(
const StrategyGroup* strategy_group = strategy_map.at(inst).get();
CHECK(strategy_group->is_tuple);
for (auto index_element : index) {
CHECK_LT(index_element, strategy_group->childs.size());
const auto& strategies = strategy_group->childs[index_element];
CHECK_LT(index_element, strategy_group->GetChildren().size());
const auto& strategies = strategy_group->GetChildren()[index_element];
strategy_group = strategies.get();
}
NodeIdx node_idx = strategy_group->node_idx;
NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]);
return strategy_group->strategies[stra_idx];
return strategy_group->GetStrategies()[stra_idx];
}

} // namespace spmd
Expand Down
24 changes: 14 additions & 10 deletions xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/array.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h"
Expand All @@ -51,6 +50,7 @@ limitations under the License.
#include "xla/service/dot_as_convolution_util.h"
#include "xla/service/hlo_cost_analysis.h"
#include "xla/service/sharding_propagation.h"
#include "xla/shape.h"
#include "tsl/platform/errors.h"

namespace xla {
Expand Down Expand Up @@ -331,15 +331,15 @@ void HandlerBase::AppendNewStrategy(const std::string& name,

for (int i = 0; i < ins_->operand_count(); ++i) {
const HloInstruction* operand = ins_->operand(i);
const Shape& operand_shape = operand->shape();
const StrategyGroup& operand_strategy_group = *strategy_map_.at(operand);
communication_resharding_costs.push_back(CommunicationReshardingCostVector(
strategy_map_.at(operand).get(), operand->shape(), input_specs[i],
cluster_env_));
operand_strategy_group, operand_shape, input_specs[i], cluster_env_));
memory_resharding_costs.push_back(MemoryReshardingCostVector(
strategy_map_.at(operand).get(), operand->shape(), input_specs[i],
cluster_env_));
operand_strategy_group, operand_shape, input_specs[i], cluster_env_));
}

strategy_group_->strategies.push_back(ShardingStrategy({
strategy_group_->AddStrategy(ShardingStrategy({
name,
output_spec,
compute_cost,
Expand Down Expand Up @@ -462,15 +462,19 @@ std::optional<HloSharding> HandlerBase::GetShardingFromUser(
}

void HandlerBase::SortStrategies() {
auto strategies = strategy_group_->GetStrategies();
absl::c_stable_sort(
strategy_group_->strategies,
[](const ShardingStrategy& s1, const ShardingStrategy& s2) {
strategies, [](const ShardingStrategy& s1, const ShardingStrategy& s2) {
if (s1.memory_cost == s2.memory_cost) {
return s1.name < s2.name;
} else {
return s1.memory_cost < s2.memory_cost;
}
});
strategy_group_->ClearStrategies();
for (const ShardingStrategy& strategy : strategies) {
strategy_group_->AddStrategy(strategy);
}
}

/************** DotHandler function definitions **************/
Expand Down Expand Up @@ -962,8 +966,8 @@ absl::Status ConvHandler::RegisterStrategies() {
// and only keep the data parallel strategies.
if (option_.force_batch_dim_to_mesh_dim >= 0 &&
batch_map_.contains(GetBatchDimMapKey(ins_))) {
TF_RETURN_IF_ERROR(FilterStrategy(ins_, ins_->shape(), strategy_group_,
cluster_env_, batch_map_, option_));
TF_RETURN_IF_ERROR(FilterStrategy(ins_, ins_->shape(), cluster_env_,
batch_map_, option_, *strategy_group_));
}

SortStrategies();
Expand Down
Loading

0 comments on commit f4a8c36

Please sign in to comment.