From c179173b36e083522b09d12a7095158584d084f8 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 1 Mar 2024 02:28:58 -0600 Subject: [PATCH] Add reinitialization map For mapping the initializeprob to a new time --- src/scimlfunctions.jl | 39 +++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index f387893b3..01b3edd73 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -402,7 +402,7 @@ numerically-defined functions. """ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ, O, TCV, - SYS, IProb, IProbMap} <: AbstractODEFunction{iip} + SYS, IProb, IProbMap, RProbMap} <: AbstractODEFunction{iip} f::F mass_matrix::TMM analytic::Ta @@ -421,6 +421,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW sys::SYS initializeprob::IProb initializeprobmap::IProbMap + reinitializemap::RProbMap end @doc doc""" @@ -1504,7 +1505,7 @@ automatically symbolically generating the Jacobian and more from the numerically-defined functions. """ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV, - SYS, IProb, IProbMap} <: + SYS, IProb, IProbMap, RProbMap} <: AbstractDAEFunction{iip} f::F analytic::Ta @@ -1522,6 +1523,7 @@ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TP sys::SYS initializeprob::IProb initializeprobmap::IProbMap + reinitializemap::RProbMap end """ @@ -2248,6 +2250,7 @@ function ODEFunction{iip, specialize}(f; sys = __has_sys(f) ? f.sys : nothing, initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing, initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing + reinitializemap = __has_reinitializemap(f) ? f.reinitializemap : nothing ) where {iip, specialize } @@ -2308,7 +2311,7 @@ function ODEFunction{iip, specialize}(f; typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, initializeprobmap) + observed, _colorvec, sys, initializeprob, initializeprobmap, reinitializemap) elseif specialize === false ODEFunction{iip, FunctionWrapperSpecialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), @@ -2318,10 +2321,11 @@ function ODEFunction{iip, specialize}(f; typeof(observed), typeof(_colorvec), typeof(sys), typeof(initializeprob), - typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac, + typeof(initializeprobmap), typeof(reinitializemap)}( + _f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, initializeprobmap) + observed, _colorvec, sys, initializeprob, initializeprobmap, reinitializemap) else ODEFunction{iip, specialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), @@ -2331,10 +2335,10 @@ function ODEFunction{iip, specialize}(f; typeof(observed), typeof(_colorvec), typeof(sys), typeof(initializeprob), - typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac, + typeof(initializeprobmap), typeof(reinitializemap)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initializeprob, initializeprobmap) + observed, _colorvec, sys, initializeprob, initializeprobmap, reinitializemap) end end @@ -2354,7 +2358,7 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) typeof(f.sys), Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, - f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap) + f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap, f.reinitializemap) else ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix), typeof(f.analytic), typeof(f.tgrad), @@ -2363,11 +2367,11 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) typeof(f.paramjac), typeof(f.observed), typeof(f.colorvec), typeof(f.sys), typeof(f.initializeprob), - typeof(f.initializeprobmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, + typeof(f.initializeprobmap), typeof(f.reinitializemap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, f.observed, f.colorvec, f.sys, f.initializeprob, - f.initializeprobmap) + f.initializeprobmap, f.reinitializemap) end end @@ -3159,7 +3163,9 @@ function DAEFunction{iip, specialize}(f; colorvec = __has_colorvec(f) ? f.colorvec : nothing, sys = __has_sys(f) ? f.sys : nothing, initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing, - initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing) where { + initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing + reinitializemap = __has_reinitializemap(f) ? f.reinitializemap : nothing + ) where { iip, specialize } @@ -3202,18 +3208,19 @@ function DAEFunction{iip, specialize}(f; Any, typeof(_colorvec), Any, Any, Any}(_f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, - _colorvec, sys, initializeprob, initializeprobmap) + _colorvec, sys, initializeprob, initializeprobmap, reinitializemap) else DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}( + typeof(sys), typeof(initializeprob), typeof(initializeprobmap), + typeof(reinitializemap)}( _f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, - _colorvec, sys, initializeprob, initializeprobmap) + _colorvec, sys, initializeprob, initializeprobmap, reinitializemap) end end @@ -3928,6 +3935,7 @@ __has_analytic_full(f) = isdefined(f, :analytic_full) __has_resid_prototype(f) = isdefined(f, :resid_prototype) __has_initializeprob(f) = isdefined(f, :initializeprob) __has_initializeprobmap(f) = isdefined(f, :initializeprobmap) +__has_reinitializemap(f) = isdefined(f, :reinitializemap) # compatibility has_invW(f::AbstractSciMLFunction) = false @@ -3946,6 +3954,9 @@ end function has_initializeprobmap(f::AbstractSciMLFunction) __has_initializeprobmap(f) && f.initializeprobmap !== nothing end +function has_reinitializemap(f::AbstractSciMLFunction) + __has_reinitializemap(f) && f.reinitializemap !== nothing +end function has_syms(f::AbstractSciMLFunction) if __has_syms(f)