diff --git a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp index 02ec1d5c259cd6..9f24181d5d1f6d 100644 --- a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -324,6 +324,11 @@ class Vectorizer { Instruction *ChainElem, Instruction *ChainBegin, const DenseMap &ChainOffsets); + /// Merges the equivalence classes if they have underlying objects that differ + /// by one level of indirection (i.e., one is a getelementptr and the other is + /// the base pointer in that getelementptr). + void mergeEquivalenceClasses(EquivalenceClassMap &EQClasses) const; + /// Collects loads and stores grouped by "equivalence class", where: /// - all elements in an eq class are a load or all are a store, /// - they all load/store the same element size (it's OK to have e.g. i8 and @@ -1305,6 +1310,123 @@ std::optional Vectorizer::getConstantOffsetSelects( return std::nullopt; } +void Vectorizer::mergeEquivalenceClasses(EquivalenceClassMap &EQClasses) const { + if (EQClasses.size() < 2) // There is nothing to merge. + return; + + // The reduced key has all elements of the ECClassKey except the underlying + // object. Check that EqClassKey has 4 elements and define the reduced key. + static_assert(std::tuple_size_v == 4, + "EqClassKey has changed - EqClassReducedKey needs changes too"); + using EqClassReducedKey = + std::tuple /* AddrSpace */, + std::tuple_element_t<2, EqClassKey> /* Element size */, + std::tuple_element_t<3, EqClassKey> /* IsLoad; */>; + using ECReducedKeyToUnderlyingObjectMap = + MapVector, 4>>; + + // Form a map from the reduced key (without the underlying object) to the + // underlying objects: 1 reduced key to many underlying objects, to form + // groups of potentially merge-able equivalence classes. + ECReducedKeyToUnderlyingObjectMap RedKeyToUOMap; + bool FoundPotentiallyOptimizableEC = false; + for (const auto &EC : EQClasses) { + const auto &Key = EC.first; + EqClassReducedKey RedKey{std::get<1>(Key), std::get<2>(Key), + std::get<3>(Key)}; + RedKeyToUOMap[RedKey].insert(std::get<0>(Key)); + if (RedKeyToUOMap[RedKey].size() > 1) + FoundPotentiallyOptimizableEC = true; + } + if (!FoundPotentiallyOptimizableEC) + return; + + LLVM_DEBUG({ + dbgs() << "LSV: mergeEquivalenceClasses: before merging:\n"; + for (const auto &EC : EQClasses) { + dbgs() << " Key: ([" << std::get<0>(EC.first) + << "]: " << *std::get<0>(EC.first) << ", " << std::get<1>(EC.first) + << ", " << std::get<2>(EC.first) << ", " + << static_cast(std::get<3>(EC.first)) << ")\n"; + for (const auto &Inst : EC.second) + dbgs() << "\tInst: " << *Inst << '\n'; + } + }); + LLVM_DEBUG({ + dbgs() << "LSV: mergeEquivalenceClasses: RedKeyToUOMap:\n"; + for (const auto &RedKeyToUO : RedKeyToUOMap) { + dbgs() << " Reduced key: (" << std::get<0>(RedKeyToUO.first) << ", " + << std::get<1>(RedKeyToUO.first) << ", " + << static_cast(std::get<2>(RedKeyToUO.first)) << ") --> " + << RedKeyToUO.second.size() << " underlying objects:\n"; + for (auto UObject : RedKeyToUO.second) + dbgs() << " [" << UObject << "]: " << *UObject << '\n'; + } + }); + + using UObjectToUObjectMap = DenseMap; + + // Compute the ultimate targets for a set of underlying objects. + auto GetUltimateTargets = + [](SmallPtrSetImpl &UObjects) -> UObjectToUObjectMap { + UObjectToUObjectMap IndirectionMap; + for (const auto *UObject : UObjects) { + const unsigned MaxLookupDepth = 1; // look for 1-level indirections only + const auto *UltimateTarget = getUnderlyingObject(UObject, MaxLookupDepth); + if (UltimateTarget != UObject) + IndirectionMap[UObject] = UltimateTarget; + } + UObjectToUObjectMap UltimateTargetsMap; + for (const auto *UObject : UObjects) { + auto Target = UObject; + auto It = IndirectionMap.find(Target); + for (; It != IndirectionMap.end(); It = IndirectionMap.find(Target)) + Target = It->second; + UltimateTargetsMap[UObject] = Target; + } + return UltimateTargetsMap; + }; + + // For each item in RedKeyToUOMap, if it has more than one underlying object, + // try to merge the equivalence classes. + for (auto &[RedKey, UObjects] : RedKeyToUOMap) { + if (UObjects.size() < 2) + continue; + auto UTMap = GetUltimateTargets(UObjects); + for (const auto &[UObject, UltimateTarget] : UTMap) { + if (UObject == UltimateTarget) + continue; + + EqClassKey KeyFrom{UObject, std::get<0>(RedKey), std::get<1>(RedKey), + std::get<2>(RedKey)}; + EqClassKey KeyTo{UltimateTarget, std::get<0>(RedKey), std::get<1>(RedKey), + std::get<2>(RedKey)}; + const auto &VecFrom = EQClasses[KeyFrom]; + const auto &VecTo = EQClasses[KeyTo]; + SmallVector MergedVec; + std::merge(VecFrom.begin(), VecFrom.end(), VecTo.begin(), VecTo.end(), + std::back_inserter(MergedVec), + [](Instruction *A, Instruction *B) { + return A && B && A->comesBefore(B); + }); + EQClasses[KeyTo] = std::move(MergedVec); + EQClasses.erase(KeyFrom); + } + } + LLVM_DEBUG({ + dbgs() << "LSV: mergeEquivalenceClasses: after merging:\n"; + for (const auto &EC : EQClasses) { + dbgs() << " Key: ([" << std::get<0>(EC.first) + << "]: " << *std::get<0>(EC.first) << ", " << std::get<1>(EC.first) + << ", " << std::get<2>(EC.first) << ", " + << static_cast(std::get<3>(EC.first)) << ")\n"; + for (const auto &Inst : EC.second) + dbgs() << "\tInst: " << *Inst << '\n'; + } + }); +} + EquivalenceClassMap Vectorizer::collectEquivalenceClasses(BasicBlock::iterator Begin, BasicBlock::iterator End) { @@ -1377,6 +1499,7 @@ Vectorizer::collectEquivalenceClasses(BasicBlock::iterator Begin, .emplace_back(&I); } + mergeEquivalenceClasses(Ret); return Ret; } diff --git a/llvm/test/Transforms/LoadStoreVectorizer/X86/massive_indirection.ll b/llvm/test/Transforms/LoadStoreVectorizer/X86/massive_indirection.ll new file mode 100644 index 00000000000000..c4b0d2e311d9d7 --- /dev/null +++ b/llvm/test/Transforms/LoadStoreVectorizer/X86/massive_indirection.ll @@ -0,0 +1,142 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt %s -mtriple=x86_64-unknown-linux-gnu -passes=load-store-vectorizer -mcpu=skx -S -o - | FileCheck %s + +; This test verifies that the vectorizer can handle an extended sequence of +; getelementptr instructions and generate longer vectors. With special handling, +; some elements can still be vectorized even if they require looking up the +; common underlying object deeper than 6 levels from the original pointer. + +; The test below is the simplified version of actual performance oriented +; workload; the offsets in getelementptr instructions are similar or same for +; the test simplicity. + +define void @v1_v2_v4_v1_to_v8_levels_6_7_8_8(i32 %arg0, ptr align 16 %arg1) { +; CHECK-LABEL: define void @v1_v2_v4_v1_to_v8_levels_6_7_8_8( +; CHECK-SAME: i32 [[ARG0:%.*]], ptr align 16 [[ARG1:%.*]]) #[[ATTR0:[0-9]+]] { +; CHECK-NEXT: [[LEVEL1:%.*]] = getelementptr i8, ptr [[ARG1]], i32 917504 +; CHECK-NEXT: [[LEVEL2:%.*]] = getelementptr i8, ptr [[LEVEL1]], i32 [[ARG0]] +; CHECK-NEXT: [[LEVEL3:%.*]] = getelementptr i8, ptr [[LEVEL2]], i32 32768 +; CHECK-NEXT: [[LEVEL4:%.*]] = getelementptr i8, ptr [[LEVEL3]], i32 [[ARG0]] +; CHECK-NEXT: [[LEVEL5:%.*]] = getelementptr i8, ptr [[LEVEL4]], i32 [[ARG0]] +; CHECK-NEXT: [[A6:%.*]] = getelementptr i8, ptr [[LEVEL5]], i32 [[ARG0]] +; CHECK-NEXT: store <8 x half> zeroinitializer, ptr [[A6]], align 16 +; CHECK-NEXT: ret void +; + + %level1 = getelementptr i8, ptr %arg1, i32 917504 + %level2 = getelementptr i8, ptr %level1, i32 %arg0 + %level3 = getelementptr i8, ptr %level2, i32 32768 + %level4 = getelementptr i8, ptr %level3, i32 %arg0 + %level5 = getelementptr i8, ptr %level4, i32 %arg0 + + %a6 = getelementptr i8, ptr %level5, i32 %arg0 + %b7 = getelementptr i8, ptr %a6, i32 2 + %c8 = getelementptr i8, ptr %b7, i32 8 + %d8 = getelementptr i8, ptr %b7, i32 12 + + store half 0xH0000, ptr %a6, align 16 + store <4 x half> zeroinitializer, ptr %b7, align 2 + store <2 x half> zeroinitializer, ptr %c8, align 2 + store half 0xH0000, ptr %d8, align 2 + ret void +} + +define void @v1x8_levels_6_7_8_9_10_11_12_13(i32 %arg0, ptr align 16 %arg1) { +; CHECK-LABEL: define void @v1x8_levels_6_7_8_9_10_11_12_13( +; CHECK-SAME: i32 [[ARG0:%.*]], ptr align 16 [[ARG1:%.*]]) #[[ATTR0]] { +; CHECK-NEXT: [[LEVEL1:%.*]] = getelementptr i8, ptr [[ARG1]], i32 917504 +; CHECK-NEXT: [[LEVEL2:%.*]] = getelementptr i8, ptr [[LEVEL1]], i32 [[ARG0]] +; CHECK-NEXT: [[LEVEL3:%.*]] = getelementptr i8, ptr [[LEVEL2]], i32 32768 +; CHECK-NEXT: [[LEVEL4:%.*]] = getelementptr i8, ptr [[LEVEL3]], i32 [[ARG0]] +; CHECK-NEXT: [[LEVEL5:%.*]] = getelementptr i8, ptr [[LEVEL4]], i32 [[ARG0]] +; CHECK-NEXT: [[A6:%.*]] = getelementptr i8, ptr [[LEVEL5]], i32 [[ARG0]] +; CHECK-NEXT: store <8 x half> zeroinitializer, ptr [[A6]], align 16 +; CHECK-NEXT: ret void +; + + %level1 = getelementptr i8, ptr %arg1, i32 917504 + %level2 = getelementptr i8, ptr %level1, i32 %arg0 + %level3 = getelementptr i8, ptr %level2, i32 32768 + %level4 = getelementptr i8, ptr %level3, i32 %arg0 + %level5 = getelementptr i8, ptr %level4, i32 %arg0 + + %a6 = getelementptr i8, ptr %level5, i32 %arg0 + %b7 = getelementptr i8, ptr %a6, i32 2 + %c8 = getelementptr i8, ptr %b7, i32 2 + %d9 = getelementptr i8, ptr %c8, i32 2 + %e10 = getelementptr i8, ptr %d9, i32 2 + %f11 = getelementptr i8, ptr %e10, i32 2 + %g12 = getelementptr i8, ptr %f11, i32 2 + %h13 = getelementptr i8, ptr %g12, i32 2 + + store half 0xH0000, ptr %a6, align 16 + store half 0xH0000, ptr %b7, align 2 + store half 0xH0000, ptr %c8, align 2 + store half 0xH0000, ptr %d9, align 2 + store half 0xH0000, ptr %e10, align 8 + store half 0xH0000, ptr %f11, align 2 + store half 0xH0000, ptr %g12, align 2 + store half 0xH0000, ptr %h13, align 2 + ret void +} + +define void @v1_4_4_4_2_1_to_v8_8_levels_6_7(i32 %arg0, ptr addrspace(3) align 16 %arg1_ptr, i32 %arg2, i32 %arg3, i32 %arg4, i32 %arg5, half %arg6_half, half %arg7_half, <2 x half> %arg8_2xhalf) { +; CHECK-LABEL: define void @v1_4_4_4_2_1_to_v8_8_levels_6_7( +; CHECK-SAME: i32 [[ARG0:%.*]], ptr addrspace(3) align 16 [[ARG1_PTR:%.*]], i32 [[ARG2:%.*]], i32 [[ARG3:%.*]], i32 [[ARG4:%.*]], i32 [[ARG5:%.*]], half [[ARG6_HALF:%.*]], half [[ARG7_HALF:%.*]], <2 x half> [[ARG8_2XHALF:%.*]]) #[[ATTR0]] { +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i8, ptr addrspace(3) [[ARG1_PTR]], i32 458752 +; CHECK-NEXT: br [[DOTPREHEADER11_PREHEADER:label %.*]] +; CHECK: [[_PREHEADER11_PREHEADER:.*:]] +; CHECK-NEXT: [[TMP2:%.*]] = shl nuw nsw i32 [[ARG0]], 6 +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i8, ptr addrspace(3) [[TMP1]], i32 [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, ptr addrspace(3) [[TMP3]], i32 [[ARG2]] +; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds i8, ptr addrspace(3) [[TMP4]], i32 [[ARG3]] +; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[ARG0]], 2 +; CHECK-NEXT: br i1 [[CMP]], [[DOTLR_PH:label %.*]], [[DOTEXIT_POINT:label %.*]] +; CHECK: [[_LR_PH:.*:]] +; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i8, ptr addrspace(3) [[TMP5]], i32 [[ARG4]] +; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds i8, ptr addrspace(3) [[GEP]], i32 [[ARG5]] +; CHECK-NEXT: [[TMP7:%.*]] = insertelement <8 x half> poison, half [[ARG6_HALF]], i32 0 +; CHECK-NEXT: [[TMP8:%.*]] = insertelement <8 x half> [[TMP7]], half 0xH0000, i32 1 +; CHECK-NEXT: [[TMP9:%.*]] = insertelement <8 x half> [[TMP8]], half 0xH0000, i32 2 +; CHECK-NEXT: [[TMP10:%.*]] = insertelement <8 x half> [[TMP9]], half 0xH0000, i32 3 +; CHECK-NEXT: [[TMP11:%.*]] = insertelement <8 x half> [[TMP10]], half 0xH0000, i32 4 +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <2 x half> [[ARG8_2XHALF]], i32 0 +; CHECK-NEXT: [[TMP13:%.*]] = insertelement <8 x half> [[TMP11]], half [[TMP12]], i32 5 +; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x half> [[ARG8_2XHALF]], i32 1 +; CHECK-NEXT: [[TMP15:%.*]] = insertelement <8 x half> [[TMP13]], half [[TMP14]], i32 6 +; CHECK-NEXT: [[TMP16:%.*]] = insertelement <8 x half> [[TMP15]], half [[ARG7_HALF]], i32 7 +; CHECK-NEXT: store <8 x half> [[TMP16]], ptr addrspace(3) [[TMP6]], align 2 +; CHECK-NEXT: br [[DOTEXIT_POINT]] +; CHECK: [[_EXIT_POINT:.*:]] +; CHECK-NEXT: ret void +; + %base1 = getelementptr inbounds i8, ptr addrspace(3) %arg1_ptr, i32 458752 + br label %.preheader11.preheader + +.preheader11.preheader: + %base2 = shl nuw nsw i32 %arg0, 6 + %base3 = getelementptr inbounds i8, ptr addrspace(3) %base1, i32 %base2 + + %base4 = getelementptr inbounds i8, ptr addrspace(3) %base3, i32 %arg2 + %base5 = getelementptr inbounds i8, ptr addrspace(3) %base4, i32 %arg3 + + %cmp = icmp sgt i32 %arg0, 2 + br i1 %cmp, label %.lr.ph, label %.exit_point + +.lr.ph: + %gep = getelementptr inbounds i8, ptr addrspace(3) %base5, i32 %arg4 + + %dst = getelementptr inbounds i8, ptr addrspace(3) %gep, i32 %arg5 + %dst_off2 = getelementptr inbounds i8, ptr addrspace(3) %dst, i32 2 + %dst_off10 = getelementptr inbounds i8, ptr addrspace(3) %dst, i32 10 + %dst_off14 = getelementptr inbounds i8, ptr addrspace(3) %dst, i32 14 + + store half %arg6_half, ptr addrspace(3) %dst, align 2 + store <4 x half> zeroinitializer, ptr addrspace(3) %dst_off2, align 2 + store <2 x half> %arg8_2xhalf, ptr addrspace(3) %dst_off10, align 2 + store half %arg7_half, ptr addrspace(3) %dst_off14, align 2 + br label %.exit_point + +.exit_point: + ret void +}