diff --git a/ext/ArrayInterfaceReverseDiffExt.jl b/ext/ArrayInterfaceReverseDiffExt.jl index ed544e48..37176824 100644 --- a/ext/ArrayInterfaceReverseDiffExt.jl +++ b/ext/ArrayInterfaceReverseDiffExt.jl @@ -15,4 +15,8 @@ function ArrayInterface.aos_to_soa(x::AbstractArray{<:ReverseDiff.TrackedReal, N end end +function ArrayInterface.restructure(x::Array, y::ReverseDiff.TrackedArray) + reshape(y, Base.size(x)...) +end + end # module diff --git a/ext/ArrayInterfaceTrackerExt.jl b/ext/ArrayInterfaceTrackerExt.jl index d2d4e2ce..5723d9f1 100644 --- a/ext/ArrayInterfaceTrackerExt.jl +++ b/ext/ArrayInterfaceTrackerExt.jl @@ -9,4 +9,11 @@ ArrayInterface.can_setindex(::Type{<:Tracker.TrackedArray}) = false ArrayInterface.fast_scalar_indexing(::Type{<:Tracker.TrackedArray}) = false ArrayInterface.aos_to_soa(x::AbstractArray{<:Tracker.TrackedReal,N}) where {N} = Tracker.collect(x) +function ArrayInterface.restructure(x::Array, y::Tracker.TrackedArray) + reshape(y, Base.size(x)...) +end +function ArrayInterface.restructure(x::Array, y::Array{<:Tracker.TrackedReal}) + reshape(y, Base.size(x)...) +end + end # module diff --git a/test/ad.jl b/test/ad.jl index 7c29c8dd..bc3c3dd8 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -18,3 +18,21 @@ x = Tracker.TrackedArray([4.0,4.0]) x = reduce(vcat, Tracker.TrackedArray([4.0,4.0])) x = [x[1],x[2]] @test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray + +x = rand(4) +y = Tracker.TrackedReal.(rand(2,2)) +@test ArrayInterface.restructure(x, y) isa Array +@test eltype(ArrayInterface.restructure(x, y)) <: Tracker.TrackedReal +@test size(ArrayInterface.restructure(x, y)) == (4,) +y = Tracker.TrackedArray(rand(2,2)) +@test ArrayInterface.restructure(x, y) isa Tracker.TrackedArray +@test size(ArrayInterface.restructure(x, y)) == (4,) + +x = rand(4) +y = ReverseDiff.track(rand(2,2)) +@test ArrayInterface.restructure(x, y) isa ReverseDiff.TrackedArray +@test size(ArrayInterface.restructure(x, y)) == (4,) +y = ReverseDiff.track.(rand(2,2)) +@test ArrayInterface.restructure(x, y) isa Array +@test eltype(ArrayInterface.restructure(x, y)) <: ReverseDiff.TrackedReal +@test size(ArrayInterface.restructure(x, y)) == (4,)