diff --git a/src/PreallocationTools.jl b/src/PreallocationTools.jl index ced6354..9834c81 100644 --- a/src/PreallocationTools.jl +++ b/src/PreallocationTools.jl @@ -129,27 +129,39 @@ const dualcache = DiffCache Returns the `Dual` or normal cache array stored in `dc` based on the type of `u`. """ function get_tmp(dc::DiffCache, u::T) where {T <: ForwardDiff.Dual} - nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du) - if nelem > length(dc.dual_du) - enlargediffcache!(dc, nelem) + if isbitstype(T) + nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du) + if nelem > length(dc.dual_du) + enlargediffcache!(dc, nelem) + end + _restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem))) + else + _restructure(dc.du, zeros(T, size(dc.du))) end - _restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem))) end function get_tmp(dc::DiffCache, ::Type{T}) where {T <: ForwardDiff.Dual} - nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du) - if nelem > length(dc.dual_du) - enlargediffcache!(dc, nelem) + if isbitstype(T) + nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du) + if nelem > length(dc.dual_du) + enlargediffcache!(dc, nelem) + end + _restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem))) + else + _restructure(dc.du, zeros(T, size(dc.du))) end - _restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem))) end function get_tmp(dc::DiffCache, u::AbstractArray{T}) where {T <: ForwardDiff.Dual} - nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du) - if nelem > length(dc.dual_du) - enlargediffcache!(dc, nelem) + if isbitstype(T) + nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du) + if nelem > length(dc.dual_du) + enlargediffcache!(dc, nelem) + end + _restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem))) + else + _restructure(dc.du, zeros(T, size(dc.du))) end - _restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem))) end function get_tmp(dc::DiffCache, u::Union{Number, AbstractArray})