Skip to content

Commit

Permalink
Reduce unnecessary allocations and reuse code
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 29, 2023
1 parent 657c549 commit 74c2ad7
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 173 deletions.
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Expand All @@ -42,8 +43,8 @@ NonlinearSolveZygoteExt = "Zygote"

[compat]
ADTypes = "0.2"
ArrayInterface = "6.0.24, 7"
Aqua = "0.8"
ArrayInterface = "6.0.24, 7"
BandedMatrices = "1"
BenchmarkTools = "1"
ConcreteStructs = "0.2"
Expand All @@ -70,9 +71,9 @@ Reexport = "0.2, 1"
SafeTestsets = "0.1"
SciMLBase = "2.9"
SciMLOperators = "0.3"
SimpleNonlinearSolve = "0.1.23"
SimpleNonlinearSolve = "1" # FIXME: Don't update the version in this PR. Using it to test
SparseArrays = "<0.0.1, 1"
SparseDiffTools = "2.12"
SparseDiffTools = "2.14"
StaticArrays = "1"
StaticArraysCore = "1.4"
Symbolics = "5"
Expand Down
113 changes: 56 additions & 57 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,24 @@ import Reexport: @reexport
import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload

@recompile_invalidations begin
using DiffEqBase,
LazyArrays, LinearAlgebra, LinearSolve, Printf, SparseArrays,
using DiffEqBase, LazyArrays, LinearAlgebra, LinearSolve, Printf, SparseArrays,
SparseDiffTools
using FastBroadcast: @..
import ArrayInterface: restructure

import ADTypes: AbstractFiniteDifferencesMode
import ArrayInterface: undefmatrix,
import ArrayInterface: undefmatrix, restructure, can_setindex,
matrix_colors, parameterless_type, ismutable, issingular, fast_scalar_indexing
import ConcreteStructs: @concrete
import EnumX: @enumx
import FastBroadcast: @..
import ForwardDiff
import ForwardDiff: Dual
import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A
import MaybeInplace: @bb
import RecursiveArrayTools: ArrayPartition,
AbstractVectorOfArray, recursivecopy!, recursivefill!
import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
import SciMLOperators: FunctionOperator
import StaticArraysCore: StaticArray, SVector, SArray, MArray
import StaticArraysCore: StaticArray, SVector, SArray, MArray, Size, SMatrix
import UnPack: @unpack

using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
Expand Down Expand Up @@ -55,13 +54,13 @@ isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip
function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm)
str = "$(nameof(typeof(alg)))("
modifiers = String[]
if _getproperty(alg, Val(:ad)) !== nothing
if __getproperty(alg, Val(:ad)) !== nothing
push!(modifiers, "ad = $(nameof(typeof(alg.ad)))()")
end
if _getproperty(alg, Val(:linsolve)) !== nothing
if __getproperty(alg, Val(:linsolve)) !== nothing
push!(modifiers, "linsolve = $(nameof(typeof(alg.linsolve)))()")
end
if _getproperty(alg, Val(:linesearch)) !== nothing
if __getproperty(alg, Val(:linesearch)) !== nothing
ls = alg.linesearch
if ls isa LineSearch
ls.method !== nothing &&
Expand All @@ -70,7 +69,7 @@ function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm)
push!(modifiers, "linesearch = $(nameof(typeof(alg.linesearch)))()")
end
end
if _getproperty(alg, Val(:radius_update_scheme)) !== nothing
if __getproperty(alg, Val(:radius_update_scheme)) !== nothing
push!(modifiers, "radius_update_scheme = $(alg.radius_update_scheme)")
end
str = str * join(modifiers, ", ")
Expand Down Expand Up @@ -107,7 +106,7 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
end
end

trace = _getproperty(cache, Val{:trace}())
trace = __getproperty(cache, Val{:trace}())
if trace !== nothing
update_trace!(trace, cache.stats.nsteps, get_u(cache), get_fu(cache), nothing,
nothing, nothing; last = Val(true))
Expand All @@ -134,52 +133,52 @@ include("jacobian.jl")
include("ad.jl")
include("default.jl")

