Skip to content

Commit

Permalink
Remove AutoShardingResult in favor of a boolean now that the value kM…
Browse files Browse the repository at this point in the history
…oduleUnchangedNoShardingPerformed of the enum is unused, effectively making it a boolean. Also simplified away some dead code.

PiperOrigin-RevId: 681506949
  • Loading branch information
Google-ML-Automation committed Oct 2, 2024
1 parent 29cad9d commit 8108e87
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 67 deletions.
85 changes: 26 additions & 59 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3504,14 +3504,14 @@ std::pair<int64_t, int64_t> ReduceMemoryTerms(
return num_terms;
}

absl::StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(
absl::StatusOr<bool> AutoShardingImplementation::RunAutoSharding(
HloModule* module,
const absl::flat_hash_set<std::string>& replicated_small_tensors,
const absl::flat_hash_set<absl::string_view>& execution_threads,
const absl::flat_hash_map<std::string, HloSharding>&
sharding_propagation_solution) {
if (!option_.enable) {
return AutoShardingResult::kModuleUnchanged;
return false;
}
bool module_is_changed = false;

Expand Down Expand Up @@ -3790,18 +3790,14 @@ absl::StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(

// ----- Call the ILP Solver -----
std::string request_name = absl::StrCat("mesh_idx_", mesh_idx);
auto solver_result =
spmd::AutoShardingSolverResult solver_result =
Solve(*module, *hlo_live_range, strategy_map, strategy_groups,
cost_graph, alias_set, reduced_node_intervals,
reduced_edge_intervals, reduced_node_groups, reduced_edge_groups,
option_, request_name, sharding_propagation_solution);
if (solver_result.skip_auto_sharding) {
return AutoShardingResult::kModuleUnchangedNoShardingPerformed;
} else if (!solver_result.status.ok()) {
return AutoShardingResult::kModuleUnchanged;
}
TF_ASSIGN_OR_RETURN(spmd::AutoShardingSolverOutput output,
solver_result.status);

if (mesh_idx == partial_mesh_shapes.size() - 1) {
this->solver_optimal_objective_value_ = output.cost;
}
Expand All @@ -3823,21 +3819,14 @@ absl::StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(
output.s_val, (mesh_idx == partial_mesh_shapes.size() - 1));

if (mesh_idx == partial_mesh_shapes.size() - 1) {
if (!spmd::SetHloShardingPostProcessing(sequence, instructions_to_shard,
preserve_shardings)
.ok()) {
return AutoShardingResult::kModuleUnchanged;
}

if (!InsertReshardReshapes(
sequence, instructions_to_shard, strategy_map, cost_graph,
output.s_val, cluster_env,
/* crash_at_error */ !option_.try_multiple_mesh_shapes,
option_.insert_resharding_reshapes_for_non_dot_ops,
preserve_shardings)
.ok()) {
return AutoShardingResult::kModuleUnchanged;
}
TF_RETURN_IF_ERROR(spmd::SetHloShardingPostProcessing(
sequence, instructions_to_shard, preserve_shardings));
TF_RETURN_IF_ERROR(InsertReshardReshapes(
sequence, instructions_to_shard, strategy_map, cost_graph,
output.s_val, cluster_env,
/* crash_at_error */ !option_.try_multiple_mesh_shapes,
option_.insert_resharding_reshapes_for_non_dot_ops,
preserve_shardings));
} else {
spmd::RecoverShardingsFromPartialMesh(sequence, preserve_shardings);
}
Expand Down Expand Up @@ -3878,8 +3867,7 @@ absl::StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(
}
}

return module_is_changed ? AutoShardingResult::kModuleChangedShardingPerformed
: AutoShardingResult::kModuleUnchanged;
return module_is_changed;
}

bool ModuleIsManuallyPartitioned(const HloModule* module) {
Expand Down Expand Up @@ -4109,15 +4097,12 @@ absl::StatusOr<bool> AutoSharding::Run(
}
}

absl::StatusOr<AutoShardingResult> min_mesh_pass_result =
AutoShardingResult::kModuleUnchanged;

bool module_is_changed = false;
VLOG(1) << "Original mesh shape "
<< spmd::ToString(option_.device_mesh_shape);
double min_objective_value = std::numeric_limits<double>::max();
int min_mesh_shape_index = -1;
std::unique_ptr<HloModule> min_mesh_shape_module;
bool skip_auto_sharding = true;
for (size_t i = 0; i < mesh_shapes.size(); ++i) {
VLOG(1) << "Trying mesh shape " << spmd::ToString(mesh_shapes[i]);
AutoShardingOption this_option = option_;
Expand All @@ -4130,7 +4115,7 @@ absl::StatusOr<bool> AutoSharding::Run(
}
auto pass = std::make_unique<AutoShardingImplementation>(this_option);
std::unique_ptr<HloModule> module_clone = CloneModule(module);
absl::StatusOr<AutoShardingResult> pass_result =
absl::StatusOr<bool> pass_result =
pass->RunAutoSharding(module_clone.get(), replicated_small_tensors,
execution_threads, sharding_propagation_solution);
if (!pass_result.ok()) {
Expand All @@ -4148,19 +4133,11 @@ absl::StatusOr<bool> AutoSharding::Run(
min_mesh_shape_index = i;
min_mesh_shape_module = std::move(module_clone);
min_objective_value = this_mesh_objective_value;
min_mesh_pass_result = pass_result;
}
if (*pass_result !=
AutoShardingResult::kModuleUnchangedNoShardingPerformed) {
skip_auto_sharding = false;
CHECK_OK(pass_result);
module_is_changed = *pass_result;
}
}

if (skip_auto_sharding) {
RecordPassEndAndDumpModule(start_time, module);
LOG(FATAL) << "The auto-sharding solver has timed out without a solution.";
}

std::string trying_to_find =
option_.try_multiple_mesh_shapes
? "a device mesh (and the corresponding shardings)"
Expand All @@ -4173,28 +4150,18 @@ absl::StatusOr<bool> AutoSharding::Run(
"higher budget). If you think you have set a reasonably large memory "
"budget, please report this as a bug.";

if (!min_mesh_pass_result.ok()) {
RecordPassEndAndDumpModule(start_time, module);
return min_mesh_pass_result.status();
}

absl::StatusOr<bool> module_is_changed;
solver_optimal_objective_value_ = min_objective_value;
if (*min_mesh_pass_result !=
AutoShardingResult::kModuleChangedShardingPerformed) {
RecordPassEndAndDumpModule(start_time, module);
return false;
if (module_is_changed) {
VLOG(1) << "Choosing mesh shape "
<< spmd::ToString(mesh_shapes[min_mesh_shape_index])
<< " which had the minimal solver objective value of "
<< min_objective_value;
chosen_mesh_shape_ = mesh_shapes[min_mesh_shape_index];
TF_RETURN_IF_ERROR(MoveComputationsFromModuleToModule(
min_mesh_shape_module.get(), module));
}

VLOG(1) << "Choosing mesh shape "
<< spmd::ToString(mesh_shapes[min_mesh_shape_index])
<< " which had the minimal solver objective value of "
<< min_objective_value;
chosen_mesh_shape_ = mesh_shapes[min_mesh_shape_index];
TF_RETURN_IF_ERROR(
MoveComputationsFromModuleToModule(min_mesh_shape_module.get(), module));
RecordPassEndAndDumpModule(start_time, module);
return true;
return module_is_changed;
}

} // namespace xla
8 changes: 1 addition & 7 deletions xla/hlo/experimental/auto_sharding/auto_sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,12 @@ limitations under the License.

namespace xla {

enum class AutoShardingResult {
kModuleUnchanged,
kModuleChangedShardingPerformed,
kModuleUnchangedNoShardingPerformed
};

class AutoShardingImplementation {
public:
explicit AutoShardingImplementation(const AutoShardingOption& option);
~AutoShardingImplementation() = default;

absl::StatusOr<AutoShardingResult> RunAutoSharding(
absl::StatusOr<bool> RunAutoSharding(
HloModule* module,
const absl::flat_hash_set<std::string>& replicated_small_tensors,
const absl::flat_hash_set<absl::string_view>& execution_threads,
Expand Down
3 changes: 2 additions & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2839,7 +2839,8 @@ ENTRY matmul {
// TODO(b/369616683) Fix the error message output in this case.
EXPECT_DEATH(
absl::StatusOr<bool> status = AutoSharding(option).Run(module.get()),
"The auto-sharding solver has timed out without a solution.");
"The auto-sharding pass could not find shardings that works for this "
"input.");
}

TEST_F(AutoShardingTest, IgnoreShardAsShardLike) {
Expand Down

0 comments on commit 8108e87

Please sign in to comment.