Skip to content

Commit

Permalink
Reuse more code in Broyden
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 30, 2023
1 parent 74c2ad7 commit 9bc8f5b
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 77 deletions.
2 changes: 1 addition & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work
AbstractVectorOfArray, recursivecopy!, recursivefill!
import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
import SciMLOperators: FunctionOperator
import StaticArraysCore: StaticArray, SVector, SArray, MArray, Size, SMatrix
import StaticArraysCore: StaticArray, SVector, SArray, MArray, Size, SMatrix, MMatrix
import UnPack: @unpack

using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
Expand Down
95 changes: 31 additions & 64 deletions src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,107 +65,74 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
kwargs...) where {uType, iip, F}
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
u = __maybe_unaliased(u0, alias_u0)
fu = evaluate_f(prob, u)
du = _mutable_zero(u)
@bb du = copy(u)
J⁻¹ = __init_identity_jacobian(u, fu)
reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(real(eltype(u)))) :
alg.reset_tolerance
reset_check = x -> abs(x) reset_tolerance

@bb u_prev = copy(u)
@bb fu2 = copy(fu)
@bb dfu = similar(fu)
@bb J⁻¹₂ = similar(u)
@bb J⁻¹df = similar(u)

abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u,
termination_condition)
trace = init_nonlinearsolve_trace(alg, u, fu, J⁻¹, du; uses_jac_inverse = Val(true),
kwargs...)

return GeneralBroydenCache{iip}(f, alg, u, zero(u), du, fu, zero(fu),
zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0,
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol,
reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0),
return GeneralBroydenCache{iip}(f, alg, u, u_prev, du, fu, fu2, dfu, p, J⁻¹,
J⁻¹₂, J⁻¹df, false, 0, alg.max_resets, maxiters, internalnorm, ReturnCode.Default,
abstol, reltol, reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache, trace)
end

function perform_step!(cache::GeneralBroydenCache{true})
@unpack f, p, du, fu, fu2, dfu, u, u_prev, J⁻¹, J⁻¹df, J⁻¹₂ = cache
T = eltype(u)

mul!(_vec(du), J⁻¹, _vec(fu))
α = perform_linesearch!(cache.ls_cache, u, du)
_axpy!(-α, du, u)
f(fu2, u, p)

update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache),
get_fu(cache), J⁻¹, du, α)

check_and_update!(cache, fu2, u, u_prev)
cache.stats.nf += 1

cache.force_stop && return nothing
function perform_step!(cache::GeneralBroydenCache{iip}) where {iip}
T = eltype(cache.u)

# Update the inverse jacobian
dfu .= fu2 .- fu
@bb cache.du = cache.J⁻¹ × vec(cache.fu)
α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
@bb axpy!(-α, cache.du, cache.u)

if all(cache.reset_check, du) || all(cache.reset_check, dfu)
if cache.resets cache.max_resets
cache.retcode = ReturnCode.ConvergenceFailure
cache.force_stop = true
return nothing
end
fill!(J⁻¹, 0)
J⁻¹[diagind(J⁻¹)] .= T(1)
cache.resets += 1
if iip
cache.f(cache.fu2, cache.u, cache.p)
else
du .*= -1
mul!(_vec(J⁻¹df), J⁻¹, _vec(dfu))
mul!(J⁻¹₂, _vec(du)', J⁻¹)
denom = dot(du, J⁻¹df)
du .= (du .- J⁻¹df) ./ ifelse(iszero(denom), T(1e-5), denom)
mul!(J⁻¹, _vec(du), J⁻¹₂, 1, 1)
cache.fu2 = cache.f(cache.u, cache.p)
end
fu .= fu2
@. u_prev = u

return nothing
end

function perform_step!(cache::GeneralBroydenCache{false})
@unpack f, p = cache

T = eltype(cache.u)

cache.du = _restructure(cache.du, cache.J⁻¹ * _vec(cache.fu))
α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
cache.u = cache.u .- α * cache.du
cache.fu2 = f(cache.u, p)

update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache),
get_fu(cache), cache.J⁻¹, cache.du, α)
cache.fu2, cache.J⁻¹, cache.du, α)

check_and_update!(cache, cache.fu2, cache.u, cache.u_prev)
cache.stats.nf += 1

cache.force_stop && return nothing

# Update the inverse jacobian
cache.dfu = cache.fu2 .- cache.fu
@bb @. cache.dfu = cache.fu2 - cache.fu

