Skip to content

[AArch64] Allow commuting cmn #150514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 23 additions & 24 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3455,7 +3455,8 @@ static SDValue emitStrictFPComparison(SDValue LHS, SDValue RHS, const SDLoc &DL,
}

static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
const SDLoc &DL, SelectionDAG &DAG) {
AArch64CC::CondCode &OutCC, const SDLoc &DL,
SelectionDAG &DAG) {
EVT VT = LHS.getValueType();
const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();

Expand All @@ -3478,12 +3479,12 @@ static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
// Can we combine a (CMP op1, (sub 0, op2) into a CMN instruction ?
Opcode = AArch64ISD::ADDS;
RHS = RHS.getOperand(1);
} else if (LHS.getOpcode() == ISD::SUB && isNullConstant(LHS.getOperand(0)) &&
isIntEqualitySetCC(CC)) {
} else if (isCMN(LHS, CC, DAG)) {
// As we are looking for EQ/NE compares, the operands can be commuted ; can
// we combine a (CMP (sub 0, op1), op2) into a CMN instruction ?
Opcode = AArch64ISD::ADDS;
LHS = LHS.getOperand(1);
OutCC = getSwappedCondition(OutCC);
} else if (isNullConstant(RHS) && !isUnsignedIntSetCC(CC)) {
if (LHS.getOpcode() == ISD::AND) {
// Similarly, (CMP (and X, Y), 0) can be implemented with a TST
Expand Down Expand Up @@ -3561,7 +3562,7 @@ static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
static SDValue emitConditionalComparison(SDValue LHS, SDValue RHS,
ISD::CondCode CC, SDValue CCOp,
AArch64CC::CondCode Predicate,
AArch64CC::CondCode OutCC,
AArch64CC::CondCode &OutCC,
const SDLoc &DL, SelectionDAG &DAG) {
unsigned Opcode = 0;
const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();
Expand All @@ -3583,12 +3584,12 @@ static SDValue emitConditionalComparison(SDValue LHS, SDValue RHS,
} else if (isCMN(RHS, CC, DAG)) {
Opcode = AArch64ISD::CCMN;
RHS = RHS.getOperand(1);
} else if (LHS.getOpcode() == ISD::SUB && isNullConstant(LHS.getOperand(0)) &&
isIntEqualitySetCC(CC)) {
} else if (isCMN(LHS, CC, DAG)) {
// As we are looking for EQ/NE compares, the operands can be commuted ; can
// we combine a (CCMP (sub 0, op1), op2) into a CCMN instruction ?
Opcode = AArch64ISD::CCMN;
LHS = LHS.getOperand(1);
OutCC = getSwappedCondition(OutCC);
}
if (Opcode == 0)
Opcode = AArch64ISD::CCMP;
Expand Down Expand Up @@ -3701,7 +3702,7 @@ static SDValue emitConjunctionRec(SelectionDAG &DAG, SDValue Val,
if (ExtraCC != AArch64CC::AL) {
SDValue ExtraCmp;
if (!CCOp.getNode())
ExtraCmp = emitComparison(LHS, RHS, CC, DL, DAG);
ExtraCmp = emitComparison(LHS, RHS, CC, ExtraCC, DL, DAG);
else
ExtraCmp = emitConditionalComparison(LHS, RHS, CC, CCOp, Predicate,
ExtraCC, DL, DAG);
Expand All @@ -3712,7 +3713,7 @@ static SDValue emitConjunctionRec(SelectionDAG &DAG, SDValue Val,

// Produce a normal comparison if we are first in the chain
if (!CCOp)
return emitComparison(LHS, RHS, CC, DL, DAG);
return emitComparison(LHS, RHS, CC, OutCC, DL, DAG);
// Otherwise produce a ccmp.
return emitConditionalComparison(LHS, RHS, CC, CCOp, Predicate, OutCC, DL,
DAG);
Expand Down Expand Up @@ -3929,13 +3930,11 @@ static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC,
// can be turned into:
// cmp w12, w11, lsl #1
if (!isa<ConstantSDNode>(RHS) || !isLegalCmpImmed(RHS->getAsAPIntVal())) {
bool LHSIsCMN = isCMN(LHS, CC, DAG);
bool RHSIsCMN = isCMN(RHS, CC, DAG);
SDValue TheLHS = LHSIsCMN ? LHS.getOperand(1) : LHS;
SDValue TheRHS = RHSIsCMN ? RHS.getOperand(1) : RHS;
SDValue TheLHS = isCMN(LHS, CC, DAG) ? LHS.getOperand(1) : LHS;
SDValue TheRHS = isCMN(RHS, CC, DAG) ? RHS.getOperand(1) : RHS;

if (getCmpOperandFoldingProfit(TheLHS) + (LHSIsCMN ? 1 : 0) >
getCmpOperandFoldingProfit(TheRHS) + (RHSIsCMN ? 1 : 0)) {
if (getCmpOperandFoldingProfit(TheLHS) >
getCmpOperandFoldingProfit(TheRHS)) {
std::swap(LHS, RHS);
CC = ISD::getSetCCSwappedOperands(CC);
}
Expand Down Expand Up @@ -3971,10 +3970,11 @@ static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC,
SDValue SExt =
DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, LHS.getValueType(), LHS,
DAG.getValueType(MVT::i16));

AArch64CC = changeIntCCToAArch64CC(CC);
Cmp = emitComparison(
SExt, DAG.getSignedConstant(ValueofRHS, DL, RHS.getValueType()), CC,
DL, DAG);
AArch64CC = changeIntCCToAArch64CC(CC);
AArch64CC, DL, DAG);
}
}

Expand All @@ -3987,8 +3987,8 @@ static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC,
}

if (!Cmp) {
Cmp = emitComparison(LHS, RHS, CC, DL, DAG);
AArch64CC = changeIntCCToAArch64CC(CC);
Cmp = emitComparison(LHS, RHS, CC, AArch64CC, DL, DAG);
}
AArch64cc = DAG.getConstant(AArch64CC, DL, MVT_CC);
return Cmp;
Expand Down Expand Up @@ -10574,8 +10574,8 @@ SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const {

// Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
// clean. Some of them require two branches to implement.
SDValue Cmp = emitComparison(LHS, RHS, CC, DL, DAG);
AArch64CC::CondCode CC1, CC2;
AArch64CC::CondCode CC1 = AArch64CC::AL, CC2;
SDValue Cmp = emitComparison(LHS, RHS, CC, CC1, DL, DAG);
changeFPCCToAArch64CC(CC, CC1, CC2);
SDValue CC1Val = DAG.getConstant(CC1, DL, MVT::i32);
SDValue BR1 =
Expand Down Expand Up @@ -11059,12 +11059,12 @@ SDValue AArch64TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
// If that fails, we'll need to perform an FCMP + CSEL sequence. Go ahead
// and do the comparison.
SDValue Cmp;
AArch64CC::CondCode CC1 = AArch64CC::AL, CC2;
if (IsStrict)
Cmp = emitStrictFPComparison(LHS, RHS, DL, DAG, Chain, IsSignaling);
else
Cmp = emitComparison(LHS, RHS, CC, DL, DAG);
Cmp = emitComparison(LHS, RHS, CC, CC1, DL, DAG);

AArch64CC::CondCode CC1, CC2;
changeFPCCToAArch64CC(CC, CC1, CC2);
SDValue Res;
if (CC2 == AArch64CC::AL) {
Expand Down Expand Up @@ -11444,12 +11444,11 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(
if (VectorCmp)
return VectorCmp;
}

SDValue Cmp = emitComparison(LHS, RHS, CC, DL, DAG);
AArch64CC::CondCode CC1 = AArch64CC::AL, CC2;
SDValue Cmp = emitComparison(LHS, RHS, CC, CC1, DL, DAG);

// Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
// clean. Some of them require two CSELs to implement.
AArch64CC::CondCode CC1, CC2;
changeFPCCToAArch64CC(CC, CC1, CC2);

if (DAG.getTarget().Options.UnsafeFPMath) {
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Target/AArch64/GISel/AArch64GlobalISelUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ bool AArch64GISelUtils::isCMN(const MachineInstr *MaybeSub,
//
// %sub = G_SUB 0, %y
// %cmp = G_ICMP eq/ne, %z, %sub
// or with signed comparisons with the no-signed-wrap flag set
if (!MaybeSub || MaybeSub->getOpcode() != TargetOpcode::G_SUB ||
!CmpInst::isEquality(Pred))
(!CmpInst::isEquality(Pred) &&
!(CmpInst::isSigned(Pred) && MaybeSub->getFlag(MachineInstr::NoSWrap))))
return false;
auto MaybeZero =
getIConstantVRegValWithLookThrough(MaybeSub->getOperand(1).getReg(), MRI);
Expand Down
74 changes: 50 additions & 24 deletions llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ class AArch64InstructionSelector : public InstructionSelector {
MachineInstr *emitConditionalComparison(Register LHS, Register RHS,
CmpInst::Predicate CC,
AArch64CC::CondCode Predicate,
AArch64CC::CondCode OutCC,
AArch64CC::CondCode &OutCC,
MachineIRBuilder &MIB) const;
MachineInstr *emitConjunctionRec(Register Val, AArch64CC::CondCode &OutCC,
bool Negate, Register CCOp,
Expand Down Expand Up @@ -1810,7 +1810,7 @@ bool AArch64InstructionSelector::selectCompareBranchFedByICmp(

// Couldn't optimize. Emit a compare + a Bcc.
MachineBasicBlock *DestMBB = I.getOperand(1).getMBB();
auto PredOp = ICmp.getOperand(1);
auto &PredOp = ICmp.getOperand(1);
emitIntegerCompare(ICmp.getOperand(2), ICmp.getOperand(3), PredOp, MIB);
const AArch64CC::CondCode CC = changeICMPPredToAArch64CC(
static_cast<CmpInst::Predicate>(PredOp.getPredicate()));
Expand Down Expand Up @@ -2506,12 +2506,12 @@ bool AArch64InstructionSelector::earlySelect(MachineInstr &I) {
return false;
}
auto &PredOp = Cmp->getOperand(1);
auto Pred = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
const AArch64CC::CondCode InvCC =
changeICMPPredToAArch64CC(CmpInst::getInversePredicate(Pred));
MIB.setInstrAndDebugLoc(I);
emitIntegerCompare(/*LHS=*/Cmp->getOperand(2),
/*RHS=*/Cmp->getOperand(3), PredOp, MIB);
auto Pred = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
const AArch64CC::CondCode InvCC =
changeICMPPredToAArch64CC(CmpInst::getInversePredicate(Pred));
emitCSINC(/*Dst=*/AddDst, /*Src =*/AddLHS, /*Src2=*/AddLHS, InvCC, MIB);
I.eraseFromParent();
return true;
Expand Down Expand Up @@ -3574,10 +3574,11 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
return false;
}

auto Pred = static_cast<CmpInst::Predicate>(I.getOperand(1).getPredicate());
auto &PredOp = I.getOperand(1);
emitIntegerCompare(I.getOperand(2), I.getOperand(3), PredOp, MIB);
auto Pred = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
const AArch64CC::CondCode InvCC =
changeICMPPredToAArch64CC(CmpInst::getInversePredicate(Pred));
emitIntegerCompare(I.getOperand(2), I.getOperand(3), I.getOperand(1), MIB);
emitCSINC(/*Dst=*/I.getOperand(0).getReg(), /*Src1=*/AArch64::WZR,
/*Src2=*/AArch64::WZR, InvCC, MIB);
I.eraseFromParent();
Expand Down Expand Up @@ -4868,7 +4869,7 @@ static bool canEmitConjunction(Register Val, bool &CanNegate, bool &MustBeFirst,

MachineInstr *AArch64InstructionSelector::emitConditionalComparison(
Register LHS, Register RHS, CmpInst::Predicate CC,
AArch64CC::CondCode Predicate, AArch64CC::CondCode OutCC,
AArch64CC::CondCode Predicate, AArch64CC::CondCode &OutCC,
MachineIRBuilder &MIB) const {
auto &MRI = *MIB.getMRI();
LLT OpTy = MRI.getType(LHS);
Expand All @@ -4877,7 +4878,25 @@ MachineInstr *AArch64InstructionSelector::emitConditionalComparison(
if (CmpInst::isIntPredicate(CC)) {
assert(OpTy.getSizeInBits() == 32 || OpTy.getSizeInBits() == 64);
C = getIConstantVRegValWithLookThrough(RHS, MRI);
if (!C || C->Value.sgt(31) || C->Value.slt(-31))
if (!C) {
MachineInstr *Def = getDefIgnoringCopies(RHS, MRI);
if (isCMN(Def, CC, MRI)) {
RHS = Def->getOperand(2).getReg();
CCmpOpc =
OpTy.getSizeInBits() == 32 ? AArch64::CCMNWr : AArch64::CCMNXr;
} else {
Def = getDefIgnoringCopies(LHS, MRI);
if (isCMN(Def, CC, MRI)) {
LHS = Def->getOperand(2).getReg();
OutCC = getSwappedCondition(OutCC);
CCmpOpc =
OpTy.getSizeInBits() == 32 ? AArch64::CCMNWr : AArch64::CCMNXr;
} else {
CCmpOpc =
OpTy.getSizeInBits() == 32 ? AArch64::CCMPWr : AArch64::CCMPXr;
}
}
} else if (C->Value.sgt(31) || C->Value.slt(-31))
CCmpOpc = OpTy.getSizeInBits() == 32 ? AArch64::CCMPWr : AArch64::CCMPXr;
else if (C->Value.ule(31))
CCmpOpc = OpTy.getSizeInBits() == 32 ? AArch64::CCMPWi : AArch64::CCMPXi;
Expand All @@ -4903,8 +4922,7 @@ MachineInstr *AArch64InstructionSelector::emitConditionalComparison(
}
AArch64CC::CondCode InvOutCC = AArch64CC::getInvertedCondCode(OutCC);
unsigned NZCV = AArch64CC::getNZCVToSatisfyCondCode(InvOutCC);
auto CCmp =
MIB.buildInstr(CCmpOpc, {}, {LHS});
auto CCmp = MIB.buildInstr(CCmpOpc, {}, {LHS});
if (CCmpOpc == AArch64::CCMPWi || CCmpOpc == AArch64::CCMPXi)
CCmp.addImm(C->Value.getZExtValue());
else if (CCmpOpc == AArch64::CCMNWi || CCmpOpc == AArch64::CCMNXi)
Expand Down Expand Up @@ -5096,11 +5114,11 @@ bool AArch64InstructionSelector::tryOptSelect(GSelect &I) {

AArch64CC::CondCode CondCode;
if (CondOpc == TargetOpcode::G_ICMP) {
auto Pred =
static_cast<CmpInst::Predicate>(CondDef->getOperand(1).getPredicate());
auto &PredOp = CondDef->getOperand(1);
emitIntegerCompare(CondDef->getOperand(2), CondDef->getOperand(3), PredOp,
MIB);
auto Pred = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
CondCode = changeICMPPredToAArch64CC(Pred);
emitIntegerCompare(CondDef->getOperand(2), CondDef->getOperand(3),
CondDef->getOperand(1), MIB);
} else {
// Get the condition code for the select.
auto Pred =
Expand Down Expand Up @@ -5148,29 +5166,37 @@ MachineInstr *AArch64InstructionSelector::tryFoldIntegerCompare(
MachineInstr *LHSDef = getDefIgnoringCopies(LHS.getReg(), MRI);
MachineInstr *RHSDef = getDefIgnoringCopies(RHS.getReg(), MRI);
auto P = static_cast<CmpInst::Predicate>(Predicate.getPredicate());

// Given this:
//
// x = G_SUB 0, y
// G_ICMP x, z
// G_ICMP z, x
//
// Produce this:
//
// cmn y, z
if (isCMN(LHSDef, P, MRI))
return emitCMN(LHSDef->getOperand(2), RHS, MIRBuilder);
// cmn z, y
if (isCMN(RHSDef, P, MRI))
return emitCMN(LHS, RHSDef->getOperand(2), MIRBuilder);

// Same idea here, but with the RHS of the compare instead:
// Same idea here, but with the LHS of the compare instead:
//
// Given this:
//
// x = G_SUB 0, y
// G_ICMP z, x
// G_ICMP x, z
//
// Produce this:
//
// cmn z, y
if (isCMN(RHSDef, P, MRI))
return emitCMN(LHS, RHSDef->getOperand(2), MIRBuilder);
// cmn y, z
//
// But be careful! We need to swap the predicate!
if (isCMN(LHSDef, P, MRI)) {
if (!CmpInst::isEquality(P)) {
P = CmpInst::getSwappedPredicate(P);
Predicate = MachineOperand::CreatePredicate(P);
}
return emitCMN(LHSDef->getOperand(2), RHS, MIRBuilder);
}

// Given this:
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -667,11 +667,10 @@ body: |
; SELECT-NEXT: {{ $}}
; SELECT-NEXT: %zero:gpr64 = COPY $xzr
; SELECT-NEXT: %reg0:gpr64 = COPY $x0
; SELECT-NEXT: %shl:gpr64 = UBFMXri %reg0, 1, 0
; SELECT-NEXT: %cmp_lhs:gpr64 = SUBSXrs %zero, %reg0, 63, implicit-def dead $nzcv
; SELECT-NEXT: %reg1:gpr64 = COPY $x1
; SELECT-NEXT: %sext_in_reg:gpr64 = SBFMXri %reg1, 0, 0
; SELECT-NEXT: %cmp_rhs:gpr64 = SUBSXrs %zero, %sext_in_reg, 131, implicit-def dead $nzcv
; SELECT-NEXT: [[ADDSXrr:%[0-9]+]]:gpr64 = ADDSXrr %shl, %cmp_rhs, implicit-def $nzcv
; SELECT-NEXT: [[ADDSXrs:%[0-9]+]]:gpr64 = ADDSXrs %cmp_lhs, %sext_in_reg, 131, implicit-def $nzcv
; SELECT-NEXT: %cmp:gpr32 = CSINCWr $wzr, $wzr, 1, implicit $nzcv
; SELECT-NEXT: $w0 = COPY %cmp
; SELECT-NEXT: RET_ReallyLR implicit $w0
Expand Down
16 changes: 7 additions & 9 deletions llvm/test/CodeGen/AArch64/cmp-chains.ll
Original file line number Diff line number Diff line change
Expand Up @@ -270,14 +270,13 @@ define i32 @neg_range_int_comp(i32 %a, i32 %b, i32 %c, i32 %d) {
;
; GISEL-LABEL: neg_range_int_comp:
; GISEL: // %bb.0:
; GISEL-NEXT: orr w8, w3, #0x1
; GISEL-NEXT: cmp w0, w2
; GISEL-NEXT: neg w8, w8
; GISEL-NEXT: ccmp w1, w8, #4, lt
; GISEL-NEXT: orr w8, w3, #0x1
; GISEL-NEXT: ccmn w1, w8, #4, lt
; GISEL-NEXT: csel w0, w1, w0, gt
; GISEL-NEXT: ret
%dor = or i32 %d, 1
%negd = sub i32 0, %dor
%negd = sub nsw i32 0, %dor
%cmp = icmp sgt i32 %b, %negd
%cmp1 = icmp slt i32 %a, %c
%or.cond = and i1 %cmp, %cmp1
Expand Down Expand Up @@ -373,14 +372,13 @@ define i32 @neg_range_int_comp2(i32 %a, i32 %b, i32 %c, i32 %d) {
;
; GISEL-LABEL: neg_range_int_comp2:
; GISEL: // %bb.0:
; GISEL-NEXT: orr w8, w3, #0x1
; GISEL-NEXT: cmp w0, w2
; GISEL-NEXT: neg w8, w8
; GISEL-NEXT: ccmp w1, w8, #0, ge
; GISEL-NEXT: orr w8, w3, #0x1
; GISEL-NEXT: ccmn w1, w8, #0, ge
; GISEL-NEXT: csel w0, w1, w0, lt
; GISEL-NEXT: ret
%dor = or i32 %d, 1
%negd = sub i32 0, %dor
%negd = sub nsw i32 0, %dor
%cmp = icmp slt i32 %b, %negd
%cmp1 = icmp sge i32 %a, %c
%or.cond = and i1 %cmp, %cmp1
Expand All @@ -407,7 +405,7 @@ define i32 @neg_range_int_comp_u2(i32 %a, i32 %b, i32 %c, i32 %d) {
; GISEL-NEXT: csel w0, w1, w0, lo
; GISEL-NEXT: ret
%dor = or i32 %d, 1
%negd = sub i32 0, %dor
%negd = sub nsw i32 0, %dor
%cmp = icmp ult i32 %b, %negd
%cmp1 = icmp sgt i32 %a, %c
%or.cond = and i1 %cmp, %cmp1
Expand Down
Loading