Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove AutoShardingResult in favor of a boolean now that the value kModuleUnchangedNoShardingPerformed of the enum is unused, effectively making it a boolean. Also simplified away some dead code. #17874

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 16 additions & 28 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3504,14 +3504,14 @@ std::pair<int64_t, int64_t> ReduceMemoryTerms(
return num_terms;
}

absl::StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(
absl::StatusOr<bool> AutoShardingImplementation::RunAutoSharding(
HloModule* module,
const absl::flat_hash_set<std::string>& replicated_small_tensors,
const absl::flat_hash_set<absl::string_view>& execution_threads,
const absl::flat_hash_map<std::string, HloSharding>&
sharding_propagation_solution) {
if (!option_.enable) {
return AutoShardingResult::kModuleUnchanged;
return false;
}
bool module_is_changed = false;

Expand Down Expand Up @@ -3867,8 +3867,7 @@ absl::StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(
}
}

return module_is_changed ? AutoShardingResult::kModuleChangedShardingPerformed
: AutoShardingResult::kModuleUnchanged;
return module_is_changed;
}

bool ModuleIsManuallyPartitioned(const HloModule* module) {
Expand Down Expand Up @@ -4098,9 +4097,7 @@ absl::StatusOr<bool> AutoSharding::Run(
}
}

absl::StatusOr<AutoShardingResult> min_mesh_pass_result =
AutoShardingResult::kModuleUnchanged;

bool module_is_changed = false;
VLOG(1) << "Original mesh shape "
<< spmd::ToString(option_.device_mesh_shape);
double min_objective_value = std::numeric_limits<double>::max();
Expand All @@ -4118,7 +4115,7 @@ absl::StatusOr<bool> AutoSharding::Run(
}
auto pass = std::make_unique<AutoShardingImplementation>(this_option);
std::unique_ptr<HloModule> module_clone = CloneModule(module);
absl::StatusOr<AutoShardingResult> pass_result =
absl::StatusOr<bool> pass_result =
pass->RunAutoSharding(module_clone.get(), replicated_small_tensors,
execution_threads, sharding_propagation_solution);
if (!pass_result.ok()) {
Expand All @@ -4136,7 +4133,8 @@ absl::StatusOr<bool> AutoSharding::Run(
min_mesh_shape_index = i;
min_mesh_shape_module = std::move(module_clone);
min_objective_value = this_mesh_objective_value;
min_mesh_pass_result = pass_result;
CHECK_OK(pass_result);
module_is_changed = *pass_result;
}
}

Expand All @@ -4152,28 +4150,18 @@ absl::StatusOr<bool> AutoSharding::Run(
"higher budget). If you think you have set a reasonably large memory "
"budget, please report this as a bug.";

if (!min_mesh_pass_result.ok()) {
RecordPassEndAndDumpModule(start_time, module);
return min_mesh_pass_result.status();
}

absl::StatusOr<bool> module_is_changed;
solver_optimal_objective_value_ = min_objective_value;
if (*min_mesh_pass_result !=
AutoShardingResult::kModuleChangedShardingPerformed) {
RecordPassEndAndDumpModule(start_time, module);
return false;
if (module_is_changed) {
VLOG(1) << "Choosing mesh shape "
<< spmd::ToString(mesh_shapes[min_mesh_shape_index])
<< " which had the minimal solver objective value of "
<< min_objective_value;
chosen_mesh_shape_ = mesh_shapes[min_mesh_shape_index];
TF_RETURN_IF_ERROR(MoveComputationsFromModuleToModule(
min_mesh_shape_module.get(), module));
}

VLOG(1) << "Choosing mesh shape "
<< spmd::ToString(mesh_shapes[min_mesh_shape_index])
<< " which had the minimal solver objective value of "
<< min_objective_value;
chosen_mesh_shape_ = mesh_shapes[min_mesh_shape_index];
TF_RETURN_IF_ERROR(
MoveComputationsFromModuleToModule(min_mesh_shape_module.get(), module));
RecordPassEndAndDumpModule(start_time, module);
return true;
return module_is_changed;
}

} // namespace xla
8 changes: 1 addition & 7 deletions xla/hlo/experimental/auto_sharding/auto_sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,12 @@ limitations under the License.

namespace xla {

enum class AutoShardingResult {
kModuleUnchanged,
kModuleChangedShardingPerformed,
kModuleUnchangedNoShardingPerformed
};

class AutoShardingImplementation {
public:
explicit AutoShardingImplementation(const AutoShardingOption& option);
~AutoShardingImplementation() = default;

absl::StatusOr<AutoShardingResult> RunAutoSharding(
absl::StatusOr<bool> RunAutoSharding(
HloModule* module,
const absl::flat_hash_set<std::string>& replicated_small_tensors,
const absl::flat_hash_set<absl::string_view>& execution_threads,
Expand Down
Loading