Skip to content

Commit

Permalink
Add Enzyme GPU support
Browse files Browse the repository at this point in the history
Still missing synchronize
  • Loading branch information
michel2323 committed Mar 28, 2024
1 parent 6aee730 commit 8d4e7f8
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 77 deletions.
1 change: 1 addition & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ steps:
version:
- "1.8"
- "1.9"
- "1.10"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.version}}"
Expand Down
257 changes: 186 additions & 71 deletions ext/EnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ module EnzymeExt
using ..EnzymeCore
using ..EnzymeCore.EnzymeRules
end
import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU
using KernelAbstractions
import KernelAbstractions: Kernel, StaticSize, launch_config, allocate,
blocks, mkcontext, CompilerMetadata, CPU, GPU, argconvert,
supports_enzyme, __fake_compiler_job, backend,
__index_Group_Cartesian, __index_Global_Linear

EnzymeRules.inactive(::Type{StaticSize}, x...) = nothing

Expand All @@ -15,55 +19,188 @@ module EnzymeExt
return nothing
end

function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
subtape[__groupindex(ctx)] = forward(Const(f), Const(ctx), args...)[1]
function EnzymeRules.forward(func::Const{<:Kernel}, ::Type{Const{Nothing}}, args...; ndrange=nothing, workgroupsize=nothing)
kernel = func.val
f = kernel.f
fwd_kernel = similar(kernel, fwd)

fwd_kernel(f, args...; ndrange, workgroupsize)
end

function _enzyme_mkcontext(kernel::Kernel{CPU}, ndrange, iterspace, dynamic)
block = first(blocks(iterspace))
return mkcontext(kernel, block, ndrange, iterspace, dynamic)
end

function _enzyme_mkcontext(kernel::Kernel{<:GPU}, ndrange, iterspace, dynamic)
return mkcontext(kernel, ndrange, iterspace)
end

function _augmented_return(::Kernel{CPU}, subtape, arg_refs, tape_type)
return AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs), typeof(tape_type)}}(
nothing, nothing, (subtape, arg_refs, tape_type)
)
end

function _augmented_return(::Kernel{<:GPU}, subtape, arg_refs, tape_type)
return AugmentedReturn{Nothing, Nothing, Any}(
nothing, nothing, (subtape, arg_refs, tape_type)
)
end

function _create_tape_kernel(
kernel::Kernel{CPU}, ModifiedBetween,
FT, ctxTy, ndrange, iterspace, args2...
)
TapeType = EnzymeCore.tape_type(
ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween),
FT, Const, Const{ctxTy}, map(Core.Typeof, args2)...
)
subtape = Array{TapeType}(undef, size(blocks(iterspace)))
aug_kernel = similar(kernel, cpu_aug_fwd)
return TapeType, subtape, aug_kernel
end

function _create_tape_kernel(
kernel::Kernel{<:GPU}, ModifiedBetween,
FT, ctxTy, ndrange, iterspace, args2...
)
# For peeking at the TapeType we need to first construct a correct compilation job
# this requires the use of the device side representation of arguments.
# So we convert the arguments here, this is a bit wasteful since the `aug_kernel` call
# will later do the same.
dev_args2 = ((argconvert(kernel, a) for a in args2)...,)
dev_TT = map(Core.Typeof, dev_args2)

job = __fake_compiler_job(backend(kernel))
TapeType = EnzymeCore.tape_type(
job, ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween),
FT, Const, Const{ctxTy}, dev_TT...
)

# Allocate per thread
subtape = allocate(backend(kernel), TapeType, prod(ndrange))

aug_kernel = similar(kernel, gpu_aug_fwd)
return TapeType, subtape, aug_kernel
end

_create_rev_kernel(kernel::Kernel{CPU}) = similar(kernel, cpu_rev)
_create_rev_kernel(kernel::Kernel{<:GPU}) = similar(kernel, gpu_rev)

function cpu_aug_fwd(
ctx, f::FT, ::Val{ModifiedBetween}, subtape, ::Val{TapeType}, args...
) where {ModifiedBetween, FT, TapeType}
# A2 = Const{Nothing} -- since f->Nothing
forward, _ = EnzymeCore.autodiff_deferred_thunk(
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType,
Const{Core.Typeof(f)}, Const, Const{Nothing},
Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...
)

