diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 0d322a056..6a440c2de 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -23,10 +23,10 @@ using RecursiveArrayTools T = eltype(eltype(VA.u)) N = length(VA.prob.p) Δ′ = ODESolution{T, N, typeof(du), Nothing, Nothing, typeof(VA.t), - typeof(VA.k), typeof(dprob), typeof(VA.alg), typeof(VA.interp), - typeof(VA.stats), typeof(VA.alg_choice)}(du, nothing, nothing, - VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats, - VA.alg_choice, VA.retcode) + typeof(VA.k), typeof(VA.discretes), typeof(dprob), typeof(VA.alg), + typeof(VA.interp), typeof(VA.stats), typeof(VA.alg_choice)}(du, nothing, + nothing, VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, + VA.stats, VA.alg_choice, VA.retcode) (Δ′, nothing, nothing) end VA[i, j], ODESolution_getindex_pullback @@ -52,10 +52,10 @@ end T = eltype(eltype(VA.u)) N = length(VA.prob.p) Δ′ = ODESolution{T, N, typeof(du), Nothing, Nothing, typeof(VA.t), - typeof(VA.k), typeof(dprob), typeof(VA.alg), typeof(VA.interp), - typeof(VA.stats), typeof(VA.alg_choice)}(du, nothing, nothing, - VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats, - VA.alg_choice, VA.retcode) + typeof(VA.k), typeof(VA.discretes), typeof(dprob), typeof(VA.alg), + typeof(VA.interp), typeof(VA.stats), typeof(VA.alg_choice)}(du, nothing, + nothing, VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, + VA.stats, VA.alg_choice, VA.retcode) (Δ′, nothing, nothing) end VA[sym, j], ODESolution_getindex_pullback @@ -108,15 +108,15 @@ end VA[sym], ODESolution_getindex_pullback end -@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12 +@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13 }(u, args...) where {T1, T2, T3, T4, T5, T6, T7, T8, - T9, T10, T11, T12} + T9, T10, T11, T12, T13} function ODESolutionAdjoint(ȳ) (ȳ, ntuple(_ -> nothing, length(args))...) end - ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...), + ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13}(u, args...), ODESolutionAdjoint end