Skip to content

Commit

Permalink
Share reinit code
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 30, 2023
1 parent 9bc8f5b commit a5c6195
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 55 deletions.
36 changes: 35 additions & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import Reexport: @reexport
import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload

@recompile_invalidations begin
using DiffEqBase, LazyArrays, LinearAlgebra, LinearSolve, Printf, SparseArrays,
using DiffEqBase,
LazyArrays, LinearAlgebra, LinearSolve, Printf, SparseArrays,
SparseDiffTools

import ADTypes: AbstractFiniteDifferencesMode
Expand Down Expand Up @@ -51,6 +52,39 @@ abstract type AbstractNonlinearSolveCache{iip} end

isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip

function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache{iip}, u0 = get_u(cache);
p = cache.p, abstol = cache.abstol, reltol = cache.reltol,
maxiters = cache.maxiters, alias_u0 = false,
termination_condition = get_termination_mode(cache.tc_cache)) where {iip}
cache.p = p
if iip
recursivecopy!(get_u(cache), u0)
cache.f(cache.fu1, get_u(cache), p)
else
cache.u = __maybe_unaliased(u0, alias_u0)
set_fu!(cache, cache.f(cache.u, p))
end

reset!(cache.trace)
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, get_fu(cache),
get_u(cache), termination_condition)

cache.abstol = abstol
cache.reltol = reltol
cache.tc_cache = tc_cache
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
cache.force_stop = false
cache.retcode = ReturnCode.Default

__reinit_internal!(cache)

return cache
end

__reinit_internal!(cache::AbstractNonlinearSolveCache) = nothing

function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm)
str = "$(nameof(typeof(alg)))("
modifiers = String[]
Expand Down
29 changes: 3 additions & 26 deletions src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,31 +137,8 @@ function perform_step!(cache::GeneralBroydenCache{iip}) where {iip}
return nothing
end

function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, reltol = cache.reltol, maxiters = cache.maxiters,
termination_condition = get_termination_mode(cache.tc_cache)) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
cache.f(cache.fu, cache.u, p)
else
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
cache.u = u0
cache.fu = cache.f(cache.u, p)
end

reset!(cache.trace)
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u,
termination_condition)

cache.abstol = abstol
cache.reltol = reltol
cache.tc_cache = tc_cache
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
function __reinit_internal!(cache::GeneralBroydenCache)
cache.J⁻¹ = __reinit_identity_jacobian!!(cache.J⁻¹)
cache.resets = 0
cache.force_stop = false
cache.retcode = ReturnCode.Default
return cache
return nothing
end
28 changes: 0 additions & 28 deletions src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,31 +128,3 @@ function perform_step!(cache::NewtonRaphsonCache{iip}) where {iip}
cache.stats.nfactors += 1
return nothing
end

function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, reltol = cache.reltol, maxiters = cache.maxiters,
termination_condition = get_termination_mode(cache.tc_cache)) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
cache.f(cache.fu1, cache.u, p)
else
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
cache.u = u0
cache.fu1 = cache.f(cache.u, p)
end

reset!(cache.trace)
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu1, cache.u,
termination_condition)

cache.abstol = abstol
cache.reltol = reltol
cache.tc_cache = tc_cache
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
cache.force_stop = false
cache.retcode = ReturnCode.Default
return cache
end

0 comments on commit a5c6195

Please sign in to comment.