Skip to content

Commit

Permalink
Compiler crashes
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed Dec 15, 2023
1 parent df499c2 commit d6c5a72
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 3 deletions.
63 changes: 62 additions & 1 deletion ext/EnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ module EnzymeExt
using ..EnzymeCore
using ..EnzymeCore.EnzymeRules
end
import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU
import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU, GPU
using CUDA

EnzymeRules.inactive(::Type{StaticSize}, x...) = nothing

Expand Down Expand Up @@ -55,6 +56,7 @@ module EnzymeExt
end

function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CPU}}, ::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
println("Custom rule CPU")
kernel = func.val
f = kernel.f

Expand Down Expand Up @@ -93,6 +95,65 @@ module EnzymeExt

# TODO in KA backends like CUDAKernels, etc have a version with a parent job type
TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args2)...)
@show TapeType


subtape = Array{TapeType}(undef, size(blocks(iterspace)))

aug_kernel = similar(kernel, aug_fwd)

aug_kernel(f, ModifiedBetween, subtape, args2...; ndrange, workgroupsize)

# TODO the fact that ctxTy is type unstable means this is all type unstable.
# Since custom rules require a fixed return type, explicitly cast to Any, rather
# than returning a AugmentedReturn{Nothing, Nothing, T} where T.

res = AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs)}}(nothing, nothing, (subtape, arg_refs))
return res
end

function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CUDABackend}}, type::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
println("Custom rule GPU")
kernel = func.val
f = kernel.f
mi = CUDA.methodinstance(typeof(()->return), Tuple{})
job = CUDA.CompilerJob(mi, CUDA.compiler_config(device()))

ndrange, workgroupsize, iterspace, dynamic = launch_config(kernel, ndrange, workgroupsize)
block = first(blocks(iterspace))
ctx = mkcontext(kernel, ndrange, iterspace)
ctxTy = Core.Typeof(ctx) # CompilerMetadata{ndrange(kernel), Core.Typeof(dynamic)}

# TODO autodiff_deferred on the func.val
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))

tup = Val(ntuple(Val(N)) do i
Base.@_inline_meta
args[i] isa Active
end)
f = make_active_byref(f, tup)
FT = Const{Core.Typeof(f)}

arg_refs = ntuple(Val(N)) do i
Base.@_inline_meta
if args[i] isa Active
Ref(EnzymeCore.make_zero(args[i].val))
else
nothing
end
end
args2 = ntuple(Val(N)) do i
Base.@_inline_meta
if args[i] isa Active
Duplicated(Ref(args[i].val), arg_refs[i])
else
args[i]
end
end

# TODO in KA backends like CUDAKernels, etc have a version with a parent job type
TapeType = EnzymeCore.tape_type(job, ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args2)...)
@show TapeType


subtape = Array{TapeType}(undef, size(blocks(iterspace)))
Expand Down
4 changes: 2 additions & 2 deletions reverse_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
end
end

enzyme_testsuite(CPU, Array, true)
# enzyme_testsuite(CPU, Array, true)
# enzyme_testsuite(CUDABackend, CuArray, false)
enzyme_testsuite(CUDABackend, CuArray, true)
enzyme_testsuite(CUDABackend, CuArray, true)

0 comments on commit d6c5a72

Please sign in to comment.