Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] InstCombine should fold frexp of select to select of frexp #121227

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 121 additions & 119 deletions llvm/include/llvm/IR/IRBuilder.h

Large diffs are not rendered by default.

59 changes: 30 additions & 29 deletions llvm/lib/IR/IRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ void IRBuilderBase::SetInstDebugLocation(Instruction *I) const {

CallInst *
IRBuilderBase::createCallHelper(Function *Callee, ArrayRef<Value *> Ops,
const Twine &Name, Instruction *FMFSource,
const Twine &Name, FMFSource FMFSource,
ArrayRef<OperandBundleDef> OpBundles) {
CallInst *CI = CreateCall(Callee, Ops, OpBundles, Name);
if (FMFSource)
CI->copyFastMathFlags(FMFSource);
if (isa<FPMathOperator>(CI))
CI->setFastMathFlags(FMFSource.get(FMF));
return CI;
}

Expand Down Expand Up @@ -869,28 +869,28 @@ CallInst *IRBuilderBase::CreateGCGetPointerOffset(Value *DerivedPtr,
}

CallInst *IRBuilderBase::CreateUnaryIntrinsic(Intrinsic::ID ID, Value *V,
Instruction *FMFSource,
FMFSource FMFSource,
const Twine &Name) {
Module *M = BB->getModule();
Function *Fn = Intrinsic::getOrInsertDeclaration(M, ID, {V->getType()});
return createCallHelper(Fn, {V}, Name, FMFSource);
}

Value *IRBuilderBase::CreateBinaryIntrinsic(Intrinsic::ID ID, Value *LHS,
Value *RHS, Instruction *FMFSource,
Value *RHS, FMFSource FMFSource,
const Twine &Name) {
Module *M = BB->getModule();
Function *Fn = Intrinsic::getOrInsertDeclaration(M, ID, {LHS->getType()});
if (Value *V = Folder.FoldBinaryIntrinsic(ID, LHS, RHS, Fn->getReturnType(),
FMFSource))
/*FMFSource=*/nullptr))
return V;
return createCallHelper(Fn, {LHS, RHS}, Name, FMFSource);
}

