diff --git a/src/solve/mirk.jl b/src/solve/mirk.jl index 94a5ee2c..a82dc4f9 100644 --- a/src/solve/mirk.jl +++ b/src/solve/mirk.jl @@ -276,14 +276,14 @@ end @views function __mirk_loss_bc!( resid, u, p, pt, bc!::BC, y, mesh, cache::MIRKCache) where {BC} y_ = recursive_unflatten!(y, u) - soly_ = VectorOfArray([[r[i] for i in 1:cache.M] for r in y_]) + soly_ = VectorOfArray(y_) eval_bc_residual!(resid, pt, bc!, soly_, p, mesh) return nothing end @views function __mirk_loss_bc(u, p, pt, bc!::BC, y, mesh, cache::MIRKCache) where {BC} y_ = recursive_unflatten!(y, u) - soly_ = VectorOfArray([[r[i] for i in 1:cache.M] for r in y_]) + soly_ = VectorOfArray(y_) return eval_bc_residual(pt, bc!, soly_, p, mesh) end diff --git a/src/utils.jl b/src/utils.jl index 1b569648..ce50f771 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -234,8 +234,8 @@ __vec_bc(sol, p, bc, u_size) = vec(bc(reshape(sol, u_size), p)) @inline __get_non_sparse_ad(ad::AutoSparse) = ADTypes.dense_ad(ad) # Restructure Solution -function __restructure_sol(sol::Vector{<:AbstractArray}, u_size) - return map(Base.Fix2(reshape, u_size), sol) +function __restructure_sol(sol::AbstractVectorOfArray, u_size) + return VectorOfArray(map(Base.Fix2(reshape, u_size), sol)) end # Override the checks for NonlinearFunction