Skip to content

Commit

Permalink
Automated Code Change
Browse files Browse the repository at this point in the history
FUTURE_COPYBARA_INTEGRATE_REVIEW=#17430 from ROCm:ci_use_shared_ptr_20240920 80eb830
PiperOrigin-RevId: 680652154
  • Loading branch information
IllogicalMoose authored and Google-ML-Automation committed Oct 2, 2024
1 parent c37c16e commit 875e54c
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 41 deletions.
41 changes: 10 additions & 31 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3790,18 +3790,14 @@ absl::StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(

// ----- Call the ILP Solver -----
std::string request_name = absl::StrCat("mesh_idx_", mesh_idx);
auto solver_result =
spmd::AutoShardingSolverResult solver_result =
Solve(*module, *hlo_live_range, strategy_map, strategy_groups,
cost_graph, alias_set, reduced_node_intervals,
reduced_edge_intervals, reduced_node_groups, reduced_edge_groups,
option_, request_name, sharding_propagation_solution);
if (solver_result.skip_auto_sharding) {
return AutoShardingResult::kModuleUnchangedNoShardingPerformed;
} else if (!solver_result.status.ok()) {
return AutoShardingResult::kModuleUnchanged;
}
TF_ASSIGN_OR_RETURN(spmd::AutoShardingSolverOutput output,
solver_result.status);

if (mesh_idx == partial_mesh_shapes.size() - 1) {
this->solver_optimal_objective_value_ = output.cost;
}
Expand All @@ -3823,21 +3819,14 @@ absl::StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(
output.s_val, (mesh_idx == partial_mesh_shapes.size() - 1));

if (mesh_idx == partial_mesh_shapes.size() - 1) {
if (!spmd::SetHloShardingPostProcessing(sequence, instructions_to_shard,
preserve_shardings)
.ok()) {
return AutoShardingResult::kModuleUnchanged;
}

if (!InsertReshardReshapes(
sequence, instructions_to_shard, strategy_map, cost_graph,
output.s_val, cluster_env,
/* crash_at_error */ !option_.try_multiple_mesh_shapes,
option_.insert_resharding_reshapes_for_non_dot_ops,
preserve_shardings)
.ok()) {
return AutoShardingResult::kModuleUnchanged;
}
TF_RETURN_IF_ERROR(spmd::SetHloShardingPostProcessing(
sequence, instructions_to_shard, preserve_shardings));
TF_RETURN_IF_ERROR(InsertReshardReshapes(
sequence, instructions_to_shard, strategy_map, cost_graph,
output.s_val, cluster_env,
/* crash_at_error */ !option_.try_multiple_mesh_shapes,
option_.insert_resharding_reshapes_for_non_dot_ops,
preserve_shardings));
} else {
spmd::RecoverShardingsFromPartialMesh(sequence, preserve_shardings);
}
Expand Down Expand Up @@ -4117,7 +4106,6 @@ absl::StatusOr<bool> AutoSharding::Run(
double min_objective_value = std::numeric_limits<double>::max();
int min_mesh_shape_index = -1;
std::unique_ptr<HloModule> min_mesh_shape_module;
bool skip_auto_sharding = true;
for (size_t i = 0; i < mesh_shapes.size(); ++i) {
VLOG(1) << "Trying mesh shape " << spmd::ToString(mesh_shapes[i]);
AutoShardingOption this_option = option_;
Expand Down Expand Up @@ -4150,15 +4138,6 @@ absl::StatusOr<bool> AutoSharding::Run(
min_objective_value = this_mesh_objective_value;
min_mesh_pass_result = pass_result;
}
if (*pass_result !=
AutoShardingResult::kModuleUnchangedNoShardingPerformed) {
skip_auto_sharding = false;
}
}

if (skip_auto_sharding) {
RecordPassEndAndDumpModule(start_time, module);
LOG(FATAL) << "The auto-sharding solver has timed out without a solution.";
}

std::string trying_to_find =
Expand Down
3 changes: 2 additions & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2839,7 +2839,8 @@ ENTRY matmul {
// TODO(b/369616683) Fix the error message output in this case.
EXPECT_DEATH(
absl::StatusOr<bool> status = AutoSharding(option).Run(module.get()),
"The auto-sharding solver has timed out without a solution.");
"The auto-sharding pass could not find shardings that works for this "
"input.");
}

TEST_F(AutoShardingTest, IgnoreShardAsShardLike) {
Expand Down
7 changes: 4 additions & 3 deletions xla/pjrt/pjrt_stream_executor_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2230,7 +2230,7 @@ absl::Status CheckCompatibleShapes(bool strict_shape_checking,
}

// Makes a tuple from the arguments to an execution.
absl::StatusOr<TupleHandle> MakeTupleHelper(
absl::StatusOr<std::unique_ptr<TupleHandle>> MakeTupleHelper(
PjRtStreamExecutorClient* client, LocalDeviceState* local_device,
bool strict_shape_checking, const Shape& tupled_parameter_shape,
absl::Span<PjRtBuffer* const> py_buffers,
Expand Down Expand Up @@ -2296,7 +2296,8 @@ absl::StatusOr<TupleHandle> MakeTupleHelper(
auto transfer_event =
std::make_shared<BufferSequencingEvent>(client->thread_pool());
transfer_event->SetSequencingEvent(std::move(event_or).value(), stream);
return TupleHandle({std::move(execution_input), std::move(transfer_event)});
return std::make_unique<TupleHandle>(
TupleHandle({std::move(execution_input), std::move(transfer_event)}));
}

// Converts a ScopedShapedBuffer returned from an execution into a
Expand Down Expand Up @@ -2465,7 +2466,7 @@ PjRtStreamExecutorLoadedExecutable::MakeExecutionInputsAndWaitForEvents(
client_->client()->backend().transfer_manager();
// Lift tuple_handle outside the conditional so that the event it returns is
// not destroyed until after the loop below that waits on events.
std::optional<TupleHandle> tuple_handle;
std::unique_ptr<TupleHandle> tuple_handle;
if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) {
TF_ASSIGN_OR_RETURN(
tuple_handle,
Expand Down
17 changes: 13 additions & 4 deletions xla/tsl/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ load("@bazel_skylib//lib:selects.bzl", "selects")
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "bool_setting")
load("@tsl//tsl/platform:rules_cc.bzl", "cc_library")
load("//xla/tsl:package_groups.bzl", "tsl_package_groups")
load("//xla/tsl:tsl.bzl", "if_google", "if_oss")
load("//xla/tsl:tsl.bzl", "if_google", "if_oss", "internal_visibility")
load(
"//xla/tsl:tsl.default.bzl",
"tsl_extra_config_settings",
Expand Down Expand Up @@ -39,7 +39,7 @@ alias(
name = "is_cuda_enabled",
actual = if_oss(
"@local_config_cuda//:is_cuda_enabled",
"@local_config_cuda//cuda:using_clang",
"@local_config_cuda//cuda:using_config_cuda",
),
visibility = ["//visibility:public"],
)
Expand Down Expand Up @@ -500,11 +500,20 @@ config_setting(
)

config_setting(
name = "no_nccl_support",
name = "using_no_nccl_support_define",
define_values = dict(
if_google({"GOOGLE_CUDA_COMPILER": "clang"}),
no_nccl_support = "true",
),
visibility = internal_visibility(["//visibility:private"]),
)

selects.config_setting_group(
name = "no_nccl_support",
match_all = [
":using_no_nccl_support_define",
] + if_google([
"@local_config_cuda//cuda:using_config_cuda",
]),
visibility = ["//visibility:public"],
)

Expand Down
3 changes: 1 addition & 2 deletions xla/tsl/tsl.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ def if_cuda_or_rocm(if_true, if_false = []):
"""
return select({
"@local_config_cuda//cuda:using_nvcc": if_true,
"@local_config_cuda//cuda:using_clang": if_true,
clean_dep("//tensorflow/commpiler/xla/tsl:is_cuda_enabled"): if_true,
"@local_config_rocm//rocm:using_hipcc": if_true,
"//conditions:default": if_false,
})
Expand Down

0 comments on commit 875e54c

Please sign in to comment.