Skip to content

Commit

Permalink
Merge pull request #202 from ErikQQY/qqy/bigfloat
Browse files Browse the repository at this point in the history
Fix support for BigFloat u0
  • Loading branch information
ChrisRackauckas authored Aug 17, 2024
2 parents 5171740 + 40656db commit 53f8566
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Logging = "1.10"
NonlinearSolve = "3.8.1"
ODEInterface = "0.5"
OrdinaryDiffEq = "6.63"
PreallocationTools = "0.4"
PreallocationTools = "0.4.24"
PrecompileTools = "1.2"
Preferences = "1.4"
Random = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module BoundaryValueDiffEq

import PrecompileTools: @compile_workload, @setup_workload

using ADTypes, Adapt, DiffEqBase, ForwardDiff, LinearAlgebra, NonlinearSolve,
using ADTypes, Adapt, ArrayInterface, DiffEqBase, ForwardDiff, LinearAlgebra, NonlinearSolve,
OrdinaryDiffEq, Preferences, RecursiveArrayTools, Reexport, SciMLBase, Setfield,
SparseDiffTools

Expand Down
2 changes: 1 addition & 1 deletion src/collocation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ end
@views function Φ(
fᵢ_cache, k_discrete, f, TU::MIRKTableau, y, u, p, mesh, mesh_dt, stage::Int)
(; c, v, x, b) = TU
residuals = [similar(yᵢ) for yᵢ in y[1:(end - 1)]]
residuals = [__similar(yᵢ) for yᵢ in y[1:(end - 1)]]
tmp = get_tmp(fᵢ_cache, u)
T = eltype(u)
for i in eachindex(k_discrete)
Expand Down
18 changes: 9 additions & 9 deletions src/solve/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
mesh_dt = diff(mesh)

chunksize = pickchunksize(N * (Nig - 1))
__alloc = @closure x -> __maybe_allocate_diffcache(vec(x), chunksize, alg.jac_alg)
__alloc = @closure x -> __maybe_allocate_diffcache(vec(zero(x)), chunksize, alg.jac_alg)

fᵢ_cache = __alloc(similar(X))
fᵢ₂_cache = vec(similar(X))
fᵢ_cache = __alloc(zero(X))
fᵢ₂_cache = vec(zero(X))

# Don't flatten this here, since we need to expand it later if needed
y₀ = __initial_guess_on_mesh(prob.u0, mesh, prob.p)
Expand All @@ -54,7 +54,7 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
TU, ITU = constructMIRK(alg, T)
stage = alg_stage(alg)

k_discrete = [__maybe_allocate_diffcache(similar(X, N, stage), chunksize, alg.jac_alg)
k_discrete = [__maybe_allocate_diffcache(__similar(X, N, stage), chunksize, alg.jac_alg)
for _ in 1:Nig]
k_interp = VectorOfArray([similar(X, N, ITU.s_star - stage) for _ in 1:Nig])

Expand Down Expand Up @@ -308,7 +308,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo

resid_bc = cache.bcresid_prototype
L = length(resid_bc)
resid_collocation = similar(y, cache.M * (N - 1))
resid_collocation = __similar(y, cache.M * (N - 1))

loss_bcₚ = (iip ? __Fix3 : Base.Fix2)(loss_bc, cache.p)
loss_collocationₚ = (iip ? __Fix3 : Base.Fix2)(loss_collocation, cache.p)
Expand Down Expand Up @@ -341,8 +341,8 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
Val(iip), jac_alg.nonbc_diffmode, sd_collocation,
loss_collocationₚ, resid_collocation, y)

J_bc = init_jacobian(cache_bc)
J_c = init_jacobian(cache_collocation)
J_bc = zero(init_jacobian(cache_bc))
J_c = zero(init_jacobian(cache_collocation))
if J_full_band === nothing
jac_prototype = vcat(J_bc, J_c)
else
Expand Down Expand Up @@ -414,7 +414,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
lossₚ = iip ? ((du, u) -> loss(du, u, cache.p)) : (u -> loss(u, cache.p))

