@@ -505,6 +505,43 @@ function Xgeev!(jobvl::Char, jobvr::Char, A::StridedCuMatrix{T}) where {T <: Bla
505505end
506506
507507# XsyevBatched
508+ function XsyevBatched! (jobz:: Char , uplo:: Char , A:: StridedCuArray{T,3} ) where {T <: BlasFloat }
509+ CUSOLVER. version () < v " 11.7.1" && throw (ErrorException (" This operation is not supported by the current CUDA version." ))
510+ chkuplo (uplo)
511+ n = checksquare (A)
512+ batch_size = size (A,3 )
513+ R = real (T)
514+ lda = max (1 , stride (A,2 ))
515+ W = CuMatrix {R} (undef, n, batch_size)
516+ params = CuSolverParameters ()
517+ dh = dense_handle ()
518+ resize! (dh. info, batch_size)
519+
520+ function bufferSize ()
521+ out_cpu = Ref {Csize_t} (0 )
522+ out_gpu = Ref {Csize_t} (0 )
523+ cusolverDnXsyevBatched_bufferSize (dh, params, jobz, uplo, n,
524+ T, A, lda, R, W, T, out_gpu, out_cpu, batch_size)
525+ out_gpu[], out_cpu[]
526+ end
527+ with_workspaces (dh. workspace_gpu, dh. workspace_cpu, bufferSize ()... ) do buffer_gpu, buffer_cpu
528+ cusolverDnXsyevBatched (dh, params, jobz, uplo, n, T, A,
529+ lda, R, W, T, buffer_gpu, sizeof (buffer_gpu),
530+ buffer_cpu, sizeof (buffer_cpu), dh. info, batch_size)
531+ end
532+
533+ info = @allowscalar collect (dh. info)
534+ for i = 1 : batch_size
535+ chkargsok (info[i] |> BlasInt)
536+ end
537+
538+ if jobz == ' N'
539+ return W
540+ elseif jobz == ' V'
541+ return W, A
542+ end
543+ end
544+
508545function XsyevBatched! (jobz:: Char , uplo:: Char , A:: StridedCuMatrix{T} ) where {T <: BlasFloat }
509546 CUSOLVER. version () < v " 11.7.1" && throw (ErrorException (" This operation is not supported by the current CUDA version." ))
510547 chkuplo (uplo)
0 commit comments