From 9bc8f5bacb706ee0c3ef9382e2270b1ec5a791db Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 29 Nov 2023 21:45:54 -0500 Subject: [PATCH] Reuse more code in Broyden --- src/NonlinearSolve.jl | 2 +- src/broyden.jl | 95 ++++++++++++++----------------------------- src/raphson.jl | 2 +- src/utils.jl | 38 ++++++++++++----- 4 files changed, 60 insertions(+), 77 deletions(-) diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index f050bf007..f1782b8c1 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -25,7 +25,7 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work AbstractVectorOfArray, recursivecopy!, recursivefill! import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace import SciMLOperators: FunctionOperator - import StaticArraysCore: StaticArray, SVector, SArray, MArray, Size, SMatrix + import StaticArraysCore: StaticArray, SVector, SArray, MArray, Size, SMatrix, MMatrix import UnPack: @unpack using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve diff --git a/src/broyden.jl b/src/broyden.jl index 008ff589d..e0b69f19c 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -65,81 +65,46 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde termination_condition = nothing, internalnorm::F = DEFAULT_NORM, kwargs...) where {uType, iip, F} @unpack f, u0, p = prob - u = alias_u0 ? u0 : deepcopy(u0) + u = __maybe_unaliased(u0, alias_u0) fu = evaluate_f(prob, u) - du = _mutable_zero(u) + @bb du = copy(u) J⁻¹ = __init_identity_jacobian(u, fu) reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(real(eltype(u)))) : alg.reset_tolerance reset_check = x -> abs(x) ≤ reset_tolerance + @bb u_prev = copy(u) + @bb fu2 = copy(fu) + @bb dfu = similar(fu) + @bb J⁻¹₂ = similar(u) + @bb J⁻¹df = similar(u) + abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u, termination_condition) trace = init_nonlinearsolve_trace(alg, u, fu, J⁻¹, du; uses_jac_inverse = Val(true), kwargs...) - return GeneralBroydenCache{iip}(f, alg, u, zero(u), du, fu, zero(fu), - zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0, - alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, - reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0), + return GeneralBroydenCache{iip}(f, alg, u, u_prev, du, fu, fu2, dfu, p, J⁻¹, + J⁻¹₂, J⁻¹df, false, 0, alg.max_resets, maxiters, internalnorm, ReturnCode.Default, + abstol, reltol, reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0), init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache, trace) end -function perform_step!(cache::GeneralBroydenCache{true}) - @unpack f, p, du, fu, fu2, dfu, u, u_prev, J⁻¹, J⁻¹df, J⁻¹₂ = cache - T = eltype(u) - - mul!(_vec(du), J⁻¹, _vec(fu)) - α = perform_linesearch!(cache.ls_cache, u, du) - _axpy!(-α, du, u) - f(fu2, u, p) - - update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache), - get_fu(cache), J⁻¹, du, α) - - check_and_update!(cache, fu2, u, u_prev) - cache.stats.nf += 1 - - cache.force_stop && return nothing +function perform_step!(cache::GeneralBroydenCache{iip}) where {iip} + T = eltype(cache.u) - # Update the inverse jacobian - dfu .= fu2 .- fu + @bb cache.du = cache.J⁻¹ × vec(cache.fu) + α = perform_linesearch!(cache.ls_cache, cache.u, cache.du) + @bb axpy!(-α, cache.du, cache.u) - if all(cache.reset_check, du) || all(cache.reset_check, dfu) - if cache.resets ≥ cache.max_resets - cache.retcode = ReturnCode.ConvergenceFailure - cache.force_stop = true - return nothing - end - fill!(J⁻¹, 0) - J⁻¹[diagind(J⁻¹)] .= T(1) - cache.resets += 1 + if iip + cache.f(cache.fu2, cache.u, cache.p) else - du .*= -1 - mul!(_vec(J⁻¹df), J⁻¹, _vec(dfu)) - mul!(J⁻¹₂, _vec(du)', J⁻¹) - denom = dot(du, J⁻¹df) - du .= (du .- J⁻¹df) ./ ifelse(iszero(denom), T(1e-5), denom) - mul!(J⁻¹, _vec(du), J⁻¹₂, 1, 1) + cache.fu2 = cache.f(cache.u, cache.p) end - fu .= fu2 - @. u_prev = u - - return nothing -end - -function perform_step!(cache::GeneralBroydenCache{false}) - @unpack f, p = cache - - T = eltype(cache.u) - - cache.du = _restructure(cache.du, cache.J⁻¹ * _vec(cache.fu)) - α = perform_linesearch!(cache.ls_cache, cache.u, cache.du) - cache.u = cache.u .- α * cache.du - cache.fu2 = f(cache.u, p) update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache), - get_fu(cache), cache.J⁻¹, cache.du, α) + cache.fu2, cache.J⁻¹, cache.du, α) check_and_update!(cache, cache.fu2, cache.u, cache.u_prev) cache.stats.nf += 1 @@ -147,25 +112,27 @@ function perform_step!(cache::GeneralBroydenCache{false}) cache.force_stop && return nothing # Update the inverse jacobian - cache.dfu = cache.fu2 .- cache.fu + @bb @. cache.dfu = cache.fu2 - cache.fu + if all(cache.reset_check, cache.du) || all(cache.reset_check, cache.dfu) if cache.resets ≥ cache.max_resets cache.retcode = ReturnCode.ConvergenceFailure cache.force_stop = true return nothing end - cache.J⁻¹ = __init_identity_jacobian(cache.u, cache.fu) + cache.J⁻¹ = __reinit_identity_jacobian!!(cache.J⁻¹) cache.resets += 1 else - cache.du = -cache.du - cache.J⁻¹df = _restructure(cache.J⁻¹df, cache.J⁻¹ * _vec(cache.dfu)) - cache.J⁻¹₂ = _vec(cache.du)' * cache.J⁻¹ + @bb cache.du .*= -1 + @bb cache.J⁻¹df = cache.J⁻¹ × vec(cache.dfu) + @bb cache.J⁻¹₂ = cache.J⁻¹ × vec(cache.du) denom = dot(cache.du, cache.J⁻¹df) - cache.du = (cache.du .- cache.J⁻¹df) ./ ifelse(iszero(denom), T(1e-5), denom) - cache.J⁻¹ = cache.J⁻¹ .+ _vec(cache.du) * cache.J⁻¹₂ + @bb @. cache.du = (cache.du - cache.J⁻¹df) / ifelse(iszero(denom), T(1e-5), denom) + @bb cache.J⁻¹ += vec(cache.du) × transpose(cache.J⁻¹₂) end - cache.fu = cache.fu2 - cache.u_prev = @. cache.u + + @bb copyto!(cache.fu, cache.fu2) + @bb copyto!(cache.u_prev, cache.u) return nothing end diff --git a/src/raphson.jl b/src/raphson.jl index 4c4125579..52e47ac01 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -114,7 +114,7 @@ function perform_step!(cache::NewtonRaphsonCache{iip}) where {iip} α = perform_linesearch!(cache.ls_cache, cache.u, cache.du) @bb axpy!(-α, cache.du, cache.u) - evaluate_f(cache, cache.u) + evaluate_f(cache, cache.u, cache.p) update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), cache.J, cache.du, α) diff --git a/src/utils.jl b/src/utils.jl index d3017d42f..ab2db093f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -188,12 +188,11 @@ function evaluate_f(prob::Union{NonlinearProblem{uType, iip}, return fu end -function evaluate_f(cache, u) - @unpack f, p = cache.prob +function evaluate_f(cache, u, p) if isinplace(cache) - f(get_fu(cache), u, p) + cache.prob.f(get_fu(cache), u, p) else - set_fu!(cache, f(u, p)) + set_fu!(cache, cache.prob.f(u, p)) end return nothing end @@ -301,14 +300,31 @@ function check_and_update!(tc_cache, cache, fu, u, uprev, end end -__init_identity_jacobian(u::Number, _) = u -function __init_identity_jacobian(u, fu) - return convert(parameterless_type(_mutable(u)), - Matrix{eltype(u)}(I, length(fu), length(u))) +@inline __init_identity_jacobian(u::Number, _) = one(u) +@inline function __init_identity_jacobian(u, fu) + J = similar(fu, promote_type(eltype(fu), eltype(u)), length(fu), length(u)) + fill!(J, zero(eltype(J))) + J[diagind(J)] .= one(eltype(J)) + return J end -function __init_identity_jacobian(u::StaticArray, fu) - return convert(MArray{Tuple{length(fu), length(u)}}, - Matrix{eltype(u)}(I, length(fu), length(u))) +@inline function __init_identity_jacobian(u::StaticArray, fu::StaticArray) + T = promote_type(eltype(fu), eltype(u)) + return MArray{Tuple{prod(Size(fu)), prod(Size(u))}, T}(I) +end +@inline function __init_identity_jacobian(u::SArray, fu::SArray) + T = promote_type(eltype(fu), eltype(u)) + return SArray{Tuple{prod(Size(fu)), prod(Size(u))}, T}(I) +end + +@inline __reinit_identity_jacobian!!(J::Number) = one(J) +@inline function __reinit_identity_jacobian!!(J::AbstractMatrix) + fill!(J, zero(eltype(J))) + J[diagind(J)] .= one(eltype(J)) + return J +end +@inline function __reinit_identity_jacobian!!(J::SMatrix) + S = Size(J) + return SArray{Tuple{S[1], S[2]}, eltype(J)}(I) end function __init_low_rank_jacobian(u::StaticArray, fu, threshold::Int)