From 1be50b8107adc69a22330d4b7eb1a2d4a62f4383 Mon Sep 17 00:00:00 2001 From: Farzin Houshmand Date: Mon, 26 Aug 2024 18:07:22 -0700 Subject: [PATCH] Make unroll_config optional in WhileLoopUnroller::GetUnrollableLoops. This change allows more loops to be considered when using GetUnrollableLoops, e.g., in while_initializer_removal pass. PiperOrigin-RevId: 667782425 --- xla/service/hlo_unstacker.cc | 3 ++- ...scan_loop_accumulator_input_unification.cc | 3 ++- xla/service/while_loop_unroller.cc | 21 +++++++++++-------- xla/service/while_loop_unroller.h | 8 ++++--- xla/service/while_loop_unroller_test.cc | 4 ++-- 5 files changed, 23 insertions(+), 16 deletions(-) diff --git a/xla/service/hlo_unstacker.cc b/xla/service/hlo_unstacker.cc index a3b7eec937dc86..03c6f5c2eb6356 100644 --- a/xla/service/hlo_unstacker.cc +++ b/xla/service/hlo_unstacker.cc @@ -110,7 +110,8 @@ struct UnstackerMetadata { VLOG(3) << "Prepared module: " << module->name() << " for unstacking."; } std::vector> loops = - WhileLoopUnroller::GetUnrollableLoops(module, {}); + WhileLoopUnroller::GetUnrollableLoops(module, {}, + /*unroll_config=*/std::nullopt); for (const auto& [instr, while_loop_config] : loops) { metadata.unrollable_loop_bodies[instr->while_body()] = while_loop_config; metadata.bodies[instr->while_body()] = instr; diff --git a/xla/service/scan_loop_accumulator_input_unification.cc b/xla/service/scan_loop_accumulator_input_unification.cc index b123bd83e5fd99..4470ae615eb7e5 100644 --- a/xla/service/scan_loop_accumulator_input_unification.cc +++ b/xla/service/scan_loop_accumulator_input_unification.cc @@ -270,7 +270,8 @@ absl::StatusOr ScanLoopAccumulatorInputUnification::Run( // accumulators and inputs that are by definition updated and read fully via // dynamic-update-slice and dynamic-sliced within a loop. std::vector> unrollable_loops = - WhileLoopUnroller::GetUnrollableLoops(module, execution_threads); + WhileLoopUnroller::GetUnrollableLoops(module, execution_threads, + /*unroll_config=*/std::nullopt); // TODO(b/337883537): We might want to simplify compare instructions before // this. It helps us identify more inputs and accumulators. diff --git a/xla/service/while_loop_unroller.cc b/xla/service/while_loop_unroller.cc index b8f2bcae972abf..51574a8a9fd5ef 100644 --- a/xla/service/while_loop_unroller.cc +++ b/xla/service/while_loop_unroller.cc @@ -665,7 +665,7 @@ std::optional MatchShapeCoveringDynamicIndexInstruction( WhileLoopUnroller::GetUnrollableLoops( HloModule* module, const absl::flat_hash_set& execution_threads, - const UnrollConfig& unroll_config) { + std::optional unroll_config) { // Processing the while loops in the reverse topological order. If the body // of while loop A calls while loop B, B comes before A. std::vector all_while_ops; @@ -676,13 +676,16 @@ WhileLoopUnroller::GetUnrollableLoops( std::vector> while_loop_configs; for (HloInstruction* instr : all_while_ops) { std::optional config = IsLoopUnrollable(instr); - if (config.has_value()) { - if (!InitialFeasibilityCheck(instr, config.value(), unroll_config)) { - VLOG(3) << "Initial feasibility check failed for " << instr->name(); - continue; - } - while_loop_configs.emplace_back(instr, config.value()); + if (!config.has_value()) { + continue; + } + if (unroll_config.has_value() && + !InitialFeasibilityCheck(instr, config.value(), + unroll_config.value())) { + VLOG(3) << "Initial feasibility check failed for " << instr->name(); + continue; } + while_loop_configs.emplace_back(instr, config.value()); } return while_loop_configs; } @@ -764,8 +767,8 @@ absl::StatusOr WhileLoopUnroller::Run( // unroll. We do this ahead of time so we don't have to worry about mutating // the lists of computations or instructions while we iterate. std::vector> - unrollable_while_ops = - GetUnrollableLoops(module, execution_threads, unroll_config_); + unrollable_while_ops = GetUnrollableLoops( + module, execution_threads, /*unroll_config=*/unroll_config_); VLOG(3) << "Number of while instructions in the module to unroll: " << unrollable_while_ops.size(); diff --git a/xla/service/while_loop_unroller.h b/xla/service/while_loop_unroller.h index cb513320a742b7..b6c3e1f3e1f10e 100644 --- a/xla/service/while_loop_unroller.h +++ b/xla/service/while_loop_unroller.h @@ -49,7 +49,7 @@ struct WhileLoopConfig { int64_t induction_var_idx; }; -// Config for unrollable while loops. +// Result for unrolled while loops. struct UnrollResult { // Whether it's unrolled. bool unrolled = false; @@ -120,12 +120,14 @@ class WhileLoopUnroller : public HloModulePass { static std::optional IsLoopUnrollable( HloInstruction* while_op); - // Returns the list of unrollable loops in the given module + // Returns the list of unrollable loops in the given module. If + // `unroll_config` is provided, it will be used to check feasibility according + // to InitialFeasibilityCheck method static std::vector> GetUnrollableLoops( HloModule* module, const absl::flat_hash_set& execution_threads, - const UnrollConfig& unroll_config = UnrollConfig()); + std::optional unroll_config); // Unrolls the given while loop with the default behaviour set to full unroll. // If wrap_in_trivial_loop is set, the unrolled body of the loop will be diff --git a/xla/service/while_loop_unroller_test.cc b/xla/service/while_loop_unroller_test.cc index fa8fb1dbff72ff..fa44905344f66d 100644 --- a/xla/service/while_loop_unroller_test.cc +++ b/xla/service/while_loop_unroller_test.cc @@ -495,8 +495,8 @@ TEST_F(WhileLoopUnrollerTest, GetUnrollableLoops) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_string)); - auto unrollable_loops = - WhileLoopUnroller::GetUnrollableLoops(module.get(), {}); + auto unrollable_loops = WhileLoopUnroller::GetUnrollableLoops( + module.get(), {}, /*unroll_config=*/std::nullopt); // Only while1 and while2 are unrollable EXPECT_EQ(unrollable_loops.size(), 2); }