diff --git a/ext/EnzymeExt.jl b/ext/EnzymeExt.jl index df9fd72f..cf3ac66b 100644 --- a/ext/EnzymeExt.jl +++ b/ext/EnzymeExt.jl @@ -6,7 +6,8 @@ module EnzymeExt using ..EnzymeCore using ..EnzymeCore.EnzymeRules end - import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU + import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU, GPU + using CUDA EnzymeRules.inactive(::Type{StaticSize}, x...) = nothing @@ -55,6 +56,7 @@ module EnzymeExt end function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CPU}}, ::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N + println("Custom rule CPU") kernel = func.val f = kernel.f @@ -93,6 +95,65 @@ module EnzymeExt # TODO in KA backends like CUDAKernels, etc have a version with a parent job type TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args2)...) + @show TapeType + + + subtape = Array{TapeType}(undef, size(blocks(iterspace))) + + aug_kernel = similar(kernel, aug_fwd) + + aug_kernel(f, ModifiedBetween, subtape, args2...; ndrange, workgroupsize) + + # TODO the fact that ctxTy is type unstable means this is all type unstable. + # Since custom rules require a fixed return type, explicitly cast to Any, rather + # than returning a AugmentedReturn{Nothing, Nothing, T} where T. + + res = AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs)}}(nothing, nothing, (subtape, arg_refs)) + return res + end + + function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CUDABackend}}, type::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N + println("Custom rule GPU") + kernel = func.val + f = kernel.f + mi = CUDA.methodinstance(typeof(()->return), Tuple{}) + job = CUDA.CompilerJob(mi, CUDA.compiler_config(device())) + + ndrange, workgroupsize, iterspace, dynamic = launch_config(kernel, ndrange, workgroupsize) + block = first(blocks(iterspace)) + ctx = mkcontext(kernel, ndrange, iterspace) + ctxTy = Core.Typeof(ctx) # CompilerMetadata{ndrange(kernel), Core.Typeof(dynamic)} + + # TODO autodiff_deferred on the func.val + ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...)) + + tup = Val(ntuple(Val(N)) do i + Base.@_inline_meta + args[i] isa Active + end) + f = make_active_byref(f, tup) + FT = Const{Core.Typeof(f)} + + arg_refs = ntuple(Val(N)) do i + Base.@_inline_meta + if args[i] isa Active + Ref(EnzymeCore.make_zero(args[i].val)) + else + nothing + end + end + args2 = ntuple(Val(N)) do i + Base.@_inline_meta + if args[i] isa Active + Duplicated(Ref(args[i].val), arg_refs[i]) + else + args[i] + end + end + + # TODO in KA backends like CUDAKernels, etc have a version with a parent job type + TapeType = EnzymeCore.tape_type(job, ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args2)...) + @show TapeType subtape = Array{TapeType}(undef, size(blocks(iterspace))) diff --git a/reverse_gpu.jl b/reverse_gpu.jl index 8766700b..f9682f25 100644 --- a/reverse_gpu.jl +++ b/reverse_gpu.jl @@ -59,6 +59,6 @@ function enzyme_testsuite(backend, ArrayT, supports_reverse=true) end end -enzyme_testsuite(CPU, Array, true) +# enzyme_testsuite(CPU, Array, true) # enzyme_testsuite(CUDABackend, CuArray, false) -enzyme_testsuite(CUDABackend, CuArray, true) \ No newline at end of file +enzyme_testsuite(CUDABackend, CuArray, true)