Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix support for BigFloat u0 #202

Merged
merged 8 commits into from
Aug 17, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Logging = "1.10"
NonlinearSolve = "3.8.1"
ODEInterface = "0.5"
OrdinaryDiffEq = "6.63"
PreallocationTools = "0.4"
PreallocationTools = "0.4.24"
PrecompileTools = "1.2"
Preferences = "1.4"
Random = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module BoundaryValueDiffEq

import PrecompileTools: @compile_workload, @setup_workload

using ADTypes, Adapt, DiffEqBase, ForwardDiff, LinearAlgebra, NonlinearSolve,
using ADTypes, Adapt, ArrayInterface, DiffEqBase, ForwardDiff, LinearAlgebra, NonlinearSolve,
OrdinaryDiffEq, Preferences, RecursiveArrayTools, Reexport, SciMLBase, Setfield,
SparseDiffTools

Expand Down
2 changes: 1 addition & 1 deletion src/collocation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ end
@views function Φ(
fᵢ_cache, k_discrete, f, TU::MIRKTableau, y, u, p, mesh, mesh_dt, stage::Int)
(; c, v, x, b) = TU
residuals = [similar(yᵢ) for yᵢ in y[1:(end - 1)]]
residuals = [__similar(yᵢ) for yᵢ in y[1:(end - 1)]]
tmp = get_tmp(fᵢ_cache, u)
T = eltype(u)
for i in eachindex(k_discrete)
Expand Down
18 changes: 9 additions & 9 deletions src/solve/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
mesh_dt = diff(mesh)

chunksize = pickchunksize(N * (Nig - 1))
__alloc = @closure x -> __maybe_allocate_diffcache(vec(x), chunksize, alg.jac_alg)
__alloc = @closure x -> __maybe_allocate_diffcache(vec(__similar(x)), chunksize, alg.jac_alg)

fᵢ_cache = __alloc(similar(X))
fᵢ₂_cache = vec(similar(X))
fᵢ_cache = __alloc(__similar(X))
fᵢ₂_cache = vec(__similar(X))
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved

# Don't flatten this here, since we need to expand it later if needed
y₀ = __initial_guess_on_mesh(prob.u0, mesh, prob.p)
Expand All @@ -54,7 +54,7 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
TU, ITU = constructMIRK(alg, T)
stage = alg_stage(alg)

k_discrete = [__maybe_allocate_diffcache(similar(X, N, stage), chunksize, alg.jac_alg)
k_discrete = [__maybe_allocate_diffcache(__similar(X, N, stage), chunksize, alg.jac_alg)
for _ in 1:Nig]
k_interp = VectorOfArray([similar(X, N, ITU.s_star - stage) for _ in 1:Nig])

Expand Down Expand Up @@ -308,7 +308,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo

resid_bc = cache.bcresid_prototype
L = length(resid_bc)
resid_collocation = similar(y, cache.M * (N - 1))
resid_collocation = __similar(y, cache.M * (N - 1))

loss_bcₚ = (iip ? __Fix3 : Base.Fix2)(loss_bc, cache.p)
loss_collocationₚ = (iip ? __Fix3 : Base.Fix2)(loss_collocation, cache.p)
Expand Down Expand Up @@ -341,8 +341,8 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
Val(iip), jac_alg.nonbc_diffmode, sd_collocation,
loss_collocationₚ, resid_collocation, y)

J_bc = init_jacobian(cache_bc)
J_c = init_jacobian(cache_collocation)
J_bc = __similar(init_jacobian(cache_bc))
J_c = __similar(init_jacobian(cache_collocation))
if J_full_band === nothing
jac_prototype = vcat(J_bc, J_c)
else
Expand Down Expand Up @@ -414,7 +414,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
lossₚ = iip ? ((du, u) -> loss(du, u, cache.p)) : (u -> loss(u, cache.p))

resid = vcat(@view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]),
similar(y, cache.M * (N - 1)),
__similar(y, cache.M * (N - 1)),
@view(cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end]))
L = length(cache.bcresid_prototype)

