diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 29ad335d888e2..49958936358d3 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -3504,14 +3504,14 @@ std::pair ReduceMemoryTerms( return num_terms; } -absl::StatusOr AutoShardingImplementation::RunAutoSharding( +absl::StatusOr AutoShardingImplementation::RunAutoSharding( HloModule* module, const absl::flat_hash_set& replicated_small_tensors, const absl::flat_hash_set& execution_threads, const absl::flat_hash_map& sharding_propagation_solution) { if (!option_.enable) { - return AutoShardingResult::kModuleUnchanged; + return false; } bool module_is_changed = false; @@ -3790,18 +3790,14 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( // ----- Call the ILP Solver ----- std::string request_name = absl::StrCat("mesh_idx_", mesh_idx); - auto solver_result = + spmd::AutoShardingSolverResult solver_result = Solve(*module, *hlo_live_range, strategy_map, strategy_groups, cost_graph, alias_set, reduced_node_intervals, reduced_edge_intervals, reduced_node_groups, reduced_edge_groups, option_, request_name, sharding_propagation_solution); - if (solver_result.skip_auto_sharding) { - return AutoShardingResult::kModuleUnchangedNoShardingPerformed; - } else if (!solver_result.status.ok()) { - return AutoShardingResult::kModuleUnchanged; - } TF_ASSIGN_OR_RETURN(spmd::AutoShardingSolverOutput output, solver_result.status); + if (mesh_idx == partial_mesh_shapes.size() - 1) { this->solver_optimal_objective_value_ = output.cost; } @@ -3823,21 +3819,14 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( output.s_val, (mesh_idx == partial_mesh_shapes.size() - 1)); if (mesh_idx == partial_mesh_shapes.size() - 1) { - if (!spmd::SetHloShardingPostProcessing(sequence, instructions_to_shard, - preserve_shardings) - .ok()) { - return AutoShardingResult::kModuleUnchanged; - } - - if (!InsertReshardReshapes( - sequence, instructions_to_shard, strategy_map, cost_graph, - output.s_val, cluster_env, - /* crash_at_error */ !option_.try_multiple_mesh_shapes, - option_.insert_resharding_reshapes_for_non_dot_ops, - preserve_shardings) - .ok()) { - return AutoShardingResult::kModuleUnchanged; - } + TF_RETURN_IF_ERROR(spmd::SetHloShardingPostProcessing( + sequence, instructions_to_shard, preserve_shardings)); + TF_RETURN_IF_ERROR(InsertReshardReshapes( + sequence, instructions_to_shard, strategy_map, cost_graph, + output.s_val, cluster_env, + /* crash_at_error */ !option_.try_multiple_mesh_shapes, + option_.insert_resharding_reshapes_for_non_dot_ops, + preserve_shardings)); } else { spmd::RecoverShardingsFromPartialMesh(sequence, preserve_shardings); } @@ -3878,8 +3867,7 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( } } - return module_is_changed ? AutoShardingResult::kModuleChangedShardingPerformed - : AutoShardingResult::kModuleUnchanged; + return module_is_changed; } bool ModuleIsManuallyPartitioned(const HloModule* module) { @@ -4109,15 +4097,12 @@ absl::StatusOr AutoSharding::Run( } } - absl::StatusOr min_mesh_pass_result = - AutoShardingResult::kModuleUnchanged; - + bool module_is_changed = false; VLOG(1) << "Original mesh shape " << spmd::ToString(option_.device_mesh_shape); double min_objective_value = std::numeric_limits::max(); int min_mesh_shape_index = -1; std::unique_ptr min_mesh_shape_module; - bool skip_auto_sharding = true; for (size_t i = 0; i < mesh_shapes.size(); ++i) { VLOG(1) << "Trying mesh shape " << spmd::ToString(mesh_shapes[i]); AutoShardingOption this_option = option_; @@ -4130,7 +4115,7 @@ absl::StatusOr AutoSharding::Run( } auto pass = std::make_unique(this_option); std::unique_ptr module_clone = CloneModule(module); - absl::StatusOr pass_result = + absl::StatusOr pass_result = pass->RunAutoSharding(module_clone.get(), replicated_small_tensors, execution_threads, sharding_propagation_solution); if (!pass_result.ok()) { @@ -4148,19 +4133,11 @@ absl::StatusOr AutoSharding::Run( min_mesh_shape_index = i; min_mesh_shape_module = std::move(module_clone); min_objective_value = this_mesh_objective_value; - min_mesh_pass_result = pass_result; - } - if (*pass_result != - AutoShardingResult::kModuleUnchangedNoShardingPerformed) { - skip_auto_sharding = false; + CHECK_OK(pass_result); + module_is_changed = *pass_result; } } - if (skip_auto_sharding) { - RecordPassEndAndDumpModule(start_time, module); - LOG(FATAL) << "The auto-sharding solver has timed out without a solution."; - } - std::string trying_to_find = option_.try_multiple_mesh_shapes ? "a device mesh (and the corresponding shardings)" @@ -4173,28 +4150,18 @@ absl::StatusOr AutoSharding::Run( "higher budget). If you think you have set a reasonably large memory " "budget, please report this as a bug."; - if (!min_mesh_pass_result.ok()) { - RecordPassEndAndDumpModule(start_time, module); - return min_mesh_pass_result.status(); - } - - absl::StatusOr module_is_changed; solver_optimal_objective_value_ = min_objective_value; - if (*min_mesh_pass_result != - AutoShardingResult::kModuleChangedShardingPerformed) { - RecordPassEndAndDumpModule(start_time, module); - return false; + if (module_is_changed) { + VLOG(1) << "Choosing mesh shape " + << spmd::ToString(mesh_shapes[min_mesh_shape_index]) + << " which had the minimal solver objective value of " + << min_objective_value; + chosen_mesh_shape_ = mesh_shapes[min_mesh_shape_index]; + TF_RETURN_IF_ERROR(MoveComputationsFromModuleToModule( + min_mesh_shape_module.get(), module)); } - - VLOG(1) << "Choosing mesh shape " - << spmd::ToString(mesh_shapes[min_mesh_shape_index]) - << " which had the minimal solver objective value of " - << min_objective_value; - chosen_mesh_shape_ = mesh_shapes[min_mesh_shape_index]; - TF_RETURN_IF_ERROR( - MoveComputationsFromModuleToModule(min_mesh_shape_module.get(), module)); RecordPassEndAndDumpModule(start_time, module); - return true; + return module_is_changed; } } // namespace xla diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.h b/xla/hlo/experimental/auto_sharding/auto_sharding.h index e749fb8682532..e37aecaa46898 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -50,18 +50,12 @@ limitations under the License. namespace xla { -enum class AutoShardingResult { - kModuleUnchanged, - kModuleChangedShardingPerformed, - kModuleUnchangedNoShardingPerformed -}; - class AutoShardingImplementation { public: explicit AutoShardingImplementation(const AutoShardingOption& option); ~AutoShardingImplementation() = default; - absl::StatusOr RunAutoSharding( + absl::StatusOr RunAutoSharding( HloModule* module, const absl::flat_hash_set& replicated_small_tensors, const absl::flat_hash_set& execution_threads, diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index cd12aeaf3d846..aa3e45ea54b0b 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -2839,7 +2839,8 @@ ENTRY matmul { // TODO(b/369616683) Fix the error message output in this case. EXPECT_DEATH( absl::StatusOr status = AutoSharding(option).Run(module.get()), - "The auto-sharding solver has timed out without a solution."); + "The auto-sharding pass could not find shardings that works for this " + "input."); } TEST_F(AutoShardingTest, IgnoreShardAsShardLike) { diff --git a/xla/tsl/BUILD b/xla/tsl/BUILD index 18f54ec37eb9f..401cf088936d3 100644 --- a/xla/tsl/BUILD +++ b/xla/tsl/BUILD @@ -3,7 +3,7 @@ load("@bazel_skylib//lib:selects.bzl", "selects") load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "bool_setting") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//xla/tsl:package_groups.bzl", "tsl_package_groups") -load("//xla/tsl:tsl.bzl", "if_google", "if_oss") +load("//xla/tsl:tsl.bzl", "if_google", "if_oss", "internal_visibility") load( "//xla/tsl:tsl.default.bzl", "tsl_extra_config_settings", @@ -500,11 +500,20 @@ config_setting( ) config_setting( - name = "no_nccl_support", + name = "using_no_nccl_support_define", define_values = dict( - if_google({"GOOGLE_CUDA_COMPILER": "clang"}), no_nccl_support = "true", ), + visibility = internal_visibility(["//visibility:private"]), +) + +selects.config_setting_group( + name = "no_nccl_support", + match_all = [ + ":using_no_nccl_support_define", + ] + if_google([ + "@local_config_cuda//cuda:using_config_cuda", + ]), visibility = ["//visibility:public"], )