Skip to content

Commit

Permalink
Adding extension fix and allocate
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed Jan 29, 2024
1 parent 9dd121d commit b82dfb4
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 45 deletions.
8 changes: 5 additions & 3 deletions ext/CUDAEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ module CUDAEnzymeExt
import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU, GPU
using CUDA

include("enzyme_utils.jl")

function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CUDABackend}}, type::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
println("Custom rule GPU")
kernel = func.val
Expand Down Expand Up @@ -52,8 +54,8 @@ module CUDAEnzymeExt
TapeType = EnzymeCore.tape_type(job, ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args2)...)
@show TapeType


subtape = Array{TapeType}(undef, size(blocks(iterspace)))
subtape = allocate(CUDABackend(), TapeType, size(blocks(iterspace)))
# subtape = Array{TapeType}(undef, size(blocks(iterspace)))

aug_kernel = similar(kernel, aug_fwd)

Expand All @@ -67,4 +69,4 @@ module CUDAEnzymeExt
return res
end

end
end
38 changes: 2 additions & 36 deletions ext/EnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,9 @@ module EnzymeExt
end
import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU, GPU

EnzymeRules.inactive(::Type{StaticSize}, x...) = nothing

function fwd(ctx, f, args...)
EnzymeCore.autodiff_deferred(Forward, Const(f), Const, Const(ctx), args...)
return nothing
end
include("enzyme_utils.jl")

function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
subtape[__groupindex(ctx)] = forward(Const(f), Const(ctx), args...)[1]
return nothing
end

function rev(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
tp = subtape[__groupindex(ctx)]
reverse(Const(f), Const(ctx), args..., tp)
return nothing
end
EnzymeRules.inactive(::Type{StaticSize}, x...) = nothing

function EnzymeRules.forward(func::Const{<:Kernel}, ::Type{Const{Nothing}}, args...; ndrange=nothing, workgroupsize=nothing)
kernel = func.val
Expand All @@ -36,24 +20,6 @@ module EnzymeExt
fwd_kernel(f, args...; ndrange, workgroupsize)
end

@inline function make_active_byref(f::F, ::Val{ActiveTys}) where {F, ActiveTys}
if !any(ActiveTys)
return f
end
function inact(ctx, args2::Vararg{Any, N}) where N
args3 = ntuple(Val(N)) do i
Base.@_inline_meta
if ActiveTys[i]
args2[i][]
else
args2[i]
end
end
f(ctx, args3...)
end
return inact
end

function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CPU}}, ::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
println("Custom rule CPU")
kernel = func.val
Expand Down
35 changes: 35 additions & 0 deletions ext/enzyme_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
function fwd(ctx, f, args...)
EnzymeCore.autodiff_deferred(Forward, Const(f), Const, Const(ctx), args...)
return nothing
end

function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
subtape[__groupindex(ctx)] = forward(Const(f), Const(ctx), args...)[1]
return nothing
end

function rev(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
tp = subtape[__groupindex(ctx)]
reverse(Const(f), Const(ctx), args..., tp)
return nothing
end

@inline function make_active_byref(f::F, ::Val{ActiveTys}) where {F, ActiveTys}
if !any(ActiveTys)
return f
end
function inact(ctx, args2::Vararg{Any, N}) where N
args3 = ntuple(Val(N)) do i
Base.@_inline_meta
if ActiveTys[i]
args2[i][]
else
args2[i]
end
end
f(ctx, args3...)
end
return inact
end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Enzyme_jll = "7cc45869-7501-5eee-bdea-0790c847d4ef"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
# KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
5 changes: 3 additions & 2 deletions test/extensions/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ end
function square_caller(A, backend)
kernel = square!(backend)
kernel(A, ndrange=size(A))
KernelAbstractions.synchronize(backend)
end


Expand All @@ -22,7 +21,6 @@ end
function mul_caller(A, B, backend)
kernel = mul!(backend)
kernel(A, B, ndrange=size(A))
KernelAbstractions.synchronize(backend)
end

function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
Expand All @@ -36,13 +34,15 @@ function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
dA .= 1

Enzyme.autodiff(Reverse, square_caller, Duplicated(A, dA), Const(backend()))
KernelAbstractions.synchronize(backend())
@test all(dA .≈ (2:2:128))


A .= (1:1:64)
dA .= 1

_, dB, _ = Enzyme.autodiff(Reverse, mul_caller, Duplicated(A, dA), Active(1.2), Const(backend()))[1]
KernelAbstractions.synchronize(backend())

@test all(dA .≈ 1.2)
@test dB sum(1:1:64)
Expand All @@ -52,6 +52,7 @@ function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
dA .= 1

Enzyme.autodiff(Forward, square_caller, Duplicated(A, dA), Const(backend()))
KernelAbstractions.synchronize(backend())
@test all(dA .≈ 2:2:128)

end
Expand Down
9 changes: 5 additions & 4 deletions test/reverse_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ end
function square_caller(A, backend)
kernel = square!(backend)
kernel(A, ndrange=size(A))
KernelAbstractions.synchronize(backend)
end


Expand All @@ -23,7 +22,6 @@ end
function mul_caller(A, B, backend)
kernel = mul!(backend)
kernel(A, B, ndrange=size(A))
KernelAbstractions.synchronize(backend)
end

function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
Expand All @@ -37,13 +35,15 @@ function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
dA .= 1

Enzyme.autodiff(Reverse, square_caller, Duplicated(A, dA), Const(backend()))
KernelAbstractions.synchronize(backend())
@test all(dA .≈ (2:2:128))


A .= (1:1:64)
dA .= 1

_, dB, _ = Enzyme.autodiff(Reverse, mul_caller, Duplicated(A, dA), Active(1.2), Const(backend()))[1]
KernelAbstractions.synchronize(backend())

@test all(dA .≈ 1.2)
@test dB sum(1:1:64)
Expand All @@ -53,11 +53,12 @@ function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
dA .= 1

Enzyme.autodiff(Forward, square_caller, Duplicated(A, dA), Const(backend()))
KernelAbstractions.synchronize(backend())
@test all(dA .≈ 2:2:128)

end
end

# enzyme_testsuite(CPU, Array, true)
enzyme_testsuite(CUDABackend, CuArray, false)
# enzyme_testsuite(CUDABackend, CuArray, true)
# enzyme_testsuite(CUDABackend, CuArray, false)
enzyme_testsuite(CUDABackend, CuArray, true)
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using CUDA
using KernelAbstractions
using Test

Expand Down Expand Up @@ -74,5 +75,7 @@ include("extensions/enzyme.jl")
@static if VERSION >= v"1.7.0"
@testset "Enzyme" begin
enzyme_testsuite(CPU, Array)
enzyme_testsuite(CUDABackend, CuArray, false)
# enzyme_testsuite(CUDABackend, CuArray, true)
end
end

0 comments on commit b82dfb4

Please sign in to comment.