Skip to content

Commit

Permalink
Enzyme: reverse mode kernels (#2422)
Browse files Browse the repository at this point in the history
[skip julia]
[skip cuda]
[skip subpackages]
[skip special]
  • Loading branch information
wsmoses authored Jul 31, 2024
1 parent beccab1 commit d7077da
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 0 deletions.
150 changes: 150 additions & 0 deletions ext/EnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,156 @@ function EnzymeCore.EnzymeRules.forward(ofn::EnzymeCore.Annotation{CUDA.HostKern
return nothing
end

function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{typeof(cufunction)},
::Type{RT}, f::Const{F},
tt::Const{TT}; kwargs...) where {F,CT, RT<:EnzymeCore.Annotation{CT}, TT}
res = ofn.val(f.val, tt.val; kwargs...)

primal = if EnzymeRules.needs_primal(config)
res
else
nothing
end

shadow = if EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
res
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
res
end
end
else
nothing
end
return EnzymeRules.AugmentedReturn{(EnzymeRules.needs_primal(config) ? CT : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? CT : NTuple{EnzymeRules.width(config), CT}) : Nothing), Nothing}(primal, shadow, nothing)
end

function EnzymeCore.EnzymeRules.reverse(config, ofn::EnzymeCore.Const{typeof(cufunction)},::Type{RT}, subtape, f, tt; kwargs...) where RT
return (nothing, nothing)
end

function meta_augf(f, tape::CuDeviceArray{TapeType}, ::Val{ModifiedBetween}, args::Vararg{Any, N}) where {N, ModifiedBetween, TapeType}
forward, _ = EnzymeCore.autodiff_deferred_thunk(
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
TapeType,
Const{Core.Typeof(f)},
Const{Nothing},
map(typeof, args)...,
)

idx = 0
# idx *= gridDim().x
idx += blockIdx().x-1

idx *= gridDim().y
idx += blockIdx().y-1

idx *= gridDim().z
idx += blockIdx().z-1

idx *= blockDim().x
idx += threadIdx().x-1

idx *= blockDim().y
idx += threadIdx().y-1

idx *= blockDim().z
idx += threadIdx().z-1
idx += 1

@inbounds tape[idx] = forward(Const(f), args...)[1]
nothing
end

function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::EnzymeCore.Annotation{CUDA.HostKernel{F,TT}},
::Type{Const{Nothing}}, args0...;
threads::CuDim=1, blocks::CuDim=1, kwargs...) where {F,TT}
args = ((cudaconvert(arg) for arg in args0)...,)
ModifiedBetween = overwritten(config)
TapeType = EnzymeCore.tape_type(
EnzymeCore.compiler_job_from_backend(CUDABackend(), typeof(Base.identity), Tuple{Float64}),
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
Const{F},
Const{Nothing},
map(typeof, args)...,
)
threads = CuDim3(threads)
blocks = CuDim3(blocks)
subtape = CuArray{TapeType}(undef, blocks.x*blocks.y*blocks.z*threads.x*threads.y*threads.z)

GC.@preserve args subtape, begin
subtape2 = cudaconvert(subtape)
T2 = (F, typeof(subtape2), Val{ModifiedBetween}, (typeof(a) for a in args)...)
TT2 = Tuple{T2...}
cuf = cufunction(meta_augf, TT2)
res = cuf(ofn.val.f, subtape2, Val(ModifiedBetween), args...; threads=(threads.x, threads.y, threads.z), blocks=(blocks.x, blocks.y, blocks.z), kwargs...)
end

return AugmentedReturn{Nothing,Nothing,CuArray}(nothing, nothing, subtape)
end

function meta_revf(f, tape::CuDeviceArray{TapeType}, ::Val{ModifiedBetween}, args::Vararg{Any, N}) where {N, ModifiedBetween, TapeType}
_, reverse = EnzymeCore.autodiff_deferred_thunk(
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
TapeType,
Const{Core.Typeof(f)},
Const{Nothing},
map(typeof, args)...,
)

idx = 0
# idx *= gridDim().x
idx += blockIdx().x-1

idx *= gridDim().y
idx += blockIdx().y-1

idx *= gridDim().z
idx += blockIdx().z-1

idx *= blockDim().x
idx += threadIdx().x-1

idx *= blockDim().y
idx += threadIdx().y-1

idx *= blockDim().z
idx += threadIdx().z-1
idx += 1
reverse(Const(f), args..., @inbounds tape[idx])
nothing
end

function EnzymeCore.EnzymeRules.reverse(config, ofn::EnzymeCore.Annotation{CUDA.HostKernel{F,TT}},
::Type{Const{Nothing}}, subtape, args0...;
threads::CuDim=1, blocks::CuDim=1, kwargs...) where {F,TT}
args = ((cudaconvert(arg) for arg in args0)...,)
ModifiedBetween = overwritten(config)
TapeType = EnzymeCore.tape_type(
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
Const{F},
Const{Nothing},
map(typeof, args)...,
)
threads = CuDim3(threads)
blocks = CuDim3(blocks)

GC.@preserve args0 subtape, begin
subtape2 = cudaconvert(subtape)
T2 = (F, typeof(subtape2), Val{ModifiedBetween}, (typeof(a) for a in args)...)
TT2 = Tuple{T2...}
cuf = cufunction(meta_revf, TT2)
res = cuf(ofn.val.f, subtape2, Val(ModifiedBetween), args...; threads=(threads.x, threads.y, threads.z), blocks=(blocks.x, blocks.y, blocks.z), kwargs...)
end

return ntuple(Val(length(args0))) do i
Base.@_inline_meta
nothing
end
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(Base.fill!)}, ::Type{RT}, A::EnzymeCore.Annotation{<:DenseCuArray{T}}, x) where {RT, T <: CUDA.MemsetCompatTypes}
if A isa Const || A isa Duplicated || A isa BatchDuplicated
ofn.val(A.val, x.val)
Expand Down
19 changes: 19 additions & 0 deletions test/extensions/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,25 @@ end
@test all(dA2 .≈ 3*(2:2:64))
end

@testset "Reverse Kernel" begin
A = CUDA.rand(64)
dA = CUDA.ones(64)
A .= (1:1:64)
dA .= 1
Enzyme.autodiff(Reverse, square!, Duplicated(A, dA))
@test all(dA .≈ (2:2:128))

A = CUDA.rand(32)
dA = CUDA.ones(32)
dA2 = CUDA.ones(32)
A .= (1:1:32)
dA .= 1
dA2 .= 3
Enzyme.autodiff(Reverse, square!, BatchDuplicated(A, (dA, dA2)))
@test all(dA .≈ (2:2:64))
@test all(dA2 .≈ 3*(2:2:64))
end

@testset "Forward Fill!" begin
A = CUDA.ones(64)
dA = CUDA.ones(64)
Expand Down

0 comments on commit d7077da

Please sign in to comment.