Skip to content

Commit

Permalink
refactor: move Descent Directions to NonlinearSolveBase
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 26, 2024
1 parent 6d5f743 commit 33f1ec9
Show file tree
Hide file tree
Showing 21 changed files with 900 additions and 803 deletions.
6 changes: 0 additions & 6 deletions docs/src/devdocs/operators.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
# Custom SciML Operators

## Abstract Operators

```@docs
NonlinearSolve.AbstractNonlinearSolveOperator
```

## Low-Rank Jacobian Operators

```@docs
Expand Down
4 changes: 4 additions & 0 deletions lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ FunctionProperties = "f62d2435-5019-4c03-9749-2d4c77af0cbc"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

[weakdeps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Expand Down Expand Up @@ -57,6 +59,7 @@ LinearAlgebra = "1.10"
LinearSolve = "2.36.1"
Markdown = "1.10"
MaybeInplace = "0.1.4"
Preferences = "1.4"
RecursiveArrayTools = "3"
SciMLBase = "2.50"
SciMLJacobianOperators = "0.1.1"
Expand All @@ -65,6 +68,7 @@ SparseArrays = "1.10"
SparseMatrixColorings = "0.4.8"
StaticArraysCore = "1.4"
Test = "1.10"
TimerOutputs = "0.5.23"
julia = "1.10"

[extras]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module NonlinearSolveBaseSparseArraysExt

using NonlinearSolveBase: NonlinearSolveBase
using NonlinearSolveBase: NonlinearSolveBase, Utils
using SparseArrays: AbstractSparseMatrix, AbstractSparseMatrixCSC, nonzeros

function NonlinearSolveBase.NAN_CHECK(x::AbstractSparseMatrixCSC)
Expand All @@ -9,4 +9,6 @@ end

NonlinearSolveBase.sparse_or_structured_prototype(::AbstractSparseMatrix) = true

Utils.maybe_symmetric(x::AbstractSparseMatrix) = x

end
14 changes: 13 additions & 1 deletion lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ using DifferentiationInterface: DifferentiationInterface, Constant
using EnzymeCore: EnzymeCore
using FastClosures: @closure
using FunctionProperties: hasbranching
using LinearAlgebra: Diagonal, norm, ldiv!
using LinearAlgebra: LinearAlgebra, Diagonal, norm, ldiv!, diagind
using Markdown: @doc_str
using MaybeInplace: @bb
using Preferences: @load_preference
using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition
using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinearProblem,
AbstractNonlinearAlgorithm, AbstractNonlinearFunction,
Expand All @@ -37,6 +38,14 @@ include("termination_conditions.jl")
include("autodiff.jl")
include("jacobian.jl")
include("linear_solve.jl")
include("timer_outputs.jl")

include("descent/common.jl")
include("descent/newton.jl")
include("descent/steepest.jl")
include("descent/damped_newton.jl")
include("descent/dogleg.jl")
include("descent/geodesic_acceleration.jl")

# Unexported Public API
@compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance))
Expand All @@ -55,4 +64,7 @@ export RelTerminationMode, AbsTerminationMode,
RelNormSafeTerminationMode, AbsNormSafeTerminationMode,
RelNormSafeBestTerminationMode, AbsNormSafeBestTerminationMode

export DescentResult, SteepestDescent, NewtonDescent, DampedNewtonDescent, Dogleg,
GeodesicAcceleration

end
184 changes: 177 additions & 7 deletions lib/NonlinearSolveBase/src/abstract_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,65 @@ function reinit! end

end

abstract type AbstractDescentDirection end
abstract type AbstractNonlinearSolveBaseAPI end # Mostly used for pretty-printing

function Base.show(io::IO, ::MIME"text/plain", alg::AbstractNonlinearSolveBaseAPI)
main_name = nameof(typeof(alg))
modifiers = String[]
for field in fieldnames(typeof(alg))
val = getfield(alg, field)
Utils.is_default_value(val, field, getfield(alg, field)) && continue
push!(modifiers, "$(field) = $(val)")
end
print(io, "$(main_name)($(join(modifiers, ", ")))")
return
end

"""
AbstractDescentDirection
Abstract Type for all Descent Directions used in NonlinearSolveBase. Given the Jacobian
`J` and the residual `fu`, these algorithms compute the descent direction `δu`.
For non-square Jacobian problems, if we need to solve a linear solve problem, we use a
least squares solver by default, unless the provided `linsolve` can't handle non-square
matrices, in which case we use the normal form equations ``JᵀJ δu = Jᵀ fu``. Note that
this factorization is often the faster choice, but it is not as numerically stable as
the least squares solver.
### `InternalAPI.init` specification
```julia
InternalAPI.init(
prob::AbstractNonlinearProblem, alg::AbstractDescentDirection, J, fu, u;
pre_inverted::Val = Val(false), linsolve_kwargs = (;),
abstol = nothing, reltol = nothing, alias_J::Bool = true,
shared::Val = Val(1), kwargs...
)::AbstractDescentCache
```
- `pre_inverted`: whether or not the Jacobian has been pre_inverted.
- `linsolve_kwargs`: keyword arguments to pass to the linear solver.
- `abstol`: absolute tolerance for the linear solver.
- `reltol`: relative tolerance for the linear solver.
- `alias_J`: whether or not to alias the Jacobian.
- `shared`: Store multiple descent directions in the cache. Allows efficient and
correct reuse of factorizations if needed.
Some of the algorithms also allow additional keyword arguments. See the documentation for
the specific algorithm for more information.
### Interface Functions
- `supports_trust_region(alg)`: whether or not the algorithm supports trust region
methods. Defaults to `false`.
- `supports_line_search(alg)`: whether or not the algorithm supports line search
methods. Defaults to `false`.
See also [`NewtonDescent`](@ref), [`Dogleg`](@ref), [`SteepestDescent`](@ref),
[`DampedNewtonDescent`](@ref).
"""
abstract type AbstractDescentDirection <: AbstractNonlinearSolveBaseAPI end

supports_line_search(::AbstractDescentDirection) = false
supports_trust_region(::AbstractDescentDirection) = false
Expand All @@ -15,7 +73,46 @@ function get_linear_solver(alg::AbstractDescentDirection)
return Utils.safe_getproperty(alg, Val(:linsolve))
end

abstract type AbstractDescentCache end
"""
AbstractDescentCache
Abstract Type for all Descent Caches.
### `InternalAPI.solve!` specification
```julia
InternalAPI.solve!(
cache::AbstractDescentCache, J, fu, u, idx::Val;
skip_solve::Bool = false, new_jacobian::Bool = true, kwargs...
)::DescentResult
```
- `J`: Jacobian or Inverse Jacobian (if `pre_inverted = Val(true)`).
- `fu`: residual.
- `u`: current state.
- `idx`: index of the descent problem to solve and return. Defaults to `Val(1)`.
- `skip_solve`: Skip the direction computation and return the previous direction.
Defaults to `false`. This is useful for Trust Region Methods where the previous
direction was rejected and we want to try with a modified trust region.
- `new_jacobian`: Whether the Jacobian has been updated. Defaults to `true`.
- `kwargs`: keyword arguments to pass to the linear solver if there is one.
#### Returned values
- `descent_result`: Result in a [`DescentResult`](@ref).
### Interface Functions
- `get_du(cache)`: get the descent direction.
- `get_du(cache, ::Val{N})`: get the `N`th descent direction.
- `set_du!(cache, δu)`: set the descent direction.
- `set_du!(cache, δu, ::Val{N})`: set the `N`th descent direction.
- `last_step_accepted(cache)`: whether or not the last step was accepted. Checks if the
cache has a `last_step_accepted` field and returns it if it does, else returns `true`.
- `preinverted_jacobian(cache)`: whether or not the Jacobian has been preinverted.
- `normal_form(cache)`: whether or not the linear solver uses normal form.
"""
abstract type AbstractDescentCache <: AbstractNonlinearSolveBaseAPI end

SciMLBase.get_du(cache::AbstractDescentCache) = cache.δu
SciMLBase.get_du(cache::AbstractDescentCache, ::Val{1}) = SciMLBase.get_du(cache)
Expand All @@ -29,6 +126,79 @@ function last_step_accepted(cache::AbstractDescentCache)
return true
end

for fname in (:preinverted_jacobian, :normal_form)
@eval function $(fname)(alg::AbstractDescentCache)
res = Utils.unwrap_val(Utils.safe_getproperty(alg, Val($(QuoteNode(fname)))))
res === missing && return false
return res
end
end

"""
AbstractDampingFunction
Abstract Type for Damping Functions in DampedNewton.
### `InternalAPI.init` specification
```julia
InternalAPI.init(
prob::AbstractNonlinearProblem, f::AbstractDampingFunction, initial_damping,
J, fu, u, args...;
internalnorm::F = L2_NORM, kwargs...
)::AbstractDampingFunctionCache
```
Returns a [`NonlinearSolveBase.AbstractDampingFunctionCache`](@ref).
"""
abstract type AbstractDampingFunction <: AbstractNonlinearAlgorithm end

"""
AbstractDampingFunctionCache
Abstract Type for the Caches created by AbstractDampingFunctions
### Interface Functions
- `requires_normal_form_jacobian(alg)`: whether or not the Jacobian is needed in normal
form. No default.
- `requires_normal_form_rhs(alg)`: whether or not the residual is needed in normal form.
No default.
- `returns_norm_form_damping(alg)`: whether or not the damping function returns the
damping factor in normal form. Defaults to
`requires_normal_form_jacobian(alg) || requires_normal_form_rhs(alg)`.
- `(cache::AbstractDampingFunctionCache)(::Nothing)`: returns the damping factor. The type
of the damping factor returned from `solve!` is guaranteed to be the same as this.
### `InternalAPI.solve!` specification
```julia
InternalAPI.solve!(
cache::AbstractDampingFunctionCache, J, fu, u, δu, descent_stats
)
```
Returns the damping factor.
"""
abstract type AbstractDampingFunctionCache <: AbstractNonlinearAlgorithm end

function requires_normal_form_jacobian end
function requires_normal_form_rhs end
function returns_norm_form_damping(f::F) where {F}
return requires_normal_form_jacobian(f) || requires_normal_form_rhs(f)
end

"""
AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm
Abstract Type for all NonlinearSolveBase Algorithms.
### Interface Functions
- `concrete_jac(alg)`: whether or not the algorithm uses a concrete Jacobian. Defaults
to `nothing`.
- `get_name(alg)`: get the name of the algorithm.
"""
abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end

get_name(alg::AbstractNonlinearSolveAlgorithm) = Utils.safe_getproperty(alg, Val(:name))
Expand All @@ -47,20 +217,20 @@ concrete_jac(v::Bool) = v
concrete_jac(::Val{false}) = false
concrete_jac(::Val{true}) = true

abstract type AbstractNonlinearSolveCache end
abstract type AbstractNonlinearSolveCache <: AbstractNonlinearSolveBaseAPI end

"""
AbstractLinearSolverCache
Abstract Type for all Linear Solvers used in NonlinearSolve. Subtypes of these are
Abstract Type for all Linear Solvers used in NonlinearSolveBase. Subtypes of these are
meant to be constructured via [`construct_linear_solver`](@ref).
"""
abstract type AbstractLinearSolverCache end
abstract type AbstractLinearSolverCache <: AbstractNonlinearSolveBaseAPI end

"""
AbstractJacobianCache
Abstract Type for all Jacobian Caches used in NonlinearSolve. Subtypes of these are
Abstract Type for all Jacobian Caches used in NonlinearSolveBase. Subtypes of these are
meant to be constructured via [`construct_jacobian_cache`](@ref).
"""
abstract type AbstractJacobianCache end
abstract type AbstractJacobianCache <: AbstractNonlinearSolveBaseAPI end
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
DescentResult(; δu = missing, u = missing, success::Bool = true,
linsolve_success::Bool = true, extras = (;))
DescentResult(;
δu = missing, u = missing, success::Bool = true, linsolve_success::Bool = true,
extras = (;)
)
Construct a `DescentResult` object.
Expand All @@ -23,8 +25,10 @@ Construct a `DescentResult` object.
extras
end

function DescentResult(; δu = missing, u = missing, success::Bool = true,
linsolve_success::Bool = true, extras = (;))
function DescentResult(;
δu = missing, u = missing, success::Bool = true, linsolve_success::Bool = true,
extras = (;)
)
@assert δu !== missing || u !== missing
return DescentResult(δu, u, success, linsolve_success, extras)
end
Loading

0 comments on commit 33f1ec9

Please sign in to comment.