Skip to content

Commit

Permalink
fixup! feat: allow initialization of null integrators
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 29, 2024
1 parent f0720b7 commit 1651ee6
Showing 1 changed file with 3 additions and 26 deletions.
29 changes: 3 additions & 26 deletions src/initialize_dae.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,7 @@ end
function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
alg::OverrideInit, isinplace::Union{Val{true}, Val{false}})
initializeprob = prob.f.initializeprob
if initializeprob.f.sys !== nothing && prob.f.sys !== nothing
if initializeprob.u0 === nothing || isempty(initializeprob.u0)
initu0 = Float64[]
else
initu0vars = variable_symbols(initializeprob)
initu0order = variable_index.((initializeprob,), initu0vars)
# Variable symbols are not guaranteed to be in order
invpermute!(initu0vars, initu0order)
initu0 = getu(prob.f.initializeprob, initu0vars)(prob)
end
initp = remake_buffer(initializeprob, parameter_values(initializeprob),
Dict(sym => getu(prob, sym)(prob) for sym in parameter_symbols(initializeprob)))
initializeprob = remake(initializeprob; u0 = initu0, p = initp)
end
prob.f.initializeprob_init!(initializeprob, integrator)
# If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit
# Since then it's the case of not a DAE but has initializeprob
# In which case, it should be differentiable
Expand All @@ -168,19 +155,9 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)
nlsol = solve(initializeprob, alg)
if isinplace === Val{true}()
if prob.u0 !== nothing && !isempty(prob.u0)
integrator.u .= prob.f.initializeprobmap(nlsol)
end
if SciMLBase.has_initializeprob_updatep(prob.f)
prob.f.initializeprob_updatep!(integrator.p, nlsol)
end
prob.f.initializeprob_update!(integrator, nlsol)
elseif isinplace === Val{false}()
if prob.u0 !== nothing && !isempty(prob.u0)
integrator.u .= prob.f.initializeprobmap(nlsol)
end
if SciMLBase.has_initializeprob_updatep(prob.f)
prob.f.initializeprob_updatep!(integrator.p, nlsol)
end
prob.f.initializeprob_update!(integrator, nlsol)
else
error("Unreachable reached. Report this error.")
end
Expand Down

0 comments on commit 1651ee6

Please sign in to comment.