Skip to content

Commit 9025e3e

Browse files
committed
refactor: implement internal caches function
1 parent 804c662 commit 9025e3e

File tree

20 files changed

+183
-318
lines changed

20 files changed

+183
-318
lines changed

lib/NonlinearSolveBase/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
3131
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
3232
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
3333
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
34+
LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
3435
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
3536
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3637
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
@@ -39,6 +40,7 @@ SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
3940
NonlinearSolveBaseBandedMatricesExt = "BandedMatrices"
4041
NonlinearSolveBaseDiffEqBaseExt = "DiffEqBase"
4142
NonlinearSolveBaseForwardDiffExt = "ForwardDiff"
43+
NonlinearSolveBaseLineSearchExt = "LineSearch"
4244
NonlinearSolveBaseLinearSolveExt = "LinearSolve"
4345
NonlinearSolveBaseSparseArraysExt = "SparseArrays"
4446
NonlinearSolveBaseSparseMatrixColoringsExt = "SparseMatrixColorings"
@@ -60,6 +62,7 @@ FastClosures = "0.3"
6062
ForwardDiff = "0.10.36"
6163
FunctionProperties = "0.1.2"
6264
InteractiveUtils = "<0.0.1, 1"
65+
LineSearch = "0.1.4"
6366
LinearAlgebra = "1.10"
6467
LinearSolve = "2.36.1"
6568
Markdown = "1.10"
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
module NonlinearSolveBaseLineSearchExt
2+
3+
using LineSearch: LineSearch, AbstractLineSearchCache
4+
using NonlinearSolveBase: NonlinearSolveBase, InternalAPI
5+
using SciMLBase: SciMLBase
6+
7+
function NonlinearSolveBase.callback_into_cache!(
8+
topcache, cache::AbstractLineSearchCache, args...
9+
)
10+
return LineSearch.callback_into_cache!(cache, NonlinearSolveBase.get_fu(topcache))
11+
end
12+
13+
function InternalAPI.reinit!(cache::AbstractLineSearchCache; kwargs...)
14+
return SciMLBase.reinit!(cache; kwargs...)
15+
end
16+
17+
end

lib/NonlinearSolveBase/src/abstract_types.jl

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,31 @@
11
module InternalAPI
22

3+
using SciMLBase: NLStats
4+
35
function init end
46
function solve! end
5-
function reinit! end
67
function step! end
78

9+
function reinit! end
10+
function reinit_self! end
11+
12+
function reinit!(x::Any; kwargs...)
13+
@debug "`InternalAPI.reinit!` is not implemented for $(typeof(x))."
14+
return
15+
end
16+
function reinit_self!(x::Any; kwargs...)
17+
@debug "`InternalAPI.reinit_self!` is not implemented for $(typeof(x))."
18+
return
19+
end
20+
21+
function reinit_self!(stats::NLStats)
22+
stats.nf = 0
23+
stats.nsteps = 0
24+
stats.nfactors = 0
25+
stats.njacs = 0
26+
stats.nsolve = 0
27+
end
28+
829
end
930

1031
abstract type AbstractNonlinearSolveBaseAPI end # Mostly used for pretty-printing
@@ -512,3 +533,53 @@ accepted then these values should be copied into the toplevel cache.
512533
abstract type AbstractTrustRegionMethodCache <: AbstractNonlinearSolveBaseAPI end
513534

514535
last_step_accepted(cache::AbstractTrustRegionMethodCache) = cache.last_step_accepted
536+
537+
# Additional Interface
538+
"""
539+
callback_into_cache!(cache, internalcache, args...)
540+
541+
Define custom operations on `internalcache` tightly coupled with the calling `cache`.
542+
`args...` contain the sequence of caches calling into `internalcache`.
543+
544+
This unfortunately makes code very tightly coupled and not modular. It is recommended to not
545+
use this functionality unless it can't be avoided (like in [`LevenbergMarquardt`](@ref)).
546+
"""
547+
callback_into_cache!(cache, internalcache, args...) = nothing # By default do nothing
548+
549+
# Helper functions to generate cache callbacks and resetting functions
550+
macro internal_caches(cType, internal_cache_names...)
551+
callback_caches = map(internal_cache_names) do name
552+
return quote
553+
$(callback_into_cache!)(
554+
cache, getproperty(internalcache, $(name)), internalcache, args...
555+
)
556+
end
557+
end
558+
callbacks_self = map(internal_cache_names) do name
559+
return quote
560+
$(callback_into_cache!)(cache, getproperty(cache, $(name)))
561+
end
562+
end
563+
reinit_caches = map(internal_cache_names) do name
564+
return quote
565+
$(InternalAPI.reinit!)(getproperty(cache, $(name)), args...; kwargs...)
566+
end
567+
end
568+
return esc(quote
569+
function NonlinearSolveBase.callback_into_cache!(
570+
cache, internalcache::$(cType), args...
571+
)
572+
$(callback_caches...)
573+
return
574+
end
575+
function NonlinearSolveBase.callback_into_cache!(cache::$(cType))
576+
$(callbacks_self...)
577+
return
578+
end
579+
function NonlinearSolveBase.InternalAPI.reinit!(cache::$(cType), args...; kwargs...)
580+
$(reinit_caches...)
581+
$(InternalAPI.reinit_self!)(cache, args...; kwargs...)
582+
return
583+
end
584+
end)
585+
end

