Skip to content

Commit

Permalink
Remove AutoShardingSolverResult in favor of StatusOr<AutoShardingSolv…
Browse files Browse the repository at this point in the history
…erOutput>

PiperOrigin-RevId: 678928364
  • Loading branch information
Google-ML-Automation committed Sep 26, 2024
1 parent 0032f7c commit 8ce1ab7
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 115 deletions.
9 changes: 3 additions & 6 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1757,7 +1757,7 @@ std::unique_ptr<StrategyGroup> CreateReshapeStrategies(
return strategy_group;
}

AutoShardingSolverResult CallSolver(
absl::StatusOr<AutoShardingSolverOutput> CallSolver(
const HloModule& hlo_module, const HloLiveRange& hlo_live_range,
const StrategyMap& strategy_map, const StrategyGroups& strategy_groups,
const CostGraph& cost_graph, const AliasSet& alias_set,
Expand Down Expand Up @@ -3804,13 +3804,10 @@ absl::StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(
cost_graph, alias_set, reduced_node_intervals,
reduced_edge_intervals, reduced_node_groups, reduced_edge_groups,
option_, request_name, sharding_propagation_solution);
if (solver_result.skip_auto_sharding) {
return AutoShardingResult::kModuleUnchangedNoShardingPerformed;
} else if (!solver_result.status.ok()) {
if (!solver_result.ok()) {
return AutoShardingResult::kModuleUnchanged;
}
TF_ASSIGN_OR_RETURN(spmd::AutoShardingSolverOutput output,
solver_result.status);
TF_ASSIGN_OR_RETURN(spmd::AutoShardingSolverOutput output, solver_result);
if (mesh_idx == partial_mesh_shapes.size() - 1) {
this->solver_optimal_objective_value_ = output.cost;
}
Expand Down
2 changes: 1 addition & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins,
const ClusterEnvironment& cluster_env);

// The high-level "recipe" for solving an Auto Sharding problem.
AutoShardingSolverResult Solve(
absl::StatusOr<AutoShardingSolverOutput> Solve(
const HloModule& hlo_module, const HloLiveRange& hlo_live_range,
const StrategyMap& strategy_map, const StrategyGroups& strategy_groups,
const CostGraph& cost_graph, const AliasSet& alias_set,
Expand Down
2 changes: 1 addition & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ limitations under the License.
namespace xla {
namespace spmd {

AutoShardingSolverResult Solve(
absl::StatusOr<AutoShardingSolverOutput> Solve(
const HloModule& hlo_module, const HloLiveRange& hlo_live_range,
const StrategyMap& strategy_map, const StrategyGroups& strategy_groups,
const CostGraph& cost_graph, const AliasSet& alias_set,
Expand Down
40 changes: 18 additions & 22 deletions xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ bool AutoShardingSolverOutput::operator==(
peak_times == other.peak_times;
}

bool AutoShardingSolverResult::operator==(
const AutoShardingSolverResult& other) const {
return status == other.status &&
skip_auto_sharding == other.skip_auto_sharding;
}
// bool AutoShardingSolverResult::operator==(
// const AutoShardingSolverResult& other) const {
// return status == other.status &&
// skip_auto_sharding == other.skip_auto_sharding;
// }

void PrintLargestInstructions(
const std::vector<NodeStrategyIdx>& chosen_strategy,
Expand Down Expand Up @@ -143,7 +143,7 @@ void PrintLargestInstructions(
}
}

AutoShardingSolverResult SolveAndExtractSolution(
absl::StatusOr<AutoShardingSolverOutput> SolveAndExtractSolution(
const AutoShardingSolverRequest& request,
const std::vector<std::vector<MPVariable*>>& s,
const std::vector<std::vector<MPVariable*>>& e,
Expand Down Expand Up @@ -399,7 +399,7 @@ void AddMemoryTerms(
// can be a few (usually < 10) edges in the problem with negative costs. This
// is guaranteed to never produce a negative overall cost for the graph,
// however.
AutoShardingSolverResult CallORToolsSolver(
absl::StatusOr<AutoShardingSolverOutput> CallORToolsSolver(
const AutoShardingSolverRequest& unscaled_request) {
const absl::Time start_time = absl::Now();
const AutoShardingSolverRequest& request = ScaleRequest(unscaled_request);
Expand Down Expand Up @@ -565,8 +565,7 @@ AutoShardingSolverResult CallORToolsSolver(
LOG(FATAL) << err_msg;
} else {
LOG(WARNING) << err_msg;
return AutoShardingSolverResult(absl::InternalError(err_msg),
/*skip_auto_sharding=*/false);
return absl::InternalError(err_msg);
}
}
}
Expand Down Expand Up @@ -784,9 +783,9 @@ AutoShardingSolverResult CallORToolsSolver(
}
auto result = SolveAndExtractSolution(request, s, e, overbudget_var,
makespan_var, *solver);
if (result.status.ok()) {
if (result.ok()) {
const AutoShardingEvaluation evaluation =
Evaluate(unscaled_request, result);
Evaluate(unscaled_request, *result);
LOG(INFO) << "*** Total costs for the (unscaled) solver request ***";
LOG(INFO) << "Total Communication Cost: "
<< evaluation.total.communication_cost
Expand Down Expand Up @@ -832,7 +831,7 @@ std::vector<NodeStrategyIdx> GetChosenNodeStrategy(
return chosen_node_strategy;
}

AutoShardingSolverResult SolveAndExtractSolution(
absl::StatusOr<AutoShardingSolverOutput> SolveAndExtractSolution(
const AutoShardingSolverRequest& request,
const std::vector<std::vector<MPVariable*>>& s,
const std::vector<std::vector<MPVariable*>>& e,
Expand Down Expand Up @@ -870,17 +869,14 @@ AutoShardingSolverResult SolveAndExtractSolution(
}
}
#endif
return AutoShardingSolverResult(
absl::InternalError("MPSolver could not find any feasible solution."),
/*skip_auto_sharding=*/false);
return absl::InternalError(
"MPSolver could not find any feasible solution.");
} else if (status == operations_research::MPSolver::MODEL_INVALID) {
LOG(FATAL) << "Solver says that the input MIP is invalid. This is most "
"likely a bug and should be reported.";
return AutoShardingSolverResult(absl::InternalError("Solver timed out."),
/*skip_auto_sharding=*/false);
return absl::InternalError("Solver timed out.");
} else if (status != operations_research::MPSolver::OPTIMAL) {
return AutoShardingSolverResult(absl::InternalError("Solver timed out."),
/*skip_auto_sharding=*/true);
return absl::InternalError("Solver timed out.");
}

// Fingerprint the model & solution (useful when checking for determinism).
Expand Down Expand Up @@ -951,7 +947,7 @@ AutoShardingSolverResult SolveAndExtractSolution(
PrintLargestInstructions(chosen_node_strategy, request);
const AutoShardingSolverOutput output = {std::move(chosen_node_strategy),
solver.Objective().Value()};
return AutoShardingSolverResult(output, /*skip_auto_sharding=*/false);
return output;
}

bool CostComponents::operator==(const CostComponents& other) const {
Expand All @@ -975,13 +971,13 @@ bool AutoShardingEvaluation::operator==(
}

AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request,
const AutoShardingSolverResult& result) {
const AutoShardingSolverOutput& result) {
const auto& c = request.computation_costs();
const auto& d = request.communication_costs();
const auto& r = request.resharding_costs();
const auto& v = request.value_costs();
const auto& p = request.departure_costs();
const std::vector<NodeStrategyIdx>& s_val = result.status->s_val;
const std::vector<NodeStrategyIdx>& s_val = result.s_val;
const auto e_val = [&](EdgeIdx edge_idx) {
const auto& edge = request.edges(edge_idx);
return s_val[edge.first()] * request.s_len(edge.second()) +
Expand Down
16 changes: 3 additions & 13 deletions xla/hlo/experimental/auto_sharding/auto_sharding_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,7 @@ struct AutoShardingSolverOutput {
bool operator==(const AutoShardingSolverOutput& other) const;
};

struct AutoShardingSolverResult {
public:
AutoShardingSolverResult(absl::StatusOr<AutoShardingSolverOutput> status,
bool skip_auto_sharding)
: status(status), skip_auto_sharding(skip_auto_sharding) {}
bool operator==(const AutoShardingSolverResult& other) const;
absl::StatusOr<AutoShardingSolverOutput> status;
bool skip_auto_sharding;
};

AutoShardingSolverResult CallORToolsSolver(
absl::StatusOr<AutoShardingSolverOutput> CallORToolsSolver(
const AutoShardingSolverRequest& request);

enum AutoShardingViolationCode {
Expand Down Expand Up @@ -92,7 +82,7 @@ struct AutoShardingEvaluation {
// Evaluates the given solver result w.r.t. the input request, computing various
// solution quality metrics and validating the consistency of hard constraints.
AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request,
const AutoShardingSolverResult& result);
const AutoShardingSolverOutput& result);

// Creates and returns a variable for makespan.
operations_research::MPVariable* CreateMakespanVar(
Expand All @@ -101,7 +91,7 @@ operations_research::MPVariable* CreateMakespanVar(
operations_research::MPSolver& solver);

double EvaluateMakespan(const AutoShardingSolverRequest& request,
const AutoShardingSolverResult& result,
const AutoShardingSolverOutput& result,
AutoShardingEvaluation& evaluation);

// Scale down values to reduce the range of costs & coefficients in the solver.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ MPVariable* CreateMakespanVar(const AutoShardingSolverRequest& request,
}

double EvaluateMakespan(const AutoShardingSolverRequest& request,
const AutoShardingSolverResult& result,
const AutoShardingSolverOutput& result,
AutoShardingEvaluation& evaluation) {
return 0.0; // TODO(moffitt): Implement this.
}
Expand Down
Loading

0 comments on commit 8ce1ab7

Please sign in to comment.