Skip to content

Commit

Permalink
[CUSOLVER] Don't reuse sparse handles (#2173)
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison authored Nov 21, 2023
1 parent 98586fe commit 66559cc
Showing 1 changed file with 22 additions and 27 deletions.
49 changes: 22 additions & 27 deletions lib/cusolver/sparse_factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ mutable struct SparseQR{T <: BlasFloat} <: Factorization{T}
m::Cint
nnzA::Cint
mu::T
handle::cusolverSpHandle_t
descA::CuMatrixDescriptor
info::SparseQRInfo
buffer::Union{CuPtr{Cvoid},CuVector{UInt8}}
Expand All @@ -27,12 +26,10 @@ function SparseQR(A::CuSparseMatrixCSR{T,Cint}, index::Char='O') where T <: Blas
m,n = size(A)
nnzA = nnz(A)
mu = zero(T)
handle = sparse_handle()
descA = CuMatrixDescriptor('G', 'L', 'N', index)
handle = sparse_handle()
info = SparseQRInfo()
buffer = CU_NULL
F = SparseQR{T}(n, m, nnzA, mu, handle, descA, info, buffer)
F = SparseQR{T}(n, m, nnzA, mu, descA, info, buffer)
spqr_analyse(F, A)
spqr_buffer(F, A)
return F
Expand All @@ -50,7 +47,7 @@ end
# const int * csrColIndA,
# csrqrInfo_t info);
function spqr_analyse(F::SparseQR{T}, A::CuSparseMatrixCSR{T,Cint}) where T <: BlasFloat
cusolverSpXcsrqrAnalysis(F.handle, F.m, F.n, F.nnzA, F.descA, A.rowPtr, A.colVal, F.info)
cusolverSpXcsrqrAnalysis(sparse_handle(), F.m, F.n, F.nnzA, F.descA, A.rowPtr, A.colVal, F.info)
return F
end

Expand All @@ -77,7 +74,7 @@ for (bname, iname, fname, sname, pname, elty, relty) in
function spqr_buffer(F::SparseQR{$elty}, A::CuSparseMatrixCSR{$elty,Cint})
internalDataInBytes = Ref{Csize_t}(0)
workspaceInBytes = Ref{Csize_t}(0)
$bname(F.handle, F.m, F.n, F.nnzA, F.descA, A.nzVal, A.rowPtr, A.colVal, F.info, internalDataInBytes, workspaceInBytes)
$bname(sparse_handle(), F.m, F.n, F.nnzA, F.descA, A.nzVal, A.rowPtr, A.colVal, F.info, internalDataInBytes, workspaceInBytes)
F.buffer = CuVector{UInt8}(undef, workspaceInBytes[])
return F
end
Expand Down Expand Up @@ -116,19 +113,19 @@ for (bname, iname, fname, sname, pname, elty, relty) in
# double tol,
# int * position);
function spqr_factorise(F::SparseQR{$elty}, A::CuSparseMatrixCSR{$elty,Cint}, tol::$relty)
$iname(F.handle, F.m, F.n, F.nnzA, F.descA, A.nzVal, A.rowPtr, A.colVal, F.mu, F.info)
$fname(F.handle, F.m, F.n, F.nnzA, CU_NULL, CU_NULL, F.info, F.buffer)
$iname(sparse_handle(), F.m, F.n, F.nnzA, F.descA, A.nzVal, A.rowPtr, A.colVal, F.mu, F.info)
$fname(sparse_handle(), F.m, F.n, F.nnzA, CU_NULL, CU_NULL, F.info, F.buffer)
singularity = Ref{Cint}(0)
$pname(F.handle, F.info, tol, singularity)
$pname(sparse_handle(), F.info, tol, singularity)
(singularity[] 0) && throw(SingularException(singularity[]))
return F
end

function spqr_factorise_solve(F::SparseQR{$elty}, A::CuSparseMatrixCSR{$elty,Cint}, b::CuVector{$elty}, x::CuVector{$elty}, tol::$relty)
$iname(F.handle, F.m, F.n, F.nnzA, F.descA, A.nzVal, A.rowPtr, A.colVal, F.mu, F.info)
$fname(F.handle, F.m, F.n, F.nnzA, b, x, F.info, F.buffer)
$iname(sparse_handle(), F.m, F.n, F.nnzA, F.descA, A.nzVal, A.rowPtr, A.colVal, F.mu, F.info)
$fname(sparse_handle(), F.m, F.n, F.nnzA, b, x, F.info, F.buffer)
singularity = Ref{Cint}(0)
$pname(F.handle, F.info, tol, singularity)
$pname(sparse_handle(), F.info, tol, singularity)
(singularity[] 0) && throw(SingularException(singularity[]))
return F
end
Expand All @@ -144,14 +141,14 @@ for (bname, iname, fname, sname, pname, elty, relty) in
# csrqrInfo_t info,
# void * pBuffer);
function spqr_solve(F::SparseQR{$elty}, b::CuVector{$elty}, x::CuVector{$elty})
$sname(F.handle, F.m, F.n, b, x, F.info, F.buffer)
$sname(sparse_handle(), F.m, F.n, b, x, F.info, F.buffer)
return x
end

