From 1e62b87e87a640a1607c12d70fa19b201513c1c4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 11 Mar 2024 15:42:45 +0530 Subject: [PATCH] fix: fix ODESolution-related adjoints --- ext/SciMLBaseZygoteExt.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index aa6abbdf2..5264bff91 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -31,7 +31,7 @@ using RecursiveArrayTools N = length((size(dprob.u0)..., length(du))) end Δ′ = ODESolution{T, N}(du, nothing, nothing, - VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats, + VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats, VA.alg_choice, VA.retcode) (Δ′, nothing, nothing) end @@ -66,7 +66,7 @@ end N = length((size(dprob.u0)..., length(du))) end Δ′ = ODESolution{T, N}(du, nothing, nothing, - VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats, + VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats, VA.alg_choice, VA.retcode) (Δ′, nothing, nothing) end @@ -144,15 +144,15 @@ end VA[sym], ODESolution_getindex_pullback end -@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12 +@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13 }(u, args...) where {T1, T2, T3, T4, T5, T6, T7, T8, - T9, T10, T11, T12} + T9, T10, T11, T12, T13} function ODESolutionAdjoint(ȳ) (ȳ, ntuple(_ -> nothing, length(args))...) end - ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...), + ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13}(u, args...), ODESolutionAdjoint end