Skip to content

Commit

Permalink
feat: set integrator.du in OverrideInit for DAEProblems
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 28, 2024
1 parent 14e6cb6 commit 4074647
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 2 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ SparseDiffTools = "2"
Static = "0.8, 1"
StaticArrayInterface = "1.2"
StaticArrays = "1.0"
SymbolicIndexingInterface = "0.3.31"
TruncatedStacktraces = "1.2"
julia = "1.10"

Expand All @@ -166,9 +167,10 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["Calculus", "ComponentArrays", "Symbolics", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DiffEqDevTools", "ODEProblemLibrary", "ElasticArrays", "InteractiveUtils", "ParameterizedFunctions", "PoissonRandom", "Printf", "Random", "ReverseDiff", "SafeTestsets", "SparseArrays", "Statistics", "Test", "Unitful", "ModelingToolkit", "Pkg", "NLsolve"]
test = ["Calculus", "ComponentArrays", "Symbolics", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DiffEqDevTools", "ODEProblemLibrary", "ElasticArrays", "InteractiveUtils", "ParameterizedFunctions", "PoissonRandom", "Printf", "Random", "ReverseDiff", "SafeTestsets", "SparseArrays", "Statistics", "Test", "Unitful", "ModelingToolkit", "Pkg", "NLsolve", "SymbolicIndexingInterface"]
13 changes: 12 additions & 1 deletion lib/OrdinaryDiffEqCore/src/initialize_dae.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,23 @@ function _initialize_dae!(integrator, prob::AbstractDEProblem,

nlsolve_alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)

u0, p, success = SciMLBase.get_initial_values(prob, integrator, prob.f, alg, isinplace; nlsolve_alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol)
if prob isa DAEProblem
du0, u0, p, success = SciMLBase.get_initial_values(prob, integrator, prob.f, alg, isinplace; nlsolve_alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol, return_du0 = true)
else
u0, p, success = SciMLBase.get_initial_values(prob, integrator, prob.f, alg, isinplace; nlsolve_alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol)
du0 = nothing
end

if isinplace === Val{true}()
integrator.u .= u0
if du0 !== nothing
integrator.du .= du0
end
elseif isinplace === Val{false}()
integrator.u = u0
if du0 !== nothing
integrator.du = du0
end
else
error("Unreachable reached. Report this error.")
end
Expand Down
35 changes: 35 additions & 0 deletions test/interface/dae_initialization_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,38 @@ prob = ODEProblem(f, ones(3), (0.0, 1.0))
integrator = init(prob, Rodas5P(),
initializealg = ShampineCollocationInit(1.0, BrokenNLSolve()))
@test all(isequal(reinterpret(Float64, 0xDEADBEEFDEADBEEF)), integrator.u)

@testset "OverrideInit for DAEProblem" begin
function daerhs(du, u, p, t)
return [u[1] * t + p, u[1]^2 - u[2]^2]
end
# unknowns are u[2], p, D(u[1]), D(u[2]). Parameters are u[1], t
initprob = NonlinearProblem([1.0, 1.0, 1.0, 1.0], [1.0, 0.0]) do x, _p
u2, p, du1, du2 = x
u1, t = _p
return [u1^3 - u2^3, p^2 - 2p + 1, du1 - u1 * t - p, 2u1 * du1 - 2u2 * du2]
end

update_initializeprob! = function (iprob, integ)
iprob.p[1] = integ.u[1]
iprob.p[2] = integ.t
end
initprobmap = function (nlsol)
return [parameter_values(nlsol)[1], nlsol.u[1]]
end
initprobpmap = function (_, nlsol)
return nlsol.u[2]
end
initprob_du0map = function (nlsol)
return nlsol.u[3:4]
end
initialization_data = SciMLBase.OverrideInitData(
initprob, update_initializeprob!, initprobmap, initprobpmap, initprob_du0map)
fn = DAEFunction(daerhs; initialization_data)
prob = DAEProblem(fn, [0.0, 0.0], [2.0, 0.0], (0.0, 1.0), 0.0)
integ = init(prob, DImplicitEuler())
@test integ.du [1.0, 1.0]
@test integ.u [2.0, 2.0]
@test integ.p 1.0
@test integ.sol.retcode != SciMLBase.ReturnCode.InitialFailure
end

0 comments on commit 4074647

Please sign in to comment.