diff --git a/Project.toml b/Project.toml index 1e47a2f8..d5274406 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Chris Rackauckas "] version = "6.71.1" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" @@ -20,6 +21,7 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" @@ -32,6 +34,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [compat] +ADTypes = "1" Adapt = "3, 4" ArrayInterface = "6, 7" DataStructures = "0.18" @@ -48,11 +51,12 @@ Logging = "1.6" MuladdMacro = "0.2.1" NLsolve = "4" OrdinaryDiffEq = "6.87" +OrdinaryDiffEqCore = "1.12.1" Random = "1.6" RandomNumbers = "1.5.3" RecursiveArrayTools = "2, 3" Reexport = "0.2, 1.0" -SciMLBase = "2.59.2" +SciMLBase = "2.65" SciMLOperators = "0.2.9, 0.3" SparseArrays = "1.6" SparseDiffTools = "2" diff --git a/src/StochasticDiffEq.jl b/src/StochasticDiffEq.jl index caf8c11d..b1da7e57 100644 --- a/src/StochasticDiffEq.jl +++ b/src/StochasticDiffEq.jl @@ -9,6 +9,8 @@ using DocStringExtensions using Reexport @reexport using DiffEqBase + import ADTypes + import OrdinaryDiffEq import OrdinaryDiffEq: default_controller, isstandard, ispredictive, beta2_default, beta1_default, gamma_default, @@ -41,7 +43,7 @@ using DocStringExtensions import DiffEqBase: step!, initialize!, DEAlgorithm, AbstractSDEAlgorithm, AbstractRODEAlgorithm, DEIntegrator, AbstractDiffEqInterpolation, DECache, AbstractSDEIntegrator, AbstractRODEIntegrator, AbstractContinuousCallback, - Tableau + Tableau, AbstractSDDEIntegrator # Integrator Interface import DiffEqBase: resize!,deleteat!,addat!,full_cache,user_cache,u_cache,du_cache, @@ -58,6 +60,8 @@ using OrdinaryDiffEq: nlsolvefail, isnewton, set_new_W!, get_W, _vec, _reshape using OrdinaryDiffEq: NLSolver +import OrdinaryDiffEqCore + if isdefined(OrdinaryDiffEq,:FastConvergence) using OrdinaryDiffEq: FastConvergence, Convergence, SlowConvergence, VerySlowConvergence, Divergence @@ -119,6 +123,7 @@ end include("cache_utils.jl") include("integrators/integrator_interface.jl") include("iterator_interface.jl") + include("initialize_dae.jl") include("solve.jl") include("initdt.jl") include("perform_step/low_order.jl") diff --git a/src/alg_utils.jl b/src/alg_utils.jl index 6ddf4083..4d22b89e 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -9,6 +9,31 @@ SciMLBase.forwarddiffs_model(alg::Union{StochasticDiffEqNewtonAlgorithm, StochasticDiffEqNewtonAdaptiveAlgorithm,StochasticDiffEqJumpNewtonAdaptiveAlgorithm, StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm}) = OrdinaryDiffEq.alg_autodiff(alg) +# Required for initialization, because ODECore._initialize_dae! calls it during +# OverrideInit +OrdinaryDiffEqCore.has_autodiff(::Union{StochasticDiffEqAlgorithm,StochasticDiffEqRODEAlgorithm,StochasticDiffEqJumpAlgorithm}) = false +for T in [ + StochasticDiffEqNewtonAlgorithm, StochasticDiffEqNewtonAdaptiveAlgorithm, + StochasticDiffEqJumpNewtonAdaptiveAlgorithm, + StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm] + @eval OrdinaryDiffEqCore.has_autodiff(::$T) = true +end + +_alg_autodiff(::StochasticDiffEqNewtonAlgorithm{T, AD}) where {T, AD} = Val{AD}() +_alg_autodiff(::StochasticDiffEqNewtonAdaptiveAlgorithm{T, AD}) where {T, AD} = Val{AD}() +_alg_autodiff(::StochasticDiffEqJumpNewtonAdaptiveAlgorithm{T, AD}) where {T, AD} = Val{AD}() +_alg_autodiff(::StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm{T, AD}) where {T, AD} = Val{AD}() + +function OrdinaryDiffEqCore.alg_autodiff(alg) + ad = _alg_autodiff(alg) + if ad == Val(false) + return ADTypes.AutoFiniteDiff() + elseif ad == Val(true) + return ADTypes.AutoForwardDiff() + else + return SciMLBase._unwrap_val(ad) + end +end isadaptive(alg::Union{StochasticDiffEqAlgorithm,StochasticDiffEqRODEAlgorithm}) = false isadaptive(alg::Union{StochasticDiffEqAdaptiveAlgorithm,StochasticDiffEqRODEAdaptiveAlgorithm,StochasticDiffEqJumpAdaptiveAlgorithm,StochasticDiffEqJumpDiffusionAdaptiveAlgorithm}) = true diff --git a/src/initialize_dae.jl b/src/initialize_dae.jl new file mode 100644 index 00000000..f9bfbfc0 --- /dev/null +++ b/src/initialize_dae.jl @@ -0,0 +1,13 @@ +struct SDEDefaultInit <: DiffEqBase.DAEInitializationAlgorithm end + +function DiffEqBase.initialize_dae!(integrator::Union{AbstractSDEIntegrator, AbstractSDDEIntegrator}, initializealg = integrator.initializealg) + OrdinaryDiffEqCore._initialize_dae!(integrator, integrator.sol.prob, initializealg, Val(DiffEqBase.isinplace(integrator.sol.prob))) +end + +function OrdinaryDiffEqCore._initialize_dae!(integrator::Union{AbstractSDEIntegrator, AbstractSDDEIntegrator}, prob, ::SDEDefaultInit, isinplace) + if SciMLBase.has_initializeprob(prob.f) + OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.OverrideInit(), isinplace) + elseif SciMLBase.__has_mass_matrix(prob.f) + OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.CheckInit(), isinplace) + end +end diff --git a/src/integrators/type.jl b/src/integrators/type.jl index b7780597..979fdf60 100644 --- a/src/integrators/type.jl +++ b/src/integrators/type.jl @@ -1,4 +1,4 @@ -mutable struct SDEIntegrator{algType,IIP,uType,uEltype,tType,tdirType,P2,eigenType,tTypeNoUnits,uEltypeNoUnits,randType,randType2,rateType,solType,cacheType,F4,F5,F6,OType,noiseType,EventErrorType,CallbackCacheType,RCs} <: AbstractSDEIntegrator{algType,IIP,uType,tType} +mutable struct SDEIntegrator{algType,IIP,uType,uEltype,tType,tdirType,P2,eigenType,tTypeNoUnits,uEltypeNoUnits,randType,randType2,rateType,solType,cacheType,F4,F5,F6,OType,noiseType,EventErrorType,CallbackCacheType,RCs,IA} <: AbstractSDEIntegrator{algType,IIP,uType,tType} f::F4 g::F5 c::F6 @@ -43,4 +43,5 @@ mutable struct SDEIntegrator{algType,IIP,uType,uEltype,tType,tdirType,P2,eigenTy qold::tTypeNoUnits q11::tTypeNoUnits stats::DiffEqBase.Stats + initializealg::IA end diff --git a/src/solve.jl b/src/solve.jl index d7bc2f8c..8858694d 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -66,6 +66,7 @@ function DiffEqBase.__init( userdata=nothing, initialize_integrator=true, seed = UInt64(0), alias_u0=false, alias_jumps = Threads.threadid()==1, + initializealg = SDEDefaultInit(), kwargs...) where recompile_flag prob = concrete_prob(_prob) @@ -587,7 +588,8 @@ function DiffEqBase.__init( uBottomEltype,tType,typeof(tdir),typeof(p), typeof(eigen_est),QT, uEltypeNoUnits,typeof(W),typeof(P),rateType,typeof(sol),typeof(cache), - FType,GType,CType,typeof(opts),typeof(noise),typeof(last_event_error),typeof(callback_cache),typeof(rate_constants)}( + FType,GType,CType,typeof(opts),typeof(noise),typeof(last_event_error),typeof(callback_cache),typeof(rate_constants), + typeof(initializealg)}( f,g,c,noise,uprev,tprev,t,u,p,tType(dt),tType(dt),tType(dt),dtcache,tspan[2],tdir, just_hit_tstop,do_error_check,isout,event_last_time, vector_event_last_time,last_event_error,accept_step, @@ -597,9 +599,10 @@ function DiffEqBase.__init( alg,sol, cache,callback_cache,tType(dt),W,P,rate_constants, opts,iter,success_iter,eigen_est,EEst,q, - QT(qoldinit),q11,stats) + QT(qoldinit),q11,stats,initializealg) if initialize_integrator + DiffEqBase.initialize_dae!(integrator) initialize_callbacks!(integrator, initialize_save) initialize!(integrator,integrator.cache) save_start && alg isa Union{StochasticDiffEqCompositeAlgorithm,