diff --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp index 4ad426285ce2f0..a27cb4dd219c30 100644 --- a/llvm/lib/Transforms/Utils/InlineFunction.cpp +++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp @@ -181,9 +181,21 @@ namespace { } } }; - } // end anonymous namespace +static IntrinsicInst *getConvergenceEntry(BasicBlock &BB) { + auto *I = BB.getFirstNonPHI(); + while (I) { + if (auto *IntrinsicCall = dyn_cast(I)) { + if (IntrinsicCall->isEntry()) { + return IntrinsicCall; + } + } + I = I->getNextNode(); + } + return nullptr; +} + /// Get or create a target for the branch from ResumeInsts. BasicBlock *LandingPadInliningInfo::getInnerResumeDest() { if (InnerResumeDest) return InnerResumeDest; @@ -2496,15 +2508,10 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // fully implements convergence control tokens, there is no mixing of // controlled and uncontrolled convergent operations in the whole program. if (CB.isConvergent()) { - auto *I = CalledFunc->getEntryBlock().getFirstNonPHI(); - if (auto *IntrinsicCall = dyn_cast(I)) { - if (IntrinsicCall->getIntrinsicID() == - Intrinsic::experimental_convergence_entry) { - if (!ConvergenceControlToken) { - return InlineResult::failure( - "convergent call needs convergencectrl operand"); - } - } + if (!ConvergenceControlToken && + getConvergenceEntry(CalledFunc->getEntryBlock())) { + return InlineResult::failure( + "convergent call needs convergencectrl operand"); } } @@ -2795,13 +2802,10 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, } if (ConvergenceControlToken) { - auto *I = FirstNewBlock->getFirstNonPHI(); - if (auto *IntrinsicCall = dyn_cast(I)) { - if (IntrinsicCall->getIntrinsicID() == - Intrinsic::experimental_convergence_entry) { - IntrinsicCall->replaceAllUsesWith(ConvergenceControlToken); - IntrinsicCall->eraseFromParent(); - } + IntrinsicInst *IntrinsicCall = getConvergenceEntry(*FirstNewBlock); + if (IntrinsicCall) { + IntrinsicCall->replaceAllUsesWith(ConvergenceControlToken); + IntrinsicCall->eraseFromParent(); } } diff --git a/llvm/test/Transforms/Inline/convergence-inline.ll b/llvm/test/Transforms/Inline/convergence-inline.ll index 8c67e6a59b7db1..4996a2376be638 100644 --- a/llvm/test/Transforms/Inline/convergence-inline.ll +++ b/llvm/test/Transforms/Inline/convergence-inline.ll @@ -185,6 +185,30 @@ define void @test_two_calls() convergent { ret void } +define i32 @token_not_first(i32 %x) convergent alwaysinline { +; CHECK-LABEL: @token_not_first( +; CHECK-NEXT: {{%.*}} = alloca ptr, align 8 +; CHECK-NEXT: [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry() +; CHECK-NEXT: [[Y:%.*]] = call i32 @g(i32 [[X:%.*]]) [ "convergencectrl"(token [[TOKEN]]) ] +; CHECK-NEXT: ret i32 [[Y]] +; + %p = alloca ptr, align 8 + %token = call token @llvm.experimental.convergence.entry() + %y = call i32 @g(i32 %x) [ "convergencectrl"(token %token) ] + ret i32 %y +} + +define void @test_token_not_first() convergent { +; CHECK-LABEL: @test_token_not_first( +; CHECK-NEXT: [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry() +; CHECK-NEXT: {{%.*}} = call i32 @g(i32 23) [ "convergencectrl"(token [[TOKEN]]) ] +; CHECK-NEXT: ret void +; + %token = call token @llvm.experimental.convergence.entry() + %x = call i32 @token_not_first(i32 23) [ "convergencectrl"(token %token) ] + ret void +} + declare void @f(i32) convergent declare i32 @g(i32) convergent