@@ -530,112 +530,159 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) {
530
530
return nullptr ;
531
531
}
532
532
533
- // / Return true if we can simplify two logical (either left or right) shifts
534
- // / that have constant shift amounts: OuterShift (InnerShift X, C1), C2.
535
- static bool canEvaluateShiftedShift (unsigned OuterShAmt, bool IsOuterShl,
536
- Instruction *InnerShift,
537
- InstCombinerImpl &IC, Instruction *CxtI) {
533
+ // / Return a bitmask of all constant outer shift amounts that can be simplified
534
+ // / by foldShiftedShift().
535
+ static APInt getEvaluableShiftedShiftMask (bool IsOuterShl,
536
+ Instruction *InnerShift,
537
+ InstCombinerImpl &IC,
538
+ Instruction *CxtI) {
538
539
assert (InnerShift->isLogicalShift () && " Unexpected instruction type" );
539
540
541
+ const unsigned TypeWidth = InnerShift->getType ()->getScalarSizeInBits ();
542
+
540
543
// We need constant scalar or constant splat shifts.
541
544
const APInt *InnerShiftConst;
542
545
if (!match (InnerShift->getOperand (1 ), m_APInt (InnerShiftConst)))
543
- return false ;
546
+ return APInt::getZero (TypeWidth) ;
544
547
545
- // Two logical shifts in the same direction:
548
+ if (InnerShiftConst->uge (TypeWidth))
549
+ return APInt::getZero (TypeWidth);
550
+
551
+ const unsigned InnerShAmt = InnerShiftConst->getZExtValue ();
552
+
553
+ // Two logical shifts in the same direction can always be simplified, so long
554
+ // as the total shift amount is legal.
546
555
// shl (shl X, C1), C2 --> shl X, C1 + C2
547
556
// lshr (lshr X, C1), C2 --> lshr X, C1 + C2
548
557
bool IsInnerShl = InnerShift->getOpcode () == Instruction::Shl;
549
558
if (IsInnerShl == IsOuterShl)
550
- return true ;
559
+ return APInt::getLowBitsSet (TypeWidth, TypeWidth - InnerShAmt) ;
551
560
561
+ APInt ShMask = APInt::getZero (TypeWidth);
552
562
// Equal shift amounts in opposite directions become bitwise 'and':
553
563
// lshr (shl X, C), C --> and X, C'
554
564
// shl (lshr X, C), C --> and X, C'
555
- if (*InnerShiftConst == OuterShAmt)
556
- return true ;
565
+ ShMask.setBit (InnerShAmt);
557
566
558
- // If the 2nd shift is bigger than the 1st , we can fold:
567
+ // If the inner shift is bigger than the outer , we can fold:
559
568
// lshr (shl X, C1), C2 --> and (shl X, C1 - C2), C3
560
569
// shl (lshr X, C1), C2 --> and (lshr X, C1 - C2), C3
561
- // but it isn't profitable unless we know the and'd out bits are already zero.
562
- // Also, check that the inner shift is valid (less than the type width) or
563
- // we'll crash trying to produce the bit mask for the 'and'.
564
- unsigned TypeWidth = InnerShift->getType ()->getScalarSizeInBits ();
565
- if (InnerShiftConst->ugt (OuterShAmt) && InnerShiftConst->ult (TypeWidth)) {
566
- unsigned InnerShAmt = InnerShiftConst->getZExtValue ();
567
- unsigned MaskShift =
568
- IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt;
569
- APInt Mask = APInt::getLowBitsSet (TypeWidth, OuterShAmt) << MaskShift;
570
- if (IC.MaskedValueIsZero (InnerShift->getOperand (0 ), Mask, CxtI))
571
- return true ;
572
- }
573
-
574
- return false ;
570
+ // but it isn't profitable unless we know the masked out bits are already
571
+ // zero.
572
+ KnownBits Known = IC.computeKnownBits (InnerShift->getOperand (0 ), CxtI);
573
+ // Isolate the bits that are annihilated by the inner shift.
574
+ APInt InnerShMask = IsInnerShl ? Known.Zero .lshr (TypeWidth - InnerShAmt)
575
+ : Known.Zero .trunc (InnerShAmt);
576
+ // Isolate the upper (resp. lower) InnerShAmt bits of the base operand of the
577
+ // inner shl (resp. lshr).
578
+ // Then:
579
+ // - lshr (shl X, C1), C2 == (shl X, C1 - C2) if the bottom C2 of the isolated
580
+ // bits are zero
581
+ // - shl (lshr X, C1), C2 == (lshr X, C1 - C2) if the top C2 of the isolated
582
+ // bits are zero
583
+ const unsigned MaxOuterShAmt =
584
+ IsInnerShl ? Known.Zero .lshr (TypeWidth - InnerShAmt).countr_one ()
585
+ : Known.Zero .trunc (InnerShAmt).countl_one ();
586
+ ShMask.setLowBits (MaxOuterShAmt);
587
+ return ShMask;
575
588
}
576
589
577
- // / See if we can compute the specified value, but shifted logically to the left
578
- // / or right by some number of bits. This should return true if the expression
579
- // / can be computed for the same cost as the current expression tree. This is
580
- // / used to eliminate extraneous shifting from things like:
581
- // / %C = shl i128 %A, 64
582
- // / %D = shl i128 %B, 96
583
- // / %E = or i128 %C, %D
584
- // / %F = lshr i128 %E, 64
585
- // / where the client will ask if E can be computed shifted right by 64-bits. If
586
- // / this succeeds, getShiftedValue() will be called to produce the value.
587
- static bool canEvaluateShifted (Value *V, unsigned NumBits, bool IsLeftShift,
588
- InstCombinerImpl &IC, Instruction *CxtI) {
590
+ // / Given a bitmask \p ShiftMask of desired shift amounts, determine the submask
591
+ // / of bits corresponding to shift amounts X for which the given expression \p V
592
+ // / can be computed for at worst the same cost as the current expression tree
593
+ // / when shifted by X. For each set bit in the \p ShiftMask afterward,
594
+ // / getShiftedValue() can produce the corresponding value.
595
+ // /
596
+ // / \returns true if and only if at least one bit of the \p ShiftMask is set
597
+ // / after refinement.
598
+ static bool refineEvaluableShiftMask (Value *V, APInt &ShiftMask,
599
+ bool IsLeftShift, InstCombinerImpl &IC,
600
+ Instruction *CxtI) {
589
601
// We can always evaluate immediate constants.
590
602
if (match (V, m_ImmConstant ()))
591
603
return true ;
592
604
593
605
Instruction *I = dyn_cast<Instruction>(V);
594
- if (!I) return false ;
606
+ if (!I) {
607
+ ShiftMask.clearAllBits ();
608
+ return false ;
609
+ }
595
610
596
611
// We can't mutate something that has multiple uses: doing so would
597
612
// require duplicating the instruction in general, which isn't profitable.
598
- if (!I->hasOneUse ()) return false ;
613
+ if (!I->hasOneUse ()) {
614
+ ShiftMask.clearAllBits ();
615
+ return false ;
616
+ }
599
617
600
618
switch (I->getOpcode ()) {
601
- default : return false ;
619
+ default : {
620
+ ShiftMask.clearAllBits ();
621
+ return false ;
622
+ }
602
623
case Instruction::And:
603
624
case Instruction::Or:
604
625
case Instruction::Xor:
605
- // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
606
- return canEvaluateShifted (I->getOperand (0 ), NumBits, IsLeftShift, IC, I) &&
607
- canEvaluateShifted (I->getOperand (1 ), NumBits, IsLeftShift, IC, I);
626
+ return refineEvaluableShiftMask (I->getOperand (0 ), ShiftMask, IsLeftShift,
627
+ IC, I) &&
628
+ refineEvaluableShiftMask (I->getOperand (1 ), ShiftMask, IsLeftShift,
629
+ IC, I);
608
630
609
631
case Instruction::Shl:
610
- case Instruction::LShr:
611
- return canEvaluateShiftedShift (NumBits, IsLeftShift, I, IC, CxtI);
632
+ case Instruction::LShr: {
633
+ ShiftMask &= getEvaluableShiftedShiftMask (IsLeftShift, I, IC, CxtI);
634
+ return !ShiftMask.isZero ();
635
+ }
612
636
613
637
case Instruction::Select: {
614
638
SelectInst *SI = cast<SelectInst>(I);
615
639
Value *TrueVal = SI->getTrueValue ();
616
640
Value *FalseVal = SI->getFalseValue ();
617
- return canEvaluateShifted (TrueVal, NumBits , IsLeftShift, IC, SI) &&
618
- canEvaluateShifted (FalseVal, NumBits , IsLeftShift, IC, SI);
641
+ return refineEvaluableShiftMask (TrueVal, ShiftMask , IsLeftShift, IC, SI) &&
642
+ refineEvaluableShiftMask (FalseVal, ShiftMask , IsLeftShift, IC, SI);
619
643
}
620
644
case Instruction::PHI: {
621
645
// We can change a phi if we can change all operands. Note that we never
622
646
// get into trouble with cyclic PHIs here because we only consider
623
647
// instructions with a single use.
624
648
PHINode *PN = cast<PHINode>(I);
625
649
for (Value *IncValue : PN->incoming_values ())
626
- if (!canEvaluateShifted (IncValue, NumBits , IsLeftShift, IC, PN))
650
+ if (!refineEvaluableShiftMask (IncValue, ShiftMask , IsLeftShift, IC, PN))
627
651
return false ;
628
652
return true ;
629
653
}
630
654
case Instruction::Mul: {
631
655
const APInt *MulConst;
632
656
// We can fold (shr (mul X, -(1 << C)), C) -> (and (neg X), C`)
633
- return !IsLeftShift && match (I->getOperand (1 ), m_APInt (MulConst)) &&
634
- MulConst->isNegatedPowerOf2 () && MulConst->countr_zero () == NumBits;
657
+ if (IsLeftShift || !match (I->getOperand (1 ), m_APInt (MulConst)) ||
658
+ !MulConst->isNegatedPowerOf2 ()) {
659
+ ShiftMask.clearAllBits ();
660
+ return false ;
661
+ }
662
+ ShiftMask &=
663
+ APInt::getOneBitSet (ShiftMask.getBitWidth (), MulConst->countr_zero ());
664
+ return !ShiftMask.isZero ();
635
665
}
636
666
}
637
667
}
638
668
669
+ // / See if we can compute the specified value, but shifted logically to the left
670
+ // / or right by some number of bits. This should return true if the expression
671
+ // / can be computed for the same cost as the current expression tree. This is
672
+ // / used to eliminate extraneous shifting from things like:
673
+ // / %C = shl i128 %A, 64
674
+ // / %D = shl i128 %B, 96
675
+ // / %E = or i128 %C, %D
676
+ // / %F = lshr i128 %E, 64
677
+ // / where the client will ask if E can be computed shifted right by 64-bits. If
678
+ // / this succeeds, getShiftedValue() will be called to produce the value.
679
+ static bool canEvaluateShifted (Value *V, unsigned ShAmt, bool IsLeftShift,
680
+ InstCombinerImpl &IC, Instruction *CxtI) {
681
+ APInt ShiftMask =
682
+ APInt::getOneBitSet (V->getType ()->getScalarSizeInBits (), ShAmt);
683
+ return refineEvaluableShiftMask (V, ShiftMask, IsLeftShift, IC, CxtI);
684
+ }
685
+
639
686
// / Fold OuterShift (InnerShift X, C1), C2.
640
687
// / See canEvaluateShiftedShift() for the constraints on these instructions.
641
688
static Value *foldShiftedShift (BinaryOperator *InnerShift, unsigned OuterShAmt,
@@ -985,37 +1032,32 @@ static Instruction *foldShrThroughZExtedShl(BinaryOperator &I, Value *Op,
985
1032
InstCombinerImpl &IC,
986
1033
const DataLayout &DL) {
987
1034
Type *DestTy = I.getType ();
1035
+ const unsigned InnerBitWidth = Op->getType ()->getScalarSizeInBits ();
988
1036
989
- auto *Inner = dyn_cast<Instruction>(Op);
990
- if (!Inner)
1037
+ // Determine if the operand is effectively right-shifted by counting the
1038
+ // known leading zero bits.
1039
+ KnownBits Known = IC.computeKnownBits (Op, nullptr );
1040
+ const unsigned MaxInnerShrAmt = Known.countMinLeadingZeros ();
1041
+ if (MaxInnerShrAmt == 0 )
991
1042
return nullptr ;
1043
+ APInt ShrMask =
1044
+ APInt::getLowBitsSet (InnerBitWidth, std::min (MaxInnerShrAmt, ShlAmt) + 1 );
992
1045
993
- // Dig through operations until the first shift.
994
- while (!Inner->isShift ())
995
- if (!match (Inner, m_BinOp (m_OneUse (m_Instruction (Inner)), m_Constant ())))
996
- return nullptr ;
997
-
998
- // Fold only if the inner shift is a logical right-shift.
999
- const APInt *InnerShrConst;
1000
- if (!match (Inner, m_LShr (m_Value (), m_APInt (InnerShrConst))))
1046
+ // Undo the maximal inner right shift amount that simplifies the overall
1047
+ // computation.
1048
+ if (!refineEvaluableShiftMask (Op, ShrMask, /* IsLeftShift=*/ true , IC, nullptr ))
1001
1049
return nullptr ;
1002
1050
1003
- const uint64_t InnerShrAmt = InnerShrConst->getZExtValue ();
1004
- if (InnerShrAmt >= ShlAmt) {
1005
- const uint64_t ReducedShrAmt = InnerShrAmt - ShlAmt;
1006
- if (!canEvaluateShifted (Op, ReducedShrAmt, /* IsLeftShift=*/ false , IC,
1007
- nullptr ))
1008
- return nullptr ;
1009
- Value *NewOp =
1010
- getShiftedValue (Op, ReducedShrAmt, /* isLeftShift=*/ false , IC, DL);
1011
- return new ZExtInst (NewOp, DestTy);
1012
- }
1013
-
1014
- if (!canEvaluateShifted (Op, InnerShrAmt, /* IsLeftShift=*/ true , IC, nullptr ))
1051
+ const unsigned InnerShrAmt = ShrMask.getActiveBits () - 1 ;
1052
+ if (InnerShrAmt == 0 )
1015
1053
return nullptr ;
1054
+ assert (InnerShrAmt <= ShlAmt);
1016
1055
1017
1056
const uint64_t ReducedShlAmt = ShlAmt - InnerShrAmt;
1018
1057
Value *NewOp = getShiftedValue (Op, InnerShrAmt, /* isLeftShift=*/ true , IC, DL);
1058
+ if (ReducedShlAmt == 0 )
1059
+ return new ZExtInst (NewOp, DestTy);
1060
+
1019
1061
Value *NewZExt = IC.Builder .CreateZExt (NewOp, DestTy);
1020
1062
NewZExt->takeName (I.getOperand (0 ));
1021
1063
auto *NewShl = BinaryOperator::CreateShl (
0 commit comments