Skip to content

Commit

Permalink
Remove AutoShardingSolverResult in favor of StatusOr<AutoShardingSolv…
Browse files Browse the repository at this point in the history
…erOutput> as the AutoShardingSolverResult::skip_auto_sharding is now dead after some recent changes.

PiperOrigin-RevId: 681682683
  • Loading branch information
Google-ML-Automation committed Oct 3, 2024
1 parent 150b7ba commit 92e3c7a
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 150 deletions.
4 changes: 3 additions & 1 deletion xla/hlo/experimental/auto_sharding/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ cc_library(
"//xla/service:hlo_cost_analysis",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
],
)
Expand All @@ -227,7 +228,6 @@ cc_library(
compatible_with = get_compatible_with_libtpu_portable(),
deps = [
":auto_sharding_cost_graph",
":auto_sharding_device_mesh",
":auto_sharding_option",
":auto_sharding_strategy",
":auto_sharding_wrapper",
Expand All @@ -236,6 +236,7 @@ cc_library(
"//xla/service:hlo_cost_analysis",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
],
)
Expand Down Expand Up @@ -426,5 +427,6 @@ xla_cc_test(
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest",
"@tsl//tsl/platform",
"@tsl//tsl/platform:statusor",
] + if_google(["@com_google_ortools//ortools/linear_solver:linear_solver_scip"]),
)
11 changes: 5 additions & 6 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1745,7 +1745,8 @@ std::unique_ptr<StrategyGroup> CreateReshapeStrategies(
return strategy_group;
}