@setup_workload begin
nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
(NonlinearFunction{false}((u, p) -> u .* u .- p), [0.1]),
(NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1]))
probs_nls = NonlinearProblem[]
for T in (Float32, Float64), (fn, u0) in nlfuncs
push!(probs_nls, NonlinearProblem(fn, T.(u0), T(2)))
end

nls_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(),
GeneralBroyden(), GeneralKlement(), DFSane(), nothing)

probs_nlls = NonlinearLeastSquaresProblem[]
nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
(NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
(NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
resid_prototype = zeros(1)), [0.1, 0.0]),
(NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
resid_prototype = zeros(4)), [0.1, 0.1]))
for (fn, u0) in nlfuncs
push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0))
end
nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), Float32[0.1, 0.0]),
(NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)),
Float32[0.1, 0.1]),
(NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
resid_prototype = zeros(Float32, 1)), Float32[0.1, 0.0]),
(NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
resid_prototype = zeros(Float32, 4)), Float32[0.1, 0.1]))
for (fn, u0) in nlfuncs
push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0f0))
end

nlls_algs = (LevenbergMarquardt(), GaussNewton(),
LevenbergMarquardt(; linsolve = LUFactorization()),
GaussNewton(; linsolve = LUFactorization()))

@compile_workload begin
for prob in probs_nls, alg in nls_algs
solve(prob, alg, abstol = 1e-2)
end
for prob in probs_nlls, alg in nlls_algs
solve(prob, alg, abstol = 1e-2)
end
end
end
# @setup_workload begin
# nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
# (NonlinearFunction{false}((u, p) -> u .* u .- p), [0.1]),
# (NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1]))
# probs_nls = NonlinearProblem[]
# for T in (Float32, Float64), (fn, u0) in nlfuncs
# push!(probs_nls, NonlinearProblem(fn, T.(u0), T(2)))
# end

# nls_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(),
# GeneralBroyden(), GeneralKlement(), DFSane(), nothing)

# probs_nlls = NonlinearLeastSquaresProblem[]
# nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
# (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
# (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
# resid_prototype = zeros(1)), [0.1, 0.0]),
# (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
# resid_prototype = zeros(4)), [0.1, 0.1]))
# for (fn, u0) in nlfuncs
# push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0))
# end
# nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), Float32[0.1, 0.0]),
# (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)),
# Float32[0.1, 0.1]),
# (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
# resid_prototype = zeros(Float32, 1)), Float32[0.1, 0.0]),
# (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
# resid_prototype = zeros(Float32, 4)), Float32[0.1, 0.1]))
# for (fn, u0) in nlfuncs
# push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0f0))
# end

# nlls_algs = (LevenbergMarquardt(), GaussNewton(),
# LevenbergMarquardt(; linsolve = LUFactorization()),
# GaussNewton(; linsolve = LUFactorization()))

# @compile_workload begin
# for prob in probs_nls, alg in nls_algs
# solve(prob, alg, abstol = 1e-2)
# end
# for prob in probs_nlls, alg in nlls_algs
# solve(prob, alg, abstol = 1e-2)
# end
# end
# end

export RadiusUpdateSchemes

Expand Down
96 changes: 52 additions & 44 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
Jᵀ
end

SciMLBase.isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ)
__maybe_symmetric(x::KrylovJᵀJ) = x.JᵀJ

isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ)