lib/NonlinearSolveBase/src/descent/damped_newton.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ supports_trust_region(::DampedNewtonDescent) = true
4242
mode <: Union{Val{:normal_form}, Val{:least_squares}, Val{:simple}}
4343
end
4444

45-
# XXX: Implement
46-
# @internal_caches DampedNewtonDescentCache :lincache :damping_fn_cache
45+
NonlinearSolveBase.@internal_caches DampedNewtonDescentCache :lincache :damping_fn_cache
4746

4847
function InternalAPI.init(
4948
prob::AbstractNonlinearProblem, alg::DampedNewtonDescent, J, fu, u; stats,

lib/NonlinearSolveBase/src/descent/dogleg.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ end
4444
normal_form <: Union{Val{false}, Val{true}}
4545
end
4646

47-
# XXX: Implement
48-
# @internal_caches DoglegCache :newton_cache :cauchy_cache
47+
NonlinearSolveBase.@internal_caches DoglegCache :newton_cache :cauchy_cache
4948

5049
function InternalAPI.init(
5150
prob::AbstractNonlinearProblem, alg::Dogleg, J, fu, u;

lib/NonlinearSolveBase/src/descent/geodesic_acceleration.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,12 @@ get_linear_solver(alg::GeodesicAcceleration) = get_linear_solver(alg.descent)
4949
last_step_accepted::Bool
5050
end
5151

52-
function InternalAPI.reinit!(cache::GeodesicAccelerationCache; p = cache.p, kwargs...)
52+
function InternalAPI.reinit_self!(cache::GeodesicAccelerationCache; p = cache.p, kwargs...)
5353
cache.p = p
5454
cache.last_step_accepted = false
5555
end
5656

57-
# XXX: Implement
58-
# @internal_caches GeodesicAccelerationCache :descent_cache
57+
NonlinearSolveBase.@internal_caches GeodesicAccelerationCache :descent_cache
5958

6059
function get_velocity(cache::GeodesicAccelerationCache)
6160
return SciMLBase.get_du(cache.descent_cache, Val(1))

lib/NonlinearSolveBase/src/descent/newton.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ supports_line_search(::NewtonDescent) = true
2424
normal_form <: Union{Val{false}, Val{true}}
2525
end
2626

27-
# XXX: Implement
28-
# @internal_caches NewtonDescentCache :lincache
27+
NonlinearSolveBase.@internal_caches NewtonDescentCache :lincache
2928

3029
function InternalAPI.init(
3130
prob::AbstractNonlinearProblem, alg::NewtonDescent, J, fu, u; stats,

lib/NonlinearSolveBase/src/descent/steepest.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ supports_line_search(::SteepestDescent) = true
2121
preinverted_jacobian <: Union{Val{false}, Val{true}}
2222
end
2323

24-
# XXX: Implement
25-
# @internal_caches SteepestDescentCache :lincache
24+
NonlinearSolveBase.@internal_caches SteepestDescentCache :lincache
2625

2726
function InternalAPI.init(
2827
prob::AbstractNonlinearProblem, alg::SteepestDescent, J, fu, u;

lib/NonlinearSolveBase/src/solve.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ function CommonSolve.solve!(cache::AbstractNonlinearSolveCache)
1818
)
1919
end
2020

21-
# XXX: Implement this
22-
# update_from_termination_cache!(cache.termination_cache, cache)
21+
update_from_termination_cache!(cache.termination_cache, cache)
2322

2423
update_trace!(
2524
cache.trace, cache.nsteps, get_u(cache), get_fu(cache), nothing, nothing, nothing;

lib/NonlinearSolveBase/src/termination_conditions.jl

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
const RelNormModes = Union{
2-
RelNormTerminationMode, RelNormSafeTerminationMode, RelNormSafeBestTerminationMode}
2+
RelNormTerminationMode, RelNormSafeTerminationMode, RelNormSafeBestTerminationMode
3+
}
34
const AbsNormModes = Union{
4-
AbsNormTerminationMode, AbsNormSafeTerminationMode, AbsNormSafeBestTerminationMode}
5+
AbsNormTerminationMode, AbsNormSafeTerminationMode, AbsNormSafeBestTerminationMode
6+
}
57

68
# Core Implementation
79
@concrete mutable struct NonlinearTerminationModeCache{uType, T}
@@ -32,7 +34,8 @@ end
3234

3335
function CommonSolve.init(
3436
::AbstractNonlinearProblem, mode::AbstractNonlinearTerminationMode, du, u,
35-
saved_value_prototype...; abstol = nothing, reltol = nothing, kwargs...)
37+
saved_value_prototype...; abstol = nothing, reltol = nothing, kwargs...
38+
)
3639
T = promote_type(eltype(du), eltype(u))
3740
abstol = get_tolerance(u, abstol, T)
3841
reltol = get_tolerance(u, reltol, T)
@@ -77,12 +80,14 @@ function CommonSolve.init(
7780
return NonlinearTerminationModeCache(
7881
u_unaliased, ReturnCode.Default, abstol, reltol, best_value, mode,
7982
initial_objective, objectives_trace, 0, saved_value_prototype,
80-
u0_norm, step_norm_trace, max_stalled_steps, u_diff_cache)
83+
u0_norm, step_norm_trace, max_stalled_steps, u_diff_cache
84+
)
8185
end
8286

8387
function SciMLBase.reinit!(
8488
cache::NonlinearTerminationModeCache, du, u, saved_value_prototype...;
85-
abstol = cache.abstol, reltol = cache.reltol, kwargs...)
89+
abstol = cache.abstol, reltol = cache.reltol, kwargs...
90+
)
8691
T = eltype(cache.abstol)
8792
length(saved_value_prototype) != 0 && (cache.saved_values = saved_value_prototype)
8893

