Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EnzymeRules for batchnorm #537

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Prev Previous commit
rebase
wsmoses committed Oct 8, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 1933a9afe32ff020ba550f34ee882c1a557f5f33
8 changes: 2 additions & 6 deletions ext/NNlibCUDACUDNNExt/batchnorm.jl
Original file line number Diff line number Diff line change
@@ -186,9 +186,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{
cache_running_var = nothing

if !(typeof(y) <: EnzymeCore.Const)
if !(typeof(x) <: EnzymeCore.Const)
|| !(typeof(g) <: EnzymeCore.Const)
|| !(typeof(b) <: EnzymeCore.Const)
if !(typeof(x) <: EnzymeCore.Const) || !(typeof(g) <: EnzymeCore.Const) || !(typeof(b) <: EnzymeCore.Const)

if EnzymeCore.EnzymeRules.overwritten(config)[3]
cache_g = copy(g.val)
@@ -218,9 +216,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(cu
cache_g, cache_x, cache_running_mean, cache_running_var = cache

if !(typeof(y) <: EnzymeCore.Const)
if !(typeof(x) <: EnzymeCore.Const)
|| !(typeof(g) <: EnzymeCore.Const)
|| !(typeof(b) <: EnzymeCore.Const)
if !(typeof(x) <: EnzymeCore.Const) || !(typeof(g) <: EnzymeCore.Const) || !(typeof(b) <: EnzymeCore.Const)

if EnzymeCore.EnzymeRules.overwritten(config)[3]
cache_g = g.val
2 changes: 0 additions & 2 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
@@ -123,6 +123,4 @@ include("impl/depthwiseconv_im2col.jl")
include("impl/pooling_direct.jl")
include("deprecations.jl")

include("enzyme.jl")

end # module NNlib
358 changes: 0 additions & 358 deletions src/enzyme.jl

This file was deleted.

18 changes: 0 additions & 18 deletions test/gather.jl
Original file line number Diff line number Diff line change
@@ -154,24 +154,6 @@ function gather_testsuite(Backend)
gradtest_fn((s, i) -> gather(s, i), src, idx)
end


@testset "EnzymeRules: gather! gradient for scalar index" begin
src = device(Float64[3, 4, 5, 6, 7])
idx = device([
1 2 3 4;
4 2 1 3;
3 5 5 3])
dst = gather(src, idx)
for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)

EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue

EnzymeTestUtils.test_reverse(gather!, Tret, (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const))
end
end

@static if Test_Enzyme

@testset "EnzymeRules: gather! gradient for scalar index" begin