diff --git a/ext/EnzymeExt.jl b/ext/EnzymeExt.jl index d293e54b..7cf7ef73 100644 --- a/ext/EnzymeExt.jl +++ b/ext/EnzymeExt.jl @@ -44,12 +44,12 @@ EnzymeRules.inactive(::Type{StaticSize}, x...) = nothing # https://github.com/EnzymeAD/Enzyme.jl/issues/1516 # On the CPU `autodiff_deferred` can deadlock. # Hence a specialized CPU version -function cpu_fwd(config, ctx, f, args...) +function cpu_fwd(ctx, config, f, args...) EnzymeCore.autodiff(EnzymeCore.set_runtime_activity(Forward, config), Const(f), Const{Nothing}, Const(ctx), args...) return nothing end -function gpu_fwd(ctx, f, args...) +function gpu_fwd(ctx, config, f, args...) EnzymeCore.autodiff_deferred(EnzymeCore.set_runtime_activity(Forward, config), Const(f), Const{Nothing}, Const(ctx), args...) return nothing end @@ -66,7 +66,7 @@ function EnzymeRules.forward( f = kernel.f fwd_kernel = similar(config, kernel, cpu_fwd) - fwd_kernel(f, args...; ndrange, workgroupsize) + fwd_kernel(config, f, args...; ndrange, workgroupsize) end function EnzymeRules.forward( @@ -81,7 +81,7 @@ function EnzymeRules.forward( f = kernel.f fwd_kernel = similar(config, kernel, gpu_fwd) - fwd_kernel(f, args...; ndrange, workgroupsize) + fwd_kernel(config, f, args...; ndrange, workgroupsize) end _enzyme_mkcontext(kernel::Kernel{CPU}, ndrange, iterspace, dynamic) =