Skip to content

Commit

Permalink
Make unroll_config optional in WhileLoopUnroller::GetUnrollableLoops.
Browse files Browse the repository at this point in the history
This change allows more loops to be considered when using GetUnrollableLoops, e.g., in while_initializer_removal pass.

Also, make sure the bfloat16 propagation pass correctly handles AllocateBuffer custom-calls.

PiperOrigin-RevId: 667782425
  • Loading branch information
fhoushmand authored and Google-ML-Automation committed Sep 18, 2024
1 parent a9e9b31 commit 8c897a0
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 18 deletions.
6 changes: 4 additions & 2 deletions xla/service/bfloat16_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,10 @@ bool BFloat16Propagation::ShouldKeepPrecisionUnchanged(
}
// Do not change precision for side-effecting instructions, control flow, and
// bitcast-convert, because this pass might break the interfaces or
// assumptions for them.
return inst->opcode() == HloOpcode::kCustomCall ||
// assumptions for them. It is safe to change precision for AllocateBuffer
// since it is merely a buffer allocation and does not have any side effects.
return (inst->opcode() == HloOpcode::kCustomCall &&
!inst->IsCustomCall("AllocateBuffer")) ||
inst->opcode() == HloOpcode::kCall ||
inst->opcode() == HloOpcode::kBitcastConvert ||
inst->HasSideEffectNoRecurse();
Expand Down
3 changes: 2 additions & 1 deletion xla/service/hlo_unstacker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ struct UnstackerMetadata {
VLOG(3) << "Prepared module: " << module->name() << " for unstacking.";
}
std::vector<std::pair<HloInstruction*, WhileLoopConfig>> loops =
WhileLoopUnroller::GetUnrollableLoops(module, {});
WhileLoopUnroller::GetUnrollableLoops(module, {},
/*unroll_config=*/std::nullopt);
for (const auto& [instr, while_loop_config] : loops) {
metadata.unrollable_loop_bodies[instr->while_body()] = while_loop_config;
metadata.bodies[instr->while_body()] = instr;
Expand Down
3 changes: 2 additions & 1 deletion xla/service/scan_loop_accumulator_input_unification.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ absl::StatusOr<bool> ScanLoopAccumulatorInputUnification::Run(
// accumulators and inputs that are by definition updated and read fully via
// dynamic-update-slice and dynamic-sliced within a loop.
std::vector<std::pair<HloInstruction*, WhileLoopConfig>> unrollable_loops =
WhileLoopUnroller::GetUnrollableLoops(module, execution_threads);
WhileLoopUnroller::GetUnrollableLoops(module, execution_threads,
/*unroll_config=*/std::nullopt);

// TODO(b/337883537): We might want to simplify compare instructions before
// this. It helps us identify more inputs and accumulators.
Expand Down
21 changes: 12 additions & 9 deletions xla/service/while_loop_unroller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ std::optional<int64_t> MatchShapeCoveringDynamicIndexInstruction(
WhileLoopUnroller::GetUnrollableLoops(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads,
const UnrollConfig& unroll_config) {
std::optional<UnrollConfig> unroll_config) {
// Processing the while loops in the reverse topological order. If the body
// of while loop A calls while loop B, B comes before A.
std::vector<HloInstruction*> all_while_ops;
Expand All @@ -676,13 +676,16 @@ WhileLoopUnroller::GetUnrollableLoops(
std::vector<std::pair<HloInstruction*, WhileLoopConfig>> while_loop_configs;
for (HloInstruction* instr : all_while_ops) {
std::optional<WhileLoopConfig> config = IsLoopUnrollable(instr);
if (config.has_value()) {
if (!InitialFeasibilityCheck(instr, config.value(), unroll_config)) {
VLOG(3) << "Initial feasibility check failed for " << instr->name();
continue;
}
while_loop_configs.emplace_back(instr, config.value());
if (!config.has_value()) {
continue;
}
if (unroll_config.has_value() &&
!InitialFeasibilityCheck(instr, config.value(),
unroll_config.value())) {
VLOG(3) << "Initial feasibility check failed for " << instr->name();
continue;
}
while_loop_configs.emplace_back(instr, config.value());
}
return while_loop_configs;
}
Expand Down Expand Up @@ -764,8 +767,8 @@ absl::StatusOr<bool> WhileLoopUnroller::Run(
// unroll. We do this ahead of time so we don't have to worry about mutating
// the lists of computations or instructions while we iterate.
std::vector<std::pair<HloInstruction*, WhileLoopConfig>>
unrollable_while_ops =
GetUnrollableLoops(module, execution_threads, unroll_config_);
unrollable_while_ops = GetUnrollableLoops(
module, execution_threads, /*unroll_config=*/unroll_config_);
VLOG(3) << "Number of while instructions in the module to unroll: "
<< unrollable_while_ops.size();

Expand Down
8 changes: 5 additions & 3 deletions xla/service/while_loop_unroller.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct WhileLoopConfig {
int64_t induction_var_idx;
};

// Config for unrollable while loops.
// Result for unrolled while loops.
struct UnrollResult {
// Whether it's unrolled.
bool unrolled = false;
Expand Down Expand Up @@ -120,12 +120,14 @@ class WhileLoopUnroller : public HloModulePass {
static std::optional<WhileLoopConfig> IsLoopUnrollable(
HloInstruction* while_op);

// Returns the list of unrollable loops in the given module
// Returns the list of unrollable loops in the given module. If
// `unroll_config` is provided, it will be used to check feasibility according
// to InitialFeasibilityCheck method
static std::vector<std::pair<HloInstruction*, WhileLoopConfig>>
GetUnrollableLoops(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads,
const UnrollConfig& unroll_config = UnrollConfig());
std::optional<UnrollConfig> unroll_config);

// Unrolls the given while loop with the default behaviour set to full unroll.
// If wrap_in_trivial_loop is set, the unrolled body of the loop will be
Expand Down
4 changes: 2 additions & 2 deletions xla/service/while_loop_unroller_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,8 +495,8 @@ TEST_F(WhileLoopUnrollerTest, GetUnrollableLoops) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));

auto unrollable_loops =
WhileLoopUnroller::GetUnrollableLoops(module.get(), {});
auto unrollable_loops = WhileLoopUnroller::GetUnrollableLoops(
module.get(), {}, /*unroll_config=*/std::nullopt);
// Only while1 and while2 are unrollable
EXPECT_EQ(unrollable_loops.size(), 2);
}
Expand Down

0 comments on commit 8c897a0

Please sign in to comment.