Skip to content

Commit

Permalink
Add ReverseDiff
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Sep 1, 2024
1 parent c3acb74 commit 00cc8c6
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
4 changes: 4 additions & 0 deletions ext/ArrayInterfaceReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,8 @@ function ArrayInterface.aos_to_soa(x::AbstractArray{<:ReverseDiff.TrackedReal, N
end
end

function ArrayInterface.restructure(x::Array, y::ReverseDiff.TrackedArray)
reshape(y, Base.size(x)...)
end

end # module
11 changes: 10 additions & 1 deletion test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,13 @@ y = Tracker.TrackedReal.(rand(2,2))
@test size(ArrayInterface.restructure(x, y)) == (4,)
y = Tracker.TrackedArray(rand(2,2))
@test ArrayInterface.restructure(x, y) isa Tracker.TrackedArray
@test size(ArrayInterface.restructure(x, y)) == (4,)
@test size(ArrayInterface.restructure(x, y)) == (4,)

x = rand(4)
y = ReverseDiff.track(rand(2,2))
@test ArrayInterface.restructure(x, y) isa ReverseDiff.TrackedArray
@test size(ArrayInterface.restructure(x, y)) == (4,)
y = ReverseDiff.track.(rand(2,2))
@test ArrayInterface.restructure(x, y) isa Array
@test eltype(ArrayInterface.restructure(x, y)) <: ReverseDiff.TrackedReal
@test size(ArrayInterface.restructure(x, y)) == (4,)

0 comments on commit 00cc8c6

Please sign in to comment.