Skip to content

Commit

Permalink
Remove AutoShardingSolverResult in favor of StatusOr<AutoShardingSolv…
Browse files Browse the repository at this point in the history
…erOutput> as the AutoShardingSolverResult::skip_auto_sharding is now dead after some recent changes.

PiperOrigin-RevId: 678928364
  • Loading branch information
Google-ML-Automation committed Oct 2, 2024
1 parent db9f12b commit 7c61f6d
Show file tree
Hide file tree
Showing 17 changed files with 150 additions and 1,123 deletions.
901 changes: 0 additions & 901 deletions third_party/stablehlo/temporary.patch

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions third_party/stablehlo/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
# LINT.IfChange
STABLEHLO_COMMIT = "9d9290dc2308c1850cea69ea05f8c94017e484ee"
STABLEHLO_SHA256 = "29803fc8a3a96f9e5469c7ab51f2ff4292dc2419c17bd0466f5d15a448cf6815"
STABLEHLO_COMMIT = "f7f8e4e35296deeff2e12e39421ac8d9599ba340"
STABLEHLO_SHA256 = "c92b55d5512e58d6fefba62c58e60d7762adb184dc3ad489521de562f6ca7aeb"
# LINT.ThenChange(Google-internal path)

tf_http_archive(
Expand Down
3 changes: 2 additions & 1 deletion xla/hlo/experimental/auto_sharding/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ cc_library(
"//xla/service:hlo_cost_analysis",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
],
)
Expand All @@ -227,7 +228,6 @@ cc_library(
compatible_with = get_compatible_with_libtpu_portable(),
deps = [
":auto_sharding_cost_graph",
":auto_sharding_device_mesh",
":auto_sharding_option",
":auto_sharding_strategy",
":auto_sharding_wrapper",
Expand All @@ -236,6 +236,7 @@ cc_library(
"//xla/service:hlo_cost_analysis",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
],
)
Expand Down
87 changes: 27 additions & 60 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1747,7 +1747,8 @@ std::unique_ptr<StrategyGroup> CreateReshapeStrategies(
return strategy_group;
}

AutoShardingSolverResult CreateAutoShardingSolverRequestAndCallSolver(
absl::StatusOr<AutoShardingSolverOutput>
CreateAutoShardingSolverRequestAndCallSolver(
const HloModule& hlo_module, const HloLiveRange& hlo_live_range,
const StrategyMap& strategy_map, const StrategyGroups& strategy_groups,
const CostGraph& cost_graph, const AliasSet& alias_set,
Expand Down Expand Up @@ -3504,14 +3505,14 @@ std::pair<int64_t, int64_t> ReduceMemoryTerms(
return num_terms;
}

absl::StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(
absl::StatusOr<bool> AutoShardingImplementation::RunAutoSharding(
HloModule* module,
const absl::flat_hash_set<std::string>& replicated_small_tensors,
const absl::flat_hash_set<absl::string_view>& execution_threads,
const absl::flat_hash_map<std::string, HloSharding>&
sharding_propagation_solution) {
if (!option_.enable) {
return AutoShardingResult::kModuleUnchanged;
return false;
}
bool module_is_changed = false;

Expand Down Expand Up @@ -3790,16 +3791,11 @@ absl::StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(

// ----- Call the ILP Solver -----
std::string request_name = absl::StrCat("mesh_idx_", mesh_idx);
auto solver_result =
spmd::AutoShardingSolverResult solver_result =
Solve(*module, *hlo_live_range, strategy_map, strategy_groups,
cost_graph, alias_set, reduced_node_intervals,
reduced_edge_intervals, reduced_node_groups, reduced_edge_groups,
option_, request_name, sharding_propagation_solution);
if (solver_result.skip_auto_sharding) {
return AutoShardingResult::kModuleUnchangedNoShardingPerformed;
} else if (!solver_result.status.ok()) {
return AutoShardingResult::kModuleUnchanged;
}
TF_ASSIGN_OR_RETURN(spmd::AutoShardingSolverOutput output,
solver_result.status);
if (mesh_idx == partial_mesh_shapes.size() - 1) {
Expand All @@ -3823,21 +3819,14 @@ absl::StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(
output.s_val, (mesh_idx == partial_mesh_shapes.size() - 1));

if (mesh_idx == partial_mesh_shapes.size() - 1) {
if (!spmd::SetHloShardingPostProcessing(sequence, instructions_to_shard,
preserve_shardings)
.ok()) {
return AutoShardingResult::kModuleUnchanged;
}

if (!InsertReshardReshapes(
sequence, instructions_to_shard, strategy_map, cost_graph,
output.s_val, cluster_env,
/* crash_at_error */ !option_.try_multiple_mesh_shapes,
option_.insert_resharding_reshapes_for_non_dot_ops,
preserve_shardings)
.ok()) {
return AutoShardingResult::kModuleUnchanged;
}
TF_RETURN_IF_ERROR(spmd::SetHloShardingPostProcessing(
sequence, instructions_to_shard, preserve_shardings));
TF_RETURN_IF_ERROR(InsertReshardReshapes(
sequence, instructions_to_shard, strategy_map, cost_graph,
output.s_val, cluster_env,
/* crash_at_error */ !option_.try_multiple_mesh_shapes,
option_.insert_resharding_reshapes_for_non_dot_ops,
preserve_shardings));
} else {
spmd::RecoverShardingsFromPartialMesh(sequence, preserve_shardings);
}
Expand Down Expand Up @@ -3878,8 +3867,7 @@ absl::StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(
}
}

return module_is_changed ? AutoShardingResult::kModuleChangedShardingPerformed
: AutoShardingResult::kModuleUnchanged;
return module_is_changed;
}

bool ModuleIsManuallyPartitioned(const HloModule* module) {
Expand Down Expand Up @@ -4109,15 +4097,12 @@ absl::StatusOr<bool> AutoSharding::Run(
}
}

absl::StatusOr<AutoShardingResult> min_mesh_pass_result =
AutoShardingResult::kModuleUnchanged;

bool module_is_changed = false;
VLOG(1) << "Original mesh shape "
<< spmd::ToString(option_.device_mesh_shape);
double min_objective_value = std::numeric_limits<double>::max();
int min_mesh_shape_index = -1;
std::unique_ptr<HloModule> min_mesh_shape_module;
bool skip_auto_sharding = true;
for (size_t i = 0; i < mesh_shapes.size(); ++i) {
VLOG(1) << "Trying mesh shape " << spmd::ToString(mesh_shapes[i]);
AutoShardingOption this_option = option_;
Expand All @@ -4130,7 +4115,7 @@ absl::StatusOr<bool> AutoSharding::Run(
}
auto pass = std::make_unique<AutoShardingImplementation>(this_option);
std::unique_ptr<HloModule> module_clone = CloneModule(module);
absl::StatusOr<AutoShardingResult> pass_result =
absl::StatusOr<bool> pass_result =
pass->RunAutoSharding(module_clone.get(), replicated_small_tensors,
execution_threads, sharding_propagation_solution);
if (!pass_result.ok()) {
Expand All @@ -4148,19 +4133,11 @@ absl::StatusOr<bool> AutoSharding::Run(
min_mesh_shape_index = i;
min_mesh_shape_module = std::move(module_clone);
min_objective_value = this_mesh_objective_value;
min_mesh_pass_result = pass_result;
}
if (*pass_result !=
AutoShardingResult::kModuleUnchangedNoShardingPerformed) {
skip_auto_sharding = false;
CHECK_OK(pass_result);
module_is_changed = *pass_result;
}
}

if (skip_auto_sharding) {
RecordPassEndAndDumpModule(start_time, module);
LOG(FATAL) << "The auto-sharding solver has timed out without a solution.";
}

std::string trying_to_find =
option_.try_multiple_mesh_shapes
? "a device mesh (and the corresponding shardings)"
Expand All @@ -4173,28 +4150,18 @@ absl::StatusOr<bool> AutoSharding::Run(
"higher budget). If you think you have set a reasonably large memory "
"budget, please report this as a bug.";

if (!min_mesh_pass_result.ok()) {
RecordPassEndAndDumpModule(start_time, module);
return min_mesh_pass_result.status();
}

absl::StatusOr<bool> module_is_changed;
solver_optimal_objective_value_ = min_objective_value;
if (*min_mesh_pass_result !=
AutoShardingResult::kModuleChangedShardingPerformed) {
RecordPassEndAndDumpModule(start_time, module);
return false;
if (module_is_changed) {
VLOG(1) << "Choosing mesh shape "
<< spmd::ToString(mesh_shapes[min_mesh_shape_index])
<< " which had the minimal solver objective value of "
<< min_objective_value;
chosen_mesh_shape_ = mesh_shapes[min_mesh_shape_index];
TF_RETURN_IF_ERROR(MoveComputationsFromModuleToModule(
min_mesh_shape_module.get(), module));
}

VLOG(1) << "Choosing mesh shape "
<< spmd::ToString(mesh_shapes[min_mesh_shape_index])
<< " which had the minimal solver objective value of "
<< min_objective_value;
chosen_mesh_shape_ = mesh_shapes[min_mesh_shape_index];
TF_RETURN_IF_ERROR(
MoveComputationsFromModuleToModule(min_mesh_shape_module.get(), module));
RecordPassEndAndDumpModule(start_time, module);
return true;
return module_is_changed;
}

} // namespace xla
21 changes: 1 addition & 20 deletions xla/hlo/experimental/auto_sharding/auto_sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,12 @@ limitations under the License.

namespace xla {

enum class AutoShardingResult {
kModuleUnchanged,
kModuleChangedShardingPerformed,
kModuleUnchangedNoShardingPerformed
};

class AutoShardingImplementation {
public:
explicit AutoShardingImplementation(const AutoShardingOption& option);
~AutoShardingImplementation() = default;

absl::StatusOr<AutoShardingResult> RunAutoSharding(
absl::StatusOr<bool> RunAutoSharding(
HloModule* module,
const absl::flat_hash_set<std::string>& replicated_small_tensors,
const absl::flat_hash_set<absl::string_view>& execution_threads,
Expand Down Expand Up @@ -216,19 +210,6 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins,
const ShardingStrategy& strategy,
const ClusterEnvironment& cluster_env);

// The high-level "recipe" for solving an Auto Sharding problem.
AutoShardingSolverResult Solve(
const HloModule& hlo_module, const HloLiveRange& hlo_live_range,
const StrategyMap& strategy_map, const StrategyGroups& strategy_groups,
const CostGraph& cost_graph, const AliasSet& alias_set,
const std::vector<std::pair<LivenessIdx, LivenessIdx>>& node_intervals,
const std::vector<std::pair<LivenessIdx, LivenessIdx>>& edge_intervals,
const std::vector<absl::btree_set<int64_t>>& node_groups,
const std::vector<absl::btree_set<int64_t>>& edge_groups,
const AutoShardingOption& option, absl::string_view request_prefix,
const absl::flat_hash_map<std::string, HloSharding>&
sharding_propagation_solution = {});

// Populates temporal distance values.
void PopulateTemporalValues(const CostGraph& cost_graph,
AutoShardingSolverRequest& request);
Expand Down
3 changes: 2 additions & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.

#include "absl/container/btree_set.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h"
Expand All @@ -37,7 +38,7 @@ limitations under the License.
namespace xla {
namespace spmd {

AutoShardingSolverResult Solve(
absl::StatusOr<AutoShardingSolverOutput> Solve(
const HloModule& hlo_module, const HloLiveRange& hlo_live_range,
const StrategyMap& strategy_map, const StrategyGroups& strategy_groups,
const CostGraph& cost_graph, const AliasSet& alias_set,
Expand Down
44 changes: 16 additions & 28 deletions xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,6 @@ bool AutoShardingSolverOutput::operator==(
peak_times == other.peak_times;
}

bool AutoShardingSolverResult::operator==(
const AutoShardingSolverResult& other) const {
return status == other.status &&
skip_auto_sharding == other.skip_auto_sharding;
}

void PrintLargestInstructions(
const std::vector<NodeStrategyIdx>& chosen_strategy,
const AutoShardingSolverRequest& request) {
Expand Down Expand Up @@ -143,7 +137,7 @@ void PrintLargestInstructions(
}
}

AutoShardingSolverResult SolveAndExtractSolution(
absl::StatusOr<AutoShardingSolverOutput> SolveAndExtractSolution(
const AutoShardingSolverRequest& request,
const std::vector<std::vector<MPVariable*>>& s,
const std::vector<std::vector<MPVariable*>>& e,
Expand Down Expand Up @@ -399,7 +393,7 @@ void AddMemoryTerms(
// can be a few (usually < 10) edges in the problem with negative costs. This
// is guaranteed to never produce a negative overall cost for the graph,
// however.
AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest(
absl::StatusOr<AutoShardingSolverOutput> FormulateAndSolveMIPFromSolverRequest(
const AutoShardingSolverRequest& unscaled_request) {
const absl::Time start_time = absl::Now();
const AutoShardingSolverRequest& request = ScaleRequest(unscaled_request);
Expand Down Expand Up @@ -568,8 +562,7 @@ AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest(
LOG(FATAL) << err_msg;
} else {
LOG(WARNING) << err_msg;
return AutoShardingSolverResult(absl::InternalError(err_msg),
/*skip_auto_sharding=*/false);
return absl::InternalError(err_msg);
}
}
}
Expand Down Expand Up @@ -783,9 +776,9 @@ AutoShardingSolverResult FormulateAndSolveMIPFromSolverRequest(
}
auto result = SolveAndExtractSolution(request, s, e, overbudget_var,
makespan_var, *solver);
if (result.status.ok()) {
if (result.ok()) {
const AutoShardingEvaluation evaluation =
Evaluate(unscaled_request, result);
Evaluate(unscaled_request, *result);
LOG(INFO) << "*** Total costs for the (unscaled) solver request ***";
LOG(INFO) << "Total Communication Cost: "
<< evaluation.total.communication_cost
Expand Down Expand Up @@ -831,7 +824,7 @@ std::vector<NodeStrategyIdx> GetChosenNodeStrategy(
return chosen_node_strategy;
}

AutoShardingSolverResult SolveAndExtractSolution(
absl::StatusOr<AutoShardingSolverOutput> SolveAndExtractSolution(
const AutoShardingSolverRequest& request,
const std::vector<std::vector<MPVariable*>>& s,
const std::vector<std::vector<MPVariable*>>& e,
Expand Down Expand Up @@ -869,22 +862,18 @@ AutoShardingSolverResult SolveAndExtractSolution(
}
}
#endif
return AutoShardingSolverResult(
absl::InternalError("MPSolver could not find any feasible solution."),
/*skip_auto_sharding=*/false);
return absl::InternalError(
"MPSolver could not find any feasible solution.");
} else if (status == operations_research::MPSolver::MODEL_INVALID) {
LOG(FATAL) << "Solver says that the input MIP is invalid. This is most "
"likely a bug and should be reported.";
return AutoShardingSolverResult(absl::InternalError("Invalid MIP."),
/*skip_auto_sharding=*/false);
LOG(FATAL) << "The MIP fed to the solver is invalid. This is most likely a "
"bug and should be reported.";
return absl::InternalError("Invalid MIP.");
} else if (status == operations_research::MPSolver::NOT_SOLVED) {
LOG(WARNING) << "Solver timeout; no solution was produced";
return AutoShardingSolverResult(absl::InternalError("Solver timed out."),
/*skip_auto_sharding=*/true);
return absl::InternalError("Solver timed out.");
} else if (status != operations_research::MPSolver::OPTIMAL) {
LOG(WARNING) << "Solver timeout; moving forward with a suboptimal solution";
}

// Fingerprint the model & solution (useful when checking for determinism).
// We use TensorFlow's fingerprint library here, which differs from CP-SAT's.
operations_research::MPModelProto model_proto;
Expand Down Expand Up @@ -951,9 +940,8 @@ AutoShardingSolverResult SolveAndExtractSolution(
<< request.memory_budget() / (1024 * 1024 * 1024) << " GB";
}
PrintLargestInstructions(chosen_node_strategy, request);
const AutoShardingSolverOutput output = {std::move(chosen_node_strategy),
solver.Objective().Value()};
return AutoShardingSolverResult(output, /*skip_auto_sharding=*/false);
return AutoShardingSolverOutput{.s_val = std::move(chosen_node_strategy),
.cost = solver.Objective().Value()};
}

bool CostComponents::operator==(const CostComponents& other) const {
Expand All @@ -977,13 +965,13 @@ bool AutoShardingEvaluation::operator==(
}

AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request,
const AutoShardingSolverResult& result) {
const AutoShardingSolverOutput& result) {
const auto& c = request.computation_costs();
const auto& d = request.communication_costs();
const auto& r = request.resharding_costs();
const auto& v = request.value_costs();
const auto& p = request.departure_costs();
const std::vector<NodeStrategyIdx>& s_val = result.status->s_val;
const std::vector<NodeStrategyIdx>& s_val = result.s_val;
const auto e_val = [&](EdgeIdx edge_idx) {
const auto& edge = request.edges(edge_idx);
return s_val[edge.first()] * request.s_len(edge.second()) +
Expand Down
Loading

0 comments on commit 7c61f6d

Please sign in to comment.