Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JET test for MIRKN methods #271

Merged
merged 5 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
BoundaryValueDiffEqODEInterfaceExt = "ODEInterface"

[compat]
ADTypes = "1"
ADTypes = "1.11"
Aqua = "0.8.9"
ArrayInterface = "7.18"
BoundaryValueDiffEqAscher = "1"
Expand Down
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqAscher/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"

[compat]
ADTypes = "1.9"
ADTypes = "1.11"
Adapt = "4.1.1"
AlmostBlockDiagonals = "0.1.10"
ArrayInterface = "7.18"
Expand Down
4 changes: 2 additions & 2 deletions lib/BoundaryValueDiffEqCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BoundaryValueDiffEqCore"
uuid = "56b672f2-a5fe-4263-ab2d-da677488eb3a"
authors = ["Qingyu Qu <[email protected]>"]
version = "1.5.0"
version = "1.6.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -24,7 +24,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"

[compat]
ADTypes = "1.9"
ADTypes = "1.11"
Adapt = "4.1.1"
ArrayInterface = "7.18"
Aqua = "0.8.9"
Expand Down
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqCore/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ end
end

# Construct BVP Solution
function __build_solution(prob::BVProblem, odesol, nlsol)
function __build_solution(prob::AbstractBVProblem, odesol, nlsol)
retcode = ifelse(SciMLBase.successful_retcode(nlsol), odesol.retcode, nlsol.retcode)
return SciMLBase.solution_new_original_retcode(odesol, nlsol, retcode, nlsol.resid)
end
Expand Down
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqFIRK/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"

[compat]
ADTypes = "1.9"
ADTypes = "1.11"
Adapt = "4.1.1"
Aqua = "0.8.7"
ArrayInterface = "7.18"
Expand Down
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqMIRK/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"

[compat]
ADTypes = "1.9"
ADTypes = "1.11"
Adapt = "4.1.1"
Aqua = "0.8.7"
ArrayInterface = "7.18"
Expand Down
4 changes: 2 additions & 2 deletions lib/BoundaryValueDiffEqMIRKN/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BoundaryValueDiffEqMIRKN"
uuid = "9255f1d6-53bf-473e-b6bd-23f1ff009da4"
authors = ["Qingyu Qu <[email protected]>"]
version = "1.2.0"
version = "1.3.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -29,7 +29,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"

[compat]
ADTypes = "1.9"
ADTypes = "1.11"
Adapt = "4.1.1"
Aqua = "0.8.7"
ArrayInterface = "7.18"
Expand Down
3 changes: 2 additions & 1 deletion lib/BoundaryValueDiffEqMIRKN/src/BoundaryValueDiffEqMIRKN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ using SciMLBase: SciMLBase, AbstractDiffEqInterpolation, AbstractBVProblem,
using Setfield: @set!, @set
using SparseArrays: sparse
using SparseDiffTools: init_jacobian, sparse_jacobian, sparse_jacobian_cache,
sparse_jacobian!, matrix_colors, SymbolicsSparsityDetection
sparse_jacobian!, matrix_colors, SymbolicsSparsityDetection,
NoSparsityDetection

@reexport using ADTypes, BoundaryValueDiffEqCore, SciMLBase

Expand Down
40 changes: 27 additions & 13 deletions lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,34 @@ function SciMLBase.__init(prob::SecondOrderBVProblem, alg::AbstractMIRKN;

return MIRKNCache{iip, T}(
alg_order(alg), stage, M, size(X), f, bc, prob_, prob.problem_type,
prob.p, alg, TU, bcresid_prototype, mesh, mesh_dt, k_discrete,
y, y₀, residual, fᵢ_cache, fᵢ₂_cache, resid_size, kwargs)
prob.p, alg, TU, bcresid_prototype, mesh, mesh_dt, k_discrete, y,
y₀, residual, fᵢ_cache, fᵢ₂_cache, resid_size, (; dt, kwargs...))
end

function __split_mirkn_kwargs(; dt, kwargs...)
return ((dt), (; kwargs...))
end

function SciMLBase.solve!(cache::MIRKNCache{iip, T}) where {iip, T}
(; mesh, M, p, prob, kwargs) = cache
nlprob = __construct_nlproblem(cache, vec(cache.y₀))
(_), kwargs = __split_mirkn_kwargs(; cache.kwargs...)
info::ReturnCode.T = ReturnCode.Success

sol_nlprob, info = __perform_mirkn_iteration(cache; kwargs...)

solu = ArrayPartition.(
cache.y₀.u[1:length(cache.mesh)], cache.y₀.u[(length(cache.mesh) + 1):end])
odesol = SciMLBase.build_solution(
cache.prob, cache.alg, cache.mesh, solu; retcode = info)
return __build_solution(cache.prob, odesol, sol_nlprob)
end

