diff --git a/Project.toml b/Project.toml index 60764651b..8a42f9d21 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" @@ -42,8 +43,8 @@ NonlinearSolveZygoteExt = "Zygote" [compat] ADTypes = "0.2" -ArrayInterface = "6.0.24, 7" Aqua = "0.8" +ArrayInterface = "6.0.24, 7" BandedMatrices = "1" BenchmarkTools = "1" ConcreteStructs = "0.2" @@ -70,9 +71,9 @@ Reexport = "0.2, 1" SafeTestsets = "0.1" SciMLBase = "2.9" SciMLOperators = "0.3" -SimpleNonlinearSolve = "0.1.23" +SimpleNonlinearSolve = "1" # FIXME: Don't update the version in this PR. Using it to test SparseArrays = "<0.0.1, 1" -SparseDiffTools = "2.12" +SparseDiffTools = "2.14" StaticArrays = "1" StaticArraysCore = "1.4" Symbolics = "5" diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index c591eb4ee..f050bf007 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -8,25 +8,24 @@ import Reexport: @reexport import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload @recompile_invalidations begin - using DiffEqBase, - LazyArrays, LinearAlgebra, LinearSolve, Printf, SparseArrays, + using DiffEqBase, LazyArrays, LinearAlgebra, LinearSolve, Printf, SparseArrays, SparseDiffTools - using FastBroadcast: @.. - import ArrayInterface: restructure import ADTypes: AbstractFiniteDifferencesMode - import ArrayInterface: undefmatrix, + import ArrayInterface: undefmatrix, restructure, can_setindex, matrix_colors, parameterless_type, ismutable, issingular, fast_scalar_indexing import ConcreteStructs: @concrete import EnumX: @enumx + import FastBroadcast: @.. import ForwardDiff import ForwardDiff: Dual import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A + import MaybeInplace: @bb import RecursiveArrayTools: ArrayPartition, AbstractVectorOfArray, recursivecopy!, recursivefill! import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace import SciMLOperators: FunctionOperator - import StaticArraysCore: StaticArray, SVector, SArray, MArray + import StaticArraysCore: StaticArray, SVector, SArray, MArray, Size, SMatrix import UnPack: @unpack using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve @@ -55,13 +54,13 @@ isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm) str = "$(nameof(typeof(alg)))(" modifiers = String[] - if _getproperty(alg, Val(:ad)) !== nothing + if __getproperty(alg, Val(:ad)) !== nothing push!(modifiers, "ad = $(nameof(typeof(alg.ad)))()") end - if _getproperty(alg, Val(:linsolve)) !== nothing + if __getproperty(alg, Val(:linsolve)) !== nothing push!(modifiers, "linsolve = $(nameof(typeof(alg.linsolve)))()") end - if _getproperty(alg, Val(:linesearch)) !== nothing + if __getproperty(alg, Val(:linesearch)) !== nothing ls = alg.linesearch if ls isa LineSearch ls.method !== nothing && @@ -70,7 +69,7 @@ function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm) push!(modifiers, "linesearch = $(nameof(typeof(alg.linesearch)))()") end end - if _getproperty(alg, Val(:radius_update_scheme)) !== nothing + if __getproperty(alg, Val(:radius_update_scheme)) !== nothing push!(modifiers, "radius_update_scheme = $(alg.radius_update_scheme)") end str = str * join(modifiers, ", ") @@ -107,7 +106,7 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache) end end - trace = _getproperty(cache, Val{:trace}()) + trace = __getproperty(cache, Val{:trace}()) if trace !== nothing update_trace!(trace, cache.stats.nsteps, get_u(cache), get_fu(cache), nothing, nothing, nothing; last = Val(true)) @@ -134,52 +133,52 @@ include("jacobian.jl") include("ad.jl") include("default.jl") -@setup_workload begin - nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1), - (NonlinearFunction{false}((u, p) -> u .* u .- p), [0.1]), - (NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1])) - probs_nls = NonlinearProblem[] - for T in (Float32, Float64), (fn, u0) in nlfuncs - push!(probs_nls, NonlinearProblem(fn, T.(u0), T(2))) - end - - nls_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(), - GeneralBroyden(), GeneralKlement(), DFSane(), nothing) - - probs_nlls = NonlinearLeastSquaresProblem[] - nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]), - (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]), - (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p, - resid_prototype = zeros(1)), [0.1, 0.0]), - (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p), - resid_prototype = zeros(4)), [0.1, 0.1])) - for (fn, u0) in nlfuncs - push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0)) - end - nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), Float32[0.1, 0.0]), - (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), - Float32[0.1, 0.1]), - (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p, - resid_prototype = zeros(Float32, 1)), Float32[0.1, 0.0]), - (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p), - resid_prototype = zeros(Float32, 4)), Float32[0.1, 0.1])) - for (fn, u0) in nlfuncs - push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0f0)) - end - - nlls_algs = (LevenbergMarquardt(), GaussNewton(), - LevenbergMarquardt(; linsolve = LUFactorization()), - GaussNewton(; linsolve = LUFactorization())) - - @compile_workload begin - for prob in probs_nls, alg in nls_algs - solve(prob, alg, abstol = 1e-2) - end - for prob in probs_nlls, alg in nlls_algs - solve(prob, alg, abstol = 1e-2) - end - end -end +# @setup_workload begin +# nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1), +# (NonlinearFunction{false}((u, p) -> u .* u .- p), [0.1]), +# (NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1])) +# probs_nls = NonlinearProblem[] +# for T in (Float32, Float64), (fn, u0) in nlfuncs +# push!(probs_nls, NonlinearProblem(fn, T.(u0), T(2))) +# end + +# nls_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(), +# GeneralBroyden(), GeneralKlement(), DFSane(), nothing) + +# probs_nlls = NonlinearLeastSquaresProblem[] +# nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]), +# (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]), +# (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p, +# resid_prototype = zeros(1)), [0.1, 0.0]), +# (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p), +# resid_prototype = zeros(4)), [0.1, 0.1])) +# for (fn, u0) in nlfuncs +# push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0)) +# end +# nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), Float32[0.1, 0.0]), +# (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), +# Float32[0.1, 0.1]), +# (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p, +# resid_prototype = zeros(Float32, 1)), Float32[0.1, 0.0]), +# (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p), +# resid_prototype = zeros(Float32, 4)), Float32[0.1, 0.1])) +# for (fn, u0) in nlfuncs +# push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0f0)) +# end + +# nlls_algs = (LevenbergMarquardt(), GaussNewton(), +# LevenbergMarquardt(; linsolve = LUFactorization()), +# GaussNewton(; linsolve = LUFactorization())) + +# @compile_workload begin +# for prob in probs_nls, alg in nls_algs +# solve(prob, alg, abstol = 1e-2) +# end +# for prob in probs_nlls, alg in nlls_algs +# solve(prob, alg, abstol = 1e-2) +# end +# end +# end export RadiusUpdateSchemes diff --git a/src/jacobian.jl b/src/jacobian.jl index 41c7319a1..54f1c0f0e 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -3,8 +3,11 @@ Jᵀ end -SciMLBase.isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ) +__maybe_symmetric(x::KrylovJᵀJ) = x.JᵀJ + +isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ) +# Select if we are going to use sparse differentiation or not sparsity_detection_alg(_, _) = NoSparsityDetection() function sparsity_detection_alg(f, ad::AbstractSparseADType) if f.sparsity === nothing @@ -33,13 +36,21 @@ function jacobian!!(J::Union{AbstractMatrix{<:Number}, Nothing}, cache) @unpack f, uf, u, p, jac_cache, alg, fu2 = cache iip = isinplace(cache) if iip - has_jac(f) ? f.jac(J, u, p) : - sparse_jacobian!(J, alg.ad, jac_cache, uf, fu2, _maybe_mutable(u, alg.ad)) + if has_jac(f) + f.jac(J, u, p) + else + sparse_jacobian!(J, alg.ad, jac_cache, uf, fu2, u) + end + return J else - return has_jac(f) ? f.jac(u, p) : - sparse_jacobian!(J, alg.ad, jac_cache, uf, _maybe_mutable(u, alg.ad)) + if has_jac(f) + return f.jac(u, p) + elseif can_setindex(typeof(J)) + return sparse_jacobian!(J, alg.ad, jac_cache, uf, u) + else + return sparse_jacobian(alg.ad, jac_cache, uf, u) + end end - return J end # Scalar case jacobian!!(::Number, cache) = last(value_derivative(cache.uf, cache.u)) @@ -59,13 +70,13 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val alg_wants_jac = (concrete_jac(alg) !== nothing && concrete_jac(alg)) # NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere - fu = f.resid_prototype === nothing ? (iip ? _mutable_zero(u) : _mutable(f(u, p))) : + fu = f.resid_prototype === nothing ? (iip ? zero(u) : f(u, p)) : (iip ? deepcopy(f.resid_prototype) : f.resid_prototype) if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac) sd = sparsity_detection_alg(f, alg.ad) ad = alg.ad - jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, _maybe_mutable(u, ad)) : - sparse_jacobian_cache(ad, sd, uf, _maybe_mutable(u, ad); fx = fu) + jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, u) : + sparse_jacobian_cache(ad, sd, uf, __maybe_mutable(u, ad); fx = fu) else jac_cache = nothing end @@ -76,11 +87,11 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val JacVec(uf, u; fu, autodiff = __get_nonsparse_ad(alg.ad)) else if iip - jvp = (_, u, v) -> (du = similar(fu); f.jvp(du, v, u, p); du) - jvp! = (du, _, u, v) -> f.jvp(du, v, u, p) + jvp = (_, u, v) -> (du_ = similar(fu); f.jvp(du_, v, u, p); du_) + jvp! = (du_, _, u, v) -> f.jvp(du_, v, u, p) else jvp = (_, u, v) -> f.jvp(v, u, p) - jvp! = (du, _, u, v) -> (du .= f.jvp(v, u, p)) + jvp! = (du_, _, u, v) -> (du_ .= f.jvp(v, u, p)) end op = SparseDiffTools.FwdModeAutoDiffVecProd(f, u, (), jvp, jvp!) FunctionOperator(op, u, fu; isinplace = Val(true), outofplace = Val(false), @@ -89,16 +100,18 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val else if has_analytic_jac f.jac_prototype === nothing ? undefmatrix(u) : f.jac_prototype + elseif f.jac_prototype === nothing + init_jacobian(jac_cache; preserve_immutable = Val(true)) else - f.jac_prototype === nothing ? init_jacobian(jac_cache) : f.jac_prototype + f.jac_prototype end end - du = _mutable_zero(u) + du = copy(u) if needsJᵀJ JᵀJ, Jᵀfu = __init_JᵀJ(J, _vec(fu), uf, u; f, - vjp_autodiff = __get_nonsparse_ad(_getproperty(alg, Val(:vjp_autodiff))), + vjp_autodiff = __get_nonsparse_ad(__getproperty(alg, Val(:vjp_autodiff))), jvp_autodiff = __get_nonsparse_ad(alg.ad)) end @@ -106,7 +119,8 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val linprob_A = alg isa PseudoTransient ? (J - (1 / (convert(eltype(u), alg.alpha_initial))) * I) : (needsJᵀJ ? __maybe_symmetric(JᵀJ) : J) - linsolve = __setup_linsolve(linprob_A, needsJᵀJ ? Jᵀfu : fu, du, p, alg) + linsolve = linsolve_caches(linprob_A, needsJᵀJ ? Jᵀfu : fu, du, p, alg; + linsolve_kwargs) else linsolve = nothing end @@ -115,22 +129,33 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val return uf, linsolve, J, fu, jac_cache, du end -function __setup_linsolve(A, b, u, p, alg) - linprob = LinearProblem(A, _vec(b); u0 = _vec(u)) +## Special Handling for Scalars +function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p, + ::Val{false}; linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false), + kwargs...) where {needsJᵀJ, F} + # NOTE: Scalar `u` assumes scalar output from `f` + uf = SciMLBase.JacobianWrapper{false}(f, p) + needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u + return uf, FakeLinearSolveJLCache(u, u), u, nothing, nothing, u +end - weight = similar(u) - recursivefill!(weight, true) +# Linear Solve Cache +function linsolve_caches(A, b, u, p, alg; linsolve_kwargs = (;)) + if alg.linsolve === nothing && A isa SMatrix && linsolve_kwargs === (;) + # Default handling for SArrays in LinearSolve is not great. Some parts are patched + # but there are quite a few unnecessary allocations + return FakeLinearSolveJLCache(A, b) + end + + linprob = LinearProblem(A, _vec(b); u0 = _vec(u), linsolve_kwargs...) + + weight = __init_ones(u) Pl, Pr = wrapprecs(alg.precs(A, nothing, u, p, nothing, nothing, nothing, nothing, nothing)..., weight) return init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr) end -__setup_linsolve(A::KrylovJᵀJ, b, u, p, alg) = __setup_linsolve(A.JᵀJ, b, u, p, alg) - -__get_nonsparse_ad(::AutoSparseForwardDiff) = AutoForwardDiff() -__get_nonsparse_ad(::AutoSparseFiniteDiff) = AutoFiniteDiff() -__get_nonsparse_ad(::AutoSparseZygote) = AutoZygote() -__get_nonsparse_ad(ad) = ad +linsolve_caches(A::KrylovJᵀJ, b, u, p, alg) = linsolve_caches(A.JᵀJ, b, u, p, alg) __init_JᵀJ(J::Number, args...; kwargs...) = zero(J), zero(J) function __init_JᵀJ(J::AbstractArray, fu, args...; kwargs...) @@ -180,24 +205,7 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf) end end -__maybe_symmetric(x) = Symmetric(x) -__maybe_symmetric(x::Number) = x -# LinearSolve with `nothing` doesn't dispatch correctly here -__maybe_symmetric(x::StaticArray) = x -__maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x -__maybe_symmetric(x::SciMLOperators.AbstractSciMLOperator) = x -__maybe_symmetric(x::KrylovJᵀJ) = x.JᵀJ - -## Special Handling for Scalars -function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p, - ::Val{false}; linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false), - kwargs...) where {needsJᵀJ, F} - # NOTE: Scalar `u` assumes scalar output from `f` - uf = SciMLBase.JacobianWrapper{false}(f, p) - needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u - return uf, nothing, u, nothing, nothing, u -end - +# Generic Handling of Krylov Methods for Normal Form Linear Solves function __update_JᵀJ!(iip::Val, cache, sym::Symbol, J) return __update_JᵀJ!(iip, cache, sym, getproperty(cache, sym), J) end diff --git a/src/klement.jl b/src/klement.jl index ec32dc6b8..8a9640fd4 100644 --- a/src/klement.jl +++ b/src/klement.jl @@ -87,7 +87,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKleme linsolve_alg = alg_.linsolve === nothing && u isa Array ? LUFactorization() : nothing alg = set_linsolve(alg_, linsolve_alg) - linsolve = __setup_linsolve(J, _vec(fu), _vec(du), p, alg) + linsolve = linsolve_caches(J, _vec(fu), _vec(du), p, alg) end abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u, diff --git a/src/levenberg.jl b/src/levenberg.jl index dcc07d85e..94e882223 100644 --- a/src/levenberg.jl +++ b/src/levenberg.jl @@ -232,7 +232,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip}, fill!(mat_tmp, zero(eltype(u))) rhs_tmp = vcat(_vec(fu1), _vec(u)) fill!(rhs_tmp, zero(eltype(u))) - linsolve = __setup_linsolve(mat_tmp, rhs_tmp, u, p, alg) + linsolve = linsolve_caches(mat_tmp, rhs_tmp, u, p, alg) end return LevenbergMarquardtCache{iip, !_unwrap_val(linsolve_with_JᵀJ)}(f, alg, u, copy(u), diff --git a/src/raphson.jl b/src/raphson.jl index 594b893e5..4c4125579 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -80,7 +80,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso kwargs...) where {uType, iip} alg = get_concrete_algorithm(alg_, prob) @unpack f, u0, p = prob - u = alias_u0 ? u0 : deepcopy(u0) + u = __maybe_unaliased(u0, alias_u0) fu1 = evaluate_f(prob, u) uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip); linsolve_kwargs) @@ -91,62 +91,37 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso ls_cache = init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip)) trace = init_nonlinearsolve_trace(alg, u, fu1, ApplyArray(__zero, J), du; kwargs...) - return NewtonRaphsonCache{iip}(f, alg, u, copy(u), fu1, fu2, du, p, uf, linsolve, J, + @bb u_prev = copy(u) + + return NewtonRaphsonCache{iip}(f, alg, u, u_prev, fu1, fu2, du, p, uf, linsolve, J, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), ls_cache, tc_cache, trace) end -function perform_step!(cache::NewtonRaphsonCache{true}) - @unpack u, u_prev, fu1, f, p, alg, J, linsolve, du = cache - jacobian!!(J, cache) +function perform_step!(cache::NewtonRaphsonCache{iip}) where {iip} + @unpack alg = cache + + cache.J = jacobian!!(cache.J, cache) # u = u - J \ fu - linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(fu1), linu = _vec(du), - p, reltol = cache.abstol) + linres = dolinsolve(alg.precs, cache.linsolve; A = cache.J, b = _vec(cache.fu1), + linu = _vec(cache.du), cache.p, reltol = cache.abstol) cache.linsolve = linres.cache - # Line Search - α = perform_linesearch!(cache.ls_cache, u, du) - _axpy!(-α, du, u) - f(cache.fu1, u, p) - - update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), J, - cache.du, α) - - check_and_update!(cache, cache.fu1, cache.u, cache.u_prev) - - @. u_prev = u - cache.stats.nf += 1 - cache.stats.njacs += 1 - cache.stats.nsolve += 1 - cache.stats.nfactors += 1 - return nothing -end - -function perform_step!(cache::NewtonRaphsonCache{false}) - @unpack u, u_prev, fu1, f, p, alg, linsolve = cache - - cache.J = jacobian!!(cache.J, cache) - # u = u - J \ fu - if linsolve === nothing - cache.du = fu1 / cache.J - else - linres = dolinsolve(alg.precs, linsolve; A = cache.J, b = _vec(fu1), - linu = _vec(cache.du), p, reltol = cache.abstol) - cache.linsolve = linres.cache - end + !iip && (cache.du = linres.u) # Line Search - α = perform_linesearch!(cache.ls_cache, u, cache.du) - cache.u = @. u - α * cache.du # `u` might not support mutation - cache.fu1 = f(cache.u, p) + α = perform_linesearch!(cache.ls_cache, cache.u, cache.du) + @bb axpy!(-α, cache.du, cache.u) + + evaluate_f(cache, cache.u) update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), cache.J, cache.du, α) check_and_update!(cache, cache.fu1, cache.u, cache.u_prev) - cache.u_prev = cache.u + @bb copyto!(cache.u_prev, cache.u) cache.stats.nf += 1 cache.stats.njacs += 1 cache.stats.nsolve += 1 diff --git a/src/trace.jl b/src/trace.jl index c458c7d07..e89efe956 100644 --- a/src/trace.jl +++ b/src/trace.jl @@ -151,8 +151,10 @@ function reset!(trace::NonlinearSolveTrace) end function Base.show(io::IO, trace::NonlinearSolveTrace) - for entry in trace.history - show(io, entry) + if trace.history !== nothing + foreach(entry -> show(io, entry), trace.history) + else + print(io, "Tracing Disabled") end return nothing end diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 8b4041b75..5493aa4d7 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -248,7 +248,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, uf, _, J, fu2, jac_cache, du, H, g = jacobian_caches(alg, f, u, p, Val(iip); linsolve_kwargs, linsolve_with_JᵀJ = Val(true), lininit = Val(false)) g = _restructure(fu1, g) - linsolve = u isa Number ? nothing : __setup_linsolve(J, fu2, du, p, alg) + linsolve = u isa Number ? nothing : linsolve_caches(J, fu2, du, p, alg) u_tmp = zero(u) u_cauchy = zero(u) diff --git a/src/utils.jl b/src/utils.jl index bf6d1152f..d3017d42f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,5 +1,15 @@ const DEFAULT_NORM = DiffEqBase.NONLINEARSOLVE_DEFAULT_NORM +@concrete mutable struct FakeLinearSolveJLCache + A + b +end + +@concrete struct FakeLinearSolveJLResult + cache + u +end + # Ignores NaN function __findmin(f, x) return findmin(x) do xᵢ @@ -55,7 +65,7 @@ function default_adargs_to_adtype(; chunk_size = missing, autodiff = nothing, end """ -value_derivative(f, x) + value_derivative(f, x) Compute `f(x), d/dx f(x)` in the most efficient way. """ @@ -65,10 +75,6 @@ function value_derivative(f::F, x::R) where {F, R} ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out) end -function value_derivative(f::F, x::SVector) where {F} - f(x), ForwardDiff.jacobian(f, x) -end - @inline value(x) = x @inline value(x::Dual) = ForwardDiff.value(x) @inline value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) @@ -82,6 +88,15 @@ end DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, cachedata) = nothing, nothing +function dolinsolve(precs::P, linsolve::FakeLinearSolveJLCache; A = nothing, + linu = nothing, b = nothing, du = nothing, p = nothing, weight = nothing, + cachedata = nothing, reltol = nothing, reuse_A_if_factorization = false) where {P} + A !== nothing && (linsolve.A = A) + b !== nothing && (linsolve.b = b) + linres = linsolve.A \ linsolve.b + return FakeLinearSolveJLResult(linsolve, linres) +end + function dolinsolve(precs::P, linsolve; A = nothing, linu = nothing, b = nothing, du = nothing, p = nothing, weight = nothing, cachedata = nothing, reltol = nothing, reuse_A_if_factorization = false) where {P} @@ -155,33 +170,32 @@ _mutable_zero(x::SArray) = MArray(x) _mutable(x) = x _mutable(x::SArray) = MArray(x) -_maybe_mutable(x, ::AbstractFiniteDifferencesMode) = _mutable(x) +# __maybe_mutable(x, ::AbstractFiniteDifferencesMode) = _mutable(x) # The shadow allocated for Enzyme needs to be mutable -_maybe_mutable(x, ::AutoSparseEnzyme) = _mutable(x) -_maybe_mutable(x, _) = x +__maybe_mutable(x, ::AutoSparseEnzyme) = _mutable(x) +__maybe_mutable(x, _) = x # Helper function to get value of `f(u, p)` function evaluate_f(prob::Union{NonlinearProblem{uType, iip}, NonlinearLeastSquaresProblem{uType, iip}}, u) where {uType, iip} @unpack f, u0, p = prob if iip - fu = f.resid_prototype === nothing ? zero(u) : f.resid_prototype + fu = f.resid_prototype === nothing ? similar(u) : f.resid_prototype f(fu, u, p) else - fu = _mutable(f(u, p)) + fu = f(u, p) end return fu end -evaluate_f(cache, u; fu = nothing) = evaluate_f(cache.f, u, cache.p, Val(cache.iip); fu) - -function evaluate_f(f, u, p, ::Val{iip}; fu = nothing) where {iip} - if iip - f(fu, u, p) - return fu +function evaluate_f(cache, u) + @unpack f, p = cache.prob + if isinplace(cache) + f(get_fu(cache), u, p) else - return f(u, p) + set_fu!(cache, f(u, p)) end + return nothing end """ @@ -206,7 +220,7 @@ end function __get_concrete_algorithm(alg, prob) @unpack sparsity, jac_prototype = prob.f use_sparse_ad = sparsity !== nothing || jac_prototype !== nothing - ad = if eltype(prob.u0) <: Complex + ad = if !ForwardDiff.can_dual(eltype(prob.u0)) # Use Finite Differencing use_sparse_ad ? AutoSparseFiniteDiff() : AutoFiniteDiff() else @@ -310,16 +324,16 @@ function __init_low_rank_jacobian(u, fu, threshold::Int) end # Check Singular Matrix -_issingular(x::Number) = iszero(x) -@generated function _issingular(x::T) where {T} +@inline _issingular(x::Number) = iszero(x) +@inline @generated function _issingular(x::T) where {T} hasmethod(issingular, Tuple{T}) && return :(issingular(x)) return :(__issingular(x)) end -__issingular(x::AbstractMatrix{T}) where {T} = cond(x) > inv(sqrt(eps(real(T)))) -__issingular(x) = false ## If SciMLOperator and such +@inline __issingular(x::AbstractMatrix{T}) where {T} = cond(x) > inv(sqrt(eps(real(T)))) +@inline __issingular(x) = false ## If SciMLOperator and such # Safe getproperty -@generated function _getproperty(s::S, ::Val{X}) where {S, X} +@generated function __getproperty(s::S, ::Val{X}) where {S, X} hasfield(S, X) && return :(s.$X) return :(nothing) end @@ -348,6 +362,7 @@ _try_factorize_and_check_singular!(::Nothing, x) = _issingular(x), false return :(@. y += α * x) end +# Non-square matrix @inline _needs_square_A(_, ::Number) = true @inline _needs_square_A(_, ::StaticArray) = true @inline _needs_square_A(alg, _) = LinearSolve.needs_square_A(alg.linsolve) @@ -355,9 +370,40 @@ end # Define special concatenation for certain Array combinations @inline _vcat(x, y) = vcat(x, y) +# LazyArrays for tracing __zero(x::AbstractArray) = zero(x) __zero(x) = x LazyArrays.applied_eltype(::typeof(__zero), x) = eltype(x) LazyArrays.applied_ndims(::typeof(__zero), x) = ndims(x) LazyArrays.applied_size(::typeof(__zero), x) = size(x) LazyArrays.applied_axes(::typeof(__zero), x) = axes(x) + +# SparseAD --> NonSparseAD +@inline __get_nonsparse_ad(::AutoSparseForwardDiff) = AutoForwardDiff() +@inline __get_nonsparse_ad(::AutoSparseFiniteDiff) = AutoFiniteDiff() +@inline __get_nonsparse_ad(::AutoSparseZygote) = AutoZygote() +@inline __get_nonsparse_ad(ad) = ad + +# Use Symmetric Matrices if known to be efficient +@inline __maybe_symmetric(x) = Symmetric(x) +@inline __maybe_symmetric(x::Number) = x +## LinearSolve with `nothing` doesn't dispatch correctly here +@inline __maybe_symmetric(x::StaticArray) = x +@inline __maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x +@inline __maybe_symmetric(x::SciMLOperators.AbstractSciMLOperator) = x + +# Unalias +@inline __maybe_unaliased(x::Union{Number, SArray}, ::Bool) = x +@inline function __maybe_unaliased(x::AbstractArray, alias::Bool) + # Spend time coping iff we will mutate the array + (alias || !can_setindex(typeof(x))) && return x + return deepcopy(x) +end + +# Init ones +@inline function __init_ones(x) + w = similar(x) + recursivefill!(w, true) + return w +end +@inline __init_ones(x::StaticArray) = ones(typeof(x))