Skip to content

Commit

Permalink
fix: fix ODESolution-related adjoints
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 3, 2024
1 parent 8f07d49 commit 1e62b87
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ using RecursiveArrayTools
N = length((size(dprob.u0)..., length(du)))
end
Δ′ = ODESolution{T, N}(du, nothing, nothing,
VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
VA.alg_choice, VA.retcode)
(Δ′, nothing, nothing)
end
Expand Down Expand Up @@ -66,7 +66,7 @@ end
N = length((size(dprob.u0)..., length(du)))
end
Δ′ = ODESolution{T, N}(du, nothing, nothing,
VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
VA.alg_choice, VA.retcode)
(Δ′, nothing, nothing)
end
Expand Down Expand Up @@ -144,15 +144,15 @@ end
VA[sym], ODESolution_getindex_pullback
end

@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12
@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13

Check warning on line 147 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L147

Added line #L147 was not covered by tests
}(u,
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,
T9, T10, T11, T12}
T9, T10, T11, T12, T13}
function ODESolutionAdjoint(ȳ)
(ȳ, ntuple(_ -> nothing, length(args))...)
end

ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...),
ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13}(u, args...),

Check warning on line 155 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L155

Added line #L155 was not covered by tests
ODESolutionAdjoint
end

Expand Down

0 comments on commit 1e62b87

Please sign in to comment.