Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add initializeprob_updatep! to ODEFunction, DAEFunction #698

Closed
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "2.39.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Expand Down Expand Up @@ -51,6 +52,7 @@ SciMLBaseZygoteExt = "Zygote"

[compat]
ADTypes = "0.2.5,1.0.0"
Accessors = "0.1"
ArrayInterface = "7.6"
ChainRules = "1.58.0"
ChainRulesCore = "1.18"
Expand Down
1 change: 1 addition & 0 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import FunctionWrappersWrappers
import RuntimeGeneratedFunctions
import EnumX
import ADTypes: AbstractADType
import Accessors: @reset

using Reexport
using SciMLOperators
Expand Down
14 changes: 12 additions & 2 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@
p = missing, kwargs = missing, _kwargs...)

Remake the given `ODEProblem`.
If `u0` or `p` are given as symbolic maps `ModelingToolkit.jl` has to be loaded.
"""
function remake(prob::ODEProblem; f = missing,
u0 = missing,
Expand Down Expand Up @@ -128,7 +127,18 @@
else
_f = ODEFunction{isinplace(prob), specialization(prob.f)}(f)
end

if has_initializeprob(prob.f) && (typeof(u0) != typeof(prob.u0) || typeof(p) != typeof(prob.p) || typeof(tspan) != typeof(prob.tspan))
temp_state = ProblemState(; u = u0, p = p, t = tspan[1])
initu0 = [sym => getu(prob, sym)(temp_state) for sym in variable_symbols(prob.f.initializeprob)]
initp = [sym =>getu(prob, sym)(temp_state) for sym in parameter_symbols(prob.f.initializeprob)]
if initu0 == []
initu0 = nothing

Check warning on line 135 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L131-L135

Added lines #L131 - L135 were not covered by tests
end
if initp == []
initp = nothing

Check warning on line 138 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L137-L138

Added lines #L137 - L138 were not covered by tests
end
@reset _f.initializeprob = remake(prob.f.initializeprob, u0 = initu0, p = initp)

Check warning on line 140 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L140

Added line #L140 was not covered by tests
end
if kwargs === missing
ODEProblem{isinplace(prob)}(_f, u0, tspan, p, prob.problem_type; prob.kwargs...,
_kwargs...)
Expand Down
87 changes: 64 additions & 23 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@
"""
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
O, TCV,
SYS, IProb, IProbMap} <: AbstractODEFunction{iip}
SYS, IProb, IProbMap, IProbInit, IProbUp} <: AbstractODEFunction{iip}
f::F
mass_matrix::TMM
analytic::Ta
Expand All @@ -419,8 +419,17 @@
observed::O
colorvec::TCV
sys::SYS
# The initialization problem.
initializeprob::IProb
# Legacy: Function which takes (initializesol) and returns the state vector of the
# integrator
initializeprobmap::IProbMap
# Function which takes (initializeprob, integrator) and updates the problem with
# unknown and parameter values from the integrator.
initializeprob_init!::IProbInit
# Function which takes (integrator, initializesol) and updates the integrator with
# unknown and parameter values from initializesol (solution of initializeprob).
initializeprob_update!::IProbUp
end

@doc doc"""
Expand Down Expand Up @@ -1504,7 +1513,7 @@
numerically-defined functions.
"""
struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV,
SYS, IProb, IProbMap} <:
SYS, IProb, IProbMap, IProbInit, IProbUp} <:
AbstractDAEFunction{iip}
f::F
analytic::Ta
Expand All @@ -1520,8 +1529,17 @@
observed::O
colorvec::TCV
sys::SYS
# The initialization problem.
initializeprob::IProb
# Legacy: Function which takes (initializesol) and returns the state vector of the
# integrator
initializeprobmap::IProbMap
# Function which takes (initializeprob, integrator) and updates the problem with
# unknown and parameter values from the integrator.
initializeprob_init!::IProbInit
# Function which takes (integrator, initializesol) and updates the integrator with
# unknown and parameter values from initializesol (solution of initializeprob).
initializeprob_update!::IProbUp
end

