Skip to content

Commit 730920c

Browse files
committed
Refactored canEvaluateShifted to identify candidates for
simplification.
1 parent 8ea6668 commit 730920c

File tree

2 files changed

+207
-86
lines changed

2 files changed

+207
-86
lines changed

llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp

Lines changed: 114 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -530,112 +530,159 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) {
530530
return nullptr;
531531
}
532532

533-
/// Return true if we can simplify two logical (either left or right) shifts
534-
/// that have constant shift amounts: OuterShift (InnerShift X, C1), C2.
535-
static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl,
536-
Instruction *InnerShift,
537-
InstCombinerImpl &IC, Instruction *CxtI) {
533+
/// Return a bitmask of all constant outer shift amounts that can be simplified
534+
/// by foldShiftedShift().
535+
static APInt getEvaluableShiftedShiftMask(bool IsOuterShl,
536+
Instruction *InnerShift,
537+
InstCombinerImpl &IC,
538+
Instruction *CxtI) {
538539
assert(InnerShift->isLogicalShift() && "Unexpected instruction type");
539540

541+
const unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits();
542+
540543
// We need constant scalar or constant splat shifts.
541544
const APInt *InnerShiftConst;
542545
if (!match(InnerShift->getOperand(1), m_APInt(InnerShiftConst)))
543-
return false;
546+
return APInt::getZero(TypeWidth);
544547

545-
// Two logical shifts in the same direction:
548+
if (InnerShiftConst->uge(TypeWidth))
549+
return APInt::getZero(TypeWidth);
550+
551+
const unsigned InnerShAmt = InnerShiftConst->getZExtValue();
552+
553+
// Two logical shifts in the same direction can always be simplified, so long
554+
// as the total shift amount is legal.
546555
// shl (shl X, C1), C2 --> shl X, C1 + C2
547556
// lshr (lshr X, C1), C2 --> lshr X, C1 + C2
548557
bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
549558
if (IsInnerShl == IsOuterShl)
550-
return true;
559+
return APInt::getLowBitsSet(TypeWidth, TypeWidth - InnerShAmt);
551560

561+
APInt ShMask = APInt::getZero(TypeWidth);
552562
// Equal shift amounts in opposite directions become bitwise 'and':
553563
// lshr (shl X, C), C --> and X, C'
554564
// shl (lshr X, C), C --> and X, C'
555-
if (*InnerShiftConst == OuterShAmt)
556-
return true;
565+
ShMask.setBit(InnerShAmt);
557566

558-
// If the 2nd shift is bigger than the 1st, we can fold:
567+
// If the inner shift is bigger than the outer, we can fold:
559568
// lshr (shl X, C1), C2 --> and (shl X, C1 - C2), C3
560569
// shl (lshr X, C1), C2 --> and (lshr X, C1 - C2), C3
561-
// but it isn't profitable unless we know the and'd out bits are already zero.
562-
// Also, check that the inner shift is valid (less than the type width) or
563-
// we'll crash trying to produce the bit mask for the 'and'.
564-
unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits();
565-
if (InnerShiftConst->ugt(OuterShAmt) && InnerShiftConst->ult(TypeWidth)) {
566-
unsigned InnerShAmt = InnerShiftConst->getZExtValue();
567-
unsigned MaskShift =
568-
IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt;
569-
APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift;
570-
if (IC.MaskedValueIsZero(InnerShift->getOperand(0), Mask, CxtI))
571-
return true;
572-
}
573-
574-
return false;
570+
// but it isn't profitable unless we know the masked out bits are already
571+
// zero.
572+
KnownBits Known = IC.computeKnownBits(InnerShift->getOperand(0), CxtI);
573+
// Isolate the bits that are annihilated by the inner shift.
574+
APInt InnerShMask = IsInnerShl ? Known.Zero.lshr(TypeWidth - InnerShAmt)
575+
: Known.Zero.trunc(InnerShAmt);
576+
// Isolate the upper (resp. lower) InnerShAmt bits of the base operand of the
577+
// inner shl (resp. lshr).
578+
// Then:
579+
// - lshr (shl X, C1), C2 == (shl X, C1 - C2) if the bottom C2 of the isolated
580+
// bits are zero
581+
// - shl (lshr X, C1), C2 == (lshr X, C1 - C2) if the top C2 of the isolated
582+
// bits are zero
583+
const unsigned MaxOuterShAmt =
584+
IsInnerShl ? Known.Zero.lshr(TypeWidth - InnerShAmt).countr_one()
585+
: Known.Zero.trunc(InnerShAmt).countl_one();
586+
ShMask.setLowBits(MaxOuterShAmt);
587+
return ShMask;
575588
}
576589

577-
/// See if we can compute the specified value, but shifted logically to the left
578-
/// or right by some number of bits. This should return true if the expression
579-
/// can be computed for the same cost as the current expression tree. This is
580-
/// used to eliminate extraneous shifting from things like:
581-
/// %C = shl i128 %A, 64
582-
/// %D = shl i128 %B, 96
583-
/// %E = or i128 %C, %D
584-
/// %F = lshr i128 %E, 64
585-
/// where the client will ask if E can be computed shifted right by 64-bits. If
586-
/// this succeeds, getShiftedValue() will be called to produce the value.
587-
static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
588-
InstCombinerImpl &IC, Instruction *CxtI) {
590+
/// Given a bitmask \p ShiftMask of desired shift amounts, determine the submask
591+
/// of bits corresponding to shift amounts X for which the given expression \p V
592+
/// can be computed for at worst the same cost as the current expression tree
593+
/// when shifted by X. For each set bit in the \p ShiftMask afterward,
594+
/// getShiftedValue() can produce the corresponding value.
595+
///
596+
/// \returns true if and only if at least one bit of the \p ShiftMask is set
597+
/// after refinement.
598+
static bool refineEvaluableShiftMask(Value *V, APInt &ShiftMask,
599+
bool IsLeftShift, InstCombinerImpl &IC,
600+
Instruction *CxtI) {
589601
// We can always evaluate immediate constants.
590602
if (match(V, m_ImmConstant()))
591603
return true;
592604

593605
Instruction *I = dyn_cast<Instruction>(V);
594-
if (!I) return false;
606+
if (!I) {
607+
ShiftMask.clearAllBits();
608+
return false;
609+
}
595610

596611
// We can't mutate something that has multiple uses: doing so would
597612
// require duplicating the instruction in general, which isn't profitable.
598-
if (!I->hasOneUse()) return false;
613+
if (!I->hasOneUse()) {
614+
ShiftMask.clearAllBits();
615+
return false;
616+
}
599617

600618
switch (I->getOpcode()) {
601-
default: return false;
619+
default: {
620+
ShiftMask.clearAllBits();
621+
return false;
622+
}
602623
case Instruction::And:
603624
case Instruction::Or:
604625
case Instruction::Xor:
605-
// Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
606-
return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) &&
607-
canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I);
626+
return refineEvaluableShiftMask(I->getOperand(0), ShiftMask, IsLeftShift,
627+
IC, I) &&
628+
refineEvaluableShiftMask(I->getOperand(1), ShiftMask, IsLeftShift,
629+
IC, I);
608630

609631
case Instruction::Shl:
610-
case Instruction::LShr:
611-
return canEvaluateShiftedShift(NumBits, IsLeftShift, I, IC, CxtI);
632+
case Instruction::LShr: {
633+
ShiftMask &= getEvaluableShiftedShiftMask(IsLeftShift, I, IC, CxtI);
634+
return !ShiftMask.isZero();
635+
}
612636

613637
case Instruction::Select: {
614638
SelectInst *SI = cast<SelectInst>(I);
615639
Value *TrueVal = SI->getTrueValue();
616640
Value *FalseVal = SI->getFalseValue();
617-
return canEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) &&
618-
canEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI);
641+
return refineEvaluableShiftMask(TrueVal, ShiftMask, IsLeftShift, IC, SI) &&
642+
refineEvaluableShiftMask(FalseVal, ShiftMask, IsLeftShift, IC, SI);
619643
}
620644
case Instruction::PHI: {
621645
// We can change a phi if we can change all operands. Note that we never
622646
// get into trouble with cyclic PHIs here because we only consider
623647
// instructions with a single use.
624648
PHINode *PN = cast<PHINode>(I);
625649
for (Value *IncValue : PN->incoming_values())
626-
if (!canEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN))
650+
if (!refineEvaluableShiftMask(IncValue, ShiftMask, IsLeftShift, IC, PN))
627651
return false;
628652
return true;
629653
}
630654
case Instruction::Mul: {
631655
const APInt *MulConst;
632656
// We can fold (shr (mul X, -(1 << C)), C) -> (and (neg X), C`)
633-
return !IsLeftShift && match(I->getOperand(1), m_APInt(MulConst)) &&
634-
MulConst->isNegatedPowerOf2() && MulConst->countr_zero() == NumBits;
657+
if (IsLeftShift || !match(I->getOperand(1), m_APInt(MulConst)) ||
658+
!MulConst->isNegatedPowerOf2()) {
659+
ShiftMask.clearAllBits();
660+
return false;
661+
}
662+
ShiftMask &=
663+
APInt::getOneBitSet(ShiftMask.getBitWidth(), MulConst->countr_zero());
664+
return !ShiftMask.isZero();
635665
}
636666
}
637667
}
638668