function __perform_mirkn_iteration(cache::MIRKNCache; nlsolve_kwargs = (;), kwargs...)
nlprob::NonlinearProblem = __construct_nlproblem(cache, vec(cache.y₀))
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, cache.alg.nlsolve)
sol_nlprob = __solve(nlprob, nlsolve_alg; kwargs..., alias_u0 = true)
sol_nlprob = __solve(nlprob, nlsolve_alg; kwargs..., nlsolve_kwargs..., alias_u0 = true)
recursive_unflatten!(cache.y₀, sol_nlprob.u)
solu = ArrayPartition.(cache.y₀.u[1:length(mesh)], cache.y₀.u[(length(mesh) + 1):end])
return SciMLBase.build_solution(
prob, cache.alg, mesh, solu; retcode = sol_nlprob.retcode)

return sol_nlprob, sol_nlprob.retcode
end

function __construct_nlproblem(cache::MIRKNCache{iip}, y::AbstractVector) where {iip}
Expand All @@ -115,19 +130,18 @@ function __construct_nlproblem(cache::MIRKNCache{iip}, y::AbstractVector) where
sd = alg.jac_alg.diffmode isa AutoSparse ? SymbolicsSparsityDetection() :
NoSparsityDetection()
ad = alg.jac_alg.diffmode
lz = reduce(vcat, cache.y₀)
jac_cache = __sparse_jacobian_cache(Val(iip), ad, sd, lossₚ, lz, lz)
lz = __similar(y)
jac_cache = __sparse_jacobian_cache(Val(iip), ad, sd, lossₚ, lz, y)
jac_prototype = init_jacobian(jac_cache)
jac = if iip
@closure (J, u, p) -> __mirkn_mpoint_jacobian!(J, u, ad, jac_cache, lossₚ, lz)
else
@closure (u, p) -> __mirkn_mpoint_jacobian(jac_prototype, u, ad, jac_cache, lossₚ)
end
resid_prototype = zero(lz)
_nlf = NonlinearFunction{iip}(
nlf = NonlinearFunction{iip}(
loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype)
nlprob::NonlinearProblem = NonlinearProblem(_nlf, lz, cache.p)
return nlprob
return __internal_nlsolve_problem(cache.prob, resid_prototype, lz, nlf, lz, cache.p)
end

function __mirkn_2point_jacobian!(J, x, diffmode, diffcache, loss_fn::L, resid) where {L}
Expand Down
160 changes: 91 additions & 69 deletions lib/BoundaryValueDiffEqMIRKN/test/mirkn_basic_tests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,94 @@
@testsetup module MIRKNConvergenceTests

using BoundaryValueDiffEqMIRKN

for order in (4, 6)
s = Symbol("MIRKN$(order)")
@eval mirkn_solver(::Val{$order}, args...; kwargs...) = $(s)(args...; kwargs...)
end

function f!(ddu, du, u, p, t)
ddu[1] = u[1]
end
function f(du, u, p, t)
return u[1]
end
function bc!(res, du, u, p, t)
res[1] = u(0.0)[1] - 1
res[2] = u(1.0)[1]
end
function bc(du, u, p, t)
return [u(0.0)[1] - 1, u(1.0)[1]]
end
function bc_indexing!(res, du, u, p, t)
res[1] = u[:, 1][1] - 1
res[2] = u[:, end][1]
end
function bc_indexing(du, u, p, t)
return [u[:, 1][1] - 1, u[:, end][1]]
end
function bc_a!(res, du, u, p)
res[1] = u[1] - 1
end
function bc_b!(res, du, u, p)
res[1] = u[1]
end
function bc_a(du, u, p)
return [u[1] - 1]
end
function bc_b(du, u, p)
return [u[1]]
end
analytical_solution = (u0, p, t) -> [
(exp(-t) - exp(t - 2)) / (1 - exp(-2)), (-exp(-t) - exp(t - 2)) / (1 - exp(-2))]
u0 = [1.0]
tspan = (0.0, 1.0)
testTol = 0.2
bvpf1 = DynamicalBVPFunction(f!, bc!, analytic = analytical_solution)
bvpf2 = DynamicalBVPFunction(f, bc, analytic = analytical_solution)
bvpf3 = DynamicalBVPFunction(f!, bc_indexing!, analytic = analytical_solution)
bvpf4 = DynamicalBVPFunction(f, bc_indexing, analytic = analytical_solution)
bvpf5 = DynamicalBVPFunction(f!, (bc_a!, bc_b!), analytic = analytical_solution,
bcresid_prototype = (zeros(1), zeros(1)), twopoint = Val(true))
bvpf6 = DynamicalBVPFunction(f, (bc_a, bc_b), analytic = analytical_solution,
bcresid_prototype = (zeros(1), zeros(1)), twopoint = Val(true))
probArr = [SecondOrderBVProblem(bvpf1, u0, tspan), SecondOrderBVProblem(bvpf2, u0, tspan),
SecondOrderBVProblem(bvpf3, u0, tspan), SecondOrderBVProblem(bvpf4, u0, tspan),
TwoPointSecondOrderBVProblem(bvpf5, u0, tspan),
TwoPointSecondOrderBVProblem(bvpf6, u0, tspan)]
dts = 1 .// 2 .^ (3:-1:1)

export probArr, dts, testTol, mirkn_solver

end

@testitem "Convergence on Linear" setup=[MIRKNConvergenceTests] begin
using LinearAlgebra, DiffEqDevTools

@testset "Problem: $i" for i in (1, 2, 3, 4, 5, 6)
prob = probArr[i]
@testset "MIRKN$order" for order in (4, 6)
sim = test_convergence(
dts, prob, mirkn_solver(Val(order)); abstol = 1e-8, reltol = 1e-8)
@test sim.𝒪est[:final]≈order atol=testTol
end
end
end

@testitem "JET tests" setup=[MIRKNConvergenceTests] begin
using JET

@testset "Problem: $i" for i in 1:6
prob = probArr[i]
@testset "MIRKN$order" for order in (4, 6)
solver = mirkn_solver(Val(order); nlsolve = NewtonRaphson(),
jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2)))
@test_call target_modules=(BoundaryValueDiffEqMIRKN,) solve(
prob, solver; dt = 0.2)
end
end
end

