From 1136ec539424e4106d3ad579045184a658ac41af Mon Sep 17 00:00:00 2001 From: Farzin Houshmand Date: Wed, 18 Sep 2024 15:51:30 -0700 Subject: [PATCH] Make unroll_config optional in WhileLoopUnroller::GetUnrollableLoops. This change allows more loops to be considered when using GetUnrollableLoops, e.g., in while_initializer_removal pass. Also, make sure the bfloat16 propagation pass correctly handles AllocateBuffer custom-calls. PiperOrigin-RevId: 676163078 --- xla/service/bfloat16_propagation.cc | 6 ++++-- xla/service/hlo_unstacker.cc | 3 ++- ...scan_loop_accumulator_input_unification.cc | 3 ++- xla/service/while_loop_unroller.cc | 21 +++++++++++-------- xla/service/while_loop_unroller.h | 8 ++++--- xla/service/while_loop_unroller_test.cc | 4 ++-- 6 files changed, 27 insertions(+), 18 deletions(-) diff --git a/xla/service/bfloat16_propagation.cc b/xla/service/bfloat16_propagation.cc index 3bf490a4f5e88..bf3dfedf4a0ca 100644 --- a/xla/service/bfloat16_propagation.cc +++ b/xla/service/bfloat16_propagation.cc @@ -352,8 +352,10 @@ bool BFloat16Propagation::ShouldKeepPrecisionUnchanged( } // Do not change precision for side-effecting instructions, control flow, and // bitcast-convert, because this pass might break the interfaces or - // assumptions for them. - return inst->opcode() == HloOpcode::kCustomCall || + // assumptions for them. It is safe to change precision for AllocateBuffer + // since it is merely a buffer allocation and does not have any side effects. + return (inst->opcode() == HloOpcode::kCustomCall && + !inst->IsCustomCall("AllocateBuffer")) || inst->opcode() == HloOpcode::kCall || inst->opcode() == HloOpcode::kBitcastConvert || inst->HasSideEffectNoRecurse(); diff --git a/xla/service/hlo_unstacker.cc b/xla/service/hlo_unstacker.cc index a3b7eec937dc8..03c6f5c2eb635 100644 --- a/xla/service/hlo_unstacker.cc +++ b/xla/service/hlo_unstacker.cc @@ -110,7 +110,8 @@ struct UnstackerMetadata { VLOG(3) << "Prepared module: " << module->name() << " for unstacking."; } std::vector> loops = - WhileLoopUnroller::GetUnrollableLoops(module, {}); + WhileLoopUnroller::GetUnrollableLoops(module, {}, + /*unroll_config=*/std::nullopt); for (const auto& [instr, while_loop_config] : loops) { metadata.unrollable_loop_bodies[instr->while_body()] = while_loop_config; metadata.bodies[instr->while_body()] = instr; diff --git a/xla/service/scan_loop_accumulator_input_unification.cc b/xla/service/scan_loop_accumulator_input_unification.cc index b123bd83e5fd9..4470ae615eb7e 100644 --- a/xla/service/scan_loop_accumulator_input_unification.cc +++ b/xla/service/scan_loop_accumulator_input_unification.cc @@ -270,7 +270,8 @@ absl::StatusOr ScanLoopAccumulatorInputUnification::Run( // accumulators and inputs that are by definition updated and read fully via // dynamic-update-slice and dynamic-sliced within a loop. std::vector> unrollable_loops = - WhileLoopUnroller::GetUnrollableLoops(module, execution_threads); + WhileLoopUnroller::GetUnrollableLoops(module, execution_threads, + /*unroll_config=*/std::nullopt); // TODO(b/337883537): We might want to simplify compare instructions before // this. It helps us identify more inputs and accumulators. diff --git a/xla/service/while_loop_unroller.cc b/xla/service/while_loop_unroller.cc index 3bb2f8d6c8300..053c20aa8aa23 100644 --- a/xla/service/while_loop_unroller.cc +++ b/xla/service/while_loop_unroller.cc @@ -665,7 +665,7 @@ std::optional MatchShapeCoveringDynamicIndexInstruction( WhileLoopUnroller::GetUnrollableLoops( HloModule* module, const absl::flat_hash_set& execution_threads, - const UnrollConfig& unroll_config) { + std::optional unroll_config) { // Processing the while loops in the reverse topological order. If the body // of while loop A calls while loop B, B comes before A. std::vector all_while_ops; @@ -676,13 +676,16 @@ WhileLoopUnroller::GetUnrollableLoops( std::vector> while_loop_configs; for (HloInstruction* instr : all_while_ops) { std::optional config = IsLoopUnrollable(instr); - if (config.has_value()) { - if (!InitialFeasibilityCheck(instr, config.value(), unroll_config)) { - VLOG(3) << "Initial feasibility check failed for " << instr->name(); - continue; - } - while_loop_configs.emplace_back(instr, config.value()); + if (!config.has_value()) { + continue; + } + if (unroll_config.has_value() && + !InitialFeasibilityCheck(instr, config.value(), + unroll_config.value())) { + VLOG(3) << "Initial feasibility check failed for " << instr->name(); + continue; } + while_loop_configs.emplace_back(instr, config.value()); } return while_loop_configs; } @@ -764,8 +767,8 @@ absl::StatusOr WhileLoopUnroller::Run( // unroll. We do this ahead of time so we don't have to worry about mutating // the lists of computations or instructions while we iterate. std::vector> - unrollable_while_ops = - GetUnrollableLoops(module, execution_threads, unroll_config_); + unrollable_while_ops = GetUnrollableLoops( + module, execution_threads, /*unroll_config=*/unroll_config_); VLOG(3) << "Number of while instructions in the module to unroll: " << unrollable_while_ops.size(); diff --git a/xla/service/while_loop_unroller.h b/xla/service/while_loop_unroller.h index def7d9d9a33e1..face63336372a 100644 --- a/xla/service/while_loop_unroller.h +++ b/xla/service/while_loop_unroller.h @@ -49,7 +49,7 @@ struct WhileLoopConfig { int64_t induction_var_idx; }; -// Config for unrollable while loops. +// Result for unrolled while loops. struct UnrollResult { // Whether it's unrolled. bool unrolled = false; @@ -120,12 +120,14 @@ class WhileLoopUnroller : public HloModulePass { static std::optional IsLoopUnrollable( HloInstruction* while_op); - // Returns the list of unrollable loops in the given module + // Returns the list of unrollable loops in the given module. If + // `unroll_config` is provided, it will be used to check feasibility according + // to InitialFeasibilityCheck method static std::vector> GetUnrollableLoops( HloModule* module, const absl::flat_hash_set& execution_threads, - const UnrollConfig& unroll_config = UnrollConfig()); + std::optional unroll_config); // Unrolls the given while loop with the default behaviour set to full unroll. // If wrap_in_trivial_loop is set, the unrolled body of the loop will be diff --git a/xla/service/while_loop_unroller_test.cc b/xla/service/while_loop_unroller_test.cc index fa8fb1dbff72f..fa44905344f66 100644 --- a/xla/service/while_loop_unroller_test.cc +++ b/xla/service/while_loop_unroller_test.cc @@ -495,8 +495,8 @@ TEST_F(WhileLoopUnrollerTest, GetUnrollableLoops) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_string)); - auto unrollable_loops = - WhileLoopUnroller::GetUnrollableLoops(module.get(), {}); + auto unrollable_loops = WhileLoopUnroller::GetUnrollableLoops( + module.get(), {}, /*unroll_config=*/std::nullopt); // Only while1 and while2 are unrollable EXPECT_EQ(unrollable_loops.size(), 2); }