# On the CPU: F is a per block function
# On the CPU: subtape::Vector{Vector}
I = __index_Group_Cartesian(ctx, #=fake=#CartesianIndex(1,1))
subtape[I] = forward(Const(f), Const(ctx), args...)[1]
return nothing
end

function rev(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
tp = subtape[__groupindex(ctx)]
function cpu_rev(
ctx, f::FT, ::Val{ModifiedBetween}, subtape, ::Val{TapeType}, args...
) where {ModifiedBetween, FT, TapeType}
_, reverse = EnzymeCore.autodiff_deferred_thunk(
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType,
Const{Core.Typeof(f)}, Const, Const{Nothing},
Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...
)
I = __index_Group_Cartesian(ctx, #=fake=#CartesianIndex(1,1))
tp = subtape[I]
reverse(Const(f), Const(ctx), args..., tp)
return nothing
end

function EnzymeRules.forward(func::Const{<:Kernel}, ::Type{Const{Nothing}}, args...; ndrange=nothing, workgroupsize=nothing)
function EnzymeRules.reverse(config::Config, func::Const{<:Kernel}, ::Type{<:EnzymeCore.Annotation}, tape, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
subtape, arg_refs, tape_type = tape

args2 = ntuple(Val(N)) do i
Base.@_inline_meta
if args[i] isa Active
Duplicated(Ref(args[i].val), arg_refs[i])
else
args[i]
end
end

kernel = func.val
f = kernel.f
fwd_kernel = similar(kernel, fwd)

fwd_kernel(f, args...; ndrange, workgroupsize)
end
tup = Val(ntuple(Val(N)) do i
Base.@_inline_meta
args[i] isa Active
end)
f = make_active_byref(f, tup)

@inline function make_active_byref(f::F, ::Val{ActiveTys}) where {F, ActiveTys}
if !any(ActiveTys)
return f
end
function inact(ctx, args2::Vararg{Any, N}) where N
args3 = ntuple(Val(N)) do i
Base.@_inline_meta
if ActiveTys[i]
args2[i][]
else
args2[i]
end
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))

rev_kernel = _create_rev_kernel(kernel)
rev_kernel(f, ModifiedBetween, subtape, Val(tape_type), args2...; ndrange, workgroupsize)
res = ntuple(Val(N)) do i
Base.@_inline_meta
if args[i] isa Active
arg_refs[i][]
else
nothing
end
f(ctx, args3...)
end
return inact
return res
end

function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CPU}}, ::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
# GPU support
function gpu_aug_fwd(
ctx, f::FT, ::Val{ModifiedBetween}, subtape, ::Val{TapeType}, args...
) where {ModifiedBetween, FT, TapeType}
# A2 = Const{Nothing} -- since f->Nothing
forward, _ = EnzymeCore.autodiff_deferred_thunk(
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType,
Const{Core.Typeof(f)}, Const, Const{Nothing},
Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...
)

# On the GPU: F is a per thread function
# On the GPU: subtape::Vector
I = __index_Global_Linear(ctx)
subtape[I] = forward(Const(f), Const(ctx), args...)[1]
return nothing
end

function gpu_rev(
ctx, f::FT, ::Val{ModifiedBetween}, subtape, ::Val{TapeType}, args...
) where {ModifiedBetween, FT, TapeType}
# XXX: TapeType and A2 as args to autodiff_deferred_thunk
_, reverse = EnzymeCore.autodiff_deferred_thunk(
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType,
Const{Core.Typeof(f)}, Const, Const{Nothing},
Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...
)
I = __index_Global_Linear(ctx)
tp = subtape[I]
reverse(Const(f), Const(ctx), args..., tp)
return nothing
end

function EnzymeRules.augmented_primal(
config::Config, func::Const{<:Kernel},
::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing
) where N
kernel = func.val
if !supports_enzyme(backend(kernel))
error("KernelAbstractions backend does not support Enzyme")
end
f = kernel.f

ndrange, workgroupsize, iterspace, dynamic = launch_config(kernel, ndrange, workgroupsize)
block = first(blocks(iterspace))

ctx = mkcontext(kernel, block, ndrange, iterspace, dynamic)
ctx = _enzyme_mkcontext(kernel, ndrange, iterspace, dynamic)
ctxTy = Core.Typeof(ctx) # CompilerMetadata{ndrange(kernel), Core.Typeof(dynamic)}

# TODO autodiff_deferred on the func.val
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))

Expand Down Expand Up @@ -91,56 +228,34 @@ module EnzymeExt
end
end

# TODO in KA backends like CUDAKernels, etc have a version with a parent job type
TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args2)...)


subtape = Array{TapeType}(undef, size(blocks(iterspace)))

aug_kernel = similar(kernel, aug_fwd)

