Skip to content

Commit

Permalink
wrap args to Const
Browse files Browse the repository at this point in the history
Co-authored-by: Valentin Churavy <[email protected]>
  • Loading branch information
aelligp and vchuravy committed Oct 9, 2024
1 parent 408977d commit 01d4f7b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
11 changes: 8 additions & 3 deletions src/ParallelKernel/EnzymeExt/autodiff_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@ import Enzyme
# end
# end

function promoto_to_const(args...)
# ParallelStencil injects a configuration parameter at the end, for Enzyme we need to wrap that parameter as a Annotation
# for all purposes this ought to be Const. This is not ideal since we might accidentially wrap other parameters the user
# provided as well. This is needed to support @parallel autodiff_deferred(...)
function promote_to_const(args...)
ntuple(length(args)) do i
@inbounds
if !(args[i] isa Enzyme.Annotation)
@inline
if !(args[i] isa Enzyme.Annotation ||
(args[i] isa UnionAll && args[i] <: Enzyme.Annotation) || # Const
(args[i] isa DataType && args[i] <: Enzyme.Annotation)) # Const{Nothing}
return Enzyme.Const(args[i])
else
return args[i]
Expand Down
5 changes: 2 additions & 3 deletions test/ParallelKernel/test_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,8 @@ import Enzyme
end
return
end
# Enzyme requires explicit argument annotation, PS injects arguments without annotation
@parallel configcall=f!(A, B, a) AD.autodiff_deferred!(Enzyme.Reverse, Const(f!), DuplicatedNoNeed(A, Ā), DuplicatedNoNeed(B, B̄), Const(a))
Enzyme.autodiff_deferred(Enzyme.Reverse, Const(g!), DuplicatedNoNeed(A_ref, Ā_ref), DuplicatedNoNeed(B_ref, B̄_ref), Const(a))
@parallel configcall=f!(A, B, a) AD.autodiff_deferred!(Enzyme.Reverse, Const(f!), Const, DuplicatedNoNeed(A, Ā), DuplicatedNoNeed(B, B̄), Const(a))
Enzyme.autodiff_deferred(Enzyme.Reverse, Const(g!),Const, DuplicatedNoNeed(A_ref, Ā_ref), DuplicatedNoNeed(B_ref, B̄_ref), Const(a))
@test Array(Ā) Ā_ref
@test Array(B̄) B̄_ref
end
Expand Down

0 comments on commit 01d4f7b

Please sign in to comment.