Skip to content

Commit

Permalink
Fix non-vector input tests error
Browse files Browse the repository at this point in the history
Signed-off-by: ErikQQY <[email protected]>
  • Loading branch information
ErikQQY committed Aug 15, 2024
1 parent 0508d64 commit 5597128
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/solve/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,14 +276,14 @@ end
@views function __mirk_loss_bc!(
resid, u, p, pt, bc!::BC, y, mesh, cache::MIRKCache) where {BC}
y_ = recursive_unflatten!(y, u)
soly_ = VectorOfArray([[r[i] for i in 1:cache.M] for r in y_])
soly_ = VectorOfArray(y_)
eval_bc_residual!(resid, pt, bc!, soly_, p, mesh)
return nothing
end

@views function __mirk_loss_bc(u, p, pt, bc!::BC, y, mesh, cache::MIRKCache) where {BC}
y_ = recursive_unflatten!(y, u)
soly_ = VectorOfArray([[r[i] for i in 1:cache.M] for r in y_])
soly_ = VectorOfArray(y_)
return eval_bc_residual(pt, bc!, soly_, p, mesh)
end

Expand Down
4 changes: 2 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ __vec_bc(sol, p, bc, u_size) = vec(bc(reshape(sol, u_size), p))
@inline __get_non_sparse_ad(ad::AutoSparse) = ADTypes.dense_ad(ad)

# Restructure Solution
function __restructure_sol(sol::Vector{<:AbstractArray}, u_size)
return map(Base.Fix2(reshape, u_size), sol)
function __restructure_sol(sol::AbstractVectorOfArray, u_size)
return VectorOfArray(map(Base.Fix2(reshape, u_size), sol))
end

# Override the checks for NonlinearFunction
Expand Down

0 comments on commit 5597128

Please sign in to comment.