From 98cfe4f7679f0dc2488bee75204e842b51688c72 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 28 May 2024 15:49:03 +0530 Subject: [PATCH 1/6] feat: add initializeprob_updatep! to ODEFunction, DAEFunction --- src/scimlfunctions.jl | 56 ++++++++++++++++++++++++++++--------------- 1 file changed, 37 insertions(+), 19 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index becff1bf9..2ae083242 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -402,7 +402,7 @@ numerically-defined functions. """ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ, O, TCV, - SYS, IProb, IProbMap} <: AbstractODEFunction{iip} + SYS, IProb, IProbMap, IProbP} <: AbstractODEFunction{iip} f::F mass_matrix::TMM analytic::Ta @@ -421,6 +421,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW sys::SYS initializeprob::IProb initializeprobmap::IProbMap + initializeprob_updatep!::IProbP end @doc doc""" @@ -1504,7 +1505,7 @@ automatically symbolically generating the Jacobian and more from the numerically-defined functions. """ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV, - SYS, IProb, IProbMap} <: + SYS, IProb, IProbMap, IProbP} <: AbstractDAEFunction{iip} f::F analytic::Ta @@ -1522,6 +1523,7 @@ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TP sys::SYS initializeprob::IProb initializeprobmap::IProbMap + initializeprob_updatep!::IProbP end """ @@ -2376,7 +2378,8 @@ function ODEFunction{iip, specialize}(f; colorvec = __has_colorvec(f) ? f.colorvec : nothing, sys = __has_sys(f) ? f.sys : nothing, initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing, - initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing + initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing, + initializeprob_updatep! = __has_initializeprob_updatep(f) ? f.initializeprob_updatep! : nothing ) where {iip, specialize } @@ -2434,10 +2437,11 @@ function ODEFunction{iip, specialize}(f; typeof(sparsity), Any, Any, typeof(W_prototype), Any, Any, typeof(_colorvec), - typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac, + typeof(sys), Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, initializeprobmap) + observed, _colorvec, sys, initializeprob, initializeprobmap, + initializeprob_updatep!) elseif specialize === false ODEFunction{iip, FunctionWrapperSpecialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), @@ -2447,10 +2451,12 @@ function ODEFunction{iip, specialize}(f; typeof(observed), typeof(_colorvec), typeof(sys), typeof(initializeprob), - typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac, + typeof(initializeprobmap), + typeof(initializeprob_updatep!)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, initializeprobmap) + observed, _colorvec, sys, initializeprob, initializeprobmap, + initializeprob_updatep!) else ODEFunction{iip, specialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), @@ -2460,10 +2466,12 @@ function ODEFunction{iip, specialize}(f; typeof(observed), typeof(_colorvec), typeof(sys), typeof(initializeprob), - typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac, + typeof(initializeprobmap), + typeof(initializeprob_updatep!)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, initializeprobmap) + observed, _colorvec, sys, initializeprob, initializeprobmap, + initializeprob_updatep!) end end @@ -2480,10 +2488,11 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) Any, Any, Any, Any, typeof(f.jac_prototype), typeof(f.sparsity), Any, Any, Any, Any, typeof(f.colorvec), - typeof(f.sys), Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, + typeof(f.sys), Any, Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, - f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap) + f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap, + f.initializeprob_updatep!) else ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix), typeof(f.analytic), typeof(f.tgrad), @@ -2492,11 +2501,12 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) typeof(f.paramjac), typeof(f.observed), typeof(f.colorvec), typeof(f.sys), typeof(f.initializeprob), - typeof(f.initializeprobmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, - f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, + typeof(f.initializeprobmap), + typeof(f.initializeprob_updatep!)}(newf, f.mass_matrix, f.analytic, f.tgrad, + f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, f.observed, f.colorvec, f.sys, f.initializeprob, - f.initializeprobmap) + f.initializeprobmap, f.initializeprob_updatep!) end end @@ -3288,7 +3298,9 @@ function DAEFunction{iip, specialize}(f; colorvec = __has_colorvec(f) ? f.colorvec : nothing, sys = __has_sys(f) ? f.sys : nothing, initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing, - initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing) where { + initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing, + initializeprob_updatep! = __has_initializeprob_updatep(f) ? f.initializeprob_updatep! : nothing + ) where { iip, specialize } @@ -3328,21 +3340,22 @@ function DAEFunction{iip, specialize}(f; DAEFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, - Any, typeof(_colorvec), Any, Any, Any}(_f, analytic, tgrad, jac, jvp, + Any, typeof(_colorvec), Any, Any, Any, Any}(_f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, - _colorvec, sys, initializeprob, initializeprobmap) + _colorvec, sys, initializeprob, initializeprobmap, initializeprob_updatep!) else DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}( + typeof(sys), typeof(initializeprob), typeof(initializeprobmap), + typeof(initializeprob_updatep!)}( _f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, - _colorvec, sys, initializeprob, initializeprobmap) + _colorvec, sys, initializeprob, initializeprobmap, initializeprob_updatep!) end end @@ -4232,6 +4245,7 @@ __has_analytic_full(f) = isdefined(f, :analytic_full) __has_resid_prototype(f) = isdefined(f, :resid_prototype) __has_initializeprob(f) = isdefined(f, :initializeprob) __has_initializeprobmap(f) = isdefined(f, :initializeprobmap) +__has_initializeprob_updatep(f) = isdefined(f, :initializeprob_updatep!) # compatibility has_invW(f::AbstractSciMLFunction) = false @@ -4250,6 +4264,10 @@ end function has_initializeprobmap(f::AbstractSciMLFunction) __has_initializeprobmap(f) && f.initializeprobmap !== nothing end +function has_initializeprob_updatep(f::AbstractSciMLFunction) + __has_initializeprob_updatep(f) && f.initializeprob_updatep! !== nothing +end + function has_syms(f::AbstractSciMLFunction) if __has_syms(f) From c7227ef46177d20a7d916fee65611c00c6903617 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 29 May 2024 15:45:27 +0530 Subject: [PATCH 2/6] fixup! feat: add initializeprob_updatep! to ODEFunction, DAEFunction --- src/scimlfunctions.jl | 89 ++++++++++++++++++++++++------------------- 1 file changed, 49 insertions(+), 40 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 2ae083242..414577426 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -402,7 +402,7 @@ numerically-defined functions. """ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ, O, TCV, - SYS, IProb, IProbMap, IProbP} <: AbstractODEFunction{iip} + SYS, IProb, IProbInit, IProbUp} <: AbstractODEFunction{iip} f::F mass_matrix::TMM analytic::Ta @@ -419,9 +419,14 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW observed::O colorvec::TCV sys::SYS + # The initialization problem. initializeprob::IProb - initializeprobmap::IProbMap - initializeprob_updatep!::IProbP + # Function which takes (initializeprob, integrator) and updates the problem with + # unknown and parameter values from the integrator. + initializeprob_init!::IProbInit + # Function which takes (integrator, initializesol) and updates the integrator with + # unknown and parameter values from initializesol (solution of initializeprob). + initializeprob_update!::IProbUp end @doc doc""" @@ -1505,7 +1510,7 @@ automatically symbolically generating the Jacobian and more from the numerically-defined functions. """ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV, - SYS, IProb, IProbMap, IProbP} <: + SYS, IProb, IProbInit, IProbUp} <: AbstractDAEFunction{iip} f::F analytic::Ta @@ -1521,9 +1526,14 @@ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TP observed::O colorvec::TCV sys::SYS + # The initialization problem. initializeprob::IProb - initializeprobmap::IProbMap - initializeprob_updatep!::IProbP + # Function which takes (initializeprob, integrator) and updates the problem with + # unknown and parameter values from the integrator. + initializeprob_init!::IProbInit + # Function which takes (integrator, initializesol) and updates the integrator with + # unknown and parameter values from initializesol (solution of initializeprob). + initializeprob_update!::IProbUp end """ @@ -2378,8 +2388,8 @@ function ODEFunction{iip, specialize}(f; colorvec = __has_colorvec(f) ? f.colorvec : nothing, sys = __has_sys(f) ? f.sys : nothing, initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing, - initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing, - initializeprob_updatep! = __has_initializeprob_updatep(f) ? f.initializeprob_updatep! : nothing + initializeprob_init! = __has_initializeprob_init(f) ? f.initializeprob_init! : nothing, + initializeprob_update! = __has_initializeprob_update(f) ? f.initializeprob_update! : nothing ) where {iip, specialize } @@ -2440,8 +2450,8 @@ function ODEFunction{iip, specialize}(f; typeof(sys), Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, initializeprobmap, - initializeprob_updatep!) + observed, _colorvec, sys, initializeprob, initializeprob_init!, + initializeprob_update!) elseif specialize === false ODEFunction{iip, FunctionWrapperSpecialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), @@ -2451,12 +2461,12 @@ function ODEFunction{iip, specialize}(f; typeof(observed), typeof(_colorvec), typeof(sys), typeof(initializeprob), - typeof(initializeprobmap), - typeof(initializeprob_updatep!)}(_f, mass_matrix, analytic, tgrad, jac, + typeof(initializeprob_init!), + typeof(initializeprob_update!)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, initializeprobmap, - initializeprob_updatep!) + observed, _colorvec, sys, initializeprob, initializeprob_init!, + initializeprob_update!) else ODEFunction{iip, specialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), @@ -2465,13 +2475,12 @@ function ODEFunction{iip, specialize}(f; typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initializeprob), - typeof(initializeprobmap), - typeof(initializeprob_updatep!)}(_f, mass_matrix, analytic, tgrad, jac, + typeof(sys), typeof(initializeprob), typeof(initializeprob_init!), + typeof(initializeprob_update!)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, initializeprobmap, - initializeprob_updatep!) + observed, _colorvec, sys, initializeprob, initializeprob_init!, + initializeprob_update!) end end @@ -2488,11 +2497,11 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) Any, Any, Any, Any, typeof(f.jac_prototype), typeof(f.sparsity), Any, Any, Any, Any, typeof(f.colorvec), - typeof(f.sys), Any, Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, + typeof(f.sys), Any, Any, Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, - f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap, - f.initializeprob_updatep!) + f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprob_init!, + f.initializeprob_update!) else ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix), typeof(f.analytic), typeof(f.tgrad), @@ -2500,13 +2509,12 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype), typeof(f.paramjac), typeof(f.observed), typeof(f.colorvec), - typeof(f.sys), typeof(f.initializeprob), - typeof(f.initializeprobmap), - typeof(f.initializeprob_updatep!)}(newf, f.mass_matrix, f.analytic, f.tgrad, + typeof(f.sys), typeof(f.initializeprob), typeof(f.initializeprob_init!), + typeof(f.initializeprob_update!)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, - f.observed, f.colorvec, f.sys, f.initializeprob, - f.initializeprobmap, f.initializeprob_updatep!) + f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprob_init!, + f.initializeprob_update!) end end @@ -3298,8 +3306,8 @@ function DAEFunction{iip, specialize}(f; colorvec = __has_colorvec(f) ? f.colorvec : nothing, sys = __has_sys(f) ? f.sys : nothing, initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing, - initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing, - initializeprob_updatep! = __has_initializeprob_updatep(f) ? f.initializeprob_updatep! : nothing + initializeprob_init! = __has_initializeprob_init(f) ? f.initializeprob_init! : nothing, + initializeprob_update! = __has_initializeprob_update(f) ? f.initializeprob_update! : nothing ) where { iip, specialize @@ -3340,22 +3348,23 @@ function DAEFunction{iip, specialize}(f; DAEFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, - Any, typeof(_colorvec), Any, Any, Any, Any}(_f, analytic, tgrad, jac, jvp, + Any, typeof(_colorvec), Any, Any, Any}(_f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, - _colorvec, sys, initializeprob, initializeprobmap, initializeprob_updatep!) + _colorvec, sys, initializeprob, initializeprob_update!) else DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initializeprob), typeof(initializeprobmap), - typeof(initializeprob_updatep!)}( + typeof(sys), typeof(initializeprob), typeof(initializeprob_init!), + typeof(initializeprob_update!)}( _f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, - _colorvec, sys, initializeprob, initializeprobmap, initializeprob_updatep!) + _colorvec, sys, initializeprob, initializeprob_init!, + initializeprob_update!) end end @@ -4244,8 +4253,8 @@ __has_sys(f) = isdefined(f, :sys) __has_analytic_full(f) = isdefined(f, :analytic_full) __has_resid_prototype(f) = isdefined(f, :resid_prototype) __has_initializeprob(f) = isdefined(f, :initializeprob) -__has_initializeprobmap(f) = isdefined(f, :initializeprobmap) -__has_initializeprob_updatep(f) = isdefined(f, :initializeprob_updatep!) +__has_initializeprob_init(f) = isdefined(f, :initializeprob_init!) +__has_initializeprob_update(f) = isdefined(f, :initializeprob_update!) # compatibility has_invW(f::AbstractSciMLFunction) = false @@ -4261,11 +4270,11 @@ has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing function has_initializeprob(f::AbstractSciMLFunction) __has_initializeprob(f) && f.initializeprob !== nothing end -function has_initializeprobmap(f::AbstractSciMLFunction) - __has_initializeprobmap(f) && f.initializeprobmap !== nothing +function has_initializeprob_init(f::AbstractSciMLFunction) + __has_initializeprob_init(f) && f.initializeprob_init! !== nothing end -function has_initializeprob_updatep(f::AbstractSciMLFunction) - __has_initializeprob_updatep(f) && f.initializeprob_updatep! !== nothing +function has_initializeprob_update(f::AbstractSciMLFunction) + __has_initializeprob_update(f) && f.initializeprob_update! !== nothing end From ac59617453705006155b813627e9491b716332e4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 29 May 2024 16:34:07 +0530 Subject: [PATCH 3/6] feat: remake initialization problem when remaking ODEProblem --- Project.toml | 2 ++ src/SciMLBase.jl | 1 + src/remake.jl | 14 ++++++++++++-- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 98d008501..778d4c1b1 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "2.39.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" @@ -51,6 +52,7 @@ SciMLBaseZygoteExt = "Zygote" [compat] ADTypes = "0.2.5,1.0.0" +Accessors = "0.1" ArrayInterface = "7.6" ChainRules = "1.58.0" ChainRulesCore = "1.18" diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index dad4ff058..8e9e8187a 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -22,6 +22,7 @@ import FunctionWrappersWrappers import RuntimeGeneratedFunctions import EnumX import ADTypes: AbstractADType +import Accessors: @reset using Reexport using SciMLOperators diff --git a/src/remake.jl b/src/remake.jl index 017632a00..d066a7459 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -78,7 +78,6 @@ end p = missing, kwargs = missing, _kwargs...) Remake the given `ODEProblem`. -If `u0` or `p` are given as symbolic maps `ModelingToolkit.jl` has to be loaded. """ function remake(prob::ODEProblem; f = missing, u0 = missing, @@ -128,7 +127,18 @@ function remake(prob::ODEProblem; f = missing, else _f = ODEFunction{isinplace(prob), specialization(prob.f)}(f) end - + if has_initializeprob(prob.f) && (typeof(u0) != typeof(prob.u0) || typeof(p) != typeof(prob.p) || typeof(tspan) != typeof(prob.tspan)) + temp_state = ProblemState(; u = u0, p = p, t = tspan[1]) + initu0 = [sym => getu(prob, sym)(temp_state) for sym in variable_symbols(prob.f.initializeprob)] + initp = [sym =>getu(prob, sym)(temp_state) for sym in parameter_symbols(prob.f.initializeprob)] + if initu0 == [] + initu0 = nothing + end + if initp == [] + initp = nothing + end + @reset _f.initializeprob = remake(prob.f.initializeprob, u0 = initu0, p = initp) + end if kwargs === missing ODEProblem{isinplace(prob)}(_f, u0, tspan, p, prob.problem_type; prob.kwargs..., _kwargs...) From 08abfe1c34cc470b5ffe2992f36196739324ceea Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 30 May 2024 11:55:59 +0530 Subject: [PATCH 4/6] fixup! feat: add initializeprob_updatep! to ODEFunction, DAEFunction --- src/scimlfunctions.jl | 50 +++++++++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 414577426..e30e35c97 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -402,7 +402,7 @@ numerically-defined functions. """ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ, O, TCV, - SYS, IProb, IProbInit, IProbUp} <: AbstractODEFunction{iip} + SYS, IProb, IProbMap, IProbInit, IProbUp} <: AbstractODEFunction{iip} f::F mass_matrix::TMM analytic::Ta @@ -421,6 +421,9 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW sys::SYS # The initialization problem. initializeprob::IProb + # Legacy: Function which takes (initializesol) and returns the state vector of the + # integrator + initializeprobmap::IProbMap # Function which takes (initializeprob, integrator) and updates the problem with # unknown and parameter values from the integrator. initializeprob_init!::IProbInit @@ -1510,7 +1513,7 @@ automatically symbolically generating the Jacobian and more from the numerically-defined functions. """ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV, - SYS, IProb, IProbInit, IProbUp} <: + SYS, IProb, IProbMap, IProbInit, IProbUp} <: AbstractDAEFunction{iip} f::F analytic::Ta @@ -1528,6 +1531,9 @@ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TP sys::SYS # The initialization problem. initializeprob::IProb + # Legacy: Function which takes (initializesol) and returns the state vector of the + # integrator + initializeprobmap::IProbMap # Function which takes (initializeprob, integrator) and updates the problem with # unknown and parameter values from the integrator. initializeprob_init!::IProbInit @@ -2388,6 +2394,7 @@ function ODEFunction{iip, specialize}(f; colorvec = __has_colorvec(f) ? f.colorvec : nothing, sys = __has_sys(f) ? f.sys : nothing, initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing, + initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing, initializeprob_init! = __has_initializeprob_init(f) ? f.initializeprob_init! : nothing, initializeprob_update! = __has_initializeprob_update(f) ? f.initializeprob_update! : nothing ) where {iip, @@ -2447,10 +2454,10 @@ function ODEFunction{iip, specialize}(f; typeof(sparsity), Any, Any, typeof(W_prototype), Any, Any, typeof(_colorvec), - typeof(sys), Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac, + typeof(sys), Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, initializeprob_init!, + observed, _colorvec, sys, initializeprob, initializeprobmap, initializeprob_init!, initializeprob_update!) elseif specialize === false ODEFunction{iip, FunctionWrapperSpecialize, @@ -2460,12 +2467,12 @@ function ODEFunction{iip, specialize}(f; typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initializeprob), + typeof(sys), typeof(initializeprob), typeof(initializeprobmap) typeof(initializeprob_init!), typeof(initializeprob_update!)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, initializeprob_init!, + observed, _colorvec, sys, initializeprob, initializeprobmap, initializeprob_init!, initializeprob_update!) else ODEFunction{iip, specialize, @@ -2475,11 +2482,12 @@ function ODEFunction{iip, specialize}(f; typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initializeprob), typeof(initializeprob_init!), + typeof(sys), typeof(initializeprob), typeof(initializeprobmap), + typeof(initializeprob_init!), typeof(initializeprob_update!)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, initializeprob_init!, + observed, _colorvec, sys, initializeprob, initializeprobmap, initializeprob_init!, initializeprob_update!) end end @@ -2497,11 +2505,11 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) Any, Any, Any, Any, typeof(f.jac_prototype), typeof(f.sparsity), Any, Any, Any, Any, typeof(f.colorvec), - typeof(f.sys), Any, Any, Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, + typeof(f.sys), Any, Any, Any, Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, - f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprob_init!, - f.initializeprob_update!) + f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap, + f.initializeprob_init!, f.initializeprob_update!) else ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix), typeof(f.analytic), typeof(f.tgrad), @@ -2509,12 +2517,13 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype), typeof(f.paramjac), typeof(f.observed), typeof(f.colorvec), - typeof(f.sys), typeof(f.initializeprob), typeof(f.initializeprob_init!), + typeof(f.sys), typeof(f.initializeprob), typeof(f.initializeprobmap), + typeof(f.initializeprob_init!), typeof(f.initializeprob_update!)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, - f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprob_init!, - f.initializeprob_update!) + f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap, + f.initializeprob_init!, f.initializeprob_update!) end end @@ -3306,6 +3315,7 @@ function DAEFunction{iip, specialize}(f; colorvec = __has_colorvec(f) ? f.colorvec : nothing, sys = __has_sys(f) ? f.sys : nothing, initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing, + initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing, initializeprob_init! = __has_initializeprob_init(f) ? f.initializeprob_init! : nothing, initializeprob_update! = __has_initializeprob_update(f) ? f.initializeprob_update! : nothing ) where { @@ -3348,22 +3358,22 @@ function DAEFunction{iip, specialize}(f; DAEFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, - Any, typeof(_colorvec), Any, Any, Any}(_f, analytic, tgrad, jac, jvp, + Any, typeof(_colorvec), Any, Any, Any, Any}(_f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, - _colorvec, sys, initializeprob, initializeprob_update!) + _colorvec, sys, initializeprob, initializeprobmap, initializeprob_init!, initializeprob_update!) else DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initializeprob), typeof(initializeprob_init!), + typeof(sys), typeof(initializeprob), typeof(initializeprobmap), typeof(initializeprob_init!), typeof(initializeprob_update!)}( _f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, - _colorvec, sys, initializeprob, initializeprob_init!, + _colorvec, sys, initializeprob, initializeprobmap, initializeprob_init!, initializeprob_update!) end end @@ -4253,6 +4263,7 @@ __has_sys(f) = isdefined(f, :sys) __has_analytic_full(f) = isdefined(f, :analytic_full) __has_resid_prototype(f) = isdefined(f, :resid_prototype) __has_initializeprob(f) = isdefined(f, :initializeprob) +__has_initializeprobmap(f) = isdefined(f, :initializeprobmap) __has_initializeprob_init(f) = isdefined(f, :initializeprob_init!) __has_initializeprob_update(f) = isdefined(f, :initializeprob_update!) @@ -4270,6 +4281,9 @@ has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing function has_initializeprob(f::AbstractSciMLFunction) __has_initializeprob(f) && f.initializeprob !== nothing end +function has_initializeprobmap(f::AbstractSciMLFunction) + __has_initializeprobmap(f) && f.initializeprobmap !== nothing +end function has_initializeprob_init(f::AbstractSciMLFunction) __has_initializeprob_init(f) && f.initializeprob_init! !== nothing end From 18620a9b84f0b5b4f1677f19d3fd6941804c5513 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 30 May 2024 12:01:05 +0530 Subject: [PATCH 5/6] fixup! fixup! feat: add initializeprob_updatep! to ODEFunction, DAEFunction --- src/scimlfunctions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index e30e35c97..562130d23 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2467,7 +2467,7 @@ function ODEFunction{iip, specialize}(f; typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initializeprob), typeof(initializeprobmap) + typeof(sys), typeof(initializeprob), typeof(initializeprobmap), typeof(initializeprob_init!), typeof(initializeprob_update!)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, From c4030d7bed7fc7d31ff73c79e830025603424e24 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 30 May 2024 12:14:29 +0530 Subject: [PATCH 6/6] fixup! fixup! feat: add initializeprob_updatep! to ODEFunction, DAEFunction --- src/scimlfunctions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 562130d23..8eee72eb8 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -3358,7 +3358,7 @@ function DAEFunction{iip, specialize}(f; DAEFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, - Any, typeof(_colorvec), Any, Any, Any, Any}(_f, analytic, tgrad, jac, jvp, + Any, typeof(_colorvec), Any, Any, Any, Any, Any}(_f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, _colorvec, sys, initializeprob, initializeprobmap, initializeprob_init!, initializeprob_update!)