From e488f51590a06fb2e35babad36a3a51246acaac0 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Wed, 8 Oct 2025 01:03:58 -0500 Subject: [PATCH] [CUSPARSE] Interface generic mv! for SparseMatrixBSR --- lib/cusparse/generic.jl | 40 ++++++------- lib/cusparse/level2.jl | 16 ++--- test/libraries/cusparse.jl | 16 +++-- test/libraries/cusparse/generic.jl | 95 +++++++++++++++--------------- 4 files changed, 83 insertions(+), 84 deletions(-) diff --git a/lib/cusparse/generic.jl b/lib/cusparse/generic.jl index 265487068c..eb7abad69e 100644 --- a/lib/cusparse/generic.jl +++ b/lib/cusparse/generic.jl @@ -152,22 +152,16 @@ function vv!(transx::SparseChar, X::CuSparseVector{T}, Y::DenseCuVector{T}, inde return result[] end -function mv!(transa::SparseChar, alpha::Number, A::Union{CuSparseMatrixCSC{TA},CuSparseMatrixCSR{TA},CuSparseMatrixCOO{TA}}, X::DenseCuVector{T}, +function mv!(transa::SparseChar, alpha::Number, A::CuSparseMatrix{TA}, X::DenseCuVector{T}, beta::Number, Y::DenseCuVector{T}, index::SparseChar, algo::cusparseSpMVAlg_t=CUSPARSE_SPMV_ALG_DEFAULT) where {TA, T} + (A isa CuSparseMatrixBSR) && (CUSPARSE.version() < v"12.6.3") && throw(ErrorException("This operation is not supported by the current CUDA version.")) + # Support transa = 'C' for real matrices transa = T <: Real && transa == 'C' ? 'T' : transa - if isa(A, CuSparseMatrixCSC) - # cusparseSpMV completely supports CSC matrices with CUSPARSE.version() ≥ v"12.0". - # We use Aᵀ to model them as CSR matrices for older versions of CUSPARSE. - descA = CuSparseMatrixDescriptor(A, index, transposed=true) - n,m = size(A) - transa = transa == 'N' ? 'T' : 'N' - else - descA = CuSparseMatrixDescriptor(A, index) - m,n = size(A) - end + descA = CuSparseMatrixDescriptor(A, index) + m,n = size(A) if transa == 'N' chkmvdims(X,n,Y,m) @@ -318,12 +312,12 @@ function bmm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparse return out[] end with_workspace(bufferSize) do buffer - # We should find a way to reuse the buffer (issue #1362) - if !(A isa CuSparseMatrixCOO) - cusparseSpMM_preprocess( - handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta), - descC, T, algo, buffer) - end + # Uncomment if we find a way to reuse the buffer (issue #1362) + # if !(A isa CuSparseMatrixCOO) + # cusparseSpMM_preprocess( + # handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta), + # descC, T, algo, buffer) + # end cusparseSpMM( handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta), descC, T, algo, buffer) @@ -372,12 +366,12 @@ function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::DenseCuMa return out[] end with_workspace(bufferSize) do buffer - # We should find a way to reuse the buffer (issue #1362) - if !(B isa CuSparseMatrixCOO) - cusparseSpMM_preprocess( - handle(), transb, transa, Ref{T}(alpha), descB, descA, Ref{T}(beta), - descC, T, algo, buffer) - end + # Uncomment if we find a way to reuse the buffer (issue #1362) + # if !(B isa CuSparseMatrixCOO) + # cusparseSpMM_preprocess( + # handle(), transb, transa, Ref{T}(alpha), descB, descA, Ref{T}(beta), + # descC, T, algo, buffer) + # end cusparseSpMM( handle(), transb, transa, Ref{T}(alpha), descB, descA, Ref{T}(beta), descC, T, algo, buffer) diff --git a/lib/cusparse/level2.jl b/lib/cusparse/level2.jl index f18257e847..37a9bd686b 100644 --- a/lib/cusparse/level2.jl +++ b/lib/cusparse/level2.jl @@ -1,20 +1,20 @@ # sparse linear algebra functions that perform operations between sparse matrices and dense # vectors -export sv2!, sv2, gemvi! +export sv2!, sv2, mv2!, gemvi! for (fname,elty) in ((:cusparseSbsrmv, :Float32), (:cusparseDbsrmv, :Float64), (:cusparseCbsrmv, :ComplexF32), (:cusparseZbsrmv, :ComplexF64)) @eval begin - function mv!(transa::SparseChar, - alpha::Number, - A::CuSparseMatrixBSR{$elty}, - X::CuVector{$elty}, - beta::Number, - Y::CuVector{$elty}, - index::SparseChar) + function mv2!(transa::SparseChar, + alpha::Number, + A::CuSparseMatrixBSR{$elty}, + X::CuVector{$elty}, + beta::Number, + Y::CuVector{$elty}, + index::SparseChar) # Support transa = 'C' for real matrices transa = $elty <: Real && transa == 'C' ? 'T' : transa diff --git a/test/libraries/cusparse.jl b/test/libraries/cusparse.jl index 8b6d064fbc..8610013dc6 100644 --- a/test/libraries/cusparse.jl +++ b/test/libraries/cusparse.jl @@ -756,8 +756,7 @@ end alpha = rand(elty) beta = rand(elty) @testset "$(typeof(d_A))" for d_A in [CuSparseMatrixCSR(A), - CuSparseMatrixCSC(A), - CuSparseMatrixBSR(A, blockdim)] + CuSparseMatrixCSC(A)] d_x = CuArray(x) d_y = CuArray(y) @test_throws DimensionMismatch CUSPARSE.mv!('T',alpha,d_A,d_x,beta,d_y,'O') @@ -766,9 +765,16 @@ end h_z = collect(d_y) z = alpha * A * x + beta * y @test z ≈ h_z - #if d_A isa CuSparseMatrixCSR - # @test d_y' * (d_A * d_x) ≈ (d_y' * d_A) * d_x - #end + end + @testset "$(typeof(d_A))" for d_A in [CuSparseMatrixBSR(A, blockdim)] + d_x = CuArray(x) + d_y = CuArray(y) + @test_throws DimensionMismatch CUSPARSE.mv2!('T',alpha,d_A,d_x,beta,d_y,'O') + @test_throws DimensionMismatch CUSPARSE.mv2!('N',alpha,d_A,d_y,beta,d_x,'O') + CUSPARSE.mv2!('N',alpha,d_A,d_x,beta,d_y,'O') + h_z = collect(d_y) + z = alpha * A * x + beta * y + @test z ≈ h_z end end diff --git a/test/libraries/cusparse/generic.jl b/test/libraries/cusparse/generic.jl index 5bbe0c09da..7843fd40b9 100644 --- a/test/libraries/cusparse/generic.jl +++ b/test/libraries/cusparse/generic.jl @@ -2,7 +2,7 @@ using CUDA.CUSPARSE using SparseArrays using LinearAlgebra -@testset "generic mv!" for T in [Float32, Float64] +@testset "generic mv! -- $T" for T in [Float32, Float64] m = 10 A = sprand(T, m, m, 0.1) x = rand(Complex{T}, m) @@ -17,7 +17,7 @@ using LinearAlgebra dA = CuSparseMatrixCSR(dA) mv!('N', one(T), dA, dx, zero(T), dy, 'O') @test Array(dy) ≈ A * x - + A_bad = sprand(T, m+1, m, 0.1) dA_bad = adapt(CuArray, A_bad) @test_throws DimensionMismatch("Y must have length $(m+1), but has length $m") mv!('N', one(T), dA_bad, dx, zero(T), dy, 'O') @@ -32,9 +32,7 @@ SPMV_ALGOS = Dict(CuSparseMatrixCSC => [CUSPARSE.CUSPARSE_SPMV_ALG_DEFAULT], CUSPARSE.CUSPARSE_SPMV_CSR_ALG1, CUSPARSE.CUSPARSE_SPMV_CSR_ALG2], CuSparseMatrixCOO => [CUSPARSE.CUSPARSE_SPMV_ALG_DEFAULT, - CUSPARSE.CUSPARSE_SPMV_COO_ALG1, - ], - ) + CUSPARSE.CUSPARSE_SPMV_COO_ALG1]) SPMM_ALGOS = Dict(CuSparseMatrixCSC => [CUSPARSE.CUSPARSE_SPMM_ALG_DEFAULT], CuSparseMatrixCSR => [CUSPARSE.CUSPARSE_SPMM_ALG_DEFAULT, @@ -43,10 +41,9 @@ SPMM_ALGOS = Dict(CuSparseMatrixCSC => [CUSPARSE.CUSPARSE_SPMM_ALG_DEFAULT], CUSPARSE.CUSPARSE_SPMM_CSR_ALG3], CuSparseMatrixCOO => [CUSPARSE.CUSPARSE_SPMM_ALG_DEFAULT, CUSPARSE.CUSPARSE_SPMM_COO_ALG1, + CUSPARSE.CUSPARSE_SPMM_COO_ALG2, CUSPARSE.CUSPARSE_SPMM_COO_ALG3, - CUSPARSE.CUSPARSE_SPMM_COO_ALG4] - ) - + CUSPARSE.CUSPARSE_SPMM_COO_ALG4]) if CUSPARSE.version() >= v"12.1.3" push!(SPMV_ALGOS[CuSparseMatrixCOO], CUSPARSE.CUSPARSE_SPMV_COO_ALG2) @@ -57,15 +54,20 @@ if CUSPARSE.version() >= v"12.5.1" CUSPARSE.CUSPARSE_SPMM_BSR_ALG1] end +if CUSPARSE.version() >= v"12.6.3" + SPMV_ALGOS[CuSparseMatrixBSR] = [CUSPARSE.CUSPARSE_SPMV_ALG_DEFAULT, + CUSPARSE.CUSPARSE_SPMV_BSR_ALG1] +end + for SparseMatrixType in keys(SPMV_ALGOS) @testset "$SparseMatrixType -- mv! algo=$algo" for algo in SPMV_ALGOS[SparseMatrixType] @testset "mv! $T" for T in [Float32, Float64, ComplexF32, ComplexF64] @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)] - SparseMatrixType == CuSparseMatrixCSC && T <: Complex && transa == 'C' && continue + (SparseMatrixType == CuSparseMatrixBSR) && (transa != 'N') && continue A = sprand(T, 20, 10, 0.1) B = transa == 'N' ? rand(T, 10) : rand(T, 20) C = transa == 'N' ? rand(T, 20) : rand(T, 10) - dA = SparseMatrixType(A) + dA = SparseMatrixType == CuSparseMatrixBSR ? SparseMatrixType(A,1) : SparseMatrixType(A) dB = CuArray(B) dC = CuArray(C) @@ -83,7 +85,6 @@ for SparseMatrixType in keys(SPMM_ALGOS) @testset "mm! $T" for T in [Float32, Float64, ComplexF32, ComplexF64] @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)] @testset "transb = $transb" for (transb, opb) in [('N', identity), ('T', transpose), ('C', adjoint)] - CUSPARSE.version() < v"12.0" && SparseMatrixType == CuSparseMatrixCSC && T <: Complex && transa == 'C' && continue algo == CUSPARSE.CUSPARSE_SPMM_CSR_ALG3 && (transa != 'N' || transb != 'N') && continue (SparseMatrixType == CuSparseMatrixBSR) && (transa != 'N') && continue A = sprand(T, 10, 10, 0.1) @@ -122,7 +123,6 @@ for SparseMatrixType in keys(SPMM_ALGOS) @testset "$T" for T in [Float32, Float64, ComplexF32, ComplexF64] @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)] @testset "transb = $transb" for (transb, opb) in [('N', identity), ('T', transpose), ('C', adjoint)] - CUSPARSE.version() < v"12.0" && SparseMatrixType == CuSparseMatrixCSR && T <: Complex && transb == 'C' && continue algo == CUSPARSE.CUSPARSE_SPMM_CSR_ALG3 && (transa != 'N' || transb != 'N') && continue A = rand(T, 10, 10) B = transb == 'N' ? sprand(T, 10, 5, 0.5) : sprand(T, 5, 10, 0.5) @@ -313,16 +313,14 @@ end @test Z ≈ collect(dY) end -SPGEMM_ALGOS = Dict(CuSparseMatrixCSR => [CUSPARSE.CUSPARSE_SPGEMM_DEFAULT], - CuSparseMatrixCSC => [CUSPARSE.CUSPARSE_SPGEMM_DEFAULT]) -if CUSPARSE.version() >= v"12.0" - append!(SPGEMM_ALGOS[CuSparseMatrixCSR], (CUSPARSE.CUSPARSE_SPGEMM_ALG1, - CUSPARSE.CUSPARSE_SPGEMM_ALG2, - CUSPARSE.CUSPARSE_SPGEMM_ALG3)) - append!(SPGEMM_ALGOS[CuSparseMatrixCSC], (CUSPARSE.CUSPARSE_SPGEMM_ALG1, - CUSPARSE.CUSPARSE_SPGEMM_ALG2, - CUSPARSE.CUSPARSE_SPGEMM_ALG3)) -end +SPGEMM_ALGOS = Dict(CuSparseMatrixCSR => [CUSPARSE.CUSPARSE_SPGEMM_DEFAULT, + CUSPARSE.CUSPARSE_SPGEMM_ALG1, + CUSPARSE.CUSPARSE_SPGEMM_ALG2, + CUSPARSE.CUSPARSE_SPGEMM_ALG3], + CuSparseMatrixCSC => [CUSPARSE.CUSPARSE_SPGEMM_DEFAULT, + CUSPARSE.CUSPARSE_SPGEMM_ALG1, + CUSPARSE.CUSPARSE_SPGEMM_ALG2, + CUSPARSE.CUSPARSE_SPGEMM_ALG3]) # Algorithms CUSPARSE.CUSPARSE_SPGEMM_CSR_ALG_DETERMINITIC and # CUSPARSE.CUSPARSE_SPGEMM_CSR_ALG_NONDETERMINITIC are dedicated to the cusparseSpGEMMreuse routine. @@ -391,39 +389,40 @@ for SparseMatrixType in keys(SPGEMM_ALGOS) end end -if CUSPARSE.version() >= v"11.4.1" +SDDMM_ALGOS = Dict(CuSparseMatrixCSR => [CUSPARSE.CUSPARSE_SDDMM_ALG_DEFAULT]) - SDDMM_ALGOS = Dict(CuSparseMatrixCSR => [CUSPARSE.CUSPARSE_SDDMM_ALG_DEFAULT]) +# if CUSPARSE.version() >= v"12.1.0" +# SDDMM_ALGOS[CuSparseMatrixBSR] = [CUSPARSE_SDDMM_ALG_DEFAULT] +# end - for SparseMatrixType in keys(SDDMM_ALGOS) - @testset "$SparseMatrixType -- sddmm! algo=$algo" for algo in SDDMM_ALGOS[SparseMatrixType] - @testset "sddmm! $T" for T in [Float32, Float64, ComplexF32, ComplexF64] - @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)] - @testset "transb = $transb" for (transb, opb) in [('N', identity), ('T', transpose), ('C', adjoint)] - T <: Complex && (transa == 'C' || transb == 'C') && continue - mA = transa == 'N' ? 25 : 10 - nA = transa == 'N' ? 10 : 25 - mB = transb == 'N' ? 10 : 35 - nB = transb == 'N' ? 35 : 10 +for SparseMatrixType in keys(SDDMM_ALGOS) + @testset "$SparseMatrixType -- sddmm! algo=$algo" for algo in SDDMM_ALGOS[SparseMatrixType] + @testset "sddmm! $T" for T in [Float32, Float64, ComplexF32, ComplexF64] + @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)] + @testset "transb = $transb" for (transb, opb) in [('N', identity), ('T', transpose), ('C', adjoint)] + T <: Complex && (transa == 'C' || transb == 'C') && continue + mA = transa == 'N' ? 25 : 10 + nA = transa == 'N' ? 10 : 25 + mB = transb == 'N' ? 10 : 35 + nB = transb == 'N' ? 35 : 10 - A = rand(T,mA,nA) - B = rand(T,mB,nB) - C = sprand(T,25,35,0.3) + A = rand(T,mA,nA) + B = rand(T,mB,nB) + C = sprand(T,25,35,0.3) - spyC = copy(C) - spyC.nzval .= one(T) + spyC = copy(C) + spyC.nzval .= one(T) - dA = CuArray(A) - dB = CuArray(B) - dC = SparseMatrixType(C) + dA = CuArray(A) + dB = CuArray(B) + dC = SparseMatrixType(C) - alpha = rand(T) - beta = rand(T) + alpha = rand(T) + beta = rand(T) - D = alpha * (opa(A) * opb(B)) .* spyC + beta * C - sddmm!(transa, transb, alpha, dA, dB, beta, dC, 'O', algo) - @test collect(dC) ≈ D - end + D = alpha * (opa(A) * opb(B)) .* spyC + beta * C + sddmm!(transa, transb, alpha, dA, dB, beta, dC, 'O', algo) + @test collect(dC) ≈ D end end end