diff --git a/lib/cusolver/dense_generic.jl b/lib/cusolver/dense_generic.jl index e7f5c27134..866b282103 100644 --- a/lib/cusolver/dense_generic.jl +++ b/lib/cusolver/dense_generic.jl @@ -505,6 +505,47 @@ function Xgeev!(jobvl::Char, jobvr::Char, A::StridedCuMatrix{T}) where {T <: Bla end # XsyevBatched +function XsyevBatched!(jobz::Char, uplo::Char, A::StridedCuArray{T, 3}) where {T <: BlasFloat} + CUSOLVER.version() < v"11.7.1" && throw(ErrorException("This operation is not supported by the current CUDA version.")) + chkuplo(uplo) + n = checksquare(A) + batch_size = size(A, 3) + R = real(T) + lda = max(1, stride(A, 2)) + W = CuMatrix{R}(undef, n, batch_size) + params = CuSolverParameters() + dh = dense_handle() + resize!(dh.info, batch_size) + + function bufferSize() + out_cpu = Ref{Csize_t}(0) + out_gpu = Ref{Csize_t}(0) + cusolverDnXsyevBatched_bufferSize( + dh, params, jobz, uplo, n, + T, A, lda, R, W, T, out_gpu, out_cpu, batch_size + ) + return out_gpu[], out_cpu[] + end + with_workspaces(dh.workspace_gpu, dh.workspace_cpu, bufferSize()...) do buffer_gpu, buffer_cpu + cusolverDnXsyevBatched( + dh, params, jobz, uplo, n, T, A, + lda, R, W, T, buffer_gpu, sizeof(buffer_gpu), + buffer_cpu, sizeof(buffer_cpu), dh.info, batch_size + ) + end + + info = @allowscalar collect(dh.info) + for i in 1:batch_size + chkargsok(info[i] |> BlasInt) + end + + if jobz == 'N' + return W + elseif jobz == 'V' + return W, A + end +end + function XsyevBatched!(jobz::Char, uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat} CUSOLVER.version() < v"11.7.1" && throw(ErrorException("This operation is not supported by the current CUDA version.")) chkuplo(uplo) diff --git a/test/libraries/cusolver/dense_generic.jl b/test/libraries/cusolver/dense_generic.jl index bca2cae46a..c67dbe5693 100644 --- a/test/libraries/cusolver/dense_generic.jl +++ b/test/libraries/cusolver/dense_generic.jl @@ -33,6 +33,36 @@ p = 5 end @testset "syevBatched!" begin + batch_size = 5 + for uplo in ('L', 'U') + (CUSOLVER.version() < v"11.7.2") && (uplo == 'L') && (elty == ComplexF32) && continue + + A = rand(elty, n, n, batch_size) + B = rand(elty, n, n, batch_size) + for i in 1:batch_size + S = rand(elty, n, n) + S = S * S' + I + B[:, :, i] .= S + S = uplo == 'L' ? tril(S) : triu(S) + A[:, :, i] .= S + end + d_A = CuArray(A) + d_W, d_V = CUSOLVER.XsyevBatched!('V', uplo, d_A) + W = collect(d_W) + V = collect(d_V) + for i in 1:batch_size + Bᵢ = B[:, :, i] + Wᵢ = Diagonal(W[:, i]) + Vᵢ = V[:, :, i] + @test Bᵢ * Vᵢ ≈ Vᵢ * Diagonal(Wᵢ) + end + + d_A = CuArray(A) + d_W = CUSOLVER.XsyevBatched!('N', uplo, d_A) + end + end + + @testset "syevBatched! updated" begin batch_size = 5 for uplo in ('L', 'U') (CUSOLVER.version() < v"11.7.2") && (uplo == 'L') && (elty == ComplexF32) && continue @@ -61,6 +91,7 @@ p = 5 d_W = CUSOLVER.XsyevBatched!('N', uplo, d_A) end end + end if CUSOLVER.version() >= v"11.6.0"