Expand All @@ -428,7 +428,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
NoSparsityDetection()
end
diffcache = __sparse_jacobian_cache(Val(iip), jac_alg.diffmode, sd, lossₚ, resid, y)
jac_prototype = init_jacobian(diffcache)
jac_prototype = __similar(init_jacobian(diffcache))

jac = if iip
@closure (J, u, p) -> __mirk_2point_jacobian!(
Expand Down
4 changes: 2 additions & 2 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ end
function __maybe_allocate_diffcache(x, chunksize, jac_alg)
return __needs_diffcache(jac_alg) ? DiffCache(x, chunksize) : FakeDiffCache(x)
end
__maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(similar(x.du), chunksize)
__maybe_allocate_diffcache(x::FakeDiffCache, _) = FakeDiffCache(similar(x.du))
__maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(__similar(x.du), chunksize)
__maybe_allocate_diffcache(x::FakeDiffCache, _) = FakeDiffCache(__similar(x.du))
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved

const MaybeDiffCache = Union{DiffCache, FakeDiffCache}

Expand Down
20 changes: 16 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ recursive_length(x::Vector{<:AbstractArray}) = sum(length, x)
recursive_length(x::Vector{<:MaybeDiffCache}) = sum(xᵢ -> length(xᵢ.u), x)

function recursive_flatten(x::Vector{<:AbstractArray})
y = similar(first(x), recursive_length(x))
y = __similar(first(x), recursive_length(x))
recursive_flatten!(y, x)
return y
end
Expand Down Expand Up @@ -112,7 +112,7 @@ function __append_similar!(x::AbstractVector{<:AbstractArray}, n, _)
N = n - length(x)
N == 0 && return x
N < 0 && throw(ArgumentError("Cannot append a negative number of elements"))
append!(x, [similar(last(x)) for _ in 1:N])
append!(x, [__similar(last(x)) for _ in 1:N])
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved
return x
end

Expand Down Expand Up @@ -195,12 +195,24 @@ function __get_bcresid_prototype(::StandardBVProblem, prob::BVProblem, u)
return prototype, size(prototype)
end

@inline function __fill_like(v, x, args...)
@inline function __similar(x, args...)
y = similar(x, args...)
eltype(y) <: BigFloat && fill!(y, BigFloat(0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason to not always fill with zeros?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zeros is better, just change them all. While NonlinearSolve is doing the similar thing like specialize on bigfloat and do __similar things in SciML/NonlinearSolve.jl#438, maybe we can change to zeros everywhere in NonlinearSolve.jl as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note this will also fix issues with non-fully defined functions. For example, if a user doesn't write to du[i], we should still converge now, while before you'd get something random. So a few bugs should be fixed, some bugs avoided, etc. Avoiding uninitialized memory is just a good idea all around and NonlinearSolve needs to do this change too.

return y
end

@inline function __fill_like(v, x)
y = __similar(x)
fill!(y, v)
return y
end
@inline __zeros_like(args...) = __fill_like(0, args...)
@inline function __zeros_like(u)
if ArrayInterface.can_setindex(u)
eltype(u) <: BigFloat && __fill_like(BigFloat(0), u)
return u
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
end
return __fill_like(0, u)
end
@inline __ones_like(args...) = __fill_like(1, args...)

@inline __safe_vec(x) = vec(x)
Expand Down
18 changes: 18 additions & 0 deletions test/misc/bigfloat_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
@testitem "BigFloat compatibility" begin
using BoundaryValueDiffEq
tspan = (0.0, pi / 2)
function simplependulum!(du, u, p, t)
θ = u[1]
dθ = u[2]
du[1] = dθ
du[2] = -9.81 * sin(θ)
end
function bc!(residual, u, p, t)
residual[1] = u[end ÷ 2][1] + big(pi / 2)
residual[2] = u[end][1] - big(pi / 2)
end
u0 = BigFloat.([pi / 2, pi / 2])
prob = BVProblem(simplependulum!, bc!, u0, tspan)
sol = solve(prob, MIRK4(), dt = 0.05)
@test SciMLBase.successful_retcode(sol.retcode)
end
Loading