diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index aa6abbdf2..5264bff91 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -31,7 +31,7 @@ using RecursiveArrayTools N = length((size(dprob.u0)..., length(du))) end Δ′ = ODESolution{T, N}(du, nothing, nothing, - VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats, + VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats, VA.alg_choice, VA.retcode) (Δ′, nothing, nothing) end @@ -66,7 +66,7 @@ end N = length((size(dprob.u0)..., length(du))) end Δ′ = ODESolution{T, N}(du, nothing, nothing, - VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats, + VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats, VA.alg_choice, VA.retcode) (Δ′, nothing, nothing) end @@ -144,15 +144,15 @@ end VA[sym], ODESolution_getindex_pullback end -@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12 +@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13 }(u, args...) where {T1, T2, T3, T4, T5, T6, T7, T8, - T9, T10, T11, T12} + T9, T10, T11, T12, T13} function ODESolutionAdjoint(ȳ) (ȳ, ntuple(_ -> nothing, length(args))...) end - ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...), + ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13}(u, args...), ODESolutionAdjoint end