Skip to content

Commit

Permalink
Synchronize rule
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed May 31, 2024
1 parent 5bb7651 commit ede2f76
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
34 changes: 28 additions & 6 deletions ext/EnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ module EnzymeExt
using ..EnzymeCore.EnzymeRules
end

function EnzymeCore.compiler_job_from_backend(b::Backend, @nospecialize(F::Type), @nospecialize(TT::Type))
error("EnzymeCore.compiler_job_from_backend is not yet implemented for $(typeof(b)), please file an issue.")
end
# TODO: Remove
using KernelAbstractions
import KernelAbstractions: Kernel, StaticSize, launch_config, allocate,
blocks, mkcontext, CompilerMetadata, CPU, GPU, argconvert,
supports_enzyme, __fake_compiler_job, backend,
__index_Group_Cartesian, __index_Global_Linear,
__groupsize, __groupindex, Backend
__groupsize, __groupindex, Backend, synchronize

function EnzymeCore.compiler_job_from_backend(b::Backend, @nospecialize(F::Type), @nospecialize(TT::Type))
error("EnzymeCore.compiler_job_from_backend is not yet implemented for $(typeof(b)), please file an issue.")
end

EnzymeRules.inactive(::Type{StaticSize}, x...) = nothing

Expand Down Expand Up @@ -156,6 +156,7 @@ module EnzymeExt
nothing
end
end
synchronize(backend(kernel))
return res
end

Expand Down Expand Up @@ -231,7 +232,6 @@ module EnzymeExt
kernel, ModifiedBetween, FT, ctxTy, ndrange, iterspace, args2...
)
aug_kernel(f, ModifiedBetween, subtape, Val(TapeType), args2...; ndrange, workgroupsize)
KernelAbstractions.synchronize(backend(kernel))

# TODO the fact that ctxTy is type unstable means this is all type unstable.
# Since custom rules require a fixed return type, explicitly cast to Any, rather
Expand All @@ -257,4 +257,26 @@ module EnzymeExt
return inact
end

# Synchronize rules
# TODO: Right now we do the synchronization as part of the kernel launch in the augmented primal
# and reverse rules. This is not ideal, as we would want to launch the kernel in the reverse
# synchronize rule and then synchronize where the launch was. However, with the current
# kernel semantics this ensures correctness for now.
function EnzymeRules.augmented_primal(
config::Config,
func::Const{typeof(synchronize)},
::Type{Const{Nothing}},
backend::T
) where T <: EnzymeCore.Annotation
synchronize(backend.val)
return AugmentedReturn(
nothing, nothing, nothing
)
end

function EnzymeRules.reverse(config::Config, func::Const{typeof(synchronize)}, ::Type{Const{Nothing}}, tape, backend)
# noop for now
return (nothing,)
end

end
3 changes: 1 addition & 2 deletions test/extensions/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ end
function mul_caller(A, B, backend)
kernel = mul!(backend)
kernel(A, B, ndrange=size(A))
KernelAbstractions.synchronize(backend)
end

function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
Expand All @@ -34,15 +35,13 @@ function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
dA .= 1

Enzyme.autodiff(Reverse, square_caller, Duplicated(A, dA), Const(backend()))
KernelAbstractions.synchronize(backend())
@test all(dA .≈ (2:2:128))


A .= (1:1:64)
dA .= 1

_, dB, _ = Enzyme.autodiff(Reverse, mul_caller, Duplicated(A, dA), Active(1.2), Const(backend()))[1]
KernelAbstractions.synchronize(backend())

@test all(dA .≈ 1.2)
@test dB sum(1:1:64)
Expand Down

0 comments on commit ede2f76

Please sign in to comment.