diff --git a/Project.toml b/Project.toml index 30e9c014..281e5854 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "0.9.15" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" @@ -20,9 +19,11 @@ UnsafeAtomicsLLVM = "d80eeb9a-aca5-4d75-85e5-170c8b632249" [weakdeps] EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" [extensions] EnzymeExt = "EnzymeCore" +CUDAEnzymeExt = ["CUDA", "EnzymeCore"] [compat] Adapt = "0.4, 1.0, 2.0, 3.0, 4" diff --git a/ext/CUDAEnzymeExt.jl b/ext/CUDAEnzymeExt.jl new file mode 100644 index 00000000..ee82e1c3 --- /dev/null +++ b/ext/CUDAEnzymeExt.jl @@ -0,0 +1,70 @@ +module CUDAEnzymeExt + if isdefined(Base, :get_extension) + using EnzymeCore + using EnzymeCore.EnzymeRules + else + using ..EnzymeCore + using ..EnzymeCore.EnzymeRules + end + import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU, GPU + using CUDA + + function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CUDABackend}}, type::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N + println("Custom rule GPU") + kernel = func.val + f = kernel.f + mi = CUDA.methodinstance(typeof(()->return), Tuple{}) + job = CUDA.CompilerJob(mi, CUDA.compiler_config(device())) + + ndrange, workgroupsize, iterspace, dynamic = launch_config(kernel, ndrange, workgroupsize) + block = first(blocks(iterspace)) + ctx = mkcontext(kernel, ndrange, iterspace) + ctxTy = Core.Typeof(ctx) # CompilerMetadata{ndrange(kernel), Core.Typeof(dynamic)} + + # TODO autodiff_deferred on the func.val + ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...)) + + tup = Val(ntuple(Val(N)) do i + Base.@_inline_meta + args[i] isa Active + end) + f = make_active_byref(f, tup) + FT = Const{Core.Typeof(f)} + + arg_refs = ntuple(Val(N)) do i + Base.@_inline_meta + if args[i] isa Active + Ref(EnzymeCore.make_zero(args[i].val)) + else + nothing + end + end + args2 = ntuple(Val(N)) do i + Base.@_inline_meta + if args[i] isa Active + Duplicated(Ref(args[i].val), arg_refs[i]) + else + args[i] + end + end + + # TODO in KA backends like CUDAKernels, etc have a version with a parent job type + TapeType = EnzymeCore.tape_type(job, ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args2)...) + @show TapeType + + + subtape = Array{TapeType}(undef, size(blocks(iterspace))) + + aug_kernel = similar(kernel, aug_fwd) + + aug_kernel(f, ModifiedBetween, subtape, args2...; ndrange, workgroupsize) + + # TODO the fact that ctxTy is type unstable means this is all type unstable. + # Since custom rules require a fixed return type, explicitly cast to Any, rather + # than returning a AugmentedReturn{Nothing, Nothing, T} where T. + + res = AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs)}}(nothing, nothing, (subtape, arg_refs)) + return res + end + +end \ No newline at end of file diff --git a/ext/EnzymeExt.jl b/ext/EnzymeExt.jl index cf3ac66b..3eb52d6b 100644 --- a/ext/EnzymeExt.jl +++ b/ext/EnzymeExt.jl @@ -7,7 +7,6 @@ module EnzymeExt using ..EnzymeCore.EnzymeRules end import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU, GPU - using CUDA EnzymeRules.inactive(::Type{StaticSize}, x...) = nothing @@ -112,64 +111,6 @@ module EnzymeExt return res end - function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CUDABackend}}, type::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N - println("Custom rule GPU") - kernel = func.val - f = kernel.f - mi = CUDA.methodinstance(typeof(()->return), Tuple{}) - job = CUDA.CompilerJob(mi, CUDA.compiler_config(device())) - - ndrange, workgroupsize, iterspace, dynamic = launch_config(kernel, ndrange, workgroupsize) - block = first(blocks(iterspace)) - ctx = mkcontext(kernel, ndrange, iterspace) - ctxTy = Core.Typeof(ctx) # CompilerMetadata{ndrange(kernel), Core.Typeof(dynamic)} - - # TODO autodiff_deferred on the func.val - ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...)) - - tup = Val(ntuple(Val(N)) do i - Base.@_inline_meta - args[i] isa Active - end) - f = make_active_byref(f, tup) - FT = Const{Core.Typeof(f)} - - arg_refs = ntuple(Val(N)) do i - Base.@_inline_meta - if args[i] isa Active - Ref(EnzymeCore.make_zero(args[i].val)) - else - nothing - end - end - args2 = ntuple(Val(N)) do i - Base.@_inline_meta - if args[i] isa Active - Duplicated(Ref(args[i].val), arg_refs[i]) - else - args[i] - end - end - - # TODO in KA backends like CUDAKernels, etc have a version with a parent job type - TapeType = EnzymeCore.tape_type(job, ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args2)...) - @show TapeType - - - subtape = Array{TapeType}(undef, size(blocks(iterspace))) - - aug_kernel = similar(kernel, aug_fwd) - - aug_kernel(f, ModifiedBetween, subtape, args2...; ndrange, workgroupsize) - - # TODO the fact that ctxTy is type unstable means this is all type unstable. - # Since custom rules require a fixed return type, explicitly cast to Any, rather - # than returning a AugmentedReturn{Nothing, Nothing, T} where T. - - res = AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs)}}(nothing, nothing, (subtape, arg_refs)) - return res - end - function EnzymeRules.reverse(config::Config, func::Const{<:Kernel}, ::Type{<:EnzymeCore.Annotation}, tape, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N subtape, arg_refs = tape diff --git a/test/reverse_gpu.jl b/test/reverse_gpu.jl index d5e51ea6..2758e64c 100644 --- a/test/reverse_gpu.jl +++ b/test/reverse_gpu.jl @@ -1,5 +1,4 @@ using Test -using Enzyme_jll using Enzyme using KernelAbstractions using CUDA