Skip to content

Commit

Permalink
Start cleaning up TrustRegion
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 1, 2023
1 parent 031639f commit f18fe15
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 125 deletions.
1 change: 1 addition & 0 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
end

# Generic Handling of Krylov Methods for Normal Form Linear Solves
# FIXME: Use MaybeInplace here for efficient matmuls
function __update_JᵀJ!(iip::Val, cache, sym::Symbol, J)
return __update_JᵀJ!(iip, cache, sym, getproperty(cache, sym), J)
end
Expand Down
171 changes: 53 additions & 118 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
`RadiusUpdateSchemes`
RadiusUpdateSchemes
`RadiusUpdateSchemes` is the standard enum interface for different types of radius update schemes
implemented in the Trust Region method. These schemes specify how the radius of the so-called trust region
Expand All @@ -16,7 +16,7 @@ states as `RadiusUpdateSchemes.T`. Simply put the desired scheme as follows:
"""
@enumx RadiusUpdateSchemes begin
"""
`RadiusUpdateSchemes.Simple`
RadiusUpdateSchemes.Simple
The simple or conventional radius update scheme. This scheme is chosen by default
and follows the conventional approach to update the trust region radius, i.e. if the
Expand All @@ -26,21 +26,21 @@ states as `RadiusUpdateSchemes.T`. Simply put the desired scheme as follows:
Simple

"""
`RadiusUpdateSchemes.NLsolve`
RadiusUpdateSchemes.NLsolve
The same updating scheme as in NLsolve's (https://github.com/JuliaNLSolvers/NLsolve.jl) trust region dogleg implementation.
"""
NLsolve

"""
`RadiusUpdateSchemes.NocedalWright`
RadiusUpdateSchemes.NocedalWright
Trust region updating scheme as in Nocedal and Wright [see Alg 11.5, page 291].
"""
NocedalWright