@testitem "Example problem from paper" begin
using BoundaryValueDiffEqMIRKN

Expand Down Expand Up @@ -81,72 +172,3 @@
end
end
end

@testitem "Convergence on Linear" begin
using LinearAlgebra, DiffEqDevTools

for order in (4, 6)
s = Symbol("MIRKN$(order)")
@eval mirkn_solver(::Val{$order}, args...; kwargs...) = $(s)(args...; kwargs...)
end

function f!(ddu, du, u, p, t)
ddu[1] = u[1]
end
function f(du, u, p, t)
return u[1]
end
function bc!(res, du, u, p, t)
res[1] = u(0.0)[1] - 1
res[2] = u(1.0)[1]
end
function bc(du, u, p, t)
return [u(0.0)[1] - 1, u(1.0)[1]]
end
function bc_indexing!(res, du, u, p, t)
res[1] = u[:, 1][1] - 1
res[2] = u[:, end][1]
end
function bc_indexing(du, u, p, t)
return [u[:, 1][1] - 1, u[:, end][1]]
end
function bc_a!(res, du, u, p)
res[1] = u[1] - 1
end
function bc_b!(res, du, u, p)
res[1] = u[1]
end
function bc_a(du, u, p)
return [u[1] - 1]
end
function bc_b(du, u, p)
return [u[1]]
end
analytical_solution = (u0, p, t) -> [
(exp(-t) - exp(t - 2)) / (1 - exp(-2)), (-exp(-t) - exp(t - 2)) / (1 - exp(-2))]
u0 = [1.0]
tspan = (0.0, 1.0)
testTol = 0.2
bvpf1 = DynamicalBVPFunction(f!, bc!, analytic = analytical_solution)
bvpf2 = DynamicalBVPFunction(f, bc, analytic = analytical_solution)
bvpf3 = DynamicalBVPFunction(f!, bc_indexing!, analytic = analytical_solution)
bvpf4 = DynamicalBVPFunction(f, bc_indexing, analytic = analytical_solution)
bvpf5 = DynamicalBVPFunction(f!, (bc_a!, bc_b!), analytic = analytical_solution,
bcresid_prototype = (zeros(1), zeros(1)), twopoint = Val(true))
bvpf6 = DynamicalBVPFunction(f, (bc_a, bc_b), analytic = analytical_solution,
bcresid_prototype = (zeros(1), zeros(1)), twopoint = Val(true))
probArr = [
SecondOrderBVProblem(bvpf1, u0, tspan), SecondOrderBVProblem(bvpf2, u0, tspan),
SecondOrderBVProblem(bvpf3, u0, tspan), SecondOrderBVProblem(bvpf4, u0, tspan),
TwoPointSecondOrderBVProblem(bvpf5, u0, tspan),
TwoPointSecondOrderBVProblem(bvpf6, u0, tspan)]
dts = 1 .// 2 .^ (3:-1:1)
@testset "Problem: $i" for i in (1, 2, 3, 4, 5, 6)
prob = probArr[i]
@testset "MIRKN$order" for order in (4, 6)
sim = test_convergence(
dts, prob, mirkn_solver(Val(order)); abstol = 1e-8, reltol = 1e-8)
@test sim.𝒪est[:final]≈order atol=testTol
end
end
end
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqShooting/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"

[compat]
ADTypes = "1.9"
ADTypes = "1.11"
Adapt = "4.1.1"
Aqua = "0.8.7"
ArrayInterface = "7.18"
Expand Down
Loading