Skip to content

Commit

Permalink
Make unroll_config optional in WhileLoopUnroller::GetUnrollableLoops.
Browse files Browse the repository at this point in the history
This change allows more loops to be considered when using GetUnrollableLoops, e.g., in while_initializer_removal pass.

PiperOrigin-RevId: 667782425
  • Loading branch information
fhoushmand authored and copybara-github committed Aug 27, 2024
1 parent 6ea9438 commit 0e90a49
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 13 deletions.
2 changes: 1 addition & 1 deletion xla/service/hlo_unstacker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ struct UnstackerMetadata {
VLOG(3) << "Prepared module: " << module->name() << " for unstacking.";
}
std::vector<std::pair<HloInstruction*, WhileLoopConfig>> loops =
WhileLoopUnroller::GetUnrollableLoops(module, {});
WhileLoopUnroller::GetUnrollableLoops(module, {}, std::nullopt);
for (const auto& [instr, while_loop_config] : loops) {
metadata.unrollable_loop_bodies[instr->while_body()] = while_loop_config;
metadata.bodies[instr->while_body()] = instr;
Expand Down
3 changes: 2 additions & 1 deletion xla/service/scan_loop_accumulator_input_unification.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ absl::StatusOr<bool> ScanLoopAccumulatorInputUnification::Run(
// accumulators and inputs that are by definition updated and read fully via
// dynamic-update-slice and dynamic-sliced within a loop.
std::vector<std::pair<HloInstruction*, WhileLoopConfig>> unrollable_loops =
WhileLoopUnroller::GetUnrollableLoops(module, execution_threads);
WhileLoopUnroller::GetUnrollableLoops(module, execution_threads,
std::nullopt);

// TODO(b/337883537): We might want to simplify compare instructions before
// this. It helps us identify more inputs and accumulators.
Expand Down
17 changes: 10 additions & 7 deletions xla/service/while_loop_unroller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ std::optional<int64_t> MatchShapeCoveringDynamicIndexInstruction(
WhileLoopUnroller::GetUnrollableLoops(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads,
const UnrollConfig& unroll_config) {
std::optional<UnrollConfig> unroll_config) {
// Processing the while loops in the reverse topological order. If the body
// of while loop A calls while loop B, B comes before A.
std::vector<HloInstruction*> all_while_ops;
Expand All @@ -676,13 +676,16 @@ WhileLoopUnroller::GetUnrollableLoops(
std::vector<std::pair<HloInstruction*, WhileLoopConfig>> while_loop_configs;
for (HloInstruction* instr : all_while_ops) {
std::optional<WhileLoopConfig> config = IsLoopUnrollable(instr);
if (config.has_value()) {
if (!InitialFeasibilityCheck(instr, config.value(), unroll_config)) {
VLOG(3) << "Initial feasibility check failed for " << instr->name();
continue;
}
while_loop_configs.emplace_back(instr, config.value());
if (!config.has_value()) {
continue;
}
if (unroll_config.has_value() &&
!InitialFeasibilityCheck(instr, config.value(),
unroll_config.value())) {
VLOG(3) << "Initial feasibility check failed for " << instr->name();
continue;
}
while_loop_configs.emplace_back(instr, config.value());
}
return while_loop_configs;
}
Expand Down
8 changes: 5 additions & 3 deletions xla/service/while_loop_unroller.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct WhileLoopConfig {
int64_t induction_var_idx;
};

// Config for unrollable while loops.
// Result for unrolled while loops.
struct UnrollResult {
// Whether it's unrolled.
bool unrolled = false;
Expand Down Expand Up @@ -120,12 +120,14 @@ class WhileLoopUnroller : public HloModulePass {
static std::optional<WhileLoopConfig> IsLoopUnrollable(
HloInstruction* while_op);

// Returns the list of unrollable loops in the given module
// Returns the list of unrollable loops in the given module. If
// `unroll_config` is provided, it will be used to check feasibility according
// to InitialFeasibilityCheck method
static std::vector<std::pair<HloInstruction*, WhileLoopConfig>>
GetUnrollableLoops(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads,
const UnrollConfig& unroll_config = UnrollConfig());
std::optional<UnrollConfig> unroll_config);

// Unrolls the given while loop with the default behaviour set to full unroll.
// If wrap_in_trivial_loop is set, the unrolled body of the loop will be
Expand Down
2 changes: 1 addition & 1 deletion xla/service/while_loop_unroller_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ TEST_F(WhileLoopUnrollerTest, GetUnrollableLoops) {
ParseAndReturnVerifiedModule(hlo_string));

auto unrollable_loops =
WhileLoopUnroller::GetUnrollableLoops(module.get(), {});
WhileLoopUnroller::GetUnrollableLoops(module.get(), {}, std::nullopt);
// Only while1 and while2 are unrollable
EXPECT_EQ(unrollable_loops.size(), 2);
}
Expand Down

0 comments on commit 0e90a49

Please sign in to comment.