1
1
const RelNormModes = Union{
2
- RelNormTerminationMode, RelNormSafeTerminationMode, RelNormSafeBestTerminationMode}
2
+ RelNormTerminationMode, RelNormSafeTerminationMode, RelNormSafeBestTerminationMode
3
+ }
3
4
const AbsNormModes = Union{
4
- AbsNormTerminationMode, AbsNormSafeTerminationMode, AbsNormSafeBestTerminationMode}
5
+ AbsNormTerminationMode, AbsNormSafeTerminationMode, AbsNormSafeBestTerminationMode
6
+ }
5
7
6
8
# Core Implementation
7
9
@concrete mutable struct NonlinearTerminationModeCache{uType, T}
32
34
33
35
function CommonSolve. init (
34
36
:: AbstractNonlinearProblem , mode:: AbstractNonlinearTerminationMode , du, u,
35
- saved_value_prototype... ; abstol = nothing , reltol = nothing , kwargs... )
37
+ saved_value_prototype... ; abstol = nothing , reltol = nothing , kwargs...
38
+ )
36
39
T = promote_type (eltype (du), eltype (u))
37
40
abstol = get_tolerance (u, abstol, T)
38
41
reltol = get_tolerance (u, reltol, T)
@@ -77,12 +80,14 @@ function CommonSolve.init(
77
80
return NonlinearTerminationModeCache (
78
81
u_unaliased, ReturnCode. Default, abstol, reltol, best_value, mode,
79
82
initial_objective, objectives_trace, 0 , saved_value_prototype,
80
- u0_norm, step_norm_trace, max_stalled_steps, u_diff_cache)
83
+ u0_norm, step_norm_trace, max_stalled_steps, u_diff_cache
84
+ )
81
85
end
82
86
83
87
function SciMLBase. reinit! (
84
88
cache:: NonlinearTerminationModeCache , du, u, saved_value_prototype... ;
85
- abstol = cache. abstol, reltol = cache. reltol, kwargs... )
89
+ abstol = cache. abstol, reltol = cache. reltol, kwargs...
90
+ )
86
91
T = eltype (cache. abstol)
87
92
length (saved_value_prototype) != 0 && (cache. saved_values = saved_value_prototype)
88
93
113
118
114
119
# # This dispatch is needed based on how Terminating Callback works!
115
120
function (cache:: NonlinearTerminationModeCache )(
116
- integrator:: AbstractODEIntegrator , abstol:: Number , reltol:: Number , min_t)
121
+ integrator:: AbstractODEIntegrator , abstol:: Number , reltol:: Number , min_t
122
+ )
117
123
if min_t === nothing || integrator. t ≥ min_t
118
124
return cache (cache. mode, SciMLBase. get_du (integrator),
119
125
integrator. u, integrator. uprev, abstol, reltol)
@@ -125,7 +131,8 @@ function (cache::NonlinearTerminationModeCache)(du, u, uprev, args...)
125
131
end
126
132
127
133
function (cache:: NonlinearTerminationModeCache )(
128
- mode:: AbstractNonlinearTerminationMode , du, u, uprev, abstol, reltol, args... )
134
+ mode:: AbstractNonlinearTerminationMode , du, u, uprev, abstol, reltol, args...
135
+ )
129
136
if check_convergence (mode, du, u, uprev, abstol, reltol)
130
137
cache. retcode = ReturnCode. Success
131
138
return true
@@ -134,7 +141,8 @@ function (cache::NonlinearTerminationModeCache)(
134
141
end
135
142
136
143
function (cache:: NonlinearTerminationModeCache )(
137
- mode:: AbstractSafeNonlinearTerminationMode , du, u, uprev, abstol, reltol, args... )
144
+ mode:: AbstractSafeNonlinearTerminationMode , du, u, uprev, abstol, reltol, args...
145
+ )
138
146
if mode isa AbsNormSafeTerminationMode || mode isa AbsNormSafeBestTerminationMode
139
147
objective = Utils. apply_norm (mode. internalnorm, du)
140
148
criteria = abstol
@@ -251,15 +259,17 @@ end
251
259
# High-Level API with defaults.
252
260
# # This is mostly for internal usage in NonlinearSolve and SimpleNonlinearSolve
253
261
function default_termination_mode (
254
- :: Union{ImmutableNonlinearProblem, NonlinearProblem} , :: Val{:simple} )
262
+ :: Union{ImmutableNonlinearProblem, NonlinearProblem} , :: Val{:simple}
263
+ )
255
264
return AbsNormTerminationMode (Base. Fix1 (maximum, abs))
256
265
end
257
266
function default_termination_mode (:: NonlinearLeastSquaresProblem , :: Val{:simple} )
258
267
return AbsNormTerminationMode (Base. Fix2 (norm, 2 ))
259
268
end
260
269
261
270
function default_termination_mode (
262
- :: Union{ImmutableNonlinearProblem, NonlinearProblem} , :: Val{:regular} )
271
+ :: Union{ImmutableNonlinearProblem, NonlinearProblem} , :: Val{:regular}
272
+ )
263
273
return AbsNormSafeBestTerminationMode (Base. Fix1 (maximum, abs); max_stalled_steps = 32 )
264
274
end
265
275
@@ -268,16 +278,53 @@ function default_termination_mode(::NonlinearLeastSquaresProblem, ::Val{:regular
268
278
end
269
279
270
280
function init_termination_cache (
271
- prob:: AbstractNonlinearProblem , abstol, reltol, du, u, :: Nothing , callee:: Val )
281
+ prob:: AbstractNonlinearProblem , abstol, reltol, du, u, :: Nothing , callee:: Val
282
+ )
272
283
return init_termination_cache (
273
284
prob, abstol, reltol, du, u, default_termination_mode (prob, callee), callee)
274
285
end
275
286
276
287
function init_termination_cache (prob:: AbstractNonlinearProblem , abstol, reltol, du,
277
- u, tc:: AbstractNonlinearTerminationMode , :: Val )
288
+ u, tc:: AbstractNonlinearTerminationMode , :: Val
289
+ )
278
290
T = promote_type (eltype (du), eltype (u))
279
291
abstol = get_tolerance (u, abstol, T)
280
292
reltol = get_tolerance (u, reltol, T)
281
293
cache = init (prob, tc, du, u; abstol, reltol)
282
294
return abstol, reltol, cache
283
295
end
296
+
297
+ function check_and_update! (cache, fu, u, uprev)
298
+ return check_and_update! (
299
+ cache. termination_cache, cache, fu, u, uprev, cache. termination_cache. mode
300
+ )
301
+ end
302
+
303
+ function check_and_update! (tc_cache, cache, fu, u, uprev, mode)
304
+ if tc_cache (fu, u, uprev)
305
+ cache. retcode = tc_cache. retcode
306
+ update_from_termination_cache! (tc_cache, cache, mode, u)
307
+ cache. force_stop = true
308
+ end
309
+ end
310
+
311
+ function update_from_termination_cache! (tc_cache, cache, u = get_u (cache))
312
+ return update_from_termination_cache! (tc_cache, cache, tc_cache. mode, u)
313
+ end
314
+
315
+ function update_from_termination_cache! (
316
+ tc_cache, cache, :: AbstractNonlinearTerminationMode , u = get_u (cache)
317
+ )
318
+ Utils. evaluate_f! (cache, u, cache. p)
319
+ end
320
+
321
+ function update_from_termination_cache! (
322
+ tc_cache, cache, :: AbstractSafeBestNonlinearTerminationMode , u = get_u (cache)
323
+ )
324
+ if SciMLBase. isinplace (cache)
325
+ copyto! (get_u (cache), tc_cache. u)
326
+ else
327
+ SciMLBase. set_u! (cache, tc_cache. u)
328
+ end
329
+ Utils. evaluate_f! (cache, get_u (cache), cache. p)
330
+ end
0 commit comments