Skip to content

Commit

Permalink
Trust Region mostly works
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 5, 2023
1 parent 13e590e commit 445e97b
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 344 deletions.
98 changes: 49 additions & 49 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ include("trace.jl")
include("extension_algs.jl")
include("linesearch.jl")
include("raphson.jl")
# include("trustRegion.jl")
include("trustRegion.jl")
include("levenberg.jl")
include("gaussnewton.jl")
include("dfsane.jl")
Expand All @@ -179,54 +179,54 @@ include("klement.jl")
include("lbroyden.jl")
include("jacobian.jl")
include("ad.jl")
# include("default.jl")

# @setup_workload begin
# nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
# (NonlinearFunction{false}((u, p) -> u .* u .- p), [0.1]),
# (NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1]))
# probs_nls = NonlinearProblem[]
# for T in (Float32, Float64), (fn, u0) in nlfuncs
# push!(probs_nls, NonlinearProblem(fn, T.(u0), T(2)))
# end

# nls_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(),
# GeneralBroyden(), GeneralKlement(), DFSane(), nothing)

# probs_nlls = NonlinearLeastSquaresProblem[]
# nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
# (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
# (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
# resid_prototype = zeros(1)), [0.1, 0.0]),
# (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
# resid_prototype = zeros(4)), [0.1, 0.1]))
# for (fn, u0) in nlfuncs
# push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0))
# end
# nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), Float32[0.1, 0.0]),
# (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)),
# Float32[0.1, 0.1]),
# (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
# resid_prototype = zeros(Float32, 1)), Float32[0.1, 0.0]),
# (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
# resid_prototype = zeros(Float32, 4)), Float32[0.1, 0.1]))
# for (fn, u0) in nlfuncs
# push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0f0))
# end

# nlls_algs = (LevenbergMarquardt(), GaussNewton(),
# LevenbergMarquardt(; linsolve = LUFactorization()),
# GaussNewton(; linsolve = LUFactorization()))

# @compile_workload begin
# for prob in probs_nls, alg in nls_algs
# solve(prob, alg, abstol = 1e-2)
# end
# for prob in probs_nlls, alg in nlls_algs
# solve(prob, alg, abstol = 1e-2)
# end
# end
# end
include("default.jl")

@setup_workload begin
nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
(NonlinearFunction{false}((u, p) -> u .* u .- p), [0.1]),
(NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1]))
probs_nls = NonlinearProblem[]
for T in (Float32, Float64), (fn, u0) in nlfuncs
push!(probs_nls, NonlinearProblem(fn, T.(u0), T(2)))
end

nls_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(),
GeneralBroyden(), GeneralKlement(), DFSane(), nothing)

probs_nlls = NonlinearLeastSquaresProblem[]
nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
(NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
(NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
resid_prototype = zeros(1)), [0.1, 0.0]),
(NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
resid_prototype = zeros(4)), [0.1, 0.1]))
for (fn, u0) in nlfuncs
push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0))
end
nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), Float32[0.1, 0.0]),
(NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)),
Float32[0.1, 0.1]),
(NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
resid_prototype = zeros(Float32, 1)), Float32[0.1, 0.0]),
(NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
resid_prototype = zeros(Float32, 4)), Float32[0.1, 0.1]))
for (fn, u0) in nlfuncs
push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0f0))
end

nlls_algs = (LevenbergMarquardt(), GaussNewton(),
LevenbergMarquardt(; linsolve = LUFactorization()),
GaussNewton(; linsolve = LUFactorization()))

@compile_workload begin
for prob in probs_nls, alg in nls_algs
solve(prob, alg, abstol = 1e-2)
end
for prob in probs_nlls, alg in nlls_algs
solve(prob, alg, abstol = 1e-2)
end
end
end

export RadiusUpdateSchemes

Expand Down
4 changes: 2 additions & 2 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ function perform_step!(cache::GaussNewtonCache{iip}) where {iip}

# Use normal form to solve the Linear Problem
if cache.JᵀJ !== nothing
__update_JᵀJ!(cache, Val(:JᵀJ))
__update_Jᵀf!(cache, Val(:JᵀJ))
__update_JᵀJ!(cache)
__update_Jᵀf!(cache)
A, b = __maybe_symmetric(cache.JᵀJ), _vec(cache.Jᵀf)
else
A, b = cache.J, _vec(cache.fu)
Expand Down
46 changes: 34 additions & 12 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number,
kwargs...) where {needsJᵀJ, F}
# NOTE: Scalar `u` assumes scalar output from `f`
uf = SciMLBase.JacobianWrapper{false}(f, p)
return uf, FakeLinearSolveJLCache(u, u), u, nothing, nothing, u, u, u
return uf, FakeLinearSolveJLCache(u, u), u, zero(u), nothing, u, u, u
end

# Linear Solve Cache
Expand Down Expand Up @@ -208,27 +208,49 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
end
end

# jvp fallback scalar
__jacvec(args...; kwargs...) = JacVec(args...; kwargs...)
function __jacvec(uf, u::Number; autodiff, kwargs...)
@assert autodiff isa AutoForwardDiff "Only ForwardDiff is currently supported."
return JVPScalar(uf, u, autodiff)
end

@concrete mutable struct JVPScalar
uf
u
autodiff
end

function Base.:*(jvp::JVPScalar, v)
T = typeof(ForwardDiff.Tag(typeof(jvp.uf), typeof(jvp.u)))
out = jvp.uf(ForwardDiff.Dual{T}(jvp.u, v))
return ForwardDiff.extract_derivative(T, out)
end

# Generic Handling of Krylov Methods for Normal Form Linear Solves
function __update_JᵀJ!(cache::AbstractNonlinearSolveCache)
function __update_JᵀJ!(cache::AbstractNonlinearSolveCache, J = nothing)
if !(cache.JᵀJ isa KrylovJᵀJ)
@bb cache.JᵀJ = transpose(cache.J) × cache.J
J_ = ifelse(J === nothing, cache.J, J)
@bb cache.JᵀJ = transpose(J_) × J_
end
end

function __update_Jᵀf!(cache::AbstractNonlinearSolveCache)
function __update_Jᵀf!(cache::AbstractNonlinearSolveCache, J = nothing)
if cache.JᵀJ isa KrylovJᵀJ
@bb cache.Jᵀf = cache.JᵀJ.Jᵀ × cache.fu
else
@bb cache.Jᵀf = transpose(cache.J) × vec(cache.fu)
J_ = ifelse(J === nothing, cache.J, J)
@bb cache.Jᵀf = transpose(J_) × vec(cache.fu)
end
end

# Left-Right Multiplication
__lr_mul(::Val, H, g) = dot(g, H, g)
## TODO: Use a cache here to avoid allocations
__lr_mul(::Val{false}, H::KrylovJᵀJ, g) = dot(g, H.JᵀJ, g)
function __lr_mul(::Val{true}, H::KrylovJᵀJ, g)
c = similar(g)
mul!(c, H.JᵀJ, g)
return dot(g, c)
__lr_mul(cache::AbstractNonlinearSolveCache) = __lr_mul(cache, cache.JᵀJ, cache.Jᵀf)
function __lr_mul(cache::AbstractNonlinearSolveCache, JᵀJ::KrylovJᵀJ, Jᵀf)
@bb cache.lr_mul_cache = JᵀJ.JᵀJ × vec(Jᵀf)
return dot(_vec(Jᵀf), _vec(cache.lr_mul_cache))
end
function __lr_mul(cache::AbstractNonlinearSolveCache, JᵀJ, Jᵀf)
@bb cache.lr_mul_cache = JᵀJ × Jᵀf
return dot(_vec(Jᵀf), _vec(cache.lr_mul_cache))
end
Loading

0 comments on commit 445e97b

Please sign in to comment.