You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I want to desing rrule from chainrules for my kernel. Below is simple reproducible example.
System info:
Julia 1.10
CUDA v5.4.0
ChainRulesCore v1.23.0
ChainRulesTestUtils v1.13.0
Enzyme v0.12.9 https://github.com/EnzymeAD/Enzyme.jl.git#main
EnzymeTestUtils v0.1.7
KernelAbstractions v0.9.20 https://github.com/JuliaGPU/KernelAbstractions.jl#main
GPU: Nvidia RTX 3090
code
using KernelAbstractions
using ChainRulesCore, Zygote, CUDA, Enzyme, Test
@kernel function example_kenr(@Const(A),A_out)
index = @index(Global)
shared_arr = @localmem Float32 (@groupsize()[1], 1)
shared_arr[@index(Local, Linear)] = A[index]
A_out[index] = shared_arr[@index(Local, Linear), 1]
index = @index(Global)
end
function call_example(A,A_out)
dev = get_backend(A)
example_kenr(dev, 256)(A,A_out, ndrange=(size(A)[1]))
KernelAbstractions.synchronize(dev)
return nothing
end
A=CUDA.ones(10).*2
A_out=CUDA.ones(10)
call_example(A,A_out)
@test A_out == CUDA.ones(10).*2
function ChainRulesCore.rrule(::typeof(call_example), A,A_out)
#modify A_out by mutation
call_example(A,A_out)
function call_test_kernel1_pullback(d_A_out)
d_A_out = CuArray(collect(d_A_out))
d_A = CUDA.zeros(size(A)...)
Enzyme.autodiff_deferred(Enzyme.Reverse, call_example, Const, Duplicated(A,d_A), Duplicated(A_out, d_A_out))
#NoTangent for the function itself
return NoTangent(), d_A,d_A_out
end
return A_out, call_test_kernel1_pullback
end
out,pull_back=rrule(call_example,A,A_out)
pull_back(CUDA.ones(10))
I want to desing rrule from chainrules for my kernel. Below is simple reproducible example.
System info:
Julia 1.10
CUDA v5.4.0
ChainRulesCore v1.23.0
ChainRulesTestUtils v1.13.0
Enzyme v0.12.9
https://github.com/EnzymeAD/Enzyme.jl.git#main
EnzymeTestUtils v0.1.7
KernelAbstractions v0.9.20
https://github.com/JuliaGPU/KernelAbstractions.jl#main
GPU: Nvidia RTX 3090
code
error
The text was updated successfully, but these errors were encountered: