Skip to content

Commit

Permalink
Use a descent result
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 13, 2024
1 parent 9e46f00 commit 517f695
Show file tree
Hide file tree
Showing 13 changed files with 163 additions and 116 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "3.5.4"
version = "3.6.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
1 change: 1 addition & 0 deletions docs/src/devdocs/internal_interfaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ NonlinearSolve.AbstractNonlinearSolveCache
```@docs
NonlinearSolve.AbstractDescentAlgorithm
NonlinearSolve.AbstractDescentCache
NonlinearSolve.DescentResult
```

## Approximate Jacobian
Expand Down
2 changes: 2 additions & 0 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@ include("adtypes.jl")
include("timer_outputs.jl")
include("internal/helpers.jl")

include("descent/common.jl")
include("descent/newton.jl")
include("descent/steepest.jl")
include("descent/dogleg.jl")
include("descent/damped_newton.jl")
include("descent/geodesic_acceleration.jl")
include("descent/multistep.jl")

include("internal/operators.jl")
include("internal/jacobian.jl")
Expand Down
13 changes: 3 additions & 10 deletions src/abstract_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ Abstract Type for all Descent Caches.
### `__internal_solve!` specification
```julia
δu, success, intermediates = __internal_solve!(cache::AbstractDescentCache, J, fu, u,
idx::Val; skip_solve::Bool = false, kwargs...)
descent_result = __internal_solve!(cache::AbstractDescentCache, J, fu, u, idx::Val;
skip_solve::Bool = false, kwargs...)
```
- `J`: Jacobian or Inverse Jacobian (if `pre_inverted = Val(true)`).
Expand All @@ -78,14 +78,7 @@ Abstract Type for all Descent Caches.
direction was rejected and we want to try with a modified trust region.
- `kwargs`: keyword arguments to pass to the linear solver if there is one.
#### Returned values
- `δu`: the descent direction.
- `success`: Certain Descent Algorithms can reject a descent direction for example
`GeodesicAcceleration`.
- `intermediates`: A named tuple containing intermediates computed during the solve.
For example, `GeodesicAcceleration` returns `NamedTuple{(:v, :a)}` containing the
"velocity" and "acceleration" terms.
Returns a result of type [`DescentResult`](@ref).
### Interface Functions
Expand Down
94 changes: 52 additions & 42 deletions src/core/approximate_jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,7 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem{uType, iip},

linsolve = get_linear_solver(alg.descent)
initialization_cache = __internal_init(prob, alg.initialization, alg, f, fu, u, p;
linsolve,
maxiters, internalnorm)
linsolve, maxiters, internalnorm)

abstol, reltol, termination_cache = init_termination_cache(abstol, reltol, fu, u,
termination_condition)
Expand Down Expand Up @@ -222,9 +221,7 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
new_jacobian = true
@static_timeit cache.timer "jacobian init/reinit" begin
if get_nsteps(cache) == 0 # First Step is special ignore kwargs
J_init = __internal_solve!(cache.initialization_cache,
cache.fu,
cache.u,
J_init = __internal_solve!(cache.initialization_cache, cache.fu, cache.u,
Val(false))
if INV
if jacobian_initialized_preinverted(cache.initialization_cache.alg)
Expand Down Expand Up @@ -283,52 +280,65 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
@static_timeit cache.timer "descent" begin
if cache.trustregion_cache !== nothing &&
hasfield(typeof(cache.trustregion_cache), :trust_region)
δu, descent_success, descent_intermediates = __internal_solve!(cache.descent_cache,
J, cache.fu, cache.u; new_jacobian,
trust_region = cache.trustregion_cache.trust_region)
descent_result = __internal_solve!(cache.descent_cache, J, cache.fu, cache.u;
new_jacobian, trust_region = cache.trustregion_cache.trust_region)
else
δu, descent_success, descent_intermediates = __internal_solve!(cache.descent_cache,
J, cache.fu, cache.u; new_jacobian)
descent_result = __internal_solve!(cache.descent_cache, J, cache.fu, cache.u;
new_jacobian)
end
end

if descent_success
if GB === :LineSearch
@static_timeit cache.timer "linesearch" begin
needs_reset, α = __internal_solve!(cache.linesearch_cache, cache.u, δu)
end
if needs_reset && cache.steps_since_last_reset > 5 # Reset after a burn-in period
cache.force_reinit = true
else
@static_timeit cache.timer "step" begin
@bb axpy!(α, δu, cache.u)
evaluate_f!(cache, cache.u, cache.p)
end
end
elseif GB === :TrustRegion
@static_timeit cache.timer "trustregion" begin
tr_accepted, u_new, fu_new = __internal_solve!(cache.trustregion_cache, J,
cache.fu, cache.u, δu, descent_intermediates)
if tr_accepted
@bb copyto!(cache.u, u_new)
@bb copyto!(cache.fu, fu_new)
end
if hasfield(typeof(cache.trustregion_cache), :shrink_counter) &&
cache.trustregion_cache.shrink_counter > cache.max_shrink_times
cache.retcode = ReturnCode.ShrinkThresholdExceeded
cache.force_stop = true
end
end
α = true
elseif GB === :None
if descent_result.success
if GB === :None
@static_timeit cache.timer "step" begin
@bb axpy!(1, δu, cache.u)
if descent_result.u !== missing
@bb copyto!(cache.u, descent_result.u)
elseif descent_result.δu !== missing
@bb axpy!(1, descent_result.δu, cache.u)
else
error("This shouldn't occur. `$(cache.alg.descent)` is incorrectly \
specified.")
end
evaluate_f!(cache, cache.u, cache.p)
end
α = true
else
error("Unknown Globalization Strategy: $(GB). Allowed values are (:LineSearch, \
:TrustRegion, :None)")
δu = descent_result.δu
@assert δu!==missing "Descent Supporting LineSearch or TrustRegion must return a `δu`."

if GB === :LineSearch
@static_timeit cache.timer "linesearch" begin
needs_reset, α = __internal_solve!(cache.linesearch_cache, cache.u, δu)
end
if needs_reset && cache.steps_since_last_reset > 5 # Reset after a burn-in period
cache.force_reinit = true
else
@static_timeit cache.timer "step" begin
@bb axpy!(α, δu, cache.u)
evaluate_f!(cache, cache.u, cache.p)
end
end
elseif GB === :TrustRegion
@static_timeit cache.timer "trustregion" begin
tr_accepted, u_new, fu_new = __internal_solve!(cache.trustregion_cache,
J, cache.fu, cache.u, δu, descent_result.extras)
if tr_accepted
@bb copyto!(cache.u, u_new)
@bb copyto!(cache.fu, fu_new)
α = true
else
α = false
end
if hasfield(typeof(cache.trustregion_cache), :shrink_counter) &&
cache.trustregion_cache.shrink_counter > cache.max_shrink_times
cache.retcode = ReturnCode.ShrinkThresholdExceeded
cache.force_stop = true
end
end
else
error("Unknown Globalization Strategy: $(GB). Allowed values are \
(:LineSearch, :TrustRegion, :None)")
end
end
check_and_update!(cache, cache.fu, cache.u, cache.u_cache)
else
Expand Down
88 changes: 49 additions & 39 deletions src/core/generalized_first_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,57 +215,67 @@ function __step!(cache::GeneralizedFirstOrderAlgorithmCache{iip, GB};
@static_timeit cache.timer "descent" begin
if cache.trustregion_cache !== nothing &&
hasfield(typeof(cache.trustregion_cache), :trust_region)
δu, descent_success, descent_intermediates = __internal_solve!(cache.descent_cache,
J, cache.fu, cache.u; new_jacobian,
trust_region = cache.trustregion_cache.trust_region)
descent_result = __internal_solve!(cache.descent_cache, J, cache.fu, cache.u;
new_jacobian, trust_region = cache.trustregion_cache.trust_region)
else
δu, descent_success, descent_intermediates = __internal_solve!(cache.descent_cache,
J, cache.fu, cache.u; new_jacobian)
descent_result = __internal_solve!(cache.descent_cache, J, cache.fu, cache.u;
new_jacobian)
end
end

if descent_success
if descent_result.success
cache.make_new_jacobian = true
if GB === :LineSearch
@static_timeit cache.timer "linesearch" begin
linesearch_failed, α = __internal_solve!(cache.linesearch_cache,
cache.u, δu)
end
if linesearch_failed
cache.retcode = ReturnCode.InternalLineSearchFailed
cache.force_stop = true
end
if GB === :None
@static_timeit cache.timer "step" begin
@bb axpy!(α, δu, cache.u)
evaluate_f!(cache, cache.u, cache.p)
end
elseif GB === :TrustRegion
@static_timeit cache.timer "trustregion" begin
tr_accepted, u_new, fu_new = __internal_solve!(cache.trustregion_cache, J,
cache.fu, cache.u, δu, descent_intermediates)
if tr_accepted
@bb copyto!(cache.u, u_new)
@bb copyto!(cache.fu, fu_new)
α = true
if descent_result.u !== missing
@bb copyto!(cache.u, descent_result.u)
elseif descent_result.δu !== missing
@bb axpy!(1, descent_result.δu, cache.u)
else
α = false
cache.make_new_jacobian = false
error("This shouldn't occur. `$(cache.alg.descent)` is incorrectly \
specified.")
end
if hasfield(typeof(cache.trustregion_cache), :shrink_counter) &&
cache.trustregion_cache.shrink_counter > cache.max_shrink_times
cache.retcode = ReturnCode.ShrinkThresholdExceeded
cache.force_stop = true
end
end
elseif GB === :None
@static_timeit cache.timer "step" begin
@bb axpy!(1, δu, cache.u)
evaluate_f!(cache, cache.u, cache.p)
end
α = true
else
error("Unknown Globalization Strategy: $(GB). Allowed values are (:LineSearch, \
:TrustRegion, :None)")
δu = descent_result.δu
@assert δu!==missing "Descent Supporting LineSearch or TrustRegion must return a `δu`."

if GB === :LineSearch
@static_timeit cache.timer "linesearch" begin
failed, α = __internal_solve!(cache.linesearch_cache, cache.u, δu)
end
if failed
cache.retcode = ReturnCode.InternalLineSearchFailed
cache.force_stop = true
else
@static_timeit cache.timer "step" begin
@bb axpy!(α, δu, cache.u)
evaluate_f!(cache, cache.u, cache.p)
end
end
elseif GB === :TrustRegion
@static_timeit cache.timer "trustregion" begin
tr_accepted, u_new, fu_new = __internal_solve!(cache.trustregion_cache,
J, cache.fu, cache.u, δu, descent_result.extras)
if tr_accepted
@bb copyto!(cache.u, u_new)
@bb copyto!(cache.fu, fu_new)
α = true
else
α = false
end
if hasfield(typeof(cache.trustregion_cache), :shrink_counter) &&
cache.trustregion_cache.shrink_counter > cache.max_shrink_times
cache.retcode = ReturnCode.ShrinkThresholdExceeded
cache.force_stop = true
end
end
else
error("Unknown Globalization Strategy: $(GB). Allowed values are \
(:LineSearch, :TrustRegion, :None)")
end
end
check_and_update!(cache, cache.fu, cache.u, cache.u_cache)
else
Expand Down
26 changes: 26 additions & 0 deletions src/descent/common.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
DescentResult(; δu = missing, u = missing, success::Bool = true, extras = (;))
Construct a `DescentResult` object.
### Keyword Arguments
* `δu`: The descent direction.
* `u`: The new iterate. This is provided only for multi-step methods currently.
* `success`: Certain Descent Algorithms can reject a descent direction for example
[`GeodesicAcceleration`](@ref).
* `extras`: A named tuple containing intermediates computed during the solve.
For example, [`GeodesicAcceleration`](@ref) returns `NamedTuple{(:v, :a)}` containing
the "velocity" and "acceleration" terms.
"""
@concrete struct DescentResult
δu
u
success::Bool
extras
end

function DescentResult(; δu = missing, u = missing, success::Bool = true, extras = (;))
@assert δu !== missing || u !== missing
return DescentResult(δu, u, success, extras)
end
9 changes: 4 additions & 5 deletions src/descent/damped_newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ function __internal_solve!(cache::DampedNewtonDescentCache{INV, mode}, J, fu, u,
idx::Val{N} = Val(1); skip_solve::Bool = false, new_jacobian::Bool = true,
kwargs...) where {INV, N, mode}
δu = get_du(cache, idx)
skip_solve && return δu, true, (;)
skip_solve && return DescentResult(; δu)

recompute_A = idx === Val(1)

Expand Down Expand Up @@ -201,15 +201,14 @@ function __internal_solve!(cache::DampedNewtonDescentCache{INV, mode}, J, fu, u,
end

@static_timeit cache.timer "linear solve" begin
δu = cache.lincache(; A, b,
reuse_A_if_factorization = !new_jacobian && !recompute_A,
kwargs..., linu = _vec(δu))
δu = cache.lincache(; A, b, linu = _vec(δu),
reuse_A_if_factorization = !new_jacobian && !recompute_A, kwargs...)
δu = _restructure(get_du(cache, idx), δu)
end

@bb @. δu *= -1
set_du!(cache, δu, idx)
return δu, true, (;)
return DescentResult(; δu)
end

# Define special concatenation for certain Array combinations
Expand Down
Loading

0 comments on commit 517f695

Please sign in to comment.