"""
`RadiusUpdateSchemes.Hei`
RadiusUpdateSchemes.Hei
This scheme is proposed by [Hei, L.] (https://www.jstor.org/stable/43693061). The trust region radius
depends on the size (norm) of the current step size. The hypothesis is to let the radius converge to zero
Expand All @@ -50,7 +50,7 @@ states as `RadiusUpdateSchemes.T`. Simply put the desired scheme as follows:
Hei

"""
`RadiusUpdateSchemes.Yuan`
RadiusUpdateSchemes.Yuan
This scheme is proposed by [Yuan, Y.] (https://www.researchgate.net/publication/249011466_A_new_trust_region_algorithm_with_trust_region_radius_converging_to_zero).
Similar to Hei's scheme, the trust region is updated in a way so that it converges to zero, however here,
Expand All @@ -60,7 +60,7 @@ states as `RadiusUpdateSchemes.T`. Simply put the desired scheme as follows:
Yuan

"""
`RadiusUpdateSchemes.Bastin`
RadiusUpdateSchemes.Bastin
This scheme is proposed by [Bastin, et al.] (https://www.researchgate.net/publication/225100660_A_retrospective_trust-region_method_for_unconstrained_optimization).
The scheme is called a retrospective update scheme as it uses the model function at the current
Expand All @@ -71,7 +71,7 @@ states as `RadiusUpdateSchemes.T`. Simply put the desired scheme as follows:
Bastin

"""
`RadiusUpdateSchemes.Fan`
RadiusUpdateSchemes.Fan
This scheme is proposed by [Fan, J.] (https://link.springer.com/article/10.1007/s10589-005-3078-8). It is very much similar to
Hei's and Yuan's schemes as it lets the trust region radius depend on the current size (norm) of the objective (merit)
Expand Down Expand Up @@ -170,7 +170,7 @@ function set_ad(alg::TrustRegion{CJ}, ad) where {CJ}
end

function TrustRegion(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
radius_update_scheme::RadiusUpdateSchemes.T = RadiusUpdateSchemes.Simple, #defaults to conventional radius update
radius_update_scheme::RadiusUpdateSchemes.T = RadiusUpdateSchemes.Simple,
max_trust_radius::Real = 0 // 1, initial_trust_radius::Real = 0 // 1,
step_threshold::Real = 1 // 10000, shrink_threshold::Real = 1 // 4,
expand_threshold::Real = 3 // 4, shrink_factor::Real = 1 // 4,
Expand Down Expand Up @@ -233,6 +233,7 @@ end
trace
end

# TODO: add J_cache
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, args...;
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing, internalnorm = DEFAULT_NORM, linsolve_kwargs = (;),
Expand All @@ -244,7 +245,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
fu1 = evaluate_f(prob, u)
fu_prev = zero(fu1)

loss = get_loss(fu1)
loss = __get_trust_region_loss(fu1)
uf, _, J, fu2, jac_cache, du, H, g = jacobian_caches(alg, f, u, p, Val(iip);
linsolve_kwargs, linsolve_with_JᵀJ = Val(true), lininit = Val(false))
g = _restructure(fu1, g)
Expand Down Expand Up @@ -350,92 +351,54 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
p1, p2, p3, p4, ϵ, NLStats(1, 0, 0, 0, 0), tc_cache, trace)
end

function perform_step!(cache::TrustRegionCache{true})
@unpack make_new_J, J, fu, f, u, p, u_gauss_newton, alg, linsolve = cache
function perform_step!(cache::TrustRegionCache{iip}) where {iip}
if cache.make_new_J
jacobian!!(J, cache)
__update_JᵀJ!(Val{true}(), cache, :H, J)
__update_Jᵀf!(Val{true}(), cache, :g, :H, J, _vec(fu))
cache.J = jacobian!!(cache.J, cache)

__update_JᵀJ!(Val{iip}(), cache, :H, cache.J)
__update_Jᵀf!(Val{iip}(), cache, :g, :H, cache.J, _vec(cache.fu))
cache.stats.njacs += 1

# do not use A = cache.H, b = _vec(cache.g) since it is equivalent
# to A = cache.J, b = _vec(fu) as long as the Jacobian is non-singular
linres = dolinsolve(alg.precs, linsolve, A = J, b = _vec(fu),
linu = _vec(u_gauss_newton), p = p, reltol = cache.abstol)
linres = dolinsolve(cache.alg.precs, cache.linsolve, A = cache.J,
b = _vec(cache.fu), linu = _vec(cache.u_gauss_newton), p = cache.p,
reltol = cache.abstol)
cache.linsolve = linres.cache
@. cache.u_gauss_newton = -1 * u_gauss_newton
end

# Compute dogleg step
dogleg!(cache)

# Compute the potentially new u
@. cache.u_tmp = u + cache.du
f(cache.fu_new, cache.u_tmp, p)
trust_region_step!(cache)
cache.stats.nf += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1
return nothing
end

function perform_step!(cache::TrustRegionCache{false})
@unpack make_new_J, fu, f, u, p = cache

if make_new_J
J = jacobian!!(cache.J, cache)
__update_JᵀJ!(Val{false}(), cache, :H, J)
__update_Jᵀf!(Val{false}(), cache, :g, :H, J, _vec(fu))
cache.stats.njacs += 1

if cache.linsolve === nothing
# Scalar
cache.u_gauss_newton = -cache.H \ cache.g
else
# do not use A = cache.H, b = _vec(cache.g) since it is equivalent
# to A = cache.J, b = _vec(fu) as long as the Jacobian is non-singular
linres = dolinsolve(cache.alg.precs, cache.linsolve, A = cache.J, b = _vec(fu),
linu = _vec(cache.u_gauss_newton), p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. cache.u_gauss_newton *= -1
end
cache.u_gauss_newton = _restructure(cache.u_gauss_newton, linres.u)
@bb @. cache.u_gauss_newton *= -1
end

# Compute the Newton step.
# compute dogleg step
dogleg!(cache)

# Compute the potentially new u
cache.u_tmp = u + cache.du

cache.fu_new = f(cache.u_tmp, p)
# compute the potentially new u
@bb @. cache.u_cache_2 = cache.u + cache.du
evaluate_f(cache, cache.u_tmp, cache.p, Val{:fu_cache_2}())
trust_region_step!(cache)
cache.stats.nf += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1
return nothing
end

function retrospective_step!(cache::TrustRegionCache)
@unpack J, fu_prev, fu, u_prev, u = cache
J = jacobian!!(deepcopy(J), cache)
if J isa Number
cache.H = J' * J
cache.g = J' * fu
else
__update_JᵀJ!(Val{isinplace(cache)}(), cache, :H, J)
__update_Jᵀf!(Val{isinplace(cache)}(), cache, :g, :H, J, fu)
end
function retrospective_step!(cache::TrustRegionCache{iip}) where {iip}
J = jacobian!!(cache.J_cache, cache)
__update_JᵀJ!(Val{iip}(), cache, :H, J)
__update_Jᵀf!(Val{iip}(), cache, :g, :H, J, cache.fu)
cache.stats.njacs += 1
@unpack H, g, du = cache

return -(get_loss(fu_prev) - get_loss(fu)) /
(dot(_vec(du), _vec(g)) + __lr_mul(Val(isinplace(cache)), H, _vec(du)) / 2)
# FIXME: Caching in __lr_mul
num = __get_trust_region_loss(cache.fu) - __get_trust_region_loss(cache.fu_cache)
denom = dot(_vec(du), _vec(g)) + __lr_mul(Val{iip}(), H, _vec(du)) / 2
return num / denom
end

# TODO
function trust_region_step!(cache::TrustRegionCache)
@unpack fu_new, du, g, H, loss, max_trust_r, radius_update_scheme = cache

cache.loss_new = get_loss(fu_new)
cache.loss_new = __get_trust_region_loss(fu_new)

# Compute the ratio of the actual reduction to the predicted reduction.
cache.r = -(loss - cache.loss_new) /
Expand Down Expand Up @@ -556,6 +519,7 @@ function trust_region_step!(cache::TrustRegionCache)
end

@unpack p1 = cache
# TODO: Use the `vjp_autodiff` to for the jvp
cache.trust_r = p1 * cache.internalnorm(jvp!(cache))

update_trace!(cache.trace, cache.stats.nsteps + 1, cache.u, cache.fu, cache.J,
Expand Down Expand Up @@ -608,6 +572,7 @@ function trust_region_step!(cache::TrustRegionCache)
end
end

# TODO
function dogleg!(cache::TrustRegionCache{true})
@unpack u_tmp, u_gauss_newton, u_cauchy, trust_r = cache

Expand Down Expand Up @@ -638,6 +603,7 @@ function dogleg!(cache::TrustRegionCache{true})
@. cache.du = u_cauchy + τ * u_tmp
end

# TODO
function dogleg!(cache::TrustRegionCache{false})
@unpack u_tmp, u_gauss_newton, u_cauchy, trust_r = cache

Expand Down Expand Up @@ -667,20 +633,14 @@ function dogleg!(cache::TrustRegionCache{false})
cache.du = u_cauchy + τ * u_tmp
end

function take_step!(cache::TrustRegionCache{true})
cache.u_prev .= cache.u
cache.u .= cache.u_tmp
cache.fu_prev .= cache.fu
cache.fu .= cache.fu_new
end

function take_step!(cache::TrustRegionCache{false})
cache.u_prev = cache.u
cache.u = cache.u_tmp
cache.fu_prev = cache.fu
cache.fu = cache.fu_new
function __take_step!(cache::TrustRegionCache)
@bb copyto!(cache.u_cache, cache.u)
@bb copyto!(cache.u, cache.u_cache_2) # u_tmp --> u_cache_2
@bb copyto!(cache.fu_cache, cache.fu)
@bb copyto!(cache.fu, cache.fu_cache_2) # fu_new --> fu_cache_2
end

# TODO
function jvp!(cache::TrustRegionCache{false})
@unpack f, u, fu, uf = cache
if isa(u, Number)
Expand Down Expand Up @@ -710,40 +670,15 @@ function not_terminated(cache::TrustRegionCache)
end
return true
end
get_fu(cache::TrustRegionCache) = cache.fu
set_fu!(cache::TrustRegionCache, fu) = (cache.fu = fu)

function SciMLBase.reinit!(cache::TrustRegionCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, reltol = cache.reltol, maxiters = cache.maxiters,
termination_condition = get_termination_mode(cache.tc_cache)) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
cache.f(cache.fu, cache.u, p)
else
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
cache.u = u0
cache.fu = cache.f(cache.u, p)
end

reset!(cache.trace)
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u,
termination_condition)

cache.abstol = abstol
cache.reltol = reltol
cache.tc_cache = tc_cache
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
cache.force_stop = false
cache.retcode = ReturnCode.Default
cache.make_new_J = true
cache.loss = get_loss(cache.fu)
function __reinit_internal!(cache::TrustRegionCache; kwargs...)
cache.loss = __get_trust_region_loss(cache.fu)
cache.shrink_counter = 0
cache.trust_r = convert(eltype(cache.u), cache.alg.initial_trust_radius)
if iszero(cache.trust_r)
cache.trust_r = convert(eltype(cache.u), cache.max_trust_r / 11)
end
return cache
cache.trust_r = convert(eltype(cache.u),
ifelse(cache.alg.initial_trust_radius == 0, cache.alg.initial_trust_radius,
cache.max_trust_r / 11))
cache.make_new_J = true
return nothing
end

__get_trust_region_loss(fu) = norm(fu)^2 / 2
21 changes: 14 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,6 @@ function wrapprecs(_Pl, _Pr, weight)
return Pl, Pr
end

get_loss(fu) = norm(fu)^2 / 2

function rfunc(r::R, c2::R, M::R, γ1::R, γ2::R, β::R) where {R <: Real} # R-function for adaptive trust region method
if (r c2)
return (2 * (M - 1 - γ2) * atan(r - c2) + (1 + γ2)) / π
Expand Down Expand Up @@ -188,7 +186,7 @@ function evaluate_f(prob::Union{NonlinearProblem{uType, iip},
return fu
end

function evaluate_f(f, u, p, ::Val{iip}; fu = nothing) where {iip}
function evaluate_f(f::F, u, p, ::Val{iip}; fu = nothing) where {F, iip <: Bool}
if iip
f(fu, u, p)
return fu
Expand All @@ -197,11 +195,20 @@ function evaluate_f(f, u, p, ::Val{iip}; fu = nothing) where {iip}
end
end

function evaluate_f(cache, u, p)
if isinplace(cache)
cache.prob.f(get_fu(cache), u, p)
function evaluate_f(cache::AbstractNonlinearSolveCache, u, p,
fu_sym::Val{FUSYM} = Val(nothing)) where {FUSYM}
if FUSYM === nothing
if isinplace(cache)
cache.prob.f(get_fu(cache), u, p)
else
set_fu!(cache, cache.prob.f(u, p))
end
else
set_fu!(cache, cache.prob.f(u, p))
if isinplace(cache)
cache.prob.f(__getproperty(cache, fu_sym), u, p)
else
setproperty!(cache, FUSYM, cache.prob.f(u, p))
end
end
return nothing
end
Expand Down

0 comments on commit f18fe15

Please sign in to comment.