Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix support for BigFloat u0 #202

Merged
merged 8 commits into from
Aug 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module BoundaryValueDiffEq

import PrecompileTools: @compile_workload, @setup_workload

using ADTypes, Adapt, DiffEqBase, ForwardDiff, LinearAlgebra, NonlinearSolve,
using ADTypes, Adapt, ArrayInterface, DiffEqBase, ForwardDiff, LinearAlgebra, NonlinearSolve,
OrdinaryDiffEq, Preferences, RecursiveArrayTools, Reexport, SciMLBase, Setfield,
SparseDiffTools

Expand Down
2 changes: 1 addition & 1 deletion src/collocation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ end
@views function Φ(
fᵢ_cache, k_discrete, f, TU::MIRKTableau, y, u, p, mesh, mesh_dt, stage::Int)
(; c, v, x, b) = TU
residuals = [similar(yᵢ) for yᵢ in y[1:(end - 1)]]
residuals = [__similar(yᵢ) for yᵢ in y[1:(end - 1)]]
tmp = get_tmp(fᵢ_cache, u)
T = eltype(u)
for i in eachindex(k_discrete)
Expand Down
25 changes: 13 additions & 12 deletions src/solve/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
mesh_dt = diff(mesh)

chunksize = pickchunksize(N * (Nig - 1))
__alloc = @closure x -> __maybe_allocate_diffcache(vec(x), chunksize, alg.jac_alg)
__alloc = @closure x -> __maybe_allocate_diffcache(vec(__similar(x)), chunksize, alg.jac_alg)

fᵢ_cache = __alloc(similar(X))
fᵢ₂_cache = vec(similar(X))
fᵢ_cache = __alloc(__similar(X))
fᵢ₂_cache = vec(__similar(X))
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved

# Don't flatten this here, since we need to expand it later if needed
y₀ = __initial_guess_on_mesh(prob.u0, mesh, prob.p, false)
Expand All @@ -54,9 +54,10 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
TU, ITU = constructMIRK(alg, T)
stage = alg_stage(alg)

k_discrete = [__maybe_allocate_diffcache(similar(X, N, stage), chunksize, alg.jac_alg)
k_discrete = [__maybe_allocate_diffcache(__similar(X, N, stage), chunksize, alg.jac_alg)
for _ in 1:Nig]
k_interp = [similar(X, N, ITU.s_star - stage) for _ in 1:Nig]
k_interp = [__similar(X, ifelse(adaptive, N, 0), ifelse(adaptive, ITU.s_star - stage, 0))
for _ in 1:Nig]

bcresid_prototype, resid₁_size = __get_bcresid_prototype(prob.problem_type, prob, X)

Expand All @@ -70,8 +71,8 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
nothing
end

defect = [similar(X, ifelse(adaptive, N, 0)) for _ in 1:Nig]
new_stages = [similar(X, N) for _ in 1:Nig]
defect = [__similar(X, ifelse(adaptive, N, 0)) for _ in 1:Nig]
new_stages = [__similar(X, ifelse(adaptive, N, 0)) for _ in 1:Nig]

# Transform the functions to handle non-vector inputs
bcresid_prototype = __vec(bcresid_prototype)
Expand Down Expand Up @@ -302,7 +303,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo

resid_bc = cache.bcresid_prototype
L = length(resid_bc)
resid_collocation = similar(y, cache.M * (N - 1))
resid_collocation = __similar(y, cache.M * (N - 1))

loss_bcₚ = (iip ? __Fix3 : Base.Fix2)(loss_bc, cache.p)
loss_collocationₚ = (iip ? __Fix3 : Base.Fix2)(loss_collocation, cache.p)
Expand Down Expand Up @@ -335,8 +336,8 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
Val(iip), jac_alg.nonbc_diffmode, sd_collocation,
loss_collocationₚ, resid_collocation, y)

J_bc = init_jacobian(cache_bc)
J_c = init_jacobian(cache_collocation)
J_bc = __init_bigfloat_array!!(init_jacobian(cache_bc))
J_c = __init_bigfloat_array!!(init_jacobian(cache_collocation))
if J_full_band === nothing
jac_prototype = vcat(J_bc, J_c)
else
Expand Down Expand Up @@ -408,7 +409,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
lossₚ = iip ? ((du, u) -> loss(du, u, cache.p)) : (u -> loss(u, cache.p))

