diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index ca2f3a937f73ad..e6985fc2f07af8 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "271e8634de184fbfafd677d3876170feb6d08c97" - STABLEHLO_SHA256 = "06db84c751bd4a980dc76249e02f10e119175fceba3eebed008da122cb480bab" + STABLEHLO_COMMIT = "1bdf7c2603b7e68d97c1b9be92a51826e06cb6ee" + STABLEHLO_SHA256 = "24b594aa66a5d780d30a98e50d24be6d52dd46643a875abc1004288144c6cbc2" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 70166ae23a3fcb..d6457eaf5f62d8 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -2309,10 +2309,10 @@ Status SetHloShardingPostProcessing( continue; } else { if (inst->shape().IsTuple()) { - // While we do not support nested tuples fully, this is a hack to get - // things to work in some cases (specifically observed for the llama and - // gemma models) where nested tuples as used as inputs/outputs of the - // kOptimizationBarrier instruction. + // While we do not support nested tuples fully (b/332951306), this is a + // hack to get things to work in some cases (specifically observed for + // the llama and gemma models) where nested tuples as used as + // inputs/outputs of the kOptimizationBarrier instruction. if (absl::c_any_of( inst->shape().tuple_shapes(), [](const Shape& shape) { return shape.IsTuple(); })) { @@ -2355,7 +2355,7 @@ Status SetHloShardingPostProcessing( for (size_t i = 0; i < inst->shape().tuple_shapes_size(); ++i) { CHECK(!inst->shape().tuple_shapes(i).IsTuple()) << "We currently do not support ops with nested tuples as " - "output."; + "output. See b/332951306."; const ShardingStrategy& stra = GetShardingStrategyForTuple(inst, {static_cast(i)}, strategy_map, cost_graph, s_val); @@ -2842,7 +2842,7 @@ void FindReplicateSet( } // Substitute all-reduce strategies with their reduce-scatter variants. -void GenerateReduceScatter( +absl::Status GenerateReduceScatter( const HloInstructionSequence& sequence, const AliasMap& alias_map, const InstructionDepthMap& depth_map, const StrategyMap& strategy_map, const CostGraph& cost_graph, absl::Span s_val, @@ -3107,8 +3107,9 @@ void GenerateReduceScatter( replace_with->set_sharding( GetShardingStrategy(inst, strategy_map, cost_graph, s_val) .output_sharding); - TF_CHECK_OK(inst->ReplaceAllUsesWith(replace_with)); + TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(replace_with)); } + return OkStatus(); } void AnnotateShardingWithSimpleHeuristic( @@ -3837,8 +3838,9 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( // ----- Substitute all-reduce with reduce-scatter ----- if (option_.prefer_reduce_scatter) { - GenerateReduceScatter(sequence, alias_map, ins_depth_map, strategy_map, - cost_graph, s_val, cluster_env, option_); + TF_RETURN_IF_ERROR(GenerateReduceScatter( + sequence, alias_map, ins_depth_map, strategy_map, cost_graph, s_val, + cluster_env, option_)); } // ----- Set Sharding ----- SetHloSharding(sequence, strategy_map, cost_graph, s_val, @@ -3918,6 +3920,21 @@ bool ShardedOnTooManyMeshAxes(const HloModule& module) { return false; } +bool HasUnsupportedNestedTuples(const HloModule& module) { + for (const auto* computation : module.computations()) { + for (const auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kConditional) { + for (const HloInstruction* operand : instruction->operands()) { + if (ShapeUtil::IsNestedTuple(operand->shape())) { + return true; + } + } + } + } + } + return false; +} + std::unique_ptr CloneModule(const HloModule* module) { auto module_clone = module->Clone(""); module_clone->set_layout_canonicalization_callback( @@ -3938,15 +3955,25 @@ absl::StatusOr AutoSharding::Run( if (IsModuleManuallySharded(module)) { LOG(FATAL) - << "Auto-sharding on partially manually sharded modules is not yet " - "supported. Please fall back on the sharding propagation pass."; + << "Auto-sharding on partially manually sharded modules " // Crash OK + "is not yet supported. Please fall back on the sharding " + "propagation pass."; return false; } if (ShardedOnTooManyMeshAxes(*module)) { - LOG(FATAL) << "The input module contains sharding annotations over a mesh " - "with too many axes (>2). This case is currently not well " - "supported."; + LOG(FATAL) << "The input module contains sharding annotations " // Crash OK + "over a mesh with too many axes (>2). This case is currently " + "not well supported."; + return false; + } + + // TODO(b/332951306): Remove this check once nested tuples are supported + // everywhere + if (HasUnsupportedNestedTuples(*module)) { + LOG(FATAL) << "The input module contains nested tuples " // Crash OK + "which we do not currently support well. See b/332951306 to " + "track progress on this."; return false; } @@ -3960,6 +3987,8 @@ absl::StatusOr AutoSharding::Run( metrics::RecordAutoShardingInvocations(); #endif + TF_RETURN_IF_ERROR(module->RemoveUnusedComputations()); + TF_RETURN_IF_ERROR(option_.CheckAndSetup()); LOG(INFO) << "AutoShardingOptions:\n" << option_.ToString(); @@ -4133,12 +4162,25 @@ absl::StatusOr AutoSharding::Run( << " which had the minimal solver objective value of " << min_objective_value; chosen_mesh_shape_ = mesh_shapes[min_mesh_shape_index]; + TF_RETURN_IF_ERROR( + modules[min_mesh_shape_index]->RemoveUnusedComputations()); + const std::vector& original_module_computations = + module->MakeComputationSorted(); + const std::vector& clone_module_computations = + modules[min_mesh_shape_index]->MakeComputationSorted(); + if (original_module_computations.size() != + clone_module_computations.size()) { + return absl::InternalError( + "The cloned and the original modules do not have the same number " + "of computations. This is a bug and should be reported."); + } + absl::flat_hash_map computation_replacements; - for (size_t i = 0; i < module->computation_count(); ++i) { - auto original_computation = module->mutable_computation(i); - auto new_computation = - modules[min_mesh_shape_index]->mutable_computation(i); + for (size_t i = 0; i < original_module_computations.size(); ++i) { + HloComputation* original_computation = + original_module_computations[i]; + HloComputation* new_computation = clone_module_computations[i]; computation_replacements[original_computation] = new_computation; } diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.h b/xla/hlo/experimental/auto_sharding/auto_sharding.h index 4a0dbe0902fe9c..9a994bfce9a7ed 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -210,7 +210,7 @@ Status CheckAliasSetCompatibility(const AliasSet& alias_set, const HloInstructionSequence& sequence, bool crash_on_error); -void GenerateReduceScatter( +absl::Status GenerateReduceScatter( const HloInstructionSequence& sequence, const AliasMap& alias_map, const InstructionDepthMap& depth_map, const StrategyMap& strategy_map, const CostGraph& cost_graph, absl::Span s_val,