Skip to content

Commit

Permalink
Add CUDAEnzyme extension
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed Jan 25, 2024
1 parent e0b64a5 commit 8221520
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 61 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ version = "0.9.15"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand All @@ -20,9 +19,11 @@ UnsafeAtomicsLLVM = "d80eeb9a-aca5-4d75-85e5-170c8b632249"

[weakdeps]
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[extensions]
EnzymeExt = "EnzymeCore"
CUDAEnzymeExt = ["CUDA", "EnzymeCore"]

[compat]
Adapt = "0.4, 1.0, 2.0, 3.0, 4"
Expand Down
70 changes: 70 additions & 0 deletions ext/CUDAEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
module CUDAEnzymeExt
if isdefined(Base, :get_extension)
using EnzymeCore
using EnzymeCore.EnzymeRules
else
using ..EnzymeCore
using ..EnzymeCore.EnzymeRules
end
import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU, GPU
using CUDA

function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CUDABackend}}, type::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
println("Custom rule GPU")
kernel = func.val
f = kernel.f
mi = CUDA.methodinstance(typeof(()->return), Tuple{})
job = CUDA.CompilerJob(mi, CUDA.compiler_config(device()))

ndrange, workgroupsize, iterspace, dynamic = launch_config(kernel, ndrange, workgroupsize)
block = first(blocks(iterspace))
ctx = mkcontext(kernel, ndrange, iterspace)
ctxTy = Core.Typeof(ctx) # CompilerMetadata{ndrange(kernel), Core.Typeof(dynamic)}

# TODO autodiff_deferred on the func.val
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))

tup = Val(ntuple(Val(N)) do i
Base.@_inline_meta
args[i] isa Active
end)
f = make_active_byref(f, tup)
FT = Const{Core.Typeof(f)}

arg_refs = ntuple(Val(N)) do i
Base.@_inline_meta
if args[i] isa Active
Ref(EnzymeCore.make_zero(args[i].val))
else
nothing
end
end
args2 = ntuple(Val(N)) do i
Base.@_inline_meta
if args[i] isa Active
Duplicated(Ref(args[i].val), arg_refs[i])
else
args[i]
end
end

# TODO in KA backends like CUDAKernels, etc have a version with a parent job type
TapeType = EnzymeCore.tape_type(job, ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args2)...)
@show TapeType


subtape = Array{TapeType}(undef, size(blocks(iterspace)))

aug_kernel = similar(kernel, aug_fwd)

aug_kernel(f, ModifiedBetween, subtape, args2...; ndrange, workgroupsize)

# TODO the fact that ctxTy is type unstable means this is all type unstable.
# Since custom rules require a fixed return type, explicitly cast to Any, rather
# than returning a AugmentedReturn{Nothing, Nothing, T} where T.

res = AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs)}}(nothing, nothing, (subtape, arg_refs))
return res
end

end
59 changes: 0 additions & 59 deletions ext/EnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ module EnzymeExt
using ..EnzymeCore.EnzymeRules
end
import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU, GPU
using CUDA

EnzymeRules.inactive(::Type{StaticSize}, x...) = nothing

Expand Down Expand Up @@ -112,64 +111,6 @@ module EnzymeExt
return res
end

function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CUDABackend}}, type::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
println("Custom rule GPU")
kernel = func.val
f = kernel.f
mi = CUDA.methodinstance(typeof(()->return), Tuple{})
job = CUDA.CompilerJob(mi, CUDA.compiler_config(device()))

ndrange, workgroupsize, iterspace, dynamic = launch_config(kernel, ndrange, workgroupsize)
block = first(blocks(iterspace))
ctx = mkcontext(kernel, ndrange, iterspace)
ctxTy = Core.Typeof(ctx) # CompilerMetadata{ndrange(kernel), Core.Typeof(dynamic)}

# TODO autodiff_deferred on the func.val
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))

tup = Val(ntuple(Val(N)) do i
Base.@_inline_meta
args[i] isa Active
end)
f = make_active_byref(f, tup)
FT = Const{Core.Typeof(f)}

arg_refs = ntuple(Val(N)) do i
Base.@_inline_meta
if args[i] isa Active
Ref(EnzymeCore.make_zero(args[i].val))
else
nothing
end
end
args2 = ntuple(Val(N)) do i
Base.@_inline_meta
if args[i] isa Active
Duplicated(Ref(args[i].val), arg_refs[i])
else
args[i]
end
end

# TODO in KA backends like CUDAKernels, etc have a version with a parent job type
TapeType = EnzymeCore.tape_type(job, ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args2)...)
@show TapeType


subtape = Array{TapeType}(undef, size(blocks(iterspace)))

aug_kernel = similar(kernel, aug_fwd)

aug_kernel(f, ModifiedBetween, subtape, args2...; ndrange, workgroupsize)

# TODO the fact that ctxTy is type unstable means this is all type unstable.
# Since custom rules require a fixed return type, explicitly cast to Any, rather
# than returning a AugmentedReturn{Nothing, Nothing, T} where T.

res = AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs)}}(nothing, nothing, (subtape, arg_refs))
return res
end

function EnzymeRules.reverse(config::Config, func::Const{<:Kernel}, ::Type{<:EnzymeCore.Annotation}, tape, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
subtape, arg_refs = tape

Expand Down
1 change: 0 additions & 1 deletion test/reverse_gpu.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Test
using Enzyme_jll
using Enzyme
using KernelAbstractions
using CUDA
Expand Down

0 comments on commit 8221520

Please sign in to comment.