"""
Expand Down Expand Up @@ -2376,7 +2394,9 @@
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
sys = __has_sys(f) ? f.sys : nothing,
initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing,
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
initializeprob_init! = __has_initializeprob_init(f) ? f.initializeprob_init! : nothing,
initializeprob_update! = __has_initializeprob_update(f) ? f.initializeprob_update! : nothing
) where {iip,
specialize
}
Expand Down Expand Up @@ -2434,10 +2454,11 @@
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
Any,
typeof(_colorvec),
typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
typeof(sys), Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, initializeprobmap)
observed, _colorvec, sys, initializeprob, initializeprobmap, initializeprob_init!,
initializeprob_update!)
elseif specialize === false
ODEFunction{iip, FunctionWrapperSpecialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2446,11 +2467,13 @@
typeof(paramjac),
typeof(observed),
typeof(_colorvec),
typeof(sys), typeof(initializeprob),
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
typeof(sys), typeof(initializeprob), typeof(initializeprobmap),
typeof(initializeprob_init!),
typeof(initializeprob_update!)}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, initializeprobmap)
observed, _colorvec, sys, initializeprob, initializeprobmap, initializeprob_init!,
initializeprob_update!)
else
ODEFunction{iip, specialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2459,11 +2482,13 @@
typeof(paramjac),
typeof(observed),
typeof(_colorvec),
typeof(sys), typeof(initializeprob),
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
typeof(sys), typeof(initializeprob), typeof(initializeprobmap),
typeof(initializeprob_init!),
typeof(initializeprob_update!)}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, initializeprobmap)
observed, _colorvec, sys, initializeprob, initializeprobmap, initializeprob_init!,
initializeprob_update!)
end
end

Expand All @@ -2480,23 +2505,25 @@
Any, Any, Any, Any, typeof(f.jac_prototype),
typeof(f.sparsity), Any, Any, Any,
Any, typeof(f.colorvec),
typeof(f.sys), Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
typeof(f.sys), Any, Any, Any, Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap)
f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap,
f.initializeprob_init!, f.initializeprob_update!)
else
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
typeof(f.analytic), typeof(f.tgrad),
typeof(f.jac), typeof(f.jvp), typeof(f.vjp), typeof(f.jac_prototype),
typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype),
typeof(f.paramjac),
typeof(f.observed), typeof(f.colorvec),
typeof(f.sys), typeof(f.initializeprob),
typeof(f.initializeprobmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
typeof(f.sys), typeof(f.initializeprob), typeof(f.initializeprobmap),
typeof(f.initializeprob_init!),
typeof(f.initializeprob_update!)}(newf, f.mass_matrix, f.analytic, f.tgrad,
f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys, f.initializeprob,
f.initializeprobmap)
f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap,
f.initializeprob_init!, f.initializeprob_update!)
end
end

Expand Down Expand Up @@ -3288,7 +3315,10 @@
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
sys = __has_sys(f) ? f.sys : nothing,
initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing,
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing) where {
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
initializeprob_init! = __has_initializeprob_init(f) ? f.initializeprob_init! : nothing,
initializeprob_update! = __has_initializeprob_update(f) ? f.initializeprob_update! : nothing
) where {
iip,
specialize
}
Expand Down Expand Up @@ -3328,21 +3358,23 @@
DAEFunction{iip, specialize, Any, Any, Any,
Any, Any, Any, Any, Any,
Any, Any, Any,
Any, typeof(_colorvec), Any, Any, Any}(_f, analytic, tgrad, jac, jvp,
Any, typeof(_colorvec), Any, Any, Any, Any, Any}(_f, analytic, tgrad, jac, jvp,
vjp, jac_prototype, sparsity,
Wfact, Wfact_t, paramjac, observed,
_colorvec, sys, initializeprob, initializeprobmap)
_colorvec, sys, initializeprob, initializeprobmap, initializeprob_init!, initializeprob_update!)
else
DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad),
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
typeof(sparsity), typeof(Wfact), typeof(Wfact_t),
typeof(paramjac),
typeof(observed), typeof(_colorvec),
typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}(
typeof(sys), typeof(initializeprob), typeof(initializeprobmap), typeof(initializeprob_init!),
typeof(initializeprob_update!)}(
_f, analytic, tgrad, jac, jvp, vjp,
jac_prototype, sparsity, Wfact, Wfact_t,
paramjac, observed,
_colorvec, sys, initializeprob, initializeprobmap)
_colorvec, sys, initializeprob, initializeprobmap, initializeprob_init!,
initializeprob_update!)
end
end

Expand Down Expand Up @@ -4232,6 +4264,8 @@
__has_resid_prototype(f) = isdefined(f, :resid_prototype)
__has_initializeprob(f) = isdefined(f, :initializeprob)
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)
__has_initializeprob_init(f) = isdefined(f, :initializeprob_init!)
__has_initializeprob_update(f) = isdefined(f, :initializeprob_update!)

# compatibility
has_invW(f::AbstractSciMLFunction) = false
Expand All @@ -4250,6 +4284,13 @@
function has_initializeprobmap(f::AbstractSciMLFunction)
__has_initializeprobmap(f) && f.initializeprobmap !== nothing
end
function has_initializeprob_init(f::AbstractSciMLFunction)
__has_initializeprob_init(f) && f.initializeprob_init! !== nothing

Check warning on line 4288 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L4287-L4288

Added lines #L4287 - L4288 were not covered by tests
end
function has_initializeprob_update(f::AbstractSciMLFunction)
__has_initializeprob_update(f) && f.initializeprob_update! !== nothing

Check warning on line 4291 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L4290-L4291

Added lines #L4290 - L4291 were not covered by tests
end


function has_syms(f::AbstractSciMLFunction)
if __has_syms(f)
Expand Down
Loading