Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix support for BigFloat u0 #202

Merged
merged 8 commits into from
Aug 17, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module BoundaryValueDiffEq

import PrecompileTools: @compile_workload, @setup_workload

using ADTypes, Adapt, DiffEqBase, ForwardDiff, LinearAlgebra, NonlinearSolve,
using ADTypes, Adapt, ArrayInterface, DiffEqBase, ForwardDiff, LinearAlgebra, NonlinearSolve,
OrdinaryDiffEq, Preferences, RecursiveArrayTools, Reexport, SciMLBase, Setfield,
SparseDiffTools

Expand Down
2 changes: 1 addition & 1 deletion src/collocation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ end
@views function Φ(
fᵢ_cache, k_discrete, f, TU::MIRKTableau, y, u, p, mesh, mesh_dt, stage::Int)
(; c, v, x, b) = TU
residuals = [similar(yᵢ) for yᵢ in y[1:(end - 1)]]
residuals = [__similar(yᵢ) for yᵢ in y[1:(end - 1)]]
tmp = get_tmp(fᵢ_cache, u)
T = eltype(u)
for i in eachindex(k_discrete)
Expand Down
18 changes: 9 additions & 9 deletions src/solve/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
mesh_dt = diff(mesh)

chunksize = pickchunksize(N * (Nig - 1))
__alloc = @closure x -> __maybe_allocate_diffcache(vec(x), chunksize, alg.jac_alg)
__alloc = @closure x -> __maybe_allocate_diffcache(vec(__similar(x)), chunksize, alg.jac_alg)

fᵢ_cache = __alloc(similar(X))
fᵢ₂_cache = vec(similar(X))
fᵢ_cache = __alloc(__similar(X))
fᵢ₂_cache = vec(__similar(X))
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved

# Don't flatten this here, since we need to expand it later if needed
y₀ = __initial_guess_on_mesh(prob.u0, mesh, prob.p)
Expand All @@ -54,7 +54,7 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
TU, ITU = constructMIRK(alg, T)
stage = alg_stage(alg)

k_discrete = [__maybe_allocate_diffcache(similar(X, N, stage), chunksize, alg.jac_alg)
k_discrete = [__maybe_allocate_diffcache(__similar(X, N, stage), chunksize, alg.jac_alg)
for _ in 1:Nig]
k_interp = VectorOfArray([similar(X, N, ITU.s_star - stage) for _ in 1:Nig])

Expand Down Expand Up @@ -302,7 +302,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo

resid_bc = cache.bcresid_prototype
L = length(resid_bc)
resid_collocation = similar(y, cache.M * (N - 1))
resid_collocation = __similar(y, cache.M * (N - 1))

loss_bcₚ = (iip ? __Fix3 : Base.Fix2)(loss_bc, cache.p)
loss_collocationₚ = (iip ? __Fix3 : Base.Fix2)(loss_collocation, cache.p)
Expand Down Expand Up @@ -335,8 +335,8 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
Val(iip), jac_alg.nonbc_diffmode, sd_collocation,
loss_collocationₚ, resid_collocation, y)

J_bc = init_jacobian(cache_bc)
J_c = init_jacobian(cache_collocation)
J_bc = __init_bigfloat_array!!(init_jacobian(cache_bc))
J_c = __init_bigfloat_array!!(init_jacobian(cache_collocation))
if J_full_band === nothing
jac_prototype = vcat(J_bc, J_c)
else
Expand Down Expand Up @@ -408,7 +408,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
lossₚ = iip ? ((du, u) -> loss(du, u, cache.p)) : (u -> loss(u, cache.p))

resid = vcat(@view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]),
similar(y, cache.M * (N - 1)),
__similar(y, cache.M * (N - 1)),
@view(cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end]))
L = length(cache.bcresid_prototype)

Expand All @@ -422,7 +422,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
NoSparsityDetection()
end
diffcache = __sparse_jacobian_cache(Val(iip), jac_alg.diffmode, sd, lossₚ, resid, y)
jac_prototype = init_jacobian(diffcache)
jac_prototype = __init_bigfloat_array!!(init_jacobian(diffcache))

