From aea063cea1b6fab688ebc08c194faf7e56d884bc Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 28 Sep 2023 15:19:04 -0400 Subject: [PATCH] Add EnzymeRule for conv!/gather!/scatter!/dropout!/pool! (#536) * Add EnzymeRule for conv * Fix * Add missing file * Change to enzymecore ext * attempt fix * Finish enzymecore rewrite * Add missing file * Also add gather * Additional functions, tests, and fixes * minor fixup * Add pooling * Add dropout * Fix scatter bug * fix pool * More fixups * fix up pool * split conv/depth * Cleanup rule * Fix minor test bug * Fix depthwise conv * Fix typo * Bound tests * Failing to extension * Add file * Address review * Remove inlining --- Project.toml | 19 +- ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl | 382 +++++++++++++++++++ test/conv.jl | 46 +++ test/dropout.jl | 31 +- test/gather.jl | 23 ++ test/pooling.jl | 22 ++ test/runtests.jl | 4 + test/scatter.jl | 23 ++ 8 files changed, 542 insertions(+), 8 deletions(-) create mode 100644 ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl diff --git a/Project.toml b/Project.toml index b33b90564..7fad9dfed 100644 --- a/Project.toml +++ b/Project.toml @@ -16,24 +16,27 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" -cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [extensions] NNlibAMDGPUExt = "AMDGPU" -NNlibCUDAExt = "CUDA" NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] +NNlibCUDAExt = "CUDA" +NNlibEnzymeCoreExt = "EnzymeCore" [compat] AMDGPU = "0.5, 0.6" Adapt = "3.2" Atomix = "0.1" -ChainRulesCore = "1.13" CUDA = "4, 5" -cuDNN = "1" +ChainRulesCore = "1.13" +EnzymeCore = "0.5, 0.6" GPUArraysCore = "0.1" KernelAbstractions = "0.9.2" Requires = "1.0" +cuDNN = "1" julia = "1.9" [extras] @@ -41,6 +44,9 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -52,6 +58,5 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter", - "FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff", - "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN"] +test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter", "FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff", "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN", + "Enzyme", "EnzymeCore", "EnzymeTestUtils"] diff --git a/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl b/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl new file mode 100644 index 000000000..7624463da --- /dev/null +++ b/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl @@ -0,0 +1,382 @@ +module NNlibEnzymeCoreExt + +using NNlib +import EnzymeCore +using Random + +using EnzymeCore.EnzymeRules + +for (name, dataname, filtername) in ((typeof(NNlib.conv!), NNlib.∇conv_data!, NNlib.∇conv_filter!), + (typeof(NNlib.depthwiseconv!), NNlib.∇depthwiseconv_data!, NNlib.∇depthwiseconv_filter!) ) + @eval begin + + function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT}, + y::EnzymeCore.Annotation{<:AbstractArray{yT, N}}, + x::EnzymeCore.Annotation{<:AbstractArray{xT, N}}, + w::EnzymeCore.Annotation{<:AbstractArray{wT, N}}, + cdims; kwargs...) where {RT, yT, xT, wT, N} + + if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated + func.val(y.val, x.val, w.val, cdims.val; kwargs...) + end + + primal = if EnzymeRules.needs_primal(config) + y.val + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) + y.dval + else + nothing + end + + # Cache x if its overwritten and w is active (and thus required) + cache_x = ( EnzymeRules.overwritten(config)[3] + && !(typeof(w) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(x.val) : nothing + + # Cache w if its overwritten and x is active (and thus required) + cache_w = ( EnzymeRules.overwritten(config)[4] + && !(typeof(x) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(w.val) : nothing + + cache = (cache_x, cache_w) + + return EnzymeRules.AugmentedReturn(primal, shadow, cache) + end + + function EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, ::Type{RT}, cache, + y::EnzymeCore.Annotation{<:AbstractArray{yT, N}}, + x::EnzymeCore.Annotation{<:AbstractArray{xT, N}}, + w::EnzymeCore.Annotation{<:AbstractArray{wT, N}}, + cdims; kwargs...) where {RT, yT, xT, wT, N} + cache_x, cache_w = cache + + # Don't cache x if not overwritten and w is active (and thus required) + if !(typeof(w) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(config)[3] + cache_x = x.val + end + end + + # Don't cache w if not overwritten and x is active (and thus required) + if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(config)[4] + cache_w = w.val + end + end + + dys = y.dval + dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval + dws = (typeof(w) <: EnzymeCore.Const) ? dys : w.dval + + if EnzymeRules.width(config) == 1 + dys = (dys,) + dxs = (dxs,) + dws = (dws,) + end + + for (dy, dx, dw) in zip(dys, dxs, dws) + if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val + + if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val + # dx += grad wrt x.val + $dataname(dx, dy, cache_w, cdims.val; alpha=xT(1), beta=xT(1), kwargs...) + end + if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val + # dw += grad wrt w.val + $filtername(dw, cache_x, dy, cdims.val; alpha=wT(1), beta=wT(1), kwargs...) + end + + dy .= 0 + end + end + + return (nothing, nothing, nothing, nothing) + end + +end +end + +function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} + + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated + func.val(dst.val, src.val, idx.val) + end + + primal = if EnzymeRules.needs_primal(config) + dst.val + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) + dst.dval + else + nothing + end + + # Cache idx if its overwritten + cache_idx = ( EnzymeRules.overwritten(config)[4] + && !(typeof(src) <: EnzymeCore.Const) + && !(typeof(dst) <: EnzymeCore.Const) + ) ? copy(idx.val) : nothing + + return EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) +end + +function EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, cache_idx, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} + + # Don't cache idx if not overwritten + if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(config)[4] + cache_idx = idx.val + end + end + + ddsts = dst.dval + dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval + + if EnzymeRules.width(config) == 1 + ddsts = (ddsts,) + dsrcs = (dsrcs,) + end + + for (ddst, dsrc) in zip(ddsts, dsrcs) + if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val + + if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val + NNlib.scatter!(+, dsrc, ddst, cache_idx) + end + + ddst .= 0 + end + end + + return (nothing, nothing, nothing) +end + + + +function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, op::EnzymeCore.Const, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} + + @assert !(OutType <: EnzymeCore.Const) + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated + func.val(op.val, dst.val, src.val, idx.val) + end + + primal = if EnzymeRules.needs_primal(config) + dst.val + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) + dst.dval + else + nothing + end + + # Cache idx if its overwritten + cache_idx = ( EnzymeRules.overwritten(config)[4] + && !(typeof(src) <: EnzymeCore.Const) + && !(typeof(dst) <: EnzymeCore.Const) + ) ? copy(idx.val) : nothing + + return EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) +end + +function EnzymeRules.reverse(config, + func::EnzymeCore.Const{typeof(NNlib.scatter!)}, + ::Type{RT}, + cache_idx, + op::Union{EnzymeCore.Const{typeof(+)},EnzymeCore.Const{typeof(-)}}, dst::OutType, + src, + idx::EnzymeCore.Const) where {OutType, RT} + + # Don't cache idx if not overwritten + if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(config)[4] + cache_idx = idx.val + end + end + + ddsts = dst.dval + dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval + + if EnzymeRules.width(config) == 1 + ddsts = (ddsts,) + dsrcs = (dsrcs,) + end + + for (ddst, dsrc) in zip(ddsts, dsrcs) + if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val + + if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val + + if eltype(typeof(op)) == typeof(+) + dsrc .+= NNlib.gather(ddst, cache_idx) + else + @assert eltype(typeof(op)) == typeof(-) + dsrc .-= NNlib.gather(ddst, cache_idx) + end + end + + end + end + + return (nothing, nothing, nothing, nothing) +end + + + +for pool in [:maxpool, :meanpool, :lpnormpool] + pool! = Symbol(pool, :!) + ∇pool = Symbol(:∇, pool, :!) + + @eval begin + +function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, y::OutType, x, dims; kwargs...) where {OutType, RT} + + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated + func.val(y.val, x.val, dims.val; kwargs...) + end + + primal = if EnzymeRules.needs_primal(config) + y.val + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) + y.dval + else + nothing + end + + cache_y = ( EnzymeRules.overwritten(config)[2] + && !(typeof(x) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(y.val) : nothing + + cache_x = ( EnzymeRules.overwritten(config)[3] + && !(typeof(x) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(x.val) : nothing + + cache = (cache_y, cache_x) + + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, cache, y, x, dims; kwargs...) where {RT} + cache_y, cache_x = cache + + # Don't cache y if not overwritten + if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(config)[2] + cache_y = y.val + end + end + + # Don't cache x if not overwritten + if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(config)[3] + cache_x = x.val + end + end + + dys = y.dval + dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval + + if EnzymeRules.width(config) == 1 + dys = (dys,) + dxs = (dxs,) + end + + for (dy, dx) in zip(dys, dxs) + if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val + + if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val + NNlib.$(∇pool)(dx, dy, cache_y, cache_x, dims.val; alpha=eltype(dx)(1), beta=eltype(dx)(1), kwargs...) + end + + dy .= 0 + end + end + + return (nothing, nothing, nothing) +end + +end +end + +function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, rng, dst::OutType, src, p, dims) where {OutType, RT} + + T = float(real(eltype(dst.val))) + val = convert(T, 1/(1-p.val)) + keep = if dims.val isa Colon + similar(dst.val, T, size(dst.val)) + else + similar(dst.val, T, ntuple(d -> d in dims.val ? size(dst.val,d) : 1, ndims(dst.val))) + end + rand!(rng.val, keep) + + keep = keep .> p.val + + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated + dst.val .= (keep .* val) .* src.val + end + + primal = if EnzymeRules.needs_primal(config) + dst.val + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) + dst.dval + else + nothing + end + + if typeof(dst) <: EnzymeCore.Const || typeof(src) <: EnzymeCore.Const + keep = nothing + end + + return EnzymeRules.AugmentedReturn(primal, shadow, keep) +end + +function EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, keep, rng, dst::OutType, src, p, dims) where {OutType, RT} + T = float(real(eltype(dst.val))) + val = convert(T, 1/(1-p.val)) + + ddsts = dst.dval + dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval + + if EnzymeRules.width(config) == 1 + ddsts = (ddsts,) + dsrcs = (dsrcs,) + end + + for (ddst, dsrc) in zip(ddsts, dsrcs) + if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val + + if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val + dsrc .+= (keep .* val) .* ddst + end + + ddst .= 0 + end + end + + dp = if typeof(p) <: EnzymeCore.Active + typeof(p.val)(0) + else + nothing + end + + return (nothing, nothing, nothing, dp, nothing) +end + + +end \ No newline at end of file diff --git a/test/conv.jl b/test/conv.jl index 8107bd387..dc3fc57f5 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -886,3 +886,49 @@ end gradtest((y, w) -> ∇depthwiseconv_data(y, w, dcdims), y, w) gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w) end + +@static if Test_Enzyme + +@testset "EnzymeRules: conv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) + x = rand(rng, repeat([5], spatial_rank)..., 3, 2) + w = rand(rng, repeat([3], spatial_rank)..., 3, 3) + + cdims = DenseConvDims(x, w) + + curconv = conv + curconv! = conv! + dst = curconv(x, w, cdims) + + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tx, Tw) || continue + + EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (w, Tw), (cdims, EnzymeCore.Const)) + end +end + +@testset "EnzymeRules: depthwiseconv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) + x = rand(rng, repeat([5], spatial_rank)..., 3, 2) + w = rand(rng, repeat([3], spatial_rank)..., 3, 3) + + cdims = DepthwiseConvDims(x, w) + + curconv = depthwiseconv + curconv! = depthwiseconv! + dst = curconv(x, w, cdims) + + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tx, Tw) || continue + + EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (w, Tw), (cdims, EnzymeCore.Const)) + end +end + +end \ No newline at end of file diff --git a/test/dropout.jl b/test/dropout.jl index 07a48edc5..0da70111e 100644 --- a/test/dropout.jl +++ b/test/dropout.jl @@ -1,5 +1,5 @@ using NNlib, Test, Statistics, Random, LinearAlgebra -using Zygote, StableRNGs, ChainRulesCore +using Zygote, StableRNGs, ChainRulesCore, Enzyme @testset "dropout" begin # Basics @@ -75,3 +75,32 @@ using Zygote, StableRNGs, ChainRulesCore @test_throws ArgumentError dropout(x1, 2) @test_throws ArgumentError dropout!(y1, x1, 3) end + +@static if Test_Enzyme + +@testset "EnzymeRules: dropout " begin + rng = Random.default_rng() + + x1 = randn(Float32, 3000, 4000) + dx1 = zeros(Float32, 3000, 4000) + + dout = randn(Float32, 3000, 4000) + + p = 0.2f0 + + forward, reverse = Enzyme.autodiff_thunk(ReverseSplitWithPrimal, typeof(Const(dropout)), Duplicated, typeof(Const(rng)), typeof(Duplicated(x1, dx1)), typeof(Const(0.2f0))) + + tape, primal, shadow = forward(Const(dropout), Const(rng), Duplicated(x1, dx1), Const(p)) + + shadow .= dout + + reverse(Const(dropout), Const(rng), Duplicated(x1, dx1), Const(p), tape) + + @test dx1[.!tape[1]] ≈ zero(x1)[.!tape[1]] + + val = convert(Float32, 1/(1-p)) + + @test dx1[tape[1]] ≈ (val * dout)[tape[1]] +end + +end \ No newline at end of file diff --git a/test/gather.jl b/test/gather.jl index e3221145b..92e3bfb7d 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -1,4 +1,6 @@ using NNlib: gather, gather! +import EnzymeTestUtils +using EnzymeCore function gather_testsuite(Backend) device(x) = adapt(Backend(), x) @@ -152,6 +154,27 @@ function gather_testsuite(Backend) gradtest_fn((s, i) -> gather(s, i), src, idx) end + @static if Test_Enzyme + + @testset "EnzymeRules: gather! gradient for scalar index" begin + src = device(Float64[3, 4, 5, 6, 7]) + idx = device([ + 1 2 3 4; + 4 2 1 3; + 3 5 5 3]) + dst = gather(src, idx) + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue + + EnzymeTestUtils.test_reverse(gather!, Tret, (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const)) + end + end + + end + @testset "gather gradient for tuple index" begin src = device(Float64[ 3 5 7 diff --git a/test/pooling.jl b/test/pooling.jl index b4b4f40b7..97d014d52 100644 --- a/test/pooling.jl +++ b/test/pooling.jl @@ -967,3 +967,25 @@ end gradtest(x -> sum(maxpool(x, k)), x, skip = spatial_rank==2) gradtest(x -> sum(meanpool(x, k)), x) end + +@static if Test_Enzyme + +@testset "EnzymeRules: pooling! $pool spatial_rank=$spatial_rank " for spatial_rank in (1, 2), + (pool, pool!) in ((maxpool, maxpool!), (meanpool, meanpool!)) + + x = rand(rng, repeat([10], spatial_rank)..., 3, 2) + pdims = PoolDims(x, 2) + y = pool(x, pdims) + + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue + + EnzymeTestUtils.test_reverse(pool!, Tret, (y, Tdst), (x, Tsrc), (pdims, EnzymeCore.Const)) + end + +end + +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 03602a40d..8b359ad87 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,8 @@ using NNlib, Test, Statistics, Random using ChainRulesCore, ChainRulesTestUtils using Base.Broadcast: broadcasted +import EnzymeTestUtils +using EnzymeCore import FiniteDifferences import ForwardDiff import Zygote @@ -11,6 +13,8 @@ using Adapt using KernelAbstractions import ReverseDiff as RD # used in `pooling.jl` +const Test_Enzyme = VERSION <= v"1.10" && !Sys.iswindows() + DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursive=true) # ENV["NNLIB_TEST_CUDA"] = "true" # uncomment to run CUDA tests diff --git a/test/scatter.jl b/test/scatter.jl index 26fc06cde..7e06e2650 100644 --- a/test/scatter.jl +++ b/test/scatter.jl @@ -207,5 +207,28 @@ function scatter_testsuite(Backend) gradtest_fn((xs, i) -> scatter(op, xs, i), src, idx) end end + + @static if Test_Enzyme + + @testset "EnzymeRules" begin + idx = device([2, 2, 3, 4, 4]) + src = device(ones(T, 3, 5)) + + for op in (+, -) + + dst = scatter(op, src, idx) + + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue + + EnzymeTestUtils.test_reverse(scatter!, Tret, (op, EnzymeCore.Const), (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const)) + end + end + end + + end end end