aug_kernel(f, ModifiedBetween, subtape, args2...; ndrange, workgroupsize)
TapeType, subtape, aug_kernel = _create_tape_kernel(
kernel, ModifiedBetween, FT, ctxTy, ndrange, iterspace, args2...
)
aug_kernel(f, ModifiedBetween, subtape, Val(TapeType), args2...; ndrange, workgroupsize)
KernelAbstractions.synchronize(backend(kernel))

# TODO the fact that ctxTy is type unstable means this is all type unstable.
# Since custom rules require a fixed return type, explicitly cast to Any, rather
# than returning a AugmentedReturn{Nothing, Nothing, T} where T.

res = AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs)}}(nothing, nothing, (subtape, arg_refs))
return res
return _augmented_return(kernel, subtape, arg_refs, TapeType)
end

function EnzymeRules.reverse(config::Config, func::Const{<:Kernel}, ::Type{<:EnzymeCore.Annotation}, tape, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
subtape, arg_refs = tape

args2 = ntuple(Val(N)) do i
Base.@_inline_meta
if args[i] isa Active
Duplicated(Ref(args[i].val), arg_refs[i])
else
args[i]
end
end

kernel = func.val
f = kernel.f

tup = Val(ntuple(Val(N)) do i
Base.@_inline_meta
args[i] isa Active
end)
f = make_active_byref(f, tup)

ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))

rev_kernel = similar(func.val, rev)
rev_kernel(f, ModifiedBetween, subtape, args2...; ndrange, workgroupsize)
return ntuple(Val(N)) do i
@inline function make_active_byref(f::F, ::Val{ActiveTys}) where {F, ActiveTys}
if !any(ActiveTys)
return f
end
function inact(ctx, args2::Vararg{Any, N}) where N
args3 = ntuple(Val(N)) do i
Base.@_inline_meta
if args[i] isa Active
arg_refs[i][]
if ActiveTys[i]
args2[i][]
else
nothing
args2[i]
end
end
f(ctx, args3...)
end
return inact
end

end
12 changes: 12 additions & 0 deletions src/KernelAbstractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,18 @@ end
__size(args::Tuple) = Tuple{args...}
__size(i::Int) = Tuple{i}

"""
argconvert(::Kernel, arg)
Convert arguments to the device side representation.
"""
argconvert(k::Kernel{T}, arg) where T =
error("Don't know how to convert arguments for Kernel{$T}")

# Enzyme support
supports_enzyme(::Backend) = false
function __fake_compiler_job end

###
# Extras
# - LoopInfo
Expand Down
4 changes: 3 additions & 1 deletion src/cpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,6 @@ end
end

# Argument conversion
KernelAbstractions.argconvert(k::Kernel{CPU}, arg) = arg
argconvert(k::Kernel{CPU}, arg) = arg

supports_enzyme(::CPU) = true
3 changes: 0 additions & 3 deletions src/reflection.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import InteractiveUtils
export @ka_code_typed, @ka_code_llvm

argconvert(k::Kernel{T}, arg) where T =
error("Don't know how to convert arguments for Kernel{$T}")

using UUIDs
const Cthulhu = Base.PkgId(UUID("f68482b8-f384-11e8-15f7-abe071a5a75f"), "Cthulhu")

Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
5 changes: 3 additions & 2 deletions test/extensions/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ end
function square_caller(A, backend)
kernel = square!(backend)
kernel(A, ndrange=size(A))
KernelAbstractions.synchronize(backend)
end


Expand All @@ -22,7 +21,6 @@ end
function mul_caller(A, B, backend)
kernel = mul!(backend)
kernel(A, B, ndrange=size(A))
KernelAbstractions.synchronize(backend)
end

function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
Expand All @@ -36,13 +34,15 @@ function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
dA .= 1

Enzyme.autodiff(Reverse, square_caller, Duplicated(A, dA), Const(backend()))
KernelAbstractions.synchronize(backend())
@test all(dA .≈ (2:2:128))


A .= (1:1:64)
dA .= 1

_, dB, _ = Enzyme.autodiff(Reverse, mul_caller, Duplicated(A, dA), Active(1.2), Const(backend()))[1]
KernelAbstractions.synchronize(backend())

@test all(dA .≈ 1.2)
@test dB sum(1:1:64)
Expand All @@ -52,6 +52,7 @@ function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
dA .= 1

Enzyme.autodiff(Forward, square_caller, Duplicated(A, dA), Const(backend()))
KernelAbstractions.synchronize(backend())
@test all(dA .≈ 2:2:128)

end
Expand Down
Loading

0 comments on commit 8d4e7f8

Please sign in to comment.