jac = if iip
@closure (J, u, p) -> __mirk_2point_jacobian!(
Expand Down
47 changes: 44 additions & 3 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,54 @@ end
du
end

# hacking DiffCache to handling with BigFloat case
@concrete struct BigFloatDiffCache{T <: AbstractArray, S <: AbstractArray}
du::T
dual_du::S
any_du::Vector{Any}
end

function BigFloatDiffCache(u::AbstractArray{T}, siz, chunk_sizes) where {T}
x = adapt(ArrayInterface.parameterless_type(u),
zeros(T, prod(chunk_sizes .+ 1) * prod(siz)))
xany = Any[]
BigFloatDiffCache(u, x, xany)
end
function BigFloatDiffCache(u::AbstractArray, N::Int = ForwardDiff.pickchunksize(length(u));
levels::Int = 1)
BigFloatDiffCache(u, size(u), N * ones(Int, levels))
end
BigFloatDiffCache(u::AbstractArray, N::AbstractArray{<:Int}) = BigFloatDiffCache(u, size(u), N)
function BigFloatDiffCache(u::AbstractArray, ::Type{Val{N}}; levels::Int = 1) where {N}
BigFloatDiffCache(u, N; levels)
end
BigFloatDiffCache(u::AbstractArray, ::Val{N}; levels::Int = 1) where {N} = BigFloatDiffCache(u, N; levels)

function get_tmp(dc::BigFloatDiffCache, u::T) where {T <: ForwardDiff.Dual}
nelem = length(dc.du)
PreallocationTools._restructure(dc.du, view(T.(dc.dual_du), 1:nelem))
end
function get_tmp(dc::BigFloatDiffCache, u::AbstractArray{T}) where {T <: ForwardDiff.Dual}
nelem = length(dc.du)
PreallocationTools._restructure(dc.du, view(T.(dc.dual_du), 1:nelem))
end
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved
function get_tmp(dc::BigFloatDiffCache, u::Union{Number, AbstractArray})
return dc.du
end
function get_tmp(dc::BigFloatDiffCache, ::Type{T}) where {T <: Number}
return dc.du
end
get_tmp(dc::Vector{BigFloat}, u::Vector{BigFloat}) = dc

function __maybe_allocate_diffcache(x, chunksize, jac_alg)
eltype(x) <: BigFloat && return (__needs_diffcache(jac_alg) ? BigFloatDiffCache(x, chunksize) : FakeDiffCache(x))
return __needs_diffcache(jac_alg) ? DiffCache(x, chunksize) : FakeDiffCache(x)
end
__maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(similar(x.du), chunksize)
__maybe_allocate_diffcache(x::FakeDiffCache, _) = FakeDiffCache(similar(x.du))
__maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(__similar(x.du), chunksize)
__maybe_allocate_diffcache(x::FakeDiffCache, _) = FakeDiffCache(__similar(x.du))
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved
__maybe_allocate_diffcache(x::BigFloatDiffCache, chunksize) = BigFloatDiffCache(x.du, chunksize)

const MaybeDiffCache = Union{DiffCache, FakeDiffCache}
const MaybeDiffCache = Union{DiffCache, FakeDiffCache, BigFloatDiffCache}

## get_tmp shows a warning as it should on cache exapansion, this behavior however is
## expected for adaptive BVP solvers so we write our own `get_tmp` and drop the warning logs
Expand Down
27 changes: 23 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ recursive_length(x::Vector{<:AbstractArray}) = sum(length, x)
recursive_length(x::Vector{<:MaybeDiffCache}) = sum(xᵢ -> length(xᵢ.u), x)

function recursive_flatten(x::Vector{<:AbstractArray})
y = similar(first(x), recursive_length(x))
y = __similar(first(x), recursive_length(x))
recursive_flatten!(y, x)
return y
end
Expand Down Expand Up @@ -112,7 +112,7 @@ function __append_similar!(x::AbstractVector{<:AbstractArray}, n, _)
N = n - length(x)
N == 0 && return x
N < 0 && throw(ArgumentError("Cannot append a negative number of elements"))
append!(x, [similar(last(x)) for _ in 1:N])
append!(x, [__similar(last(x)) for _ in 1:N])
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved
return x
end

Expand Down Expand Up @@ -195,12 +195,31 @@ function __get_bcresid_prototype(::StandardBVProblem, prob::BVProblem, u)
return prototype, size(prototype)
end

@inline function __fill_like(v, x, args...)
@inline function __similar(x, args...)
y = similar(x, args...)
return __init_bigfloat_array!!(y)
end

@inline function __init_bigfloat_array!!(x)
if ArrayInterface.can_setindex(x)
eltype(x) <: BigFloat && fill!(x, BigFloat(0))
return x
end
return x
end

@inline function __fill_like(v, x)
y = __similar(x)
fill!(y, v)
return y
end
@inline __zeros_like(args...) = __fill_like(0, args...)
@inline function __zeros_like(u)
if ArrayInterface.can_setindex(u)
eltype(u) <: BigFloat && __fill_like(BigFloat(0), u)
return u
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
end
return __fill_like(0, u)
end
@inline __ones_like(args...) = __fill_like(1, args...)

@inline __safe_vec(x) = vec(x)
Expand Down
18 changes: 18 additions & 0 deletions test/misc/bigfloat_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
@testitem "BigFloat compatibility" begin
using BoundaryValueDiffEq
tspan = (0.0, pi / 2)
function simplependulum!(du, u, p, t)
θ = u[1]
dθ = u[2]
du[1] = dθ
du[2] = -9.81 * sin(θ)
end
function bc!(residual, u, p, t)
residual[1] = u[end ÷ 2][1] + big(pi / 2)
residual[2] = u[end][1] - big(pi / 2)
end
u0 = BigFloat.([pi / 2, pi / 2])
prob = BVProblem(simplependulum!, bc!, u0, tspan)
sol = solve(prob, MIRK4(), dt = 0.05)
@test SciMLBase.successful_retcode(sol.retcode)
end
Loading