Skip to content

Commit

Permalink
fix: remake initialization problem during DAE initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 22, 2024
1 parent 46e3153 commit 4d8276a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 1 deletion.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"

[compat]
Expand Down Expand Up @@ -84,6 +85,7 @@ SparseArrays = "1.9"
SparseDiffTools = "2.3"
StaticArrayInterface = "1.2"
StaticArrays = "1.0"
SymbolicIndexingInterface = "0.3.16"
TruncatedStacktraces = "1.2"
julia = "1.10"

Expand Down
2 changes: 2 additions & 0 deletions src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ using ExponentialUtilities

using NonlinearSolve

using SymbolicIndexingInterface

# Required by temporary fix in not in-place methods with 12+ broadcasts
# `MVector` is used by Nordsieck forms
import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, SA
Expand Down
9 changes: 8 additions & 1 deletion src/initialize_dae.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,14 @@ end
function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
alg::OverrideInit, isinplace::Union{Val{true}, Val{false}})
initializeprob = prob.f.initializeprob

initu0vars = variable_symbols(initializeprob)
initu0order = variable_index.((initializeprob,), initu0vars)
# Variable symbols are not guaranteed to be in order
invpermute!(initu0vars, initu0order)
initu0 = getu(prob.f.initializeprob, initu0vars)(prob)
initp = remake_buffer(initializeprob, parameter_values(initializeprob),
Dict(sym => getu(prob, sym)(prob) for sym in parameter_symbols(initializeprob)))
initializeprob = remake(initializeprob; u0 = initu0, p = initp)
# If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit
# Since then it's the case of not a DAE but has initializeprob
# In which case, it should be differentiable
Expand Down

0 comments on commit 4d8276a

Please sign in to comment.