From cd532ece939fb4301ebd593f38b79e37bfa932e1 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 23 Dec 2024 13:22:25 -0500 Subject: [PATCH] add support for DDEAliasSpecifier --- Project.toml | 2 +- src/DelayDiffEq.jl | 1 + src/solve.jl | 39 +++++++++++++++++++++++++++++++++++++- test/interface/aliasing.jl | 23 ++++++++++++++++++++++ 4 files changed, 63 insertions(+), 2 deletions(-) create mode 100644 test/interface/aliasing.jl diff --git a/Project.toml b/Project.toml index 196f2e8..eba16e6 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ OrdinaryDiffEqNonlinearSolve = "1.2.2" OrdinaryDiffEqRosenbrock = "1.2.0" RecursiveArrayTools = "2, 3" Reexport = "0.2, 1.0" -SciMLBase = "2.68" +SciMLBase = "2.69" SimpleNonlinearSolve = "0.1, 1, 2" SimpleUnPack = "1" SymbolicIndexingInterface = "0.3.36" diff --git a/src/DelayDiffEq.jl b/src/DelayDiffEq.jl index 9ecc323..48b877d 100644 --- a/src/DelayDiffEq.jl +++ b/src/DelayDiffEq.jl @@ -25,6 +25,7 @@ using OrdinaryDiffEqNonlinearSolve: NLNewton, NLAnderson, NLFunctional, Abstract using OrdinaryDiffEqRosenbrock: RosenbrockMutableCache import SciMLBase +using SciMLBase: DDEAliasSpecifier export Discontinuity, MethodOfSteps diff --git a/src/solve.jl b/src/solve.jl index ea4f5d0..ce8636e 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -62,6 +62,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDDEProblem, allow_extrapolation = OrdinaryDiffEqCore.alg_extrapolates(alg), initialize_integrator = true, alias_u0 = false, + alias = DDEAliasSpecifier(), # keyword arguments for DDEs discontinuity_interp_points::Int = 10, discontinuity_abstol = eltype(prob.tspan)(1 // Int64(10)^12), @@ -107,6 +108,42 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDDEProblem, # unpack problem @unpack f, u0, h, tspan, p, neutral, constant_lags, dependent_lags = prob + use_old_kwargs = haskey(kwargs, :alias_u0) + + if haskey(kwargs, :alias_u0) + aliases = DDEAliasSpecifier() + message = "`alias_u0` keyword argument is deprecated, to set `alias_u0`, + please use an ODEAliasSpecifier, e.g. `solve(prob, alias = ODEAliasSpecifier(alias_u0 = true))" + Base.depwarn(message, :init) + Base.depwarn(message, :solve) + aliases = DDEAliasSpecifier(alias_u0 = values(kwargs).alias_u0) + else + # If alias isa Bool, all fields of ODEAliases set to alias + if alias isa Bool + aliases = DDEAliasSpecifier(alias = alias) + elseif alias isa DDEAliasSpecifier + aliases = alias + end + end + + if isnothing(aliases.alias_f) || aliases.alias_f + f = f + else + f = deepcopy(f) + end + + if isnothing(aliases.alias_p) || aliases.alias_p + p = p + else + p = recursivecopy(p) + end + + if !isnothing(aliases.alias_u0) && aliases.alias_u0 + u = prob.u0 + else + u = recursivecopy(prob.u0) + end + # determine type and direction of time tType = eltype(tspan) t0 = first(tspan) @@ -131,7 +168,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDDEProblem, # get states (possibly different from the ODE integrator!) u, uprev, uprev2 = u_uprev_uprev2(u0, alg; - alias_u0 = alias_u0, + alias_u0 = aliases.alias_u0, adaptive = adaptive, allow_extrapolation = allow_extrapolation, calck = calck) diff --git a/test/interface/aliasing.jl b/test/interface/aliasing.jl new file mode 100644 index 0000000..9a2a3fa --- /dev/null +++ b/test/interface/aliasing.jl @@ -0,0 +1,23 @@ +using DelayDiffEq, DDEProblemLibrary + +# For now, testing if the old keyword works the same as the new alias keyword +prob_ip = prob_dde_constant_1delay_ip +prob_scalar = prob_dde_constant_1delay_scalar +ts = 0:0.1:10 + +noreuse = NLNewton(fast_convergence_cutoff = 0) + +const working_algs = [ImplicitMidpoint(), SSPSDIRK2(), KenCarp5(nlsolve = noreuse), + ImplicitEuler(nlsolve = noreuse), Trapezoid(nlsolve = noreuse), + TRBDF2(nlsolve = noreuse)] + +@testset "Algorithm $(nameof(typeof(alg)))" for alg in working_algs + println(nameof(typeof(alg))) + + stepsalg = MethodOfSteps(alg) + sol_new_alias = solve(prob_ip, stepsalg; dt = 0.1, alias = DDEAliasSpecifier(alias_u0 = true)) + sol_old_alias = solve( + prob_ip, stepsalg; dt = 0.1, alias = alias_u0 = true) + + @test sol_new_alias == sol_old_alias +end \ No newline at end of file