Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix support for BigFloat u0 #202

Merged
merged 8 commits into from
Aug 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Logging = "1.10"
NonlinearSolve = "3.8.1"
ODEInterface = "0.5"
OrdinaryDiffEq = "6.63"
PreallocationTools = "0.4"
PreallocationTools = "0.4.24"
PrecompileTools = "1.2"
Preferences = "1.4"
Random = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module BoundaryValueDiffEq

import PrecompileTools: @compile_workload, @setup_workload

using ADTypes, Adapt, DiffEqBase, ForwardDiff, LinearAlgebra, NonlinearSolve,
using ADTypes, Adapt, ArrayInterface, DiffEqBase, ForwardDiff, LinearAlgebra, NonlinearSolve,
OrdinaryDiffEq, Preferences, RecursiveArrayTools, Reexport, SciMLBase, Setfield,
SparseDiffTools

Expand Down
2 changes: 1 addition & 1 deletion src/collocation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ end
@views function Φ(
fᵢ_cache, k_discrete, f, TU::MIRKTableau, y, u, p, mesh, mesh_dt, stage::Int)
(; c, v, x, b) = TU
residuals = [similar(yᵢ) for yᵢ in y[1:(end - 1)]]
residuals = [__similar(yᵢ) for yᵢ in y[1:(end - 1)]]
tmp = get_tmp(fᵢ_cache, u)
T = eltype(u)
for i in eachindex(k_discrete)
Expand Down
18 changes: 9 additions & 9 deletions src/solve/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
mesh_dt = diff(mesh)

chunksize = pickchunksize(N * (Nig - 1))
__alloc = @closure x -> __maybe_allocate_diffcache(vec(x), chunksize, alg.jac_alg)
__alloc = @closure x -> __maybe_allocate_diffcache(vec(zero(x)), chunksize, alg.jac_alg)

fᵢ_cache = __alloc(similar(X))
fᵢ₂_cache = vec(similar(X))
fᵢ_cache = __alloc(zero(X))
fᵢ₂_cache = vec(zero(X))

# Don't flatten this here, since we need to expand it later if needed
y₀ = __initial_guess_on_mesh(prob.u0, mesh, prob.p)
Expand All @@ -54,7 +54,7 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
TU, ITU = constructMIRK(alg, T)
stage = alg_stage(alg)

k_discrete = [__maybe_allocate_diffcache(similar(X, N, stage), chunksize, alg.jac_alg)
k_discrete = [__maybe_allocate_diffcache(__similar(X, N, stage), chunksize, alg.jac_alg)
for _ in 1:Nig]
k_interp = VectorOfArray([similar(X, N, ITU.s_star - stage) for _ in 1:Nig])

Expand Down Expand Up @@ -308,7 +308,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo

resid_bc = cache.bcresid_prototype
L = length(resid_bc)
resid_collocation = similar(y, cache.M * (N - 1))
resid_collocation = __similar(y, cache.M * (N - 1))

loss_bcₚ = (iip ? __Fix3 : Base.Fix2)(loss_bc, cache.p)
loss_collocationₚ = (iip ? __Fix3 : Base.Fix2)(loss_collocation, cache.p)
Expand Down Expand Up @@ -341,8 +341,8 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
Val(iip), jac_alg.nonbc_diffmode, sd_collocation,
loss_collocationₚ, resid_collocation, y)

J_bc = init_jacobian(cache_bc)
J_c = init_jacobian(cache_collocation)
J_bc = zero(init_jacobian(cache_bc))
J_c = zero(init_jacobian(cache_collocation))
if J_full_band === nothing
jac_prototype = vcat(J_bc, J_c)
else
Expand Down Expand Up @@ -414,7 +414,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
lossₚ = iip ? ((du, u) -> loss(du, u, cache.p)) : (u -> loss(u, cache.p))

resid = vcat(@view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]),
similar(y, cache.M * (N - 1)),
__similar(y, cache.M * (N - 1)),
@view(cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end]))
L = length(cache.bcresid_prototype)

Expand All @@ -428,7 +428,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
NoSparsityDetection()
end
diffcache = __sparse_jacobian_cache(Val(iip), jac_alg.diffmode, sd, lossₚ, resid, y)
jac_prototype = init_jacobian(diffcache)
jac_prototype = zero(init_jacobian(diffcache))

jac = if iip
@closure (J, u, p) -> __mirk_2point_jacobian!(
Expand Down
4 changes: 2 additions & 2 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ end
function __maybe_allocate_diffcache(x, chunksize, jac_alg)
return __needs_diffcache(jac_alg) ? DiffCache(x, chunksize) : FakeDiffCache(x)
end
__maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(similar(x.du), chunksize)
__maybe_allocate_diffcache(x::FakeDiffCache, _) = FakeDiffCache(similar(x.du))
__maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(zero(x.du), chunksize)
__maybe_allocate_diffcache(x::FakeDiffCache, _) = FakeDiffCache(zero(x.du))

const MaybeDiffCache = Union{DiffCache, FakeDiffCache}

Expand Down
15 changes: 10 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ recursive_length(x::Vector{<:AbstractArray}) = sum(length, x)
recursive_length(x::Vector{<:MaybeDiffCache}) = sum(xᵢ -> length(xᵢ.u), x)

function recursive_flatten(x::Vector{<:AbstractArray})
y = similar(first(x), recursive_length(x))
y = zero(first(x), recursive_length(x))
recursive_flatten!(y, x)
return y
end
Expand Down Expand Up @@ -112,7 +112,7 @@ function __append_similar!(x::AbstractVector{<:AbstractArray}, n, _)
N = n - length(x)
N == 0 && return x
N < 0 && throw(ArgumentError("Cannot append a negative number of elements"))
append!(x, [similar(last(x)) for _ in 1:N])
append!(x, [zero(last(x)) for _ in 1:N])
return x
end

Expand Down Expand Up @@ -191,16 +191,21 @@ function __get_bcresid_prototype(::TwoPointBVProblem, prob::BVProblem, u)
end
function __get_bcresid_prototype(::StandardBVProblem, prob::BVProblem, u)
prototype = prob.f.bcresid_prototype !== nothing ? prob.f.bcresid_prototype :
__zeros_like(u)
zero(u)
return prototype, size(prototype)
end

@inline function __fill_like(v, x, args...)
@inline function __similar(x, args...)
y = similar(x, args...)
return zero(y)
end

@inline function __fill_like(v, x)
y = __similar(x)
fill!(y, v)
return y
end
@inline __zeros_like(args...) = __fill_like(0, args...)

@inline __ones_like(args...) = __fill_like(1, args...)

@inline __safe_vec(x) = vec(x)
Expand Down
18 changes: 18 additions & 0 deletions test/misc/bigfloat_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
@testitem "BigFloat compatibility" begin
using BoundaryValueDiffEq
tspan = (0.0, pi / 2)
function simplependulum!(du, u, p, t)
θ = u[1]
dθ = u[2]
du[1] = dθ
du[2] = -9.81 * sin(θ)
end
function bc!(residual, u, p, t)
residual[1] = u[end ÷ 2][1] + big(pi / 2)
residual[2] = u[end][1] - big(pi / 2)
end
u0 = BigFloat.([pi / 2, pi / 2])
prob = BVProblem(simplependulum!, bc!, u0, tspan)
sol = solve(prob, MIRK4(), dt = 0.05)
@test SciMLBase.successful_retcode(sol.retcode)
end
Loading