Skip to content

Commit

Permalink
[EnzymeTestUtils] Fix batch duplicated return and aliasing return/arg…
Browse files Browse the repository at this point in the history
…ument (#1079)

* Fix batch duplicated return

* Handle aliasing input and outputs
  • Loading branch information
wsmoses authored and michel2323 committed Nov 7, 2023
1 parent 8fa5ab8 commit e0e4393
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
2 changes: 1 addition & 1 deletion lib/EnzymeTestUtils/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeTestUtils"
uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
authors = ["Seth Axen <[email protected]>", "William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.1.0"
version = "0.1.1"

[deps]
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Expand Down
15 changes: 14 additions & 1 deletion lib/EnzymeTestUtils/src/finite_difference_calls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,20 @@ function _wrap_reverse_function(f, xs, ignores)
@assert j == length(sigargs) + 1
@assert length(callargs) == length(xs)
@assert length(retargs) == count(!, ignores)
return (deepcopy(f)(callargs...), retargs...)

# if an arg and a return alias, do not consider the contribution from the arg as returned here,
# it will already be taken into account. This is implemented using the deepcopy_internal, which
# will add all objects inside the return into the dict `zeros`.
zeros = IdDict()
origRet = Base.deepcopy_internal(deepcopy(f)(callargs...), zeros)

# we will now explicitly zero all objects returned, and replace any of the args with this
# zero, if the input and output alias.
for k in keys(zeros)
zeros[k] = zero_tangent(k)
end

return (origRet, Base.deepcopy_internal(retargs, zeros)...)
end
return fnew
end
8 changes: 7 additions & 1 deletion lib/EnzymeTestUtils/src/test_reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,13 @@ function test_reverse(
else
# if there's a shadow result, then we need to set it to our random adjoint
if !(shadow_result === nothing)
map_fields_recursive(copyto!, shadow_result, ȳ)
if !_any_batch_duplicated(map(typeof, activities)...)
map_fields_recursive(copyto!, shadow_result, ȳ)
else
for (sr, dy) in zip(shadow_result, ȳ)
map_fields_recursive(copyto!, sr, dy)
end
end
end
dx_ad = only(reverse(c_act, activities..., tape))
end
Expand Down

0 comments on commit e0e4393

Please sign in to comment.