Skip to content

Commit

Permalink
Simplify error handling in auto-sharding.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681493654
  • Loading branch information
Google-ML-Automation committed Oct 2, 2024
1 parent 93be085 commit 323cd24
Showing 1 changed file with 10 additions and 31 deletions.
41 changes: 10 additions & 31 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3790,18 +3790,14 @@ absl::StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(

// ----- Call the ILP Solver -----
std::string request_name = absl::StrCat("mesh_idx_", mesh_idx);
auto solver_result =
spmd::AutoShardingSolverResult solver_result =
Solve(*module, *hlo_live_range, strategy_map, strategy_groups,
cost_graph, alias_set, reduced_node_intervals,
reduced_edge_intervals, reduced_node_groups, reduced_edge_groups,
option_, request_name, sharding_propagation_solution);
if (solver_result.skip_auto_sharding) {
return AutoShardingResult::kModuleUnchangedNoShardingPerformed;
} else if (!solver_result.status.ok()) {
return AutoShardingResult::kModuleUnchanged;
}
TF_ASSIGN_OR_RETURN(spmd::AutoShardingSolverOutput output,
solver_result.status);

if (mesh_idx == partial_mesh_shapes.size() - 1) {
this->solver_optimal_objective_value_ = output.cost;
}
Expand All @@ -3823,21 +3819,14 @@ absl::StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(
output.s_val, (mesh_idx == partial_mesh_shapes.size() - 1));

if (mesh_idx == partial_mesh_shapes.size() - 1) {
if (!spmd::SetHloShardingPostProcessing(sequence, instructions_to_shard,
preserve_shardings)
.ok()) {
return AutoShardingResult::kModuleUnchanged;
}

if (!InsertReshardReshapes(
sequence, instructions_to_shard, strategy_map, cost_graph,
output.s_val, cluster_env,
/* crash_at_error */ !option_.try_multiple_mesh_shapes,
option_.insert_resharding_reshapes_for_non_dot_ops,
preserve_shardings)
.ok()) {
return AutoShardingResult::kModuleUnchanged;
}
TF_RETURN_IF_ERROR(spmd::SetHloShardingPostProcessing(
sequence, instructions_to_shard, preserve_shardings));
TF_RETURN_IF_ERROR(InsertReshardReshapes(
sequence, instructions_to_shard, strategy_map, cost_graph,
output.s_val, cluster_env,
/* crash_at_error */ !option_.try_multiple_mesh_shapes,
option_.insert_resharding_reshapes_for_non_dot_ops,
preserve_shardings));
} else {
spmd::RecoverShardingsFromPartialMesh(sequence, preserve_shardings);
}
Expand Down Expand Up @@ -4117,7 +4106,6 @@ absl::StatusOr<bool> AutoSharding::Run(
double min_objective_value = std::numeric_limits<double>::max();
int min_mesh_shape_index = -1;
std::unique_ptr<HloModule> min_mesh_shape_module;
bool skip_auto_sharding = true;
for (size_t i = 0; i < mesh_shapes.size(); ++i) {
VLOG(1) << "Trying mesh shape " << spmd::ToString(mesh_shapes[i]);
AutoShardingOption this_option = option_;
Expand Down Expand Up @@ -4150,15 +4138,6 @@ absl::StatusOr<bool> AutoSharding::Run(
min_objective_value = this_mesh_objective_value;
min_mesh_pass_result = pass_result;
}
if (*pass_result !=
AutoShardingResult::kModuleUnchangedNoShardingPerformed) {
skip_auto_sharding = false;
}
}

if (skip_auto_sharding) {
RecordPassEndAndDumpModule(start_time, module);
LOG(FATAL) << "The auto-sharding solver has timed out without a solution.";
}

std::string trying_to_find =
Expand Down

0 comments on commit 323cd24

Please sign in to comment.