@@ -113,7 +118,8 @@ end
113118

114119
## This dispatch is needed based on how Terminating Callback works!
115120
function (cache::NonlinearTerminationModeCache)(
116-
integrator::AbstractODEIntegrator, abstol::Number, reltol::Number, min_t)
121+
integrator::AbstractODEIntegrator, abstol::Number, reltol::Number, min_t
122+
)
117123
if min_t === nothing || integrator.t min_t
118124
return cache(cache.mode, SciMLBase.get_du(integrator),
119125
integrator.u, integrator.uprev, abstol, reltol)
@@ -125,7 +131,8 @@ function (cache::NonlinearTerminationModeCache)(du, u, uprev, args...)
125131
end
126132

127133
function (cache::NonlinearTerminationModeCache)(
128-
mode::AbstractNonlinearTerminationMode, du, u, uprev, abstol, reltol, args...)
134+
mode::AbstractNonlinearTerminationMode, du, u, uprev, abstol, reltol, args...
135+
)
129136
if check_convergence(mode, du, u, uprev, abstol, reltol)
130137
cache.retcode = ReturnCode.Success
131138
return true
@@ -134,7 +141,8 @@ function (cache::NonlinearTerminationModeCache)(
134141
end
135142

136143
function (cache::NonlinearTerminationModeCache)(
137-
mode::AbstractSafeNonlinearTerminationMode, du, u, uprev, abstol, reltol, args...)
144+
mode::AbstractSafeNonlinearTerminationMode, du, u, uprev, abstol, reltol, args...
145+
)
138146
if mode isa AbsNormSafeTerminationMode || mode isa AbsNormSafeBestTerminationMode
139147
objective = Utils.apply_norm(mode.internalnorm, du)
140148
criteria = abstol
@@ -251,15 +259,17 @@ end
251259
# High-Level API with defaults.
252260
## This is mostly for internal usage in NonlinearSolve and SimpleNonlinearSolve
253261
function default_termination_mode(
254-
::Union{ImmutableNonlinearProblem, NonlinearProblem}, ::Val{:simple})
262+
::Union{ImmutableNonlinearProblem, NonlinearProblem}, ::Val{:simple}
263+
)
255264
return AbsNormTerminationMode(Base.Fix1(maximum, abs))
256265
end
257266
function default_termination_mode(::NonlinearLeastSquaresProblem, ::Val{:simple})
258267
return AbsNormTerminationMode(Base.Fix2(norm, 2))
259268
end
260269

261270
function default_termination_mode(
262-
::Union{ImmutableNonlinearProblem, NonlinearProblem}, ::Val{:regular})
271+
::Union{ImmutableNonlinearProblem, NonlinearProblem}, ::Val{:regular}
272+
)
263273
return AbsNormSafeBestTerminationMode(Base.Fix1(maximum, abs); max_stalled_steps = 32)
264274
end
265275

