diff --git a/src/ParallelKernel/EnzymeExt/autodiff_gpu.jl b/src/ParallelKernel/EnzymeExt/autodiff_gpu.jl index e4ae9894..319c9ede 100644 --- a/src/ParallelKernel/EnzymeExt/autodiff_gpu.jl +++ b/src/ParallelKernel/EnzymeExt/autodiff_gpu.jl @@ -8,12 +8,25 @@ import Enzyme # end # end +function promoto_to_const(args...) + ntuple(length(args)) do i + @inbounds + if !(args[i] isa Enzyme.Annotation) + return Enzyme.Const(args[i]) + else + return args[i] + end + end +end + function ParallelStencil.ParallelKernel.AD.autodiff_deferred!(arg, args...) # NOTE: minimal specialization is used to avoid overwriting the default method + args = promote_to_const(args...) Enzyme.autodiff_deferred(arg, args...) return end function ParallelStencil.ParallelKernel.AD.autodiff_deferred_thunk!(arg, args...) # NOTE: minimal specialization is used to avoid overwriting the default method + args = promote_to_const(args...) Enzyme.autodiff_deferred_thunk(arg, args...) return end diff --git a/test/ParallelKernel/test_parallel.jl b/test/ParallelKernel/test_parallel.jl index cb07fff6..a97076cf 100644 --- a/test/ParallelKernel/test_parallel.jl +++ b/test/ParallelKernel/test_parallel.jl @@ -114,10 +114,10 @@ import Enzyme return end # Enzyme requires explicit argument annotation, PS injects arguments without annotation - # @parallel configcall=f!(A, B, a) AD.autodiff_deferred!(Enzyme.Reverse, Const(f!), DuplicatedNoNeed(A, Ā), DuplicatedNoNeed(B, B̄), Const(a)) + @parallel configcall=f!(A, B, a) AD.autodiff_deferred!(Enzyme.Reverse, Const(f!), DuplicatedNoNeed(A, Ā), DuplicatedNoNeed(B, B̄), Const(a)) Enzyme.autodiff_deferred(Enzyme.Reverse, Const(g!), DuplicatedNoNeed(A_ref, Ā_ref), DuplicatedNoNeed(B_ref, B̄_ref), Const(a)) - # @test Array(Ā) ≈ Ā_ref - # @test Array(B̄) ≈ B̄_ref + @test Array(Ā) ≈ Ā_ref + @test Array(B̄) ≈ B̄_ref end end @testset "@parallel_indices" begin