From 2f7a772b6733ee3b3b35d11a6fa3b0fd8d293f9f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 7 Dec 2023 18:32:48 -0500 Subject: [PATCH] Add NLsolve --- docs/src/api/nlsolve.md | 2 +- docs/src/solvers/NonlinearSystemSolvers.md | 2 +- ext/NonlinearSolveMINPACKExt.jl | 9 ++- ext/NonlinearSolveNLsolveExt.jl | 77 ++++++++++++++++++++++ src/extension_algs.jl | 18 ++--- test/nlsolve.jl | 18 ++--- 6 files changed, 101 insertions(+), 25 deletions(-) diff --git a/docs/src/api/nlsolve.md b/docs/src/api/nlsolve.md index 6de117872..05db0937f 100644 --- a/docs/src/api/nlsolve.md +++ b/docs/src/api/nlsolve.md @@ -13,5 +13,5 @@ using NLSolve, NonlinearSolve ## Solver API ```@docs -NLSolveJL +NLsolveJL ``` diff --git a/docs/src/solvers/NonlinearSystemSolvers.md b/docs/src/solvers/NonlinearSystemSolvers.md index 55e799efb..776e58b9c 100644 --- a/docs/src/solvers/NonlinearSystemSolvers.md +++ b/docs/src/solvers/NonlinearSystemSolvers.md @@ -114,7 +114,7 @@ computationally expensive than direct methods. This is a wrapper package for importing solvers from NLsolve.jl into the SciML interface. - - `NLSolveJL()`: A wrapper for [NLsolve.jl](https://github.com/JuliaNLSolvers/NLsolve.jl) + - `NLsolveJL()`: A wrapper for [NLsolve.jl](https://github.com/JuliaNLSolvers/NLsolve.jl) Submethod choices for this algorithm include: diff --git a/ext/NonlinearSolveMINPACKExt.jl b/ext/NonlinearSolveMINPACKExt.jl index 588e93fa4..5f15d85b3 100644 --- a/ext/NonlinearSolveMINPACKExt.jl +++ b/ext/NonlinearSolveMINPACKExt.jl @@ -19,7 +19,6 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip}, # unwrapping alg params show_trace = alg.show_trace tracing = alg.tracing - io = alg.io if !iip && prob.u0 isa Number f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0)) @@ -36,9 +35,10 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip}, u = zero(u0) resid = NonlinearSolve.evaluate_f(prob, u) m = length(resid) + size_jac = (length(resid), length(u)) method = ifelse(alg.method === :auto, - ifelse(prob isa NonlinearLeastSquaresProblem, :lm, :hydr), alg.method) + ifelse(prob isa NonlinearLeastSquaresProblem, :lm, :hybr), alg.method) if SciMLBase.has_jac(prob.f) if !iip && prob.u0 isa Number @@ -51,9 +51,8 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip}, g! = (du, u) -> prob.f.jac(du, u, p) else # Then it's an in-place function on an abstract array g! = function (du, u) - prob.f.jac(reshape(du, sizeu), reshape(u, sizeu), p) - du = vec(du) - return CInt(0) + prob.f.jac(reshape(du, size_jac), reshape(u, sizeu), p) + return Cint(0) end end original = MINPACK.fsolve(f!, g!, u0, m; tol = abstol, show_trace, tracing, method, diff --git a/ext/NonlinearSolveNLsolveExt.jl b/ext/NonlinearSolveNLsolveExt.jl index a9424fdf7..5350e9be2 100644 --- a/ext/NonlinearSolveNLsolveExt.jl +++ b/ext/NonlinearSolveNLsolveExt.jl @@ -1,3 +1,80 @@ module NonlinearSolveNLsolveExt +using NonlinearSolve, NLsolve, DiffEqBase, SciMLBase +import UnPack: @unpack + +function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abstol = 1e-6, + maxiters = 1000, alias_u0::Bool = false, kwargs...) + if typeof(prob.u0) <: Number + u0 = [prob.u0] + else + u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0) + end + + iip = isinplace(prob) + + sizeu = size(prob.u0) + p = prob.p + + # unwrapping alg params + @unpack method, autodiff, store_trace, extended_trace, linesearch, linsolve = alg + @unpack factor, autoscale, m, beta, show_trace = alg + + if !iip && prob.u0 isa Number + f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0)) + elseif !iip && prob.u0 isa Vector{Float64} + f! = (du, u) -> (du .= prob.f(u, p); Cint(0)) + elseif !iip && prob.u0 isa AbstractArray + f! = (du, u) -> (du .= vec(prob.f(reshape(u, sizeu), p)); Cint(0)) + elseif prob.u0 isa Vector{Float64} + f! = (du, u) -> prob.f(du, u, p) + else # Then it's an in-place function on an abstract array + f! = (du, u) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p); du = vec(du); 0) + end + + if prob.u0 isa Number + resid = [NonlinearSolve.evaluate_f(prob, first(u0))] + else + resid = NonlinearSolve.evaluate_f(prob, u0) + end + + size_jac = (length(resid), length(u0)) + + if SciMLBase.has_jac(prob.f) + if !iip && prob.u0 isa Number + g! = (du, u) -> (du .= prob.f.jac(first(u), p); Cint(0)) + elseif !iip && prob.u0 isa Vector{Float64} + g! = (du, u) -> (du .= prob.f.jac(u, p); Cint(0)) + elseif !iip && prob.u0 isa AbstractArray + g! = (du, u) -> (du .= vec(prob.f.jac(reshape(u, sizeu), p)); Cint(0)) + elseif prob.u0 isa Vector{Float64} + g! = (du, u) -> prob.f.jac(du, u, p) + else # Then it's an in-place function on an abstract array + g! = function (du, u) + prob.f.jac(reshape(du, size_jac), reshape(u, sizeu), p) + return Cint(0) + end + end + if prob.f.jac_prototype !== nothing + J = zero(prob.f.jac_prototype) + df = OnceDifferentiable(f!, g!, u0, resid, J) + else + df = OnceDifferentiable(f!, g!, u0, resid) + end + else + df = OnceDifferentiable(f!, u0, resid; autodiff) + end + + original = nlsolve(df, u0; ftol = abstol, iterations = maxiters, method, store_trace, + extended_trace, linesearch, linsolve, factor, autoscale, m, beta, show_trace) + + u = reshape(original.zero, size(u0)) + f!(resid, u) + retcode = original.x_converged || original.f_converged ? ReturnCode.Success : + ReturnCode.Failure + stats = SciMLBase.NLStats(original.f_calls, original.g_calls, original.g_calls, + original.g_calls, original.iterations) + return SciMLBase.build_solution(prob, alg, u, resid; retcode, original, stats) +end + end diff --git a/src/extension_algs.jl b/src/extension_algs.jl index d59627a67..c825f98a5 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -143,10 +143,10 @@ function CMINPACK(; show_trace::Bool = false, tracing::Bool = false, method::Sym end """ - NLSolveJL(; method=:trust_region, autodiff=:central, store_trace=false, - extended_trace=false, linesearch=LineSearches.Static(), - linsolve=(x, A, b) -> copyto!(x, A\\b), factor = one(Float64), autoscale=true, - m=10, beta=one(Float64), show_trace=false) + NLsolveJL(; method=:trust_region, autodiff=:central, store_trace=false, + extended_trace=false, linesearch=LineSearches.Static(), + linsolve=(x, A, b) -> copyto!(x, A\\b), factor = one(Float64), autoscale=true, + m=10, beta=one(Float64), show_trace=false) ### Keyword Arguments @@ -171,7 +171,7 @@ end ### Submethod Choice -Choices for methods in `NLSolveJL`: +Choices for methods in `NLsolveJL`: - `:anderson`: Anderson-accelerated fixed-point iteration - `:broyden`: Broyden's quasi-Newton method @@ -180,7 +180,7 @@ Choices for methods in `NLSolveJL`: these arguments, consult the [NLsolve.jl documentation](https://github.com/JuliaNLSolvers/NLsolve.jl). """ -@concrete struct NLSolveJL <: AbstractNonlinearAlgorithm +@concrete struct NLsolveJL <: AbstractNonlinearAlgorithm method::Symbol autodiff::Symbol store_trace::Bool @@ -194,14 +194,14 @@ Choices for methods in `NLSolveJL`: show_trace::Bool end -function NLSolveJL(; method = :trust_region, autodiff = :central, store_trace = false, +function NLsolveJL(; method = :trust_region, autodiff = :central, store_trace = false, extended_trace = false, linesearch = LineSearches.Static(), linsolve = (x, A, b) -> copyto!(x, A \ b), factor = 1.0, autoscale = true, m = 10, beta = one(Float64), show_trace = false) if Base.get_extension(@__MODULE__, :NonlinearSolveNLsolveExt) === nothing - error("NLSolveJL requires NLsolve.jl to be loaded") + error("NLsolveJL requires NLsolve.jl to be loaded") end - return NLSolveJL(method, autodiff, store_trace, extended_trace, linesearch, linsolve, + return NLsolveJL(method, autodiff, store_trace, extended_trace, linesearch, linsolve, factor, autoscale, m, beta, show_trace) end diff --git a/test/nlsolve.jl b/test/nlsolve.jl index e1d7714d0..ebda5d272 100644 --- a/test/nlsolve.jl +++ b/test/nlsolve.jl @@ -9,7 +9,7 @@ u0 = zeros(2) prob_iip = SteadyStateProblem(f_iip, u0) abstol = 1e-8 -for alg in [NLSolveJL()] +for alg in [NLsolveJL()] sol = solve(prob_iip, alg) @test sol.retcode == ReturnCode.Success p = nothing @@ -24,7 +24,7 @@ f_oop(u, p, t) = [2 - 2u[1], u[1] - 4u[2]] u0 = zeros(2) prob_oop = SteadyStateProblem(f_oop, u0) -for alg in [NLSolveJL()] +for alg in [NLsolveJL()] sol = solve(prob_oop, alg) @test sol.retcode == ReturnCode.Success # test the solver is doing reasonable things for linear solve @@ -45,7 +45,7 @@ end u0 = zeros(2) prob_iip = NonlinearProblem{true}(f_iip, u0) abstol = 1e-8 -for alg in [NLSolveJL()] +for alg in [NLsolveJL()] local sol sol = solve(prob_iip, alg) @test sol.retcode == ReturnCode.Success @@ -60,7 +60,7 @@ end f_oop(u, p) = [2 - 2u[1], u[1] - 4u[2]] u0 = zeros(2) prob_oop = NonlinearProblem{false}(f_oop, u0) -for alg in [NLSolveJL()] +for alg in [NLsolveJL()] local sol sol = solve(prob_oop, alg) @test sol.retcode == ReturnCode.Success @@ -74,7 +74,7 @@ end f_tol(u, p) = u^2 - 2 prob_tol = NonlinearProblem(f_tol, 1.0) for tol in [1e-1, 1e-3, 1e-6, 1e-10, 1e-15] - sol = solve(prob_tol, NLSolveJL(), abstol = tol) + sol = solve(prob_tol, NLsolveJL(), abstol = tol) @test abs(sol.u[1] - sqrt(2)) < tol end @@ -85,7 +85,7 @@ function f!(fvec, x, p) end prob = NonlinearProblem{true}(f!, [0.1; 1.2]) -sol = solve(prob, NLSolveJL(autodiff = :central)) +sol = solve(prob, NLsolveJL(autodiff = :central)) du = zeros(2) f!(du, sol.u, nothing) @@ -98,7 +98,7 @@ function f!(fvec, x, p) end prob = NonlinearProblem{true}(f!, [0.1; 1.2]) -sol = solve(prob, NLSolveJL(autodiff = :forward)) +sol = solve(prob, NLsolveJL(autodiff = :forward)) du = zeros(2) f!(du, sol.u, nothing) @@ -131,8 +131,8 @@ f = NonlinearFunction(f!, jac = j!) p = A ProbN = NonlinearProblem(f, init, p) -sol = solve(ProbN, NLSolveJL(), reltol = 1e-8, abstol = 1e-8) +sol = solve(ProbN, NLsolveJL(), reltol = 1e-8, abstol = 1e-8) init = ones(Complex{Float64}, 152); ProbN = NonlinearProblem(f, init, p) -sol = solve(ProbN, NLSolveJL(), reltol = 1e-8, abstol = 1e-8) +sol = solve(ProbN, NLsolveJL(), reltol = 1e-8, abstol = 1e-8)