Skip to content

Commit

Permalink
Add NLsolve
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 7, 2023
1 parent 6b58f42 commit 2f7a772
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 25 deletions.
2 changes: 1 addition & 1 deletion docs/src/api/nlsolve.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ using NLSolve, NonlinearSolve
## Solver API

```@docs
NLSolveJL
NLsolveJL
```
2 changes: 1 addition & 1 deletion docs/src/solvers/NonlinearSystemSolvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ computationally expensive than direct methods.

This is a wrapper package for importing solvers from NLsolve.jl into the SciML interface.

- `NLSolveJL()`: A wrapper for [NLsolve.jl](https://github.com/JuliaNLSolvers/NLsolve.jl)
- `NLsolveJL()`: A wrapper for [NLsolve.jl](https://github.com/JuliaNLSolvers/NLsolve.jl)

Submethod choices for this algorithm include:

Expand Down
9 changes: 4 additions & 5 deletions ext/NonlinearSolveMINPACKExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
# unwrapping alg params
show_trace = alg.show_trace
tracing = alg.tracing
io = alg.io

if !iip && prob.u0 isa Number
f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0))

Check warning on line 24 in ext/NonlinearSolveMINPACKExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveMINPACKExt.jl#L24

Added line #L24 was not covered by tests
Expand All @@ -36,9 +35,10 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
u = zero(u0)
resid = NonlinearSolve.evaluate_f(prob, u)
m = length(resid)
size_jac = (length(resid), length(u))

method = ifelse(alg.method === :auto,
ifelse(prob isa NonlinearLeastSquaresProblem, :lm, :hydr), alg.method)
ifelse(prob isa NonlinearLeastSquaresProblem, :lm, :hybr), alg.method)

if SciMLBase.has_jac(prob.f)
if !iip && prob.u0 isa Number
Expand All @@ -51,9 +51,8 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
g! = (du, u) -> prob.f.jac(du, u, p)
else # Then it's an in-place function on an abstract array
g! = function (du, u)
prob.f.jac(reshape(du, sizeu), reshape(u, sizeu), p)
du = vec(du)
return CInt(0)
prob.f.jac(reshape(du, size_jac), reshape(u, sizeu), p)
return Cint(0)

Check warning on line 55 in ext/NonlinearSolveMINPACKExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveMINPACKExt.jl#L53-L55

Added lines #L53 - L55 were not covered by tests
end
end
original = MINPACK.fsolve(f!, g!, u0, m; tol = abstol, show_trace, tracing, method,
Expand Down
77 changes: 77 additions & 0 deletions ext/NonlinearSolveNLsolveExt.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,80 @@
module NonlinearSolveNLsolveExt

using NonlinearSolve, NLsolve, DiffEqBase, SciMLBase
import UnPack: @unpack

function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abstol = 1e-6,
maxiters = 1000, alias_u0::Bool = false, kwargs...)
if typeof(prob.u0) <: Number
u0 = [prob.u0]
else
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
end

iip = isinplace(prob)

sizeu = size(prob.u0)
p = prob.p

# unwrapping alg params
@unpack method, autodiff, store_trace, extended_trace, linesearch, linsolve = alg
@unpack factor, autoscale, m, beta, show_trace = alg

if !iip && prob.u0 isa Number
f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0))
elseif !iip && prob.u0 isa Vector{Float64}
f! = (du, u) -> (du .= prob.f(u, p); Cint(0))
elseif !iip && prob.u0 isa AbstractArray
f! = (du, u) -> (du .= vec(prob.f(reshape(u, sizeu), p)); Cint(0))

Check warning on line 28 in ext/NonlinearSolveNLsolveExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveNLsolveExt.jl#L28

Added line #L28 was not covered by tests
elseif prob.u0 isa Vector{Float64}
f! = (du, u) -> prob.f(du, u, p)
else # Then it's an in-place function on an abstract array
f! = (du, u) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p); du = vec(du); 0)
end

if prob.u0 isa Number
resid = [NonlinearSolve.evaluate_f(prob, first(u0))]
else
resid = NonlinearSolve.evaluate_f(prob, u0)
end

size_jac = (length(resid), length(u0))

if SciMLBase.has_jac(prob.f)
if !iip && prob.u0 isa Number
g! = (du, u) -> (du .= prob.f.jac(first(u), p); Cint(0))

