Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow parameters in ODESystem to be unknowns in initialization system #2747

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -495,10 +495,18 @@
end

function SymbolicIndexingInterface.observed(sys::AbstractSystem, sym)
return let _fn = build_explicit_observed_function(sys, sym)
fn(u, p, t) = _fn(u, p, t)
fn(u, p::MTKParameters, t) = _fn(u, p..., t)
fn
if is_time_dependent(sys)
return let _fn = build_explicit_observed_function(sys, sym)
fn(u, p, t) = _fn(u, p, t)
fn(u, p::MTKParameters, t) = _fn(u, p..., t)
fn

Check warning on line 502 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L498-L502

Added lines #L498 - L502 were not covered by tests
end
else
return let _fn = build_explicit_observed_function(sys, sym)
fn2(u, p) = _fn(u, p)
fn2(u, p::MTKParameters) = _fn(u, p...)
fn2

Check warning on line 508 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L505-L508

Added lines #L505 - L508 were not covered by tests
end
end
end

Expand Down Expand Up @@ -1849,14 +1857,17 @@
end
initfn = NonlinearFunction(initsys)
initprobmap = getu(initsys, unknowns(sys))
initprob_init! = generate_initializeprob_init(sys, initsys)
initprob_update! = generate_initializeprob_update(sys, initsys)

Check warning on line 1861 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L1860-L1861

Added lines #L1860 - L1861 were not covered by tests
ps = full_parameters(sys)
lin_fun = let diff_idxs = diff_idxs,
alge_idxs = alge_idxs,
input_idxs = input_idxs,
sts = unknowns(sys),
get_initprob_u_p = get_initprob_u_p,
fun = ODEFunction{true, SciMLBase.FullSpecialize}(
sys, unknowns(sys), ps; initializeprobmap = initprobmap),
sys, unknowns(sys), ps; initializeprob_init! = initprob_init!,
initializeprob_update! = initprob_update!),
initfn = initfn,
h = build_explicit_observed_function(sys, outputs),
chunk = ForwardDiff.Chunk(input_idxs),
Expand Down
82 changes: 62 additions & 20 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,25 @@
all(iszero, tgrad)
end

struct GetAndSetFunctor{G, S}
getter::G
setter::S
end

function (gs::GetAndSetFunctor)(dest, source)
gs.setter(dest, gs.getter(source))

Check warning on line 289 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L288-L289

Added lines #L288 - L289 were not covered by tests
end

function generate_initializeprob_init(sys::AbstractSystem, initsys::AbstractSystem)
syms = vcat(variable_symbols(initsys), parameter_symbols(initsys))
return GetAndSetFunctor(getu(sys, syms), setu(initsys, syms))

Check warning on line 294 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L292-L294

Added lines #L292 - L294 were not covered by tests
end

function generate_initializeprob_update(sys::AbstractSystem, initsys::AbstractSystem)
syms = vcat(variable_symbols(sys), parameter_symbols(sys))
return GetAndSetFunctor(getu(initsys, syms), setu(sys, syms))

Check warning on line 299 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L297-L299

Added lines #L297 - L299 were not covered by tests
end

