From 8108e873b29077d729569fae037726cb27169596 Mon Sep 17 00:00:00 2001 From: xla authors Date: Wed, 2 Oct 2024 10:35:09 -0700 Subject: [PATCH] Remove AutoShardingResult in favor of a boolean now that the value kModuleUnchangedNoShardingPerformed of the enum is unused, effectively making it a boolean. Also simplified away some dead code. PiperOrigin-RevId: 681506949 --- .../auto_sharding/auto_sharding.cc | 85 ++++++------------- .../auto_sharding/auto_sharding.h | 8 +- .../auto_sharding/auto_sharding_test.cc | 3 +- 3 files changed, 29 insertions(+), 67 deletions(-) diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 29ad335d888e23..49958936358d33 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -3504,14 +3504,14 @@ std::pair ReduceMemoryTerms( return num_terms; } -absl::StatusOr AutoShardingImplementation::RunAutoSharding( +absl::StatusOr AutoShardingImplementation::RunAutoSharding( HloModule* module, const absl::flat_hash_set& replicated_small_tensors, const absl::flat_hash_set& execution_threads, const absl::flat_hash_map& sharding_propagation_solution) { if (!option_.enable) { - return AutoShardingResult::kModuleUnchanged; + return false; } bool module_is_changed = false; @@ -3790,18 +3790,14 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( // ----- Call the ILP Solver ----- std::string request_name = absl::StrCat("mesh_idx_", mesh_idx); - auto solver_result = + spmd::AutoShardingSolverResult solver_result = Solve(*module, *hlo_live_range, strategy_map, strategy_groups, cost_graph, alias_set, reduced_node_intervals, reduced_edge_intervals, reduced_node_groups, reduced_edge_groups, option_, request_name, sharding_propagation_solution); - if (solver_result.skip_auto_sharding) { - return AutoShardingResult::kModuleUnchangedNoShardingPerformed; - } else if (!solver_result.status.ok()) { - return AutoShardingResult::kModuleUnchanged; - } TF_ASSIGN_OR_RETURN(spmd::AutoShardingSolverOutput output, solver_result.status); + if (mesh_idx == partial_mesh_shapes.size() - 1) { this->solver_optimal_objective_value_ = output.cost; } @@ -3823,21 +3819,14 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( output.s_val, (mesh_idx == partial_mesh_shapes.size() - 1)); if (mesh_idx == partial_mesh_shapes.size() - 1) { - if (!spmd::SetHloShardingPostProcessing(sequence, instructions_to_shard, - preserve_shardings) - .ok()) { - return AutoShardingResult::kModuleUnchanged; - } - - if (!InsertReshardReshapes( - sequence, instructions_to_shard, strategy_map, cost_graph, - output.s_val, cluster_env, - /* crash_at_error */ !option_.try_multiple_mesh_shapes, - option_.insert_resharding_reshapes_for_non_dot_ops, - preserve_shardings) - .ok()) { - return AutoShardingResult::kModuleUnchanged; - } + TF_RETURN_IF_ERROR(spmd::SetHloShardingPostProcessing( + sequence, instructions_to_shard, preserve_shardings)); + TF_RETURN_IF_ERROR(InsertReshardReshapes( + sequence, instructions_to_shard, strategy_map, cost_graph, + output.s_val, cluster_env, + /* crash_at_error */ !option_.try_multiple_mesh_shapes, + option_.insert_resharding_reshapes_for_non_dot_ops, + preserve_shardings)); } else { spmd::RecoverShardingsFromPartialMesh(sequence, preserve_shardings); } @@ -3878,8 +3867,7 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( } } - return module_is_changed ? AutoShardingResult::kModuleChangedShardingPerformed - : AutoShardingResult::kModuleUnchanged; + return module_is_changed; } bool ModuleIsManuallyPartitioned(const HloModule* module) { @@ -4109,15 +4097,12 @@ absl::StatusOr AutoSharding::Run( } } - absl::StatusOr min_mesh_pass_result = - AutoShardingResult::kModuleUnchanged; - + bool module_is_changed = false; VLOG(1) << "Original mesh shape " << spmd::ToString(option_.device_mesh_shape); double min_objective_value = std::numeric_limits::max(); int min_mesh_shape_index = -1; std::unique_ptr min_mesh_shape_module; - bool skip_auto_sharding = true; for (size_t i = 0; i < mesh_shapes.size(); ++i) { VLOG(1) << "Trying mesh shape " << spmd::ToString(mesh_shapes[i]); AutoShardingOption this_option = option_; @@ -4130,7 +4115,7 @@ absl::StatusOr AutoSharding::Run( } auto pass = std::make_unique(this_option); std::unique_ptr module_clone = CloneModule(module); - absl::StatusOr pass_result = + absl::StatusOr pass_result = pass->RunAutoSharding(module_clone.get(), replicated_small_tensors, execution_threads, sharding_propagation_solution); if (!pass_result.ok()) { @@ -4148,19 +4133,11 @@ absl::StatusOr AutoSharding::Run( min_mesh_shape_index = i; min_mesh_shape_module = std::move(module_clone); min_objective_value = this_mesh_objective_value; - min_mesh_pass_result = pass_result; - } - if (*pass_result != - AutoShardingResult::kModuleUnchangedNoShardingPerformed) { - skip_auto_sharding = false; + CHECK_OK(pass_result); + module_is_changed = *pass_result; } } - if (skip_auto_sharding) { - RecordPassEndAndDumpModule(start_time, module); - LOG(FATAL) << "The auto-sharding solver has timed out without a solution."; - } - std::string trying_to_find = option_.try_multiple_mesh_shapes ? "a device mesh (and the corresponding shardings)" @@ -4173,28 +4150,18 @@ absl::StatusOr AutoSharding::Run( "higher budget). If you think you have set a reasonably large memory " "budget, please report this as a bug."; - if (!min_mesh_pass_result.ok()) { - RecordPassEndAndDumpModule(start_time, module); - return min_mesh_pass_result.status(); - } - - absl::StatusOr module_is_changed; solver_optimal_objective_value_ = min_objective_value; - if (*min_mesh_pass_result != - AutoShardingResult::kModuleChangedShardingPerformed) { - RecordPassEndAndDumpModule(start_time, module); - return false; + if (module_is_changed) { + VLOG(1) << "Choosing mesh shape " + << spmd::ToString(mesh_shapes[min_mesh_shape_index]) + << " which had the minimal solver objective value of " + << min_objective_value; + chosen_mesh_shape_ = mesh_shapes[min_mesh_shape_index]; + TF_RETURN_IF_ERROR(MoveComputationsFromModuleToModule( + min_mesh_shape_module.get(), module)); } - - VLOG(1) << "Choosing mesh shape " - << spmd::ToString(mesh_shapes[min_mesh_shape_index]) - << " which had the minimal solver objective value of " - << min_objective_value; - chosen_mesh_shape_ = mesh_shapes[min_mesh_shape_index]; - TF_RETURN_IF_ERROR( - MoveComputationsFromModuleToModule(min_mesh_shape_module.get(), module)); RecordPassEndAndDumpModule(start_time, module); - return true; + return module_is_changed; } } // namespace xla diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.h b/xla/hlo/experimental/auto_sharding/auto_sharding.h index e749fb8682532d..e37aecaa46898e 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -50,18 +50,12 @@ limitations under the License. namespace xla { -enum class AutoShardingResult { - kModuleUnchanged, - kModuleChangedShardingPerformed, - kModuleUnchangedNoShardingPerformed -}; - class AutoShardingImplementation { public: explicit AutoShardingImplementation(const AutoShardingOption& option); ~AutoShardingImplementation() = default; - absl::StatusOr RunAutoSharding( + absl::StatusOr RunAutoSharding( HloModule* module, const absl::flat_hash_set& replicated_small_tensors, const absl::flat_hash_set& execution_threads, diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index cd12aeaf3d846c..aa3e45ea54b0b3 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -2839,7 +2839,8 @@ ENTRY matmul { // TODO(b/369616683) Fix the error message output in this case. EXPECT_DEATH( absl::StatusOr status = AutoSharding(option).Run(module.get()), - "The auto-sharding solver has timed out without a solution."); + "The auto-sharding pass could not find shardings that works for this " + "input."); } TEST_F(AutoShardingTest, IgnoreShardAsShardLike) {