# Select if we are going to use sparse differentiation or not
sparsity_detection_alg(_, _) = NoSparsityDetection()
function sparsity_detection_alg(f, ad::AbstractSparseADType)
if f.sparsity === nothing
Expand Down Expand Up @@ -33,13 +36,21 @@ function jacobian!!(J::Union{AbstractMatrix{<:Number}, Nothing}, cache)
@unpack f, uf, u, p, jac_cache, alg, fu2 = cache
iip = isinplace(cache)
if iip
has_jac(f) ? f.jac(J, u, p) :
sparse_jacobian!(J, alg.ad, jac_cache, uf, fu2, _maybe_mutable(u, alg.ad))
if has_jac(f)
f.jac(J, u, p)
else
sparse_jacobian!(J, alg.ad, jac_cache, uf, fu2, u)
end
return J
else
return has_jac(f) ? f.jac(u, p) :
sparse_jacobian!(J, alg.ad, jac_cache, uf, _maybe_mutable(u, alg.ad))
if has_jac(f)
return f.jac(u, p)
elseif can_setindex(typeof(J))
return sparse_jacobian!(J, alg.ad, jac_cache, uf, u)
else
return sparse_jacobian(alg.ad, jac_cache, uf, u)
end
end
return J
end
# Scalar case
jacobian!!(::Number, cache) = last(value_derivative(cache.uf, cache.u))
Expand All @@ -59,13 +70,13 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
alg_wants_jac = (concrete_jac(alg) !== nothing && concrete_jac(alg))

# NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere
fu = f.resid_prototype === nothing ? (iip ? _mutable_zero(u) : _mutable(f(u, p))) :
fu = f.resid_prototype === nothing ? (iip ? zero(u) : f(u, p)) :
(iip ? deepcopy(f.resid_prototype) : f.resid_prototype)
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac)
sd = sparsity_detection_alg(f, alg.ad)
ad = alg.ad
jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, _maybe_mutable(u, ad)) :
sparse_jacobian_cache(ad, sd, uf, _maybe_mutable(u, ad); fx = fu)
jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, u) :
sparse_jacobian_cache(ad, sd, uf, __maybe_mutable(u, ad); fx = fu)
else
jac_cache = nothing
end
Expand All @@ -76,11 +87,11 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
JacVec(uf, u; fu, autodiff = __get_nonsparse_ad(alg.ad))
else
if iip
jvp = (_, u, v) -> (du = similar(fu); f.jvp(du, v, u, p); du)
jvp! = (du, _, u, v) -> f.jvp(du, v, u, p)
jvp = (_, u, v) -> (du_ = similar(fu); f.jvp(du_, v, u, p); du_)
jvp! = (du_, _, u, v) -> f.jvp(du_, v, u, p)
else
jvp = (_, u, v) -> f.jvp(v, u, p)
jvp! = (du, _, u, v) -> (du .= f.jvp(v, u, p))
jvp! = (du_, _, u, v) -> (du_ .= f.jvp(v, u, p))
end
op = SparseDiffTools.FwdModeAutoDiffVecProd(f, u, (), jvp, jvp!)
FunctionOperator(op, u, fu; isinplace = Val(true), outofplace = Val(false),
Expand All @@ -89,24 +100,27 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
else
if has_analytic_jac
f.jac_prototype === nothing ? undefmatrix(u) : f.jac_prototype
elseif f.jac_prototype === nothing
init_jacobian(jac_cache; preserve_immutable = Val(true))
else
f.jac_prototype === nothing ? init_jacobian(jac_cache) : f.jac_prototype
f.jac_prototype
end
end

du = _mutable_zero(u)
du = copy(u)

if needsJᵀJ
JᵀJ, Jᵀfu = __init_JᵀJ(J, _vec(fu), uf, u; f,
vjp_autodiff = __get_nonsparse_ad(_getproperty(alg, Val(:vjp_autodiff))),
vjp_autodiff = __get_nonsparse_ad(__getproperty(alg, Val(:vjp_autodiff))),
jvp_autodiff = __get_nonsparse_ad(alg.ad))
end

if linsolve_init
linprob_A = alg isa PseudoTransient ?
(J - (1 / (convert(eltype(u), alg.alpha_initial))) * I) :
(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J)
linsolve = __setup_linsolve(linprob_A, needsJᵀJ ? Jᵀfu : fu, du, p, alg)
linsolve = linsolve_caches(linprob_A, needsJᵀJ ? Jᵀfu : fu, du, p, alg;
linsolve_kwargs)
else
linsolve = nothing
end
Expand All @@ -115,22 +129,33 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
return uf, linsolve, J, fu, jac_cache, du
end

function __setup_linsolve(A, b, u, p, alg)
linprob = LinearProblem(A, _vec(b); u0 = _vec(u))
## Special Handling for Scalars
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p,
::Val{false}; linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false),
kwargs...) where {needsJᵀJ, F}
# NOTE: Scalar `u` assumes scalar output from `f`
uf = SciMLBase.JacobianWrapper{false}(f, p)
needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u
return uf, FakeLinearSolveJLCache(u, u), u, nothing, nothing, u
end

