diff --git a/Project.toml b/Project.toml index 01d1c0e4..e6ad83d8 100644 --- a/Project.toml +++ b/Project.toml @@ -52,7 +52,7 @@ Logging = "1.10" NonlinearSolve = "3.8.1" ODEInterface = "0.5" OrdinaryDiffEq = "6.63" -PreallocationTools = "0.4" +PreallocationTools = "0.4.24" PrecompileTools = "1.2" Preferences = "1.4" Random = "1.10" diff --git a/src/BoundaryValueDiffEq.jl b/src/BoundaryValueDiffEq.jl index 1c750d6c..1b7a3d26 100644 --- a/src/BoundaryValueDiffEq.jl +++ b/src/BoundaryValueDiffEq.jl @@ -2,7 +2,7 @@ module BoundaryValueDiffEq import PrecompileTools: @compile_workload, @setup_workload -using ADTypes, Adapt, DiffEqBase, ForwardDiff, LinearAlgebra, NonlinearSolve, +using ADTypes, Adapt, ArrayInterface, DiffEqBase, ForwardDiff, LinearAlgebra, NonlinearSolve, OrdinaryDiffEq, Preferences, RecursiveArrayTools, Reexport, SciMLBase, Setfield, SparseDiffTools diff --git a/src/collocation.jl b/src/collocation.jl index 2ac85538..9c14fc3f 100644 --- a/src/collocation.jl +++ b/src/collocation.jl @@ -37,7 +37,7 @@ end @views function Φ( fᵢ_cache, k_discrete, f, TU::MIRKTableau, y, u, p, mesh, mesh_dt, stage::Int) (; c, v, x, b) = TU - residuals = [similar(yᵢ) for yᵢ in y[1:(end - 1)]] + residuals = [__similar(yᵢ) for yᵢ in y[1:(end - 1)]] tmp = get_tmp(fᵢ_cache, u) T = eltype(u) for i in eachindex(k_discrete) diff --git a/src/solve/mirk.jl b/src/solve/mirk.jl index a82dc4f9..d9dcb2c0 100644 --- a/src/solve/mirk.jl +++ b/src/solve/mirk.jl @@ -42,10 +42,10 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0, mesh_dt = diff(mesh) chunksize = pickchunksize(N * (Nig - 1)) - __alloc = @closure x -> __maybe_allocate_diffcache(vec(x), chunksize, alg.jac_alg) + __alloc = @closure x -> __maybe_allocate_diffcache(vec(zero(x)), chunksize, alg.jac_alg) - fᵢ_cache = __alloc(similar(X)) - fᵢ₂_cache = vec(similar(X)) + fᵢ_cache = __alloc(zero(X)) + fᵢ₂_cache = vec(zero(X)) # Don't flatten this here, since we need to expand it later if needed y₀ = __initial_guess_on_mesh(prob.u0, mesh, prob.p) @@ -54,7 +54,7 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0, TU, ITU = constructMIRK(alg, T) stage = alg_stage(alg) - k_discrete = [__maybe_allocate_diffcache(similar(X, N, stage), chunksize, alg.jac_alg) + k_discrete = [__maybe_allocate_diffcache(__similar(X, N, stage), chunksize, alg.jac_alg) for _ in 1:Nig] k_interp = VectorOfArray([similar(X, N, ITU.s_star - stage) for _ in 1:Nig]) @@ -308,7 +308,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo resid_bc = cache.bcresid_prototype L = length(resid_bc) - resid_collocation = similar(y, cache.M * (N - 1)) + resid_collocation = __similar(y, cache.M * (N - 1)) loss_bcₚ = (iip ? __Fix3 : Base.Fix2)(loss_bc, cache.p) loss_collocationₚ = (iip ? __Fix3 : Base.Fix2)(loss_collocation, cache.p) @@ -341,8 +341,8 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo Val(iip), jac_alg.nonbc_diffmode, sd_collocation, loss_collocationₚ, resid_collocation, y) - J_bc = init_jacobian(cache_bc) - J_c = init_jacobian(cache_collocation) + J_bc = zero(init_jacobian(cache_bc)) + J_c = zero(init_jacobian(cache_collocation)) if J_full_band === nothing jac_prototype = vcat(J_bc, J_c) else @@ -414,7 +414,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo lossₚ = iip ? ((du, u) -> loss(du, u, cache.p)) : (u -> loss(u, cache.p)) resid = vcat(@view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]), - similar(y, cache.M * (N - 1)), + __similar(y, cache.M * (N - 1)), @view(cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end])) L = length(cache.bcresid_prototype) @@ -428,7 +428,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo NoSparsityDetection() end diffcache = __sparse_jacobian_cache(Val(iip), jac_alg.diffmode, sd, lossₚ, resid, y) - jac_prototype = init_jacobian(diffcache) + jac_prototype = zero(init_jacobian(diffcache)) jac = if iip @closure (J, u, p) -> __mirk_2point_jacobian!( diff --git a/src/types.jl b/src/types.jl index 9ba70c3f..448f1367 100644 --- a/src/types.jl +++ b/src/types.jl @@ -155,8 +155,8 @@ end function __maybe_allocate_diffcache(x, chunksize, jac_alg) return __needs_diffcache(jac_alg) ? DiffCache(x, chunksize) : FakeDiffCache(x) end -__maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(similar(x.du), chunksize) -__maybe_allocate_diffcache(x::FakeDiffCache, _) = FakeDiffCache(similar(x.du)) +__maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(zero(x.du), chunksize) +__maybe_allocate_diffcache(x::FakeDiffCache, _) = FakeDiffCache(zero(x.du)) const MaybeDiffCache = Union{DiffCache, FakeDiffCache} diff --git a/src/utils.jl b/src/utils.jl index e8f7221f..62eba2c9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,7 +2,7 @@ recursive_length(x::Vector{<:AbstractArray}) = sum(length, x) recursive_length(x::Vector{<:MaybeDiffCache}) = sum(xᵢ -> length(xᵢ.u), x) function recursive_flatten(x::Vector{<:AbstractArray}) - y = similar(first(x), recursive_length(x)) + y = zero(first(x), recursive_length(x)) recursive_flatten!(y, x) return y end @@ -112,7 +112,7 @@ function __append_similar!(x::AbstractVector{<:AbstractArray}, n, _) N = n - length(x) N == 0 && return x N < 0 && throw(ArgumentError("Cannot append a negative number of elements")) - append!(x, [similar(last(x)) for _ in 1:N]) + append!(x, [zero(last(x)) for _ in 1:N]) return x end @@ -191,16 +191,21 @@ function __get_bcresid_prototype(::TwoPointBVProblem, prob::BVProblem, u) end function __get_bcresid_prototype(::StandardBVProblem, prob::BVProblem, u) prototype = prob.f.bcresid_prototype !== nothing ? prob.f.bcresid_prototype : - __zeros_like(u) + zero(u) return prototype, size(prototype) end -@inline function __fill_like(v, x, args...) +@inline function __similar(x, args...) y = similar(x, args...) + return zero(y) +end + +@inline function __fill_like(v, x) + y = __similar(x) fill!(y, v) return y end -@inline __zeros_like(args...) = __fill_like(0, args...) + @inline __ones_like(args...) = __fill_like(1, args...) @inline __safe_vec(x) = vec(x) diff --git a/test/misc/bigfloat_test.jl b/test/misc/bigfloat_test.jl new file mode 100644 index 00000000..53211b94 --- /dev/null +++ b/test/misc/bigfloat_test.jl @@ -0,0 +1,18 @@ +@testitem "BigFloat compatibility" begin + using BoundaryValueDiffEq + tspan = (0.0, pi / 2) + function simplependulum!(du, u, p, t) + θ = u[1] + dθ = u[2] + du[1] = dθ + du[2] = -9.81 * sin(θ) + end + function bc!(residual, u, p, t) + residual[1] = u[end ÷ 2][1] + big(pi / 2) + residual[2] = u[end][1] - big(pi / 2) + end + u0 = BigFloat.([pi / 2, pi / 2]) + prob = BVProblem(simplependulum!, bc!, u0, tspan) + sol = solve(prob, MIRK4(), dt = 0.05) + @test SciMLBase.successful_retcode(sol.retcode) +end