Skip to content

feat: GridapPETSc wrapper #541

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ SimpleNonlinearSolve = {path = "lib/SimpleNonlinearSolve"}
[weakdeps]
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176"
Gridap = "56d4f2e9-7ea1-5844-9cf6-b9c51ca7ce8e"
GridapPETSc = "bcdc36c2-0c3e-11ea-095a-c9dadae499f1"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
Expand All @@ -53,9 +55,16 @@ SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"

[sources]
NonlinearSolveBase = {path = "lib/NonlinearSolveBase"}
NonlinearSolveFirstOrder = {path = "lib/NonlinearSolveFirstOrder"}
NonlinearSolveQuasiNewton = {path = "lib/NonlinearSolveQuasiNewton"}
NonlinearSolveSpectralMethods = {path = "lib/NonlinearSolveSpectralMethods"}

[extensions]
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
NonlinearSolveFixedPointAccelerationExt = "FixedPointAcceleration"
NonlinearSolveGridapPETScExt = ["Gridap", "GridapPETSc"]
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
NonlinearSolveMINPACKExt = "MINPACK"
NonlinearSolveNLSolversExt = "NLSolvers"
Expand Down
123 changes: 123 additions & 0 deletions ext/NonlinearSolveGridapPETScExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
module NonlinearSolveGridapPETScExt

using Gridap: Gridap, Algebra
using GridapPETSc: GridapPETSc

using NonlinearSolveBase: NonlinearSolveBase
using NonlinearSolve: NonlinearSolve, GridapPETScSNES
using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode

using ConcreteStructs: @concrete
using FastClosures: @closure

@concrete struct NonlinearSolveOperator <: Algebra.NonlinearOperator
f!
jac!
initial_guess_cache
resid_prototype
jacobian_prototype
end

function Algebra.residual!(b::AbstractVector, op::NonlinearSolveOperator, x::AbstractVector)
op.f!(b, x)
end

function Algebra.jacobian!(
A::AbstractMatrix, op::NonlinearSolveOperator, x::AbstractVector
)
op.jac!(A, x)
end

function Algebra.zero_initial_guess(op::NonlinearSolveOperator)
fill!(op.initial_guess_cache, 0)
return op.initial_guess_cache
end

function Algebra.allocate_residual(op::NonlinearSolveOperator, ::AbstractVector)
fill!(op.resid_prototype, 0)
return op.resid_prototype
end

function Algebra.allocate_jacobian(op::NonlinearSolveOperator, ::AbstractVector)
fill!(op.jacobian_prototype, 0)
return op.jacobian_prototype
end

# TODO: Later we should just wrap `Gridap` generally and pass in `PETSc` as the solver
function SciMLBase.__solve(
prob::NonlinearProblem, alg::GridapPETScSNES, args...;
abstol = nothing, reltol = nothing,
maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing,
show_trace::Val = Val(false), kwargs...
)
# XXX: https://petsc.org/release/manualpages/SNES/SNESSetConvergenceTest/
NonlinearSolveBase.assert_extension_supported_termination_condition(
termination_condition, alg; abs_norm_supported = false
)

f_wrapped!, u0, resid = NonlinearSolveBase.construct_extension_function_wrapper(
prob; alias_u0
)
T = eltype(u0)

abstol = NonlinearSolveBase.get_tolerance(abstol, T)
reltol = NonlinearSolveBase.get_tolerance(reltol, T)

nf = Ref{Int}(0)

f! = @closure (fx, x) -> begin
nf[] += 1
f_wrapped!(fx, x)
return fx
end

if prob.u0 isa Number
jac! = NonlinearSolveBase.construct_extension_jac(
prob, alg, prob.u0, prob.u0; alg.autodiff
)
J_init = zeros(T, 1, 1)
else
jac!, J_init = NonlinearSolveBase.construct_extension_jac(
prob, alg, u0, resid; alg.autodiff, initial_jacobian = Val(true)
)
end

njac = Ref{Int}(-1)
jac_fn! = @closure (J, x) -> begin
njac[] += 1
jac!(J, x)
return J
end

nop = NonlinearSolveOperator(f!, jac_fn!, u0, resid, J_init)

petsc_args = [
"-snes_rtol", string(reltol), "-snes_atol", string(abstol),
"-snes_max_it", string(maxiters)
]
for (k, v) in pairs(alg.snes_options)
push!(petsc_args, "-$(k)")
push!(petsc_args, string(v))
end
show_trace isa Val{true} && push!(petsc_args, "-snes_monitor")

# TODO: We can reuse the cache returned from this function
sol_u = GridapPETSc.with(args = petsc_args) do
sol_u = copy(u0)
Algebra.solve!(sol_u, GridapPETSc.PETScNonlinearSolver(), nop)
return sol_u
end

f_wrapped!(resid, sol_u)
u_res = prob.u0 isa Number ? sol_u[1] : sol_u
resid_res = prob.u0 isa Number ? resid[1] : resid

objective = maximum(abs, resid)
retcode = ifelse(objective ≤ abstol, ReturnCode.Success, ReturnCode.Failure)
return SciMLBase.build_solution(
prob, alg, u_res, resid_res;
retcode, stats = SciMLBase.NLStats(nf[], njac[], -1, -1, -1)
)
end