weight = similar(u)
recursivefill!(weight, true)
# Linear Solve Cache
function linsolve_caches(A, b, u, p, alg; linsolve_kwargs = (;))
if alg.linsolve === nothing && A isa SMatrix && linsolve_kwargs === (;)
# Default handling for SArrays in LinearSolve is not great. Some parts are patched
# but there are quite a few unnecessary allocations
return FakeLinearSolveJLCache(A, b)
end

linprob = LinearProblem(A, _vec(b); u0 = _vec(u), linsolve_kwargs...)

weight = __init_ones(u)

Pl, Pr = wrapprecs(alg.precs(A, nothing, u, p, nothing, nothing, nothing, nothing,
nothing)..., weight)
return init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr)
end
__setup_linsolve(A::KrylovJᵀJ, b, u, p, alg) = __setup_linsolve(A.JᵀJ, b, u, p, alg)

__get_nonsparse_ad(::AutoSparseForwardDiff) = AutoForwardDiff()
__get_nonsparse_ad(::AutoSparseFiniteDiff) = AutoFiniteDiff()
__get_nonsparse_ad(::AutoSparseZygote) = AutoZygote()
__get_nonsparse_ad(ad) = ad
linsolve_caches(A::KrylovJᵀJ, b, u, p, alg) = linsolve_caches(A.JᵀJ, b, u, p, alg)

__init_JᵀJ(J::Number, args...; kwargs...) = zero(J), zero(J)
function __init_JᵀJ(J::AbstractArray, fu, args...; kwargs...)
Expand Down Expand Up @@ -180,24 +205,7 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
end
end

__maybe_symmetric(x) = Symmetric(x)
__maybe_symmetric(x::Number) = x
# LinearSolve with `nothing` doesn't dispatch correctly here
__maybe_symmetric(x::StaticArray) = x
__maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x
__maybe_symmetric(x::SciMLOperators.AbstractSciMLOperator) = x
__maybe_symmetric(x::KrylovJᵀJ) = x.JᵀJ

## Special Handling for Scalars
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p,
::Val{false}; linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false),
kwargs...) where {needsJᵀJ, F}
# NOTE: Scalar `u` assumes scalar output from `f`
uf = SciMLBase.JacobianWrapper{false}(f, p)
needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u
return uf, nothing, u, nothing, nothing, u
end

# Generic Handling of Krylov Methods for Normal Form Linear Solves
function __update_JᵀJ!(iip::Val, cache, sym::Symbol, J)
return __update_JᵀJ!(iip, cache, sym, getproperty(cache, sym), J)
end
Expand Down
2 changes: 1 addition & 1 deletion src/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKleme
linsolve_alg = alg_.linsolve === nothing && u isa Array ? LUFactorization() :
nothing
alg = set_linsolve(alg_, linsolve_alg)
linsolve = __setup_linsolve(J, _vec(fu), _vec(du), p, alg)
linsolve = linsolve_caches(J, _vec(fu), _vec(du), p, alg)
end

abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u,
Expand Down
2 changes: 1 addition & 1 deletion src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
fill!(mat_tmp, zero(eltype(u)))
rhs_tmp = vcat(_vec(fu1), _vec(u))
fill!(rhs_tmp, zero(eltype(u)))
linsolve = __setup_linsolve(mat_tmp, rhs_tmp, u, p, alg)
linsolve = linsolve_caches(mat_tmp, rhs_tmp, u, p, alg)
end

return LevenbergMarquardtCache{iip, !_unwrap_val(linsolve_with_JᵀJ)}(f, alg, u, copy(u),
Expand Down
Loading

0 comments on commit 74c2ad7

Please sign in to comment.