diff --git a/lib/EnzymeTestUtils/Project.toml b/lib/EnzymeTestUtils/Project.toml index 9ede7c5d7d..f16e3a2383 100644 --- a/lib/EnzymeTestUtils/Project.toml +++ b/lib/EnzymeTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeTestUtils" uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a" authors = ["Seth Axen ", "William Moses ", "Valentin Churavy "] -version = "0.1.0" +version = "0.1.1" [deps] ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" diff --git a/lib/EnzymeTestUtils/src/finite_difference_calls.jl b/lib/EnzymeTestUtils/src/finite_difference_calls.jl index 67e86901c3..f4a00d79b5 100644 --- a/lib/EnzymeTestUtils/src/finite_difference_calls.jl +++ b/lib/EnzymeTestUtils/src/finite_difference_calls.jl @@ -156,7 +156,20 @@ function _wrap_reverse_function(f, xs, ignores) @assert j == length(sigargs) + 1 @assert length(callargs) == length(xs) @assert length(retargs) == count(!, ignores) - return (deepcopy(f)(callargs...), retargs...) + + # if an arg and a return alias, do not consider the contribution from the arg as returned here, + # it will already be taken into account. This is implemented using the deepcopy_internal, which + # will add all objects inside the return into the dict `zeros`. + zeros = IdDict() + origRet = Base.deepcopy_internal(deepcopy(f)(callargs...), zeros) + + # we will now explicitly zero all objects returned, and replace any of the args with this + # zero, if the input and output alias. + for k in keys(zeros) + zeros[k] = zero_tangent(k) + end + + return (origRet, Base.deepcopy_internal(retargs, zeros)...) end return fnew end diff --git a/lib/EnzymeTestUtils/src/test_reverse.jl b/lib/EnzymeTestUtils/src/test_reverse.jl index d0fde8db1a..845a7a90b0 100644 --- a/lib/EnzymeTestUtils/src/test_reverse.jl +++ b/lib/EnzymeTestUtils/src/test_reverse.jl @@ -94,7 +94,13 @@ function test_reverse( else # if there's a shadow result, then we need to set it to our random adjoint if !(shadow_result === nothing) - map_fields_recursive(copyto!, shadow_result, ȳ) + if !_any_batch_duplicated(map(typeof, activities)...) + map_fields_recursive(copyto!, shadow_result, ȳ) + else + for (sr, dy) in zip(shadow_result, ȳ) + map_fields_recursive(copyto!, sr, dy) + end + end end dx_ad = only(reverse(c_act, activities..., tape)) end