diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index c45e29d4..de7e3a1e 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -34,6 +34,7 @@ steps: version: - "1.8" - "1.9" + - "1.10" plugins: - JuliaCI/julia#v1: version: "{{matrix.version}}" diff --git a/ext/EnzymeExt.jl b/ext/EnzymeExt.jl index df9fd72f..c2561a77 100644 --- a/ext/EnzymeExt.jl +++ b/ext/EnzymeExt.jl @@ -6,7 +6,11 @@ module EnzymeExt using ..EnzymeCore using ..EnzymeCore.EnzymeRules end - import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU + using KernelAbstractions + import KernelAbstractions: Kernel, StaticSize, launch_config, allocate, + blocks, mkcontext, CompilerMetadata, CPU, GPU, argconvert, + supports_enzyme, __fake_compiler_job, backend, + __index_Group_Cartesian, __index_Global_Linear EnzymeRules.inactive(::Type{StaticSize}, x...) = nothing @@ -15,55 +19,188 @@ module EnzymeExt return nothing end - function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT} - forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...) - subtape[__groupindex(ctx)] = forward(Const(f), Const(ctx), args...)[1] + function EnzymeRules.forward(func::Const{<:Kernel}, ::Type{Const{Nothing}}, args...; ndrange=nothing, workgroupsize=nothing) + kernel = func.val + f = kernel.f + fwd_kernel = similar(kernel, fwd) + + fwd_kernel(f, args...; ndrange, workgroupsize) + end + + function _enzyme_mkcontext(kernel::Kernel{CPU}, ndrange, iterspace, dynamic) + block = first(blocks(iterspace)) + return mkcontext(kernel, block, ndrange, iterspace, dynamic) + end + + function _enzyme_mkcontext(kernel::Kernel{<:GPU}, ndrange, iterspace, dynamic) + return mkcontext(kernel, ndrange, iterspace) + end + + function _augmented_return(::Kernel{CPU}, subtape, arg_refs, tape_type) + return AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs), typeof(tape_type)}}( + nothing, nothing, (subtape, arg_refs, tape_type) + ) + end + + function _augmented_return(::Kernel{<:GPU}, subtape, arg_refs, tape_type) + return AugmentedReturn{Nothing, Nothing, Any}( + nothing, nothing, (subtape, arg_refs, tape_type) + ) + end + + function _create_tape_kernel( + kernel::Kernel{CPU}, ModifiedBetween, + FT, ctxTy, ndrange, iterspace, args2... + ) + TapeType = EnzymeCore.tape_type( + ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), + FT, Const, Const{ctxTy}, map(Core.Typeof, args2)... + ) + subtape = Array{TapeType}(undef, size(blocks(iterspace))) + aug_kernel = similar(kernel, cpu_aug_fwd) + return TapeType, subtape, aug_kernel + end + + function _create_tape_kernel( + kernel::Kernel{<:GPU}, ModifiedBetween, + FT, ctxTy, ndrange, iterspace, args2... + ) + # For peeking at the TapeType we need to first construct a correct compilation job + # this requires the use of the device side representation of arguments. + # So we convert the arguments here, this is a bit wasteful since the `aug_kernel` call + # will later do the same. + dev_args2 = ((argconvert(kernel, a) for a in args2)...,) + dev_TT = map(Core.Typeof, dev_args2) + + job = __fake_compiler_job(backend(kernel)) + TapeType = EnzymeCore.tape_type( + job, ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), + FT, Const, Const{ctxTy}, dev_TT... + ) + + # Allocate per thread + subtape = allocate(backend(kernel), TapeType, prod(ndrange)) + + aug_kernel = similar(kernel, gpu_aug_fwd) + return TapeType, subtape, aug_kernel + end + + _create_rev_kernel(kernel::Kernel{CPU}) = similar(kernel, cpu_rev) + _create_rev_kernel(kernel::Kernel{<:GPU}) = similar(kernel, gpu_rev) + + function cpu_aug_fwd( + ctx, f::FT, ::Val{ModifiedBetween}, subtape, ::Val{TapeType}, args... + ) where {ModifiedBetween, FT, TapeType} + # A2 = Const{Nothing} -- since f->Nothing + forward, _ = EnzymeCore.autodiff_deferred_thunk( + ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, + Const{Core.Typeof(f)}, Const, Const{Nothing}, + Const{Core.Typeof(ctx)}, map(Core.Typeof, args)... + ) + + # On the CPU: F is a per block function + # On the CPU: subtape::Vector{Vector} + I = __index_Group_Cartesian(ctx, #=fake=#CartesianIndex(1,1)) + subtape[I] = forward(Const(f), Const(ctx), args...)[1] return nothing end - function rev(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT} - forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...) - tp = subtape[__groupindex(ctx)] + function cpu_rev( + ctx, f::FT, ::Val{ModifiedBetween}, subtape, ::Val{TapeType}, args... + ) where {ModifiedBetween, FT, TapeType} + _, reverse = EnzymeCore.autodiff_deferred_thunk( + ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, + Const{Core.Typeof(f)}, Const, Const{Nothing}, + Const{Core.Typeof(ctx)}, map(Core.Typeof, args)... + ) + I = __index_Group_Cartesian(ctx, #=fake=#CartesianIndex(1,1)) + tp = subtape[I] reverse(Const(f), Const(ctx), args..., tp) return nothing end - function EnzymeRules.forward(func::Const{<:Kernel}, ::Type{Const{Nothing}}, args...; ndrange=nothing, workgroupsize=nothing) + function EnzymeRules.reverse(config::Config, func::Const{<:Kernel}, ::Type{<:EnzymeCore.Annotation}, tape, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N + subtape, arg_refs, tape_type = tape + + args2 = ntuple(Val(N)) do i + Base.@_inline_meta + if args[i] isa Active + Duplicated(Ref(args[i].val), arg_refs[i]) + else + args[i] + end + end + kernel = func.val f = kernel.f - fwd_kernel = similar(kernel, fwd) - fwd_kernel(f, args...; ndrange, workgroupsize) - end + tup = Val(ntuple(Val(N)) do i + Base.@_inline_meta + args[i] isa Active + end) + f = make_active_byref(f, tup) - @inline function make_active_byref(f::F, ::Val{ActiveTys}) where {F, ActiveTys} - if !any(ActiveTys) - return f - end - function inact(ctx, args2::Vararg{Any, N}) where N - args3 = ntuple(Val(N)) do i - Base.@_inline_meta - if ActiveTys[i] - args2[i][] - else - args2[i] - end + ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...)) + + rev_kernel = _create_rev_kernel(kernel) + rev_kernel(f, ModifiedBetween, subtape, Val(tape_type), args2...; ndrange, workgroupsize) + res = ntuple(Val(N)) do i + Base.@_inline_meta + if args[i] isa Active + arg_refs[i][] + else + nothing end - f(ctx, args3...) end - return inact + return res end - function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CPU}}, ::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N + # GPU support + function gpu_aug_fwd( + ctx, f::FT, ::Val{ModifiedBetween}, subtape, ::Val{TapeType}, args... + ) where {ModifiedBetween, FT, TapeType} + # A2 = Const{Nothing} -- since f->Nothing + forward, _ = EnzymeCore.autodiff_deferred_thunk( + ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, + Const{Core.Typeof(f)}, Const, Const{Nothing}, + Const{Core.Typeof(ctx)}, map(Core.Typeof, args)... + ) + + # On the GPU: F is a per thread function + # On the GPU: subtape::Vector + I = __index_Global_Linear(ctx) + subtape[I] = forward(Const(f), Const(ctx), args...)[1] + return nothing + end + + function gpu_rev( + ctx, f::FT, ::Val{ModifiedBetween}, subtape, ::Val{TapeType}, args... + ) where {ModifiedBetween, FT, TapeType} + # XXX: TapeType and A2 as args to autodiff_deferred_thunk + _, reverse = EnzymeCore.autodiff_deferred_thunk( + ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, + Const{Core.Typeof(f)}, Const, Const{Nothing}, + Const{Core.Typeof(ctx)}, map(Core.Typeof, args)... + ) + I = __index_Global_Linear(ctx) + tp = subtape[I] + reverse(Const(f), Const(ctx), args..., tp) + return nothing + end + + function EnzymeRules.augmented_primal( + config::Config, func::Const{<:Kernel}, + ::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing + ) where N kernel = func.val + if !supports_enzyme(backend(kernel)) + error("KernelAbstractions backend does not support Enzyme") + end f = kernel.f ndrange, workgroupsize, iterspace, dynamic = launch_config(kernel, ndrange, workgroupsize) - block = first(blocks(iterspace)) - - ctx = mkcontext(kernel, block, ndrange, iterspace, dynamic) + ctx = _enzyme_mkcontext(kernel, ndrange, iterspace, dynamic) ctxTy = Core.Typeof(ctx) # CompilerMetadata{ndrange(kernel), Core.Typeof(dynamic)} - # TODO autodiff_deferred on the func.val ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...)) @@ -91,56 +228,34 @@ module EnzymeExt end end - # TODO in KA backends like CUDAKernels, etc have a version with a parent job type - TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args2)...) - - - subtape = Array{TapeType}(undef, size(blocks(iterspace))) - - aug_kernel = similar(kernel, aug_fwd) - - aug_kernel(f, ModifiedBetween, subtape, args2...; ndrange, workgroupsize) + TapeType, subtape, aug_kernel = _create_tape_kernel( + kernel, ModifiedBetween, FT, ctxTy, ndrange, iterspace, args2... + ) + aug_kernel(f, ModifiedBetween, subtape, Val(TapeType), args2...; ndrange, workgroupsize) + KernelAbstractions.synchronize(backend(kernel)) # TODO the fact that ctxTy is type unstable means this is all type unstable. # Since custom rules require a fixed return type, explicitly cast to Any, rather # than returning a AugmentedReturn{Nothing, Nothing, T} where T. - - res = AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs)}}(nothing, nothing, (subtape, arg_refs)) - return res + return _augmented_return(kernel, subtape, arg_refs, TapeType) end - function EnzymeRules.reverse(config::Config, func::Const{<:Kernel}, ::Type{<:EnzymeCore.Annotation}, tape, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N - subtape, arg_refs = tape - - args2 = ntuple(Val(N)) do i - Base.@_inline_meta - if args[i] isa Active - Duplicated(Ref(args[i].val), arg_refs[i]) - else - args[i] - end - end - - kernel = func.val - f = kernel.f - - tup = Val(ntuple(Val(N)) do i - Base.@_inline_meta - args[i] isa Active - end) - f = make_active_byref(f, tup) - - ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...)) - - rev_kernel = similar(func.val, rev) - rev_kernel(f, ModifiedBetween, subtape, args2...; ndrange, workgroupsize) - return ntuple(Val(N)) do i + @inline function make_active_byref(f::F, ::Val{ActiveTys}) where {F, ActiveTys} + if !any(ActiveTys) + return f + end + function inact(ctx, args2::Vararg{Any, N}) where N + args3 = ntuple(Val(N)) do i Base.@_inline_meta - if args[i] isa Active - arg_refs[i][] + if ActiveTys[i] + args2[i][] else - nothing + args2[i] end end + f(ctx, args3...) end + return inact +end + end diff --git a/src/KernelAbstractions.jl b/src/KernelAbstractions.jl index efec0c4d..3cc263c0 100644 --- a/src/KernelAbstractions.jl +++ b/src/KernelAbstractions.jl @@ -698,6 +698,18 @@ end __size(args::Tuple) = Tuple{args...} __size(i::Int) = Tuple{i} +""" + argconvert(::Kernel, arg) + +Convert arguments to the device side representation. +""" +argconvert(k::Kernel{T}, arg) where T = + error("Don't know how to convert arguments for Kernel{$T}") + +# Enzyme support +supports_enzyme(::Backend) = false +function __fake_compiler_job end + ### # Extras # - LoopInfo diff --git a/src/cpu.jl b/src/cpu.jl index 8c3e8afd..9779c79f 100644 --- a/src/cpu.jl +++ b/src/cpu.jl @@ -191,4 +191,6 @@ end end # Argument conversion -KernelAbstractions.argconvert(k::Kernel{CPU}, arg) = arg +argconvert(k::Kernel{CPU}, arg) = arg + +supports_enzyme(::CPU) = true diff --git a/src/reflection.jl b/src/reflection.jl index 3ab8080c..e0be71c6 100644 --- a/src/reflection.jl +++ b/src/reflection.jl @@ -1,9 +1,6 @@ import InteractiveUtils export @ka_code_typed, @ka_code_llvm -argconvert(k::Kernel{T}, arg) where T = - error("Don't know how to convert arguments for Kernel{$T}") - using UUIDs const Cthulhu = Base.PkgId(UUID("f68482b8-f384-11e8-15f7-abe071a5a75f"), "Cthulhu") diff --git a/test/Project.toml b/test/Project.toml index 231ca958..afd1389e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/test/extensions/enzyme.jl b/test/extensions/enzyme.jl index 39e6916a..48804ec2 100644 --- a/test/extensions/enzyme.jl +++ b/test/extensions/enzyme.jl @@ -10,7 +10,6 @@ end function square_caller(A, backend) kernel = square!(backend) kernel(A, ndrange=size(A)) - KernelAbstractions.synchronize(backend) end @@ -22,7 +21,6 @@ end function mul_caller(A, B, backend) kernel = mul!(backend) kernel(A, B, ndrange=size(A)) - KernelAbstractions.synchronize(backend) end function enzyme_testsuite(backend, ArrayT, supports_reverse=true) @@ -36,6 +34,7 @@ function enzyme_testsuite(backend, ArrayT, supports_reverse=true) dA .= 1 Enzyme.autodiff(Reverse, square_caller, Duplicated(A, dA), Const(backend())) + KernelAbstractions.synchronize(backend()) @test all(dA .≈ (2:2:128)) @@ -43,6 +42,7 @@ function enzyme_testsuite(backend, ArrayT, supports_reverse=true) dA .= 1 _, dB, _ = Enzyme.autodiff(Reverse, mul_caller, Duplicated(A, dA), Active(1.2), Const(backend()))[1] + KernelAbstractions.synchronize(backend()) @test all(dA .≈ 1.2) @test dB ≈ sum(1:1:64) @@ -52,6 +52,7 @@ function enzyme_testsuite(backend, ArrayT, supports_reverse=true) dA .= 1 Enzyme.autodiff(Forward, square_caller, Duplicated(A, dA), Const(backend())) + KernelAbstractions.synchronize(backend()) @test all(dA .≈ 2:2:128) end diff --git a/test/runtests.jl b/test/runtests.jl index 15c33e27..79b2ddf4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,4 @@ +using CUDA using KernelAbstractions using Test @@ -74,5 +75,8 @@ include("extensions/enzyme.jl") @static if VERSION >= v"1.7.0" @testset "Enzyme" begin enzyme_testsuite(CPU, Array) + if CUDA.functional() && CUDA.has_cuda_gpu() + enzyme_testsuite(CUDABackend, CuArray) + end end end