Skip to content

Commit 50bf1a9

Browse files
committed
syevBatched! interface accepts 3D CuArray
Signed-off-by: Steven Hahn <[email protected]>
1 parent 64df292 commit 50bf1a9

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

lib/cusolver/dense_generic.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,43 @@ function Xgeev!(jobvl::Char, jobvr::Char, A::StridedCuMatrix{T}) where {T <: Bla
505505
end
506506

507507
# XsyevBatched
508+
function XsyevBatched!(jobz::Char, uplo::Char, A::StridedCuArray{T,3}) where {T <: BlasFloat}
509+
CUSOLVER.version() < v"11.7.1" && throw(ErrorException("This operation is not supported by the current CUDA version."))
510+
chkuplo(uplo)
511+
n = checksquare(A)
512+
batch_size = size(A,3)
513+
R = real(T)
514+
lda = max(1, stride(A,2))
515+
W = CuMatrix{R}(undef, n, batch_size)
516+
params = CuSolverParameters()
517+
dh = dense_handle()
518+
resize!(dh.info, batch_size)
519+
520+
function bufferSize()
521+
out_cpu = Ref{Csize_t}(0)
522+
out_gpu = Ref{Csize_t}(0)
523+
cusolverDnXsyevBatched_bufferSize(dh, params, jobz, uplo, n,
524+
T, A, lda, R, W, T, out_gpu, out_cpu, batch_size)
525+
out_gpu[], out_cpu[]
526+
end
527+
with_workspaces(dh.workspace_gpu, dh.workspace_cpu, bufferSize()...) do buffer_gpu, buffer_cpu
528+
cusolverDnXsyevBatched(dh, params, jobz, uplo, n, T, A,
529+
lda, R, W, T, buffer_gpu, sizeof(buffer_gpu),
530+
buffer_cpu, sizeof(buffer_cpu), dh.info, batch_size)
531+
end
532+
533+
info = @allowscalar collect(dh.info)
534+
for i = 1:batch_size
535+
chkargsok(info[i] |> BlasInt)
536+
end
537+
538+
if jobz == 'N'
539+
return W
540+
elseif jobz == 'V'
541+
return W, A
542+
end
543+
end
544+
508545
function XsyevBatched!(jobz::Char, uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
509546
CUSOLVER.version() < v"11.7.1" && throw(ErrorException("This operation is not supported by the current CUDA version."))
510547
chkuplo(uplo)

test/libraries/cusolver/dense_generic.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,36 @@ p = 5
3333
end
3434

3535
@testset "syevBatched!" begin
36+
batch_size = 5
37+
for uplo in ('L', 'U')
38+
(CUSOLVER.version() < v"11.7.2") && (uplo == 'L') && (elty == ComplexF32) && continue
39+
40+
A = rand(elty, n, n, batch_size)
41+
B = rand(elty, n, n, batch_size)
42+
for i = 1:batch_size
43+
S = rand(elty,n,n)
44+
S = S * S' + I
45+
B[:,:,i] .= S
46+
S = uplo == 'L' ? tril(S) : triu(S)
47+
A[:,:,i] .= S
48+
end
49+
d_A = CuArray(A)
50+
d_W, d_V = CUSOLVER.XsyevBatched!('V', uplo, d_A)
51+
W = collect(d_W)
52+
V = collect(d_V)
53+
for i = 1:batch_size
54+
Bᵢ = B[:,:,i]
55+
Wᵢ = Diagonal(W[:,i])
56+
Vᵢ = V[:,:,i]
57+
@test Bᵢ * Vᵢ Vᵢ * Diagonal(Wᵢ)
58+
end
59+
60+
d_A = CuArray(A)
61+
d_W = CUSOLVER.XsyevBatched!('N', uplo, d_A)
62+
end
63+
end
64+
65+
@testset "syevBatched! updated" begin
3666
batch_size = 5
3767
for uplo in ('L', 'U')
3868
(CUSOLVER.version() < v"11.7.2") && (uplo == 'L') && (elty == ComplexF32) && continue
@@ -61,6 +91,7 @@ p = 5
6191
d_W = CUSOLVER.XsyevBatched!('N', uplo, d_A)
6292
end
6393
end
94+
6495
end
6596

6697
if CUSOLVER.version() >= v"11.6.0"

0 commit comments

Comments
 (0)