Skip to content

Commit

Permalink
standardize the cache
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 1, 2023
1 parent 96d1ac8 commit e8397a9
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 31 deletions.
7 changes: 5 additions & 2 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,14 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number,
kwargs...) where {needsJᵀJ, F}
# NOTE: Scalar `u` assumes scalar output from `f`
uf = SciMLBase.JacobianWrapper{false}(f, p)
needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u
return uf, FakeLinearSolveJLCache(u, u), u, nothing, nothing, u
return uf, FakeLinearSolveJLCache(u, u), u, nothing, nothing, u, u, u
end

# Linear Solve Cache
function linsolve_caches(A::Number, b, u, p, alg; linsolve_kwargs = (;))
return FakeLinearSolveJLCache(A, b)
end

function linsolve_caches(A, b, u, p, alg; linsolve_kwargs = (;))
if alg.linsolve === nothing && A isa SMatrix && linsolve_kwargs === (;)
# Default handling for SArrays in LinearSolve is not great. Some parts are patched
Expand Down
57 changes: 28 additions & 29 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,21 @@ end
AbstractNonlinearSolveCache{iip}
f
alg
u_prev
u
fu_prev
u_cache
u_cache_2
fu
fu2
fu_cache
fu_cache_2
du
u_gauss_newton
u_cauchy
J
JᵀJ
Jᵀf
p
uf
linsolve
J
jac_cache
force_stop::Bool
maxiters::Int
Expand All @@ -213,14 +219,7 @@ end
expand_factor::trustType
loss::floatType
loss_new::floatType
H
g
shrink_counter::Int
du
u_tmp
u_gauss_newton
u_cauchy
fu_new
make_new_J::Bool
r::floatType
p1::floatType
Expand All @@ -240,24 +239,22 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
linsolve_kwargs = (;), kwargs...) where {uType, iip}
alg = get_concrete_algorithm(alg_, prob)
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
u_prev = zero(u)
fu1 = evaluate_f(prob, u)
fu_prev = zero(fu1)
u = __mabe_unaliased(u0, alias_u0)

Check warning on line 242 in src/trustRegion.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"mabe" should be "maybe".
@bb u_cache = similar(u)
@bb u_cache_2 = similar(u)
fu = evaluate_f(prob, u)
@bb fu_cache_2 = similar(fu)

loss = __get_trust_region_loss(fu1)
uf, _, J, fu2, jac_cache, du, H, g = jacobian_caches(alg, f, u, p, Val(iip);
linsolve_kwargs, linsolve_with_JᵀJ = Val(true), lininit = Val(false))
g = _restructure(fu1, g)
linsolve = u isa Number ? nothing : linsolve_caches(J, fu2, du, p, alg)
loss = __get_trust_region_loss(fu)
uf, _, J, fu_cache, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches(alg, f, u, p, Val(iip);
linsolve_with_JᵀJ = Val(true), lininit = Val(false))
linsolve = linsolve_caches(J, fu, du, p, alg)

u_tmp = zero(u)
u_cauchy = zero(u)
u_gauss_newton = _mutable_zero(u)
@bb u_cauchy = similar(u)
@bb u_gauss_newton = similar(u)

loss_new = loss
shrink_counter = 0
fu_new = zero(fu1)
make_new_J = true
r = loss

Expand Down Expand Up @@ -342,12 +339,13 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
termination_condition)
trace = init_nonlinearsolve_trace(alg, u, fu1, ApplyArray(__zero, J), du; kwargs...)

return TrustRegionCache{iip}(f, alg, u_prev, u, fu_prev, fu1, fu2, p, uf, linsolve, J,
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob,
return TrustRegionCache{iip}(f, alg, u, u_cache, u_cache_2, fu, fu_cache, fu_cache_2,
du, u_gauss_newton, u_cauchy, J, JᵀJ, Jᵀf, p, uf, linsolve, jac_cache, false,
maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob,
radius_update_scheme, initial_trust_radius, max_trust_radius, step_threshold,
shrink_threshold, expand_threshold, shrink_factor, expand_factor, loss, loss_new,
H, g, shrink_counter, du, u_tmp, u_gauss_newton, u_cauchy, fu_new, make_new_J, r,
p1, p2, p3, p4, ϵ, NLStats(1, 0, 0, 0, 0), tc_cache, trace)
shrink_counter, make_new_J, r, p1, p2, p3, p4, ϵ, NLStats(1, 0, 0, 0, 0), tc_cache,
trace)
end

function perform_step!(cache::TrustRegionCache{iip}) where {iip}
Expand Down Expand Up @@ -458,7 +456,8 @@ function trust_region_step!(cache::TrustRegionCache)
cache.shrink_counter += 1
else
cache.shrink_counter = 0
if r cache.expand_threshold && 2 * cache.internalnorm(cache.du) > cache.trust_r
if r cache.expand_threshold &&
2 * cache.internalnorm(cache.du) > cache.trust_r
cache.p1 = cache.p3 * cache.p1
end
end
Expand Down

0 comments on commit e8397a9

Please sign in to comment.