AutoShardingSolverResult CreateAutoShardingSolverRequestAndCallSolver(
absl::StatusOr<AutoShardingSolverOutput>
CreateAutoShardingSolverRequestAndCallSolver(
const HloModule& hlo_module, const HloLiveRange& hlo_live_range,
const StrategyMap& strategy_map, const StrategyGroups& strategy_groups,
const CostGraph& cost_graph, const AliasSet& alias_set,
Expand Down Expand Up @@ -3796,14 +3797,12 @@ absl::StatusOr<bool> AutoShardingImplementation::RunAutoSharding(

// ----- Call the ILP Solver -----
std::string request_name = absl::StrCat("mesh_idx_", mesh_idx);
spmd::AutoShardingSolverResult solver_result =
TF_ASSIGN_OR_RETURN(
spmd::AutoShardingSolverOutput output,
Solve(*module, *hlo_live_range, strategy_map, strategy_groups,
cost_graph, alias_set, reduced_node_intervals,
reduced_edge_intervals, reduced_node_groups, reduced_edge_groups,
option_, request_name, sharding_propagation_solution);
TF_ASSIGN_OR_RETURN(spmd::AutoShardingSolverOutput output,
solver_result.status);

option_, request_name, sharding_propagation_solution));
if (mesh_idx == partial_mesh_shapes.size() - 1) {
this->solver_optimal_objective_value_ = output.cost;
}
Expand Down
13 changes: 0 additions & 13 deletions xla/hlo/experimental/auto_sharding/auto_sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,19 +211,6 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins,
const ShardingStrategy& strategy,
const ClusterEnvironment& cluster_env);

// The high-level "recipe" for solving an Auto Sharding problem.
AutoShardingSolverResult Solve(
const HloModule& hlo_module, const HloLiveRange& hlo_live_range,
const StrategyMap& strategy_map, const StrategyGroups& strategy_groups,
const CostGraph& cost_graph, const AliasSet& alias_set,
const std::vector<std::pair<LivenessIdx, LivenessIdx>>& node_intervals,
const std::vector<std::pair<LivenessIdx, LivenessIdx>>& edge_intervals,
const std::vector<absl::btree_set<int64_t>>& node_groups,
const std::vector<absl::btree_set<int64_t>>& edge_groups,
const AutoShardingOption& option, absl::string_view request_prefix,
const absl::flat_hash_map<std::string, HloSharding>&
sharding_propagation_solution = {});

// Populates temporal distance values.
void PopulateTemporalValues(const CostGraph& cost_graph,
AutoShardingSolverRequest& request);
Expand Down
3 changes: 2 additions & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.

#include "absl/container/btree_set.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h"
Expand All @@ -37,7 +38,7 @@ limitations under the License.
namespace xla {
namespace spmd {

AutoShardingSolverResult Solve(
absl::StatusOr<AutoShardingSolverOutput> Solve(
const HloModule& hlo_module, const HloLiveRange& hlo_live_range,
const StrategyMap& strategy_map, const StrategyGroups& strategy_groups,
const CostGraph& cost_graph, const AliasSet& alias_set,
Expand Down
44 changes: 16 additions & 28 deletions xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,6 @@ bool AutoShardingSolverOutput::operator==(
peak_times == other.peak_times;
}

bool AutoShardingSolverResult::operator==(
const AutoShardingSolverResult& other) const {
return status == other.status &&
skip_auto_sharding == other.skip_auto_sharding;
}

void PrintLargestInstructions(
const std::vector<NodeStrategyIdx>& chosen_strategy,
const AutoShardingSolverRequest& request) {
Expand Down Expand Up @@ -143,7 +137,7 @@ void PrintLargestInstructions(
}
}

AutoShardingSolverResult SolveAndExtractSolution(
absl::StatusOr<AutoShardingSolverOutput> SolveAndExtractSolution(
const AutoShardingSolverRequest& request,
const std::vector<std::vector<MPVariable*>>& s,
const std::vector<std::vector<MPVariable*>>& e,
Expand Down Expand Up @@ -399,7 +393,7 @@ void AddMemoryTerms(
// can be a few (usually < 10) edges in the problem with negative costs. This
// is guaranteed to never produce a negative overall cost for the graph,
// however.
AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest(
absl::StatusOr<AutoShardingSolverOutput> FormulateAndSolveMIPFromSolverRequest(
const AutoShardingSolverRequest& unscaled_request) {
const absl::Time start_time = absl::Now();
const AutoShardingSolverRequest& request = ScaleRequest(unscaled_request);
Expand Down Expand Up @@ -568,8 +562,7 @@ AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest(
LOG(FATAL) << err_msg;
} else {
LOG(WARNING) << err_msg;
return AutoShardingSolverResult(absl::InternalError(err_msg),
/*skip_auto_sharding=*/false);
return absl::InternalError(err_msg);
}
}
}
Expand Down Expand Up @@ -783,9 +776,9 @@ AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest(
}
auto result = SolveAndExtractSolution(request, s, e, overbudget_var,
makespan_var, *solver);
if (result.status.ok()) {
if (result.ok()) {
const AutoShardingEvaluation evaluation =
Evaluate(unscaled_request, result);
Evaluate(unscaled_request, *result);
LOG(INFO) << "*** Total costs for the (unscaled) solver request ***";
LOG(INFO) << "Total Communication Cost: "
<< evaluation.total.communication_cost
Expand Down Expand Up @@ -831,7 +824,7 @@ std::vector<NodeStrategyIdx> GetChosenNodeStrategy(
return chosen_node_strategy;
}

AutoShardingSolverResult SolveAndExtractSolution(
absl::StatusOr<AutoShardingSolverOutput> SolveAndExtractSolution(
const AutoShardingSolverRequest& request,
const std::vector<std::vector<MPVariable*>>& s,
const std::vector<std::vector<MPVariable*>>& e,
Expand Down Expand Up @@ -869,22 +862,18 @@ AutoShardingSolverResult SolveAndExtractSolution(
}
}
#endif
return AutoShardingSolverResult(
absl::InternalError("MPSolver could not find any feasible solution."),
/*skip_auto_sharding=*/false);
return absl::InternalError(
"MPSolver could not find any feasible solution.");
} else if (status == operations_research::MPSolver::MODEL_INVALID) {
LOG(FATAL) << "Solver says that the input MIP is invalid. This is most "
"likely a bug and should be reported.";
return AutoShardingSolverResult(absl::InternalError("Invalid MIP."),
/*skip_auto_sharding=*/false);
LOG(FATAL) << "The MIP fed to the solver is invalid. This is most likely a "
"bug and should be reported.";
return absl::InternalError("Invalid MIP.");
} else if (status == operations_research::MPSolver::NOT_SOLVED) {
LOG(WARNING) << "Solver timeout; no solution was produced";
return AutoShardingSolverResult(absl::InternalError("Solver timed out."),
/*skip_auto_sharding=*/true);
return absl::InternalError("Solver timed out.");
} else if (status != operations_research::MPSolver::OPTIMAL) {
LOG(WARNING) << "Solver timeout; moving forward with a suboptimal solution";
}

// Fingerprint the model & solution (useful when checking for determinism).
// We use TensorFlow's fingerprint library here, which differs from CP-SAT's.
operations_research::MPModelProto model_proto;
Expand Down Expand Up @@ -951,9 +940,8 @@ AutoShardingSolverResult SolveAndExtractSolution(
<< request.memory_budget() / (1024 * 1024 * 1024) << " GB";
}
PrintLargestInstructions(chosen_node_strategy, request);
const AutoShardingSolverOutput output = {std::move(chosen_node_strategy),
solver.Objective().Value()};
return AutoShardingSolverResult(output, /*skip_auto_sharding=*/false);
return AutoShardingSolverOutput{.s_val = std::move(chosen_node_strategy),
.cost = solver.Objective().Value()};
}

bool CostComponents::operator==(const CostComponents& other) const {
Expand All @@ -977,13 +965,13 @@ bool AutoShardingEvaluation::operator==(
}

AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request,
const AutoShardingSolverResult& result) {
const AutoShardingSolverOutput& result) {
const auto& c = request.computation_costs();
const auto& d = request.communication_costs();
const auto& r = request.resharding_costs();
const auto& v = request.value_costs();
const auto& p = request.departure_costs();
const std::vector<NodeStrategyIdx>& s_val = result.status->s_val;
const std::vector<NodeStrategyIdx>& s_val = result.s_val;
const auto e_val = [&](EdgeIdx edge_idx) {
const auto& edge = request.edges(edge_idx);
return s_val[edge.first()] * request.s_len(edge.second()) +
Expand Down
16 changes: 3 additions & 13 deletions xla/hlo/experimental/auto_sharding/auto_sharding_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,7 @@ struct AutoShardingSolverOutput {
bool operator==(const AutoShardingSolverOutput& other) const;
};

struct AutoShardingSolverResult {
public:
AutoShardingSolverResult(absl::StatusOr<AutoShardingSolverOutput> status,
bool skip_auto_sharding)
: status(status), skip_auto_sharding(skip_auto_sharding) {}
bool operator==(const AutoShardingSolverResult& other) const;
absl::StatusOr<AutoShardingSolverOutput> status;
bool skip_auto_sharding;
};

AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest(
absl::StatusOr<AutoShardingSolverOutput> FormulateAndSolveMIPFromSolverRequest(
const AutoShardingSolverRequest& request);

enum AutoShardingViolationCode {
Expand Down Expand Up @@ -92,7 +82,7 @@ struct AutoShardingEvaluation {
// Evaluates the given solver result w.r.t. the input request, computing various
// solution quality metrics and validating the consistency of hard constraints.
AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request,
const AutoShardingSolverResult& result);
const AutoShardingSolverOutput& result);

// Creates and returns a variable for makespan.
operations_research::MPVariable* CreateMakespanVar(
Expand All @@ -101,7 +91,7 @@ operations_research::MPVariable* CreateMakespanVar(
operations_research::MPSolver& solver);

double EvaluateMakespan(const AutoShardingSolverRequest& request,
const AutoShardingSolverResult& result,
const AutoShardingSolverOutput& result,
AutoShardingEvaluation& evaluation);

// Scale down values to reduce the range of costs & coefficients in the solver.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ MPVariable* CreateMakespanVar(const AutoShardingSolverRequest& request,
}

double EvaluateMakespan(const AutoShardingSolverRequest& request,
const AutoShardingSolverResult& result,
const AutoShardingSolverOutput& result,
AutoShardingEvaluation& evaluation) {
return 0.0; // TODO(moffitt): Implement this.
}
Expand Down
Loading

0 comments on commit 92e3c7a

Please sign in to comment.