Skip to content

Commit

Permalink
Ensure that the module we consume has no unused computations. This ca…
Browse files Browse the repository at this point in the history
…n causes issues as we clone modules to support try_multiple_mesh_shapes, and cloning an HLO module removes dead computations leading to mismatches.

PiperOrigin-RevId: 621361695
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Apr 5, 2024
1 parent 4d135db commit 84f4292
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 21 deletions.
4 changes: 2 additions & 2 deletions third_party/stablehlo/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
# LINT.IfChange
STABLEHLO_COMMIT = "271e8634de184fbfafd677d3876170feb6d08c97"
STABLEHLO_SHA256 = "06db84c751bd4a980dc76249e02f10e119175fceba3eebed008da122cb480bab"
STABLEHLO_COMMIT = "1bdf7c2603b7e68d97c1b9be92a51826e06cb6ee"
STABLEHLO_SHA256 = "24b594aa66a5d780d30a98e50d24be6d52dd46643a875abc1004288144c6cbc2"
# LINT.ThenChange(Google-internal path)

tf_http_archive(
Expand Down
78 changes: 60 additions & 18 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2309,10 +2309,10 @@ Status SetHloShardingPostProcessing(
continue;
} else {
if (inst->shape().IsTuple()) {
// While we do not support nested tuples fully, this is a hack to get
// things to work in some cases (specifically observed for the llama and
// gemma models) where nested tuples as used as inputs/outputs of the
// kOptimizationBarrier instruction.
// While we do not support nested tuples fully (b/332951306), this is a
// hack to get things to work in some cases (specifically observed for
// the llama and gemma models) where nested tuples as used as
// inputs/outputs of the kOptimizationBarrier instruction.
if (absl::c_any_of(
inst->shape().tuple_shapes(),
[](const Shape& shape) { return shape.IsTuple(); })) {
Expand Down Expand Up @@ -2355,7 +2355,7 @@ Status SetHloShardingPostProcessing(
for (size_t i = 0; i < inst->shape().tuple_shapes_size(); ++i) {
CHECK(!inst->shape().tuple_shapes(i).IsTuple())
<< "We currently do not support ops with nested tuples as "
"output.";
"output. See b/332951306.";
const ShardingStrategy& stra =
GetShardingStrategyForTuple(inst, {static_cast<int64_t>(i)},
strategy_map, cost_graph, s_val);
Expand Down Expand Up @@ -2842,7 +2842,7 @@ void FindReplicateSet(
}

// Substitute all-reduce strategies with their reduce-scatter variants.
void GenerateReduceScatter(
absl::Status GenerateReduceScatter(
const HloInstructionSequence& sequence, const AliasMap& alias_map,
const InstructionDepthMap& depth_map, const StrategyMap& strategy_map,
const CostGraph& cost_graph, absl::Span<const NodeStrategyIdx> s_val,
Expand Down Expand Up @@ -3107,8 +3107,9 @@ void GenerateReduceScatter(
replace_with->set_sharding(
GetShardingStrategy(inst, strategy_map, cost_graph, s_val)
.output_sharding);
TF_CHECK_OK(inst->ReplaceAllUsesWith(replace_with));
TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(replace_with));
}
return OkStatus();
}

void AnnotateShardingWithSimpleHeuristic(
Expand Down Expand Up @@ -3837,8 +3838,9 @@ absl::StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(

// ----- Substitute all-reduce with reduce-scatter -----
if (option_.prefer_reduce_scatter) {
GenerateReduceScatter(sequence, alias_map, ins_depth_map, strategy_map,
cost_graph, s_val, cluster_env, option_);
TF_RETURN_IF_ERROR(GenerateReduceScatter(
sequence, alias_map, ins_depth_map, strategy_map, cost_graph, s_val,
cluster_env, option_));
}
// ----- Set Sharding -----
SetHloSharding(sequence, strategy_map, cost_graph, s_val,
Expand Down Expand Up @@ -3918,6 +3920,21 @@ bool ShardedOnTooManyMeshAxes(const HloModule& module) {
return false;
}

bool HasUnsupportedNestedTuples(const HloModule& module) {
for (const auto* computation : module.computations()) {
for (const auto* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kConditional) {
for (const HloInstruction* operand : instruction->operands()) {
if (ShapeUtil::IsNestedTuple(operand->shape())) {
return true;
}
}
}
}
}
return false;
}

std::unique_ptr<HloModule> CloneModule(const HloModule* module) {
auto module_clone = module->Clone("");
module_clone->set_layout_canonicalization_callback(
Expand All @@ -3938,15 +3955,25 @@ absl::StatusOr<bool> AutoSharding::Run(

if (IsModuleManuallySharded(module)) {
LOG(FATAL)
<< "Auto-sharding on partially manually sharded modules is not yet "
"supported. Please fall back on the sharding propagation pass.";
<< "Auto-sharding on partially manually sharded modules " // Crash OK
"is not yet supported. Please fall back on the sharding "
"propagation pass.";
return false;
}

if (ShardedOnTooManyMeshAxes(*module)) {
LOG(FATAL) << "The input module contains sharding annotations over a mesh "
"with too many axes (>2). This case is currently not well "
"supported.";
LOG(FATAL) << "The input module contains sharding annotations " // Crash OK
"over a mesh with too many axes (>2). This case is currently "
"not well supported.";
return false;
}

// TODO(b/332951306): Remove this check once nested tuples are supported
// everywhere
if (HasUnsupportedNestedTuples(*module)) {
LOG(FATAL) << "The input module contains nested tuples " // Crash OK
"which we do not currently support well. See b/332951306 to "
"track progress on this.";
return false;
}

Expand All @@ -3960,6 +3987,8 @@ absl::StatusOr<bool> AutoSharding::Run(
metrics::RecordAutoShardingInvocations();
#endif

TF_RETURN_IF_ERROR(module->RemoveUnusedComputations());

TF_RETURN_IF_ERROR(option_.CheckAndSetup());
LOG(INFO) << "AutoShardingOptions:\n" << option_.ToString();

Expand Down Expand Up @@ -4133,12 +4162,25 @@ absl::StatusOr<bool> AutoSharding::Run(
<< " which had the minimal solver objective value of "
<< min_objective_value;
chosen_mesh_shape_ = mesh_shapes[min_mesh_shape_index];
TF_RETURN_IF_ERROR(
modules[min_mesh_shape_index]->RemoveUnusedComputations());
const std::vector<HloComputation*>& original_module_computations =
module->MakeComputationSorted();
const std::vector<HloComputation*>& clone_module_computations =
modules[min_mesh_shape_index]->MakeComputationSorted();
if (original_module_computations.size() !=
clone_module_computations.size()) {
return absl::InternalError(
"The cloned and the original modules do not have the same number "
"of computations. This is a bug and should be reported.");
}

absl::flat_hash_map<HloComputation*, HloComputation*>
computation_replacements;
for (size_t i = 0; i < module->computation_count(); ++i) {
auto original_computation = module->mutable_computation(i);
auto new_computation =
modules[min_mesh_shape_index]->mutable_computation(i);
for (size_t i = 0; i < original_module_computations.size(); ++i) {
HloComputation* original_computation =
original_module_computations[i];
HloComputation* new_computation = clone_module_computations[i];
computation_replacements[original_computation] = new_computation;
}

Expand Down
2 changes: 1 addition & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ Status CheckAliasSetCompatibility(const AliasSet& alias_set,
const HloInstructionSequence& sequence,
bool crash_on_error);

void GenerateReduceScatter(
absl::Status GenerateReduceScatter(
const HloInstructionSequence& sequence, const AliasMap& alias_map,
const InstructionDepthMap& depth_map, const StrategyMap& strategy_map,
const CostGraph& cost_graph, absl::Span<const int64_t> s_val,
Expand Down

0 comments on commit 84f4292

Please sign in to comment.