669+
/// See if we can compute the specified value, but shifted logically to the left
670+
/// or right by some number of bits. This should return true if the expression
671+
/// can be computed for the same cost as the current expression tree. This is
672+
/// used to eliminate extraneous shifting from things like:
673+
/// %C = shl i128 %A, 64
674+
/// %D = shl i128 %B, 96
675+
/// %E = or i128 %C, %D
676+
/// %F = lshr i128 %E, 64
677+
/// where the client will ask if E can be computed shifted right by 64-bits. If
678+
/// this succeeds, getShiftedValue() will be called to produce the value.
679+
static bool canEvaluateShifted(Value *V, unsigned ShAmt, bool IsLeftShift,
680+
InstCombinerImpl &IC, Instruction *CxtI) {
681+
APInt ShiftMask =
682+
APInt::getOneBitSet(V->getType()->getScalarSizeInBits(), ShAmt);
683+
return refineEvaluableShiftMask(V, ShiftMask, IsLeftShift, IC, CxtI);
684+
}
685+
639686
/// Fold OuterShift (InnerShift X, C1), C2.
640687
/// See canEvaluateShiftedShift() for the constraints on these instructions.
641688
static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt,
@@ -985,37 +1032,32 @@ static Instruction *foldShrThroughZExtedShl(BinaryOperator &I, Value *Op,
9851032
InstCombinerImpl &IC,
9861033
const DataLayout &DL) {
9871034
Type *DestTy = I.getType();
1035+
const unsigned InnerBitWidth = Op->getType()->getScalarSizeInBits();
9881036

989-
auto *Inner = dyn_cast<Instruction>(Op);
990-
if (!Inner)
1037+
// Determine if the operand is effectively right-shifted by counting the
1038+
// known leading zero bits.
1039+
KnownBits Known = IC.computeKnownBits(Op, nullptr);
1040+
const unsigned MaxInnerShrAmt = Known.countMinLeadingZeros();
1041+
if (MaxInnerShrAmt == 0)
9911042
return nullptr;
1043+
APInt ShrMask =
1044+
APInt::getLowBitsSet(InnerBitWidth, std::min(MaxInnerShrAmt, ShlAmt) + 1);
9921045

993-
// Dig through operations until the first shift.
994-
while (!Inner->isShift())
995-
if (!match(Inner, m_BinOp(m_OneUse(m_Instruction(Inner)), m_Constant())))
996-
return nullptr;
997-
998-
// Fold only if the inner shift is a logical right-shift.
999-
const APInt *InnerShrConst;
1000-
if (!match(Inner, m_LShr(m_Value(), m_APInt(InnerShrConst))))
1046+
// Undo the maximal inner right shift amount that simplifies the overall
1047+
// computation.
1048+
if (!refineEvaluableShiftMask(Op, ShrMask, /*IsLeftShift=*/true, IC, nullptr))
10011049
return nullptr;
10021050

1003-
const uint64_t InnerShrAmt = InnerShrConst->getZExtValue();
1004-
if (InnerShrAmt >= ShlAmt) {
1005-
const uint64_t ReducedShrAmt = InnerShrAmt - ShlAmt;
1006-
if (!canEvaluateShifted(Op, ReducedShrAmt, /*IsLeftShift=*/false, IC,
1007-
nullptr))
1008-
return nullptr;
1009-
Value *NewOp =
1010-
getShiftedValue(Op, ReducedShrAmt, /*isLeftShift=*/false, IC, DL);
1011-
return new ZExtInst(NewOp, DestTy);
1012-
}
1013-
1014-
if (!canEvaluateShifted(Op, InnerShrAmt, /*IsLeftShift=*/true, IC, nullptr))
1051+
const unsigned InnerShrAmt = ShrMask.getActiveBits() - 1;
1052+
if (InnerShrAmt == 0)
10151053
return nullptr;
1054+
assert(InnerShrAmt <= ShlAmt);
10161055

10171056
const uint64_t ReducedShlAmt = ShlAmt - InnerShrAmt;
10181057
Value *NewOp = getShiftedValue(Op, InnerShrAmt, /*isLeftShift=*/true, IC, DL);
1058+
if (ReducedShlAmt == 0)
1059+
return new ZExtInst(NewOp, DestTy);
1060+
10191061
Value *NewZExt = IC.Builder.CreateZExt(NewOp, DestTy);
10201062
NewZExt->takeName(I.getOperand(0));
10211063
auto *NewShl = BinaryOperator::CreateShl(

llvm/test/Transforms/InstCombine/shifts-around-zext.ll

Lines changed: 93 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
22
; RUN: opt -S -passes=instcombine %s | FileCheck %s
33

4+
declare void @clobber.i32(i32)
5+
46
define i64 @simple(i32 %x) {
57
; CHECK-LABEL: define i64 @simple(
68
; CHECK-SAME: i32 [[X:%.*]]) {
@@ -15,6 +17,20 @@ define i64 @simple(i32 %x) {
1517
ret i64 %shl
1618
}
1719

20+
define <2 x i64> @simple.vec(<2 x i32> %v) {
21+
; CHECK-LABEL: define <2 x i64> @simple.vec(
22+
; CHECK-SAME: <2 x i32> [[V:%.*]]) {
23+
; CHECK-NEXT: [[LSHR:%.*]] = and <2 x i32> [[V]], splat (i32 -256)
24+
; CHECK-NEXT: [[ZEXT:%.*]] = zext <2 x i32> [[LSHR]] to <2 x i64>
25+
; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw <2 x i64> [[ZEXT]], splat (i64 24)
26+
; CHECK-NEXT: ret <2 x i64> [[SHL]]
27+
;
28+
%lshr = lshr <2 x i32> %v, splat(i32 8)
29+
%zext = zext <2 x i32> %lshr to <2 x i64>
30+
%shl = shl <2 x i64> %zext, splat(i64 32)
31+
ret <2 x i64> %shl
32+
}
33+
1834
;; u0xff0 = 4080
1935
define i64 @masked(i32 %x) {
2036
; CHECK-LABEL: define i64 @masked(
@@ -31,6 +47,83 @@ define i64 @masked(i32 %x) {
3147
ret i64 %shl
3248
}
3349

50+
define i64 @masked.multi_use.0(i32 %x) {
51+
; CHECK-LABEL: define i64 @masked.multi_use.0(
52+
; CHECK-SAME: i32 [[X:%.*]]) {
53+
; CHECK-NEXT: [[LSHR:%.*]] = lshr i32 [[X]], 4
54+
; CHECK-NEXT: call void @clobber.i32(i32 [[LSHR]])
55+
; CHECK-NEXT: [[MASK:%.*]] = and i32 [[LSHR]], 255
56+
; CHECK-NEXT: [[ZEXT:%.*]] = zext nneg i32 [[MASK]] to i64
57+
; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[ZEXT]], 48
58+
; CHECK-NEXT: ret i64 [[SHL]]
59+
;
60+
%lshr = lshr i32 %x, 4
61+
call void @clobber.i32(i32 %lshr)
62+
%mask = and i32 %lshr, u0xff
63+
%zext = zext i32 %mask to i64
64+
%shl = shl i64 %zext, 48
65+
ret i64 %shl
66+
}
67+
68+
define i64 @masked.multi_use.1(i32 %x) {
69+
; CHECK-LABEL: define i64 @masked.multi_use.1(
70+
; CHECK-SAME: i32 [[X:%.*]]) {
71+
; CHECK-NEXT: [[LSHR:%.*]] = lshr i32 [[X]], 4
72+
; CHECK-NEXT: [[MASK:%.*]] = and i32 [[LSHR]], 255
73+
; CHECK-NEXT: call void @clobber.i32(i32 [[MASK]])
74+
; CHECK-NEXT: [[ZEXT:%.*]] = zext nneg i32 [[MASK]] to i64
75+
; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[ZEXT]], 48
76+
; CHECK-NEXT: ret i64 [[SHL]]
77+
;
78+
%lshr = lshr i32 %x, 4
79+
%mask = and i32 %lshr, u0xff
80+
call void @clobber.i32(i32 %mask)
81+
%zext = zext i32 %mask to i64
82+
%shl = shl i64 %zext, 48
83+
ret i64 %shl
84+
}
85+
86+
define <2 x i64> @masked.multi_use.2(i32 %x) {
87+
; CHECK-LABEL: define <2 x i64> @masked.multi_use.2(
88+
; CHECK-SAME: i32 [[X:%.*]]) {
89+
; CHECK-NEXT: [[LSHR:%.*]] = lshr i32 [[X]], 4
90+
; CHECK-NEXT: [[MASK:%.*]] = and i32 [[LSHR]], 255
91+
; CHECK-NEXT: [[ZEXT:%.*]] = zext nneg i32 [[MASK]] to i64
92+
; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[ZEXT]], 48
93+
; CHECK-NEXT: [[CLOBBER:%.*]] = xor i32 [[MASK]], 255
94+
; CHECK-NEXT: [[CLOBBER_Z:%.*]] = zext nneg i32 [[CLOBBER]] to i64
95+
; CHECK-NEXT: [[V_0:%.*]] = insertelement <2 x i64> poison, i64 [[SHL]], i64 0
96+
; CHECK-NEXT: [[V_1:%.*]] = insertelement <2 x i64> [[V_0]], i64 [[CLOBBER_Z]], i64 1
97+
; CHECK-NEXT: ret <2 x i64> [[V_1]]
98+
;
99+
%lshr = lshr i32 %x, 4
100+
%mask = and i32 %lshr, u0xff
101+
%zext = zext i32 %mask to i64
102+
%shl = shl i64 %zext, 48
103+
104+
%clobber = xor i32 %mask, u0xff
105+
%clobber.z = zext i32 %clobber to i64
106+
%v.0 = insertelement <2 x i64> poison, i64 %shl, i32 0
107+
%v.1 = insertelement <2 x i64> %v.0, i64 %clobber.z, i32 1
108+
ret <2 x i64> %v.1
109+
}
110+
111+
;; u0xff0 = 4080
112+
define <2 x i64> @masked.vec(<2 x i32> %v) {
113+
; CHECK-LABEL: define <2 x i64> @masked.vec(
114+
; CHECK-SAME: <2 x i32> [[V:%.*]]) {
115+
; CHECK-NEXT: [[MASK:%.*]] = and <2 x i32> [[V]], splat (i32 4080)
116+
; CHECK-NEXT: [[ZEXT:%.*]] = zext nneg <2 x i32> [[MASK]] to <2 x i64>
117+
; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw <2 x i64> [[ZEXT]], splat (i64 44)
118+
; CHECK-NEXT: ret <2 x i64> [[SHL]]
119+
;
120+
%lshr = lshr <2 x i32> %v, splat(i32 4)
121+
%mask = and <2 x i32> %lshr, splat(i32 u0xff)
122+
%zext = zext <2 x i32> %mask to <2 x i64>
123+
%shl = shl <2 x i64> %zext, splat(i64 48)
124+
ret <2 x i64> %shl
125+
}
126+
34127
define i64 @combine(i32 %lower, i32 %upper) {
35128
; CHECK-LABEL: define i64 @combine(
36129
; CHECK-SAME: i32 [[LOWER:%.*]], i32 [[UPPER:%.*]]) {
@@ -67,17 +160,3 @@ define i64 @combine(i32 %lower, i32 %upper) {
67160

68161
ret i64 %o.3
69162
}
70-
71-
define <2 x i64> @simple.vec(<2 x i32> %v) {
72-
; CHECK-LABEL: define <2 x i64> @simple.vec(
73-
; CHECK-SAME: <2 x i32> [[V:%.*]]) {
74-
; CHECK-NEXT: [[LSHR:%.*]] = and <2 x i32> [[V]], splat (i32 -256)
75-
; CHECK-NEXT: [[ZEXT:%.*]] = zext <2 x i32> [[LSHR]] to <2 x i64>
76-
; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw <2 x i64> [[ZEXT]], splat (i64 24)
77-
; CHECK-NEXT: ret <2 x i64> [[SHL]]
78-
;
79-
%lshr = lshr <2 x i32> %v, splat(i32 8)
80-
%zext = zext <2 x i32> %lshr to <2 x i64>
81-
%shl = shl <2 x i64> %zext, splat(i64 32)
82-
ret <2 x i64> %shl
83-
}

0 commit comments

Comments
 (0)