diff --git a/xla/service/bfloat16_propagation.cc b/xla/service/bfloat16_propagation.cc index 3bf490a4f5e887..4ea18f044cb55b 100644 --- a/xla/service/bfloat16_propagation.cc +++ b/xla/service/bfloat16_propagation.cc @@ -352,8 +352,10 @@ bool BFloat16Propagation::ShouldKeepPrecisionUnchanged( } // Do not change precision for side-effecting instructions, control flow, and // bitcast-convert, because this pass might break the interfaces or - // assumptions for them. - return inst->opcode() == HloOpcode::kCustomCall || + // assumptions for them. It is safe to change precision for AllocateBuffer + // since it is merely a buffer allocation and does not have any side effects. + return (inst->opcode() == HloOpcode::kCustomCall && + !inst->IsCustomCall("AllocateBuffer")) || inst->opcode() == HloOpcode::kCall || inst->opcode() == HloOpcode::kBitcastConvert || inst->HasSideEffectNoRecurse(); @@ -867,6 +869,10 @@ absl::StatusOr BFloat16Propagation::Run( if (leaf->shape().element_type() != F32) { return leaf; } + std::cout << "here2" << std::endl; + std::cout << "adding convert for " << leaf->name() << std::endl; + std::cout << "new shape: " << leaf->shape().ToString() + << std::endl; return comp->AddInstruction( HloInstruction::CreateConvert(leaf->shape(), leaf)); })); @@ -937,6 +943,10 @@ absl::StatusOr BFloat16Propagation::Run( auto converted_shape = ShapeUtil::ChangeElementType(leaf->shape(), BF16); UpdateLayout(&converted_shape); + std::cout << "here1" << std::endl; + std::cout << "adding convert for " << leaf->name() << std::endl; + std::cout << "new shape: " << converted_shape.ToString() + << std::endl; return comp->AddInstruction( HloInstruction::CreateConvert(converted_shape, leaf)); })); diff --git a/xla/service/hlo_unstacker.cc b/xla/service/hlo_unstacker.cc index a3b7eec937dc86..03c6f5c2eb6356 100644 --- a/xla/service/hlo_unstacker.cc +++ b/xla/service/hlo_unstacker.cc @@ -110,7 +110,8 @@ struct UnstackerMetadata { VLOG(3) << "Prepared module: " << module->name() << " for unstacking."; } std::vector> loops = - WhileLoopUnroller::GetUnrollableLoops(module, {}); + WhileLoopUnroller::GetUnrollableLoops(module, {}, + /*unroll_config=*/std::nullopt); for (const auto& [instr, while_loop_config] : loops) { metadata.unrollable_loop_bodies[instr->while_body()] = while_loop_config; metadata.bodies[instr->while_body()] = instr; diff --git a/xla/service/scan_loop_accumulator_input_unification.cc b/xla/service/scan_loop_accumulator_input_unification.cc index b123bd83e5fd99..4470ae615eb7e5 100644 --- a/xla/service/scan_loop_accumulator_input_unification.cc +++ b/xla/service/scan_loop_accumulator_input_unification.cc @@ -270,7 +270,8 @@ absl::StatusOr ScanLoopAccumulatorInputUnification::Run( // accumulators and inputs that are by definition updated and read fully via // dynamic-update-slice and dynamic-sliced within a loop. std::vector> unrollable_loops = - WhileLoopUnroller::GetUnrollableLoops(module, execution_threads); + WhileLoopUnroller::GetUnrollableLoops(module, execution_threads, + /*unroll_config=*/std::nullopt); // TODO(b/337883537): We might want to simplify compare instructions before // this. It helps us identify more inputs and accumulators. diff --git a/xla/service/while_loop_unroller.cc b/xla/service/while_loop_unroller.cc index 3bb2f8d6c8300b..053c20aa8aa231 100644 --- a/xla/service/while_loop_unroller.cc +++ b/xla/service/while_loop_unroller.cc @@ -665,7 +665,7 @@ std::optional MatchShapeCoveringDynamicIndexInstruction( WhileLoopUnroller::GetUnrollableLoops( HloModule* module, const absl::flat_hash_set& execution_threads, - const UnrollConfig& unroll_config) { + std::optional unroll_config) { // Processing the while loops in the reverse topological order. If the body // of while loop A calls while loop B, B comes before A. std::vector all_while_ops; @@ -676,13 +676,16 @@ WhileLoopUnroller::GetUnrollableLoops( std::vector> while_loop_configs; for (HloInstruction* instr : all_while_ops) { std::optional config = IsLoopUnrollable(instr); - if (config.has_value()) { - if (!InitialFeasibilityCheck(instr, config.value(), unroll_config)) { - VLOG(3) << "Initial feasibility check failed for " << instr->name(); - continue; - } - while_loop_configs.emplace_back(instr, config.value()); + if (!config.has_value()) { + continue; + } + if (unroll_config.has_value() && + !InitialFeasibilityCheck(instr, config.value(), + unroll_config.value())) { + VLOG(3) << "Initial feasibility check failed for " << instr->name(); + continue; } + while_loop_configs.emplace_back(instr, config.value()); } return while_loop_configs; } @@ -764,8 +767,8 @@ absl::StatusOr WhileLoopUnroller::Run( // unroll. We do this ahead of time so we don't have to worry about mutating // the lists of computations or instructions while we iterate. std::vector> - unrollable_while_ops = - GetUnrollableLoops(module, execution_threads, unroll_config_); + unrollable_while_ops = GetUnrollableLoops( + module, execution_threads, /*unroll_config=*/unroll_config_); VLOG(3) << "Number of while instructions in the module to unroll: " << unrollable_while_ops.size(); diff --git a/xla/service/while_loop_unroller.h b/xla/service/while_loop_unroller.h index def7d9d9a33e1c..face63336372a1 100644 --- a/xla/service/while_loop_unroller.h +++ b/xla/service/while_loop_unroller.h @@ -49,7 +49,7 @@ struct WhileLoopConfig { int64_t induction_var_idx; }; -// Config for unrollable while loops. +// Result for unrolled while loops. struct UnrollResult { // Whether it's unrolled. bool unrolled = false; @@ -120,12 +120,14 @@ class WhileLoopUnroller : public HloModulePass { static std::optional IsLoopUnrollable( HloInstruction* while_op); - // Returns the list of unrollable loops in the given module + // Returns the list of unrollable loops in the given module. If + // `unroll_config` is provided, it will be used to check feasibility according + // to InitialFeasibilityCheck method static std::vector> GetUnrollableLoops( HloModule* module, const absl::flat_hash_set& execution_threads, - const UnrollConfig& unroll_config = UnrollConfig()); + std::optional unroll_config); // Unrolls the given while loop with the default behaviour set to full unroll. // If wrap_in_trivial_loop is set, the unrolled body of the loop will be diff --git a/xla/service/while_loop_unroller_test.cc b/xla/service/while_loop_unroller_test.cc index fa8fb1dbff72ff..fa44905344f66d 100644 --- a/xla/service/while_loop_unroller_test.cc +++ b/xla/service/while_loop_unroller_test.cc @@ -495,8 +495,8 @@ TEST_F(WhileLoopUnrollerTest, GetUnrollableLoops) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_string)); - auto unrollable_loops = - WhileLoopUnroller::GetUnrollableLoops(module.get(), {}); + auto unrollable_loops = WhileLoopUnroller::GetUnrollableLoops( + module.get(), {}, /*unroll_config=*/std::nullopt); // Only while1 and while2 are unrollable EXPECT_EQ(unrollable_loops.size(), 2); }