-
-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
152 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
function MultiStepNonlinearSolver(; concrete_jac = nothing, linsolve = nothing, | ||
scheme = MSS.PotraPtak3, precs = DEFAULT_PRECS, autodiff = nothing) | ||
descent = GenericMultiStepDescent(; scheme, linsolve, precs) | ||
# TODO: Use the scheme as the name | ||
return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = :MultiStepNonlinearSolver, | ||
descent, jacobian_ad = autodiff) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
""" | ||
MultiStepSchemes | ||
This module defines the multistep schemes used in the multistep descent algorithms. The | ||
naming convention follows <name of method><order of convergence>. The name of method is | ||
typically the last names of the authors of the paper that introduced the method. | ||
""" | ||
module MultiStepSchemes | ||
|
||
abstract type AbstractMultiStepScheme end | ||
|
||
function Base.show(io::IO, mss::AbstractMultiStepScheme) | ||
print(io, "MultiStepSchemes.$(string(nameof(typeof(mss)))[3:end])") | ||
end | ||
|
||
struct __PotraPtak3 <: AbstractMultiStepScheme end | ||
const PotraPtak3 = __PotraPtak3() | ||
|
||
alg_steps(::__PotraPtak3) = 1 | ||
|
||
struct __SinghSharma4 <: AbstractMultiStepScheme end | ||
const SinghSharma4 = __SinghSharma4() | ||
|
||
alg_steps(::__SinghSharma4) = 3 | ||
|
||
struct __SinghSharma5 <: AbstractMultiStepScheme end | ||
const SinghSharma5 = __SinghSharma5() | ||
|
||
alg_steps(::__SinghSharma5) = 3 | ||
|
||
struct __SinghSharma7 <: AbstractMultiStepScheme end | ||
const SinghSharma7 = __SinghSharma7() | ||
|
||
alg_steps(::__SinghSharma7) = 4 | ||
|
||
end | ||
|
||
const MSS = MultiStepSchemes | ||
|
||
@kwdef @concrete struct GenericMultiStepDescent <: AbstractDescentAlgorithm | ||
scheme | ||
linsolve = nothing | ||
precs = DEFAULT_PRECS | ||
end | ||
|
||
supports_line_search(::GenericMultiStepDescent) = false | ||
supports_trust_region(::GenericMultiStepDescent) = false | ||
|
||
@concrete mutable struct GenericMultiStepDescentCache{S, INV} <: AbstractDescentCache | ||
f | ||
p | ||
δu | ||
δus | ||
scheme::S | ||
lincache | ||
timer | ||
nf::Int | ||
end | ||
|
||
@internal_caches GenericMultiStepDescentCache :lincache | ||
|
||
function __reinit_internal!(cache::GenericMultiStepDescentCache, args...; p = cache.p, | ||
kwargs...) | ||
cache.nf = 0 | ||
cache.p = p | ||
end | ||
|
||
function __δu_caches(scheme::MSS.__PotraPtak3, fu, u, ::Val{N}) where {N} | ||
caches = ntuple(N) do i | ||
@bb δu = similar(u) | ||
@bb y = similar(u) | ||
@bb fy = similar(fu) | ||
@bb δy = similar(u) | ||
@bb u_new = similar(u) | ||
(δu, δy, fy, y, u_new) | ||
end | ||
return first(caches), (N ≤ 1 ? nothing : caches[2:end]) | ||
end | ||
|
||
function __internal_init(prob::NonlinearProblem, alg::GenericMultiStepDescent, J, fu, u; | ||
shared::Val{N} = Val(1), pre_inverted::Val{INV} = False, linsolve_kwargs = (;), | ||
abstol = nothing, reltol = nothing, timer = get_timer_output(), | ||
kwargs...) where {INV, N} | ||
δu, δus = __δu_caches(alg.scheme, fu, u, shared) | ||
INV && return GenericMultiStepDescentCache{true}(prob.f, prob.p, δu, δus, | ||
alg.scheme, nothing, timer, 0) | ||
lincache = LinearSolverCache(alg, alg.linsolve, J, _vec(fu), _vec(u); abstol, reltol, | ||
linsolve_kwargs...) | ||
return GenericMultiStepDescentCache{false}(prob.f, prob.p, δu, δus, alg.scheme, | ||
lincache, timer, 0) | ||
end | ||
|
||
function __internal_init(prob::NonlinearLeastSquaresProblem, alg::GenericMultiStepDescent, | ||
J, fu, u; kwargs...) | ||
error("Multi-Step Descent Algorithms for NLLS are not implemented yet.") | ||
end | ||
|
||
function __internal_solve!(cache::GenericMultiStepDescentCache{MSS.__PotraPtak3, INV}, J, | ||
fu, u, idx::Val = Val(1); skip_solve::Bool = false, new_jacobian::Bool = true, | ||
kwargs...) where {INV} | ||
(u_new, δy, fy, y, δu) = get_du(cache, idx) | ||
skip_solve && return DescentResult(; u = u_new) | ||
|
||
@static_timeit cache.timer "linear solve" begin | ||
@static_timeit cache.timer "solve and step 1" begin | ||
if INV | ||
J !== nothing && @bb(δu=J × _vec(fu)) | ||
else | ||
δu = cache.lincache(; A = J, b = _vec(fu), kwargs..., linu = _vec(δu), | ||
du = _vec(δu), | ||
reuse_A_if_factorization = !new_jacobian || (idx !== Val(1))) | ||
δu = _restructure(u, δu) | ||
end | ||
@bb @. y = u - δu | ||
end | ||
|
||
fy = evaluate_f!!(cache.f, fy, y, cache.p) | ||
cache.nf += 1 | ||
|
||
@static_timeit cache.timer "solve and step 2" begin | ||
if INV | ||
J !== nothing && @bb(δy=J × _vec(fy)) | ||
else | ||
δy = cache.lincache(; A = J, b = _vec(fy), kwargs..., linu = _vec(δy), | ||
du = _vec(δy), reuse_A_if_factorization = true) | ||
δy = _restructure(u, δy) | ||
end | ||
@bb @. u_new = y - δy | ||
end | ||
end | ||
|
||
set_du!(cache, (u_new, δy, fy, y, δu), idx) | ||
return DescentResult(; u = u_new) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters