From 702b2c09b7952b0df46a975cc88aa19138d0e861 Mon Sep 17 00:00:00 2001 From: Yingwei Zheng Date: Sat, 4 Jan 2025 22:55:51 +0800 Subject: [PATCH 1/7] [IRBuilder] Refactor FMF interface --- llvm/include/llvm/IR/IRBuilder.h | 237 +++++++++--------- llvm/lib/IR/IRBuilder.cpp | 59 ++--- .../AArch64/AArch64TargetTransformInfo.cpp | 6 +- .../AggressiveInstCombine.cpp | 8 +- .../InstCombine/InstCombineAddSub.cpp | 19 +- .../InstCombine/InstCombineAndOrXor.cpp | 37 ++- .../InstCombine/InstCombineCasts.cpp | 16 +- .../InstCombine/InstCombineMulDivRem.cpp | 34 +-- .../InstCombine/InstCombineSelect.cpp | 21 +- llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 15 +- .../lib/Transforms/Utils/SimplifyLibCalls.cpp | 52 ++-- .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 7 +- 12 files changed, 229 insertions(+), 282 deletions(-) diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h index 8cdfa27ece9378..c1f15783fb7592 100644 --- a/llvm/include/llvm/IR/IRBuilder.h +++ b/llvm/include/llvm/IR/IRBuilder.h @@ -87,6 +87,20 @@ class IRBuilderCallbackInserter : public IRBuilderDefaultInserter { } }; +/// This provides a helper for copying FMF from an instruction or setting +/// specified flags. +struct FMFSource final { + Instruction *Source; + std::optional FMF; + + FMFSource() : Source(nullptr) {} + FMFSource(Instruction *Source) : Source(Source) { + if (Source) + FMF = Source->getFastMathFlags(); + } + FMFSource(FastMathFlags FMF) : Source(nullptr), FMF(FMF) {} +}; + /// Common base class shared among various IRBuilders. class IRBuilderBase { /// Pairs of (metadata kind, MDNode *) that should be added to all newly @@ -958,29 +972,27 @@ class IRBuilderBase { /// Create a call to intrinsic \p ID with 1 operand which is mangled on its /// type. CallInst *CreateUnaryIntrinsic(Intrinsic::ID ID, Value *V, - Instruction *FMFSource = nullptr, + FMFSource FMFSource = {}, const Twine &Name = ""); /// Create a call to intrinsic \p ID with 2 operands which is mangled on the /// first type. Value *CreateBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, Value *RHS, - Instruction *FMFSource = nullptr, + FMFSource FMFSource = {}, const Twine &Name = ""); /// Create a call to intrinsic \p ID with \p Args, mangled using \p Types. If /// \p FMFSource is provided, copy fast-math-flags from that instruction to /// the intrinsic. CallInst *CreateIntrinsic(Intrinsic::ID ID, ArrayRef Types, - ArrayRef Args, - Instruction *FMFSource = nullptr, + ArrayRef Args, FMFSource FMFSource = {}, const Twine &Name = ""); /// Create a call to intrinsic \p ID with \p RetTy and \p Args. If /// \p FMFSource is provided, copy fast-math-flags from that instruction to /// the intrinsic. CallInst *CreateIntrinsic(Type *RetTy, Intrinsic::ID ID, - ArrayRef Args, - Instruction *FMFSource = nullptr, + ArrayRef Args, FMFSource FMFSource = {}, const Twine &Name = ""); /// Create call to the minnum intrinsic. @@ -1026,15 +1038,14 @@ class IRBuilderBase { } /// Create call to the copysign intrinsic. - Value *CreateCopySign(Value *LHS, Value *RHS, - Instruction *FMFSource = nullptr, + Value *CreateCopySign(Value *LHS, Value *RHS, FMFSource FMFSource = {}, const Twine &Name = "") { return CreateBinaryIntrinsic(Intrinsic::copysign, LHS, RHS, FMFSource, Name); } /// Create call to the ldexp intrinsic. - Value *CreateLdexp(Value *Src, Value *Exp, Instruction *FMFSource = nullptr, + Value *CreateLdexp(Value *Src, Value *Exp, FMFSource FMFSource = {}, const Twine &Name = "") { assert(!IsFPConstrained && "TODO: Support strictfp"); return CreateIntrinsic(Intrinsic::ldexp, {Src->getType(), Exp->getType()}, @@ -1555,144 +1566,113 @@ class IRBuilderBase { Value *CreateFAdd(Value *L, Value *R, const Twine &Name = "", MDNode *FPMD = nullptr) { - if (IsFPConstrained) - return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fadd, - L, R, nullptr, Name, FPMD); - - if (Value *V = Folder.FoldBinOpFMF(Instruction::FAdd, L, R, FMF)) - return V; - Instruction *I = setFPAttrs(BinaryOperator::CreateFAdd(L, R), FPMD, FMF); - return Insert(I, Name); + return CreateFAddFMF(L, R, {}, Name, FPMD); } - /// Copy fast-math-flags from an instruction rather than using the builder's - /// default FMF. - Value *CreateFAddFMF(Value *L, Value *R, Instruction *FMFSource, - const Twine &Name = "") { + Value *CreateFAddFMF(Value *L, Value *R, FMFSource FMFSource, + const Twine &Name = "", MDNode *FPMD = nullptr) { if (IsFPConstrained) return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fadd, - L, R, FMFSource, Name); + L, R, FMFSource, Name, FPMD); - FastMathFlags FMF = FMFSource->getFastMathFlags(); - if (Value *V = Folder.FoldBinOpFMF(Instruction::FAdd, L, R, FMF)) + if (Value *V = Folder.FoldBinOpFMF(Instruction::FAdd, L, R, + FMFSource.FMF.value_or(FMF))) return V; - Instruction *I = setFPAttrs(BinaryOperator::CreateFAdd(L, R), nullptr, FMF); + Instruction *I = setFPAttrs(BinaryOperator::CreateFAdd(L, R), FPMD, + FMFSource.FMF.value_or(FMF)); return Insert(I, Name); } Value *CreateFSub(Value *L, Value *R, const Twine &Name = "", MDNode *FPMD = nullptr) { - if (IsFPConstrained) - return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fsub, - L, R, nullptr, Name, FPMD); - - if (Value *V = Folder.FoldBinOpFMF(Instruction::FSub, L, R, FMF)) - return V; - Instruction *I = setFPAttrs(BinaryOperator::CreateFSub(L, R), FPMD, FMF); - return Insert(I, Name); + return CreateFSubFMF(L, R, {}, Name, FPMD); } - /// Copy fast-math-flags from an instruction rather than using the builder's - /// default FMF. - Value *CreateFSubFMF(Value *L, Value *R, Instruction *FMFSource, - const Twine &Name = "") { + Value *CreateFSubFMF(Value *L, Value *R, FMFSource FMFSource, + const Twine &Name = "", MDNode *FPMD = nullptr) { if (IsFPConstrained) return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fsub, - L, R, FMFSource, Name); + L, R, FMFSource, Name, FPMD); - FastMathFlags FMF = FMFSource->getFastMathFlags(); - if (Value *V = Folder.FoldBinOpFMF(Instruction::FSub, L, R, FMF)) + if (Value *V = Folder.FoldBinOpFMF(Instruction::FSub, L, R, + FMFSource.FMF.value_or(FMF))) return V; - Instruction *I = setFPAttrs(BinaryOperator::CreateFSub(L, R), nullptr, FMF); + Instruction *I = setFPAttrs(BinaryOperator::CreateFSub(L, R), FPMD, + FMFSource.FMF.value_or(FMF)); return Insert(I, Name); } Value *CreateFMul(Value *L, Value *R, const Twine &Name = "", MDNode *FPMD = nullptr) { - if (IsFPConstrained) - return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fmul, - L, R, nullptr, Name, FPMD); - - if (Value *V = Folder.FoldBinOpFMF(Instruction::FMul, L, R, FMF)) - return V; - Instruction *I = setFPAttrs(BinaryOperator::CreateFMul(L, R), FPMD, FMF); - return Insert(I, Name); + return CreateFMulFMF(L, R, {}, Name, FPMD); } - /// Copy fast-math-flags from an instruction rather than using the builder's - /// default FMF. - Value *CreateFMulFMF(Value *L, Value *R, Instruction *FMFSource, - const Twine &Name = "") { + Value *CreateFMulFMF(Value *L, Value *R, FMFSource FMFSource, + const Twine &Name = "", MDNode *FPMD = nullptr) { if (IsFPConstrained) return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fmul, - L, R, FMFSource, Name); + L, R, FMFSource, Name, FPMD); - FastMathFlags FMF = FMFSource->getFastMathFlags(); - if (Value *V = Folder.FoldBinOpFMF(Instruction::FMul, L, R, FMF)) + if (Value *V = Folder.FoldBinOpFMF(Instruction::FMul, L, R, + FMFSource.FMF.value_or(FMF))) return V; - Instruction *I = setFPAttrs(BinaryOperator::CreateFMul(L, R), nullptr, FMF); + Instruction *I = setFPAttrs(BinaryOperator::CreateFMul(L, R), FPMD, + FMFSource.FMF.value_or(FMF)); return Insert(I, Name); } Value *CreateFDiv(Value *L, Value *R, const Twine &Name = "", MDNode *FPMD = nullptr) { - if (IsFPConstrained) - return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fdiv, - L, R, nullptr, Name, FPMD); - - if (Value *V = Folder.FoldBinOpFMF(Instruction::FDiv, L, R, FMF)) - return V; - Instruction *I = setFPAttrs(BinaryOperator::CreateFDiv(L, R), FPMD, FMF); - return Insert(I, Name); + return CreateFDivFMF(L, R, {}, Name, FPMD); } - /// Copy fast-math-flags from an instruction rather than using the builder's - /// default FMF. - Value *CreateFDivFMF(Value *L, Value *R, Instruction *FMFSource, - const Twine &Name = "") { + Value *CreateFDivFMF(Value *L, Value *R, FMFSource FMFSource, + const Twine &Name = "", MDNode *FPMD = nullptr) { if (IsFPConstrained) return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fdiv, - L, R, FMFSource, Name); + L, R, FMFSource, Name, FPMD); - FastMathFlags FMF = FMFSource->getFastMathFlags(); - if (Value *V = Folder.FoldBinOpFMF(Instruction::FDiv, L, R, FMF)) + if (Value *V = Folder.FoldBinOpFMF(Instruction::FDiv, L, R, + FMFSource.FMF.value_or(FMF))) return V; - Instruction *I = setFPAttrs(BinaryOperator::CreateFDiv(L, R), nullptr, FMF); + Instruction *I = setFPAttrs(BinaryOperator::CreateFDiv(L, R), FPMD, + FMFSource.FMF.value_or(FMF)); return Insert(I, Name); } Value *CreateFRem(Value *L, Value *R, const Twine &Name = "", MDNode *FPMD = nullptr) { - if (IsFPConstrained) - return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_frem, - L, R, nullptr, Name, FPMD); - - if (Value *V = Folder.FoldBinOpFMF(Instruction::FRem, L, R, FMF)) return V; - Instruction *I = setFPAttrs(BinaryOperator::CreateFRem(L, R), FPMD, FMF); - return Insert(I, Name); + return CreateFRemFMF(L, R, {}, Name, FPMD); } - /// Copy fast-math-flags from an instruction rather than using the builder's - /// default FMF. - Value *CreateFRemFMF(Value *L, Value *R, Instruction *FMFSource, - const Twine &Name = "") { + Value *CreateFRemFMF(Value *L, Value *R, FMFSource FMFSource, + const Twine &Name = "", MDNode *FPMD = nullptr) { if (IsFPConstrained) return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_frem, - L, R, FMFSource, Name); + L, R, FMFSource, Name, FPMD); - FastMathFlags FMF = FMFSource->getFastMathFlags(); - if (Value *V = Folder.FoldBinOpFMF(Instruction::FRem, L, R, FMF)) return V; - Instruction *I = setFPAttrs(BinaryOperator::CreateFRem(L, R), nullptr, FMF); + if (Value *V = Folder.FoldBinOpFMF(Instruction::FRem, L, R, + FMFSource.FMF.value_or(FMF))) + return V; + Instruction *I = setFPAttrs(BinaryOperator::CreateFRem(L, R), FPMD, + FMFSource.FMF.value_or(FMF)); return Insert(I, Name); } Value *CreateBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS, const Twine &Name = "", MDNode *FPMathTag = nullptr) { - if (Value *V = Folder.FoldBinOp(Opc, LHS, RHS)) return V; + return CreateBinOpFMF(Opc, LHS, RHS, {}, Name, FPMathTag); + } + + Value *CreateBinOpFMF(Instruction::BinaryOps Opc, Value *LHS, Value *RHS, + FMFSource FMFSource, const Twine &Name = "", + MDNode *FPMathTag = nullptr) { + if (Value *V = Folder.FoldBinOp(Opc, LHS, RHS)) + return V; Instruction *BinOp = BinaryOperator::Create(Opc, LHS, RHS); if (isa(BinOp)) - setFPAttrs(BinOp, FPMathTag, FMF); + setFPAttrs(BinOp, FPMathTag, FMFSource.FMF.value_or(FMF)); return Insert(BinOp, Name); } @@ -1731,13 +1711,13 @@ class IRBuilderBase { } CallInst *CreateConstrainedFPBinOp( - Intrinsic::ID ID, Value *L, Value *R, Instruction *FMFSource = nullptr, + Intrinsic::ID ID, Value *L, Value *R, FMFSource FMFSource = {}, const Twine &Name = "", MDNode *FPMathTag = nullptr, std::optional Rounding = std::nullopt, std::optional Except = std::nullopt); CallInst *CreateConstrainedFPUnroundedBinOp( - Intrinsic::ID ID, Value *L, Value *R, Instruction *FMFSource = nullptr, + Intrinsic::ID ID, Value *L, Value *R, FMFSource FMFSource = {}, const Twine &Name = "", MDNode *FPMathTag = nullptr, std::optional Except = std::nullopt); @@ -1752,21 +1732,17 @@ class IRBuilderBase { Value *CreateFNeg(Value *V, const Twine &Name = "", MDNode *FPMathTag = nullptr) { - if (Value *Res = Folder.FoldUnOpFMF(Instruction::FNeg, V, FMF)) - return Res; - return Insert(setFPAttrs(UnaryOperator::CreateFNeg(V), FPMathTag, FMF), - Name); + return CreateFNegFMF(V, {}, Name, FPMathTag); } - /// Copy fast-math-flags from an instruction rather than using the builder's - /// default FMF. - Value *CreateFNegFMF(Value *V, Instruction *FMFSource, - const Twine &Name = "") { - FastMathFlags FMF = FMFSource->getFastMathFlags(); - if (Value *Res = Folder.FoldUnOpFMF(Instruction::FNeg, V, FMF)) + Value *CreateFNegFMF(Value *V, FMFSource FMFSource, const Twine &Name = "", + MDNode *FPMathTag = nullptr) { + if (Value *Res = Folder.FoldUnOpFMF(Instruction::FNeg, V, + FMFSource.FMF.value_or(FMF))) return Res; - return Insert(setFPAttrs(UnaryOperator::CreateFNeg(V), nullptr, FMF), - Name); + return Insert(setFPAttrs(UnaryOperator::CreateFNeg(V), FPMathTag, + FMFSource.FMF.value_or(FMF)), + Name); } Value *CreateNot(Value *V, const Twine &Name = "") { @@ -2127,19 +2103,31 @@ class IRBuilderBase { Value *CreateFPTrunc(Value *V, Type *DestTy, const Twine &Name = "", MDNode *FPMathTag = nullptr) { + return CreateFPTruncFMF(V, DestTy, {}, Name, FPMathTag); + } + + Value *CreateFPTruncFMF(Value *V, Type *DestTy, FMFSource FMFSource, + const Twine &Name = "", MDNode *FPMathTag = nullptr) { if (IsFPConstrained) return CreateConstrainedFPCast( - Intrinsic::experimental_constrained_fptrunc, V, DestTy, nullptr, Name, - FPMathTag); - return CreateCast(Instruction::FPTrunc, V, DestTy, Name, FPMathTag); + Intrinsic::experimental_constrained_fptrunc, V, DestTy, FMFSource, + Name, FPMathTag); + return CreateCast(Instruction::FPTrunc, V, DestTy, Name, FPMathTag, + FMFSource); } Value *CreateFPExt(Value *V, Type *DestTy, const Twine &Name = "", MDNode *FPMathTag = nullptr) { + return CreateFPExtFMF(V, DestTy, {}, Name, FPMathTag); + } + + Value *CreateFPExtFMF(Value *V, Type *DestTy, FMFSource FMFSource, + const Twine &Name = "", MDNode *FPMathTag = nullptr) { if (IsFPConstrained) return CreateConstrainedFPCast(Intrinsic::experimental_constrained_fpext, - V, DestTy, nullptr, Name, FPMathTag); - return CreateCast(Instruction::FPExt, V, DestTy, Name, FPMathTag); + V, DestTy, FMFSource, Name, FPMathTag); + return CreateCast(Instruction::FPExt, V, DestTy, Name, FPMathTag, + FMFSource); } Value *CreatePtrToInt(Value *V, Type *DestTy, @@ -2187,14 +2175,15 @@ class IRBuilderBase { } Value *CreateCast(Instruction::CastOps Op, Value *V, Type *DestTy, - const Twine &Name = "", MDNode *FPMathTag = nullptr) { + const Twine &Name = "", MDNode *FPMathTag = nullptr, + FMFSource FMFSource = {}) { if (V->getType() == DestTy) return V; if (Value *Folded = Folder.FoldCast(Op, V, DestTy)) return Folded; Instruction *Cast = CastInst::Create(Op, V, DestTy); if (isa(Cast)) - setFPAttrs(Cast, FPMathTag, FMF); + setFPAttrs(Cast, FPMathTag, FMFSource.FMF.value_or(FMF)); return Insert(Cast, Name); } @@ -2255,9 +2244,8 @@ class IRBuilderBase { } CallInst *CreateConstrainedFPCast( - Intrinsic::ID ID, Value *V, Type *DestTy, - Instruction *FMFSource = nullptr, const Twine &Name = "", - MDNode *FPMathTag = nullptr, + Intrinsic::ID ID, Value *V, Type *DestTy, FMFSource FMFSource = {}, + const Twine &Name = "", MDNode *FPMathTag = nullptr, std::optional Rounding = std::nullopt, std::optional Except = std::nullopt); @@ -2392,7 +2380,16 @@ class IRBuilderBase { // Note that this differs from CreateFCmpS only if IsFPConstrained is true. Value *CreateFCmp(CmpInst::Predicate P, Value *LHS, Value *RHS, const Twine &Name = "", MDNode *FPMathTag = nullptr) { - return CreateFCmpHelper(P, LHS, RHS, Name, FPMathTag, false); + return CreateFCmpHelper(P, LHS, RHS, Name, FPMathTag, {}, false); + } + + // Create a quiet floating-point comparison (i.e. one that raises an FP + // exception only in the case where an input is a signaling NaN). + // Note that this differs from CreateFCmpS only if IsFPConstrained is true. + Value *CreateFCmpFMF(CmpInst::Predicate P, Value *LHS, Value *RHS, + FMFSource FMFSource, const Twine &Name = "", + MDNode *FPMathTag = nullptr) { + return CreateFCmpHelper(P, LHS, RHS, Name, FPMathTag, FMFSource, false); } Value *CreateCmp(CmpInst::Predicate Pred, Value *LHS, Value *RHS, @@ -2407,14 +2404,14 @@ class IRBuilderBase { // Note that this differs from CreateFCmp only if IsFPConstrained is true. Value *CreateFCmpS(CmpInst::Predicate P, Value *LHS, Value *RHS, const Twine &Name = "", MDNode *FPMathTag = nullptr) { - return CreateFCmpHelper(P, LHS, RHS, Name, FPMathTag, true); + return CreateFCmpHelper(P, LHS, RHS, Name, FPMathTag, {}, true); } private: // Helper routine to create either a signaling or a quiet FP comparison. Value *CreateFCmpHelper(CmpInst::Predicate P, Value *LHS, Value *RHS, const Twine &Name, MDNode *FPMathTag, - bool IsSignaling); + FMFSource FMFSource, bool IsSignaling); public: CallInst *CreateConstrainedFPCmp( @@ -2436,8 +2433,7 @@ class IRBuilderBase { private: CallInst *createCallHelper(Function *Callee, ArrayRef Ops, - const Twine &Name = "", - Instruction *FMFSource = nullptr, + const Twine &Name = "", FMFSource FMFSource = {}, ArrayRef OpBundles = {}); public: @@ -2483,6 +2479,9 @@ class IRBuilderBase { Value *CreateSelect(Value *C, Value *True, Value *False, const Twine &Name = "", Instruction *MDFrom = nullptr); + Value *CreateSelectFMF(Value *C, Value *True, Value *False, + FMFSource FMFSource, const Twine &Name = "", + Instruction *MDFrom = nullptr); VAArgInst *CreateVAArg(Value *List, Type *Ty, const Twine &Name = "") { return Insert(new VAArgInst(List, Ty), Name); diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp index f340f7aafdc76f..9bb4acfb96c750 100644 --- a/llvm/lib/IR/IRBuilder.cpp +++ b/llvm/lib/IR/IRBuilder.cpp @@ -78,11 +78,11 @@ void IRBuilderBase::SetInstDebugLocation(Instruction *I) const { CallInst * IRBuilderBase::createCallHelper(Function *Callee, ArrayRef Ops, - const Twine &Name, Instruction *FMFSource, + const Twine &Name, FMFSource FMFSource, ArrayRef OpBundles) { CallInst *CI = CreateCall(Callee, Ops, OpBundles, Name); - if (FMFSource) - CI->copyFastMathFlags(FMFSource); + if (FMFSource.FMF.has_value()) + CI->setFastMathFlags(*FMFSource.FMF); return CI; } @@ -869,7 +869,7 @@ CallInst *IRBuilderBase::CreateGCGetPointerOffset(Value *DerivedPtr, } CallInst *IRBuilderBase::CreateUnaryIntrinsic(Intrinsic::ID ID, Value *V, - Instruction *FMFSource, + FMFSource FMFSource, const Twine &Name) { Module *M = BB->getModule(); Function *Fn = Intrinsic::getOrInsertDeclaration(M, ID, {V->getType()}); @@ -877,12 +877,12 @@ CallInst *IRBuilderBase::CreateUnaryIntrinsic(Intrinsic::ID ID, Value *V, } Value *IRBuilderBase::CreateBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, - Value *RHS, Instruction *FMFSource, + Value *RHS, FMFSource FMFSource, const Twine &Name) { Module *M = BB->getModule(); Function *Fn = Intrinsic::getOrInsertDeclaration(M, ID, {LHS->getType()}); if (Value *V = Folder.FoldBinaryIntrinsic(ID, LHS, RHS, Fn->getReturnType(), - FMFSource)) + FMFSource.Source)) return V; return createCallHelper(Fn, {LHS, RHS}, Name, FMFSource); } @@ -890,7 +890,7 @@ Value *IRBuilderBase::CreateBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, CallInst *IRBuilderBase::CreateIntrinsic(Intrinsic::ID ID, ArrayRef Types, ArrayRef Args, - Instruction *FMFSource, + FMFSource FMFSource, const Twine &Name) { Module *M = BB->getModule(); Function *Fn = Intrinsic::getOrInsertDeclaration(M, ID, Types); @@ -899,7 +899,7 @@ CallInst *IRBuilderBase::CreateIntrinsic(Intrinsic::ID ID, CallInst *IRBuilderBase::CreateIntrinsic(Type *RetTy, Intrinsic::ID ID, ArrayRef Args, - Instruction *FMFSource, + FMFSource FMFSource, const Twine &Name) { Module *M = BB->getModule(); @@ -925,16 +925,13 @@ CallInst *IRBuilderBase::CreateIntrinsic(Type *RetTy, Intrinsic::ID ID, } CallInst *IRBuilderBase::CreateConstrainedFPBinOp( - Intrinsic::ID ID, Value *L, Value *R, Instruction *FMFSource, - const Twine &Name, MDNode *FPMathTag, - std::optional Rounding, + Intrinsic::ID ID, Value *L, Value *R, FMFSource FMFSource, + const Twine &Name, MDNode *FPMathTag, std::optional Rounding, std::optional Except) { Value *RoundingV = getConstrainedFPRounding(Rounding); Value *ExceptV = getConstrainedFPExcept(Except); - FastMathFlags UseFMF = FMF; - if (FMFSource) - UseFMF = FMFSource->getFastMathFlags(); + FastMathFlags UseFMF = FMFSource.FMF.value_or(FMF); CallInst *C = CreateIntrinsic(ID, {L->getType()}, {L, R, RoundingV, ExceptV}, nullptr, Name); @@ -944,14 +941,12 @@ CallInst *IRBuilderBase::CreateConstrainedFPBinOp( } CallInst *IRBuilderBase::CreateConstrainedFPUnroundedBinOp( - Intrinsic::ID ID, Value *L, Value *R, Instruction *FMFSource, + Intrinsic::ID ID, Value *L, Value *R, FMFSource FMFSource, const Twine &Name, MDNode *FPMathTag, std::optional Except) { Value *ExceptV = getConstrainedFPExcept(Except); - FastMathFlags UseFMF = FMF; - if (FMFSource) - UseFMF = FMFSource->getFastMathFlags(); + FastMathFlags UseFMF = FMFSource.FMF.value_or(FMF); CallInst *C = CreateIntrinsic(ID, {L->getType()}, {L, R, ExceptV}, nullptr, Name); @@ -976,15 +971,12 @@ Value *IRBuilderBase::CreateNAryOp(unsigned Opc, ArrayRef Ops, } CallInst *IRBuilderBase::CreateConstrainedFPCast( - Intrinsic::ID ID, Value *V, Type *DestTy, - Instruction *FMFSource, const Twine &Name, MDNode *FPMathTag, - std::optional Rounding, + Intrinsic::ID ID, Value *V, Type *DestTy, FMFSource FMFSource, + const Twine &Name, MDNode *FPMathTag, std::optional Rounding, std::optional Except) { Value *ExceptV = getConstrainedFPExcept(Except); - FastMathFlags UseFMF = FMF; - if (FMFSource) - UseFMF = FMFSource->getFastMathFlags(); + FastMathFlags UseFMF = FMFSource.FMF.value_or(FMF); CallInst *C; if (Intrinsic::hasConstrainedFPRoundingModeOperand(ID)) { @@ -1002,9 +994,10 @@ CallInst *IRBuilderBase::CreateConstrainedFPCast( return C; } -Value *IRBuilderBase::CreateFCmpHelper( - CmpInst::Predicate P, Value *LHS, Value *RHS, const Twine &Name, - MDNode *FPMathTag, bool IsSignaling) { +Value *IRBuilderBase::CreateFCmpHelper(CmpInst::Predicate P, Value *LHS, + Value *RHS, const Twine &Name, + MDNode *FPMathTag, FMFSource FMFSource, + bool IsSignaling) { if (IsFPConstrained) { auto ID = IsSignaling ? Intrinsic::experimental_constrained_fcmps : Intrinsic::experimental_constrained_fcmp; @@ -1013,7 +1006,9 @@ Value *IRBuilderBase::CreateFCmpHelper( if (auto *V = Folder.FoldCmp(P, LHS, RHS)) return V; - return Insert(setFPAttrs(new FCmpInst(P, LHS, RHS), FPMathTag, FMF), Name); + return Insert(setFPAttrs(new FCmpInst(P, LHS, RHS), FPMathTag, + FMFSource.FMF.value_or(FMF)), + Name); } CallInst *IRBuilderBase::CreateConstrainedFPCmp( @@ -1047,6 +1042,12 @@ CallInst *IRBuilderBase::CreateConstrainedFPCall( Value *IRBuilderBase::CreateSelect(Value *C, Value *True, Value *False, const Twine &Name, Instruction *MDFrom) { + return CreateSelectFMF(C, True, False, {}, Name, MDFrom); +} + +Value *IRBuilderBase::CreateSelectFMF(Value *C, Value *True, Value *False, + FMFSource FMFSource, const Twine &Name, + Instruction *MDFrom) { if (auto *V = Folder.FoldSelect(C, True, False)) return V; @@ -1057,7 +1058,7 @@ Value *IRBuilderBase::CreateSelect(Value *C, Value *True, Value *False, Sel = addBranchMetadata(Sel, Prof, Unpred); } if (isa(Sel)) - setFPAttrs(Sel, nullptr /* MDNode* */, FMF); + setFPAttrs(Sel, /*MDNode=*/nullptr, FMFSource.FMF.value_or(FMF)); return Insert(Sel, Name); } diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 515764c915bf4a..7c5e5336b65313 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -1638,10 +1638,8 @@ instCombineSVEVectorBinOp(InstCombiner &IC, IntrinsicInst &II) { !match(OpPredicate, m_Intrinsic( m_ConstantInt()))) return std::nullopt; - IRBuilderBase::FastMathFlagGuard FMFGuard(IC.Builder); - IC.Builder.setFastMathFlags(II.getFastMathFlags()); - auto BinOp = - IC.Builder.CreateBinOp(BinOpCode, II.getOperand(1), II.getOperand(2)); + auto BinOp = IC.Builder.CreateBinOpFMF( + BinOpCode, II.getOperand(1), II.getOperand(2), II.getFastMathFlags()); return IC.replaceInstUsesWith(II, BinOp); } diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp index 12ae6740e055ef..d45aee37801736 100644 --- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -425,11 +425,9 @@ static bool foldSqrt(CallInst *Call, LibFunc Func, TargetTransformInfo &TTI, Arg, 0, SimplifyQuery(Call->getDataLayout(), &TLI, &DT, &AC, Call)))) { IRBuilder<> Builder(Call); - IRBuilderBase::FastMathFlagGuard Guard(Builder); - Builder.setFastMathFlags(Call->getFastMathFlags()); - - Value *NewSqrt = Builder.CreateIntrinsic(Intrinsic::sqrt, Ty, Arg, - /*FMFSource=*/nullptr, "sqrt"); + Value *NewSqrt = + Builder.CreateIntrinsic(Intrinsic::sqrt, Ty, Arg, + /*FMFSource=*/Call->getFastMathFlags(), "sqrt"); Call->replaceAllUsesWith(NewSqrt); // Explicitly erase the old call because a call with side effects is not diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 7a184a19d7c54a..9dc593bdf3058f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -2845,12 +2845,11 @@ Instruction *InstCombinerImpl::hoistFNegAboveFMulFDiv(Value *FNegOp, // Make sure to preserve flags and metadata on the call. if (II->getIntrinsicID() == Intrinsic::ldexp) { FastMathFlags FMF = FMFSource.getFastMathFlags() | II->getFastMathFlags(); - IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); - Builder.setFastMathFlags(FMF); - - CallInst *New = Builder.CreateCall( - II->getCalledFunction(), - {Builder.CreateFNeg(II->getArgOperand(0)), II->getArgOperand(1)}); + CallInst *New = + Builder.CreateCall(II->getCalledFunction(), + {Builder.CreateFNegFMF(II->getArgOperand(0), FMF), + II->getArgOperand(1)}); + New->setFastMathFlags(FMF); New->copyMetadata(*II); return New; } @@ -2932,12 +2931,8 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { // flags the copysign doesn't also have. FastMathFlags FMF = I.getFastMathFlags(); FMF &= cast(OneUse)->getFastMathFlags(); - - IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); - Builder.setFastMathFlags(FMF); - - Value *NegY = Builder.CreateFNeg(Y); - Value *NewCopySign = Builder.CreateCopySign(X, NegY); + Value *NegY = Builder.CreateFNegFMF(Y, FMF); + Value *NewCopySign = Builder.CreateCopySign(X, NegY, FMF); return replaceInstUsesWith(I, NewCopySign); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index e576eea4ca36a1..37a7c4d88b234d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -39,11 +39,12 @@ static Value *getNewICmpValue(unsigned Code, bool Sign, Value *LHS, Value *RHS, /// This is the complement of getFCmpCode, which turns an opcode and two /// operands into either a FCmp instruction, or a true/false constant. static Value *getFCmpValue(unsigned Code, Value *LHS, Value *RHS, - InstCombiner::BuilderTy &Builder) { + InstCombiner::BuilderTy &Builder, + FastMathFlags FMF) { FCmpInst::Predicate NewPred; if (Constant *TorF = getPredForFCmpCode(Code, LHS->getType(), NewPred)) return TorF; - return Builder.CreateFCmp(NewPred, LHS, RHS); + return Builder.CreateFCmpFMF(NewPred, LHS, RHS, FMF); } /// Emit a computation of: (V >= Lo && V < Hi) if Inside is true, otherwise @@ -1429,12 +1430,9 @@ static Value *matchIsFiniteTest(InstCombiner::BuilderTy &Builder, FCmpInst *LHS, !matchUnorderedInfCompare(PredR, RHS0, RHS1)) return nullptr; - IRBuilder<>::FastMathFlagGuard FMFG(Builder); - FastMathFlags FMF = LHS->getFastMathFlags(); - FMF &= RHS->getFastMathFlags(); - Builder.setFastMathFlags(FMF); - - return Builder.CreateFCmp(FCmpInst::getOrderedPredicate(PredR), RHS0, RHS1); + return Builder.CreateFCmpFMF(FCmpInst::getOrderedPredicate(PredR), RHS0, RHS1, + LHS->getFastMathFlags() & + RHS->getFastMathFlags()); } Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, @@ -1470,12 +1468,8 @@ Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, // Intersect the fast math flags. // TODO: We can union the fast math flags unless this is a logical select. - IRBuilder<>::FastMathFlagGuard FMFG(Builder); - FastMathFlags FMF = LHS->getFastMathFlags(); - FMF &= RHS->getFastMathFlags(); - Builder.setFastMathFlags(FMF); - - return getFCmpValue(NewPred, LHS0, LHS1, Builder); + return getFCmpValue(NewPred, LHS0, LHS1, Builder, + LHS->getFastMathFlags() & RHS->getFastMathFlags()); } // This transform is not valid for a logical select. @@ -1492,10 +1486,8 @@ Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, // Ignore the constants because they are obviously not NANs: // (fcmp ord x, 0.0) & (fcmp ord y, 0.0) -> (fcmp ord x, y) // (fcmp uno x, 0.0) | (fcmp uno y, 0.0) -> (fcmp uno x, y) - IRBuilder<>::FastMathFlagGuard FMFG(Builder); - Builder.setFastMathFlags(LHS->getFastMathFlags() & - RHS->getFastMathFlags()); - return Builder.CreateFCmp(PredL, LHS0, RHS0); + return Builder.CreateFCmpFMF( + PredL, LHS0, RHS0, LHS->getFastMathFlags() & RHS->getFastMathFlags()); } } @@ -1557,15 +1549,14 @@ Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, std::swap(PredL, PredR); } if (IsLessThanOrLessEqual(IsAnd ? PredL : PredR)) { - BuilderTy::FastMathFlagGuard Guard(Builder); FastMathFlags NewFlag = LHS->getFastMathFlags(); if (!IsLogicalSelect) NewFlag |= RHS->getFastMathFlags(); - Builder.setFastMathFlags(NewFlag); - Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, LHS0); - return Builder.CreateFCmp(PredL, FAbs, - ConstantFP::get(LHS0->getType(), *LHSC)); + Value *FAbs = + Builder.CreateUnaryIntrinsic(Intrinsic::fabs, LHS0, NewFlag); + return Builder.CreateFCmpFMF( + PredL, FAbs, ConstantFP::get(LHS0->getType(), *LHSC), NewFlag); } } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 0b9379965f4249..4ec1af394464bb 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -1852,15 +1852,13 @@ Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) { Value *X; Instruction *Op = dyn_cast(FPT.getOperand(0)); if (Op && Op->hasOneUse()) { - IRBuilder<>::FastMathFlagGuard FMFG(Builder); FastMathFlags FMF = FPT.getFastMathFlags(); if (auto *FPMO = dyn_cast(Op)) FMF &= FPMO->getFastMathFlags(); - Builder.setFastMathFlags(FMF); if (match(Op, m_FNeg(m_Value(X)))) { - Value *InnerTrunc = Builder.CreateFPTrunc(X, Ty); - Value *Neg = Builder.CreateFNeg(InnerTrunc); + Value *InnerTrunc = Builder.CreateFPTruncFMF(X, Ty, FMF); + Value *Neg = Builder.CreateFNegFMF(InnerTrunc, FMF); return replaceInstUsesWith(FPT, Neg); } @@ -1870,15 +1868,17 @@ Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) { if (match(Op, m_Select(m_Value(Cond), m_FPExt(m_Value(X)), m_Value(Y))) && X->getType() == Ty) { // fptrunc (select Cond, (fpext X), Y --> select Cond, X, (fptrunc Y) - Value *NarrowY = Builder.CreateFPTrunc(Y, Ty); - Value *Sel = Builder.CreateSelect(Cond, X, NarrowY, "narrow.sel", Op); + Value *NarrowY = Builder.CreateFPTruncFMF(Y, Ty, FMF); + Value *Sel = + Builder.CreateSelectFMF(Cond, X, NarrowY, FMF, "narrow.sel", Op); return replaceInstUsesWith(FPT, Sel); } if (match(Op, m_Select(m_Value(Cond), m_Value(Y), m_FPExt(m_Value(X)))) && X->getType() == Ty) { // fptrunc (select Cond, Y, (fpext X) --> select Cond, (fptrunc Y), X - Value *NarrowY = Builder.CreateFPTrunc(Y, Ty); - Value *Sel = Builder.CreateSelect(Cond, NarrowY, X, "narrow.sel", Op); + Value *NarrowY = Builder.CreateFPTruncFMF(Y, Ty, FMF); + Value *Sel = + Builder.CreateSelectFMF(Cond, NarrowY, X, FMF, "narrow.sel", Op); return replaceInstUsesWith(FPT, Sel); } } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index f85a3c93651353..e376376c3ce282 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -121,21 +121,17 @@ static Value *foldMulSelectToNegate(BinaryOperator &I, // fmul OtherOp, (select Cond, 1.0, -1.0) --> select Cond, OtherOp, -OtherOp if (match(&I, m_c_FMul(m_OneUse(m_Select(m_Value(Cond), m_SpecificFP(1.0), m_SpecificFP(-1.0))), - m_Value(OtherOp)))) { - IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); - Builder.setFastMathFlags(I.getFastMathFlags()); - return Builder.CreateSelect(Cond, OtherOp, Builder.CreateFNeg(OtherOp)); - } + m_Value(OtherOp)))) + return Builder.CreateSelectFMF(Cond, OtherOp, + Builder.CreateFNegFMF(OtherOp, &I), &I); // fmul (select Cond, -1.0, 1.0), OtherOp --> select Cond, -OtherOp, OtherOp // fmul OtherOp, (select Cond, -1.0, 1.0) --> select Cond, -OtherOp, OtherOp if (match(&I, m_c_FMul(m_OneUse(m_Select(m_Value(Cond), m_SpecificFP(-1.0), m_SpecificFP(1.0))), - m_Value(OtherOp)))) { - IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); - Builder.setFastMathFlags(I.getFastMathFlags()); - return Builder.CreateSelect(Cond, Builder.CreateFNeg(OtherOp), OtherOp); - } + m_Value(OtherOp)))) + return Builder.CreateSelectFMF(Cond, Builder.CreateFNegFMF(OtherOp, &I), + OtherOp, &I); return nullptr; } @@ -590,11 +586,9 @@ Instruction *InstCombinerImpl::foldFPSignBitOps(BinaryOperator &I) { // fabs(X) / fabs(Y) --> fabs(X / Y) if (match(Op0, m_FAbs(m_Value(X))) && match(Op1, m_FAbs(m_Value(Y))) && (Op0->hasOneUse() || Op1->hasOneUse())) { - IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); - Builder.setFastMathFlags(I.getFastMathFlags()); - Value *XY = Builder.CreateBinOp(Opcode, X, Y); - Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, XY); - Fabs->takeName(&I); + Value *XY = Builder.CreateBinOpFMF(Opcode, X, Y, &I); + Value *Fabs = + Builder.CreateUnaryIntrinsic(Intrinsic::fabs, XY, &I, I.getName()); return replaceInstUsesWith(I, Fabs); } @@ -685,8 +679,6 @@ Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) { match(Op0, m_AllowReassoc(m_BinOp(Op0BinOp)))) { // Everything in this scope folds I with Op0, intersecting their FMF. FastMathFlags FMF = I.getFastMathFlags() & Op0BinOp->getFastMathFlags(); - IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); - Builder.setFastMathFlags(FMF); Constant *C1; if (match(Op0, m_OneUse(m_FDiv(m_Constant(C1), m_Value(X))))) { // (C1 / X) * C --> (C * C1) / X @@ -718,7 +710,7 @@ Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) { // (X + C1) * C --> (X * C) + (C * C1) if (Constant *CC1 = ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL)) { - Value *XC = Builder.CreateFMul(X, C); + Value *XC = Builder.CreateFMulFMF(X, C, FMF); return BinaryOperator::CreateFAddFMF(XC, CC1, FMF); } } @@ -726,7 +718,7 @@ Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) { // (C1 - X) * C --> (C * C1) - (X * C) if (Constant *CC1 = ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL)) { - Value *XC = Builder.CreateFMul(X, C); + Value *XC = Builder.CreateFMulFMF(X, C, FMF); return BinaryOperator::CreateFSubFMF(CC1, XC, FMF); } } @@ -740,9 +732,7 @@ Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) { FastMathFlags FMF = I.getFastMathFlags() & DivOp->getFastMathFlags(); if (FMF.allowReassoc()) { // Sink division: (X / Y) * Z --> (X * Z) / Y - IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); - Builder.setFastMathFlags(FMF); - auto *NewFMul = Builder.CreateFMul(X, Z); + auto *NewFMul = Builder.CreateFMulFMF(X, Z, FMF); return BinaryOperator::CreateFDivFMF(NewFMul, Y, FMF); } } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index e7a8e947705f8d..042542ea3c0f11 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -3910,12 +3910,11 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { // (X ugt Y) ? X : Y -> (X ole Y) ? Y : X if (FCmp->hasOneUse() && FCmpInst::isUnordered(Pred)) { FCmpInst::Predicate InvPred = FCmp->getInversePredicate(); - IRBuilder<>::FastMathFlagGuard FMFG(Builder); // FIXME: The FMF should propagate from the select, not the fcmp. - Builder.setFastMathFlags(FCmp->getFastMathFlags()); - Value *NewCond = Builder.CreateFCmp(InvPred, Cmp0, Cmp1, - FCmp->getName() + ".inv"); - Value *NewSel = Builder.CreateSelect(NewCond, FalseVal, TrueVal); + Value *NewCond = Builder.CreateFCmpFMF(InvPred, Cmp0, Cmp1, FCmp, + FCmp->getName() + ".inv"); + Value *NewSel = + Builder.CreateSelectFMF(NewCond, FalseVal, TrueVal, FCmp); return replaceInstUsesWith(SI, NewSel); } } @@ -4080,15 +4079,11 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { CmpInst::Predicate MinMaxPred = getMinMaxPred(SPF, SPR.Ordered); Value *Cmp; - if (CmpInst::isIntPredicate(MinMaxPred)) { + if (CmpInst::isIntPredicate(MinMaxPred)) Cmp = Builder.CreateICmp(MinMaxPred, LHS, RHS); - } else { - IRBuilder<>::FastMathFlagGuard FMFG(Builder); - auto FMF = - cast(SI.getCondition())->getFastMathFlags(); - Builder.setFastMathFlags(FMF); - Cmp = Builder.CreateFCmp(MinMaxPred, LHS, RHS); - } + else + Cmp = Builder.CreateFCmpFMF(MinMaxPred, LHS, RHS, + cast(SI.getCondition())); Value *NewSI = Builder.CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI); if (!IsCastNeeded) diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index febc5682c21295..03dc6c1d17446d 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -2153,12 +2153,9 @@ bool SimplifyCFGOpt::hoistSuccIdenticalTerminatorToSwitchOrIf( SelectInst *&SI = InsertedSelects[std::make_pair(BB1V, BB2V)]; if (!SI) { // Propagate fast-math-flags from phi node to its replacement select. - IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); - if (isa(PN)) - Builder.setFastMathFlags(PN.getFastMathFlags()); - - SI = cast(Builder.CreateSelect( + SI = cast(Builder.CreateSelectFMF( BI->getCondition(), BB1V, BB2V, + isa(PN) ? &PN : nullptr, BB1V->getName() + "." + BB2V->getName(), BI)); } @@ -3898,16 +3895,14 @@ static bool foldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, IRBuilder Builder(DomBI); // Propagate fast-math-flags from phi nodes to replacement selects. - IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); while (PHINode *PN = dyn_cast(BB->begin())) { - if (isa(PN)) - Builder.setFastMathFlags(PN->getFastMathFlags()); - // Change the PHI node into a select instruction. Value *TrueVal = PN->getIncomingValueForBlock(IfTrue); Value *FalseVal = PN->getIncomingValueForBlock(IfFalse); - Value *Sel = Builder.CreateSelect(IfCond, TrueVal, FalseVal, "", DomBI); + Value *Sel = Builder.CreateSelectFMF(IfCond, TrueVal, FalseVal, + isa(PN) ? PN : nullptr, + "", DomBI); PN->replaceAllUsesWith(Sel); Sel->takeName(PN); PN->eraseFromParent(); diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index 737818b7825cf4..2b2b4670714b68 100644 --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -2005,28 +2005,21 @@ Value *LibCallSimplifier::optimizeCAbs(CallInst *CI, IRBuilderBase &B) { AbsOp = Real; } - if (AbsOp) { - IRBuilderBase::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - + if (AbsOp) return copyFlags( - *CI, B.CreateUnaryIntrinsic(Intrinsic::fabs, AbsOp, nullptr, "cabs")); - } + *CI, B.CreateUnaryIntrinsic(Intrinsic::fabs, AbsOp, CI, "cabs")); if (!CI->isFast()) return nullptr; } // Propagate fast-math flags from the existing call to new instructions. - IRBuilderBase::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - - Value *RealReal = B.CreateFMul(Real, Real); - Value *ImagImag = B.CreateFMul(Imag, Imag); - - return copyFlags(*CI, B.CreateUnaryIntrinsic(Intrinsic::sqrt, - B.CreateFAdd(RealReal, ImagImag), - nullptr, "cabs")); + Value *RealReal = B.CreateFMulFMF(Real, Real, CI); + Value *ImagImag = B.CreateFMulFMF(Imag, Imag, CI); + return copyFlags( + *CI, B.CreateUnaryIntrinsic(Intrinsic::sqrt, + B.CreateFAddFMF(RealReal, ImagImag, CI), CI, + "cabs")); } // Return a properly extended integer (DstWidth bits wide) if the operation is @@ -2480,15 +2473,13 @@ Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilderBase &B) { // "Ideally, fmax would be sensitive to the sign of zero, for example // fmax(-0.0, +0.0) would return +0; however, implementation in software // might be impractical." - IRBuilderBase::FastMathFlagGuard Guard(B); FastMathFlags FMF = CI->getFastMathFlags(); FMF.setNoSignedZeros(); - B.setFastMathFlags(FMF); Intrinsic::ID IID = Callee->getName().starts_with("fmin") ? Intrinsic::minnum : Intrinsic::maxnum; return copyFlags(*CI, B.CreateBinaryIntrinsic(IID, CI->getArgOperand(0), - CI->getArgOperand(1))); + CI->getArgOperand(1), FMF)); } Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) { @@ -2783,20 +2774,18 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) { // Fast math flags for any created instructions should match the sqrt // and multiply. - IRBuilderBase::FastMathFlagGuard Guard(B); - B.setFastMathFlags(I->getFastMathFlags()); // If we found a repeated factor, hoist it out of the square root and // replace it with the fabs of that factor. Value *FabsCall = - B.CreateUnaryIntrinsic(Intrinsic::fabs, RepeatOp, nullptr, "fabs"); + B.CreateUnaryIntrinsic(Intrinsic::fabs, RepeatOp, I, "fabs"); if (OtherOp) { // If we found a non-repeated factor, we still need to get its square // root. We then multiply that by the value that was simplified out // of the square root calculation. Value *SqrtCall = - B.CreateUnaryIntrinsic(Intrinsic::sqrt, OtherOp, nullptr, "sqrt"); - return copyFlags(*CI, B.CreateFMul(FabsCall, SqrtCall)); + B.CreateUnaryIntrinsic(Intrinsic::sqrt, OtherOp, I, "sqrt"); + return copyFlags(*CI, B.CreateFMulFMF(FabsCall, SqrtCall, I)); } return copyFlags(*CI, FabsCall); } @@ -2951,26 +2940,23 @@ static Value *optimizeSymmetricCall(CallInst *CI, bool IsEven, Value *Src = CI->getArgOperand(0); if (match(Src, m_OneUse(m_FNeg(m_Value(X))))) { - IRBuilderBase::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - - auto *CallInst = copyFlags(*CI, B.CreateCall(CI->getCalledFunction(), {X})); + auto *Call = B.CreateCall(CI->getCalledFunction(), {X}); + Call->copyFastMathFlags(CI); + auto *CallInst = copyFlags(*CI, Call); if (IsEven) { // Even function: f(-x) = f(x) return CallInst; } // Odd function: f(-x) = -f(x) - return B.CreateFNeg(CallInst); + return B.CreateFNegFMF(CallInst, CI); } // Even function: f(abs(x)) = f(x), f(copysign(x, y)) = f(x) if (IsEven && (match(Src, m_FAbs(m_Value(X))) || match(Src, m_CopySign(m_Value(X), m_Value())))) { - IRBuilderBase::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - - auto *CallInst = copyFlags(*CI, B.CreateCall(CI->getCalledFunction(), {X})); - return CallInst; + auto *Call = B.CreateCall(CI->getCalledFunction(), {X}); + Call->copyFastMathFlags(CI); + return copyFlags(*CI, Call); } return nullptr; diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 77c08839dbfa95..ec586fa47fe1d2 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -1352,10 +1352,9 @@ void VPWidenRecipe::execute(VPTransformState &State) { Value *C = nullptr; if (FCmp) { // Propagate fast math flags. - IRBuilder<>::FastMathFlagGuard FMFG(Builder); - if (auto *I = dyn_cast_or_null(getUnderlyingValue())) - Builder.setFastMathFlags(I->getFastMathFlags()); - C = Builder.CreateFCmp(getPredicate(), A, B); + C = Builder.CreateFCmpFMF( + getPredicate(), A, B, + dyn_cast_or_null(getUnderlyingValue())); } else { C = Builder.CreateICmp(getPredicate(), A, B); } From 9397e712f6010be15ccf62f12740e9b4a67de2f4 Mon Sep 17 00:00:00 2001 From: Yingwei Zheng Date: Sun, 5 Jan 2025 16:52:24 +0800 Subject: [PATCH 2/7] [IRBuilder] Remove source inst --- llvm/include/llvm/IR/IRBuilder.h | 67 +++++++++++++++++--------------- llvm/lib/IR/IRBuilder.cpp | 20 +++++----- 2 files changed, 45 insertions(+), 42 deletions(-) diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h index c1f15783fb7592..b73309175f20d1 100644 --- a/llvm/include/llvm/IR/IRBuilder.h +++ b/llvm/include/llvm/IR/IRBuilder.h @@ -89,16 +89,19 @@ class IRBuilderCallbackInserter : public IRBuilderDefaultInserter { /// This provides a helper for copying FMF from an instruction or setting /// specified flags. -struct FMFSource final { - Instruction *Source; +class FMFSource { std::optional FMF; - FMFSource() : Source(nullptr) {} - FMFSource(Instruction *Source) : Source(Source) { +public: + FMFSource() = default; + FMFSource(Instruction *Source) { if (Source) FMF = Source->getFastMathFlags(); } - FMFSource(FastMathFlags FMF) : Source(nullptr), FMF(FMF) {} + FMFSource(FastMathFlags FMF) : FMF(FMF) {} + FastMathFlags get(FastMathFlags Default) const { + return FMF.value_or(Default); + } }; /// Common base class shared among various IRBuilders. @@ -1575,11 +1578,11 @@ class IRBuilderBase { return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fadd, L, R, FMFSource, Name, FPMD); - if (Value *V = Folder.FoldBinOpFMF(Instruction::FAdd, L, R, - FMFSource.FMF.value_or(FMF))) + if (Value *V = + Folder.FoldBinOpFMF(Instruction::FAdd, L, R, FMFSource.get(FMF))) return V; - Instruction *I = setFPAttrs(BinaryOperator::CreateFAdd(L, R), FPMD, - FMFSource.FMF.value_or(FMF)); + Instruction *I = + setFPAttrs(BinaryOperator::CreateFAdd(L, R), FPMD, FMFSource.get(FMF)); return Insert(I, Name); } @@ -1594,11 +1597,11 @@ class IRBuilderBase { return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fsub, L, R, FMFSource, Name, FPMD); - if (Value *V = Folder.FoldBinOpFMF(Instruction::FSub, L, R, - FMFSource.FMF.value_or(FMF))) + if (Value *V = + Folder.FoldBinOpFMF(Instruction::FSub, L, R, FMFSource.get(FMF))) return V; - Instruction *I = setFPAttrs(BinaryOperator::CreateFSub(L, R), FPMD, - FMFSource.FMF.value_or(FMF)); + Instruction *I = + setFPAttrs(BinaryOperator::CreateFSub(L, R), FPMD, FMFSource.get(FMF)); return Insert(I, Name); } @@ -1613,11 +1616,11 @@ class IRBuilderBase { return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fmul, L, R, FMFSource, Name, FPMD); - if (Value *V = Folder.FoldBinOpFMF(Instruction::FMul, L, R, - FMFSource.FMF.value_or(FMF))) + if (Value *V = + Folder.FoldBinOpFMF(Instruction::FMul, L, R, FMFSource.get(FMF))) return V; - Instruction *I = setFPAttrs(BinaryOperator::CreateFMul(L, R), FPMD, - FMFSource.FMF.value_or(FMF)); + Instruction *I = + setFPAttrs(BinaryOperator::CreateFMul(L, R), FPMD, FMFSource.get(FMF)); return Insert(I, Name); } @@ -1632,11 +1635,11 @@ class IRBuilderBase { return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fdiv, L, R, FMFSource, Name, FPMD); - if (Value *V = Folder.FoldBinOpFMF(Instruction::FDiv, L, R, - FMFSource.FMF.value_or(FMF))) + if (Value *V = + Folder.FoldBinOpFMF(Instruction::FDiv, L, R, FMFSource.get(FMF))) return V; - Instruction *I = setFPAttrs(BinaryOperator::CreateFDiv(L, R), FPMD, - FMFSource.FMF.value_or(FMF)); + Instruction *I = + setFPAttrs(BinaryOperator::CreateFDiv(L, R), FPMD, FMFSource.get(FMF)); return Insert(I, Name); } @@ -1651,11 +1654,11 @@ class IRBuilderBase { return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_frem, L, R, FMFSource, Name, FPMD); - if (Value *V = Folder.FoldBinOpFMF(Instruction::FRem, L, R, - FMFSource.FMF.value_or(FMF))) + if (Value *V = + Folder.FoldBinOpFMF(Instruction::FRem, L, R, FMFSource.get(FMF))) return V; - Instruction *I = setFPAttrs(BinaryOperator::CreateFRem(L, R), FPMD, - FMFSource.FMF.value_or(FMF)); + Instruction *I = + setFPAttrs(BinaryOperator::CreateFRem(L, R), FPMD, FMFSource.get(FMF)); return Insert(I, Name); } @@ -1672,7 +1675,7 @@ class IRBuilderBase { return V; Instruction *BinOp = BinaryOperator::Create(Opc, LHS, RHS); if (isa(BinOp)) - setFPAttrs(BinOp, FPMathTag, FMFSource.FMF.value_or(FMF)); + setFPAttrs(BinOp, FPMathTag, FMFSource.get(FMF)); return Insert(BinOp, Name); } @@ -1737,12 +1740,12 @@ class IRBuilderBase { Value *CreateFNegFMF(Value *V, FMFSource FMFSource, const Twine &Name = "", MDNode *FPMathTag = nullptr) { - if (Value *Res = Folder.FoldUnOpFMF(Instruction::FNeg, V, - FMFSource.FMF.value_or(FMF))) + if (Value *Res = + Folder.FoldUnOpFMF(Instruction::FNeg, V, FMFSource.get(FMF))) return Res; - return Insert(setFPAttrs(UnaryOperator::CreateFNeg(V), FPMathTag, - FMFSource.FMF.value_or(FMF)), - Name); + return Insert( + setFPAttrs(UnaryOperator::CreateFNeg(V), FPMathTag, FMFSource.get(FMF)), + Name); } Value *CreateNot(Value *V, const Twine &Name = "") { @@ -2183,7 +2186,7 @@ class IRBuilderBase { return Folded; Instruction *Cast = CastInst::Create(Op, V, DestTy); if (isa(Cast)) - setFPAttrs(Cast, FPMathTag, FMFSource.FMF.value_or(FMF)); + setFPAttrs(Cast, FPMathTag, FMFSource.get(FMF)); return Insert(Cast, Name); } diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp index 9bb4acfb96c750..27b499e42a4e4c 100644 --- a/llvm/lib/IR/IRBuilder.cpp +++ b/llvm/lib/IR/IRBuilder.cpp @@ -81,8 +81,8 @@ IRBuilderBase::createCallHelper(Function *Callee, ArrayRef Ops, const Twine &Name, FMFSource FMFSource, ArrayRef OpBundles) { CallInst *CI = CreateCall(Callee, Ops, OpBundles, Name); - if (FMFSource.FMF.has_value()) - CI->setFastMathFlags(*FMFSource.FMF); + if (isa(CI)) + CI->setFastMathFlags(FMFSource.get(FMF)); return CI; } @@ -882,7 +882,7 @@ Value *IRBuilderBase::CreateBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, Module *M = BB->getModule(); Function *Fn = Intrinsic::getOrInsertDeclaration(M, ID, {LHS->getType()}); if (Value *V = Folder.FoldBinaryIntrinsic(ID, LHS, RHS, Fn->getReturnType(), - FMFSource.Source)) + /*FMFSource=*/nullptr)) return V; return createCallHelper(Fn, {LHS, RHS}, Name, FMFSource); } @@ -931,7 +931,7 @@ CallInst *IRBuilderBase::CreateConstrainedFPBinOp( Value *RoundingV = getConstrainedFPRounding(Rounding); Value *ExceptV = getConstrainedFPExcept(Except); - FastMathFlags UseFMF = FMFSource.FMF.value_or(FMF); + FastMathFlags UseFMF = FMFSource.get(FMF); CallInst *C = CreateIntrinsic(ID, {L->getType()}, {L, R, RoundingV, ExceptV}, nullptr, Name); @@ -946,7 +946,7 @@ CallInst *IRBuilderBase::CreateConstrainedFPUnroundedBinOp( std::optional Except) { Value *ExceptV = getConstrainedFPExcept(Except); - FastMathFlags UseFMF = FMFSource.FMF.value_or(FMF); + FastMathFlags UseFMF = FMFSource.get(FMF); CallInst *C = CreateIntrinsic(ID, {L->getType()}, {L, R, ExceptV}, nullptr, Name); @@ -976,7 +976,7 @@ CallInst *IRBuilderBase::CreateConstrainedFPCast( std::optional Except) { Value *ExceptV = getConstrainedFPExcept(Except); - FastMathFlags UseFMF = FMFSource.FMF.value_or(FMF); + FastMathFlags UseFMF = FMFSource.get(FMF); CallInst *C; if (Intrinsic::hasConstrainedFPRoundingModeOperand(ID)) { @@ -1006,9 +1006,9 @@ Value *IRBuilderBase::CreateFCmpHelper(CmpInst::Predicate P, Value *LHS, if (auto *V = Folder.FoldCmp(P, LHS, RHS)) return V; - return Insert(setFPAttrs(new FCmpInst(P, LHS, RHS), FPMathTag, - FMFSource.FMF.value_or(FMF)), - Name); + return Insert( + setFPAttrs(new FCmpInst(P, LHS, RHS), FPMathTag, FMFSource.get(FMF)), + Name); } CallInst *IRBuilderBase::CreateConstrainedFPCmp( @@ -1058,7 +1058,7 @@ Value *IRBuilderBase::CreateSelectFMF(Value *C, Value *True, Value *False, Sel = addBranchMetadata(Sel, Prof, Unpred); } if (isa(Sel)) - setFPAttrs(Sel, /*MDNode=*/nullptr, FMFSource.FMF.value_or(FMF)); + setFPAttrs(Sel, /*MDNode=*/nullptr, FMFSource.get(FMF)); return Insert(Sel, Name); } From e781b909c95ff8633ab978a96054a93efd9389a8 Mon Sep 17 00:00:00 2001 From: Yingwei Zheng Date: Sun, 5 Jan 2025 20:58:40 +0800 Subject: [PATCH 3/7] [InstCombine] Address review comments. NFC. --- .../Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp index d45aee37801736..fe7b3b1676e084 100644 --- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -426,8 +426,7 @@ static bool foldSqrt(CallInst *Call, LibFunc Func, TargetTransformInfo &TTI, SimplifyQuery(Call->getDataLayout(), &TLI, &DT, &AC, Call)))) { IRBuilder<> Builder(Call); Value *NewSqrt = - Builder.CreateIntrinsic(Intrinsic::sqrt, Ty, Arg, - /*FMFSource=*/Call->getFastMathFlags(), "sqrt"); + Builder.CreateIntrinsic(Intrinsic::sqrt, Ty, Arg, Call, "sqrt"); Call->replaceAllUsesWith(NewSqrt); // Explicitly erase the old call because a call with side effects is not From a56b52ac3821b434bee676e2d06cc97bac9cb79b Mon Sep 17 00:00:00 2001 From: Narayan Sreekumar Date: Fri, 3 Jan 2025 01:33:41 +0530 Subject: [PATCH 4/7] [InstCombine] Pre-Commit Tests --- .../Transforms/InstCombine/select_frexp.ll | 129 ++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 llvm/test/Transforms/InstCombine/select_frexp.ll diff --git a/llvm/test/Transforms/InstCombine/select_frexp.ll b/llvm/test/Transforms/InstCombine/select_frexp.ll new file mode 100644 index 00000000000000..b3f05f4db42dd1 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/select_frexp.ll @@ -0,0 +1,129 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -passes=instcombine -S < %s | FileCheck %s + +declare { float, i32 } @llvm.frexp.f32.i32(float) +declare void @use(float) + +; Basic test case - constant in true position +define float @test_select_frexp_basic(float %x, i1 %cond) { +; CHECK-LABEL: define float @test_select_frexp_basic( +; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], float 1.000000e+00, float [[X]] +; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]]) +; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0 +; CHECK-NEXT: ret float [[FREXP_0]] +; + %sel = select i1 %cond, float 1.000000e+00, float %x + %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel) + %frexp.0 = extractvalue { float, i32 } %frexp, 0 + ret float %frexp.0 +} + +; Test with constant in false position +define float @test_select_frexp_const_false(float %x, i1 %cond) { +; CHECK-LABEL: define float @test_select_frexp_const_false( +; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], float [[X]], float 1.000000e+00 +; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]]) +; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0 +; CHECK-NEXT: ret float [[FREXP_0]] +; + %sel = select i1 %cond, float %x, float 1.000000e+00 + %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel) + %frexp.0 = extractvalue { float, i32 } %frexp, 0 + ret float %frexp.0 +} + +; Multi-use test +define float @test_select_frexp_multi_use(float %x, i1 %cond) { +; CHECK-LABEL: define float @test_select_frexp_multi_use( +; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], float 1.000000e+00, float [[X]] +; CHECK-NEXT: call void @use(float [[SEL]]) +; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]]) +; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0 +; CHECK-NEXT: ret float [[FREXP_0]] +; + %sel = select i1 %cond, float 1.000000e+00, float %x + call void @use(float %sel) + %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel) + %frexp.0 = extractvalue { float, i32 } %frexp, 0 + ret float %frexp.0 +} + +; Vector test - splat constant +define <2 x float> @test_select_frexp_vec_splat(<2 x float> %x, <2 x i1> %cond) { +; CHECK-LABEL: define <2 x float> @test_select_frexp_vec_splat( +; CHECK-SAME: <2 x float> [[X:%.*]], <2 x i1> [[COND:%.*]]) { +; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> [[COND]], <2 x float> splat (float 1.000000e+00), <2 x float> [[X]] +; CHECK-NEXT: [[FREXP:%.*]] = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> [[SEL]]) +; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { <2 x float>, <2 x i32> } [[FREXP]], 0 +; CHECK-NEXT: ret <2 x float> [[FREXP_0]] +; + %sel = select <2 x i1> %cond, <2 x float> , <2 x float> %x + %frexp = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> %sel) + %frexp.0 = extractvalue { <2 x float>, <2 x i32> } %frexp, 0 + ret <2 x float> %frexp.0 +} + +; Vector test with poison +define <2 x float> @test_select_frexp_vec_poison(<2 x float> %x, <2 x i1> %cond) { +; CHECK-LABEL: define <2 x float> @test_select_frexp_vec_poison( +; CHECK-SAME: <2 x float> [[X:%.*]], <2 x i1> [[COND:%.*]]) { +; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> [[COND]], <2 x float> , <2 x float> [[X]] +; CHECK-NEXT: [[FREXP:%.*]] = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> [[SEL]]) +; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { <2 x float>, <2 x i32> } [[FREXP]], 0 +; CHECK-NEXT: ret <2 x float> [[FREXP_0]] +; + %sel = select <2 x i1> %cond, <2 x float> , <2 x float> %x + %frexp = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> %sel) + %frexp.0 = extractvalue { <2 x float>, <2 x i32> } %frexp, 0 + ret <2 x float> %frexp.0 +} + +; Vector test - non-splat (should not fold) +define <2 x float> @test_select_frexp_vec_nonsplat(<2 x float> %x, <2 x i1> %cond) { +; CHECK-LABEL: define <2 x float> @test_select_frexp_vec_nonsplat( +; CHECK-SAME: <2 x float> [[X:%.*]], <2 x i1> [[COND:%.*]]) { +; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> [[COND]], <2 x float> , <2 x float> [[X]] +; CHECK-NEXT: [[FREXP:%.*]] = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> [[SEL]]) +; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { <2 x float>, <2 x i32> } [[FREXP]], 0 +; CHECK-NEXT: ret <2 x float> [[FREXP_0]] +; + %sel = select <2 x i1> %cond, <2 x float> , <2 x float> %x + %frexp = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> %sel) + %frexp.0 = extractvalue { <2 x float>, <2 x i32> } %frexp, 0 + ret <2 x float> %frexp.0 +} + +; Negative test - both operands non-constant +define float @test_select_frexp_no_const(float %x, float %y, i1 %cond) { +; CHECK-LABEL: define float @test_select_frexp_no_const( +; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], float [[X]], float [[Y]] +; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]]) +; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0 +; CHECK-NEXT: ret float [[FREXP_0]] +; + %sel = select i1 %cond, float %x, float %y + %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel) + %frexp.0 = extractvalue { float, i32 } %frexp, 0 + ret float %frexp.0 +} + +; Negative test - extracting exp instead of mantissa +define i32 @test_select_frexp_extract_exp(float %x, i1 %cond) { +; CHECK-LABEL: define i32 @test_select_frexp_extract_exp( +; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], float 1.000000e+00, float [[X]] +; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]]) +; CHECK-NEXT: [[FREXP_1:%.*]] = extractvalue { float, i32 } [[FREXP]], 1 +; CHECK-NEXT: ret i32 [[FREXP_1]] +; + %sel = select i1 %cond, float 1.000000e+00, float %x + %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel) + %frexp.1 = extractvalue { float, i32 } %frexp, 1 + ret i32 %frexp.1 +} + +declare { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float>) From 010ef8208ebcf158b8de23c7dbff5b7d4e54f1f5 Mon Sep 17 00:00:00 2001 From: Narayan Sreekumar Date: Fri, 3 Jan 2025 18:03:45 +0530 Subject: [PATCH 5/7] [InstCombine] InstCombine should fold frexp of select to select of frexp --- .../InstCombine/InstructionCombining.cpp | 67 ++++++++++++++++++- .../Transforms/InstCombine/select_frexp.ll | 17 ++--- 2 files changed, 75 insertions(+), 9 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 934156f04f7fdd..b4b31c25ef080d 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -4043,6 +4043,52 @@ InstCombinerImpl::foldExtractOfOverflowIntrinsic(ExtractValueInst &EV) { return nullptr; } +static Value *foldFrexpOfSelect(ExtractValueInst &EV, CallInst *FrexpCall, + SelectInst *SelectInst, + InstCombiner::BuilderTy &Builder) { + // Helper to fold frexp of select to select of frexp. + Value *Cond = SelectInst->getCondition(); + Value *TrueVal = SelectInst->getTrueValue(); + Value *FalseVal = SelectInst->getFalseValue(); + ConstantFP *ConstOp = nullptr; + Value *VarOp = nullptr; + bool ConstIsTrue = false; + + if (auto *TrueConst = dyn_cast(TrueVal)) { + ConstOp = TrueConst; + VarOp = FalseVal; + ConstIsTrue = true; + } else if (auto *FalseConst = dyn_cast(FalseVal)) { + ConstOp = FalseConst; + VarOp = TrueVal; + ConstIsTrue = false; + } + + if (!ConstOp || !VarOp) + return nullptr; + + CallInst *NewFrexp = + Builder.CreateCall(FrexpCall->getCalledFunction(), {VarOp}, "frexp"); + + Value *NewEV = Builder.CreateExtractValue(NewFrexp, 0, "mantissa"); + + APFloat ConstVal = ConstOp->getValueAPF(); + int Exp = 0; + APFloat Mantissa = ConstVal; + + if (ConstVal.isFiniteNonZero()) { + Mantissa = frexp(ConstVal, Exp, APFloat::rmNearestTiesToEven); + } + + Constant *ConstantMantissa = ConstantFP::get(ConstOp->getType(), Mantissa); + + Value *NewSel = Builder.CreateSelect( + Cond, ConstIsTrue ? ConstantMantissa : NewEV, + ConstIsTrue ? NewEV : ConstantMantissa, "select.frexp"); + + return NewSel; +} + Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { Value *Agg = EV.getAggregateOperand(); @@ -4052,7 +4098,26 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { if (Value *V = simplifyExtractValueInst(Agg, EV.getIndices(), SQ.getWithInstruction(&EV))) return replaceInstUsesWith(EV, V); - + if (EV.getNumIndices() == 1 && EV.getIndices()[0] == 0) { + if (auto *FrexpCall = dyn_cast(Agg)) { + if (Function *F = FrexpCall->getCalledFunction()) { + if (F->getIntrinsicID() == Intrinsic::frexp) { + if (auto *SelInst = + dyn_cast(FrexpCall->getArgOperand(0))) { + if (isa(SelInst->getTrueValue()) || + isa(SelInst->getFalseValue())) { + Builder.SetInsertPoint(&EV); + + if (Value *Result = + foldFrexpOfSelect(EV, FrexpCall, SelInst, Builder)) { + return replaceInstUsesWith(EV, Result); + } + } + } + } + } + } + } if (InsertValueInst *IV = dyn_cast(Agg)) { // We're extracting from an insertvalue instruction, compare the indices const unsigned *exti, *exte, *insi, *inse; diff --git a/llvm/test/Transforms/InstCombine/select_frexp.ll b/llvm/test/Transforms/InstCombine/select_frexp.ll index b3f05f4db42dd1..652d4de27b7591 100644 --- a/llvm/test/Transforms/InstCombine/select_frexp.ll +++ b/llvm/test/Transforms/InstCombine/select_frexp.ll @@ -8,10 +8,10 @@ declare void @use(float) define float @test_select_frexp_basic(float %x, i1 %cond) { ; CHECK-LABEL: define float @test_select_frexp_basic( ; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) { -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], float 1.000000e+00, float [[X]] -; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]]) +; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[X]]) ; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0 -; CHECK-NEXT: ret float [[FREXP_0]] +; CHECK-NEXT: [[SELECT_FREXP:%.*]] = select i1 [[COND]], float 5.000000e-01, float [[FREXP_0]] +; CHECK-NEXT: ret float [[SELECT_FREXP]] ; %sel = select i1 %cond, float 1.000000e+00, float %x %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel) @@ -23,10 +23,10 @@ define float @test_select_frexp_basic(float %x, i1 %cond) { define float @test_select_frexp_const_false(float %x, i1 %cond) { ; CHECK-LABEL: define float @test_select_frexp_const_false( ; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) { -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], float [[X]], float 1.000000e+00 -; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]]) +; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[X]]) ; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0 -; CHECK-NEXT: ret float [[FREXP_0]] +; CHECK-NEXT: [[SELECT_FREXP:%.*]] = select i1 [[COND]], float [[FREXP_0]], float 5.000000e-01 +; CHECK-NEXT: ret float [[SELECT_FREXP]] ; %sel = select i1 %cond, float %x, float 1.000000e+00 %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel) @@ -40,9 +40,10 @@ define float @test_select_frexp_multi_use(float %x, i1 %cond) { ; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) { ; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], float 1.000000e+00, float [[X]] ; CHECK-NEXT: call void @use(float [[SEL]]) -; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]]) +; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[X]]) ; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0 -; CHECK-NEXT: ret float [[FREXP_0]] +; CHECK-NEXT: [[SELECT_FREXP:%.*]] = select i1 [[COND]], float 5.000000e-01, float [[FREXP_0]] +; CHECK-NEXT: ret float [[SELECT_FREXP]] ; %sel = select i1 %cond, float 1.000000e+00, float %x call void @use(float %sel) From e646bbdc2225fc690632ae78d88022f04a8d7968 Mon Sep 17 00:00:00 2001 From: Narayan Sreekumar Date: Mon, 6 Jan 2025 23:08:37 +0530 Subject: [PATCH 6/7] [InstCombine] Refactor and Preserve fast math flags --- .../InstCombine/InstructionCombining.cpp | 58 +++++++++---------- .../Transforms/InstCombine/select_frexp.ll | 37 +++++++++++- 2 files changed, 62 insertions(+), 33 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index b4b31c25ef080d..1c6d24cc541ef1 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -33,6 +33,7 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -4043,52 +4044,57 @@ InstCombinerImpl::foldExtractOfOverflowIntrinsic(ExtractValueInst &EV) { return nullptr; } -static Value *foldFrexpOfSelect(ExtractValueInst &EV, CallInst *FrexpCall, +static Value *foldFrexpOfSelect(ExtractValueInst &EV, IntrinsicInst *FrexpCall, SelectInst *SelectInst, InstCombiner::BuilderTy &Builder) { // Helper to fold frexp of select to select of frexp. Value *Cond = SelectInst->getCondition(); Value *TrueVal = SelectInst->getTrueValue(); Value *FalseVal = SelectInst->getFalseValue(); - ConstantFP *ConstOp = nullptr; + + const APFloat *ConstVal = nullptr; Value *VarOp = nullptr; bool ConstIsTrue = false; - if (auto *TrueConst = dyn_cast(TrueVal)) { - ConstOp = TrueConst; + if (match(TrueVal, m_APFloat(ConstVal))) { VarOp = FalseVal; ConstIsTrue = true; - } else if (auto *FalseConst = dyn_cast(FalseVal)) { - ConstOp = FalseConst; + } else if (match(FalseVal, m_APFloat(ConstVal))) { VarOp = TrueVal; ConstIsTrue = false; + } else { + return nullptr; } - if (!ConstOp || !VarOp) - return nullptr; + Builder.SetInsertPoint(&EV); CallInst *NewFrexp = Builder.CreateCall(FrexpCall->getCalledFunction(), {VarOp}, "frexp"); + NewFrexp->copyIRFlags(FrexpCall); Value *NewEV = Builder.CreateExtractValue(NewFrexp, 0, "mantissa"); - APFloat ConstVal = ConstOp->getValueAPF(); - int Exp = 0; - APFloat Mantissa = ConstVal; + int Exp; + APFloat Mantissa = frexp(*ConstVal, Exp, APFloat::rmNearestTiesToEven); - if (ConstVal.isFiniteNonZero()) { - Mantissa = frexp(ConstVal, Exp, APFloat::rmNearestTiesToEven); + Constant *ConstantMantissa; + if (auto *VecTy = dyn_cast(TrueVal->getType())) { + SmallVector Elems( + VecTy->getElementCount().getFixedValue(), + ConstantFP::get(VecTy->getElementType(), Mantissa)); + ConstantMantissa = ConstantVector::get(Elems); + } else { + ConstantMantissa = ConstantFP::get(TrueVal->getType(), Mantissa); } - Constant *ConstantMantissa = ConstantFP::get(ConstOp->getType(), Mantissa); - Value *NewSel = Builder.CreateSelect( Cond, ConstIsTrue ? ConstantMantissa : NewEV, ConstIsTrue ? NewEV : ConstantMantissa, "select.frexp"); + if (auto *NewSelInst = dyn_cast(NewSel)) + NewSelInst->copyFastMathFlags(SelectInst); return NewSel; } - Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { Value *Agg = EV.getAggregateOperand(); @@ -4099,20 +4105,12 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { SQ.getWithInstruction(&EV))) return replaceInstUsesWith(EV, V); if (EV.getNumIndices() == 1 && EV.getIndices()[0] == 0) { - if (auto *FrexpCall = dyn_cast(Agg)) { - if (Function *F = FrexpCall->getCalledFunction()) { - if (F->getIntrinsicID() == Intrinsic::frexp) { - if (auto *SelInst = - dyn_cast(FrexpCall->getArgOperand(0))) { - if (isa(SelInst->getTrueValue()) || - isa(SelInst->getFalseValue())) { - Builder.SetInsertPoint(&EV); - - if (Value *Result = - foldFrexpOfSelect(EV, FrexpCall, SelInst, Builder)) { - return replaceInstUsesWith(EV, Result); - } - } + if (auto *FrexpCall = dyn_cast(Agg)) { + if (FrexpCall->getIntrinsicID() == Intrinsic::frexp) { + if (auto *SelInst = dyn_cast(FrexpCall->getArgOperand(0))) { + if (Value *Result = + foldFrexpOfSelect(EV, FrexpCall, SelInst, Builder)) { + return replaceInstUsesWith(EV, Result); } } } diff --git a/llvm/test/Transforms/InstCombine/select_frexp.ll b/llvm/test/Transforms/InstCombine/select_frexp.ll index 652d4de27b7591..d729e7c7005142 100644 --- a/llvm/test/Transforms/InstCombine/select_frexp.ll +++ b/llvm/test/Transforms/InstCombine/select_frexp.ll @@ -56,10 +56,10 @@ define float @test_select_frexp_multi_use(float %x, i1 %cond) { define <2 x float> @test_select_frexp_vec_splat(<2 x float> %x, <2 x i1> %cond) { ; CHECK-LABEL: define <2 x float> @test_select_frexp_vec_splat( ; CHECK-SAME: <2 x float> [[X:%.*]], <2 x i1> [[COND:%.*]]) { -; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> [[COND]], <2 x float> splat (float 1.000000e+00), <2 x float> [[X]] -; CHECK-NEXT: [[FREXP:%.*]] = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> [[SEL]]) +; CHECK-NEXT: [[FREXP:%.*]] = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> [[X]]) ; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { <2 x float>, <2 x i32> } [[FREXP]], 0 -; CHECK-NEXT: ret <2 x float> [[FREXP_0]] +; CHECK-NEXT: [[SELECT_FREXP:%.*]] = select <2 x i1> [[COND]], <2 x float> splat (float 5.000000e-01), <2 x float> [[FREXP_0]] +; CHECK-NEXT: ret <2 x float> [[SELECT_FREXP]] ; %sel = select <2 x i1> %cond, <2 x float> , <2 x float> %x %frexp = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> %sel) @@ -127,4 +127,35 @@ define i32 @test_select_frexp_extract_exp(float %x, i1 %cond) { ret i32 %frexp.1 } +; Test with fast math flags +define float @test_select_frexp_fast_math_select(float %x, i1 %cond) { +; CHECK-LABEL: define float @test_select_frexp_fast_math_select( +; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: [[FREXP1:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[X]]) +; CHECK-NEXT: [[MANTISSA:%.*]] = extractvalue { float, i32 } [[FREXP1]], 0 +; CHECK-NEXT: [[SELECT_FREXP:%.*]] = select nnan ninf nsz i1 [[COND]], float 5.000000e-01, float [[MANTISSA]] +; CHECK-NEXT: ret float [[SELECT_FREXP]] +; + %sel = select nnan ninf nsz i1 %cond, float 1.000000e+00, float %x + %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel) + %frexp.0 = extractvalue { float, i32 } %frexp, 0 + ret float %frexp.0 +} + + +; Test vector case with fast math flags +define <2 x float> @test_select_frexp_vec_fast_math(<2 x float> %x, <2 x i1> %cond) { +; CHECK-LABEL: define <2 x float> @test_select_frexp_vec_fast_math( +; CHECK-SAME: <2 x float> [[X:%.*]], <2 x i1> [[COND:%.*]]) { +; CHECK-NEXT: [[FREXP1:%.*]] = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> [[X]]) +; CHECK-NEXT: [[MANTISSA:%.*]] = extractvalue { <2 x float>, <2 x i32> } [[FREXP1]], 0 +; CHECK-NEXT: [[SELECT_FREXP:%.*]] = select nnan ninf nsz <2 x i1> [[COND]], <2 x float> splat (float 5.000000e-01), <2 x float> [[MANTISSA]] +; CHECK-NEXT: ret <2 x float> [[SELECT_FREXP]] +; + %sel = select nnan ninf nsz <2 x i1> %cond, <2 x float> , <2 x float> %x + %frexp = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> %sel) + %frexp.0 = extractvalue { <2 x float>, <2 x i32> } %frexp, 0 + ret <2 x float> %frexp.0 +} + declare { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float>) From d8f576a91a10d5ccf3b859e7e84e08e176b85f02 Mon Sep 17 00:00:00 2001 From: Narayan Sreekumar Date: Sat, 11 Jan 2025 00:42:41 +0530 Subject: [PATCH 7/7] [InstCombine] Refactor PatternMatch and add scalable Vector tests --- .../InstCombine/InstructionCombining.cpp | 40 +++++++------------ .../Transforms/InstCombine/select_frexp.ll | 36 +++++++++++++++-- 2 files changed, 48 insertions(+), 28 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 1c6d24cc541ef1..6586535ab996a3 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -4048,6 +4048,9 @@ static Value *foldFrexpOfSelect(ExtractValueInst &EV, IntrinsicInst *FrexpCall, SelectInst *SelectInst, InstCombiner::BuilderTy &Builder) { // Helper to fold frexp of select to select of frexp. + + if (!SelectInst->hasOneUse() || !FrexpCall->hasOneUse()) + return nullptr; Value *Cond = SelectInst->getCondition(); Value *TrueVal = SelectInst->getTrueValue(); Value *FalseVal = SelectInst->getFalseValue(); @@ -4077,22 +4080,11 @@ static Value *foldFrexpOfSelect(ExtractValueInst &EV, IntrinsicInst *FrexpCall, int Exp; APFloat Mantissa = frexp(*ConstVal, Exp, APFloat::rmNearestTiesToEven); - Constant *ConstantMantissa; - if (auto *VecTy = dyn_cast(TrueVal->getType())) { - SmallVector Elems( - VecTy->getElementCount().getFixedValue(), - ConstantFP::get(VecTy->getElementType(), Mantissa)); - ConstantMantissa = ConstantVector::get(Elems); - } else { - ConstantMantissa = ConstantFP::get(TrueVal->getType(), Mantissa); - } + Constant *ConstantMantissa = ConstantFP::get(TrueVal->getType(), Mantissa); - Value *NewSel = Builder.CreateSelect( + Value *NewSel = Builder.CreateSelectFMF( Cond, ConstIsTrue ? ConstantMantissa : NewEV, - ConstIsTrue ? NewEV : ConstantMantissa, "select.frexp"); - if (auto *NewSelInst = dyn_cast(NewSel)) - NewSelInst->copyFastMathFlags(SelectInst); - + ConstIsTrue ? NewEV : ConstantMantissa, SelectInst, "select.frexp"); return NewSel; } Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { @@ -4104,17 +4096,15 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { if (Value *V = simplifyExtractValueInst(Agg, EV.getIndices(), SQ.getWithInstruction(&EV))) return replaceInstUsesWith(EV, V); - if (EV.getNumIndices() == 1 && EV.getIndices()[0] == 0) { - if (auto *FrexpCall = dyn_cast(Agg)) { - if (FrexpCall->getIntrinsicID() == Intrinsic::frexp) { - if (auto *SelInst = dyn_cast(FrexpCall->getArgOperand(0))) { - if (Value *Result = - foldFrexpOfSelect(EV, FrexpCall, SelInst, Builder)) { - return replaceInstUsesWith(EV, Result); - } - } - } - } + + Value *Cond, *TrueVal, *FalseVal; + if (match(&EV, m_ExtractValue<0>(m_Intrinsic(m_Select( + m_Value(Cond), m_Value(TrueVal), m_Value(FalseVal)))))) { + auto *SelInst = + cast(cast(Agg)->getArgOperand(0)); + if (Value *Result = + foldFrexpOfSelect(EV, cast(Agg), SelInst, Builder)) + return replaceInstUsesWith(EV, Result); } if (InsertValueInst *IV = dyn_cast(Agg)) { // We're extracting from an insertvalue instruction, compare the indices diff --git a/llvm/test/Transforms/InstCombine/select_frexp.ll b/llvm/test/Transforms/InstCombine/select_frexp.ll index d729e7c7005142..d025aedda7170d 100644 --- a/llvm/test/Transforms/InstCombine/select_frexp.ll +++ b/llvm/test/Transforms/InstCombine/select_frexp.ll @@ -40,10 +40,9 @@ define float @test_select_frexp_multi_use(float %x, i1 %cond) { ; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) { ; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], float 1.000000e+00, float [[X]] ; CHECK-NEXT: call void @use(float [[SEL]]) -; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[X]]) +; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]]) ; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0 -; CHECK-NEXT: [[SELECT_FREXP:%.*]] = select i1 [[COND]], float 5.000000e-01, float [[FREXP_0]] -; CHECK-NEXT: ret float [[SELECT_FREXP]] +; CHECK-NEXT: ret float [[FREXP_0]] ; %sel = select i1 %cond, float 1.000000e+00, float %x call void @use(float %sel) @@ -158,4 +157,35 @@ define <2 x float> @test_select_frexp_vec_fast_math(<2 x float> %x, <2 x i1> %co ret <2 x float> %frexp.0 } +; Test with scalable vectors with constant at True Position +define @test_select_frexp_scalable_vec0( %x, %cond) { +; CHECK-LABEL: define @test_select_frexp_scalable_vec0( +; CHECK-SAME: [[X:%.*]], [[COND:%.*]]) { +; CHECK-NEXT: [[FREXP1:%.*]] = call { , } @llvm.frexp.nxv2f32.nxv2i32( [[X]]) +; CHECK-NEXT: [[MANTISSA:%.*]] = extractvalue { , } [[FREXP1]], 0 +; CHECK-NEXT: [[SELECT_FREXP:%.*]] = select [[COND]], splat (float 5.000000e-01), [[MANTISSA]] +; CHECK-NEXT: ret [[SELECT_FREXP]] +; + %sel = select %cond, splat (float 1.000000e+00), %x + %frexp = call { , } @llvm.frexp.nxv2f32.nxv2i32( %sel) + %frexp.0 = extractvalue { , } %frexp, 0 + ret %frexp.0 +} + +; Test with scalable vectors with constant at False Position +define @test_select_frexp_scalable_vec1( %x, %cond) { +; CHECK-LABEL: define @test_select_frexp_scalable_vec1( +; CHECK-SAME: [[X:%.*]], [[COND:%.*]]) { +; CHECK-NEXT: [[FREXP1:%.*]] = call { , } @llvm.frexp.nxv2f32.nxv2i32( [[X]]) +; CHECK-NEXT: [[MANTISSA:%.*]] = extractvalue { , } [[FREXP1]], 0 +; CHECK-NEXT: [[SELECT_FREXP:%.*]] = select [[COND]], [[MANTISSA]], splat (float 5.000000e-01) +; CHECK-NEXT: ret [[SELECT_FREXP]] +; + %sel = select %cond, %x, splat (float 1.000000e+00) + %frexp = call { , } @llvm.frexp.nxv2f32.nxv2i32( %sel) + %frexp.0 = extractvalue { , } %frexp, 0 + ret %frexp.0 +} + declare { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float>) +declare { , } @llvm.frexp.nxv2f32.nxv2i32()