From 323cd24f28b8f936da92b4759814de8cd46c337b Mon Sep 17 00:00:00 2001 From: xla authors Date: Wed, 2 Oct 2024 10:02:13 -0700 Subject: [PATCH] Simplify error handling in auto-sharding. PiperOrigin-RevId: 681493654 --- .../auto_sharding/auto_sharding.cc | 41 +++++-------------- 1 file changed, 10 insertions(+), 31 deletions(-) diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 29ad335d888e2..82e2d64a7190a 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -3790,18 +3790,14 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( // ----- Call the ILP Solver ----- std::string request_name = absl::StrCat("mesh_idx_", mesh_idx); - auto solver_result = + spmd::AutoShardingSolverResult solver_result = Solve(*module, *hlo_live_range, strategy_map, strategy_groups, cost_graph, alias_set, reduced_node_intervals, reduced_edge_intervals, reduced_node_groups, reduced_edge_groups, option_, request_name, sharding_propagation_solution); - if (solver_result.skip_auto_sharding) { - return AutoShardingResult::kModuleUnchangedNoShardingPerformed; - } else if (!solver_result.status.ok()) { - return AutoShardingResult::kModuleUnchanged; - } TF_ASSIGN_OR_RETURN(spmd::AutoShardingSolverOutput output, solver_result.status); + if (mesh_idx == partial_mesh_shapes.size() - 1) { this->solver_optimal_objective_value_ = output.cost; } @@ -3823,21 +3819,14 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( output.s_val, (mesh_idx == partial_mesh_shapes.size() - 1)); if (mesh_idx == partial_mesh_shapes.size() - 1) { - if (!spmd::SetHloShardingPostProcessing(sequence, instructions_to_shard, - preserve_shardings) - .ok()) { - return AutoShardingResult::kModuleUnchanged; - } - - if (!InsertReshardReshapes( - sequence, instructions_to_shard, strategy_map, cost_graph, - output.s_val, cluster_env, - /* crash_at_error */ !option_.try_multiple_mesh_shapes, - option_.insert_resharding_reshapes_for_non_dot_ops, - preserve_shardings) - .ok()) { - return AutoShardingResult::kModuleUnchanged; - } + TF_RETURN_IF_ERROR(spmd::SetHloShardingPostProcessing( + sequence, instructions_to_shard, preserve_shardings)); + TF_RETURN_IF_ERROR(InsertReshardReshapes( + sequence, instructions_to_shard, strategy_map, cost_graph, + output.s_val, cluster_env, + /* crash_at_error */ !option_.try_multiple_mesh_shapes, + option_.insert_resharding_reshapes_for_non_dot_ops, + preserve_shardings)); } else { spmd::RecoverShardingsFromPartialMesh(sequence, preserve_shardings); } @@ -4117,7 +4106,6 @@ absl::StatusOr AutoSharding::Run( double min_objective_value = std::numeric_limits::max(); int min_mesh_shape_index = -1; std::unique_ptr min_mesh_shape_module; - bool skip_auto_sharding = true; for (size_t i = 0; i < mesh_shapes.size(); ++i) { VLOG(1) << "Trying mesh shape " << spmd::ToString(mesh_shapes[i]); AutoShardingOption this_option = option_; @@ -4150,15 +4138,6 @@ absl::StatusOr AutoSharding::Run( min_objective_value = this_mesh_objective_value; min_mesh_pass_result = pass_result; } - if (*pass_result != - AutoShardingResult::kModuleUnchangedNoShardingPerformed) { - skip_auto_sharding = false; - } - } - - if (skip_auto_sharding) { - RecordPassEndAndDumpModule(start_time, module); - LOG(FATAL) << "The auto-sharding solver has timed out without a solution."; } std::string trying_to_find =