Skip to content

Commit

Permalink
fix: fix ODESolution-related adjoints
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Mar 11, 2024
1 parent cd635f8 commit 7f57ef0
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ using RecursiveArrayTools
T = eltype(eltype(VA.u))
N = length(VA.prob.p)
Δ′ = ODESolution{T, N, typeof(du), Nothing, Nothing, typeof(VA.t),
typeof(VA.k), typeof(dprob), typeof(VA.alg), typeof(VA.interp),
typeof(VA.stats), typeof(VA.alg_choice)}(du, nothing, nothing,
VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
VA.alg_choice, VA.retcode)
typeof(VA.k), typeof(VA.discretes), typeof(dprob), typeof(VA.alg),
typeof(VA.interp), typeof(VA.stats), typeof(VA.alg_choice)}(du, nothing,
nothing, VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0,
VA.stats, VA.alg_choice, VA.retcode)
(Δ′, nothing, nothing)
end
VA[i, j], ODESolution_getindex_pullback
Expand All @@ -52,10 +52,10 @@ end
T = eltype(eltype(VA.u))
N = length(VA.prob.p)
Δ′ = ODESolution{T, N, typeof(du), Nothing, Nothing, typeof(VA.t),
typeof(VA.k), typeof(dprob), typeof(VA.alg), typeof(VA.interp),
typeof(VA.stats), typeof(VA.alg_choice)}(du, nothing, nothing,
VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
VA.alg_choice, VA.retcode)
typeof(VA.k), typeof(VA.discretes), typeof(dprob), typeof(VA.alg),
typeof(VA.interp), typeof(VA.stats), typeof(VA.alg_choice)}(du, nothing,
nothing, VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0,
VA.stats, VA.alg_choice, VA.retcode)
(Δ′, nothing, nothing)
end
VA[sym, j], ODESolution_getindex_pullback
Expand Down Expand Up @@ -108,15 +108,15 @@ end
VA[sym], ODESolution_getindex_pullback
end

@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12
@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13

Check warning on line 111 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L111

Added line #L111 was not covered by tests
}(u,
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,
T9, T10, T11, T12}
T9, T10, T11, T12, T13}
function ODESolutionAdjoint(ȳ)
(ȳ, ntuple(_ -> nothing, length(args))...)
end

ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...),
ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13}(u, args...),

Check warning on line 119 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L119

Added line #L119 was not covered by tests
ODESolutionAdjoint
end

Expand Down

0 comments on commit 7f57ef0

Please sign in to comment.