Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EnzymeTestUtils] Fix batch duplicated return and aliasing return/argument #1079

Merged
merged 2 commits into from
Sep 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/EnzymeTestUtils/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeTestUtils"
uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
authors = ["Seth Axen <[email protected]>", "William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.1.0"
version = "0.1.1"

[deps]
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Expand Down
15 changes: 14 additions & 1 deletion lib/EnzymeTestUtils/src/finite_difference_calls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,20 @@ function _wrap_reverse_function(f, xs, ignores)
@assert j == length(sigargs) + 1
@assert length(callargs) == length(xs)
@assert length(retargs) == count(!, ignores)
return (deepcopy(f)(callargs...), retargs...)

# if an arg and a return alias, do not consider the contribution from the arg as returned here,
# it will already be taken into account. This is implemented using the deepcopy_internal, which
# will add all objects inside the return into the dict `zeros`.
zeros = IdDict()
origRet = Base.deepcopy_internal(deepcopy(f)(callargs...), zeros)

# we will now explicitly zero all objects returned, and replace any of the args with this
# zero, if the input and output alias.
for k in keys(zeros)
zeros[k] = zero_tangent(k)
end

return (origRet, Base.deepcopy_internal(retargs, zeros)...)
end
return fnew
end
8 changes: 7 additions & 1 deletion lib/EnzymeTestUtils/src/test_reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,13 @@ function test_reverse(
else
# if there's a shadow result, then we need to set it to our random adjoint
if !(shadow_result === nothing)
map_fields_recursive(copyto!, shadow_result, ȳ)
if !_any_batch_duplicated(map(typeof, activities)...)
map_fields_recursive(copyto!, shadow_result, ȳ)
else
for (sr, dy) in zip(shadow_result, ȳ)
map_fields_recursive(copyto!, sr, dy)
end
end
end
dx_ad = only(reverse(c_act, activities..., tape))
end
Expand Down
Loading