"""
```julia
DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
Expand Down Expand Up @@ -323,7 +342,8 @@
analytic = nothing,
split_idxs = nothing,
initializeprob = nothing,
initializeprobmap = nothing,
initializeprob_init! = nothing,
initializeprob_update! = nothing,
kwargs...) where {iip, specialize}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`")
Expand Down Expand Up @@ -506,7 +526,8 @@
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
analytic = analytic,
initializeprob = initializeprob,
initializeprobmap = initializeprobmap)
initializeprob_init! = initializeprob_init!,
initializeprob_update! = initializeprob_update!)
end

"""
Expand Down Expand Up @@ -537,7 +558,8 @@
eval_module = @__MODULE__,
checkbounds = false,
initializeprob = nothing,
initializeprobmap = nothing,
initializeprob_init! = nothing,
initializeprob_update! = nothing,
kwargs...) where {iip}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`")
Expand Down Expand Up @@ -611,7 +633,8 @@
jac_prototype = jac_prototype,
observed = observedfun,
initializeprob = initializeprob,
initializeprobmap = initializeprobmap)
initializeprob_init! = initializeprob_init!,
initializeprob_update! = initializeprob_update!)
end

function DiffEqBase.DDEFunction(sys::AbstractODESystem, args...; kwargs...)
Expand Down Expand Up @@ -862,7 +885,6 @@
varmap = canonicalize_varmap(varmap)
varlist = collect(map(unwrap, dvs))
missingvars = setdiff(varlist, collect(keys(varmap)))

# Append zeros to the variables which are determined by the initialization system
# This essentially bypasses the check for if initial conditions are defined for DAEs
# since they will be checked in the initialization problem's construction
Expand All @@ -873,11 +895,14 @@
parammap = Dict(unwrap(k) => v for (k, v) in todict(parammap))
elseif parammap isa AbstractArray
if isempty(parammap)
parammap = SciMLBase.NullParameters()
parammap = Dict()

Check warning on line 898 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L898

Added line #L898 was not covered by tests
else
parammap = Dict(unwrap.(parameters(sys)) .=> parammap)
end
elseif parammap === nothing || parammap isa SciMLBase.NullParameters
parammap = Dict()

Check warning on line 903 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L902-L903

Added lines #L902 - L903 were not covered by tests
end
missingpars = setdiff(parameters(sys), keys(parammap))

Check warning on line 905 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L905

Added line #L905 was not covered by tests

if has_discrete_subsystems(sys) && get_discrete_subsystems(sys) !== nothing
clockedparammap = Dict()
Expand All @@ -886,7 +911,7 @@
v = unwrap(v)
is_discrete_domain(v) || continue
op = operation(v)
if !isa(op, Symbolics.Operator) && parammap != SciMLBase.NullParameters() &&
if !isa(op, Symbolics.Operator) && !isempty(parammap) &&

Check warning on line 914 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L914

Added line #L914 was not covered by tests
haskey(parammap, v)
error("Initial conditions for discrete variables must be for the past state of the unknown. Instead of providing the condition for $v, provide the condition for $(Shift(iv, -1)(v)).")
end
Expand All @@ -909,7 +934,7 @@
# TODO: make it work with clocks
# ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first
if sys isa ODESystem && build_initializeprob &&
(implicit_dae || !isempty(missingvars)) &&
(implicit_dae || !isempty(missingvars) || !isempty(missingpars)) &&
all(isequal(Continuous()), ci.var_domain) &&
ModelingToolkit.get_tearing_state(sys) !== nothing &&
t !== nothing
Expand All @@ -921,15 +946,28 @@
end
initializeprob = ModelingToolkit.InitializationProblem(
sys, t, u0map, parammap; guesses, warn_initialize_determined)
initializeprobmap = getu(initializeprob, unknowns(sys))

punknowns = [p

Check warning on line 949 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L949

Added line #L949 was not covered by tests
for p in parameters(sys)
if is_variable(initializeprob, p) || is_observed(initializeprob, p)]
initializeprob_init! = generate_initializeprob_init(sys, initializeprob.f.sys)
initializeprob_update! = generate_initializeprob_update(sys, initializeprob.f.sys)

Check warning on line 953 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L952-L953

Added lines #L952 - L953 were not covered by tests
zerovars = Dict(setdiff(unknowns(sys), keys(defaults(sys))) .=> 0.0)
zeropars = Dict()
for p in punknowns
zeropars[p] = if Symbolics.isarraysymbolic(p)
collect(unwrap.(zero(p)))

Check warning on line 958 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L955-L958

Added lines #L955 - L958 were not covered by tests
else
unwrap(zero(p))

Check warning on line 960 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L960

Added line #L960 was not covered by tests
end
end

Check warning on line 962 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L962

Added line #L962 was not covered by tests
trueinit = collect(merge(zerovars, eltype(u0map) <: Pair ? todict(u0map) : u0map))
u0map isa StaticArraysCore.StaticArray &&
(trueinit = SVector{length(trueinit)}(trueinit))
else
initializeprob = nothing
initializeprobmap = nothing
zeropars = Dict()
initializeprob_init! = nothing
initializeprob_update! = nothing

Check warning on line 970 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L968-L970

Added lines #L968 - L970 were not covered by tests
trueinit = u0map
end

Expand All @@ -940,7 +978,12 @@
parammap == SciMLBase.NullParameters() && isempty(defs)
nothing
else
MTKParameters(sys, parammap, trueinit)
if parammap === nothing || parammap == SciMLBase.NullParameters()
parammap = Dict()

Check warning on line 982 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L981-L982

Added lines #L981 - L982 were not covered by tests
else
parammap = todict(parammap)

Check warning on line 984 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L984

Added line #L984 was not covered by tests
end
MTKParameters(sys, merge(parammap, zeropars), trueinit)

Check warning on line 986 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L986

Added line #L986 was not covered by tests
end
else
u0, p, defs = get_u0_p(sys,
Expand Down Expand Up @@ -973,8 +1016,8 @@
checkbounds = checkbounds, p = p,
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
sparse = sparse, eval_expression = eval_expression,
initializeprob = initializeprob,
initializeprobmap = initializeprobmap,
initializeprob = initializeprob, initializeprob_init! = initializeprob_init!,
initializeprob_update! = initializeprob_update!,
kwargs...)
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
end
Expand Down Expand Up @@ -1602,13 +1645,15 @@
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
end
parammap = parammap isa SciMLBase.NullParameters ? Dict() : todict(parammap)

Check warning on line 1648 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L1648

Added line #L1648 was not covered by tests
if isempty(u0map) && get_initializesystem(sys) !== nothing
isys = get_initializesystem(sys)
elseif isempty(u0map) && get_initializesystem(sys) === nothing
isys = structural_simplify(generate_initializesystem(sys); fully_determined = false)
isys = structural_simplify(

Check warning on line 1652 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L1652

Added line #L1652 was not covered by tests
generate_initializesystem(sys; pmap = parammap); fully_determined = false)
else
isys = structural_simplify(
generate_initializesystem(sys; u0map); fully_determined = false)
generate_initializesystem(sys; u0map, pmap = parammap); fully_determined = false)
end

uninit = setdiff(unknowns(sys), [unknowns(isys); getfield.(observed(isys), :lhs)])
Expand All @@ -1628,10 +1673,7 @@
if warn_initialize_determined && neqs < nunknown
@warn "Initialization system is underdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. To suppress this warning pass warn_initialize_determined = false."
end

parammap = parammap isa DiffEqBase.NullParameters || isempty(parammap) ?
[get_iv(sys) => t] :
merge(todict(parammap), Dict(get_iv(sys) => t))
parammap[get_iv(sys)] = t

Check warning on line 1676 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L1676

Added line #L1676 was not covered by tests
if isempty(u0map)
u0map = Dict()
end
Expand Down
29 changes: 28 additions & 1 deletion src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
function generate_initializesystem(sys::ODESystem;
u0map = Dict(),
pmap = Dict(),
name = nameof(sys),
guesses = Dict(), check_defguess = false,
default_dd_value = 0.0,
Expand Down Expand Up @@ -69,6 +70,32 @@
defs = merge(defaults(sys), filtered_u0)
guesses = merge(get_guesses(sys), todict(guesses), dd_guess)

all_params = parameters(sys)
pars = [parameters(sys); get_iv(sys)]
paramsubs = Dict()
for p in all_params
haskey(pmap, p) && continue
paramsubs[p] = tovar(p)
push!(full_states, tovar(p))
deleteat!(pars, findfirst(isequal(p), pars))
if haskey(defs, p)
def = defs[p]
if def isa Equation
p ∉ keys(guesses) && check_defguess &&

Check warning on line 84 in src/systems/nonlinear/initializesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/nonlinear/initializesystem.jl#L73-L84

Added lines #L73 - L84 were not covered by tests
error("Invalid setup: parameter $(p) has an initial condition equation with no guess.")
push!(eqs_ics, def)
push!(u0, p => guesses[p])

Check warning on line 87 in src/systems/nonlinear/initializesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/nonlinear/initializesystem.jl#L86-L87

Added lines #L86 - L87 were not covered by tests
else
push!(eqs_ics, p ~ def)
push!(u0, p => def)

Check warning on line 90 in src/systems/nonlinear/initializesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/nonlinear/initializesystem.jl#L89-L90

Added lines #L89 - L90 were not covered by tests
end
elseif haskey(guesses, p)
push!(u0, p => guesses[p])
elseif check_defguess
error("Invalid setup: parameter $(p) has no default value or initial guess")

Check warning on line 95 in src/systems/nonlinear/initializesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/nonlinear/initializesystem.jl#L92-L95

Added lines #L92 - L95 were not covered by tests
end
end

Check warning on line 97 in src/systems/nonlinear/initializesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/nonlinear/initializesystem.jl#L97

Added line #L97 was not covered by tests

if !algebraic_only
for st in full_states
if st ∈ keys(defs)
Expand All @@ -91,12 +118,12 @@
end
end

pars = [parameters(sys); get_iv(sys)]
nleqs = if algebraic_only
[eqs_ics; observed(sys)]
else
[eqs_ics; get_initialization_eqs(sys); observed(sys)]
end
nleqs = fast_substitute(nleqs, paramsubs)

Check warning on line 126 in src/systems/nonlinear/initializesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/nonlinear/initializesystem.jl#L126

Added line #L126 was not covered by tests

sys_nl = NonlinearSystem(nleqs,
full_states,
Expand Down
Loading