end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

12 changes: 9 additions & 3 deletions ext/NonlinearSolvePETScExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ function SciMLBase.__solve(
maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing,
show_trace::Val = Val(false), kwargs...
)
if !MPI.Initialized()
@warn "MPI not initialized. Initializing MPI with MPI.Init()." maxlog=1
MPI.Init()
end

# XXX: https://petsc.org/release/manualpages/SNES/SNESSetConvergenceTest/
NonlinearSolveBase.assert_extension_supported_termination_condition(
termination_condition, alg; abs_norm_supported = false
Expand Down Expand Up @@ -68,8 +73,10 @@ function SciMLBase.__solve(
PETSc.setfunction!(snes, f!, PETSc.VecSeq(zero(u0)))

njac = Ref{Int}(-1)
if alg.autodiff !== missing || prob.f.jac !== nothing
# `missing` -> let PETSc compute the Jacobian using finite differences
if alg.autodiff !== missing
autodiff = alg.autodiff === missing ? nothing : alg.autodiff

if prob.u0 isa Number
jac! = NonlinearSolveBase.construct_extension_jac(
prob, alg, prob.u0, prob.u0; autodiff
Expand Down Expand Up @@ -125,8 +132,7 @@ function SciMLBase.__solve(
retcode = ifelse(objective ≤ abstol, ReturnCode.Success, ReturnCode.Failure)
return SciMLBase.build_solution(
prob, alg, u_res, resid_res;
retcode, original = snes,
stats = SciMLBase.NLStats(nf[], njac[], -1, -1, -1)
retcode, stats = SciMLBase.NLStats(nf[], njac[], -1, -1, -1)
)
end

Expand Down
3 changes: 3 additions & 0 deletions lib/BracketingNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,6 @@ TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"

[targets]
test = ["Aqua", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "Test", "TestItemRunner"]

[sources]
NonlinearSolveBase = {path = "../NonlinearSolveBase"}
3 changes: 3 additions & 0 deletions lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "BandedMatrices", "DiffEqBase", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "LinearAlgebra", "SparseArrays", "Test"]

[sources]
SciMLJacobianOperators = {path = "../SciMLJacobianOperators"}
4 changes: 4 additions & 0 deletions lib/NonlinearSolveFirstOrder/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ForwardDiff", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"]

[sources]
NonlinearSolveBase = {path = "../NonlinearSolveBase"}
SciMLJacobianOperators = {path = "../SciMLJacobianOperators"}
3 changes: 3 additions & 0 deletions lib/NonlinearSolveHomotopyContinuation/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "NonlinearSolve", "Enzyme", "NaNMath"]

[sources]
NonlinearSolveBase = {path = "../NonlinearSolveBase"}
3 changes: 3 additions & 0 deletions lib/NonlinearSolveQuasiNewton/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ADTypes", "Aqua", "BenchmarkTools", "Enzyme", "ExplicitImports", "FiniteDiff", "ForwardDiff", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "ReTestItems", "StableRNGs", "StaticArrays", "Test", "Zygote"]

[sources]
NonlinearSolveBase = {path = "../NonlinearSolveBase"}
3 changes: 3 additions & 0 deletions lib/NonlinearSolveSpectralMethods/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "BenchmarkTools", "ExplicitImports", "Hwloc", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "ReTestItems", "StableRNGs", "StaticArrays", "Test"]

[sources]
NonlinearSolveBase = {path = "../NonlinearSolveBase"}
3 changes: 3 additions & 0 deletions lib/SCCNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "BenchmarkTools", "ExplicitImports", "Hwloc", "InteractiveUtils", "NonlinearProblemLibrary", "NonlinearSolveBase", "NonlinearSolveFirstOrder", "Pkg", "ReTestItems", "StableRNGs", "StaticArrays", "Test"]

[sources]
NonlinearSolveFirstOrder = {path = "../NonlinearSolveFirstOrder"}
3 changes: 3 additions & 0 deletions lib/SimpleNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "DiffEqBase", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "StaticArrays", "Test", "TestItemRunner", "Tracker", "Zygote"]

[sources]
NonlinearSolveBase = {path = "../NonlinearSolveBase"}
2 changes: 1 addition & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,6 @@ export NonlinearSolvePolyAlgorithm, FastShortcutNonlinearPolyalg, FastShortcutNL
# Extension Algorithms
export LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL,
FixedPointAccelerationJL, SpeedMappingJL, SIAMFANLEquationsJL
export PETScSNES, CMINPACK
export PETScSNES, GridapPETScSNES, CMINPACK

end
13 changes: 13 additions & 0 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,16 @@ function PETScSNES(; petsclib = missing, autodiff = nothing, mpi_comm = missing,
end
return PETScSNES(petsclib, mpi_comm, autodiff, kwargs)
end

# TODO: Docs
@concrete struct GridapPETScSNES <: AbstractNonlinearSolveAlgorithm
autodiff
snes_options
end

function GridapPETScSNES(; autodiff = nothing, kwargs...)
if Base.get_extension(@__MODULE__, :NonlinearSolveGridapPETScExt) === nothing
error("`GridapPETScSNES` requires `GridapPETSc.jl` to be loaded")
end
return GridapPETScSNES(autodiff, kwargs)
end
Loading