resid = vcat(@view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]),
similar(y, cache.M * (N - 1)),
__similar(y, cache.M * (N - 1)),
@view(cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end]))
L = length(cache.bcresid_prototype)

Expand All @@ -422,7 +423,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
NoSparsityDetection()
end
diffcache = __sparse_jacobian_cache(Val(iip), jac_alg.diffmode, sd, lossₚ, resid, y)
jac_prototype = init_jacobian(diffcache)
jac_prototype = __init_bigfloat_array!!(init_jacobian(diffcache))

jac = if iip
@closure (J, u, p) -> __mirk_2point_jacobian!(
Expand Down
4 changes: 2 additions & 2 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ end
function __maybe_allocate_diffcache(x, chunksize, jac_alg)
return __needs_diffcache(jac_alg) ? DiffCache(x, chunksize) : FakeDiffCache(x)
end
__maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(similar(x.du), chunksize)
__maybe_allocate_diffcache(x::FakeDiffCache, _) = FakeDiffCache(similar(x.du))
__maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(__similar(x.du), chunksize)
__maybe_allocate_diffcache(x::FakeDiffCache, _) = FakeDiffCache(__similar(x.du))
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved

const MaybeDiffCache = Union{DiffCache, FakeDiffCache}

Expand Down
27 changes: 23 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ recursive_length(x::Vector{<:AbstractArray}) = sum(length, x)
recursive_length(x::Vector{<:MaybeDiffCache}) = sum(xᵢ -> length(xᵢ.u), x)

function recursive_flatten(x::Vector{<:AbstractArray})
y = similar(first(x), recursive_length(x))
y = __similar(first(x), recursive_length(x))
recursive_flatten!(y, x)
return y
end
Expand Down Expand Up @@ -110,7 +110,7 @@ function __append_similar!(x::AbstractVector{<:AbstractArray}, n, _)
N = n - length(x)
N == 0 && return x
N < 0 && throw(ArgumentError("Cannot append a negative number of elements"))
append!(x, [similar(last(x)) for _ in 1:N])
append!(x, [__similar(last(x)) for _ in 1:N])
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved
return x
end

Expand Down Expand Up @@ -185,12 +185,31 @@ function __get_bcresid_prototype(::StandardBVProblem, prob::BVProblem, u)
return prototype, size(prototype)
end

@inline function __fill_like(v, x, args...)
@inline function __similar(x, args...)
y = similar(x, args...)
return __init_bigfloat_array!!(y)
end

@inline function __init_bigfloat_array!!(x)
if ArrayInterface.can_setindex(x)
eltype(x) <: BigFloat && fill!(x, BigFloat(0))
return x
end
return x
end

@inline function __fill_like(v, x)
y = __similar(x)
fill!(y, v)
return y
end
@inline __zeros_like(args...) = __fill_like(0, args...)
@inline function __zeros_like(u)
if ArrayInterface.can_setindex(u)
eltype(u) <: BigFloat && __fill_like(BigFloat(0), u)
return u
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
end
return __fill_like(0, u)
end
@inline __ones_like(args...) = __fill_like(1, args...)

@inline __safe_vec(x) = vec(x)
Expand Down
17 changes: 17 additions & 0 deletions test/misc/bigfloat_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
@testitem "BigFloat compatibility" begin
using BoundaryValueDiffEq
tspan = (0.0, pi / 2)
function simplependulum!(du, u, p, t)
θ = u[1]
dθ = u[2]
du[1] = dθ
du[2] = -9.81 * sin(θ)
end
function bc!(residual, u, p, t)
residual[1] = u[end ÷ 2][1] + pi / 2
residual[2] = u[end][1] - pi / 2
end
u0 = BigFloat.([pi / 2, pi / 2])
prob = BVProblem(simplependulum!, bc!, u0, tspan)
sol = solve(prob, MIRK4(), dt = 0.05)
end
Loading