if all(cache.reset_check, cache.du) || all(cache.reset_check, cache.dfu)
if cache.resets cache.max_resets
cache.retcode = ReturnCode.ConvergenceFailure
cache.force_stop = true
return nothing
end
cache.J⁻¹ = __init_identity_jacobian(cache.u, cache.fu)
cache.J⁻¹ = __reinit_identity_jacobian!!(cache.J⁻¹)
cache.resets += 1
else
cache.du = -cache.du
cache.J⁻¹df = _restructure(cache.J⁻¹df, cache.J⁻¹ * _vec(cache.dfu))
cache.J⁻¹₂ = _vec(cache.du)' * cache.J⁻¹
@bb cache.du .*= -1
@bb cache.J⁻¹df = cache.J⁻¹ × vec(cache.dfu)
@bb cache.J⁻¹₂ = cache.J⁻¹ × vec(cache.du)
denom = dot(cache.du, cache.J⁻¹df)
cache.du = (cache.du .- cache.J⁻¹df) ./ ifelse(iszero(denom), T(1e-5), denom)
cache.J⁻¹ = cache.J⁻¹ .+ _vec(cache.du) * cache.J⁻¹₂
@bb @. cache.du = (cache.du - cache.J⁻¹df) / ifelse(iszero(denom), T(1e-5), denom)
@bb cache.J⁻¹ += vec(cache.du) × transpose(cache.J⁻¹₂)
end
cache.fu = cache.fu2
cache.u_prev = @. cache.u

@bb copyto!(cache.fu, cache.fu2)
@bb copyto!(cache.u_prev, cache.u)

return nothing
end
Expand Down
2 changes: 1 addition & 1 deletion src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ function perform_step!(cache::NewtonRaphsonCache{iip}) where {iip}
α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
@bb axpy!(-α, cache.du, cache.u)

evaluate_f(cache, cache.u)
evaluate_f(cache, cache.u, cache.p)

update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), cache.J,
cache.du, α)
Expand Down
38 changes: 27 additions & 11 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,11 @@ function evaluate_f(prob::Union{NonlinearProblem{uType, iip},
return fu
end

function evaluate_f(cache, u)
@unpack f, p = cache.prob
function evaluate_f(cache, u, p)
if isinplace(cache)
f(get_fu(cache), u, p)
cache.prob.f(get_fu(cache), u, p)
else
set_fu!(cache, f(u, p))
set_fu!(cache, cache.prob.f(u, p))
end
return nothing
end
Expand Down Expand Up @@ -301,14 +300,31 @@ function check_and_update!(tc_cache, cache, fu, u, uprev,
end
end

__init_identity_jacobian(u::Number, _) = u
function __init_identity_jacobian(u, fu)
return convert(parameterless_type(_mutable(u)),
Matrix{eltype(u)}(I, length(fu), length(u)))
@inline __init_identity_jacobian(u::Number, _) = one(u)
@inline function __init_identity_jacobian(u, fu)
J = similar(fu, promote_type(eltype(fu), eltype(u)), length(fu), length(u))
fill!(J, zero(eltype(J)))
J[diagind(J)] .= one(eltype(J))
return J
end
function __init_identity_jacobian(u::StaticArray, fu)
return convert(MArray{Tuple{length(fu), length(u)}},
Matrix{eltype(u)}(I, length(fu), length(u)))
@inline function __init_identity_jacobian(u::StaticArray, fu::StaticArray)
T = promote_type(eltype(fu), eltype(u))
return MArray{Tuple{prod(Size(fu)), prod(Size(u))}, T}(I)
end
@inline function __init_identity_jacobian(u::SArray, fu::SArray)
T = promote_type(eltype(fu), eltype(u))
return SArray{Tuple{prod(Size(fu)), prod(Size(u))}, T}(I)
end

@inline __reinit_identity_jacobian!!(J::Number) = one(J)
@inline function __reinit_identity_jacobian!!(J::AbstractMatrix)
fill!(J, zero(eltype(J)))
J[diagind(J)] .= one(eltype(J))
return J
end
@inline function __reinit_identity_jacobian!!(J::SMatrix)
S = Size(J)
return SArray{Tuple{S[1], S[2]}, eltype(J)}(I)
end

function __init_low_rank_jacobian(u::StaticArray, fu, threshold::Int)
Expand Down

0 comments on commit 9bc8f5b

Please sign in to comment.