Skip to content

Commit

Permalink
Merge pull request #2094 from SciML/interpolation_output_types
Browse files Browse the repository at this point in the history
Fix interpolation output types for dynamical ODEs
  • Loading branch information
ChrisRackauckas authored Dec 27, 2023
2 parents 4406430 + dbb3e02 commit f1b8d90
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 9 deletions.
39 changes: 30 additions & 9 deletions src/dense/generic_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -687,8 +687,13 @@ Herimte Interpolation, chosen if no other dispatch for ode_interpolant
@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{false}}, idxs::Nothing,
T::Type{Val{0}}, differential_vars) # Default interpolant is Hermite
#@.. broadcast=false (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2])
@inbounds (1 - Θ) * y₀ + Θ * y₁ +
differential_vars .**- 1) * ((1 - 2Θ) * (y₁ - y₀) +- 1) * dt * k[1] + Θ * dt * k[2]))
if all(differential_vars)
@inbounds (1 - Θ) * y₀ + Θ * y₁ +
*- 1) * ((1 - 2Θ) * (y₁ - y₀) +- 1) * dt * k[1] + Θ * dt * k[2]))
else
@inbounds (1 - Θ) * y₀ + Θ * y₁ +
differential_vars .**- 1) * ((1 - 2Θ) * (y₁ - y₀) +- 1) * dt * k[1] + Θ * dt * k[2]))
end
end

@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{true}}, idxs::Nothing,
Expand Down Expand Up @@ -755,10 +760,17 @@ Herimte Interpolation, chosen if no other dispatch for ode_interpolant
@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{false}}, idxs::Nothing,
T::Type{Val{1}}, differential_vars) # Default interpolant is Hermite
#@.. broadcast=false k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
@inbounds (.!differential_vars).*(y₁ - y₀)/dt + differential_vars .*(
k[1] +
Θ * (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ +
Θ * (3 * dt * k[1] + 3 * dt * k[2] + 6 * y₀ - 6 * y₁) + 6 * y₁) / dt)
if all(differential_vars)
@inbounds (
k[1] +
Θ * (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ +
Θ * (3 * dt * k[1] + 3 * dt * k[2] + 6 * y₀ - 6 * y₁) + 6 * y₁) / dt)
else
@inbounds (.!differential_vars).*(y₁ - y₀)/dt + differential_vars .*(
k[1] +
Θ * (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ +
Θ * (3 * dt * k[1] + 3 * dt * k[2] + 6 * y₀ - 6 * y₁) + 6 * y₁) / dt)
end
end

@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{true}}, idxs::Nothing,
Expand Down Expand Up @@ -826,8 +838,13 @@ Herimte Interpolation, chosen if no other dispatch for ode_interpolant
@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{false}}, idxs::Nothing,
T::Type{Val{2}}, differential_vars) # Default interpolant is Hermite
#@.. broadcast=false (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
@inbounds differential_vars .* (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ +
Θ * (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) + 6 * y₁) / (dt * dt)
if all(differential_vars)
@inbounds (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ +
Θ * (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) + 6 * y₁) / (dt * dt)
else
@inbounds differential_vars .* (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ +
Θ * (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) + 6 * y₁) / (dt * dt)
end
end

@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{true}}, idxs::Nothing,
Expand Down Expand Up @@ -887,7 +904,11 @@ Herimte Interpolation, chosen if no other dispatch for ode_interpolant
@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{false}}, idxs::Nothing,
T::Type{Val{3}}, differential_vars) # Default interpolant is Hermite
#@.. broadcast=false (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
@inbounds differential_vars .* (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) / (dt * dt * dt)
if all(differential_vars)
@inbounds (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) / (dt * dt * dt)
else
@inbounds differential_vars .* (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) / (dt * dt * dt)
end
end

@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{true}}, idxs::Nothing,
Expand Down
22 changes: 22 additions & 0 deletions test/interface/interpolation_output_types.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using OrdinaryDiffEq, Test

# in terms of the voltage across all three elements
rlc1!(v′,v,(R,L,C),t) = -(v′/R + v/L)/C
identity_f(v,u,p,t) = v # needed to form second order dynamical ODE

setup_rlc(R,L,C;v_init=0.0,v′_init=0.0,tspan=(0.0,50.0)) =
DynamicalODEProblem{false}(rlc1!,identity_f,v′_init,v_init,tspan,(R,L,C))

# simulate voltage impulse
R,L,C = 10, 0.3, 2

prob = setup_rlc(R,L,C,v_init=2.0)

res1 = solve(prob,Vern8(),dt=1/10,saveat=1/10)
res3 = solve(prob,CalvoSanz4(),dt=1/10,saveat=1/10)

sol = solve(prob,CalvoSanz4(),dt=1/10)
@test sol(0.32) isa OrdinaryDiffEq.ArrayPartition
@test sol(0.32, Val{1}) isa OrdinaryDiffEq.ArrayPartition
@test sol(0.32, Val{2}) isa OrdinaryDiffEq.ArrayPartition
@test sol(0.32, Val{3}) isa OrdinaryDiffEq.ArrayPartition
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ end
@time @safetestset "Complex Tests" include("interface/complex_tests.jl")
@time @safetestset "Ndim Complex Tests" include("interface/ode_ndim_complex_tests.jl")
@time @safetestset "Number Type Tests" include("interface/ode_numbertype_tests.jl")
@time @safetestset "Interpolation Output Type Tests" include("interface/interpolation_output_types.jl")
@time @safetestset "Stiffness Detection Tests" include("interface/stiffness_detection_test.jl")
@time @safetestset "Composite Interpolation Tests" include("interface/composite_interpolation.jl")
@time @safetestset "Export tests" include("interface/export_tests.jl")
Expand Down

0 comments on commit f1b8d90

Please sign in to comment.