@@ -186,15 +186,21 @@ end
186
186
AbstractNonlinearSolveCache{iip}
187
187
f
188
188
alg
189
- u_prev
190
189
u
191
- fu_prev
190
+ u_cache
191
+ u_cache_2
192
192
fu
193
- fu2
193
+ fu_cache
194
+ fu_cache_2
195
+ du
196
+ u_gauss_newton
197
+ u_cauchy
198
+ J
199
+ JᵀJ
200
+ Jᵀf
194
201
p
195
202
uf
196
203
linsolve
197
- J
198
204
jac_cache
199
205
force_stop:: Bool
200
206
maxiters:: Int
213
219
expand_factor:: trustType
214
220
loss:: floatType
215
221
loss_new:: floatType
216
- H
217
- g
218
222
shrink_counter:: Int
219
- du
220
- u_tmp
221
- u_gauss_newton
222
- u_cauchy
223
- fu_new
224
223
make_new_J:: Bool
225
224
r:: floatType
226
225
p1:: floatType
@@ -240,24 +239,22 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
240
239
linsolve_kwargs = (;), kwargs... ) where {uType, iip}
241
240
alg = get_concrete_algorithm (alg_, prob)
242
241
@unpack f, u0, p = prob
243
- u = alias_u0 ? u0 : deepcopy (u0)
244
- u_prev = zero (u)
245
- fu1 = evaluate_f (prob, u)
246
- fu_prev = zero (fu1)
242
+ u = __mabe_unaliased (u0, alias_u0)
243
+ @bb u_cache = similar (u)
244
+ @bb u_cache_2 = similar (u)
245
+ fu = evaluate_f (prob, u)
246
+ @bb fu_cache_2 = similar (fu)
247
247
248
- loss = __get_trust_region_loss (fu1)
249
- uf, _, J, fu2, jac_cache, du, H, g = jacobian_caches (alg, f, u, p, Val (iip);
250
- linsolve_kwargs, linsolve_with_JᵀJ = Val (true ), lininit = Val (false ))
251
- g = _restructure (fu1, g)
252
- linsolve = u isa Number ? nothing : linsolve_caches (J, fu2, du, p, alg)
248
+ loss = __get_trust_region_loss (fu)
249
+ uf, _, J, fu_cache, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches (alg, f, u, p, Val (iip);
250
+ linsolve_with_JᵀJ = Val (true ), lininit = Val (false ))
251
+ linsolve = linsolve_caches (J, fu, du, p, alg)
253
252
254
- u_tmp = zero (u)
255
- u_cauchy = zero (u)
256
- u_gauss_newton = _mutable_zero (u)
253
+ @bb u_cauchy = similar (u)
254
+ @bb u_gauss_newton = similar (u)
257
255
258
256
loss_new = loss
259
257
shrink_counter = 0
260
- fu_new = zero (fu1)
261
258
make_new_J = true
262
259
r = loss
263
260
@@ -342,12 +339,13 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
342
339
termination_condition)
343
340
trace = init_nonlinearsolve_trace (alg, u, fu1, ApplyArray (__zero, J), du; kwargs... )
344
341
345
- return TrustRegionCache {iip} (f, alg, u_prev, u, fu_prev, fu1, fu2, p, uf, linsolve, J,
346
- jac_cache, false , maxiters, internalnorm, ReturnCode. Default, abstol, reltol, prob,
342
+ return TrustRegionCache {iip} (f, alg, u, u_cache, u_cache_2, fu, fu_cache, fu_cache_2,
343
+ du, u_gauss_newton, u_cauchy, J, JᵀJ, Jᵀf, p, uf, linsolve, jac_cache, false ,
344
+ maxiters, internalnorm, ReturnCode. Default, abstol, reltol, prob,
347
345
radius_update_scheme, initial_trust_radius, max_trust_radius, step_threshold,
348
346
shrink_threshold, expand_threshold, shrink_factor, expand_factor, loss, loss_new,
349
- H, g, shrink_counter, du, u_tmp, u_gauss_newton, u_cauchy, fu_new, make_new_J, r ,
350
- p1, p2, p3, p4, ϵ, NLStats ( 1 , 0 , 0 , 0 , 0 ), tc_cache, trace)
347
+ shrink_counter, make_new_J, r, p1, p2, p3, p4, ϵ, NLStats ( 1 , 0 , 0 , 0 , 0 ), tc_cache ,
348
+ trace)
351
349
end
352
350
353
351
function perform_step! (cache:: TrustRegionCache{iip} ) where {iip}
@@ -458,7 +456,8 @@ function trust_region_step!(cache::TrustRegionCache)
458
456
cache. shrink_counter += 1
459
457
else
460
458
cache. shrink_counter = 0
461
- if r ≥ cache. expand_threshold && 2 * cache. internalnorm (cache. du) > cache. trust_r
459
+ if r ≥ cache. expand_threshold &&
460
+ 2 * cache. internalnorm (cache. du) > cache. trust_r
462
461
cache. p1 = cache. p3 * cache. p1
463
462
end
464
463
end
0 commit comments