Check warning on line 45 in ext/NonlinearSolveNLsolveExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveNLsolveExt.jl#L45

Added line #L45 was not covered by tests
elseif !iip && prob.u0 isa Vector{Float64}
g! = (du, u) -> (du .= prob.f.jac(u, p); Cint(0))

Check warning on line 47 in ext/NonlinearSolveNLsolveExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveNLsolveExt.jl#L47

Added line #L47 was not covered by tests
elseif !iip && prob.u0 isa AbstractArray
g! = (du, u) -> (du .= vec(prob.f.jac(reshape(u, sizeu), p)); Cint(0))

Check warning on line 49 in ext/NonlinearSolveNLsolveExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveNLsolveExt.jl#L49

Added line #L49 was not covered by tests
elseif prob.u0 isa Vector{Float64}
g! = (du, u) -> prob.f.jac(du, u, p)
else # Then it's an in-place function on an abstract array
g! = function (du, u)
prob.f.jac(reshape(du, size_jac), reshape(u, sizeu), p)
return Cint(0)
end
end
if prob.f.jac_prototype !== nothing
J = zero(prob.f.jac_prototype)
df = OnceDifferentiable(f!, g!, u0, resid, J)

Check warning on line 60 in ext/NonlinearSolveNLsolveExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveNLsolveExt.jl#L59-L60

Added lines #L59 - L60 were not covered by tests
else
df = OnceDifferentiable(f!, g!, u0, resid)
end
else
df = OnceDifferentiable(f!, u0, resid; autodiff)
end

original = nlsolve(df, u0; ftol = abstol, iterations = maxiters, method, store_trace,
extended_trace, linesearch, linsolve, factor, autoscale, m, beta, show_trace)

u = reshape(original.zero, size(u0))
f!(resid, u)
retcode = original.x_converged || original.f_converged ? ReturnCode.Success :
ReturnCode.Failure
stats = SciMLBase.NLStats(original.f_calls, original.g_calls, original.g_calls,
original.g_calls, original.iterations)
return SciMLBase.build_solution(prob, alg, u, resid; retcode, original, stats)
end