@@ -268,16 +278,53 @@ function default_termination_mode(::NonlinearLeastSquaresProblem, ::Val{:regular
268278
end
269279

270280
function init_termination_cache(
271-
prob::AbstractNonlinearProblem, abstol, reltol, du, u, ::Nothing, callee::Val)
281+
prob::AbstractNonlinearProblem, abstol, reltol, du, u, ::Nothing, callee::Val
282+
)
272283
return init_termination_cache(
273284
prob, abstol, reltol, du, u, default_termination_mode(prob, callee), callee)
274285
end
275286

276287
function init_termination_cache(prob::AbstractNonlinearProblem, abstol, reltol, du,
277-
u, tc::AbstractNonlinearTerminationMode, ::Val)
288+
u, tc::AbstractNonlinearTerminationMode, ::Val
289+
)
278290
T = promote_type(eltype(du), eltype(u))
279291
abstol = get_tolerance(u, abstol, T)
280292
reltol = get_tolerance(u, reltol, T)
281293
cache = init(prob, tc, du, u; abstol, reltol)
282294
return abstol, reltol, cache
283295
end
296+
297+
function check_and_update!(cache, fu, u, uprev)
298+
return check_and_update!(
299+
cache.termination_cache, cache, fu, u, uprev, cache.termination_cache.mode
300+
)
301+
end
302+
303+
function check_and_update!(tc_cache, cache, fu, u, uprev, mode)
304+
if tc_cache(fu, u, uprev)
305+
cache.retcode = tc_cache.retcode
306+
update_from_termination_cache!(tc_cache, cache, mode, u)
307+
cache.force_stop = true
308+
end
309+
end
310+
311+
function update_from_termination_cache!(tc_cache, cache, u = get_u(cache))
312+
return update_from_termination_cache!(tc_cache, cache, tc_cache.mode, u)
313+
end
314+
315+
function update_from_termination_cache!(
316+
tc_cache, cache, ::AbstractNonlinearTerminationMode, u = get_u(cache)
317+
)
318+
Utils.evaluate_f!(cache, u, cache.p)
319+
end
320+
321+
function update_from_termination_cache!(
322+
tc_cache, cache, ::AbstractSafeBestNonlinearTerminationMode, u = get_u(cache)
323+
)
324+
if SciMLBase.isinplace(cache)
325+
copyto!(get_u(cache), tc_cache.u)
326+
else
327+
SciMLBase.set_u!(cache, tc_cache.u)
328+
end
329+
Utils.evaluate_f!(cache, get_u(cache), cache.p)
330+
end

lib/NonlinearSolveBase/src/tracing.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,16 @@ function NonlinearSolveTraceEntry(prob::AbstractNonlinearProblem, iteration, fu,
103103
norm_type = ifelse(prob isa NonlinearLeastSquaresProblem, :L2, :Inf)
104104
fnorm = prob isa NonlinearLeastSquaresProblem ? L2_NORM(fu) : Linf_NORM(fu)
105105
condJ = J !== missing ? Utils.condition_number(J) : nothing
106-
storage = u === missing ? nothing :
107-
(; u = copy(u), fu = copy(fu), δu = copy(δu), J = copy(J))
106+
storage = if u === missing
107+
nothing
108+
else
109+
(;
110+
u = ArrayInterface.ismutable(u) ? copy(u) : u,
111+
fu = ArrayInterface.ismutable(fu) ? copy(fu) : fu,
112+
δu = ArrayInterface.ismutable(δu) ? copy(δu) : δu,
113+
J = ArrayInterface.ismutable(J) ? copy(J) : J
114+
)
115+
end
108116
return NonlinearSolveTraceEntry(
109117
iteration, fnorm, L2_NORM(δu), condJ, storage, norm_type
110118
)
@@ -149,7 +157,8 @@ function init_nonlinearsolve_trace(
149157
)
150158
if show_trace isa Val{true}
151159
print("\nAlgorithm: ")
152-
Base.printstyled(alg, "\n\n"; color = :green, bold = true)
160+
str = Utils.clean_sprint_struct(alg, 0)
161+
Base.printstyled(str, "\n\n"; color = :green, bold = true)
153162
end
154163
J = uses_jac_inverse isa Val{true} ?
155164
(trace_level.trace_mode isa Val{:minimal} ? J : LinearAlgebra.pinv(J)) : J

lib/NonlinearSolveQuasiNewton/src/initialization.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,11 @@ is reinitialized.
3737
internalnorm
3838
end
3939

40-
function InternalAPI.reinit!(cache::InitializedApproximateJacobianCache; kwargs...)
40+
function InternalAPI.reinit_self!(cache::InitializedApproximateJacobianCache; kwargs...)
4141
cache.initialized = false
4242
end
4343

44-
# XXX: Implement
45-
# @internal_caches InitializedApproximateJacobianCache :cache
44+
NonlinearSolveBase.@internal_caches InitializedApproximateJacobianCache :cache
4645

4746
function (cache::InitializedApproximateJacobianCache)(::Nothing)
4847
return NonlinearSolveBase.get_full_jacobian(cache, cache.structure, cache.J)

0 commit comments

Comments
 (0)