Skip to content

Commit

Permalink
feat: support automatic sparsity detection for PETSc
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 25, 2024
1 parent 8fe5ee6 commit 12baf76
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 15 deletions.
55 changes: 42 additions & 13 deletions ext/NonlinearSolvePETScExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@ using NonlinearSolveBase: NonlinearSolveBase, get_tolerance
using NonlinearSolve: NonlinearSolve, PETScSNES
using PETSc: PETSc
using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode
using SparseArrays: AbstractSparseMatrix

function SciMLBase.__solve(
prob::NonlinearProblem, alg::PETScSNES, args...; abstol = nothing, reltol = nothing,
maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing,
show_trace::Val{ShT} = Val(false), kwargs...) where {ShT}
# XXX: https://petsc.org/release/manualpages/SNES/SNESSetConvergenceTest/
termination_condition === nothing ||
error("`PETScSNES` does not support termination conditions!")

_f!, u0, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0)
T = eltype(prob.u0)
@assert T PETSc.scalar_types

if alg.petsclib === missing
petsclibidx = findfirst(PETSc.petsclibs) do petsclib
Expand All @@ -35,7 +38,10 @@ function SciMLBase.__solve(
abstol = get_tolerance(abstol, T)
reltol = get_tolerance(reltol, T)

nf = Ref{Int}(0)

f! = @closure (cfx, cx, user_ctx) -> begin
nf[] += 1
fx = cfx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cfx; read = false) : cfx
x = cx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cx; write = false) : cx
_f!(fx, x)
Expand All @@ -49,25 +55,47 @@ function SciMLBase.__solve(
alg.snes_options..., snes_monitor = ShT, snes_rtol = reltol,
snes_atol = abstol, snes_max_it = maxiters)

PETSc.setfunction!(snes, f!, PETSc.VecSeq(zero(u0)))

if alg.autodiff === missing && prob.f.jac === nothing
_jac! = nothing
njac = Ref{Int}(-1)
else
autodiff = alg.autodiff === missing ? nothing : alg.autodiff
_jac! = NonlinearSolve.__construct_extension_jac(prob, alg, u0, resid; autodiff)
end
_jac!, J_init = NonlinearSolve.__construct_extension_jac(
prob, alg, u0, resid; autodiff, initial_jacobian = Val(true))

PETSc.setfunction!(snes, f!, PETSc.VecSeq(zero(u0)))
njac = Ref{Int}(0)

if _jac! !== nothing # XXX: Sparsity Handling???
PJ = PETSc.MatSeqDense(zeros(T, length(resid), length(u0)))
jac! = @closure (cx, J, _, user_ctx) -> begin
x = cx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cx; write = false) : cx
_jac!(J, x)
Base.finalize(x)
PETSc.assemble(J)
return
if J_init isa AbstractSparseMatrix
PJ = PETSc.MatSeqAIJ(J_init)
jac! = @closure (cx, J, _, user_ctx) -> begin
njac[] += 1
x = cx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cx; write = false) : cx
if J isa PETSc.AbstractMat
_jac!(user_ctx.jacobian, x)
copyto!(J, user_ctx.jacobian)
PETSc.assemble(J)
else
_jac!(J, x)
end
Base.finalize(x)
return
end
PETSc.setjacobian!(snes, jac!, PJ, PJ)
snes.user_ctx = (; jacobian = J_init)
else
PJ = PETSc.MatSeqDense(J_init)
jac! = @closure (cx, J, _, user_ctx) -> begin
njac[] += 1
x = cx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cx; write = false) : cx
_jac!(J, x)
Base.finalize(x)
J isa PETSc.AbstractMat && PETSc.assemble(J)
return
end
PETSc.setjacobian!(snes, jac!, PJ, PJ)
end
PETSc.setjacobian!(snes, jac!, PJ, PJ)
end

res = PETSc.solve!(u0, snes)
Expand All @@ -79,7 +107,8 @@ function SciMLBase.__solve(
objective = maximum(abs, resid)
# XXX: Return Code from PETSc
retcode = ifelse(objective abstol, ReturnCode.Success, ReturnCode.Failure)
return SciMLBase.build_solution(prob, alg, u_, resid_; retcode, original = snes)
return SciMLBase.build_solution(prob, alg, u_, resid_; retcode, original = snes,
stats = SciMLBase.NLStats(nf[], njac[], -1, -1, -1))
end

end
7 changes: 5 additions & 2 deletions src/internal/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ function __construct_extension_f(prob::AbstractNonlinearProblem; alias_u0::Bool
end

function __construct_extension_jac(prob, alg, u0, fu; can_handle_oop::Val = False,
can_handle_scalar::Val = False, autodiff = nothing, kwargs...)
can_handle_scalar::Val = False, autodiff = nothing, initial_jacobian = False,
kwargs...)
autodiff = select_jacobian_autodiff(prob, autodiff)

Jₚ = JacobianCache(
Expand All @@ -120,7 +121,9 @@ function __construct_extension_jac(prob, alg, u0, fu; can_handle_oop::Val = Fals
𝐉 = (can_handle_oop === False && !isinplace(prob)) ?
@closure((J, u)->copyto!(J, 𝓙(u))) : 𝓙

return 𝐉
initial_jacobian === False && return 𝐉

return 𝐉, Jₚ(nothing)
end

function reinit_cache! end
Expand Down

0 comments on commit 12baf76

Please sign in to comment.