From f3bc8c34c98a4b1a5361c3148eaeebd51151513f Mon Sep 17 00:00:00 2001 From: Aidan Goldfarb <47676355+AidanGoldfarb@users.noreply.github.com> Date: Mon, 6 Jan 2025 13:57:19 -0500 Subject: [PATCH] Add SD matchers and unit test coverage for ISD::VECTOR_SHUFFLE (#119592) This PR resolves #118845. I aimed to mirror the implementation `m_Shuffle()` in [PatternMatch.h](https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/IR/PatternMatch.h). Updated [SDPatternMatch.h](https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/CodeGen/SDPatternMatch.h) - Added `struct m_Mask` to match masks (`ArrayRef`) - Added two `m_Shuffle` functions. One to match independently of mask, and one to match considering mask. - Added `struct SDShuffle_match` to match `ISD::VECTOR_SHUFFLE` considering mask Updated [SDPatternMatchTest.cpp](https://github.com/llvm/llvm-project/blob/main/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp) - Added `matchVecShuffle` test, which tests the behavior of both `m_Shuffle()` functions - - - I am not sure if my test coverage is complete. I am not sure how to test a `false` match, simply test against a different instruction? [Other tests ](https://github.com/llvm/llvm-project/blob/main/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp#L175), such as for `VSelect`, test against `Select`. I am not sure if there is an analogous instruction to compare against for `VECTOR_SHUFFLE`. I would appreciate some pointers in this area. In general, please liberally critique this PR! --------- Co-authored-by: Aidan --- llvm/include/llvm/CodeGen/SDPatternMatch.h | 44 +++++++++++++++++++ .../CodeGen/SelectionDAGPatternMatchTest.cpp | 27 ++++++++++++ 2 files changed, 71 insertions(+) diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h index d21cc962da46cb..fc8ef717c74f6a 100644 --- a/llvm/include/llvm/CodeGen/SDPatternMatch.h +++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h @@ -547,6 +547,39 @@ struct BinaryOpc_match { } }; +/// Matching while capturing mask +template struct SDShuffle_match { + T0 Op1; + T1 Op2; + T2 Mask; + + SDShuffle_match(const T0 &Op1, const T1 &Op2, const T2 &Mask) + : Op1(Op1), Op2(Op2), Mask(Mask) {} + + template + bool match(const MatchContext &Ctx, SDValue N) { + if (auto *I = dyn_cast(N)) { + return Op1.match(Ctx, I->getOperand(0)) && + Op2.match(Ctx, I->getOperand(1)) && Mask.match(I->getMask()); + } + return false; + } +}; +struct m_Mask { + ArrayRef &MaskRef; + m_Mask(ArrayRef &MaskRef) : MaskRef(MaskRef) {} + bool match(ArrayRef Mask) { + MaskRef = Mask; + return true; + } +}; + +struct m_SpecificMask { + ArrayRef MaskRef; + m_SpecificMask(ArrayRef MaskRef) : MaskRef(MaskRef) {} + bool match(ArrayRef Mask) { return MaskRef == Mask; } +}; + template struct MaxMin_match { @@ -797,6 +830,17 @@ inline BinaryOpc_match m_FRem(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::FREM, L, R); } +template +inline BinaryOpc_match m_Shuffle(const V1_t &v1, const V2_t &v2) { + return BinaryOpc_match(ISD::VECTOR_SHUFFLE, v1, v2); +} + +template +inline SDShuffle_match +m_Shuffle(const V1_t &v1, const V2_t &v2, const Mask_t &mask) { + return SDShuffle_match(v1, v2, mask); +} + template inline BinaryOpc_match m_ExtractElt(const LHS &Vec, const RHS &Idx) { return BinaryOpc_match(ISD::EXTRACT_VECTOR_ELT, Vec, Idx); diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp index 259bdad0ab2723..a2e1e588d03dea 100644 --- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp @@ -119,6 +119,33 @@ TEST_F(SelectionDAGPatternMatchTest, matchValueType) { EXPECT_FALSE(sd_match(Op2, m_ScalableVectorVT())); } +TEST_F(SelectionDAGPatternMatchTest, matchVecShuffle) { + SDLoc DL; + auto Int32VT = EVT::getIntegerVT(Context, 32); + auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4); + const std::array MaskData = {2, 0, 3, 1}; + const std::array OtherMaskData = {1, 2, 3, 4}; + ArrayRef Mask; + + SDValue V0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT); + SDValue V1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT); + SDValue VecShuffleWithMask = + DAG->getVectorShuffle(VInt32VT, DL, V0, V1, MaskData); + + using namespace SDPatternMatch; + EXPECT_TRUE(sd_match(VecShuffleWithMask, m_Shuffle(m_Value(), m_Value()))); + EXPECT_TRUE(sd_match(VecShuffleWithMask, + m_Shuffle(m_Value(), m_Value(), m_Mask(Mask)))); + EXPECT_TRUE( + sd_match(VecShuffleWithMask, + m_Shuffle(m_Value(), m_Value(), m_SpecificMask(MaskData)))); + EXPECT_FALSE( + sd_match(VecShuffleWithMask, + m_Shuffle(m_Value(), m_Value(), m_SpecificMask(OtherMaskData)))); + EXPECT_TRUE( + std::equal(MaskData.begin(), MaskData.end(), Mask.begin(), Mask.end())); +} + TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) { SDLoc DL; auto Int32VT = EVT::getIntegerVT(Context, 32);