Skip to content

Commit e8397a9

Browse files
committed
standardize the cache
1 parent 96d1ac8 commit e8397a9

File tree

2 files changed

+33
-31
lines changed

2 files changed

+33
-31
lines changed

src/jacobian.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,14 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number,
139139
kwargs...) where {needsJᵀJ, F}
140140
# NOTE: Scalar `u` assumes scalar output from `f`
141141
uf = SciMLBase.JacobianWrapper{false}(f, p)
142-
needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u
143-
return uf, FakeLinearSolveJLCache(u, u), u, nothing, nothing, u
142+
return uf, FakeLinearSolveJLCache(u, u), u, nothing, nothing, u, u, u
144143
end
145144

146145
# Linear Solve Cache
146+
function linsolve_caches(A::Number, b, u, p, alg; linsolve_kwargs = (;))
147+
return FakeLinearSolveJLCache(A, b)
148+
end
149+
147150
function linsolve_caches(A, b, u, p, alg; linsolve_kwargs = (;))
148151
if alg.linsolve === nothing && A isa SMatrix && linsolve_kwargs === (;)
149152
# Default handling for SArrays in LinearSolve is not great. Some parts are patched

src/trustRegion.jl

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,21 @@ end
186186
AbstractNonlinearSolveCache{iip}
187187
f
188188
alg
189-
u_prev
190189
u
191-
fu_prev
190+
u_cache
191+
u_cache_2
192192
fu
193-
fu2
193+
fu_cache
194+
fu_cache_2
195+
du
196+
u_gauss_newton
197+
u_cauchy
198+
J
199+
JᵀJ
200+
Jᵀf
194201
p
195202
uf
196203
linsolve
197-
J
198204
jac_cache
199205
force_stop::Bool
200206
maxiters::Int
@@ -213,14 +219,7 @@ end
213219
expand_factor::trustType
214220
loss::floatType
215221
loss_new::floatType
216-
H
217-
g
218222
shrink_counter::Int
219-
du
220-
u_tmp
221-
u_gauss_newton
222-
u_cauchy
223-
fu_new
224223
make_new_J::Bool
225224
r::floatType
226225
p1::floatType
@@ -240,24 +239,22 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
240239
linsolve_kwargs = (;), kwargs...) where {uType, iip}
241240
alg = get_concrete_algorithm(alg_, prob)
242241
@unpack f, u0, p = prob
243-
u = alias_u0 ? u0 : deepcopy(u0)
244-
u_prev = zero(u)
245-
fu1 = evaluate_f(prob, u)
246-
fu_prev = zero(fu1)
242+
u = __mabe_unaliased(u0, alias_u0)
243+
@bb u_cache = similar(u)
244+
@bb u_cache_2 = similar(u)
245+
fu = evaluate_f(prob, u)
246+
@bb fu_cache_2 = similar(fu)
247247

248-
loss = __get_trust_region_loss(fu1)
249-
uf, _, J, fu2, jac_cache, du, H, g = jacobian_caches(alg, f, u, p, Val(iip);
250-
linsolve_kwargs, linsolve_with_JᵀJ = Val(true), lininit = Val(false))
251-
g = _restructure(fu1, g)
252-
linsolve = u isa Number ? nothing : linsolve_caches(J, fu2, du, p, alg)
248+
loss = __get_trust_region_loss(fu)
249+
uf, _, J, fu_cache, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches(alg, f, u, p, Val(iip);
250+
linsolve_with_JᵀJ = Val(true), lininit = Val(false))
251+
linsolve = linsolve_caches(J, fu, du, p, alg)
253252

254-
u_tmp = zero(u)
255-
u_cauchy = zero(u)
256-
u_gauss_newton = _mutable_zero(u)
253+
@bb u_cauchy = similar(u)
254+
@bb u_gauss_newton = similar(u)
257255

258256
loss_new = loss
259257
shrink_counter = 0
260-
fu_new = zero(fu1)
261258
make_new_J = true
262259
r = loss
263260

@@ -342,12 +339,13 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
342339
termination_condition)
343340
trace = init_nonlinearsolve_trace(alg, u, fu1, ApplyArray(__zero, J), du; kwargs...)
344341

345-
return TrustRegionCache{iip}(f, alg, u_prev, u, fu_prev, fu1, fu2, p, uf, linsolve, J,
346-
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob,
342+
return TrustRegionCache{iip}(f, alg, u, u_cache, u_cache_2, fu, fu_cache, fu_cache_2,
343+
du, u_gauss_newton, u_cauchy, J, JᵀJ, Jᵀf, p, uf, linsolve, jac_cache, false,
344+
maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob,
347345
radius_update_scheme, initial_trust_radius, max_trust_radius, step_threshold,
348346
shrink_threshold, expand_threshold, shrink_factor, expand_factor, loss, loss_new,
349-
H, g, shrink_counter, du, u_tmp, u_gauss_newton, u_cauchy, fu_new, make_new_J, r,
350-
p1, p2, p3, p4, ϵ, NLStats(1, 0, 0, 0, 0), tc_cache, trace)
347+
shrink_counter, make_new_J, r, p1, p2, p3, p4, ϵ, NLStats(1, 0, 0, 0, 0), tc_cache,
348+
trace)
351349
end
352350

353351
function perform_step!(cache::TrustRegionCache{iip}) where {iip}
@@ -458,7 +456,8 @@ function trust_region_step!(cache::TrustRegionCache)
458456
cache.shrink_counter += 1
459457
else
460458
cache.shrink_counter = 0
461-
if r cache.expand_threshold && 2 * cache.internalnorm(cache.du) > cache.trust_r
459+
if r cache.expand_threshold &&
460+
2 * cache.internalnorm(cache.du) > cache.trust_r
462461
cache.p1 = cache.p3 * cache.p1
463462
end
464463
end

0 commit comments

Comments
 (0)