end
18 changes: 9 additions & 9 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ function CMINPACK(; show_trace::Bool = false, tracing::Bool = false, method::Sym
end

"""
NLSolveJL(; method=:trust_region, autodiff=:central, store_trace=false,
extended_trace=false, linesearch=LineSearches.Static(),
linsolve=(x, A, b) -> copyto!(x, A\\b), factor = one(Float64), autoscale=true,
m=10, beta=one(Float64), show_trace=false)
NLsolveJL(; method=:trust_region, autodiff=:central, store_trace=false,
extended_trace=false, linesearch=LineSearches.Static(),
linsolve=(x, A, b) -> copyto!(x, A\\b), factor = one(Float64), autoscale=true,
m=10, beta=one(Float64), show_trace=false)
### Keyword Arguments
Expand All @@ -171,7 +171,7 @@ end
### Submethod Choice
Choices for methods in `NLSolveJL`:
Choices for methods in `NLsolveJL`:
- `:anderson`: Anderson-accelerated fixed-point iteration
- `:broyden`: Broyden's quasi-Newton method
Expand All @@ -180,7 +180,7 @@ Choices for methods in `NLSolveJL`:
these arguments, consult the
[NLsolve.jl documentation](https://github.com/JuliaNLSolvers/NLsolve.jl).
"""
@concrete struct NLSolveJL <: AbstractNonlinearAlgorithm
@concrete struct NLsolveJL <: AbstractNonlinearAlgorithm
method::Symbol
autodiff::Symbol
store_trace::Bool
Expand All @@ -194,14 +194,14 @@ Choices for methods in `NLSolveJL`:
show_trace::Bool
end

function NLSolveJL(; method = :trust_region, autodiff = :central, store_trace = false,
function NLsolveJL(; method = :trust_region, autodiff = :central, store_trace = false,
extended_trace = false, linesearch = LineSearches.Static(),
linsolve = (x, A, b) -> copyto!(x, A \ b), factor = 1.0, autoscale = true, m = 10,

Check warning on line 199 in src/extension_algs.jl

View check run for this annotation

Codecov / codecov/patch

src/extension_algs.jl#L199

Added line #L199 was not covered by tests
beta = one(Float64), show_trace = false)
if Base.get_extension(@__MODULE__, :NonlinearSolveNLsolveExt) === nothing
error("NLSolveJL requires NLsolve.jl to be loaded")
error("NLsolveJL requires NLsolve.jl to be loaded")

Check warning on line 202 in src/extension_algs.jl

View check run for this annotation

Codecov / codecov/patch

src/extension_algs.jl#L202

Added line #L202 was not covered by tests
end

return NLSolveJL(method, autodiff, store_trace, extended_trace, linesearch, linsolve,
return NLsolveJL(method, autodiff, store_trace, extended_trace, linesearch, linsolve,
factor, autoscale, m, beta, show_trace)
end
18 changes: 9 additions & 9 deletions test/nlsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ u0 = zeros(2)
prob_iip = SteadyStateProblem(f_iip, u0)
abstol = 1e-8

for alg in [NLSolveJL()]
for alg in [NLsolveJL()]
sol = solve(prob_iip, alg)
@test sol.retcode == ReturnCode.Success
p = nothing
Expand All @@ -24,7 +24,7 @@ f_oop(u, p, t) = [2 - 2u[1], u[1] - 4u[2]]
u0 = zeros(2)
prob_oop = SteadyStateProblem(f_oop, u0)

for alg in [NLSolveJL()]
for alg in [NLsolveJL()]
sol = solve(prob_oop, alg)
@test sol.retcode == ReturnCode.Success
# test the solver is doing reasonable things for linear solve
Expand All @@ -45,7 +45,7 @@ end
u0 = zeros(2)
prob_iip = NonlinearProblem{true}(f_iip, u0)
abstol = 1e-8
for alg in [NLSolveJL()]
for alg in [NLsolveJL()]
local sol
sol = solve(prob_iip, alg)
@test sol.retcode == ReturnCode.Success
Expand All @@ -60,7 +60,7 @@ end
f_oop(u, p) = [2 - 2u[1], u[1] - 4u[2]]
u0 = zeros(2)
prob_oop = NonlinearProblem{false}(f_oop, u0)
for alg in [NLSolveJL()]
for alg in [NLsolveJL()]
local sol
sol = solve(prob_oop, alg)
@test sol.retcode == ReturnCode.Success
Expand All @@ -74,7 +74,7 @@ end
f_tol(u, p) = u^2 - 2
prob_tol = NonlinearProblem(f_tol, 1.0)
for tol in [1e-1, 1e-3, 1e-6, 1e-10, 1e-15]
sol = solve(prob_tol, NLSolveJL(), abstol = tol)
sol = solve(prob_tol, NLsolveJL(), abstol = tol)
@test abs(sol.u[1] - sqrt(2)) < tol
end

Expand All @@ -85,7 +85,7 @@ function f!(fvec, x, p)
end

prob = NonlinearProblem{true}(f!, [0.1; 1.2])
sol = solve(prob, NLSolveJL(autodiff = :central))
sol = solve(prob, NLsolveJL(autodiff = :central))

du = zeros(2)
f!(du, sol.u, nothing)
Expand All @@ -98,7 +98,7 @@ function f!(fvec, x, p)
end

prob = NonlinearProblem{true}(f!, [0.1; 1.2])
sol = solve(prob, NLSolveJL(autodiff = :forward))
sol = solve(prob, NLsolveJL(autodiff = :forward))

du = zeros(2)
f!(du, sol.u, nothing)
Expand Down Expand Up @@ -131,8 +131,8 @@ f = NonlinearFunction(f!, jac = j!)
p = A

ProbN = NonlinearProblem(f, init, p)
sol = solve(ProbN, NLSolveJL(), reltol = 1e-8, abstol = 1e-8)
sol = solve(ProbN, NLsolveJL(), reltol = 1e-8, abstol = 1e-8)

init = ones(Complex{Float64}, 152);
ProbN = NonlinearProblem(f, init, p)
sol = solve(ProbN, NLSolveJL(), reltol = 1e-8, abstol = 1e-8)
sol = solve(ProbN, NLsolveJL(), reltol = 1e-8, abstol = 1e-8)

0 comments on commit 2f7a772

Please sign in to comment.