function spqr_solve(F::SparseQR{$elty}, B::CuMatrix{$elty}, X::CuMatrix{$elty})
m, p = size(B)
for j=1:p
$sname(F.handle, F.m, F.n, view(B,:,j), view(X,:,j), F.info, F.buffer)
$sname(sparse_handle(), F.m, F.n, view(B,:,j), view(X,:,j), F.info, F.buffer)
end
return X
end
Expand All @@ -175,7 +172,6 @@ Base.unsafe_convert(::Type{csrcholInfo_t}, info::SparseCholeskyInfo) = info.info
mutable struct SparseCholesky{T <: BlasFloat} <: Factorization{T}
n::Cint
nnzA::Cint
handle::cusolverSpHandle_t
descA::CuMatrixDescriptor
info::SparseCholeskyInfo
buffer::Union{CuPtr{Cvoid},CuVector{UInt8}}
Expand All @@ -184,11 +180,10 @@ end
function SparseCholesky(A::Union{CuSparseMatrixCSC{T,Cint},CuSparseMatrixCSR{T,Cint}}, index::Char='O') where T <: BlasFloat
n = checksquare(A)
nnzA = nnz(A)
handle = sparse_handle()
descA = CuMatrixDescriptor('G', 'L', 'N', index)
info = SparseCholeskyInfo()
buffer = CU_NULL
F = SparseCholesky{T}(n, nnzA, handle, descA, info, buffer)
F = SparseCholesky{T}(n, nnzA, descA, info, buffer)
spcholesky_analyse(F, A)
spcholesky_buffer(F, A)
return F
Expand All @@ -206,9 +201,9 @@ end
# csrcholInfo_t info);
function spcholesky_analyse(F::SparseCholesky{T}, A::Union{CuSparseMatrixCSC{T,Cint},CuSparseMatrixCSR{T,Cint}}) where T <: BlasFloat
if A isa CuSparseMatrixCSC
cusolverSpXcsrcholAnalysis(F.handle, F.n, F.nnzA, F.descA, A.colPtr, A.rowVal, F.info)
cusolverSpXcsrcholAnalysis(sparse_handle(), F.n, F.nnzA, F.descA, A.colPtr, A.rowVal, F.info)
else
cusolverSpXcsrcholAnalysis(F.handle, F.n, F.nnzA, F.descA, A.rowPtr, A.colVal, F.info)
cusolverSpXcsrcholAnalysis(sparse_handle(), F.n, F.nnzA, F.descA, A.rowPtr, A.colVal, F.info)
end
return F
end
Expand Down Expand Up @@ -236,9 +231,9 @@ for (bname, fname, pname, elty, relty) in
internalDataInBytes = Ref{Csize_t}(0)
workspaceInBytes = Ref{Csize_t}(0)
if A isa CuSparseMatrixCSC
$bname(F.handle, F.n, F.nnzA, F.descA, A.nzVal, A.colPtr, A.rowVal, F.info, internalDataInBytes, workspaceInBytes)
$bname(sparse_handle(), F.n, F.nnzA, F.descA, A.nzVal, A.colPtr, A.rowVal, F.info, internalDataInBytes, workspaceInBytes)
else
$bname(F.handle, F.n, F.nnzA, F.descA, A.nzVal, A.rowPtr, A.colVal, F.info, internalDataInBytes, workspaceInBytes)
$bname(sparse_handle(), F.n, F.nnzA, F.descA, A.nzVal, A.rowPtr, A.colVal, F.info, internalDataInBytes, workspaceInBytes)
end
F.buffer = CuVector{UInt8}(undef, workspaceInBytes[])
return F
Expand Down Expand Up @@ -267,12 +262,12 @@ for (bname, fname, pname, elty, relty) in
function spcholesky_factorise(F::SparseCholesky{$elty}, A::Union{CuSparseMatrixCSC{$elty,Cint},CuSparseMatrixCSR{$elty,Cint}}, tol::$relty)
if A isa CuSparseMatrixCSC
nzval = $elty <: Complex ? conj(A.nzVal) : A.nzVal
$fname(F.handle, F.n, F.nnzA, F.descA, nzval, A.colPtr, A.rowVal, F.info, F.buffer)
$fname(sparse_handle(), F.n, F.nnzA, F.descA, nzval, A.colPtr, A.rowVal, F.info, F.buffer)
else
$fname(F.handle, F.n, F.nnzA, F.descA, A.nzVal, A.rowPtr, A.colVal, F.info, F.buffer)
$fname(sparse_handle(), F.n, F.nnzA, F.descA, A.nzVal, A.rowPtr, A.colVal, F.info, F.buffer)
end
singularity = Ref{Cint}(0)
$pname(F.handle, F.info, tol, singularity)
$pname(sparse_handle(), F.info, tol, singularity)
(singularity[] 0) && throw(SingularException(singularity[]))
return F
end
Expand All @@ -294,14 +289,14 @@ for (sname, dname, elty, relty) in ((:cusolverSpScsrcholSolve, :cusolverSpScsrch
# csrcholInfo_t info,
# void * pBuffer);
function spcholesky_solve(F::SparseCholesky{$elty}, b::CuVector{$elty}, x::CuVector{$elty})
$sname(F.handle, F.n, b, x, F.info, F.buffer)
$sname(sparse_handle(), F.n, b, x, F.info, F.buffer)
return x
end

function spcholesky_solve(F::SparseCholesky{$elty}, B::CuMatrix{$elty}, X::CuMatrix{$elty})
n, p = size(B)
for j=1:p
$sname(F.handle, F.n, view(B,:,j), view(X,:,j), F.info, F.buffer)
$sname(sparse_handle(), F.n, view(B,:,j), view(X,:,j), F.info, F.buffer)
end
return X
end
Expand All @@ -313,7 +308,7 @@ for (sname, dname, elty, relty) in ((:cusolverSpScsrcholSolve, :cusolverSpScsrch
# csrcholInfo_t info,
# float * diag);
function spcholesky_diag(F::SparseCholesky{$elty}, diag::CuVector{$relty})
$dname(F.handle, F.info, diag)
$dname(sparse_handle(), F.info, diag)
return diag
end
end
Expand Down

0 comments on commit 66559cc

Please sign in to comment.