diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 1b8a80d2b3c94..88da7015fc770 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -14568,8 +14568,6 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef VectorizedVals, LLVM_DEBUG(dbgs() << "SLP: Calculating cost for tree of size " << VectorizableTree.size() << ".\n"); - unsigned BundleWidth = VectorizableTree[0]->Scalars.size(); - SmallPtrSet CheckedExtracts; for (unsigned I = 0, E = VectorizableTree.size(); I < E; ++I) { TreeEntry &TE = *VectorizableTree[I]; @@ -14632,6 +14630,11 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef VectorizedVals, } SmallDenseSet, 8> CheckedScalarUser; for (ExternalUser &EU : ExternalUses) { + LLVM_DEBUG(dbgs() << "SLP: Computing cost for external use of TreeEntry " + << EU.E.Idx << " in lane " << EU.Lane << "\n"); + LLVM_DEBUG(dbgs() << " User:" << *EU.User << "\n"); + LLVM_DEBUG(dbgs() << " Use: " << EU.Scalar->getNameOrAsOperand() << "\n"); + // Uses by ephemeral values are free (because the ephemeral value will be // removed prior to code generation, and so the extraction will be // removed as well). @@ -14739,6 +14742,8 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef VectorizedVals, // for the extract and the added cost of the sign extend if needed. InstructionCost ExtraCost = TTI::TCC_Free; auto *ScalarTy = EU.Scalar->getType(); + const unsigned BundleWidth = EU.E.getVectorFactor(); + assert(EU.Lane < BundleWidth && "Extracted lane out of bounds."); auto *VecTy = getWidenedType(ScalarTy, BundleWidth); const TreeEntry *Entry = &EU.E; auto It = MinBWs.find(Entry); @@ -14752,10 +14757,14 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef VectorizedVals, VecTy = getWidenedType(MinTy, BundleWidth); ExtraCost = getExtractWithExtendCost(*TTI, Extend, ScalarTy, VecTy, EU.Lane); + LLVM_DEBUG(dbgs() << " ExtractExtend or ExtractSubvec cost: " + << ExtraCost << "\n"); } else { ExtraCost = getVectorInstrCost(*TTI, ScalarTy, Instruction::ExtractElement, VecTy, CostKind, EU.Lane, EU.Scalar, ScalarUserAndIdx); + LLVM_DEBUG(dbgs() << " ExtractElement cost for " << *ScalarTy << " from " + << *VecTy << ": " << ExtraCost << "\n"); } // Leave the scalar instructions as is if they are cheaper than extracts. if (Entry->Idx != 0 || Entry->getOpcode() == Instruction::GetElementPtr || diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/segmented-loads.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/segmented-loads.ll index ce26bd3b89392..e800b5e016b74 100644 --- a/llvm/test/Transforms/SLPVectorizer/RISCV/segmented-loads.ll +++ b/llvm/test/Transforms/SLPVectorizer/RISCV/segmented-loads.ll @@ -31,3 +31,34 @@ define void @test() { store double %res4, ptr getelementptr inbounds ([8 x double], ptr @dst, i32 0, i64 3), align 8 ret void } + +; Same as above, but %a7 is also used as a scalar and must be extracted from +; the wide load. (Or in this case, kept as a scalar load). +define double @test_with_extract() { +; CHECK-LABEL: @test_with_extract( +; CHECK-NEXT: [[TMP1:%.*]] = load <8 x double>, ptr @src, align 8 +; CHECK-NEXT: [[A7:%.*]] = load double, ptr getelementptr inbounds ([8 x double], ptr @src, i32 0, i64 7), align 8 +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x double> [[TMP1]], <8 x double> poison, <4 x i32> +; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <8 x double> [[TMP1]], <8 x double> poison, <4 x i32> +; CHECK-NEXT: [[TMP4:%.*]] = fsub fast <4 x double> [[TMP2]], [[TMP3]] +; CHECK-NEXT: store <4 x double> [[TMP4]], ptr @dst, align 8 +; CHECK-NEXT: ret double [[A7]] +; + %a0 = load double, ptr @src, align 8 + %a1 = load double, ptr getelementptr inbounds ([8 x double], ptr @src, i32 0, i64 1), align 8 + %a2 = load double, ptr getelementptr inbounds ([8 x double], ptr @src, i32 0, i64 2), align 8 + %a3 = load double, ptr getelementptr inbounds ([8 x double], ptr @src, i32 0, i64 3), align 8 + %a4 = load double, ptr getelementptr inbounds ([8 x double], ptr @src, i32 0, i64 4), align 8 + %a5 = load double, ptr getelementptr inbounds ([8 x double], ptr @src, i32 0, i64 5), align 8 + %a6 = load double, ptr getelementptr inbounds ([8 x double], ptr @src, i32 0, i64 6), align 8 + %a7 = load double, ptr getelementptr inbounds ([8 x double], ptr @src, i32 0, i64 7), align 8 + %res1 = fsub fast double %a0, %a1 + %res2 = fsub fast double %a2, %a3 + %res3 = fsub fast double %a4, %a5 + %res4 = fsub fast double %a6, %a7 + store double %res1, ptr @dst, align 8 + store double %res2, ptr getelementptr inbounds ([8 x double], ptr @dst, i32 0, i64 1), align 8 + store double %res3, ptr getelementptr inbounds ([8 x double], ptr @dst, i32 0, i64 2), align 8 + store double %res4, ptr getelementptr inbounds ([8 x double], ptr @dst, i32 0, i64 3), align 8 + ret double %a7 +}