resid = vcat(@view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]),
similar(y, cache.M * (N - 1)),
__similar(y, cache.M * (N - 1)),
@view(cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end]))
L = length(cache.bcresid_prototype)

Expand All @@ -428,7 +428,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
NoSparsityDetection()
end
diffcache = __sparse_jacobian_cache(Val(iip), jac_alg.diffmode, sd, lossₚ, resid, y)
jac_prototype = init_jacobian(diffcache)
jac_prototype = zero(init_jacobian(diffcache))

jac = if iip
@closure (J, u, p) -> __mirk_2point_jacobian!(
Expand Down
4 changes: 2 additions & 2 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ end
function __maybe_allocate_diffcache(x, chunksize, jac_alg)
return __needs_diffcache(jac_alg) ? DiffCache(x, chunksize) : FakeDiffCache(x)
end
__maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(similar(x.du), chunksize)
__maybe_allocate_diffcache(x::FakeDiffCache, _) = FakeDiffCache(similar(x.du))
__maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(zero(x.du), chunksize)
__maybe_allocate_diffcache(x::FakeDiffCache, _) = FakeDiffCache(zero(x.du))

const MaybeDiffCache = Union{DiffCache, FakeDiffCache}

Expand Down
15 changes: 10 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ recursive_length(x::Vector{<:AbstractArray}) = sum(length, x)
recursive_length(x::Vector{<:MaybeDiffCache}) = sum(xᵢ -> length(xᵢ.u), x)

function recursive_flatten(x::Vector{<:AbstractArray})
y = similar(first(x), recursive_length(x))
y = zero(first(x), recursive_length(x))
recursive_flatten!(y, x)
return y
end
Expand Down Expand Up @@ -112,7 +112,7 @@ function __append_similar!(x::AbstractVector{<:AbstractArray}, n, _)
N = n - length(x)
N == 0 && return x
N < 0 && throw(ArgumentError("Cannot append a negative number of elements"))
append!(x, [similar(last(x)) for _ in 1:N])
append!(x, [zero(last(x)) for _ in 1:N])
return x
end

Expand Down Expand Up @@ -191,16 +191,21 @@ function __get_bcresid_prototype(::TwoPointBVProblem, prob::BVProblem, u)
end
function __get_bcresid_prototype(::StandardBVProblem, prob::BVProblem, u)
prototype = prob.f.bcresid_prototype !== nothing ? prob.f.bcresid_prototype :
__zeros_like(u)
zero(u)
return prototype, size(prototype)
end

@inline function __fill_like(v, x, args...)
@inline function __similar(x, args...)
y = similar(x, args...)
return zero(y)
end

@inline function __fill_like(v, x)
y = __similar(x)
fill!(y, v)
return y
end
@inline __zeros_like(args...) = __fill_like(0, args...)

@inline __ones_like(args...) = __fill_like(1, args...)

@inline __safe_vec(x) = vec(x)
Expand Down
18 changes: 18 additions & 0 deletions test/misc/bigfloat_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
@testitem "BigFloat compatibility" begin
using BoundaryValueDiffEq
tspan = (0.0, pi / 2)
function simplependulum!(du, u, p, t)
θ = u[1]
= u[2]
du[1] =
du[2] = -9.81 * sin(θ)
end
function bc!(residual, u, p, t)
residual[1] = u[end ÷ 2][1] + big(pi / 2)
residual[2] = u[end][1] - big(pi / 2)
end
u0 = BigFloat.([pi / 2, pi / 2])
prob = BVProblem(simplependulum!, bc!, u0, tspan)
sol = solve(prob, MIRK4(), dt = 0.05)
@test SciMLBase.successful_retcode(sol.retcode)
end

0 comments on commit 53f8566

Please sign in to comment.