Skip to content

Commit 7018113

Browse files
committed
rebase and add update heuristic.. need to fix hyperparameter
1 parent b07360f commit 7018113

File tree

2 files changed

+69
-15
lines changed

2 files changed

+69
-15
lines changed

src/raphson.jl

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
2-
NewtonRaphson(; concrete_jac = nothing, linsolve = nothing,
3-
precs = DEFAULT_PRECS, adkwargs...)
2+
NewtonRaphson(; concrete_jac = nothing, linsolve = nothing, linesearch = LineSearch(),
3+
precs = DEFAULT_PRECS, reuse = true, reusetol = 1e-6, adkwargs...)
44
55
An advanced NewtonRaphson implementation with support for efficient handling of sparse
66
matrices via colored automatic differentiation and preconditioned linear solvers. Designed
@@ -29,31 +29,49 @@ for large-scale and numerically-difficult nonlinear systems.
2929
- `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
3030
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
3131
used here directly, and they will be converted to the correct `LineSearch`.
32+
- `reuse`: Determines if the Jacobian is reused between (quasi-)Newton steps. Defaults to
33+
`true`. If `true` we check how far we stepped with the same Jacobian, and automatically
34+
take a new Jacobian if we stepped more than `reusetol` or if convergence slows or starts
35+
to diverge. If `false`, the Jacobian is updated in each step.
3236
"""
3337
@concrete struct NewtonRaphson{CJ, AD} <:
3438
AbstractNewtonAlgorithm{CJ, AD}
3539
ad::AD
3640
linsolve
3741
precs
3842
linesearch
43+
reusetol
44+
reuse::Bool
3945
end
4046

4147
function set_ad(alg::NewtonRaphson{CJ}, ad) where {CJ}
42-
return NewtonRaphson{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch)
48+
return NewtonRaphson{CJ}(ad,
49+
alg.linsolve,
50+
alg.precs,
51+
alg.linesearch,
52+
alg.reusetol,
53+
alg.reuse)
4354
end
4455

4556
function NewtonRaphson(; concrete_jac = nothing, linsolve = nothing,
46-
linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...)
57+
linesearch = LineSearch(), precs = DEFAULT_PRECS, reuse = true, reusetol = 1e-6,
58+
adkwargs...)
4759
ad = default_adargs_to_adtype(; adkwargs...)
4860
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
49-
return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch)
61+
return NewtonRaphson{_unwrap_val(concrete_jac)}(ad,
62+
linsolve,
63+
precs,
64+
linesearch,
65+
reusetol,
66+
reuse)
5067
end
5168

5269
@concrete mutable struct NewtonRaphsonCache{iip} <: AbstractNonlinearSolveCache{iip}
5370
f
5471
alg
5572
u
56-
u_prev
73+
uprev
74+
Δu
5775
fu1
5876
fu2
5977
du
@@ -81,22 +99,40 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso
8199
alg = get_concrete_algorithm(alg_, prob)
82100
@unpack f, u0, p = prob
83101
u = alias_u0 ? u0 : deepcopy(u0)
102+
uprev = deepcopy(u0)
103+
Δu = zero(u0)
104+
84105
fu1 = evaluate_f(prob, u)
85106
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
86107
linsolve_kwargs)
87108

88109
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu1, u,
89110
termination_condition)
90111

91-
return NewtonRaphsonCache{iip}(f, alg, u, copy(u), fu1, fu2, du, p, uf, linsolve, J,
112+
return NewtonRaphsonCache{iip}(f, alg, u, uprev, Δu, fu1, fu2, du, p, uf, linsolve, J,
92113
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob,
93114
NLStats(1, 0, 0, 0, 0),
94115
init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip)), tc_cache)
95116
end
96117

97118
function perform_step!(cache::NewtonRaphsonCache{true})
98-
@unpack u, u_prev, fu1, f, p, alg, J, linsolve, du = cache
99-
jacobian!!(J, cache)
119+
@unpack u, uprev, Δu, fu1, f, p, alg, J, linsolve, du = cache
120+
@unpack reuse = alg
121+
122+
if reuse
123+
# check how far we stepped
124+
@. Δu += u - uprev
125+
update = cache.internalnorm(Δu) > alg.reusetol
126+
if update || cache.stats.njacs == 0
127+
jacobian!!(J, cache)
128+
cache.stats.njacs += 1
129+
Δu .*= false
130+
end
131+
else
132+
jacobian!!(J, cache)
133+
cache.stats.njacs += 1
134+
end
135+
cache.uprev .= u
100136

101137
# u = u - J \ fu
102138
linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(fu1), linu = _vec(du),
@@ -112,16 +148,32 @@ function perform_step!(cache::NewtonRaphsonCache{true})
112148

113149
@. u_prev = u
114150
cache.stats.nf += 1
115-
cache.stats.njacs += 1
116151
cache.stats.nsolve += 1
117152
cache.stats.nfactors += 1
118153
return nothing
119154
end
120155

121156
function perform_step!(cache::NewtonRaphsonCache{false})
122-
@unpack u, u_prev, fu1, f, p, alg, linsolve = cache
157+
@unpack u, uprev, Δu, fu1, f, p, alg, linsolve = cache
158+
@unpack reuse = alg
159+
160+
if reuse
161+
# check how far we stepped
162+
cache.Δu += u - uprev
163+
update = cache.internalnorm(Δu) > alg.reusetol
164+
if update || cache.stats.njacs == 0
165+
cache.J = jacobian!!(cache.J, cache)
166+
cache.stats.njacs += 1
167+
cache.Δu *= false
168+
end
169+
else
170+
cache.J = jacobian!!(cache.J, cache)
171+
# cache.Δu *= false
172+
cache.stats.njacs += 1
173+
end
174+
175+
cache.uprev = u
123176

124-
cache.J = jacobian!!(cache.J, cache)
125177
# u = u - J \ fu
126178
if linsolve === nothing
127179
cache.du = fu1 / cache.J
@@ -140,7 +192,6 @@ function perform_step!(cache::NewtonRaphsonCache{false})
140192

141193
cache.u_prev = cache.u
142194
cache.stats.nf += 1
143-
cache.stats.njacs += 1
144195
cache.stats.nsolve += 1
145196
cache.stats.nfactors += 1
146197
return nothing

test/23_test_problems.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@ function test_on_library(problems, dicts, alg_ops, broken_tests, ϵ = 1e-4)
3030
end
3131
end
3232

33-
@testset "NewtonRaphson 23 Test Problems" begin
34-
alg_ops = (NewtonRaphson(),)
33+
# NewtonRaphson
34+
@testset "NewtonRaphson test problem library" begin
35+
alg_ops = (NewtonRaphson(; reuse = false),
36+
NewtonRaphson(; reuse = true, reusetol = 1e-6))
3537

3638
# dictionary with indices of test problems where method does not converge to small residual
3739
broken_tests = Dict(alg => Int[] for alg in alg_ops)
3840
broken_tests[alg_ops[1]] = [1, 6]
41+
broken_tests[alg_ops[2]] = [1, 6]
3942

4043
test_on_library(problems, dicts, alg_ops, broken_tests)
4144
end

0 commit comments

Comments
 (0)