From 491b4ca965f11c1e10b4bc71554a571f9d4c567a Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Sat, 27 Jul 2024 13:44:17 +0800 Subject: [PATCH 1/6] Fix support for BigFloat u0 Signed-off-by: ErikQQY <2283984853@qq.com> --- src/BoundaryValueDiffEq.jl | 2 +- src/collocation.jl | 2 +- src/solve/mirk.jl | 25 +++++++++++++------------ src/types.jl | 4 ++-- src/utils.jl | 27 +++++++++++++++++++++++---- test/misc/bigfloat_test.jl | 17 +++++++++++++++++ 6 files changed, 57 insertions(+), 20 deletions(-) create mode 100644 test/misc/bigfloat_test.jl diff --git a/src/BoundaryValueDiffEq.jl b/src/BoundaryValueDiffEq.jl index fba2b302..13a174f3 100644 --- a/src/BoundaryValueDiffEq.jl +++ b/src/BoundaryValueDiffEq.jl @@ -2,7 +2,7 @@ module BoundaryValueDiffEq import PrecompileTools: @compile_workload, @setup_workload -using ADTypes, Adapt, DiffEqBase, ForwardDiff, LinearAlgebra, NonlinearSolve, +using ADTypes, Adapt, ArrayInterface, DiffEqBase, ForwardDiff, LinearAlgebra, NonlinearSolve, OrdinaryDiffEq, Preferences, RecursiveArrayTools, Reexport, SciMLBase, Setfield, SparseDiffTools diff --git a/src/collocation.jl b/src/collocation.jl index 2ac85538..9c14fc3f 100644 --- a/src/collocation.jl +++ b/src/collocation.jl @@ -37,7 +37,7 @@ end @views function Φ( fᵢ_cache, k_discrete, f, TU::MIRKTableau, y, u, p, mesh, mesh_dt, stage::Int) (; c, v, x, b) = TU - residuals = [similar(yᵢ) for yᵢ in y[1:(end - 1)]] + residuals = [__similar(yᵢ) for yᵢ in y[1:(end - 1)]] tmp = get_tmp(fᵢ_cache, u) T = eltype(u) for i in eachindex(k_discrete) diff --git a/src/solve/mirk.jl b/src/solve/mirk.jl index 725e4074..8506ad15 100644 --- a/src/solve/mirk.jl +++ b/src/solve/mirk.jl @@ -42,10 +42,10 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0, mesh_dt = diff(mesh) chunksize = pickchunksize(N * (Nig - 1)) - __alloc = @closure x -> __maybe_allocate_diffcache(vec(x), chunksize, alg.jac_alg) + __alloc = @closure x -> __maybe_allocate_diffcache(vec(__similar(x)), chunksize, alg.jac_alg) - fᵢ_cache = __alloc(similar(X)) - fᵢ₂_cache = vec(similar(X)) + fᵢ_cache = __alloc(__similar(X)) + fᵢ₂_cache = vec(__similar(X)) # Don't flatten this here, since we need to expand it later if needed y₀ = __initial_guess_on_mesh(prob.u0, mesh, prob.p, false) @@ -54,9 +54,10 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0, TU, ITU = constructMIRK(alg, T) stage = alg_stage(alg) - k_discrete = [__maybe_allocate_diffcache(similar(X, N, stage), chunksize, alg.jac_alg) + k_discrete = [__maybe_allocate_diffcache(__similar(X, N, stage), chunksize, alg.jac_alg) for _ in 1:Nig] - k_interp = [similar(X, N, ITU.s_star - stage) for _ in 1:Nig] + k_interp = [__similar(X, ifelse(adaptive, N, 0), ifelse(adaptive, ITU.s_star - stage, 0)) + for _ in 1:Nig] bcresid_prototype, resid₁_size = __get_bcresid_prototype(prob.problem_type, prob, X) @@ -70,8 +71,8 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0, nothing end - defect = [similar(X, ifelse(adaptive, N, 0)) for _ in 1:Nig] - new_stages = [similar(X, N) for _ in 1:Nig] + defect = [__similar(X, ifelse(adaptive, N, 0)) for _ in 1:Nig] + new_stages = [__similar(X, ifelse(adaptive, N, 0)) for _ in 1:Nig] # Transform the functions to handle non-vector inputs bcresid_prototype = __vec(bcresid_prototype) @@ -302,7 +303,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo resid_bc = cache.bcresid_prototype L = length(resid_bc) - resid_collocation = similar(y, cache.M * (N - 1)) + resid_collocation = __similar(y, cache.M * (N - 1)) loss_bcₚ = (iip ? __Fix3 : Base.Fix2)(loss_bc, cache.p) loss_collocationₚ = (iip ? __Fix3 : Base.Fix2)(loss_collocation, cache.p) @@ -335,8 +336,8 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo Val(iip), jac_alg.nonbc_diffmode, sd_collocation, loss_collocationₚ, resid_collocation, y) - J_bc = init_jacobian(cache_bc) - J_c = init_jacobian(cache_collocation) + J_bc = __init_bigfloat_array!!(init_jacobian(cache_bc)) + J_c = __init_bigfloat_array!!(init_jacobian(cache_collocation)) if J_full_band === nothing jac_prototype = vcat(J_bc, J_c) else @@ -408,7 +409,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo lossₚ = iip ? ((du, u) -> loss(du, u, cache.p)) : (u -> loss(u, cache.p)) resid = vcat(@view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]), - similar(y, cache.M * (N - 1)), + __similar(y, cache.M * (N - 1)), @view(cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end])) L = length(cache.bcresid_prototype) @@ -422,7 +423,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo NoSparsityDetection() end diffcache = __sparse_jacobian_cache(Val(iip), jac_alg.diffmode, sd, lossₚ, resid, y) - jac_prototype = init_jacobian(diffcache) + jac_prototype = __init_bigfloat_array!!(init_jacobian(diffcache)) jac = if iip @closure (J, u, p) -> __mirk_2point_jacobian!( diff --git a/src/types.jl b/src/types.jl index 9ba70c3f..1f758742 100644 --- a/src/types.jl +++ b/src/types.jl @@ -155,8 +155,8 @@ end function __maybe_allocate_diffcache(x, chunksize, jac_alg) return __needs_diffcache(jac_alg) ? DiffCache(x, chunksize) : FakeDiffCache(x) end -__maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(similar(x.du), chunksize) -__maybe_allocate_diffcache(x::FakeDiffCache, _) = FakeDiffCache(similar(x.du)) +__maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(__similar(x.du), chunksize) +__maybe_allocate_diffcache(x::FakeDiffCache, _) = FakeDiffCache(__similar(x.du)) const MaybeDiffCache = Union{DiffCache, FakeDiffCache} diff --git a/src/utils.jl b/src/utils.jl index f9d10d2d..2b2e0cd2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,7 +2,7 @@ recursive_length(x::Vector{<:AbstractArray}) = sum(length, x) recursive_length(x::Vector{<:MaybeDiffCache}) = sum(xᵢ -> length(xᵢ.u), x) function recursive_flatten(x::Vector{<:AbstractArray}) - y = similar(first(x), recursive_length(x)) + y = __similar(first(x), recursive_length(x)) recursive_flatten!(y, x) return y end @@ -110,7 +110,7 @@ function __append_similar!(x::AbstractVector{<:AbstractArray}, n, _) N = n - length(x) N == 0 && return x N < 0 && throw(ArgumentError("Cannot append a negative number of elements")) - append!(x, [similar(last(x)) for _ in 1:N]) + append!(x, [__similar(last(x)) for _ in 1:N]) return x end @@ -185,12 +185,31 @@ function __get_bcresid_prototype(::StandardBVProblem, prob::BVProblem, u) return prototype, size(prototype) end -@inline function __fill_like(v, x, args...) +@inline function __similar(x, args...) y = similar(x, args...) + return __init_bigfloat_array!!(y) +end + +@inline function __init_bigfloat_array!!(x) + if ArrayInterface.can_setindex(x) + eltype(x) <: BigFloat && fill!(x, BigFloat(0)) + return x + end + return x +end + +@inline function __fill_like(v, x) + y = __similar(x) fill!(y, v) return y end -@inline __zeros_like(args...) = __fill_like(0, args...) +@inline function __zeros_like(u) + if ArrayInterface.can_setindex(u) + eltype(u) <: BigFloat && __fill_like(BigFloat(0), u) + return u + end + return __fill_like(0, u) +end @inline __ones_like(args...) = __fill_like(1, args...) @inline __safe_vec(x) = vec(x) diff --git a/test/misc/bigfloat_test.jl b/test/misc/bigfloat_test.jl new file mode 100644 index 00000000..0b7e9137 --- /dev/null +++ b/test/misc/bigfloat_test.jl @@ -0,0 +1,17 @@ +@testitem "BigFloat compatibility" begin + using BoundaryValueDiffEq + tspan = (0.0, pi / 2) + function simplependulum!(du, u, p, t) + θ = u[1] + dθ = u[2] + du[1] = dθ + du[2] = -9.81 * sin(θ) + end + function bc!(residual, u, p, t) + residual[1] = u[end ÷ 2][1] + pi / 2 + residual[2] = u[end][1] - pi / 2 + end + u0 = BigFloat.([pi / 2, pi / 2]) + prob = BVProblem(simplependulum!, bc!, u0, tspan) + sol = solve(prob, MIRK4(), dt = 0.05) +end From 85bfb6b5b395d93ba5e68619a280949de2226a5c Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Tue, 13 Aug 2024 21:05:19 +0800 Subject: [PATCH 2/6] finally working Signed-off-by: ErikQQY <2283984853@qq.com> --- src/solve/mirk.jl | 1 - src/types.jl | 43 +++++++++++++++++++++++++++++++++++++- test/misc/bigfloat_test.jl | 5 +++-- 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/solve/mirk.jl b/src/solve/mirk.jl index 8506ad15..6c3e1c9f 100644 --- a/src/solve/mirk.jl +++ b/src/solve/mirk.jl @@ -70,7 +70,6 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0, else nothing end - defect = [__similar(X, ifelse(adaptive, N, 0)) for _ in 1:Nig] new_stages = [__similar(X, ifelse(adaptive, N, 0)) for _ in 1:Nig] diff --git a/src/types.jl b/src/types.jl index 1f758742..ea2b5dff 100644 --- a/src/types.jl +++ b/src/types.jl @@ -152,13 +152,54 @@ end du end +# hacking DiffCache to handling with BigFloat case +@concrete struct BigFloatDiffCache{T <: AbstractArray, S <: AbstractArray} + du::T + dual_du::S + any_du::Vector{Any} +end + +function BigFloatDiffCache(u::AbstractArray{T}, siz, chunk_sizes) where {T} + x = adapt(ArrayInterface.parameterless_type(u), + zeros(T, prod(chunk_sizes .+ 1) * prod(siz))) + xany = Any[] + BigFloatDiffCache(u, x, xany) +end +function BigFloatDiffCache(u::AbstractArray, N::Int = ForwardDiff.pickchunksize(length(u)); + levels::Int = 1) + BigFloatDiffCache(u, size(u), N * ones(Int, levels)) +end +BigFloatDiffCache(u::AbstractArray, N::AbstractArray{<:Int}) = BigFloatDiffCache(u, size(u), N) +function BigFloatDiffCache(u::AbstractArray, ::Type{Val{N}}; levels::Int = 1) where {N} + BigFloatDiffCache(u, N; levels) +end +BigFloatDiffCache(u::AbstractArray, ::Val{N}; levels::Int = 1) where {N} = BigFloatDiffCache(u, N; levels) + +function get_tmp(dc::BigFloatDiffCache, u::T) where {T <: ForwardDiff.Dual} + nelem = length(dc.du) + PreallocationTools._restructure(dc.du, view(T.(dc.dual_du), 1:nelem)) +end +function get_tmp(dc::BigFloatDiffCache, u::AbstractArray{T}) where {T <: ForwardDiff.Dual} + nelem = length(dc.du) + PreallocationTools._restructure(dc.du, view(T.(dc.dual_du), 1:nelem)) +end +function get_tmp(dc::BigFloatDiffCache, u::Union{Number, AbstractArray}) + return dc.du +end +function get_tmp(dc::BigFloatDiffCache, ::Type{T}) where {T <: Number} + return dc.du +end +get_tmp(dc::Vector{BigFloat}, u::Vector{BigFloat}) = dc + function __maybe_allocate_diffcache(x, chunksize, jac_alg) + eltype(x) <: BigFloat && return (__needs_diffcache(jac_alg) ? BigFloatDiffCache(x, chunksize) : FakeDiffCache(x)) return __needs_diffcache(jac_alg) ? DiffCache(x, chunksize) : FakeDiffCache(x) end __maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(__similar(x.du), chunksize) __maybe_allocate_diffcache(x::FakeDiffCache, _) = FakeDiffCache(__similar(x.du)) +__maybe_allocate_diffcache(x::BigFloatDiffCache, chunksize) = BigFloatDiffCache(x.du, chunksize) -const MaybeDiffCache = Union{DiffCache, FakeDiffCache} +const MaybeDiffCache = Union{DiffCache, FakeDiffCache, BigFloatDiffCache} ## get_tmp shows a warning as it should on cache exapansion, this behavior however is ## expected for adaptive BVP solvers so we write our own `get_tmp` and drop the warning logs diff --git a/test/misc/bigfloat_test.jl b/test/misc/bigfloat_test.jl index 0b7e9137..53211b94 100644 --- a/test/misc/bigfloat_test.jl +++ b/test/misc/bigfloat_test.jl @@ -8,10 +8,11 @@ du[2] = -9.81 * sin(θ) end function bc!(residual, u, p, t) - residual[1] = u[end ÷ 2][1] + pi / 2 - residual[2] = u[end][1] - pi / 2 + residual[1] = u[end ÷ 2][1] + big(pi / 2) + residual[2] = u[end][1] - big(pi / 2) end u0 = BigFloat.([pi / 2, pi / 2]) prob = BVProblem(simplependulum!, bc!, u0, tspan) sol = solve(prob, MIRK4(), dt = 0.05) + @test SciMLBase.successful_retcode(sol.retcode) end From b1904890df437962a347e239f7089a3f4179a790 Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Fri, 16 Aug 2024 23:51:59 +0800 Subject: [PATCH 3/6] Use the latest PreallocationTools.jl Signed-off-by: ErikQQY <2283984853@qq.com> --- Project.toml | 2 +- src/types.jl | 43 +------------------------------------------ 2 files changed, 2 insertions(+), 43 deletions(-) diff --git a/Project.toml b/Project.toml index cc8fe14d..22923708 100644 --- a/Project.toml +++ b/Project.toml @@ -52,7 +52,7 @@ Logging = "1.10" NonlinearSolve = "3.8.1" ODEInterface = "0.5" OrdinaryDiffEq = "6.63" -PreallocationTools = "0.4" +PreallocationTools = "0.4.24" PrecompileTools = "1.2" Preferences = "1.4" Random = "1.10" diff --git a/src/types.jl b/src/types.jl index ea2b5dff..1f758742 100644 --- a/src/types.jl +++ b/src/types.jl @@ -152,54 +152,13 @@ end du end -# hacking DiffCache to handling with BigFloat case -@concrete struct BigFloatDiffCache{T <: AbstractArray, S <: AbstractArray} - du::T - dual_du::S - any_du::Vector{Any} -end - -function BigFloatDiffCache(u::AbstractArray{T}, siz, chunk_sizes) where {T} - x = adapt(ArrayInterface.parameterless_type(u), - zeros(T, prod(chunk_sizes .+ 1) * prod(siz))) - xany = Any[] - BigFloatDiffCache(u, x, xany) -end -function BigFloatDiffCache(u::AbstractArray, N::Int = ForwardDiff.pickchunksize(length(u)); - levels::Int = 1) - BigFloatDiffCache(u, size(u), N * ones(Int, levels)) -end -BigFloatDiffCache(u::AbstractArray, N::AbstractArray{<:Int}) = BigFloatDiffCache(u, size(u), N) -function BigFloatDiffCache(u::AbstractArray, ::Type{Val{N}}; levels::Int = 1) where {N} - BigFloatDiffCache(u, N; levels) -end -BigFloatDiffCache(u::AbstractArray, ::Val{N}; levels::Int = 1) where {N} = BigFloatDiffCache(u, N; levels) - -function get_tmp(dc::BigFloatDiffCache, u::T) where {T <: ForwardDiff.Dual} - nelem = length(dc.du) - PreallocationTools._restructure(dc.du, view(T.(dc.dual_du), 1:nelem)) -end -function get_tmp(dc::BigFloatDiffCache, u::AbstractArray{T}) where {T <: ForwardDiff.Dual} - nelem = length(dc.du) - PreallocationTools._restructure(dc.du, view(T.(dc.dual_du), 1:nelem)) -end -function get_tmp(dc::BigFloatDiffCache, u::Union{Number, AbstractArray}) - return dc.du -end -function get_tmp(dc::BigFloatDiffCache, ::Type{T}) where {T <: Number} - return dc.du -end -get_tmp(dc::Vector{BigFloat}, u::Vector{BigFloat}) = dc - function __maybe_allocate_diffcache(x, chunksize, jac_alg) - eltype(x) <: BigFloat && return (__needs_diffcache(jac_alg) ? BigFloatDiffCache(x, chunksize) : FakeDiffCache(x)) return __needs_diffcache(jac_alg) ? DiffCache(x, chunksize) : FakeDiffCache(x) end __maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(__similar(x.du), chunksize) __maybe_allocate_diffcache(x::FakeDiffCache, _) = FakeDiffCache(__similar(x.du)) -__maybe_allocate_diffcache(x::BigFloatDiffCache, chunksize) = BigFloatDiffCache(x.du, chunksize) -const MaybeDiffCache = Union{DiffCache, FakeDiffCache, BigFloatDiffCache} +const MaybeDiffCache = Union{DiffCache, FakeDiffCache} ## get_tmp shows a warning as it should on cache exapansion, this behavior however is ## expected for adaptive BVP solvers so we write our own `get_tmp` and drop the warning logs From 9db2d10b8a7c7151752acebfe5a124c7a2c452c9 Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Sat, 17 Aug 2024 13:35:03 +0800 Subject: [PATCH 4/6] remove __init_bigfloat_array Signed-off-by: ErikQQY <2283984853@qq.com> --- src/solve/mirk.jl | 6 +++--- src/utils.jl | 11 ++--------- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/src/solve/mirk.jl b/src/solve/mirk.jl index c6033d1a..08fc45c3 100644 --- a/src/solve/mirk.jl +++ b/src/solve/mirk.jl @@ -335,8 +335,8 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo Val(iip), jac_alg.nonbc_diffmode, sd_collocation, loss_collocationₚ, resid_collocation, y) - J_bc = __init_bigfloat_array!!(init_jacobian(cache_bc)) - J_c = __init_bigfloat_array!!(init_jacobian(cache_collocation)) + J_bc = __similar(init_jacobian(cache_bc)) + J_c = __similar(init_jacobian(cache_collocation)) if J_full_band === nothing jac_prototype = vcat(J_bc, J_c) else @@ -422,7 +422,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo NoSparsityDetection() end diffcache = __sparse_jacobian_cache(Val(iip), jac_alg.diffmode, sd, lossₚ, resid, y) - jac_prototype = __init_bigfloat_array!!(init_jacobian(diffcache)) + jac_prototype = __similar(init_jacobian(diffcache)) jac = if iip @closure (J, u, p) -> __mirk_2point_jacobian!( diff --git a/src/utils.jl b/src/utils.jl index 496d84f4..38c7841c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -197,15 +197,8 @@ end @inline function __similar(x, args...) y = similar(x, args...) - return __init_bigfloat_array!!(y) -end - -@inline function __init_bigfloat_array!!(x) - if ArrayInterface.can_setindex(x) - eltype(x) <: BigFloat && fill!(x, BigFloat(0)) - return x - end - return x + eltype(y) <: BigFloat && fill!(y, BigFloat(0)) + return y end @inline function __fill_like(v, x) From 9a5f442a2df3095fff321971e165d56aa6143971 Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Sat, 17 Aug 2024 17:55:30 +0800 Subject: [PATCH 5/6] zero everywhere except length is changed Signed-off-by: ErikQQY <2283984853@qq.com> --- src/solve/mirk.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/solve/mirk.jl b/src/solve/mirk.jl index e01db41e..d9dcb2c0 100644 --- a/src/solve/mirk.jl +++ b/src/solve/mirk.jl @@ -42,10 +42,10 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0, mesh_dt = diff(mesh) chunksize = pickchunksize(N * (Nig - 1)) - __alloc = @closure x -> __maybe_allocate_diffcache(vec(__similar(x)), chunksize, alg.jac_alg) + __alloc = @closure x -> __maybe_allocate_diffcache(vec(zero(x)), chunksize, alg.jac_alg) - fᵢ_cache = __alloc(__similar(X)) - fᵢ₂_cache = vec(__similar(X)) + fᵢ_cache = __alloc(zero(X)) + fᵢ₂_cache = vec(zero(X)) # Don't flatten this here, since we need to expand it later if needed y₀ = __initial_guess_on_mesh(prob.u0, mesh, prob.p) @@ -341,8 +341,8 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo Val(iip), jac_alg.nonbc_diffmode, sd_collocation, loss_collocationₚ, resid_collocation, y) - J_bc = __similar(init_jacobian(cache_bc)) - J_c = __similar(init_jacobian(cache_collocation)) + J_bc = zero(init_jacobian(cache_bc)) + J_c = zero(init_jacobian(cache_collocation)) if J_full_band === nothing jac_prototype = vcat(J_bc, J_c) else @@ -428,7 +428,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo NoSparsityDetection() end diffcache = __sparse_jacobian_cache(Val(iip), jac_alg.diffmode, sd, lossₚ, resid, y) - jac_prototype = __similar(init_jacobian(diffcache)) + jac_prototype = zero(init_jacobian(diffcache)) jac = if iip @closure (J, u, p) -> __mirk_2point_jacobian!( From 40656db0437483c9ee1d9a32d761a38c7c7a4691 Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Sat, 17 Aug 2024 18:18:06 +0800 Subject: [PATCH 6/6] change to zeros everywhere Signed-off-by: ErikQQY <2283984853@qq.com> --- src/types.jl | 4 ++-- src/utils.jl | 17 +++++------------ 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/src/types.jl b/src/types.jl index 1f758742..448f1367 100644 --- a/src/types.jl +++ b/src/types.jl @@ -155,8 +155,8 @@ end function __maybe_allocate_diffcache(x, chunksize, jac_alg) return __needs_diffcache(jac_alg) ? DiffCache(x, chunksize) : FakeDiffCache(x) end -__maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(__similar(x.du), chunksize) -__maybe_allocate_diffcache(x::FakeDiffCache, _) = FakeDiffCache(__similar(x.du)) +__maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(zero(x.du), chunksize) +__maybe_allocate_diffcache(x::FakeDiffCache, _) = FakeDiffCache(zero(x.du)) const MaybeDiffCache = Union{DiffCache, FakeDiffCache} diff --git a/src/utils.jl b/src/utils.jl index a5c771ed..62eba2c9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,7 +2,7 @@ recursive_length(x::Vector{<:AbstractArray}) = sum(length, x) recursive_length(x::Vector{<:MaybeDiffCache}) = sum(xᵢ -> length(xᵢ.u), x) function recursive_flatten(x::Vector{<:AbstractArray}) - y = __similar(first(x), recursive_length(x)) + y = zero(first(x), recursive_length(x)) recursive_flatten!(y, x) return y end @@ -112,7 +112,7 @@ function __append_similar!(x::AbstractVector{<:AbstractArray}, n, _) N = n - length(x) N == 0 && return x N < 0 && throw(ArgumentError("Cannot append a negative number of elements")) - append!(x, [__similar(last(x)) for _ in 1:N]) + append!(x, [zero(last(x)) for _ in 1:N]) return x end @@ -191,14 +191,13 @@ function __get_bcresid_prototype(::TwoPointBVProblem, prob::BVProblem, u) end function __get_bcresid_prototype(::StandardBVProblem, prob::BVProblem, u) prototype = prob.f.bcresid_prototype !== nothing ? prob.f.bcresid_prototype : - __zeros_like(u) + zero(u) return prototype, size(prototype) end @inline function __similar(x, args...) y = similar(x, args...) - eltype(y) <: BigFloat && fill!(y, BigFloat(0)) - return y + return zero(y) end @inline function __fill_like(v, x) @@ -206,13 +205,7 @@ end fill!(y, v) return y end -@inline function __zeros_like(u) - if ArrayInterface.can_setindex(u) - eltype(u) <: BigFloat && __fill_like(BigFloat(0), u) - return u - end - return __fill_like(0, u) -end + @inline __ones_like(args...) = __fill_like(1, args...) @inline __safe_vec(x) = vec(x)