From df2d940e9009682fc459aa8d91428973e5c81311 Mon Sep 17 00:00:00 2001 From: xla authors Date: Wed, 25 Sep 2024 18:39:54 -0700 Subject: [PATCH] Remove AutoShardingSolverResult in favor of StatusOr as the AutoShardingSolverResult::skip_auto_sharding is now dead after some recent changes. PiperOrigin-RevId: 678928364 --- xla/hlo/experimental/auto_sharding/BUILD | 4 +- .../auto_sharding/auto_sharding.cc | 11 +- .../auto_sharding/auto_sharding.h | 13 -- .../auto_sharding/auto_sharding_impl.cc | 3 +- .../auto_sharding/auto_sharding_solver.cc | 44 ++---- .../auto_sharding/auto_sharding_solver.h | 16 +- .../auto_sharding_solver_impl.cc | 2 +- .../auto_sharding_solver_test.cc | 147 ++++++++---------- .../auto_sharding/auto_sharding_wrapper.h | 17 +- 9 files changed, 107 insertions(+), 150 deletions(-) diff --git a/xla/hlo/experimental/auto_sharding/BUILD b/xla/hlo/experimental/auto_sharding/BUILD index ee63f4e96cc4f..35cd7a5f71d60 100644 --- a/xla/hlo/experimental/auto_sharding/BUILD +++ b/xla/hlo/experimental/auto_sharding/BUILD @@ -217,6 +217,7 @@ cc_library( "//xla/service:hlo_cost_analysis", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], ) @@ -227,7 +228,6 @@ cc_library( compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_cost_graph", - ":auto_sharding_device_mesh", ":auto_sharding_option", ":auto_sharding_strategy", ":auto_sharding_wrapper", @@ -236,6 +236,7 @@ cc_library( "//xla/service:hlo_cost_analysis", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], ) @@ -426,5 +427,6 @@ xla_cc_test( "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@tsl//tsl/platform", + "@tsl//tsl/platform:statusor", ] + if_google(["@com_google_ortools//ortools/linear_solver:linear_solver_scip"]), ) diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 49958936358d3..42a7bfa1e6141 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -1747,7 +1747,8 @@ std::unique_ptr CreateReshapeStrategies( return strategy_group; } -AutoShardingSolverResult CreateAutoShardingSolverRequestAndCallSolver( +absl::StatusOr +CreateAutoShardingSolverRequestAndCallSolver( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, @@ -3790,14 +3791,12 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( // ----- Call the ILP Solver ----- std::string request_name = absl::StrCat("mesh_idx_", mesh_idx); - spmd::AutoShardingSolverResult solver_result = + TF_ASSIGN_OR_RETURN( + spmd::AutoShardingSolverOutput output, Solve(*module, *hlo_live_range, strategy_map, strategy_groups, cost_graph, alias_set, reduced_node_intervals, reduced_edge_intervals, reduced_node_groups, reduced_edge_groups, - option_, request_name, sharding_propagation_solution); - TF_ASSIGN_OR_RETURN(spmd::AutoShardingSolverOutput output, - solver_result.status); - + option_, request_name, sharding_propagation_solution)); if (mesh_idx == partial_mesh_shapes.size() - 1) { this->solver_optimal_objective_value_ = output.cost; } diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.h b/xla/hlo/experimental/auto_sharding/auto_sharding.h index e37aecaa46898..7153bf860515c 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -210,19 +210,6 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, const ShardingStrategy& strategy, const ClusterEnvironment& cluster_env); -// The high-level "recipe" for solving an Auto Sharding problem. -AutoShardingSolverResult Solve( - const HloModule& hlo_module, const HloLiveRange& hlo_live_range, - const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, - const CostGraph& cost_graph, const AliasSet& alias_set, - const std::vector>& node_intervals, - const std::vector>& edge_intervals, - const std::vector>& node_groups, - const std::vector>& edge_groups, - const AutoShardingOption& option, absl::string_view request_prefix, - const absl::flat_hash_map& - sharding_propagation_solution = {}); - // Populates temporal distance values. void PopulateTemporalValues(const CostGraph& cost_graph, AutoShardingSolverRequest& request); diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc index 7a92ac5715039..b9226f561244e 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" @@ -37,7 +38,7 @@ limitations under the License. namespace xla { namespace spmd { -AutoShardingSolverResult Solve( +absl::StatusOr Solve( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 114cca321a050..dec88705f4086 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -81,12 +81,6 @@ bool AutoShardingSolverOutput::operator==( peak_times == other.peak_times; } -bool AutoShardingSolverResult::operator==( - const AutoShardingSolverResult& other) const { - return status == other.status && - skip_auto_sharding == other.skip_auto_sharding; -} - void PrintLargestInstructions( const std::vector& chosen_strategy, const AutoShardingSolverRequest& request) { @@ -143,7 +137,7 @@ void PrintLargestInstructions( } } -AutoShardingSolverResult SolveAndExtractSolution( +absl::StatusOr SolveAndExtractSolution( const AutoShardingSolverRequest& request, const std::vector>& s, const std::vector>& e, @@ -399,7 +393,7 @@ void AddMemoryTerms( // can be a few (usually < 10) edges in the problem with negative costs. This // is guaranteed to never produce a negative overall cost for the graph, // however. -AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest( +absl::StatusOr FormulateAndSolveMIPFromSolverRequest( const AutoShardingSolverRequest& unscaled_request) { const absl::Time start_time = absl::Now(); const AutoShardingSolverRequest& request = ScaleRequest(unscaled_request); @@ -568,8 +562,7 @@ AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest( LOG(FATAL) << err_msg; } else { LOG(WARNING) << err_msg; - return AutoShardingSolverResult(absl::InternalError(err_msg), - /*skip_auto_sharding=*/false); + return absl::InternalError(err_msg); } } } @@ -783,9 +776,9 @@ AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest( } auto result = SolveAndExtractSolution(request, s, e, overbudget_var, makespan_var, *solver); - if (result.status.ok()) { + if (result.ok()) { const AutoShardingEvaluation evaluation = - Evaluate(unscaled_request, result); + Evaluate(unscaled_request, *result); LOG(INFO) << "*** Total costs for the (unscaled) solver request ***"; LOG(INFO) << "Total Communication Cost: " << evaluation.total.communication_cost @@ -831,7 +824,7 @@ std::vector GetChosenNodeStrategy( return chosen_node_strategy; } -AutoShardingSolverResult SolveAndExtractSolution( +absl::StatusOr SolveAndExtractSolution( const AutoShardingSolverRequest& request, const std::vector>& s, const std::vector>& e, @@ -869,22 +862,18 @@ AutoShardingSolverResult SolveAndExtractSolution( } } #endif - return AutoShardingSolverResult( - absl::InternalError("MPSolver could not find any feasible solution."), - /*skip_auto_sharding=*/false); + return absl::InternalError( + "MPSolver could not find any feasible solution."); } else if (status == operations_research::MPSolver::MODEL_INVALID) { - LOG(FATAL) << "Solver says that the input MIP is invalid. This is most " - "likely a bug and should be reported."; - return AutoShardingSolverResult(absl::InternalError("Invalid MIP."), - /*skip_auto_sharding=*/false); + LOG(FATAL) << "The MIP fed to the solver is invalid. This is most likely a " + "bug and should be reported."; + return absl::InternalError("Invalid MIP."); } else if (status == operations_research::MPSolver::NOT_SOLVED) { LOG(WARNING) << "Solver timeout; no solution was produced"; - return AutoShardingSolverResult(absl::InternalError("Solver timed out."), - /*skip_auto_sharding=*/true); + return absl::InternalError("Solver timed out."); } else if (status != operations_research::MPSolver::OPTIMAL) { LOG(WARNING) << "Solver timeout; moving forward with a suboptimal solution"; } - // Fingerprint the model & solution (useful when checking for determinism). // We use TensorFlow's fingerprint library here, which differs from CP-SAT's. operations_research::MPModelProto model_proto; @@ -951,9 +940,8 @@ AutoShardingSolverResult SolveAndExtractSolution( << request.memory_budget() / (1024 * 1024 * 1024) << " GB"; } PrintLargestInstructions(chosen_node_strategy, request); - const AutoShardingSolverOutput output = {std::move(chosen_node_strategy), - solver.Objective().Value()}; - return AutoShardingSolverResult(output, /*skip_auto_sharding=*/false); + return AutoShardingSolverOutput{.s_val = std::move(chosen_node_strategy), + .cost = solver.Objective().Value()}; } bool CostComponents::operator==(const CostComponents& other) const { @@ -977,13 +965,13 @@ bool AutoShardingEvaluation::operator==( } AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, - const AutoShardingSolverResult& result) { + const AutoShardingSolverOutput& result) { const auto& c = request.computation_costs(); const auto& d = request.communication_costs(); const auto& r = request.resharding_costs(); const auto& v = request.value_costs(); const auto& p = request.departure_costs(); - const std::vector& s_val = result.status->s_val; + const std::vector& s_val = result.s_val; const auto e_val = [&](EdgeIdx edge_idx) { const auto& edge = request.edges(edge_idx); return s_val[edge.first()] * request.s_len(edge.second()) + diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h index 88884f7286d0b..e6dd82717b6e8 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h @@ -37,17 +37,7 @@ struct AutoShardingSolverOutput { bool operator==(const AutoShardingSolverOutput& other) const; }; -struct AutoShardingSolverResult { - public: - AutoShardingSolverResult(absl::StatusOr status, - bool skip_auto_sharding) - : status(status), skip_auto_sharding(skip_auto_sharding) {} - bool operator==(const AutoShardingSolverResult& other) const; - absl::StatusOr status; - bool skip_auto_sharding; -}; - -AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest( +absl::StatusOr FormulateAndSolveMIPFromSolverRequest( const AutoShardingSolverRequest& request); enum AutoShardingViolationCode { @@ -92,7 +82,7 @@ struct AutoShardingEvaluation { // Evaluates the given solver result w.r.t. the input request, computing various // solution quality metrics and validating the consistency of hard constraints. AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, - const AutoShardingSolverResult& result); + const AutoShardingSolverOutput& result); // Creates and returns a variable for makespan. operations_research::MPVariable* CreateMakespanVar( @@ -101,7 +91,7 @@ operations_research::MPVariable* CreateMakespanVar( operations_research::MPSolver& solver); double EvaluateMakespan(const AutoShardingSolverRequest& request, - const AutoShardingSolverResult& result, + const AutoShardingSolverOutput& result, AutoShardingEvaluation& evaluation); // Scale down values to reduce the range of costs & coefficients in the solver. diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc index 4be54f98a0a49..176f1426a9866 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc @@ -33,7 +33,7 @@ MPVariable* CreateMakespanVar(const AutoShardingSolverRequest& request, } double EvaluateMakespan(const AutoShardingSolverRequest& request, - const AutoShardingSolverResult& result, + const AutoShardingSolverOutput& result, AutoShardingEvaluation& evaluation) { return 0.0; // TODO(moffitt): Implement this. } diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc index 3e0c82d3b7551..4ddafbee670ca 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "xla/hlo/experimental/auto_sharding/auto_sharding.pb.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "tsl/platform/platform.h" +#include "tsl/platform/statusor.h" namespace xla { namespace spmd { @@ -253,14 +254,13 @@ AutoShardingSolverRequest AutoShardingSolverRequestWithEquivalences() { TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesOptimally) { const AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesOverbudget) { @@ -268,42 +268,39 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesOverbudget) { request.set_memory_budget(100000); request.mutable_overbudget_coeff()->set_coeff(10.0); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 9007650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesMaxDepartures) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_max_departures()->set_coeff(3.0); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, MinimizesDepartures) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.set_minimize_departures(true); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 1, 0, 0, 1}; const double objective_value = 3.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, AvoidsInfiniteNodeCosts) { @@ -312,28 +309,26 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, AvoidsInfiniteNodeCosts) { request.mutable_computation_costs(0)->set_costs(1, kInfinityCost); request.mutable_computation_costs(0)->set_costs(2, kInfinityCost); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {3, 0, 0, 0, 0}; const double objective_value = 10683.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, AvoidsInfiniteEdgeCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_resharding_costs(0)->set_costs(0, kInfinityCost); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesFollowedEdges) { @@ -352,14 +347,13 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesFollowedEdges) { 70000, 71000, 72000, 73000}}; AddCosts(request.mutable_duration_costs(), t); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 12650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesCollapsedEdge) { @@ -380,14 +374,13 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesCollapsedEdge) { 80000, 81000, 82000, 83000}}; AddCosts(request.mutable_duration_costs(), t); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 13972.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, UsesHint) { @@ -395,38 +388,36 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, UsesHint) { const auto s_hint = {1, 0, 0, 0, 0}; // Not optimal, but close. request.mutable_s_hint()->Add(s_hint.begin(), s_hint.end()); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, HonorsMaxCost) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_max_cost()->set_coeff(7600.0); // Best possible is 7650.0 - const AutoShardingSolverResult result = + const absl::StatusOr result = FormulateAndSolveMIPFromSolverRequest(request); - EXPECT_TRUE(absl::IsInternal(result.status.status())); + EXPECT_TRUE(absl::IsInternal(result.status())); } TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesExtremelyHighMaxCost) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_max_cost()->set_coeff(1e19); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesMemoryEdgeCosts) { @@ -443,14 +434,13 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesMemoryEdgeCosts) { AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs); request.set_enable_memory_edge_costs(true); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesIntervals) { @@ -472,14 +462,13 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesIntervals) { AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs); request.set_enable_memory_edge_costs(true); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, @@ -506,14 +495,13 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs); request.set_enable_memory_edge_costs(true); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, @@ -527,14 +515,13 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, AddGroups(request.mutable_node_groups(), node_groups); request.set_enable_memory_edge_costs(false); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, @@ -569,28 +556,26 @@ TEST(FormulateAndSolveMIPFromSolverRequestTest, request.set_enable_memory_edge_costs(true); request.set_memory_budget(4321); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesWithEquivalences) { const AutoShardingSolverRequest request = AutoShardingSolverRequestWithEquivalences(); - const AutoShardingSolverResult result = - FormulateAndSolveMIPFromSolverRequest(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 5, 5, 1}; const double objective_value = 7650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(AutoShardingEvaluatorTest, NoViolations) { @@ -598,9 +583,8 @@ TEST(AutoShardingEvaluatorTest, NoViolations) { const std::vector s_val = {3, 1, 2, 2, 1}; const double objective_value = 12149.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.total.computation_cost = 159.0; // 13+21+32+42+51 @@ -620,9 +604,8 @@ TEST(AutoShardingEvaluatorTest, EvaluatesOverbudget) { const std::vector s_val = {2 /* violates */, 1, 2, 2, 1}; const double objective_value = 11138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.total.computation_cost = 158.0; // 12+21+32+42+51 @@ -648,9 +631,8 @@ TEST(AutoShardingEvaluatorTest, EvaluatesOverbudgetWithIntervals) { const std::vector s_val = {2 /* violates */, 1, 2, 2, 1}; const double objective_value = 11138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.total.computation_cost = 158.0; // 12+21+32+42+51 @@ -679,9 +661,8 @@ TEST(AutoShardingEvaluatorTest, const std::vector s_val = {2 /* violates */, 1, 2, 2, 1}; const double objective_value = 11138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.total.computation_cost = 158.0; // 12+21+32+42+51 @@ -701,9 +682,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesFollower) { const std::vector s_val = {3, 1, 2, 1 /* violates */, 1}; const double objective_value = 12138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kFollowerViolationCode}; @@ -722,9 +702,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesAlias) { const std::vector s_val = {3, 1, 2, 2, 0 /* violates */}; const double objective_value = 12138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kAliasViolationCode}; @@ -743,9 +722,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesMemory) { const std::vector s_val = {2 /* violates */, 1, 2, 2, 1}; const double objective_value = 11138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kMemoryViolationCode}; @@ -767,9 +745,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesInfiniteCostForNode) { const std::vector s_val = {0 /* violates */, 1, 2, 2, 1}; const double objective_value = 1e+20; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kInfiniteCostViolationCode}; @@ -789,9 +766,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesInfiniteCostForEdge) { const std::vector s_val = {0, 1, 2, 2, 1}; const double objective_value = 1e+20; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kInfiniteCostViolationCode}; @@ -811,9 +787,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesMaxDepartures) { const std::vector s_val = {3, 1, 2, 2, 1}; const double objective_value = 12149.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kMaxDeparturesViolationCode}; diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h b/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h index f9058802eea52..333df715447f0 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" @@ -39,9 +40,23 @@ limitations under the License. namespace xla { namespace spmd { +// The high-level "recipe" for solving an Auto Sharding problem. +absl::StatusOr Solve( + const HloModule& hlo_module, const HloLiveRange& hlo_live_range, + const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, + const CostGraph& cost_graph, const AliasSet& alias_set, + const std::vector>& node_intervals, + const std::vector>& edge_intervals, + const std::vector>& node_groups, + const std::vector>& edge_groups, + const AutoShardingOption& option, absl::string_view request_prefix, + const absl::flat_hash_map& + sharding_propagation_solution = {}); + // A wrapper around the solver that converts the given objects into a // combinatorial optimization problem & solves it. -AutoShardingSolverResult CreateAutoShardingSolverRequestAndCallSolver( +absl::StatusOr +CreateAutoShardingSolverRequestAndCallSolver( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set,