diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h index d21cc962da46cb..fc8ef717c74f6a 100644 --- a/llvm/include/llvm/CodeGen/SDPatternMatch.h +++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h @@ -547,6 +547,39 @@ struct BinaryOpc_match { } }; +/// Matching while capturing mask +template struct SDShuffle_match { + T0 Op1; + T1 Op2; + T2 Mask; + + SDShuffle_match(const T0 &Op1, const T1 &Op2, const T2 &Mask) + : Op1(Op1), Op2(Op2), Mask(Mask) {} + + template + bool match(const MatchContext &Ctx, SDValue N) { + if (auto *I = dyn_cast(N)) { + return Op1.match(Ctx, I->getOperand(0)) && + Op2.match(Ctx, I->getOperand(1)) && Mask.match(I->getMask()); + } + return false; + } +}; +struct m_Mask { + ArrayRef &MaskRef; + m_Mask(ArrayRef &MaskRef) : MaskRef(MaskRef) {} + bool match(ArrayRef Mask) { + MaskRef = Mask; + return true; + } +}; + +struct m_SpecificMask { + ArrayRef MaskRef; + m_SpecificMask(ArrayRef MaskRef) : MaskRef(MaskRef) {} + bool match(ArrayRef Mask) { return MaskRef == Mask; } +}; + template struct MaxMin_match { @@ -797,6 +830,17 @@ inline BinaryOpc_match m_FRem(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::FREM, L, R); } +template +inline BinaryOpc_match m_Shuffle(const V1_t &v1, const V2_t &v2) { + return BinaryOpc_match(ISD::VECTOR_SHUFFLE, v1, v2); +} + +template +inline SDShuffle_match +m_Shuffle(const V1_t &v1, const V2_t &v2, const Mask_t &mask) { + return SDShuffle_match(v1, v2, mask); +} + template inline BinaryOpc_match m_ExtractElt(const LHS &Vec, const RHS &Idx) { return BinaryOpc_match(ISD::EXTRACT_VECTOR_ELT, Vec, Idx); diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp index 259bdad0ab2723..a2e1e588d03dea 100644 --- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp @@ -119,6 +119,33 @@ TEST_F(SelectionDAGPatternMatchTest, matchValueType) { EXPECT_FALSE(sd_match(Op2, m_ScalableVectorVT())); } +TEST_F(SelectionDAGPatternMatchTest, matchVecShuffle) { + SDLoc DL; + auto Int32VT = EVT::getIntegerVT(Context, 32); + auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4); + const std::array MaskData = {2, 0, 3, 1}; + const std::array OtherMaskData = {1, 2, 3, 4}; + ArrayRef Mask; + + SDValue V0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT); + SDValue V1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT); + SDValue VecShuffleWithMask = + DAG->getVectorShuffle(VInt32VT, DL, V0, V1, MaskData); + + using namespace SDPatternMatch; + EXPECT_TRUE(sd_match(VecShuffleWithMask, m_Shuffle(m_Value(), m_Value()))); + EXPECT_TRUE(sd_match(VecShuffleWithMask, + m_Shuffle(m_Value(), m_Value(), m_Mask(Mask)))); + EXPECT_TRUE( + sd_match(VecShuffleWithMask, + m_Shuffle(m_Value(), m_Value(), m_SpecificMask(MaskData)))); + EXPECT_FALSE( + sd_match(VecShuffleWithMask, + m_Shuffle(m_Value(), m_Value(), m_SpecificMask(OtherMaskData)))); + EXPECT_TRUE( + std::equal(MaskData.begin(), MaskData.end(), Mask.begin(), Mask.end())); +} + TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) { SDLoc DL; auto Int32VT = EVT::getIntegerVT(Context, 32);