CallInst *IRBuilderBase::CreateIntrinsic(Intrinsic::ID ID,
ArrayRef<Type *> Types,
ArrayRef<Value *> Args,
Instruction *FMFSource,
FMFSource FMFSource,
const Twine &Name) {
Module *M = BB->getModule();
Function *Fn = Intrinsic::getOrInsertDeclaration(M, ID, Types);
Expand All @@ -899,7 +899,7 @@ CallInst *IRBuilderBase::CreateIntrinsic(Intrinsic::ID ID,

CallInst *IRBuilderBase::CreateIntrinsic(Type *RetTy, Intrinsic::ID ID,
ArrayRef<Value *> Args,
Instruction *FMFSource,
FMFSource FMFSource,
const Twine &Name) {
Module *M = BB->getModule();

Expand All @@ -925,16 +925,13 @@ CallInst *IRBuilderBase::CreateIntrinsic(Type *RetTy, Intrinsic::ID ID,
}

CallInst *IRBuilderBase::CreateConstrainedFPBinOp(
Intrinsic::ID ID, Value *L, Value *R, Instruction *FMFSource,
const Twine &Name, MDNode *FPMathTag,
std::optional<RoundingMode> Rounding,
Intrinsic::ID ID, Value *L, Value *R, FMFSource FMFSource,
const Twine &Name, MDNode *FPMathTag, std::optional<RoundingMode> Rounding,
std::optional<fp::ExceptionBehavior> Except) {
Value *RoundingV = getConstrainedFPRounding(Rounding);
Value *ExceptV = getConstrainedFPExcept(Except);

FastMathFlags UseFMF = FMF;
if (FMFSource)
UseFMF = FMFSource->getFastMathFlags();
FastMathFlags UseFMF = FMFSource.get(FMF);

CallInst *C = CreateIntrinsic(ID, {L->getType()},
{L, R, RoundingV, ExceptV}, nullptr, Name);
Expand All @@ -944,14 +941,12 @@ CallInst *IRBuilderBase::CreateConstrainedFPBinOp(
}

CallInst *IRBuilderBase::CreateConstrainedFPUnroundedBinOp(
Intrinsic::ID ID, Value *L, Value *R, Instruction *FMFSource,
Intrinsic::ID ID, Value *L, Value *R, FMFSource FMFSource,
const Twine &Name, MDNode *FPMathTag,
std::optional<fp::ExceptionBehavior> Except) {
Value *ExceptV = getConstrainedFPExcept(Except);

FastMathFlags UseFMF = FMF;
if (FMFSource)
UseFMF = FMFSource->getFastMathFlags();
FastMathFlags UseFMF = FMFSource.get(FMF);

CallInst *C =
CreateIntrinsic(ID, {L->getType()}, {L, R, ExceptV}, nullptr, Name);
Expand All @@ -976,15 +971,12 @@ Value *IRBuilderBase::CreateNAryOp(unsigned Opc, ArrayRef<Value *> Ops,
}

CallInst *IRBuilderBase::CreateConstrainedFPCast(
Intrinsic::ID ID, Value *V, Type *DestTy,
Instruction *FMFSource, const Twine &Name, MDNode *FPMathTag,
std::optional<RoundingMode> Rounding,
Intrinsic::ID ID, Value *V, Type *DestTy, FMFSource FMFSource,
const Twine &Name, MDNode *FPMathTag, std::optional<RoundingMode> Rounding,
std::optional<fp::ExceptionBehavior> Except) {
Value *ExceptV = getConstrainedFPExcept(Except);

FastMathFlags UseFMF = FMF;
if (FMFSource)
UseFMF = FMFSource->getFastMathFlags();
FastMathFlags UseFMF = FMFSource.get(FMF);

CallInst *C;
if (Intrinsic::hasConstrainedFPRoundingModeOperand(ID)) {
Expand All @@ -1002,9 +994,10 @@ CallInst *IRBuilderBase::CreateConstrainedFPCast(
return C;
}

Value *IRBuilderBase::CreateFCmpHelper(
CmpInst::Predicate P, Value *LHS, Value *RHS, const Twine &Name,
MDNode *FPMathTag, bool IsSignaling) {
Value *IRBuilderBase::CreateFCmpHelper(CmpInst::Predicate P, Value *LHS,
Value *RHS, const Twine &Name,
MDNode *FPMathTag, FMFSource FMFSource,
bool IsSignaling) {
if (IsFPConstrained) {
auto ID = IsSignaling ? Intrinsic::experimental_constrained_fcmps
: Intrinsic::experimental_constrained_fcmp;
Expand All @@ -1013,7 +1006,9 @@ Value *IRBuilderBase::CreateFCmpHelper(

if (auto *V = Folder.FoldCmp(P, LHS, RHS))
return V;
return Insert(setFPAttrs(new FCmpInst(P, LHS, RHS), FPMathTag, FMF), Name);
return Insert(
setFPAttrs(new FCmpInst(P, LHS, RHS), FPMathTag, FMFSource.get(FMF)),
Name);
}

CallInst *IRBuilderBase::CreateConstrainedFPCmp(
Expand Down Expand Up @@ -1047,6 +1042,12 @@ CallInst *IRBuilderBase::CreateConstrainedFPCall(

Value *IRBuilderBase::CreateSelect(Value *C, Value *True, Value *False,
const Twine &Name, Instruction *MDFrom) {
return CreateSelectFMF(C, True, False, {}, Name, MDFrom);
}

Value *IRBuilderBase::CreateSelectFMF(Value *C, Value *True, Value *False,
FMFSource FMFSource, const Twine &Name,
Instruction *MDFrom) {
if (auto *V = Folder.FoldSelect(C, True, False))
return V;

Expand All @@ -1057,7 +1058,7 @@ Value *IRBuilderBase::CreateSelect(Value *C, Value *True, Value *False,
Sel = addBranchMetadata(Sel, Prof, Unpred);
}
if (isa<FPMathOperator>(Sel))
setFPAttrs(Sel, nullptr /* MDNode* */, FMF);
setFPAttrs(Sel, /*MDNode=*/nullptr, FMFSource.get(FMF));
return Insert(Sel, Name);
}

Expand Down
6 changes: 2 additions & 4 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1638,10 +1638,8 @@ instCombineSVEVectorBinOp(InstCombiner &IC, IntrinsicInst &II) {
!match(OpPredicate, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
m_ConstantInt<AArch64SVEPredPattern::all>())))
return std::nullopt;
IRBuilderBase::FastMathFlagGuard FMFGuard(IC.Builder);
IC.Builder.setFastMathFlags(II.getFastMathFlags());
auto BinOp =
IC.Builder.CreateBinOp(BinOpCode, II.getOperand(1), II.getOperand(2));
auto BinOp = IC.Builder.CreateBinOpFMF(
BinOpCode, II.getOperand(1), II.getOperand(2), II.getFastMathFlags());
return IC.replaceInstUsesWith(II, BinOp);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,11 +425,8 @@ static bool foldSqrt(CallInst *Call, LibFunc Func, TargetTransformInfo &TTI,
Arg, 0,
SimplifyQuery(Call->getDataLayout(), &TLI, &DT, &AC, Call)))) {
IRBuilder<> Builder(Call);
IRBuilderBase::FastMathFlagGuard Guard(Builder);
Builder.setFastMathFlags(Call->getFastMathFlags());

Value *NewSqrt = Builder.CreateIntrinsic(Intrinsic::sqrt, Ty, Arg,
/*FMFSource=*/nullptr, "sqrt");
Value *NewSqrt =
Builder.CreateIntrinsic(Intrinsic::sqrt, Ty, Arg, Call, "sqrt");
Call->replaceAllUsesWith(NewSqrt);

// Explicitly erase the old call because a call with side effects is not
Expand Down
19 changes: 7 additions & 12 deletions llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2845,12 +2845,11 @@ Instruction *InstCombinerImpl::hoistFNegAboveFMulFDiv(Value *FNegOp,
// Make sure to preserve flags and metadata on the call.
if (II->getIntrinsicID() == Intrinsic::ldexp) {
FastMathFlags FMF = FMFSource.getFastMathFlags() | II->getFastMathFlags();
IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
Builder.setFastMathFlags(FMF);

CallInst *New = Builder.CreateCall(
II->getCalledFunction(),
{Builder.CreateFNeg(II->getArgOperand(0)), II->getArgOperand(1)});
CallInst *New =
Builder.CreateCall(II->getCalledFunction(),
{Builder.CreateFNegFMF(II->getArgOperand(0), FMF),
II->getArgOperand(1)});
New->setFastMathFlags(FMF);
New->copyMetadata(*II);
return New;
}
Expand Down Expand Up @@ -2932,12 +2931,8 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) {
// flags the copysign doesn't also have.
FastMathFlags FMF = I.getFastMathFlags();
FMF &= cast<FPMathOperator>(OneUse)->getFastMathFlags();

IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
Builder.setFastMathFlags(FMF);

Value *NegY = Builder.CreateFNeg(Y);
Value *NewCopySign = Builder.CreateCopySign(X, NegY);
Value *NegY = Builder.CreateFNegFMF(Y, FMF);
Value *NewCopySign = Builder.CreateCopySign(X, NegY, FMF);
return replaceInstUsesWith(I, NewCopySign);
}

Expand Down
37 changes: 14 additions & 23 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ static Value *getNewICmpValue(unsigned Code, bool Sign, Value *LHS, Value *RHS,
/// This is the complement of getFCmpCode, which turns an opcode and two
/// operands into either a FCmp instruction, or a true/false constant.
static Value *getFCmpValue(unsigned Code, Value *LHS, Value *RHS,
InstCombiner::BuilderTy &Builder) {
InstCombiner::BuilderTy &Builder,
FastMathFlags FMF) {
FCmpInst::Predicate NewPred;
if (Constant *TorF = getPredForFCmpCode(Code, LHS->getType(), NewPred))
return TorF;
return Builder.CreateFCmp(NewPred, LHS, RHS);
return Builder.CreateFCmpFMF(NewPred, LHS, RHS, FMF);
}

/// Emit a computation of: (V >= Lo && V < Hi) if Inside is true, otherwise
Expand Down Expand Up @@ -1429,12 +1430,9 @@ static Value *matchIsFiniteTest(InstCombiner::BuilderTy &Builder, FCmpInst *LHS,
!matchUnorderedInfCompare(PredR, RHS0, RHS1))
return nullptr;

IRBuilder<>::FastMathFlagGuard FMFG(Builder);
FastMathFlags FMF = LHS->getFastMathFlags();
FMF &= RHS->getFastMathFlags();
Builder.setFastMathFlags(FMF);

return Builder.CreateFCmp(FCmpInst::getOrderedPredicate(PredR), RHS0, RHS1);
return Builder.CreateFCmpFMF(FCmpInst::getOrderedPredicate(PredR), RHS0, RHS1,
LHS->getFastMathFlags() &
RHS->getFastMathFlags());
}

Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS,
Expand Down Expand Up @@ -1470,12 +1468,8 @@ Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS,

// Intersect the fast math flags.
// TODO: We can union the fast math flags unless this is a logical select.
IRBuilder<>::FastMathFlagGuard FMFG(Builder);
FastMathFlags FMF = LHS->getFastMathFlags();
FMF &= RHS->getFastMathFlags();
Builder.setFastMathFlags(FMF);

return getFCmpValue(NewPred, LHS0, LHS1, Builder);
return getFCmpValue(NewPred, LHS0, LHS1, Builder,
LHS->getFastMathFlags() & RHS->getFastMathFlags());
}

// This transform is not valid for a logical select.
Expand All @@ -1492,10 +1486,8 @@ Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS,
// Ignore the constants because they are obviously not NANs:
// (fcmp ord x, 0.0) & (fcmp ord y, 0.0) -> (fcmp ord x, y)
// (fcmp uno x, 0.0) | (fcmp uno y, 0.0) -> (fcmp uno x, y)
IRBuilder<>::FastMathFlagGuard FMFG(Builder);
Builder.setFastMathFlags(LHS->getFastMathFlags() &
RHS->getFastMathFlags());
return Builder.CreateFCmp(PredL, LHS0, RHS0);
return Builder.CreateFCmpFMF(
PredL, LHS0, RHS0, LHS->getFastMathFlags() & RHS->getFastMathFlags());
}
}

Expand Down Expand Up @@ -1557,15 +1549,14 @@ Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS,
std::swap(PredL, PredR);
}
if (IsLessThanOrLessEqual(IsAnd ? PredL : PredR)) {
BuilderTy::FastMathFlagGuard Guard(Builder);
FastMathFlags NewFlag = LHS->getFastMathFlags();
if (!IsLogicalSelect)
NewFlag |= RHS->getFastMathFlags();
Builder.setFastMathFlags(NewFlag);

Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, LHS0);
return Builder.CreateFCmp(PredL, FAbs,
ConstantFP::get(LHS0->getType(), *LHSC));
Value *FAbs =
Builder.CreateUnaryIntrinsic(Intrinsic::fabs, LHS0, NewFlag);
return Builder.CreateFCmpFMF(
PredL, FAbs, ConstantFP::get(LHS0->getType(), *LHSC), NewFlag);
}
}

Expand Down
16 changes: 8 additions & 8 deletions llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1852,15 +1852,13 @@ Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) {
Value *X;
Instruction *Op = dyn_cast<Instruction>(FPT.getOperand(0));
if (Op && Op->hasOneUse()) {
IRBuilder<>::FastMathFlagGuard FMFG(Builder);
FastMathFlags FMF = FPT.getFastMathFlags();
if (auto *FPMO = dyn_cast<FPMathOperator>(Op))
FMF &= FPMO->getFastMathFlags();
Builder.setFastMathFlags(FMF);

if (match(Op, m_FNeg(m_Value(X)))) {
Value *InnerTrunc = Builder.CreateFPTrunc(X, Ty);
Value *Neg = Builder.CreateFNeg(InnerTrunc);
Value *InnerTrunc = Builder.CreateFPTruncFMF(X, Ty, FMF);
Value *Neg = Builder.CreateFNegFMF(InnerTrunc, FMF);
return replaceInstUsesWith(FPT, Neg);
}

Expand All @@ -1870,15 +1868,17 @@ Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) {
if (match(Op, m_Select(m_Value(Cond), m_FPExt(m_Value(X)), m_Value(Y))) &&
X->getType() == Ty) {
// fptrunc (select Cond, (fpext X), Y --> select Cond, X, (fptrunc Y)
Value *NarrowY = Builder.CreateFPTrunc(Y, Ty);
Value *Sel = Builder.CreateSelect(Cond, X, NarrowY, "narrow.sel", Op);
Value *NarrowY = Builder.CreateFPTruncFMF(Y, Ty, FMF);
Value *Sel =
Builder.CreateSelectFMF(Cond, X, NarrowY, FMF, "narrow.sel", Op);
return replaceInstUsesWith(FPT, Sel);
}
if (match(Op, m_Select(m_Value(Cond), m_Value(Y), m_FPExt(m_Value(X)))) &&
X->getType() == Ty) {
// fptrunc (select Cond, Y, (fpext X) --> select Cond, (fptrunc Y), X
Value *NarrowY = Builder.CreateFPTrunc(Y, Ty);
Value *Sel = Builder.CreateSelect(Cond, NarrowY, X, "narrow.sel", Op);
Value *NarrowY = Builder.CreateFPTruncFMF(Y, Ty, FMF);
Value *Sel =
Builder.CreateSelectFMF(Cond, NarrowY, X, FMF, "narrow.sel", Op);
return replaceInstUsesWith(FPT, Sel);
}
}
Expand Down
Loading
Loading