Skip to content

Commit

Permalink
Add the structures ILU0Info() and IC0Info() for the preconditioners (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison authored Dec 11, 2023
1 parent c92bb31 commit 4f29369
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 54 deletions.
56 changes: 56 additions & 0 deletions lib/cusparse/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,59 @@ mutable struct CuSparseSpSMDescriptor
end

Base.unsafe_convert(::Type{cusparseSpSMDescr_t}, desc::CuSparseSpSMDescriptor) = desc.handle

mutable struct IC0Info
info::csric02Info_t

function IC0Info()
info_ref = Ref{csric02Info_t}()
cusparseCreateCsric02Info(info_ref)
obj = new(info_ref[])
finalizer(cusparseDestroyCsric02Info, obj)
obj
end
end

Base.unsafe_convert(::Type{csric02Info_t}, info::IC0Info) = info.info

mutable struct IC0InfoBSR
info::bsric02Info_t

function IC0InfoBSR()
info_ref = Ref{bsric02Info_t}()
cusparseCreateBsric02Info(info_ref)
obj = new(info_ref[])
finalizer(cusparseDestroyBsric02Info, obj)
obj
end
end

Base.unsafe_convert(::Type{bsric02Info_t}, info::IC0InfoBSR) = info.info

mutable struct ILU0Info
info::csrilu02Info_t

function ILU0Info()
info_ref = Ref{csrilu02Info_t}()
cusparseCreateCsrilu02Info(info_ref)
obj = new(info_ref[])
finalizer(cusparseDestroyCsrilu02Info, obj)
obj
end
end

Base.unsafe_convert(::Type{csrilu02Info_t}, info::ILU0Info) = info.info

mutable struct ILU0InfoBSR
info::bsrilu02Info_t

function ILU0InfoBSR()
info_ref = Ref{bsrilu02Info_t}()
cusparseCreateBsrilu02Info(info_ref)
obj = new(info_ref[])
finalizer(cusparseDestroyBsrilu02Info, obj)
obj
end
end

Base.unsafe_convert(::Type{bsrilu02Info_t}, info::ILU0InfoBSR) = info.info
94 changes: 40 additions & 54 deletions lib/cusparse/preconditioners.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,29 +40,27 @@ for (bname,aname,sname,elty) in ((:cusparseScsric02_bufferSize, :cusparseScsric0
if m != n
throw(DimensionMismatch("A must be square, but has dimensions ($m,$n)!"))
end
info = csric02Info_t[0]
cusparseCreateCsric02Info(info)
info = IC0Info()

function bufferSize()
out = Ref{Cint}(1)
$bname(handle(), m, nnz(A), desc, nonzeros(A), A.rowPtr, A.colVal, info[1],
$bname(handle(), m, nnz(A), desc, nonzeros(A), A.rowPtr, A.colVal, info,
out)
return out[]
end
with_workspace(bufferSize) do buffer
$aname(handle(), m, nnz(A), desc,
nonzeros(A), A.rowPtr, A.colVal, info[1],
nonzeros(A), A.rowPtr, A.colVal, info,
CUSPARSE_SOLVE_POLICY_USE_LEVEL, buffer)
posit = Ref{Cint}(1)
cusparseXcsric02_zeroPivot(handle(), info[1], posit)
cusparseXcsric02_zeroPivot(handle(), info, posit)
if posit[] >= 0
error("Structural/numerical zero in A at ($(posit[]),$(posit[])))")
end
$sname(handle(), m, nnz(A),
desc, nonzeros(A), A.rowPtr, A.colVal, info[1],
desc, nonzeros(A), A.rowPtr, A.colVal, info,
CUSPARSE_SOLVE_POLICY_USE_LEVEL, buffer)
end
cusparseDestroyCsric02Info(info[1])
A
end
end
Expand All @@ -80,29 +78,27 @@ for (bname,aname,sname,elty) in ((:cusparseScsric02_bufferSize, :cusparseScsric0
if m != n
throw(DimensionMismatch("A must be square, but has dimensions ($m,$n)!"))
end
info = csric02Info_t[0]
cusparseCreateCsric02Info(info)
info = IC0Info()

function bufferSize()
out = Ref{Cint}(1)
$bname(handle(), m, nnz(A), desc, nonzeros(A), A.colPtr, rowvals(A),
info[1], out)
info, out)
return out[]
end
with_workspace(bufferSize) do buffer
$aname(handle(), m, nnz(A), desc,
nonzeros(A), A.colPtr, rowvals(A), info[1],
nonzeros(A), A.colPtr, rowvals(A), info,
CUSPARSE_SOLVE_POLICY_USE_LEVEL, buffer)
posit = Ref{Cint}(1)
cusparseXcsric02_zeroPivot(handle(), info[1], posit)
cusparseXcsric02_zeroPivot(handle(), info, posit)
if posit[] >= 0
error("Structural/numerical zero in A at ($(posit[]),$(posit[])))")
end
$sname(handle(), m, nnz(A),
desc, nonzeros(A), A.colPtr, rowvals(A), info[1],
desc, nonzeros(A), A.colPtr, rowvals(A), info,
CUSPARSE_SOLVE_POLICY_USE_LEVEL, buffer)
end
cusparseDestroyCsric02Info(info[1])
A
end
end
Expand All @@ -120,30 +116,28 @@ for (bname,aname,sname,elty) in ((:cusparseScsrilu02_bufferSize, :cusparseScsril
if m != n
throw(DimensionMismatch("A must be square, but has dimensions ($m,$n)!"))
end
info = csrilu02Info_t[0]
cusparseCreateCsrilu02Info(info)
info = ILU0Info()

function bufferSize()
out = Ref{Cint}(1)
$bname(handle(), m, nnz(A), desc,
nonzeros(A), A.rowPtr, A.colVal, info[1],
nonzeros(A), A.rowPtr, A.colVal, info,
out)
return out[]
end
with_workspace(bufferSize) do buffer
$aname(handle(), m, nnz(A), desc,
nonzeros(A), A.rowPtr, A.colVal, info[1],
nonzeros(A), A.rowPtr, A.colVal, info,
CUSPARSE_SOLVE_POLICY_USE_LEVEL, buffer)
posit = Ref{Cint}(1)
cusparseXcsrilu02_zeroPivot(handle(), info[1], posit)
cusparseXcsrilu02_zeroPivot(handle(), info, posit)
if posit[] >= 0
error("Structural zero in A at ($(posit[]),$(posit[])))")
end
$sname(handle(), m, nnz(A),
desc, nonzeros(A), A.rowPtr, A.colVal, info[1],
desc, nonzeros(A), A.rowPtr, A.colVal, info,
CUSPARSE_SOLVE_POLICY_USE_LEVEL, buffer)
end
cusparseDestroyCsrilu02Info(info[1])
A
end
end
Expand All @@ -161,30 +155,28 @@ for (bname,aname,sname,elty) in ((:cusparseScsrilu02_bufferSize, :cusparseScsril
if m != n
throw(DimensionMismatch("A must be square, but has dimensions ($m,$n)!"))
end
info = csrilu02Info_t[0]
cusparseCreateCsrilu02Info(info)
info = ILU0Info()

function bufferSize()
out = Ref{Cint}(1)
$bname(handle(), m, nnz(A), desc,
nonzeros(A), A.colPtr, rowvals(A), info[1],
nonzeros(A), A.colPtr, rowvals(A), info,
out)
return out[]
end
with_workspace(bufferSize) do buffer
$aname(handle(), m, nnz(A), desc,
nonzeros(A), A.colPtr, rowvals(A), info[1],
nonzeros(A), A.colPtr, rowvals(A), info,
CUSPARSE_SOLVE_POLICY_USE_LEVEL, buffer)
posit = Ref{Cint}(1)
cusparseXcsrilu02_zeroPivot(handle(), info[1], posit)
cusparseXcsrilu02_zeroPivot(handle(), info, posit)
if posit[] >= 0
error("Structural zero in A at ($(posit[]),$(posit[])))")
end
$sname(handle(), m, nnz(A),
desc, nonzeros(A), A.colPtr, rowvals(A), info[1],
desc, nonzeros(A), A.colPtr, rowvals(A), info,
CUSPARSE_SOLVE_POLICY_USE_LEVEL, buffer)
end
cusparseDestroyCsrilu02Info(info[1])
A
end
end
Expand All @@ -203,30 +195,27 @@ for (bname,aname,sname,elty) in ((:cusparseSbsric02_bufferSize, :cusparseSbsric0
throw(DimensionMismatch("A must be square, but has dimensions ($m,$n)!"))
end
mb = div(m,A.blockDim)
info = bsric02Info_t[0]
cusparseCreateBsric02Info(info)
info = IC0InfoBSR()

function bufferSize()
out = Ref{Cint}(1)
$bname(handle(), A.dir, mb, nnz(A), desc,
nonzeros(A), A.rowPtr, A.colVal, A.blockDim, info[1],
out)
$bname(handle(), A.dir, mb, nnz(A), desc, nonzeros(A),
A.rowPtr, A.colVal, A.blockDim, info, out)
return out[]
end
with_workspace(bufferSize) do buffer
$aname(handle(), A.dir, mb, nnz(A), desc,
nonzeros(A), A.rowPtr, A.colVal, A.blockDim, info[1],
CUSPARSE_SOLVE_POLICY_USE_LEVEL, buffer)
posit = Ref{Cint}(1)
cusparseXbsric02_zeroPivot(handle(), info[1], posit)
if posit[] >= 0
error("Structural/numerical zero in A at ($(posit[]),$(posit[])))")
end
$sname(handle(), A.dir, mb, nnz(A), desc,
nonzeros(A), A.rowPtr, A.colVal, A.blockDim, info[1],
CUSPARSE_SOLVE_POLICY_USE_LEVEL, buffer)
$aname(handle(), A.dir, mb, nnz(A), desc,
nonzeros(A), A.rowPtr, A.colVal, A.blockDim, info,
CUSPARSE_SOLVE_POLICY_USE_LEVEL, buffer)
posit = Ref{Cint}(1)
cusparseXbsric02_zeroPivot(handle(), info, posit)
if posit[] >= 0
error("Structural/numerical zero in A at ($(posit[]),$(posit[])))")
end
cusparseDestroyBsric02Info(info[1])
$sname(handle(), A.dir, mb, nnz(A), desc,
nonzeros(A), A.rowPtr, A.colVal, A.blockDim, info,
CUSPARSE_SOLVE_POLICY_USE_LEVEL, buffer)
end
A
end
end
Expand All @@ -245,30 +234,27 @@ for (bname,aname,sname,elty) in ((:cusparseSbsrilu02_bufferSize, :cusparseSbsril
throw(DimensionMismatch("A must be square, but has dimensions ($m,$n)!"))
end
mb = div(m,A.blockDim)
info = bsrilu02Info_t[0]
cusparseCreateBsrilu02Info(info)
info = ILU0InfoBSR()

function bufferSize()
out = Ref{Cint}(1)
$bname(handle(), A.dir, mb, nnz(A), desc,
nonzeros(A), A.rowPtr, A.colVal, A.blockDim, info[1],
out)
$bname(handle(), A.dir, mb, nnz(A), desc, nonzeros(A),
A.rowPtr, A.colVal, A.blockDim, info, out)
return out[]
end
with_workspace(bufferSize) do buffer
$aname(handle(), A.dir, mb, nnz(A), desc,
nonzeros(A), A.rowPtr, A.colVal, A.blockDim, info[1],
nonzeros(A), A.rowPtr, A.colVal, A.blockDim, info,
CUSPARSE_SOLVE_POLICY_USE_LEVEL, buffer)
posit = Ref{Cint}(1)
cusparseXbsrilu02_zeroPivot(handle(), info[1], posit)
cusparseXbsrilu02_zeroPivot(handle(), info, posit)
if posit[] >= 0
error("Structural/numerical zero in A at ($(posit[]),$(posit[])))")
end
$sname(handle(), A.dir, mb, nnz(A), desc,
nonzeros(A), A.rowPtr, A.colVal, A.blockDim, info[1],
nonzeros(A), A.rowPtr, A.colVal, A.blockDim, info,
CUSPARSE_SOLVE_POLICY_USE_LEVEL, buffer)
end
cusparseDestroyBsrilu02Info(info[1])
A
end
end
Expand Down

0 comments on commit 4f29369

Please sign in to comment.