diff --git a/src/ParallelKernel/EnzymeExt/autodiff_gpu.jl b/src/ParallelKernel/EnzymeExt/autodiff_gpu.jl index 319c9ede..bde30b07 100644 --- a/src/ParallelKernel/EnzymeExt/autodiff_gpu.jl +++ b/src/ParallelKernel/EnzymeExt/autodiff_gpu.jl @@ -8,10 +8,15 @@ import Enzyme # end # end -function promoto_to_const(args...) +# ParallelStencil injects a configuration parameter at the end, for Enzyme we need to wrap that parameter as a Annotation +# for all purposes this ought to be Const. This is not ideal since we might accidentially wrap other parameters the user +# provided as well. This is needed to support @parallel autodiff_deferred(...) + function promote_to_const(args...) ntuple(length(args)) do i - @inbounds - if !(args[i] isa Enzyme.Annotation) + @inline + if !(args[i] isa Enzyme.Annotation || + (args[i] isa UnionAll && args[i] <: Enzyme.Annotation) || # Const + (args[i] isa DataType && args[i] <: Enzyme.Annotation)) # Const{Nothing} return Enzyme.Const(args[i]) else return args[i] diff --git a/test/ParallelKernel/test_parallel.jl b/test/ParallelKernel/test_parallel.jl index a97076cf..213faa21 100644 --- a/test/ParallelKernel/test_parallel.jl +++ b/test/ParallelKernel/test_parallel.jl @@ -113,9 +113,8 @@ import Enzyme end return end - # Enzyme requires explicit argument annotation, PS injects arguments without annotation - @parallel configcall=f!(A, B, a) AD.autodiff_deferred!(Enzyme.Reverse, Const(f!), DuplicatedNoNeed(A, Ā), DuplicatedNoNeed(B, B̄), Const(a)) - Enzyme.autodiff_deferred(Enzyme.Reverse, Const(g!), DuplicatedNoNeed(A_ref, Ā_ref), DuplicatedNoNeed(B_ref, B̄_ref), Const(a)) + @parallel configcall=f!(A, B, a) AD.autodiff_deferred!(Enzyme.Reverse, Const(f!), Const, DuplicatedNoNeed(A, Ā), DuplicatedNoNeed(B, B̄), Const(a)) + Enzyme.autodiff_deferred(Enzyme.Reverse, Const(g!),Const, DuplicatedNoNeed(A_ref, Ā_ref), DuplicatedNoNeed(B_ref, B̄_ref), Const(a)) @test Array(Ā) ≈ Ā_ref @test Array(B̄) ≈ B̄_ref end