diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 578343345..278667790 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -169,7 +169,7 @@ include("trace.jl") include("extension_algs.jl") include("linesearch.jl") include("raphson.jl") -# include("trustRegion.jl") +include("trustRegion.jl") include("levenberg.jl") include("gaussnewton.jl") include("dfsane.jl") @@ -179,54 +179,54 @@ include("klement.jl") include("lbroyden.jl") include("jacobian.jl") include("ad.jl") -# include("default.jl") - -# @setup_workload begin -# nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1), -# (NonlinearFunction{false}((u, p) -> u .* u .- p), [0.1]), -# (NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1])) -# probs_nls = NonlinearProblem[] -# for T in (Float32, Float64), (fn, u0) in nlfuncs -# push!(probs_nls, NonlinearProblem(fn, T.(u0), T(2))) -# end - -# nls_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(), -# GeneralBroyden(), GeneralKlement(), DFSane(), nothing) - -# probs_nlls = NonlinearLeastSquaresProblem[] -# nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]), -# (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]), -# (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p, -# resid_prototype = zeros(1)), [0.1, 0.0]), -# (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p), -# resid_prototype = zeros(4)), [0.1, 0.1])) -# for (fn, u0) in nlfuncs -# push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0)) -# end -# nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), Float32[0.1, 0.0]), -# (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), -# Float32[0.1, 0.1]), -# (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p, -# resid_prototype = zeros(Float32, 1)), Float32[0.1, 0.0]), -# (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p), -# resid_prototype = zeros(Float32, 4)), Float32[0.1, 0.1])) -# for (fn, u0) in nlfuncs -# push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0f0)) -# end - -# nlls_algs = (LevenbergMarquardt(), GaussNewton(), -# LevenbergMarquardt(; linsolve = LUFactorization()), -# GaussNewton(; linsolve = LUFactorization())) - -# @compile_workload begin -# for prob in probs_nls, alg in nls_algs -# solve(prob, alg, abstol = 1e-2) -# end -# for prob in probs_nlls, alg in nlls_algs -# solve(prob, alg, abstol = 1e-2) -# end -# end -# end +include("default.jl") + +@setup_workload begin + nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1), + (NonlinearFunction{false}((u, p) -> u .* u .- p), [0.1]), + (NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1])) + probs_nls = NonlinearProblem[] + for T in (Float32, Float64), (fn, u0) in nlfuncs + push!(probs_nls, NonlinearProblem(fn, T.(u0), T(2))) + end + + nls_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(), + GeneralBroyden(), GeneralKlement(), DFSane(), nothing) + + probs_nlls = NonlinearLeastSquaresProblem[] + nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]), + (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]), + (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p, + resid_prototype = zeros(1)), [0.1, 0.0]), + (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p), + resid_prototype = zeros(4)), [0.1, 0.1])) + for (fn, u0) in nlfuncs + push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0)) + end + nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), Float32[0.1, 0.0]), + (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), + Float32[0.1, 0.1]), + (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p, + resid_prototype = zeros(Float32, 1)), Float32[0.1, 0.0]), + (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p), + resid_prototype = zeros(Float32, 4)), Float32[0.1, 0.1])) + for (fn, u0) in nlfuncs + push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0f0)) + end + + nlls_algs = (LevenbergMarquardt(), GaussNewton(), + LevenbergMarquardt(; linsolve = LUFactorization()), + GaussNewton(; linsolve = LUFactorization())) + + @compile_workload begin + for prob in probs_nls, alg in nls_algs + solve(prob, alg, abstol = 1e-2) + end + for prob in probs_nlls, alg in nlls_algs + solve(prob, alg, abstol = 1e-2) + end + end +end export RadiusUpdateSchemes diff --git a/src/gaussnewton.jl b/src/gaussnewton.jl index 94f2e975a..9a227a7fa 100644 --- a/src/gaussnewton.jl +++ b/src/gaussnewton.jl @@ -116,8 +116,8 @@ function perform_step!(cache::GaussNewtonCache{iip}) where {iip} # Use normal form to solve the Linear Problem if cache.JᵀJ !== nothing - __update_JᵀJ!(cache, Val(:JᵀJ)) - __update_Jᵀf!(cache, Val(:JᵀJ)) + __update_JᵀJ!(cache) + __update_Jᵀf!(cache) A, b = __maybe_symmetric(cache.JᵀJ), _vec(cache.Jᵀf) else A, b = cache.J, _vec(cache.fu) diff --git a/src/jacobian.jl b/src/jacobian.jl index 2e539fcd8..60c42672c 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -138,7 +138,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, kwargs...) where {needsJᵀJ, F} # NOTE: Scalar `u` assumes scalar output from `f` uf = SciMLBase.JacobianWrapper{false}(f, p) - return uf, FakeLinearSolveJLCache(u, u), u, nothing, nothing, u, u, u + return uf, FakeLinearSolveJLCache(u, u), u, zero(u), nothing, u, u, u end # Linear Solve Cache @@ -208,27 +208,48 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf) end end +# jvp fallback scalar +__jacvec(args...; kwargs...) = JacVec(args...; kwargs...) +__jacvec(uf, u::Number; autodiff, kwargs...) = JVPScalar(uf, u, autodiff) + +@concrete mutable struct JVPScalar + uf + u + autodiff +end + +function Base.:*(jvp::JVPScalar, v) + if jvp.autodiff isa AutoForwardDiff + elseif jvp.autodiff isa AutoFiniteDiff + else + error("JVPScalar only supports AutoForwardDiff and AutoFiniteDiff") + end +end + # Generic Handling of Krylov Methods for Normal Form Linear Solves -function __update_JᵀJ!(cache::AbstractNonlinearSolveCache) +function __update_JᵀJ!(cache::AbstractNonlinearSolveCache, J = nothing) if !(cache.JᵀJ isa KrylovJᵀJ) - @bb cache.JᵀJ = transpose(cache.J) × cache.J + J_ = ifelse(J === nothing, cache.J, J) + @bb cache.JᵀJ = transpose(J_) × J_ end end -function __update_Jᵀf!(cache::AbstractNonlinearSolveCache) +function __update_Jᵀf!(cache::AbstractNonlinearSolveCache, J = nothing) if cache.JᵀJ isa KrylovJᵀJ @bb cache.Jᵀf = cache.JᵀJ.Jᵀ × cache.fu else - @bb cache.Jᵀf = transpose(cache.J) × vec(cache.fu) + J_ = ifelse(J === nothing, cache.J, J) + @bb cache.Jᵀf = transpose(J_) × vec(cache.fu) end end # Left-Right Multiplication -__lr_mul(::Val, H, g) = dot(g, H, g) -## TODO: Use a cache here to avoid allocations -__lr_mul(::Val{false}, H::KrylovJᵀJ, g) = dot(g, H.JᵀJ, g) -function __lr_mul(::Val{true}, H::KrylovJᵀJ, g) - c = similar(g) - mul!(c, H.JᵀJ, g) - return dot(g, c) +__lr_mul(cache::AbstractNonlinearSolveCache) = __lr_mul(cache, cache.JᵀJ, cache.Jᵀf) +function __lr_mul(cache::AbstractNonlinearSolveCache, JᵀJ::KrylovJᵀJ, Jᵀf) + @bb cache.lr_mul_cache = JᵀJ.JᵀJ × vec(Jᵀf) + return dot(_vec(Jᵀf), _vec(cache.lr_mul_cache)) +end +function __lr_mul(cache::AbstractNonlinearSolveCache, JᵀJ, Jᵀf) + @bb cache.lr_mul_cache = JᵀJ × Jᵀf + return dot(_vec(Jᵀf), _vec(cache.lr_mul_cache)) end diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 7e5497ffd..f27259d3f 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -182,19 +182,26 @@ function TrustRegion(; concrete_jac = nothing, linsolve = nothing, precs = DEFAU expand_threshold, shrink_factor, expand_factor, max_shrink_times, vjp_autodiff) end -@concrete mutable struct TrustRegionCache{iip, trustType, floatType} <: - AbstractNonlinearSolveCache{iip} +@concrete mutable struct TrustRegionCache{iip} <: AbstractNonlinearSolveCache{iip} f alg - u_prev u - fu_prev + u_cache + u_cache_2 + u_gauss_newton + u_cauchy fu - fu2 + fu_cache + fu_cache_2 + J + J_cache + JᵀJ + Jᵀf p uf + du + lr_mul_cache linsolve - J jac_cache force_stop::Bool maxiters::Int @@ -204,60 +211,55 @@ end reltol prob radius_update_scheme::RadiusUpdateSchemes.T - trust_r::trustType - max_trust_r::trustType + trust_r + max_trust_r step_threshold - shrink_threshold::trustType - expand_threshold::trustType - shrink_factor::trustType - expand_factor::trustType - loss::floatType - loss_new::floatType - H - g + shrink_threshold + expand_threshold + shrink_factor + expand_factor + loss + loss_new shrink_counter::Int - du - u_tmp - u_gauss_newton - u_cauchy - fu_new make_new_J::Bool - r::floatType - p1::floatType - p2::floatType - p3::floatType - p4::floatType - ϵ::floatType + r + p1 + p2 + p3 + p4 + ϵ + jvp_operator # For Yuan stats::NLStats tc_cache trace end -# TODO: add J_cache function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, - termination_condition = nothing, internalnorm = DEFAULT_NORM, linsolve_kwargs = (;), - kwargs...) where {uType, iip} + termination_condition = nothing, internalnorm = Base.Fix2(norm, 2), + linsolve_kwargs = (;), kwargs...) where {uType, iip} alg = get_concrete_algorithm(alg_, prob) @unpack f, u0, p = prob - u = alias_u0 ? u0 : deepcopy(u0) - u_prev = zero(u) - fu1 = evaluate_f(prob, u) - fu_prev = zero(fu1) + u = __maybe_unaliased(u0, alias_u0) + @bb u_cache = copy(u) + @bb u_cache_2 = similar(u) + fu = evaluate_f(prob, u) + @bb fu_cache_2 = zero(fu) - loss = __get_trust_region_loss(fu1) - uf, _, J, fu2, jac_cache, du, H, g = jacobian_caches(alg, f, u, p, Val(iip); + loss = __trust_region_loss(internalnorm, fu) + + uf, _, J, fu_cache, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches(alg, f, u, p, Val(iip); linsolve_kwargs, linsolve_with_JᵀJ = Val(true), lininit = Val(false)) - g = _restructure(fu1, g) - linsolve = u isa Number ? nothing : linsolve_caches(J, fu2, du, p, alg) + linsolve = linsolve_caches(J, fu_cache, du, p, alg) - u_tmp = zero(u) - u_cauchy = zero(u) - u_gauss_newton = _mutable_zero(u) + @bb u_cache_2 = similar(u) + @bb u_cauchy = similar(u) + @bb u_gauss_newton = similar(u) + @bb J_cache = similar(J) + @bb lr_mul_cache = similar(du) loss_new = loss shrink_counter = 0 - fu_new = zero(fu1) make_new_J = true r = loss @@ -270,11 +272,13 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, trustType = floatType if radius_update_scheme == RadiusUpdateSchemes.NLsolve max_trust_radius = convert(trustType, Inf) - initial_trust_radius = norm(u0) > 0 ? convert(trustType, norm(u0)) : one(trustType) + initial_trust_radius = internalnorm(u0) > 0 ? convert(trustType, internalnorm(u0)) : + one(trustType) else max_trust_radius = convert(trustType, alg.max_trust_radius) if iszero(max_trust_radius) - max_trust_radius = convert(trustType, max(norm(fu1), maximum(u) - minimum(u))) + max_trust_radius = convert(trustType, + max(internalnorm(fu), maximum(u) - minimum(u))) end initial_trust_radius = convert(trustType, alg.initial_trust_radius) if iszero(initial_trust_radius) @@ -293,6 +297,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, p3 = convert(floatType, 0.0) p4 = convert(floatType, 0.0) ϵ = convert(floatType, 1.0e-8) + jvp_operator = nothing if radius_update_scheme === RadiusUpdateSchemes.NLsolve p1 = convert(floatType, 0.5) elseif radius_update_scheme === RadiusUpdateSchemes.Hei @@ -311,16 +316,9 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, p1 = convert(floatType, 2.0) # μ p2 = convert(floatType, 1 / 6) # c5 p3 = convert(floatType, 6.0) # c6 - if iip - auto_jacvec!(g, (fu, x) -> f(fu, x, p), u, fu1) - else - if isa(u, Number) - g = ForwardDiff.derivative(x -> f(x, p), u) - else - g = auto_jacvec(x -> f(x, p), u, fu1) - end - end - initial_trust_radius = convert(trustType, p1 * norm(g)) + jvp_operator = __jacvec(uf, u; fu, autodiff = __get_nonsparse_ad(alg.ad)) + @bb Jᵀf = jvp_operator × fu + initial_trust_radius = convert(trustType, p1 * internalnorm(Jᵀf)) elseif radius_update_scheme === RadiusUpdateSchemes.Fan step_threshold = convert(trustType, 0.0001) shrink_threshold = convert(trustType, 0.25) @@ -329,7 +327,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, p2 = convert(floatType, 0.25) # c5 p3 = convert(floatType, 12.0) # c6 p4 = convert(floatType, 1.0e18) # M - initial_trust_radius = convert(trustType, p1 * (norm(fu1)^0.99)) + initial_trust_radius = convert(trustType, p1 * (internalnorm(fu)^0.99)) elseif radius_update_scheme === RadiusUpdateSchemes.Bastin step_threshold = convert(trustType, 0.05) shrink_threshold = convert(trustType, 0.05) @@ -339,25 +337,25 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, initial_trust_radius = convert(trustType, 1.0) end - abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu1, u, + abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u, termination_condition) - trace = init_nonlinearsolve_trace(alg, u, fu1, ApplyArray(__zero, J), du; kwargs...) + trace = init_nonlinearsolve_trace(alg, u, fu, ApplyArray(__zero, J), du; kwargs...) - return TrustRegionCache{iip}(f, alg, u_prev, u, fu_prev, fu1, fu2, p, uf, linsolve, J, + return TrustRegionCache{iip}(f, alg, u, u_cache, u_cache_2, u_gauss_newton, u_cauchy, + fu, fu_cache, fu_cache_2, J, J_cache, JᵀJ, Jᵀf, p, uf, du, lr_mul_cache, linsolve, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob, radius_update_scheme, initial_trust_radius, max_trust_radius, step_threshold, shrink_threshold, expand_threshold, shrink_factor, expand_factor, loss, loss_new, - H, g, shrink_counter, du, u_tmp, u_gauss_newton, u_cauchy, fu_new, make_new_J, r, - p1, p2, p3, p4, ϵ, NLStats(1, 0, 0, 0, 0), tc_cache, trace) + shrink_counter, make_new_J, r, p1, p2, p3, p4, ϵ, jvp_operator, + NLStats(1, 0, 0, 0, 0), tc_cache, trace) end function perform_step!(cache::TrustRegionCache{iip}) where {iip} if cache.make_new_J cache.J = jacobian!!(cache.J, cache) - __update_JᵀJ!(Val{iip}(), cache, :H, cache.J) - __update_Jᵀf!(Val{iip}(), cache, :g, :H, cache.J, _vec(cache.fu)) - cache.stats.njacs += 1 + __update_JᵀJ!(cache) + __update_Jᵀf!(cache) # do not use A = cache.H, b = _vec(cache.g) since it is equivalent # to A = cache.J, b = _vec(fu) as long as the Jacobian is non-singular @@ -374,7 +372,7 @@ function perform_step!(cache::TrustRegionCache{iip}) where {iip} # compute the potentially new u @bb @. cache.u_cache_2 = cache.u + cache.du - evaluate_f(cache, cache.u_tmp, cache.p, Val{:fu_cache_2}()) + evaluate_f(cache, cache.u_cache_2, cache.p, Val{:fu_cache_2}()) trust_region_step!(cache) cache.stats.nsolve += 1 cache.stats.nfactors += 1 @@ -383,278 +381,157 @@ end function retrospective_step!(cache::TrustRegionCache{iip}) where {iip} J = jacobian!!(cache.J_cache, cache) - __update_JᵀJ!(Val{iip}(), cache, :H, J) - __update_Jᵀf!(Val{iip}(), cache, :g, :H, J, cache.fu) - cache.stats.njacs += 1 + __update_JᵀJ!(cache, J) + __update_Jᵀf!(cache, J) - # FIXME: Caching in __lr_mul - num = __get_trust_region_loss(cache.fu) - __get_trust_region_loss(cache.fu_cache) - denom = dot(_vec(du), _vec(g)) + __lr_mul(Val{iip}(), H, _vec(du)) / 2 + num = __trust_region_loss(cache, cache.fu) - + __get_trust_region_loss(cache, cache.fu_cache) + denom = dot(_vec(cache.du), _vec(cache.Jᵀf)) + __lr_mul(cache, cache.JᵀJ, cache.du) / 2 return num / denom end -# TODO function trust_region_step!(cache::TrustRegionCache) - @unpack fu_new, du, g, H, loss, max_trust_r, radius_update_scheme = cache - - cache.loss_new = __get_trust_region_loss(fu_new) + cache.loss_new = __trust_region_loss(cache, cache.fu_cache_2) # Compute the ratio of the actual reduction to the predicted reduction. - cache.r = -(loss - cache.loss_new) / - (dot(_vec(du), _vec(g)) + __lr_mul(Val(isinplace(cache)), H, _vec(du)) / 2) - @unpack r = cache + cache.r = -(cache.loss - cache.loss_new) / + (dot(_vec(cache.du), _vec(cache.Jᵀf)) + + __lr_mul(cache, cache.JᵀJ, _vec(cache.du)) / 2) + + @unpack r, radius_update_scheme = cache + make_new_J = false + if r > cache.step_threshold + take_step!(cache) + cache.loss = cache.loss_new + make_new_J = true + end if radius_update_scheme === RadiusUpdateSchemes.Simple - # Update the trust region radius. if r < cache.shrink_threshold cache.trust_r *= cache.shrink_factor cache.shrink_counter += 1 else cache.shrink_counter = 0 - end - if r > cache.step_threshold - take_step!(cache) - cache.loss = cache.loss_new - - # Update the trust region radius. - if r > cache.expand_threshold - cache.trust_r = min(cache.expand_factor * cache.trust_r, max_trust_r) + if r > cache.step_threshold && r > cache.expand_threshold + cache.trust_r = min(cache.expand_factor * cache.trust_r, cache.max_trust_r) end - - cache.make_new_J = true - else - # No need to make a new J, no step was taken, so we try again with a smaller trust_r - cache.make_new_J = false end - update_trace!(cache.trace, cache.stats.nsteps + 1, cache.u, cache.fu, cache.J, - @~(cache.u.-cache.u_prev)) - check_and_update!(cache, cache.fu, cache.u, cache.u_prev) - elseif radius_update_scheme === RadiusUpdateSchemes.NLsolve - # accept/reject decision - if r > cache.step_threshold # accept - take_step!(cache) - cache.loss = cache.loss_new - cache.make_new_J = true - else # reject - cache.make_new_J = false - end - - # trust region update - if r < 1 // 10 # cache.shrink_threshold - cache.trust_r *= 1 // 2 # cache.shrink_factor - elseif r >= 9 // 10 # cache.expand_threshold - cache.trust_r = 2 * norm(cache.du) # cache.expand_factor * norm(cache.du) - elseif r >= 1 // 2 # cache.p1 - cache.trust_r = max(cache.trust_r, 2 * norm(cache.du)) # cache.expand_factor * norm(cache.du)) + if r < 1 // 10 + cache.shrink_counter += 1 + cache.trust_r *= 1 // 2 + else + cache.shrink_counter = 0 + if r ≥ 9 // 10 + cache.trust_r = 2 * cache.internalnorm(cache.du) + elseif r ≥ 1 // 2 + cache.trust_r = max(cache.trust_r, 2 * cache.internalnorm(cache.du)) + end end - - update_trace!(cache.trace, cache.stats.nsteps + 1, cache.u, cache.fu, cache.J, - @~(cache.u.-cache.u_prev)) - # convergence test - check_and_update!(cache, cache.fu, cache.u, cache.u_prev) - elseif radius_update_scheme === RadiusUpdateSchemes.NocedalWright - # accept/reject decision - if r > cache.step_threshold # accept - take_step!(cache) - cache.loss = cache.loss_new - cache.make_new_J = true - else # reject - cache.make_new_J = false - end - if r < 1 // 4 - cache.trust_r = (1 // 4) * norm(cache.du) - elseif (r > (3 // 4)) && abs(norm(cache.du) - cache.trust_r) / cache.trust_r < 1e-6 - cache.trust_r = min(2 * cache.trust_r, cache.max_trust_r) - end - - update_trace!(cache.trace, cache.stats.nsteps + 1, cache.u, cache.fu, cache.J, - @~(cache.u.-cache.u_prev)) - # convergence test - check_and_update!(cache, cache.fu, cache.u, cache.u_prev) - - elseif radius_update_scheme === RadiusUpdateSchemes.Hei - if r > cache.step_threshold - take_step!(cache) - cache.loss = cache.loss_new - cache.make_new_J = true + cache.shrink_counter += 1 + cache.trust_r = (1 // 4) * cache.internalnorm(cache.du) else - cache.make_new_J = false + cache.shrink_counter = 0 + if r > 3 // 4 && + abs(cache.internalnorm(cache.du) - cache.trust_r) < 1e-6 * cache.trust_r + cache.trust_r = min(2 * cache.trust_r, cache.max_trust_r) + end end - # Hei's radius update scheme + elseif radius_update_scheme === RadiusUpdateSchemes.Hei @unpack shrink_threshold, p1, p2, p3, p4 = cache - if rfunc(r, shrink_threshold, p1, p3, p4, p2) * cache.internalnorm(du) < - cache.trust_r + tr_new = __rfunc(r, shrink_threshold, p1, p3, p4, p2) * cache.internalnorm(du) + if tr_new < cache.trust_r cache.shrink_counter += 1 else cache.shrink_counter = 0 end - cache.trust_r = rfunc(r, shrink_threshold, p1, p3, p4, p2) * - cache.internalnorm(du) - - update_trace!(cache.trace, cache.stats.nsteps + 1, cache.u, cache.fu, cache.J, - @~(cache.u.-cache.u_prev)) - check_and_update!(cache, cache.fu, cache.u, cache.u_prev) - cache.internalnorm(g) < cache.ϵ && (cache.force_stop = true) + cache.trust_r = tr_new + cache.internalnorm(cache.Jᵀf) < cache.ϵ && (cache.force_stop = true) elseif radius_update_scheme === RadiusUpdateSchemes.Yuan if r < cache.shrink_threshold cache.p1 = cache.p2 * cache.p1 cache.shrink_counter += 1 - elseif r >= cache.expand_threshold && - cache.internalnorm(du) > cache.trust_r / 2 - cache.p1 = cache.p3 * cache.p1 - cache.shrink_counter = 0 - end - - if r > cache.step_threshold - take_step!(cache) - cache.loss = cache.loss_new - cache.make_new_J = true else - cache.make_new_J = false + if r ≥ cache.expand_threshold && + cache.internalnorm(cache.du) > cache.trust_r / 2 + cache.p1 = cache.p3 * cache.p1 + end + cache.shrink_counter = 0 end - @unpack p1 = cache - # TODO: Use the `vjp_autodiff` to for the jvp - cache.trust_r = p1 * cache.internalnorm(jvp!(cache)) + @bb cache.Jᵀf = cache.jvp_operator × vec(cache.fu) + cache.trust_r = cache.p1 * cache.internalnorm(cache.Jᵀf) - update_trace!(cache.trace, cache.stats.nsteps + 1, cache.u, cache.fu, cache.J, - @~(cache.u.-cache.u_prev)) - check_and_update!(cache, cache.fu, cache.u, cache.u_prev) - cache.internalnorm(g) < cache.ϵ && (cache.force_stop = true) - #Fan's update scheme + cache.internalnorm(cache.Jᵀf) < cache.ϵ && (cache.force_stop = true) elseif radius_update_scheme === RadiusUpdateSchemes.Fan if r < cache.shrink_threshold cache.p1 *= cache.p2 cache.shrink_counter += 1 - elseif r > cache.expand_threshold - cache.p1 = min(cache.p1 * cache.p3, cache.p4) - cache.shrink_counter = 0 - end - - if r > cache.step_threshold - take_step!(cache) - cache.loss = cache.loss_new - cache.make_new_J = true else - cache.make_new_J = false + cache.shrink_counter = 0 + r > cache.expand_threshold && (cache.p1 = min(cache.p1 * cache.p3, cache.p4)) end - - @unpack p1 = cache - cache.trust_r = p1 * (cache.internalnorm(cache.fu)^0.99) - - update_trace!(cache.trace, cache.stats.nsteps + 1, cache.u, cache.fu, cache.J, - @~(cache.u.-cache.u_prev)) - check_and_update!(cache, cache.fu, cache.u, cache.u_prev) - cache.internalnorm(g) < cache.ϵ && (cache.force_stop = true) + cache.trust_r = cache.p1 * (cache.internalnorm(cache.fu)^0.99) + cache.internalnorm(cache.Jᵀf) < cache.ϵ && (cache.force_stop = true) elseif radius_update_scheme === RadiusUpdateSchemes.Bastin if r > cache.step_threshold - take_step!(cache) - cache.loss = cache.loss_new - cache.make_new_J = true - if retrospective_step!(cache) >= cache.expand_threshold + if retrospective_step!(cache) ≥ cache.expand_threshold cache.trust_r = max(cache.p1 * cache.internalnorm(du), cache.trust_r) end - + cache.shrink_counter = 0 else - cache.make_new_J = false cache.trust_r *= cache.p2 cache.shrink_counter += 1 end - - update_trace!(cache.trace, cache.stats.nsteps + 1, cache.u, cache.fu, cache.J, - @~(cache.u.-cache.u_prev)) - check_and_update!(cache, cache.fu, cache.u, cache.u_prev) - end -end - -# TODO -function dogleg!(cache::TrustRegionCache{true}) - @unpack u_tmp, u_gauss_newton, u_cauchy, trust_r = cache - - # Take the full Gauss-Newton step if lies within the trust region. - if norm(u_gauss_newton) ≤ trust_r - cache.du .= u_gauss_newton - return end - # Take intersection of steepest descent direction and trust region if Cauchy point lies outside of trust region - l_grad = norm(cache.g) # length of the gradient - d_cauchy = l_grad^3 / __lr_mul(Val{true}(), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate - if d_cauchy >= trust_r - @. cache.du = -(trust_r / l_grad) * cache.g # step to the end of the trust region - return - end - - # Take the intersection of dogleg with trust region if Cauchy point lies inside the trust region - @. u_cauchy = -(d_cauchy / l_grad) * cache.g # compute Cauchy point - @. u_tmp = u_gauss_newton - u_cauchy # calf of the dogleg -- use u_tmp to avoid allocation - - a = dot(u_tmp, u_tmp) - b = 2 * dot(u_cauchy, u_tmp) - c = d_cauchy^2 - trust_r^2 - aux = max(b^2 - 4 * a * c, 0.0) # technically guaranteed to be non-negative but hedging against floating point issues - τ = (-b + sqrt(aux)) / (2 * a) # stepsize along dogleg to trust region boundary - - @. cache.du = u_cauchy + τ * u_tmp + update_trace!(cache.trace, cache.stats.nsteps + 1, cache.u, cache.fu, cache.J, + @~(cache.u.-cache.u_cache)) + check_and_update!(cache, cache.fu, cache.u, cache.u_cache) end -# TODO -function dogleg!(cache::TrustRegionCache{false}) - @unpack u_tmp, u_gauss_newton, u_cauchy, trust_r = cache - +function dogleg!(cache::TrustRegionCache{iip}) where {iip} # Take the full Gauss-Newton step if lies within the trust region. - if norm(u_gauss_newton) ≤ trust_r - cache.du = deepcopy(u_gauss_newton) + if cache.internalnorm(cache.u_gauss_newton) ≤ cache.trust_r + @bb copyto!(cache.du, cache.u_gauss_newton) return end - ## Take intersection of steepest descent direction and trust region if Cauchy point lies outside of trust region - l_grad = norm(cache.g) - d_cauchy = l_grad^3 / __lr_mul(Val{false}(), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate - if d_cauchy > trust_r # cauchy point lies outside of trust region - cache.du = -(trust_r / l_grad) * cache.g # step to the end of the trust region + # Take intersection of steepest descent direction and trust region if Cauchy point lies + # outside of trust region + l_grad = cache.internalnorm(cache.Jᵀf) # length of the gradient + d_cauchy = l_grad^3 / __lr_mul(cache) + if d_cauchy ≥ cache.trust_r + # step to the end of the trust region + @bb @. cache.du = -(cache.trust_r / l_grad) * cache.Jᵀf return end - # Take the intersection of dogleg with trust region if Cauchy point lies inside the trust region - u_cauchy = -(d_cauchy / l_grad) * cache.g # compute Cauchy point - u_tmp = u_gauss_newton - u_cauchy # calf of the dogleg - a = dot(u_tmp, u_tmp) - b = 2 * dot(u_cauchy, u_tmp) - c = d_cauchy^2 - trust_r^2 - aux = max(b^2 - 4 * a * c, 0.0) # technically guaranteed to be non-negative but hedging against floating point issues - τ = (-b + sqrt(aux)) / (2 * a) # stepsize along dogleg to trust region boundary - - cache.du = u_cauchy + τ * u_tmp + # Take the intersection of dogleg with trust region if Cauchy point lies inside the + # trust region + @bb @. cache.u_cauchy = -(d_cauchy / l_grad) * cache.Jᵀf # compute Cauchy point + @bb @. cache.u_cache_2 = cache.u_gauss_newton - cache.u_cauchy # calf of the dogleg + + a = dot(cache.u_cache_2, cache.u_cache_2) + b = 2 * dot(cache.u_cauchy, cache.u_cache_2) + c = d_cauchy^2 - cache.trust_r^2 + # technically guaranteed to be non-negative but hedging against floating point issues + aux = max(b^2 - 4 * a * c, 0) + # stepsize along dogleg to trust region boundary + τ = (-b + sqrt(aux)) / (2 * a) + + @bb @. cache.du = cache.u_cauchy + τ * cache.u_cache_2 + return end -function __take_step!(cache::TrustRegionCache) +function take_step!(cache::TrustRegionCache) @bb copyto!(cache.u_cache, cache.u) - @bb copyto!(cache.u, cache.u_cache_2) # u_tmp --> u_cache_2 + @bb copyto!(cache.u, cache.u_cache_2) @bb copyto!(cache.fu_cache, cache.fu) - @bb copyto!(cache.fu, cache.fu_cache_2) # fu_new --> fu_cache_2 -end - -# TODO -function jvp!(cache::TrustRegionCache{false}) - @unpack f, u, fu, uf = cache - if isa(u, Number) - return value_derivative(uf, u) - end - return auto_jacvec(uf, u, fu) -end - -function jvp!(cache::TrustRegionCache{true}) - @unpack g, f, u, fu, uf = cache - if isa(u, Number) - return value_derivative(uf, u) - end - auto_jacvec!(g, uf, u, fu) - return g + @bb copyto!(cache.fu, cache.fu_cache_2) end function not_terminated(cache::TrustRegionCache) @@ -670,8 +547,9 @@ function not_terminated(cache::TrustRegionCache) return true end +# FIXME: Update the JacVec Operator for Yuan function __reinit_internal!(cache::TrustRegionCache; kwargs...) - cache.loss = __get_trust_region_loss(cache.fu) + cache.loss = __trust_region_loss(cache, cache.fu) cache.shrink_counter = 0 cache.trust_r = convert(eltype(cache.u), ifelse(cache.alg.initial_trust_radius == 0, cache.alg.initial_trust_radius, @@ -680,4 +558,13 @@ function __reinit_internal!(cache::TrustRegionCache; kwargs...) return nothing end -__get_trust_region_loss(fu) = norm(fu)^2 / 2 +# This only holds for 2-norm? +__trust_region_loss(cache::TrustRegionCache, x) = __trust_region_loss(cache.internalnorm, x) +__trust_region_loss(nf::F, x) where {F} = nf(x)^2 / 2 + +# R-function for adaptive trust region method +function __rfunc(r::R, c2::R, M::R, γ1::R, γ2::R, β::R) where {R <: Real} + return ifelse(r ≥ c2, + (2 * (M - 1 - γ2) * atan(r - c2) + (1 + γ2)) / R(π), + (1 - γ1 - β) * (exp(r - c2) + β / (1 - γ1 - β))) +end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index e19771ef7..4d8496015 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -151,14 +151,6 @@ function wrapprecs(_Pl, _Pr, weight) return Pl, Pr end -function rfunc(r::R, c2::R, M::R, γ1::R, γ2::R, β::R) where {R <: Real} # R-function for adaptive trust region method - if (r ≥ c2) - return (2 * (M - 1 - γ2) * atan(r - c2) + (1 + γ2)) / π - else - return (1 - γ1 - β) * (exp(r - c2) + β / (1 - γ1 - β)) - end -end - concrete_jac(_) = nothing concrete_jac(::AbstractNewtonAlgorithm{CJ}) where {CJ} = CJ