From 5966cf953336e30aefa3ed4a8e7f2ae5b1fa5fdf Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 02:03:26 -0500 Subject: [PATCH 01/26] Add EnzymeRule for conv --- Project.toml | 6 ++- ext/NNlibEnzymeExt.jl | 88 +++++++++++++++++++++++++++++++++++++++++++ src/NNlib.jl | 8 ++++ 3 files changed, 101 insertions(+), 1 deletion(-) create mode 100644 ext/NNlibEnzymeExt.jl diff --git a/Project.toml b/Project.toml index b33b90564..aecfcf057 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.9.6" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -16,6 +17,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" @@ -23,6 +25,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" NNlibAMDGPUExt = "AMDGPU" NNlibCUDAExt = "CUDA" NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] +NNlibEnzymeExt = "Enzyme" [compat] AMDGPU = "0.5, 0.6" @@ -41,6 +44,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -54,4 +58,4 @@ cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter", "FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff", - "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN"] + "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN", "Enzyme"] diff --git a/ext/NNlibEnzymeExt.jl b/ext/NNlibEnzymeExt.jl new file mode 100644 index 000000000..16d3dd809 --- /dev/null +++ b/ext/NNlibEnzymeExt.jl @@ -0,0 +1,88 @@ +module NNlibEnzymeExt + +using NNlib +isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) + +using Enzyme + +using EnzymeCore + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims::Const; kwargs...) +) where {OutType, RT} + + @assert !(OutType <: Const) + if OutType <: Duplicated || OutType <: DuplicatedNoNeed + func.val(y.val, x.val, y.val, cdims.val; kwargs...) + end + + dres = if EnzymeRules.width(config) == 1 + func.val(prob.dval, alg.val; kwargs...) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + func.val(prob.dval[i], alg.val; kwargs...) + end + end + + primal = if EnzymeRules.needs_primal(config) + y.val + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) + y.dval + else + nothing + end + + # Cache x if its overwritten and w is active (and thus required) + cache_x = ( EnzymeRules.overwritten(config)[3] && !(typeof(w) <: Const) ) ? copy(x.val) : nothing + + # Cache w if its overwritten and x is active (and thus required) + cache_w = ( EnzymeRules.overwritten(config)[4] && !(typeof(x) <: Const) ) ? copy(w.val) : nothing + + cache = (cache_x, cache_w) + + return EnzymeCore.EnzymeRules.AugmentedReturn(y.val, y.dval, cache) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims::Const; kwargs...) where {RT} + cache_x, cache_w = cache + + # Don't cache x if not overwritten and w is active (and thus required) + if !(typeof(w) <: Const) + if !EnzymeRules.overwritten(config)[3] + cache_x = x.val + end + end + + # Don't cache w if not overwritten and x is active (and thus required) + if !(typeof(x) <: Const) + if !EnzymeRules.overwritten(config)[4] + cache_w = w.val + end + end + + dys = y.dval + dxs = (typeof(x) <: Const) ? nothing : x.dval + dws = (typeof(w) <: Const) ? nothing : w.dval + + if EnzymeRules.width(config) == 1 + dys = (dys,) + dxs = (dxs,) + dws = (dws,) + end + + for (dy, dx, dw) in (dys, dxs, dws) + if !(typeof(x) <: Const) + # dx += grad wrt x + NNlib.∇conv_data!(dx, dy, cache_w, cdims; alpha=1, beta=1, kwargs...) + end + if !(typeof(y) <: Const) + # dw += grad wrt w + NNlib.∇conv_filter!(dw, cache_x, dy, cdims; alpha=1, beta=1, kwargs...) + end + end + + return (nothing, nothing, nothing, nothing) +end \ No newline at end of file diff --git a/src/NNlib.jl b/src/NNlib.jl index 8450a0261..2e39f9448 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -123,4 +123,12 @@ include("impl/depthwiseconv_im2col.jl") include("impl/pooling_direct.jl") include("deprecations.jl") +function __init__() + @static if !isdefined(Base, :get_extension) + @require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin + include("../ext/NNlibEnzymeExt.jl") + end + end +end + end # module NNlib From 9356ca3987f5567627458ac32231378aee76631c Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 02:52:38 -0500 Subject: [PATCH 02/26] Fix --- Project.toml | 1 + ext/NNlibEnzymeExt.jl | 88 ------------------------------------------- src/NNlib.jl | 8 ++-- 3 files changed, 4 insertions(+), 93 deletions(-) delete mode 100644 ext/NNlibEnzymeExt.jl diff --git a/Project.toml b/Project.toml index aecfcf057..a9064ab57 100644 --- a/Project.toml +++ b/Project.toml @@ -34,6 +34,7 @@ Atomix = "0.1" ChainRulesCore = "1.13" CUDA = "4, 5" cuDNN = "1" +Enzyme = "0.11.8" GPUArraysCore = "0.1" KernelAbstractions = "0.9.2" Requires = "1.0" diff --git a/ext/NNlibEnzymeExt.jl b/ext/NNlibEnzymeExt.jl deleted file mode 100644 index 16d3dd809..000000000 --- a/ext/NNlibEnzymeExt.jl +++ /dev/null @@ -1,88 +0,0 @@ -module NNlibEnzymeExt - -using NNlib -isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) - -using Enzyme - -using EnzymeCore - -function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims::Const; kwargs...) -) where {OutType, RT} - - @assert !(OutType <: Const) - if OutType <: Duplicated || OutType <: DuplicatedNoNeed - func.val(y.val, x.val, y.val, cdims.val; kwargs...) - end - - dres = if EnzymeRules.width(config) == 1 - func.val(prob.dval, alg.val; kwargs...) - else - ntuple(Val(EnzymeRules.width(config))) do i - Base.@_inline_meta - func.val(prob.dval[i], alg.val; kwargs...) - end - end - - primal = if EnzymeRules.needs_primal(config) - y.val - else - nothing - end - shadow = if EnzymeRules.needs_shadow(config) - y.dval - else - nothing - end - - # Cache x if its overwritten and w is active (and thus required) - cache_x = ( EnzymeRules.overwritten(config)[3] && !(typeof(w) <: Const) ) ? copy(x.val) : nothing - - # Cache w if its overwritten and x is active (and thus required) - cache_w = ( EnzymeRules.overwritten(config)[4] && !(typeof(x) <: Const) ) ? copy(w.val) : nothing - - cache = (cache_x, cache_w) - - return EnzymeCore.EnzymeRules.AugmentedReturn(y.val, y.dval, cache) -end - -function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims::Const; kwargs...) where {RT} - cache_x, cache_w = cache - - # Don't cache x if not overwritten and w is active (and thus required) - if !(typeof(w) <: Const) - if !EnzymeRules.overwritten(config)[3] - cache_x = x.val - end - end - - # Don't cache w if not overwritten and x is active (and thus required) - if !(typeof(x) <: Const) - if !EnzymeRules.overwritten(config)[4] - cache_w = w.val - end - end - - dys = y.dval - dxs = (typeof(x) <: Const) ? nothing : x.dval - dws = (typeof(w) <: Const) ? nothing : w.dval - - if EnzymeRules.width(config) == 1 - dys = (dys,) - dxs = (dxs,) - dws = (dws,) - end - - for (dy, dx, dw) in (dys, dxs, dws) - if !(typeof(x) <: Const) - # dx += grad wrt x - NNlib.∇conv_data!(dx, dy, cache_w, cdims; alpha=1, beta=1, kwargs...) - end - if !(typeof(y) <: Const) - # dw += grad wrt w - NNlib.∇conv_filter!(dw, cache_x, dy, cdims; alpha=1, beta=1, kwargs...) - end - end - - return (nothing, nothing, nothing, nothing) -end \ No newline at end of file diff --git a/src/NNlib.jl b/src/NNlib.jl index 2e39f9448..fd9249176 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -123,11 +123,9 @@ include("impl/depthwiseconv_im2col.jl") include("impl/pooling_direct.jl") include("deprecations.jl") -function __init__() - @static if !isdefined(Base, :get_extension) - @require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin - include("../ext/NNlibEnzymeExt.jl") - end +@init @static if !isdefined(Base, :get_extension) + @require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin + include("../ext/NNlibEnzymeExt/NNlibEnzymeExt.jl") end end From 280e00a66d7ffa9884d3ff199feb7250712415c7 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 02:53:59 -0500 Subject: [PATCH 03/26] Add missing file --- ext/NNlibEnzymeExt/NNlibEnzymeExt.jl | 79 ++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 ext/NNlibEnzymeExt/NNlibEnzymeExt.jl diff --git a/ext/NNlibEnzymeExt/NNlibEnzymeExt.jl b/ext/NNlibEnzymeExt/NNlibEnzymeExt.jl new file mode 100644 index 000000000..874764374 --- /dev/null +++ b/ext/NNlibEnzymeExt/NNlibEnzymeExt.jl @@ -0,0 +1,79 @@ +module NNlibEnzymeExt + +using NNlib +isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) + +using EnzymeCore + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT} + + @assert !(OutType <: Const) + if OutType <: Duplicated || OutType <: DuplicatedNoNeed + func.val(y.val, x.val, w.val, cdims.val; kwargs...) + end + + primal = if EnzymeCore.EnzymeRules.needs_primal(config) + y.val + else + nothing + end + shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + y.dval + else + nothing + end + + # Cache x if its overwritten and w is active (and thus required) + cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] && !(typeof(w) <: Const) ) ? copy(x.val) : nothing + + # Cache w if its overwritten and x is active (and thus required) + cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(x) <: Const) ) ? copy(w.val) : nothing + + cache = (cache_x, cache_w) + + return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT} + cache_x, cache_w = cache + + # Don't cache x if not overwritten and w is active (and thus required) + if !(typeof(w) <: Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[3] + cache_x = x.val + end + end + + # Don't cache w if not overwritten and x is active (and thus required) + if !(typeof(x) <: Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[4] + cache_w = w.val + end + end + + dys = y.dval + dxs = (typeof(x) <: Const) ? dys : x.dval + dws = (typeof(w) <: Const) ? dys : w.dval + + if EnzymeCore.EnzymeRules.width(config) == 1 + dys = (dys,) + dxs = (dxs,) + dws = (dws,) + end + + for (dy, dx, dw) in zip(dys, dxs, dws) + if !(typeof(x) <: Const) && dx !== x + # dx += grad wrt x + NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) + end + if !(typeof(w) <: Const) && dw !== w + # dw += grad wrt w + NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) + end + dy .= 0 + end + + return (nothing, nothing, nothing, nothing) +end + +end \ No newline at end of file From ad4a39dd6abd76d6a9509095c1a162a5ff3ef178 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 11:00:56 -0500 Subject: [PATCH 04/26] Change to enzymecore ext --- Project.toml | 6 ++---- .../NNlibEnzymeCoreExt.jl} | 0 src/NNlib.jl | 4 ++-- 3 files changed, 4 insertions(+), 6 deletions(-) rename ext/{NNlibEnzymeExt/NNlibEnzymeExt.jl => NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl} (100%) diff --git a/Project.toml b/Project.toml index a9064ab57..cbd243e4a 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "0.9.6" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -17,15 +16,15 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" [extensions] NNlibAMDGPUExt = "AMDGPU" NNlibCUDAExt = "CUDA" NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] -NNlibEnzymeExt = "Enzyme" +NNlibEnzymeCoreExt = "EnzymeCore" [compat] AMDGPU = "0.5, 0.6" @@ -34,7 +33,6 @@ Atomix = "0.1" ChainRulesCore = "1.13" CUDA = "4, 5" cuDNN = "1" -Enzyme = "0.11.8" GPUArraysCore = "0.1" KernelAbstractions = "0.9.2" Requires = "1.0" diff --git a/ext/NNlibEnzymeExt/NNlibEnzymeExt.jl b/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl similarity index 100% rename from ext/NNlibEnzymeExt/NNlibEnzymeExt.jl rename to ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl diff --git a/src/NNlib.jl b/src/NNlib.jl index fd9249176..c4ad18750 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -124,8 +124,8 @@ include("impl/pooling_direct.jl") include("deprecations.jl") @init @static if !isdefined(Base, :get_extension) - @require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin - include("../ext/NNlibEnzymeExt/NNlibEnzymeExt.jl") + @require EnzymeCore="f151be2c-9106-41f4-ab19-57ee4f262869" begin + include("../ext/NNlibEnzymeCoreExt/NNlibEnzymeCoresExt.jl") end end From b01f2c2a8dacec1e605624e7f630484c87a498a2 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 12:15:49 -0500 Subject: [PATCH 05/26] attempt fix --- Project.toml | 16 +++++------ ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl | 30 +++++++++----------- src/NNlib.jl | 6 +--- test/conv.jl | 2 +- test/runtests.jl | 1 + test/test_utils.jl | 18 +++++++++++- 6 files changed, 41 insertions(+), 32 deletions(-) diff --git a/Project.toml b/Project.toml index cbd243e4a..8ea3da8f0 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.9.6" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -16,26 +17,24 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" -cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [extensions] NNlibAMDGPUExt = "AMDGPU" -NNlibCUDAExt = "CUDA" NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] -NNlibEnzymeCoreExt = "EnzymeCore" +NNlibCUDAExt = "CUDA" [compat] AMDGPU = "0.5, 0.6" Adapt = "3.2" Atomix = "0.1" -ChainRulesCore = "1.13" CUDA = "4, 5" -cuDNN = "1" +ChainRulesCore = "1.13" GPUArraysCore = "0.1" KernelAbstractions = "0.9.2" Requires = "1.0" +cuDNN = "1" julia = "1.9" [extras] @@ -44,6 +43,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -55,6 +55,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter", - "FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff", - "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN", "Enzyme"] +test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter", "FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff", "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN", "Enzyme", "EnzymeTestUtils"] diff --git a/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl b/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl index 874764374..3d109fa06 100644 --- a/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl +++ b/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl @@ -1,14 +1,12 @@ -module NNlibEnzymeExt +module NNlibEnzymeCoreExt using NNlib -isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) +isdefined(Base, :get_extension) ? (import EnzymeCore) : (import ..EnzymeCore) -using EnzymeCore +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT} -function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT} - - @assert !(OutType <: Const) - if OutType <: Duplicated || OutType <: DuplicatedNoNeed + @assert !(OutType <: EnzymeCore.Const) + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.DuplicatedNoNeed func.val(y.val, x.val, w.val, cdims.val; kwargs...) end @@ -24,36 +22,36 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(NNli end # Cache x if its overwritten and w is active (and thus required) - cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] && !(typeof(w) <: Const) ) ? copy(x.val) : nothing + cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] && !(typeof(w) <: EnzymeCore.Const) ) ? copy(x.val) : nothing # Cache w if its overwritten and x is active (and thus required) - cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(x) <: Const) ) ? copy(w.val) : nothing + cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(x) <: EnzymeCore.Const) ) ? copy(w.val) : nothing cache = (cache_x, cache_w) return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) end -function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT} +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT} cache_x, cache_w = cache # Don't cache x if not overwritten and w is active (and thus required) - if !(typeof(w) <: Const) + if !(typeof(w) <: EnzymeCore.Const) if !EnzymeCore.EnzymeRules.overwritten(config)[3] cache_x = x.val end end # Don't cache w if not overwritten and x is active (and thus required) - if !(typeof(x) <: Const) + if !(typeof(x) <: EnzymeCore.Const) if !EnzymeCore.EnzymeRules.overwritten(config)[4] cache_w = w.val end end dys = y.dval - dxs = (typeof(x) <: Const) ? dys : x.dval - dws = (typeof(w) <: Const) ? dys : w.dval + dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval + dws = (typeof(w) <: EnzymeCore.Const) ? dys : w.dval if EnzymeCore.EnzymeRules.width(config) == 1 dys = (dys,) @@ -62,11 +60,11 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(NNlib.conv!)} end for (dy, dx, dw) in zip(dys, dxs, dws) - if !(typeof(x) <: Const) && dx !== x + if !(typeof(x) <: EnzymeCore.Const) && dx !== x # dx += grad wrt x NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) end - if !(typeof(w) <: Const) && dw !== w + if !(typeof(w) <: EnzymeCore.Const) && dw !== w # dw += grad wrt w NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) end diff --git a/src/NNlib.jl b/src/NNlib.jl index c4ad18750..14ba2c70f 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -123,10 +123,6 @@ include("impl/depthwiseconv_im2col.jl") include("impl/pooling_direct.jl") include("deprecations.jl") -@init @static if !isdefined(Base, :get_extension) - @require EnzymeCore="f151be2c-9106-41f4-ab19-57ee4f262869" begin - include("../ext/NNlibEnzymeCoreExt/NNlibEnzymeCoresExt.jl") - end -end +include("enzyme.jl") end # module NNlib diff --git a/test/conv.jl b/test/conv.jl index 8107bd387..9b9c046a3 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -870,7 +870,7 @@ end w = rand(rng, repeat([3], spatial_rank)..., 3, 3) cdims = DenseConvDims(x, w) gradtest((x, w) -> conv(x, w, cdims), x, w) - gradtest((x, w) -> sum(conv(x, w, cdims)), x, w) # https://github.com/FluxML/Flux.jl/issues/1055 + gradtest((x, w) -> sum(conv(x, w, cdims)), x, w; check_enzyme_rule=true) # https://github.com/FluxML/Flux.jl/issues/1055 y = conv(x, w, cdims) gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w) diff --git a/test/runtests.jl b/test/runtests.jl index 03602a40d..660f9167f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using NNlib, Test, Statistics, Random using ChainRulesCore, ChainRulesTestUtils using Base.Broadcast: broadcasted +import EnzymeTestUtils import FiniteDifferences import ForwardDiff import Zygote diff --git a/test/test_utils.jl b/test/test_utils.jl index 16b3998dc..598f0f66b 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -12,7 +12,7 @@ Applies also `ChainRulesTestUtils.test_rrule` if the rrule for `f` is explicitly """ function gradtest( f, xs...; atol = 1e-6, rtol = 1e-6, fkwargs = NamedTuple(), - check_rrule = false, fdm = :central, check_broadcast = false, + check_rrule = false, check_enzyme_rrule = false, fdm = :central, check_broadcast = false, skip = false, broken = false, ) # TODO: revamp when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/166 @@ -20,6 +20,22 @@ function gradtest( if check_rrule test_rrule(f, xs...; fkwargs = fkwargs) end + if check_enzyme_rrule + if len(xs) == 2 + for Tret in (Const, Active), + Tx in (Const, Duplicated, BatchDuplicated), + Ty in (Const, Duplicated, BatchDuplicated) + + are_activities_compatible(Tret, Tx, Ty) || continue + + test_reverse(fun, Tret, (xs[1], Tx), (ys[1], Ty); atol, rtol) + end + else + throw(AssertionError("Unsupported arg count for testing")) + end + + EnzymeTestUtils.test_rrule(f, xs...; fkwargs = fkwargs) + end if check_broadcast length(fkwargs) > 0 && @warn("CHECK_BROADCAST: dropping keywords args") From a08f6fe8ade9d8cbb1bca0cea0735fb278772c18 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 12:19:24 -0500 Subject: [PATCH 06/26] Finish enzymecore rewrite --- ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl | 77 -------------------- 1 file changed, 77 deletions(-) delete mode 100644 ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl diff --git a/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl b/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl deleted file mode 100644 index 3d109fa06..000000000 --- a/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl +++ /dev/null @@ -1,77 +0,0 @@ -module NNlibEnzymeCoreExt - -using NNlib -isdefined(Base, :get_extension) ? (import EnzymeCore) : (import ..EnzymeCore) - -function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT} - - @assert !(OutType <: EnzymeCore.Const) - if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.DuplicatedNoNeed - func.val(y.val, x.val, w.val, cdims.val; kwargs...) - end - - primal = if EnzymeCore.EnzymeRules.needs_primal(config) - y.val - else - nothing - end - shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) - y.dval - else - nothing - end - - # Cache x if its overwritten and w is active (and thus required) - cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] && !(typeof(w) <: EnzymeCore.Const) ) ? copy(x.val) : nothing - - # Cache w if its overwritten and x is active (and thus required) - cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(x) <: EnzymeCore.Const) ) ? copy(w.val) : nothing - - cache = (cache_x, cache_w) - - return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) -end - -function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT} - cache_x, cache_w = cache - - # Don't cache x if not overwritten and w is active (and thus required) - if !(typeof(w) <: EnzymeCore.Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[3] - cache_x = x.val - end - end - - # Don't cache w if not overwritten and x is active (and thus required) - if !(typeof(x) <: EnzymeCore.Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[4] - cache_w = w.val - end - end - - dys = y.dval - dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval - dws = (typeof(w) <: EnzymeCore.Const) ? dys : w.dval - - if EnzymeCore.EnzymeRules.width(config) == 1 - dys = (dys,) - dxs = (dxs,) - dws = (dws,) - end - - for (dy, dx, dw) in zip(dys, dxs, dws) - if !(typeof(x) <: EnzymeCore.Const) && dx !== x - # dx += grad wrt x - NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) - end - if !(typeof(w) <: EnzymeCore.Const) && dw !== w - # dw += grad wrt w - NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) - end - dy .= 0 - end - - return (nothing, nothing, nothing, nothing) -end - -end \ No newline at end of file From d6426ac81b03dcbd0fa29b7179ef559290eebe5f Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 12:32:28 -0500 Subject: [PATCH 07/26] Add missing file --- src/enzyme.jl | 72 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 src/enzyme.jl diff --git a/src/enzyme.jl b/src/enzyme.jl new file mode 100644 index 000000000..8aaf38704 --- /dev/null +++ b/src/enzyme.jl @@ -0,0 +1,72 @@ +import EnzymeCore + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT} + + @assert !(OutType <: EnzymeCore.Const) + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.DuplicatedNoNeed + func.val(y.val, x.val, w.val, cdims.val; kwargs...) + end + + primal = if EnzymeCore.EnzymeRules.needs_primal(config) + y.val + else + nothing + end + shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + y.dval + else + nothing + end + + # Cache x if its overwritten and w is active (and thus required) + cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] && !(typeof(w) <: EnzymeCore.Const) ) ? copy(x.val) : nothing + + # Cache w if its overwritten and x is active (and thus required) + cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(x) <: EnzymeCore.Const) ) ? copy(w.val) : nothing + + cache = (cache_x, cache_w) + + return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT} + cache_x, cache_w = cache + + # Don't cache x if not overwritten and w is active (and thus required) + if !(typeof(w) <: EnzymeCore.Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[3] + cache_x = x.val + end + end + + # Don't cache w if not overwritten and x is active (and thus required) + if !(typeof(x) <: EnzymeCore.Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[4] + cache_w = w.val + end + end + + dys = y.dval + dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval + dws = (typeof(w) <: EnzymeCore.Const) ? dys : w.dval + + if EnzymeCore.EnzymeRules.width(config) == 1 + dys = (dys,) + dxs = (dxs,) + dws = (dws,) + end + + for (dy, dx, dw) in zip(dys, dxs, dws) + if !(typeof(x) <: EnzymeCore.Const) && dx !== x + # dx += grad wrt x + NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) + end + if !(typeof(w) <: EnzymeCore.Const) && dw !== w + # dw += grad wrt w + NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) + end + dy .= 0 + end + + return (nothing, nothing, nothing, nothing) +end \ No newline at end of file From af42451f69e277050a36e48682fee8b9b9945e1c Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 12:55:23 -0500 Subject: [PATCH 08/26] Also add gather --- src/enzyme.jl | 63 +++++++++++++++++++++++++++++++++++++++++++--- test/conv.jl | 2 +- test/gather.jl | 13 ++++++++++ test/runtests.jl | 1 + test/test_utils.jl | 10 ++++---- 5 files changed, 79 insertions(+), 10 deletions(-) diff --git a/src/enzyme.jl b/src/enzyme.jl index 8aaf38704..7bf15f997 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -57,16 +57,71 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN end for (dy, dx, dw) in zip(dys, dxs, dws) - if !(typeof(x) <: EnzymeCore.Const) && dx !== x - # dx += grad wrt x + if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val + # dx += grad wrt x.val NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) end - if !(typeof(w) <: EnzymeCore.Const) && dw !== w - # dw += grad wrt w + if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val + # dw += grad wrt w.val NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) end dy .= 0 end + return (nothing, nothing, nothing, nothing) +end + + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} + + @assert !(OutType <: EnzymeCore.Const) + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.DuplicatedNoNeed + func.val(dst.val, src.val, idx.val) + end + + primal = if EnzymeCore.EnzymeRules.needs_primal(config) + dst.val + else + nothing + end + shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + dst.dval + else + nothing + end + + # Cache idx if its overwritten + cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(src) <: EnzymeCore.Const) ) ? copy(idx.val) : nothing + + return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, cache_idx, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} + + # Don't cache idx if not overwritten + if !(typeof(src) <: EnzymeCore.Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[4] + cache_idx = idx.val + end + end + + ddsts = dst.dval + dsrcs = src.dval + + if EnzymeCore.EnzymeRules.width(config) == 1 + ddsts = (ddsts,) + dsrcs = (dsrcs,) + end + + for (ddst, dsrc) in zip(ddsts, dsrcs) + if !(typeof(src) <: EnzymeCore.Const) && ddst !== dst.val + src_size = size(src.val) + NNlib.∇gather_src(ddst, src_size, cache_idx) + end + if !(typeof(w) <: EnzymeCore.Const) && dw !== w + ddst .= 0 + end + end + return (nothing, nothing, nothing, nothing) end \ No newline at end of file diff --git a/test/conv.jl b/test/conv.jl index 9b9c046a3..44fe06443 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -870,7 +870,7 @@ end w = rand(rng, repeat([3], spatial_rank)..., 3, 3) cdims = DenseConvDims(x, w) gradtest((x, w) -> conv(x, w, cdims), x, w) - gradtest((x, w) -> sum(conv(x, w, cdims)), x, w; check_enzyme_rule=true) # https://github.com/FluxML/Flux.jl/issues/1055 + gradtest((x, w) -> sum(conv(x, w, cdims)), x, w; check_enzyme_rrule=true) # https://github.com/FluxML/Flux.jl/issues/1055 y = conv(x, w, cdims) gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w) diff --git a/test/gather.jl b/test/gather.jl index e3221145b..7d68eedb4 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -1,4 +1,6 @@ using NNlib: gather, gather! +import EnzymeTestUtils +using EnzymeCore function gather_testsuite(Backend) device(x) = adapt(Backend(), x) @@ -150,6 +152,17 @@ function gather_testsuite(Backend) Backend == CPU ? gradtest_fn(xs -> gather(xs, idx), src) : gradtest_fn((s, i) -> gather(s, i), src, idx) + + if Backend == CPU + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue + + EnzymeTestUtils.test_reverse(fun, Tret, (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const)) + end + end end @testset "gather gradient for tuple index" begin diff --git a/test/runtests.jl b/test/runtests.jl index 660f9167f..9b43be8b8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using NNlib, Test, Statistics, Random using ChainRulesCore, ChainRulesTestUtils using Base.Broadcast: broadcasted import EnzymeTestUtils +using EnzymeCore import FiniteDifferences import ForwardDiff import Zygote diff --git a/test/test_utils.jl b/test/test_utils.jl index 598f0f66b..2e4f19d5f 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -22,13 +22,13 @@ function gradtest( end if check_enzyme_rrule if len(xs) == 2 - for Tret in (Const, Active), - Tx in (Const, Duplicated, BatchDuplicated), - Ty in (Const, Duplicated, BatchDuplicated) + for Tret in (EnzymeCore.Const, EnzymeCore.Active), + Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Ty in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) - are_activities_compatible(Tret, Tx, Ty) || continue + EnzymeTestUtils.are_activities_compatible(Tret, Tx, Ty) || continue - test_reverse(fun, Tret, (xs[1], Tx), (ys[1], Ty); atol, rtol) + EnzymeTestUtils.test_reverse(fun, Tret, (xs[1], Tx), (ys[1], Ty); atol, rtol) end else throw(AssertionError("Unsupported arg count for testing")) From 8b92077f38a28d551dec1a05aa04dd2877ed51f1 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 16:48:54 -0500 Subject: [PATCH 09/26] Additional functions, tests, and fixes --- src/enzyme.jl | 85 +++++++++++++++++++++++++++++++++++++++++----- test/conv.jl | 24 ++++++++++++- test/gather.jl | 21 ++++++++---- test/scatter.jl | 19 +++++++++++ test/test_utils.jl | 18 +--------- 5 files changed, 133 insertions(+), 34 deletions(-) diff --git a/src/enzyme.jl b/src/enzyme.jl index 7bf15f997..d466f05c0 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -1,9 +1,12 @@ import EnzymeCore -function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT} +for name in (typeof(NNlib.conv!), typeof(NNlib.depthwiseconv!)) + @eval begin + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT} @assert !(OutType <: EnzymeCore.Const) - if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.DuplicatedNoNeed + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated func.val(y.val, x.val, w.val, cdims.val; kwargs...) end @@ -29,7 +32,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{ return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) end -function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT} +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT} cache_x, cache_w = cache # Don't cache x if not overwritten and w is active (and thus required) @@ -71,11 +74,13 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN return (nothing, nothing, nothing, nothing) end +end +end function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} @assert !(OutType <: EnzymeCore.Const) - if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.DuplicatedNoNeed + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated func.val(dst.val, src.val, idx.val) end @@ -114,14 +119,76 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN end for (ddst, dsrc) in zip(ddsts, dsrcs) - if !(typeof(src) <: EnzymeCore.Const) && ddst !== dst.val - src_size = size(src.val) - NNlib.∇gather_src(ddst, src_size, cache_idx) + if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val && + !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val + NNlib.scatter!(+, dsrc, ddst, cache_idx) end - if !(typeof(w) <: EnzymeCore.Const) && dw !== w + if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val ddst .= 0 end end + return (nothing, nothing, nothing) +end + + + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, op::EnzymeCore.Const, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} + + @assert !(OutType <: EnzymeCore.Const) + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated + func.val(op.val, dst.val, src.val, idx.val) + end + + primal = if EnzymeCore.EnzymeRules.needs_primal(config) + dst.val + else + nothing + end + shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + dst.dval + else + nothing + end + + # Cache idx if its overwritten + cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(src) <: EnzymeCore.Const) ) ? copy(idx.val) : nothing + + return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, cache_idx, op::Union{EnzymeCore.Const{typeof(+)},EnzymeCore.Const{typeof(-)}}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} + + # Don't cache idx if not overwritten + if !(typeof(src) <: EnzymeCore.Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[4] + cache_idx = idx.val + end + end + + ddsts = dst.dval + dsrcs = src.dval + + if EnzymeCore.EnzymeRules.width(config) == 1 + ddsts = (ddsts,) + dsrcs = (dsrcs,) + end + + for (ddst, dsrc) in zip(ddsts, dsrcs) + if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val && + !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val + + if eltype(typeof(op)) == typeof(+) + dsrc .+= NNlib.gather(ddst, cache_idx) + else + @assert eltype(typeof(op)) == typeof(-) + dsrc .-= NNlib.gather(ddst, cache_idx) + end + end + end + return (nothing, nothing, nothing, nothing) -end \ No newline at end of file +end + + + diff --git a/test/conv.jl b/test/conv.jl index 44fe06443..a182fef99 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -870,7 +870,7 @@ end w = rand(rng, repeat([3], spatial_rank)..., 3, 3) cdims = DenseConvDims(x, w) gradtest((x, w) -> conv(x, w, cdims), x, w) - gradtest((x, w) -> sum(conv(x, w, cdims)), x, w; check_enzyme_rrule=true) # https://github.com/FluxML/Flux.jl/issues/1055 + gradtest((x, w) -> sum(conv(x, w, cdims)), x, w) # https://github.com/FluxML/Flux.jl/issues/1055 y = conv(x, w, cdims) gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w) @@ -886,3 +886,25 @@ end gradtest((y, w) -> ∇depthwiseconv_data(y, w, dcdims), y, w) gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w) end + +@testset "EnzymeRules: conv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) + x = rand(rng, repeat([5], spatial_rank)..., 3, 2) + w = rand(rng, repeat([3], spatial_rank)..., 3, 3) + cdims = DenseConvDims(x, w) + + for name in (:conv, :depthwiseconv) + curconv = @eval $(Symbol("$(name)")) + curconv! = @eval $(Symbol("$(name)!")) + dst = curconv(x, w, cdims) + + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue + + EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (x, Tw), (idx, EnzymeCore.Const)) + end + end +end diff --git a/test/gather.jl b/test/gather.jl index 7d68eedb4..c5cd5ee57 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -152,16 +152,23 @@ function gather_testsuite(Backend) Backend == CPU ? gradtest_fn(xs -> gather(xs, idx), src) : gradtest_fn((s, i) -> gather(s, i), src, idx) + end - if Backend == CPU - for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated), - Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), - Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) - EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue + @testset "EnzymeRules: gather! gradient for scalar index" begin + src = device(Float64[3, 4, 5, 6, 7]) + idx = device([ + 1 2 3 4; + 4 2 1 3; + 3 5 5 3]) + dst = gather(src, idx) + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue - EnzymeTestUtils.test_reverse(fun, Tret, (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const)) - end + EnzymeTestUtils.test_reverse(gather!, Tret, (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const)) end end diff --git a/test/scatter.jl b/test/scatter.jl index 26fc06cde..289ed37b8 100644 --- a/test/scatter.jl +++ b/test/scatter.jl @@ -207,5 +207,24 @@ function scatter_testsuite(Backend) gradtest_fn((xs, i) -> scatter(op, xs, i), src, idx) end end + + @testset "EnzymeRules" begin + idx = device([2, 2, 3, 4, 4]) + src = device(ones(T, 3, 5)) + + for op in (+, -) + + dst = scatter(op, src, idx) + + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue + + EnzymeTestUtils.test_reverse(scatter!, Tret, (op, EnzymeCore.Const), (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const)) + end + end + end end end diff --git a/test/test_utils.jl b/test/test_utils.jl index 2e4f19d5f..16b3998dc 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -12,7 +12,7 @@ Applies also `ChainRulesTestUtils.test_rrule` if the rrule for `f` is explicitly """ function gradtest( f, xs...; atol = 1e-6, rtol = 1e-6, fkwargs = NamedTuple(), - check_rrule = false, check_enzyme_rrule = false, fdm = :central, check_broadcast = false, + check_rrule = false, fdm = :central, check_broadcast = false, skip = false, broken = false, ) # TODO: revamp when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/166 @@ -20,22 +20,6 @@ function gradtest( if check_rrule test_rrule(f, xs...; fkwargs = fkwargs) end - if check_enzyme_rrule - if len(xs) == 2 - for Tret in (EnzymeCore.Const, EnzymeCore.Active), - Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), - Ty in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) - - EnzymeTestUtils.are_activities_compatible(Tret, Tx, Ty) || continue - - EnzymeTestUtils.test_reverse(fun, Tret, (xs[1], Tx), (ys[1], Ty); atol, rtol) - end - else - throw(AssertionError("Unsupported arg count for testing")) - end - - EnzymeTestUtils.test_rrule(f, xs...; fkwargs = fkwargs) - end if check_broadcast length(fkwargs) > 0 && @warn("CHECK_BROADCAST: dropping keywords args") From e913bcab830f0cd08b6ad98d10d98be97e198e06 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 18:55:41 -0500 Subject: [PATCH 10/26] minor fixup --- test/conv.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/conv.jl b/test/conv.jl index a182fef99..54fd5b4bf 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -902,7 +902,7 @@ end Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) - EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tw, Tw) || continue EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (x, Tw), (idx, EnzymeCore.Const)) end From 11962277f05722e15ab4fa27491d77b4bf0e250d Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 19:51:24 -0500 Subject: [PATCH 11/26] Add pooling --- src/enzyme.jl | 161 +++++++++++++++++++++++++++++++++++++++--------- test/pooling.jl | 19 ++++++ 2 files changed, 150 insertions(+), 30 deletions(-) diff --git a/src/enzyme.jl b/src/enzyme.jl index d466f05c0..a80f0f932 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -5,7 +5,6 @@ for name in (typeof(NNlib.conv!), typeof(NNlib.depthwiseconv!)) function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT} - @assert !(OutType <: EnzymeCore.Const) if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated func.val(y.val, x.val, w.val, cdims.val; kwargs...) end @@ -22,10 +21,16 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{ end # Cache x if its overwritten and w is active (and thus required) - cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] && !(typeof(w) <: EnzymeCore.Const) ) ? copy(x.val) : nothing + cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] + && !(typeof(w) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(x.val) : nothing # Cache w if its overwritten and x is active (and thus required) - cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(x) <: EnzymeCore.Const) ) ? copy(w.val) : nothing + cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] + && !(typeof(x) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(w.val) : nothing cache = (cache_x, cache_w) @@ -36,14 +41,14 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, : cache_x, cache_w = cache # Don't cache x if not overwritten and w is active (and thus required) - if !(typeof(w) <: EnzymeCore.Const) + if !(typeof(w) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) if !EnzymeCore.EnzymeRules.overwritten(config)[3] cache_x = x.val end end # Don't cache w if not overwritten and x is active (and thus required) - if !(typeof(x) <: EnzymeCore.Const) + if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) if !EnzymeCore.EnzymeRules.overwritten(config)[4] cache_w = w.val end @@ -60,15 +65,19 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, : end for (dy, dx, dw) in zip(dys, dxs, dws) - if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val - # dx += grad wrt x.val - NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) - end - if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val - # dw += grad wrt w.val - NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) + if !(typeof(y) <: EnzymeCore.Const) && dy !== w.val + + if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val + # dx += grad wrt x.val + NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) + end + if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val + # dw += grad wrt w.val + NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) + end + + dy .= 0 end - dy .= 0 end return (nothing, nothing, nothing, nothing) @@ -79,7 +88,6 @@ end function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} - @assert !(OutType <: EnzymeCore.Const) if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated func.val(dst.val, src.val, idx.val) end @@ -96,7 +104,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{ end # Cache idx if its overwritten - cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(src) <: EnzymeCore.Const) ) ? copy(idx.val) : nothing + cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] + && !(typeof(src) <: EnzymeCore.Const) + && !(typeof(dst) <: EnzymeCore.Const) + ) ? copy(idx.val) : nothing return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) end @@ -104,7 +115,7 @@ end function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, cache_idx, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} # Don't cache idx if not overwritten - if !(typeof(src) <: EnzymeCore.Const) + if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) if !EnzymeCore.EnzymeRules.overwritten(config)[4] cache_idx = idx.val end @@ -119,11 +130,12 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN end for (ddst, dsrc) in zip(ddsts, dsrcs) - if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val && - !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val - NNlib.scatter!(+, dsrc, ddst, cache_idx) - end if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val + + if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val + NNlib.scatter!(+, dsrc, ddst, cache_idx) + end + ddst .= 0 end end @@ -152,7 +164,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{ end # Cache idx if its overwritten - cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(src) <: EnzymeCore.Const) ) ? copy(idx.val) : nothing + cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] + && !(typeof(src) <: EnzymeCore.Const) + && !(typeof(dst) <: EnzymeCore.Const) + ) ? copy(idx.val) : nothing return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) end @@ -160,7 +175,7 @@ end function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, cache_idx, op::Union{EnzymeCore.Const{typeof(+)},EnzymeCore.Const{typeof(-)}}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} # Don't cache idx if not overwritten - if !(typeof(src) <: EnzymeCore.Const) + if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) if !EnzymeCore.EnzymeRules.overwritten(config)[4] cache_idx = idx.val end @@ -175,15 +190,20 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN end for (ddst, dsrc) in zip(ddsts, dsrcs) - if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val && - !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val - - if eltype(typeof(op)) == typeof(+) - dsrc .+= NNlib.gather(ddst, cache_idx) - else - @assert eltype(typeof(op)) == typeof(-) - dsrc .-= NNlib.gather(ddst, cache_idx) + if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val + + if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val + + if eltype(typeof(op)) == typeof(+) + dsrc .+= NNlib.gather(ddst, cache_idx) + else + @assert eltype(typeof(op)) == typeof(-) + dsrc .-= NNlib.gather(ddst, cache_idx) + end end + + ddst .= 0 + end end @@ -192,3 +212,84 @@ end +for pool in [:maxpool, :meanpool, :lpnormpool] + pool! = Symbol(pool, :!) + ∇pool = Symbol(:∇, pool) + + @eval begin + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, y::OutType, x, dims; kwargs...) where {OutType, RT} + + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated + func.val(y.val, x.val, dims.val; kwargs...) + end + + primal = if EnzymeCore.EnzymeRules.needs_primal(config) + y.val + else + nothing + end + shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + y.dval + else + nothing + end + + cache_y = ( EnzymeCore.EnzymeRules.overwritten(config)[2] + && !(typeof(x) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(y.val) : nothing + + cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] + && !(typeof(x) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(x.val) : nothing + + cache = (cache_y, cache_x) + + return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, cache, y, x, dims; kwargs...) where {RT} + cache_y, cache_x = cache + + # Don't cache y if not overwritten + if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[2] + cache_y = y.val + end + end + + # Don't cache x if not overwritten + if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[3] + cache_x = x.val + end + end + + dys = y.dval + dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval + + if EnzymeCore.EnzymeRules.width(config) == 1 + dys = (dys,) + dxs = (dxs,) + end + + for (dy, dx, dw) in zip(dys, dxs) + if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val + + if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val + NNlib.$(∇pool)(dx, dy, cache_y, cache_x, dims; alpha=eltype(dx)(1), beta=eltype(dx)(1), kwargs...) + end + + dy .= 0 + end + end + + return (nothing, nothing, nothing) +end + +end +end + + diff --git a/test/pooling.jl b/test/pooling.jl index b4b4f40b7..d695bd421 100644 --- a/test/pooling.jl +++ b/test/pooling.jl @@ -967,3 +967,22 @@ end gradtest(x -> sum(maxpool(x, k)), x, skip = spatial_rank==2) gradtest(x -> sum(meanpool(x, k)), x) end + + +@testset "EnzymeRules: pooling! $pool spatial_rank=$spatial_rank " for spatial_rank in (1, 2), + (pool, pool!) in ((maxpool, maxpool!), (meanpool, meanpool!)) + + x = rand(rng, repeat([10], spatial_rank)..., 3, 2) + pdims = PoolDims(x, 2) + y = pool(x, pdims) + + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue + + EnzymeTestUtils.test_reverse(pool!, Tret, (y, Tdst), (x, Tsrc), (pdims, EnzymeCore.Const)) + end + +end \ No newline at end of file From 46dcac19c619e98a168e477555cbaef9bd580af6 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 21:16:29 -0500 Subject: [PATCH 12/26] Add dropout --- src/enzyme.jl | 71 +++++++++++++++++++++++++++++++++++++++++++++++++ test/dropout.jl | 27 ++++++++++++++++++- 2 files changed, 97 insertions(+), 1 deletion(-) diff --git a/src/enzyme.jl b/src/enzyme.jl index a80f0f932..9eb391616 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -292,4 +292,75 @@ end end end +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, rng, dst::OutType, src, p, dims) where {OutType, RT} + T = float(real(eltype(dst.val))) + val = convert(T, 1/(1-p.val)) + keep = if dims.val isa Colon + similar(dst.val, T, size(dst.val)) + else + similar(dst.val, T, ntuple(d -> d in dims.val ? size(dst.val,d) : 1, ndims(dst.val))) + end + rand!(rng.val, keep) + + keep = keep .> p.val + + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated + dst.val .= (keep .* val) .* src.val + end + + primal = if EnzymeCore.EnzymeRules.needs_primal(config) + dst.val + else + nothing + end + shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + dst.dval + else + nothing + end + + if typeof(dst) <: EnzymeCore.Const || typeof(src) <: EnzymeCore.Const + keep = nothing + end + + # Cache idx if its overwritten + cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] + && !(typeof(src) <: EnzymeCore.Const) + && !(typeof(dst) <: EnzymeCore.Const) + ) ? copy(idx.val) : nothing + + return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, keep) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, keep, rng, dst::OutType, src, p, dims) where {OutType, RT} + T = float(real(eltype(dst.val))) + val = convert(T, 1/(1-p.val)) + + ddsts = dst.dval + dsrcs = src.dval + + if EnzymeCore.EnzymeRules.width(config) == 1 + ddsts = (ddsts,) + dsrcs = (dsrcs,) + end + + for (ddst, dsrc) in zip(ddsts, dsrcs) + if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val + + if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val + dsrc .+= (keep .* val) .* ddst + end + + ddst .= 0 + end + end + + dp = if typeof(p) <: EnzymeCore.Active + typeof(p.val)(0) + else + nothing + end + + return (nothing, nothing, nothing, dp, nothing) +end diff --git a/test/dropout.jl b/test/dropout.jl index 07a48edc5..315d5146e 100644 --- a/test/dropout.jl +++ b/test/dropout.jl @@ -1,5 +1,5 @@ using NNlib, Test, Statistics, Random, LinearAlgebra -using Zygote, StableRNGs, ChainRulesCore +using Zygote, StableRNGs, ChainRulesCore, Enzyme @testset "dropout" begin # Basics @@ -75,3 +75,28 @@ using Zygote, StableRNGs, ChainRulesCore @test_throws ArgumentError dropout(x1, 2) @test_throws ArgumentError dropout!(y1, x1, 3) end + +@testset "EnzymeRules: dropout " + rng = Random.default_rng() + + x1 = randn(Float32, 3000, 4000) + dx1 = zeros(Float32, 3000, 4000) + + dout = randn(Float32, 3000, 4000) + + p = 0.2f0 + + forward, reverse = Enzyme.autodiff_thunk(ReverseSplitWithPrimal, typeof(Const(dropout)), Duplicated, typeof(Const(rng)), typeof(Duplicated(x1, dx1)), typeof(Const(0.2f0))) + + tape, primal, shadow = forward(Const(dropout), Const(rng), Duplicated(x1, dx1), Const(p)) + + shadow .= dout + + reverse(Const(dropout), Const(rng), Duplicated(x1, dx1), Const(p), tape) + + @test dx1[.!tape[1]] ≈ zero(x1)[.!tape[1]] + + val = convert(Float32, 1/(1-p)) + + @test dx1[tape[1]] ≈ (val * dout)[tape[1]] +end \ No newline at end of file From d7d83893007dc1938b88ff5037654ac718c801a0 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 21:45:03 -0500 Subject: [PATCH 13/26] Fix scatter bug --- src/enzyme.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/enzyme.jl b/src/enzyme.jl index 9eb391616..773066cff 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -202,8 +202,6 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN end end - ddst .= 0 - end end From 2fd798034c641b01e78ceda9c2333959169ed002 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 22:00:00 -0500 Subject: [PATCH 14/26] fix pool --- src/enzyme.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/enzyme.jl b/src/enzyme.jl index 773066cff..0e1b75d8a 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -273,7 +273,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof($p dxs = (dxs,) end - for (dy, dx, dw) in zip(dys, dxs) + for (dy, dx) in zip(dys, dxs) if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val From 93916d5be04a7432588c87019922713a396d518d Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 22:44:15 -0500 Subject: [PATCH 15/26] More fixups --- src/enzyme.jl | 12 +++--------- test/conv.jl | 2 +- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/enzyme.jl b/src/enzyme.jl index 0e1b75d8a..1d4d0c20c 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -122,7 +122,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN end ddsts = dst.dval - dsrcs = src.dval + dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval if EnzymeCore.EnzymeRules.width(config) == 1 ddsts = (ddsts,) @@ -182,7 +182,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN end ddsts = dst.dval - dsrcs = src.dval + dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval if EnzymeCore.EnzymeRules.width(config) == 1 ddsts = (ddsts,) @@ -322,12 +322,6 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{ keep = nothing end - # Cache idx if its overwritten - cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] - && !(typeof(src) <: EnzymeCore.Const) - && !(typeof(dst) <: EnzymeCore.Const) - ) ? copy(idx.val) : nothing - return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, keep) end @@ -336,7 +330,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN val = convert(T, 1/(1-p.val)) ddsts = dst.dval - dsrcs = src.dval + dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval if EnzymeCore.EnzymeRules.width(config) == 1 ddsts = (ddsts,) diff --git a/test/conv.jl b/test/conv.jl index 54fd5b4bf..67ee0bdf3 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -904,7 +904,7 @@ end EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tw, Tw) || continue - EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (x, Tw), (idx, EnzymeCore.Const)) + EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (x, Tw), (cdims, EnzymeCore.Const)) end end end From 92c82216d25f0b24d7d17a531f77f6868b0acb9e Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 23:46:31 -0500 Subject: [PATCH 16/26] fix up pool --- src/enzyme.jl | 4 ++-- test/conv.jl | 30 +++++++++++++++++------------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/src/enzyme.jl b/src/enzyme.jl index 1d4d0c20c..6a6f45165 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -212,7 +212,7 @@ end for pool in [:maxpool, :meanpool, :lpnormpool] pool! = Symbol(pool, :!) - ∇pool = Symbol(:∇, pool) + ∇pool = Symbol(:∇, pool, :!) @eval begin @@ -277,7 +277,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof($p if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val - NNlib.$(∇pool)(dx, dy, cache_y, cache_x, dims; alpha=eltype(dx)(1), beta=eltype(dx)(1), kwargs...) + NNlib.$(∇pool)(dx, dy, cache_y, cache_x, dims.val; alpha=eltype(dx)(1), beta=eltype(dx)(1), kwargs...) end dy .= 0 diff --git a/test/conv.jl b/test/conv.jl index 67ee0bdf3..d8f7b69d4 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -887,24 +887,28 @@ end gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w) end -@testset "EnzymeRules: conv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) +@testset "EnzymeRules: $conv ! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3), + name in (:conv, :depthwiseconv) x = rand(rng, repeat([5], spatial_rank)..., 3, 2) w = rand(rng, repeat([3], spatial_rank)..., 3, 3) - cdims = DenseConvDims(x, w) + + cdims = if name == :conv + DenseConvDims(x, w) + else + DepthwiseConvDims(x, w) + end - for name in (:conv, :depthwiseconv) - curconv = @eval $(Symbol("$(name)")) - curconv! = @eval $(Symbol("$(name)!")) - dst = curconv(x, w, cdims) + curconv = @eval $(Symbol("$(name)")) + curconv! = @eval $(Symbol("$(name)!")) + dst = curconv(x, w, cdims) - for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), - Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), - Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), - Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) - EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tw, Tw) || continue + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tw, Tw) || continue - EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (x, Tw), (cdims, EnzymeCore.Const)) - end + EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (w, Tw), (cdims, EnzymeCore.Const)) end end From 645b5919fcfa2855e6a2ef36b797359381fab993 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 25 Sep 2023 09:15:43 -0500 Subject: [PATCH 17/26] split conv/depth --- test/conv.jl | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/test/conv.jl b/test/conv.jl index d8f7b69d4..e46dda3ab 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -887,19 +887,14 @@ end gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w) end -@testset "EnzymeRules: $conv ! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3), - name in (:conv, :depthwiseconv) +@testset "EnzymeRules: conv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) x = rand(rng, repeat([5], spatial_rank)..., 3, 2) w = rand(rng, repeat([3], spatial_rank)..., 3, 3) - cdims = if name == :conv - DenseConvDims(x, w) - else - DepthwiseConvDims(x, w) - end + cdims = DenseConvDims(x, w) - curconv = @eval $(Symbol("$(name)")) - curconv! = @eval $(Symbol("$(name)!")) + curconv = conv + curconv! = conv! dst = curconv(x, w, cdims) for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), @@ -912,3 +907,25 @@ end EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (w, Tw), (cdims, EnzymeCore.Const)) end end + + +@testset "EnzymeRules: depthwiseconv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) + x = rand(rng, repeat([5], spatial_rank)..., 3, 2) + w = rand(rng, repeat([3], spatial_rank)..., 3, 3) + + cdims = DepthwiseConvDims(x, w) + + curconv = depthwiseconv + curconv! = depthwiseconv! + dst = curconv(x, w, cdims) + + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tw, Tw) || continue + + EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (w, Tw), (cdims, EnzymeCore.Const)) + end +end \ No newline at end of file From 06dbbaad27401567a12d02bd7978d27e44a9b2d4 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 25 Sep 2023 23:12:48 -0500 Subject: [PATCH 18/26] Cleanup rule --- src/conv.jl | 14 +++++++------- src/enzyme.jl | 20 ++++++++++++++------ 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/conv.jl b/src/conv.jl index 3fecb9151..f15551e02 100644 --- a/src/conv.jl +++ b/src/conv.jl @@ -47,7 +47,7 @@ Apply convolution filter `w` to input `x`. `x` and `w` are 3d/4d/5d tensors in 1d/2d/3d convolutions respectively. `x` and `w` may have real or complex element types. """ -function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1) where {T, N} +@inline function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1) where {T, N} stride = expand(Val(N - 2), stride) padding = expand(Val(N - 2), pad) dilation = expand(Val(N - 2), dilation) @@ -62,7 +62,7 @@ end Depthwise convolution operation with filter `w` on input `x`. `x` and `w` are 3d/4d/5d tensors in 1d/2d/3d convolutions respectively. """ -function depthwiseconv(x, w::AbstractArray{T, N}; stride=1, pad=0, dilation=1, flipped=false) where {T, N} +@inline function depthwiseconv(x, w::AbstractArray{T, N}; stride=1, pad=0, dilation=1, flipped=false) where {T, N} stride = expand(Val(N-2), stride) pad = expand(Val(N-2), pad) dilation = expand(Val(N-2), dilation) @@ -80,7 +80,7 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack) # First make auto-allocating versions of the conv()-like calls: for name in (:conv, :depthwiseconv) @eval begin - function $(Symbol("$(name)$(backend)"))( + @inline function $(Symbol("$(name)$(backend)"))( x::AbstractArray{xT,N}, w::AbstractArray{wT,N}, cdims::ConvDims; kwargs...) where {xT, wT, N} y = similar(x, promote_type(xT, wT), output_size(cdims)..., @@ -92,7 +92,7 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack) for name in (:∇conv_data, :∇depthwiseconv_data) @eval begin - function $(Symbol("$(name)$(backend)"))( + @inline function $(Symbol("$(name)$(backend)"))( dy::AbstractArray{yT,N}, w::AbstractArray{wT,N}, cdims::C; kwargs...) where {yT, wT, N, C <: ConvDims} dx = similar(dy, input_size(cdims)..., channels_in(cdims), size(dy, N)) @@ -104,7 +104,7 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack) # We do the conv/depthwiseconv filter backprops separately, as the shape calculation # for `w` is slightly different for depthwise than for normal dense convolution. @eval begin - function $(Symbol("∇conv_filter$(backend)"))( + @inline function $(Symbol("∇conv_filter$(backend)"))( x::AbstractArray{xT,N}, dy::AbstractArray{yT,N}, cdims::ConvDims; kwargs...) where {xT, yT, N} dw = similar(dy, kernel_size(cdims)..., channels_in(cdims) ÷ groupcount(cdims), @@ -114,7 +114,7 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack) end @eval begin - function $(Symbol("∇depthwiseconv_filter$(backend)"))( + @inline function $(Symbol("∇depthwiseconv_filter$(backend)"))( x::AbstractArray{xT,N}, dy::AbstractArray{yT,N}, cdims::ConvDims; kwargs...) where {xT, yT, N} dw = similar(dy, kernel_size(cdims)..., channel_multiplier(cdims), @@ -137,7 +137,7 @@ for front_name in (:conv, :∇conv_data, :∇conv_filter, for backend in (Symbol(), :_direct, :_im2col) ## NNPACK is only for 2d conv for N in (3, 4) @eval begin - function $(Symbol("$(front_name)$(backend)!"))( + @inline function $(Symbol("$(front_name)$(backend)!"))( y::AbstractArray{yT,$N}, x::AbstractArray{xT,$N}, w::AbstractArray{wT,$N}, cdims::ConvDims; kwargs...) where {yT, xT, wT} diff --git a/src/enzyme.jl b/src/enzyme.jl index 6a6f45165..c44539108 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -3,9 +3,13 @@ import EnzymeCore for name in (typeof(NNlib.conv!), typeof(NNlib.depthwiseconv!)) @eval begin -function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT} +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT}, + y::EnzymeCore.Annotation{<:AbstractArray{yT, N}}, + x::EnzymeCore.Annotation{<:AbstractArray{xT, N}}, + w::EnzymeCore.Annotation{<:AbstractArray{wT, N}}, + cdims; kwargs...) where {RT, yT, xT, wT, N} - if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated + if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated func.val(y.val, x.val, w.val, cdims.val; kwargs...) end @@ -37,7 +41,11 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{ return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) end -function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT} +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, ::Type{RT}, cache, + y::EnzymeCore.Annotation{<:AbstractArray{yT, N}}, + x::EnzymeCore.Annotation{<:AbstractArray{xT, N}}, + w::EnzymeCore.Annotation{<:AbstractArray{wT, N}}, + cdims; kwargs...) where {RT, yT, xT, wT, N} cache_x, cache_w = cache # Don't cache x if not overwritten and w is active (and thus required) @@ -65,15 +73,15 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, : end for (dy, dx, dw) in zip(dys, dxs, dws) - if !(typeof(y) <: EnzymeCore.Const) && dy !== w.val + if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val # dx += grad wrt x.val - NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) + NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=xT(1), beta=xT(1), kwargs...) end if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val # dw += grad wrt w.val - NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...) + NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=wT(1), beta=wT(1), kwargs...) end dy .= 0 From cd6365ce21084e5a5272517906601c4dac9f4a4d Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 25 Sep 2023 23:15:55 -0500 Subject: [PATCH 19/26] Fix minor test bug --- test/conv.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/conv.jl b/test/conv.jl index e46dda3ab..02b853865 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -902,7 +902,7 @@ end Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) - EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tw, Tw) || continue + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tx, Tw) || continue EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (w, Tw), (cdims, EnzymeCore.Const)) end @@ -924,7 +924,7 @@ end Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) - EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tw, Tw) || continue + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tx, Tw) || continue EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (w, Tw), (cdims, EnzymeCore.Const)) end From 10cfdc6389983094750edd85c062e4d4ed74b53b Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 25 Sep 2023 23:23:19 -0500 Subject: [PATCH 20/26] Fix depthwise conv --- src/enzyme.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/enzyme.jl b/src/enzyme.jl index c44539108..aee96b641 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -1,6 +1,7 @@ import EnzymeCore -for name in (typeof(NNlib.conv!), typeof(NNlib.depthwiseconv!)) +for (name, dataname, filtername) in ((typeof(NNlib.conv!), NNlib.∇conv_data!, NNlib.∇conv_filter!), + (typeof(NNlib.depthwiseconv!), NNlib.∇depthwiseconv_data!, NNlib.∇depthwiseconv_filter!) ) @eval begin function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT}, @@ -77,11 +78,11 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, : if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val # dx += grad wrt x.val - NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=xT(1), beta=xT(1), kwargs...) + $dataname(dx, dy, cache_w, cdims.val; alpha=xT(1), beta=xT(1), kwargs...) end if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val # dw += grad wrt w.val - NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=wT(1), beta=wT(1), kwargs...) + $filtername(dw, cache_x, dy, cdims.val; alpha=wT(1), beta=wT(1), kwargs...) end dy .= 0 From f7f0186456e89b6837f8f55072337d2efba45f64 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 25 Sep 2023 23:30:13 -0500 Subject: [PATCH 21/26] Fix typo --- test/dropout.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dropout.jl b/test/dropout.jl index 315d5146e..9f5649596 100644 --- a/test/dropout.jl +++ b/test/dropout.jl @@ -76,7 +76,7 @@ using Zygote, StableRNGs, ChainRulesCore, Enzyme @test_throws ArgumentError dropout!(y1, x1, 3) end -@testset "EnzymeRules: dropout " +@testset "EnzymeRules: dropout " begin rng = Random.default_rng() x1 = randn(Float32, 3000, 4000) From 472d33b7e8c808d42a1fca42dfc5408269c88a52 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 26 Sep 2023 10:53:52 -0500 Subject: [PATCH 22/26] Bound tests --- test/conv.jl | 5 ++++- test/dropout.jl | 4 ++++ test/gather.jl | 3 +++ test/pooling.jl | 3 +++ test/runtests.jl | 2 ++ test/scatter.jl | 4 ++++ 6 files changed, 20 insertions(+), 1 deletion(-) diff --git a/test/conv.jl b/test/conv.jl index 02b853865..dc3fc57f5 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -887,6 +887,8 @@ end gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w) end +@static if Test_Enzyme + @testset "EnzymeRules: conv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) x = rand(rng, repeat([5], spatial_rank)..., 3, 2) w = rand(rng, repeat([3], spatial_rank)..., 3, 3) @@ -908,7 +910,6 @@ end end end - @testset "EnzymeRules: depthwiseconv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) x = rand(rng, repeat([5], spatial_rank)..., 3, 2) w = rand(rng, repeat([3], spatial_rank)..., 3, 3) @@ -928,4 +929,6 @@ end EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (w, Tw), (cdims, EnzymeCore.Const)) end +end + end \ No newline at end of file diff --git a/test/dropout.jl b/test/dropout.jl index 9f5649596..0da70111e 100644 --- a/test/dropout.jl +++ b/test/dropout.jl @@ -76,6 +76,8 @@ using Zygote, StableRNGs, ChainRulesCore, Enzyme @test_throws ArgumentError dropout!(y1, x1, 3) end +@static if Test_Enzyme + @testset "EnzymeRules: dropout " begin rng = Random.default_rng() @@ -99,4 +101,6 @@ end val = convert(Float32, 1/(1-p)) @test dx1[tape[1]] ≈ (val * dout)[tape[1]] +end + end \ No newline at end of file diff --git a/test/gather.jl b/test/gather.jl index c5cd5ee57..92e3bfb7d 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -154,6 +154,7 @@ function gather_testsuite(Backend) gradtest_fn((s, i) -> gather(s, i), src, idx) end + @static if Test_Enzyme @testset "EnzymeRules: gather! gradient for scalar index" begin src = device(Float64[3, 4, 5, 6, 7]) @@ -172,6 +173,8 @@ function gather_testsuite(Backend) end end + end + @testset "gather gradient for tuple index" begin src = device(Float64[ 3 5 7 diff --git a/test/pooling.jl b/test/pooling.jl index d695bd421..97d014d52 100644 --- a/test/pooling.jl +++ b/test/pooling.jl @@ -968,6 +968,7 @@ end gradtest(x -> sum(meanpool(x, k)), x) end +@static if Test_Enzyme @testset "EnzymeRules: pooling! $pool spatial_rank=$spatial_rank " for spatial_rank in (1, 2), (pool, pool!) in ((maxpool, maxpool!), (meanpool, meanpool!)) @@ -985,4 +986,6 @@ end EnzymeTestUtils.test_reverse(pool!, Tret, (y, Tdst), (x, Tsrc), (pdims, EnzymeCore.Const)) end +end + end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 9b43be8b8..8b359ad87 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,8 @@ using Adapt using KernelAbstractions import ReverseDiff as RD # used in `pooling.jl` +const Test_Enzyme = VERSION <= v"1.10" && !Sys.iswindows() + DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursive=true) # ENV["NNLIB_TEST_CUDA"] = "true" # uncomment to run CUDA tests diff --git a/test/scatter.jl b/test/scatter.jl index 289ed37b8..7e06e2650 100644 --- a/test/scatter.jl +++ b/test/scatter.jl @@ -208,6 +208,8 @@ function scatter_testsuite(Backend) end end + @static if Test_Enzyme + @testset "EnzymeRules" begin idx = device([2, 2, 3, 4, 4]) src = device(ones(T, 3, 5)) @@ -226,5 +228,7 @@ function scatter_testsuite(Backend) end end end + + end end end From b90ef83b6994c4b09f044ece2ac14fb69ca1fc04 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 26 Sep 2023 12:24:40 -0500 Subject: [PATCH 23/26] Failing to extension --- Project.toml | 7 +++++-- src/NNlib.jl | 8 ++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 8ea3da8f0..818cd315c 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "0.9.6" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -18,12 +17,14 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [extensions] NNlibAMDGPUExt = "AMDGPU" NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] NNlibCUDAExt = "CUDA" +NNlibEnzymeCoreExt = "EnzymeCore" [compat] AMDGPU = "0.5, 0.6" @@ -31,6 +32,7 @@ Adapt = "3.2" Atomix = "0.1" CUDA = "4, 5" ChainRulesCore = "1.13" +EnzymeCore = "0.5, 0.6" GPUArraysCore = "0.1" KernelAbstractions = "0.9.2" Requires = "1.0" @@ -43,6 +45,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -55,4 +58,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter", "FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff", "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN", "Enzyme", "EnzymeTestUtils"] +test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter", "FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff", "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN", "Enzyme", "EnzymeCore", "EnzymeTestUtils"] diff --git a/src/NNlib.jl b/src/NNlib.jl index 14ba2c70f..4c6deeb41 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -39,6 +39,12 @@ is_nnpack_available() = false end end +@static if !isdefined(Base, :get_extension) +@init @require EnzymeCore="f151be2c-9106-41f4-ab19-57ee4f262869" begin + include("../ext/NNlibEnzymeCore/NNlibEnzymeCoreExt.jl") +end +end + include("activations.jl") for f in ACTIVATIONS @eval export $(f) @@ -123,6 +129,4 @@ include("impl/depthwiseconv_im2col.jl") include("impl/pooling_direct.jl") include("deprecations.jl") -include("enzyme.jl") - end # module NNlib From c7112748699d0dd0a6d92acbd8b80f11585eeccf Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 26 Sep 2023 15:34:35 -0500 Subject: [PATCH 24/26] Add file --- ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl | 374 +++++++++++++++++++ 1 file changed, 374 insertions(+) create mode 100644 ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl diff --git a/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl b/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl new file mode 100644 index 000000000..27cbcb982 --- /dev/null +++ b/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl @@ -0,0 +1,374 @@ +module NNlibEnzymeCoreExt + +using NNlib +import EnzymeCore +using Random + +for (name, dataname, filtername) in ((typeof(NNlib.conv!), NNlib.∇conv_data!, NNlib.∇conv_filter!), + (typeof(NNlib.depthwiseconv!), NNlib.∇depthwiseconv_data!, NNlib.∇depthwiseconv_filter!) ) + @eval begin + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT}, + y::EnzymeCore.Annotation{<:AbstractArray{yT, N}}, + x::EnzymeCore.Annotation{<:AbstractArray{xT, N}}, + w::EnzymeCore.Annotation{<:AbstractArray{wT, N}}, + cdims; kwargs...) where {RT, yT, xT, wT, N} + + if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated + func.val(y.val, x.val, w.val, cdims.val; kwargs...) + end + + primal = if EnzymeCore.EnzymeRules.needs_primal(config) + y.val + else + nothing + end + shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + y.dval + else + nothing + end + + # Cache x if its overwritten and w is active (and thus required) + cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] + && !(typeof(w) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(x.val) : nothing + + # Cache w if its overwritten and x is active (and thus required) + cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] + && !(typeof(x) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(w.val) : nothing + + cache = (cache_x, cache_w) + + return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, ::Type{RT}, cache, + y::EnzymeCore.Annotation{<:AbstractArray{yT, N}}, + x::EnzymeCore.Annotation{<:AbstractArray{xT, N}}, + w::EnzymeCore.Annotation{<:AbstractArray{wT, N}}, + cdims; kwargs...) where {RT, yT, xT, wT, N} + cache_x, cache_w = cache + + # Don't cache x if not overwritten and w is active (and thus required) + if !(typeof(w) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[3] + cache_x = x.val + end + end + + # Don't cache w if not overwritten and x is active (and thus required) + if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[4] + cache_w = w.val + end + end + + dys = y.dval + dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval + dws = (typeof(w) <: EnzymeCore.Const) ? dys : w.dval + + if EnzymeCore.EnzymeRules.width(config) == 1 + dys = (dys,) + dxs = (dxs,) + dws = (dws,) + end + + for (dy, dx, dw) in zip(dys, dxs, dws) + if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val + + if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val + # dx += grad wrt x.val + $dataname(dx, dy, cache_w, cdims.val; alpha=xT(1), beta=xT(1), kwargs...) + end + if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val + # dw += grad wrt w.val + $filtername(dw, cache_x, dy, cdims.val; alpha=wT(1), beta=wT(1), kwargs...) + end + + dy .= 0 + end + end + + return (nothing, nothing, nothing, nothing) +end + +end +end + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} + + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated + func.val(dst.val, src.val, idx.val) + end + + primal = if EnzymeCore.EnzymeRules.needs_primal(config) + dst.val + else + nothing + end + shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + dst.dval + else + nothing + end + + # Cache idx if its overwritten + cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] + && !(typeof(src) <: EnzymeCore.Const) + && !(typeof(dst) <: EnzymeCore.Const) + ) ? copy(idx.val) : nothing + + return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, cache_idx, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} + + # Don't cache idx if not overwritten + if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[4] + cache_idx = idx.val + end + end + + ddsts = dst.dval + dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval + + if EnzymeCore.EnzymeRules.width(config) == 1 + ddsts = (ddsts,) + dsrcs = (dsrcs,) + end + + for (ddst, dsrc) in zip(ddsts, dsrcs) + if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val + + if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val + NNlib.scatter!(+, dsrc, ddst, cache_idx) + end + + ddst .= 0 + end + end + + return (nothing, nothing, nothing) +end + + + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, op::EnzymeCore.Const, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} + + @assert !(OutType <: EnzymeCore.Const) + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated + func.val(op.val, dst.val, src.val, idx.val) + end + + primal = if EnzymeCore.EnzymeRules.needs_primal(config) + dst.val + else + nothing + end + shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + dst.dval + else + nothing + end + + # Cache idx if its overwritten + cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] + && !(typeof(src) <: EnzymeCore.Const) + && !(typeof(dst) <: EnzymeCore.Const) + ) ? copy(idx.val) : nothing + + return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, cache_idx, op::Union{EnzymeCore.Const{typeof(+)},EnzymeCore.Const{typeof(-)}}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} + + # Don't cache idx if not overwritten + if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[4] + cache_idx = idx.val + end + end + + ddsts = dst.dval + dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval + + if EnzymeCore.EnzymeRules.width(config) == 1 + ddsts = (ddsts,) + dsrcs = (dsrcs,) + end + + for (ddst, dsrc) in zip(ddsts, dsrcs) + if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val + + if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val + + if eltype(typeof(op)) == typeof(+) + dsrc .+= NNlib.gather(ddst, cache_idx) + else + @assert eltype(typeof(op)) == typeof(-) + dsrc .-= NNlib.gather(ddst, cache_idx) + end + end + + end + end + + return (nothing, nothing, nothing, nothing) +end + + + +for pool in [:maxpool, :meanpool, :lpnormpool] + pool! = Symbol(pool, :!) + ∇pool = Symbol(:∇, pool, :!) + + @eval begin + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, y::OutType, x, dims; kwargs...) where {OutType, RT} + + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated + func.val(y.val, x.val, dims.val; kwargs...) + end + + primal = if EnzymeCore.EnzymeRules.needs_primal(config) + y.val + else + nothing + end + shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + y.dval + else + nothing + end + + cache_y = ( EnzymeCore.EnzymeRules.overwritten(config)[2] + && !(typeof(x) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(y.val) : nothing + + cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] + && !(typeof(x) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(x.val) : nothing + + cache = (cache_y, cache_x) + + return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, cache, y, x, dims; kwargs...) where {RT} + cache_y, cache_x = cache + + # Don't cache y if not overwritten + if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[2] + cache_y = y.val + end + end + + # Don't cache x if not overwritten + if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) + if !EnzymeCore.EnzymeRules.overwritten(config)[3] + cache_x = x.val + end + end + + dys = y.dval + dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval + + if EnzymeCore.EnzymeRules.width(config) == 1 + dys = (dys,) + dxs = (dxs,) + end + + for (dy, dx) in zip(dys, dxs) + if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val + + if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val + NNlib.$(∇pool)(dx, dy, cache_y, cache_x, dims.val; alpha=eltype(dx)(1), beta=eltype(dx)(1), kwargs...) + end + + dy .= 0 + end + end + + return (nothing, nothing, nothing) +end + +end +end + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, rng, dst::OutType, src, p, dims) where {OutType, RT} + + T = float(real(eltype(dst.val))) + val = convert(T, 1/(1-p.val)) + keep = if dims.val isa Colon + similar(dst.val, T, size(dst.val)) + else + similar(dst.val, T, ntuple(d -> d in dims.val ? size(dst.val,d) : 1, ndims(dst.val))) + end + rand!(rng.val, keep) + + keep = keep .> p.val + + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated + dst.val .= (keep .* val) .* src.val + end + + primal = if EnzymeCore.EnzymeRules.needs_primal(config) + dst.val + else + nothing + end + shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + dst.dval + else + nothing + end + + if typeof(dst) <: EnzymeCore.Const || typeof(src) <: EnzymeCore.Const + keep = nothing + end + + return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, keep) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, keep, rng, dst::OutType, src, p, dims) where {OutType, RT} + T = float(real(eltype(dst.val))) + val = convert(T, 1/(1-p.val)) + + ddsts = dst.dval + dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval + + if EnzymeCore.EnzymeRules.width(config) == 1 + ddsts = (ddsts,) + dsrcs = (dsrcs,) + end + + for (ddst, dsrc) in zip(ddsts, dsrcs) + if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val + + if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val + dsrc .+= (keep .* val) .* ddst + end + + ddst .= 0 + end + end + + dp = if typeof(p) <: EnzymeCore.Active + typeof(p.val)(0) + else + nothing + end + + return (nothing, nothing, nothing, dp, nothing) +end + + +end \ No newline at end of file From a8d35f08b2b14501e342bb02612598afc8925316 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 27 Sep 2023 10:09:30 -0500 Subject: [PATCH 25/26] Address review --- Project.toml | 3 +- ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl | 246 +++++++------ src/NNlib.jl | 6 - src/enzyme.jl | 367 ------------------- 4 files changed, 129 insertions(+), 493 deletions(-) delete mode 100644 src/enzyme.jl diff --git a/Project.toml b/Project.toml index 818cd315c..7fad9dfed 100644 --- a/Project.toml +++ b/Project.toml @@ -58,4 +58,5 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter", "FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff", "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN", "Enzyme", "EnzymeCore", "EnzymeTestUtils"] +test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter", "FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff", "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN", + "Enzyme", "EnzymeCore", "EnzymeTestUtils"] diff --git a/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl b/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl index 27cbcb982..7624463da 100644 --- a/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl +++ b/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl @@ -4,132 +4,134 @@ using NNlib import EnzymeCore using Random +using EnzymeCore.EnzymeRules + for (name, dataname, filtername) in ((typeof(NNlib.conv!), NNlib.∇conv_data!, NNlib.∇conv_filter!), (typeof(NNlib.depthwiseconv!), NNlib.∇depthwiseconv_data!, NNlib.∇depthwiseconv_filter!) ) @eval begin -function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT}, - y::EnzymeCore.Annotation{<:AbstractArray{yT, N}}, - x::EnzymeCore.Annotation{<:AbstractArray{xT, N}}, - w::EnzymeCore.Annotation{<:AbstractArray{wT, N}}, - cdims; kwargs...) where {RT, yT, xT, wT, N} - - if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated - func.val(y.val, x.val, w.val, cdims.val; kwargs...) - end - - primal = if EnzymeCore.EnzymeRules.needs_primal(config) - y.val - else - nothing - end - shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) - y.dval - else - nothing - end - - # Cache x if its overwritten and w is active (and thus required) - cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] - && !(typeof(w) <: EnzymeCore.Const) - && !(typeof(y) <: EnzymeCore.Const) - ) ? copy(x.val) : nothing - - # Cache w if its overwritten and x is active (and thus required) - cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] - && !(typeof(x) <: EnzymeCore.Const) - && !(typeof(y) <: EnzymeCore.Const) - ) ? copy(w.val) : nothing - - cache = (cache_x, cache_w) - - return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) -end - -function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, ::Type{RT}, cache, - y::EnzymeCore.Annotation{<:AbstractArray{yT, N}}, - x::EnzymeCore.Annotation{<:AbstractArray{xT, N}}, - w::EnzymeCore.Annotation{<:AbstractArray{wT, N}}, - cdims; kwargs...) where {RT, yT, xT, wT, N} - cache_x, cache_w = cache - - # Don't cache x if not overwritten and w is active (and thus required) - if !(typeof(w) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[3] - cache_x = x.val - end - end - - # Don't cache w if not overwritten and x is active (and thus required) - if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[4] - cache_w = w.val - end - end - - dys = y.dval - dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval - dws = (typeof(w) <: EnzymeCore.Const) ? dys : w.dval - - if EnzymeCore.EnzymeRules.width(config) == 1 - dys = (dys,) - dxs = (dxs,) - dws = (dws,) - end - - for (dy, dx, dw) in zip(dys, dxs, dws) - if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val - - if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val - # dx += grad wrt x.val - $dataname(dx, dy, cache_w, cdims.val; alpha=xT(1), beta=xT(1), kwargs...) - end - if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val - # dw += grad wrt w.val - $filtername(dw, cache_x, dy, cdims.val; alpha=wT(1), beta=wT(1), kwargs...) - end - - dy .= 0 - end - end - - return (nothing, nothing, nothing, nothing) -end + function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT}, + y::EnzymeCore.Annotation{<:AbstractArray{yT, N}}, + x::EnzymeCore.Annotation{<:AbstractArray{xT, N}}, + w::EnzymeCore.Annotation{<:AbstractArray{wT, N}}, + cdims; kwargs...) where {RT, yT, xT, wT, N} + + if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated + func.val(y.val, x.val, w.val, cdims.val; kwargs...) + end + + primal = if EnzymeRules.needs_primal(config) + y.val + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) + y.dval + else + nothing + end + + # Cache x if its overwritten and w is active (and thus required) + cache_x = ( EnzymeRules.overwritten(config)[3] + && !(typeof(w) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(x.val) : nothing + + # Cache w if its overwritten and x is active (and thus required) + cache_w = ( EnzymeRules.overwritten(config)[4] + && !(typeof(x) <: EnzymeCore.Const) + && !(typeof(y) <: EnzymeCore.Const) + ) ? copy(w.val) : nothing + + cache = (cache_x, cache_w) + + return EnzymeRules.AugmentedReturn(primal, shadow, cache) + end + + function EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, ::Type{RT}, cache, + y::EnzymeCore.Annotation{<:AbstractArray{yT, N}}, + x::EnzymeCore.Annotation{<:AbstractArray{xT, N}}, + w::EnzymeCore.Annotation{<:AbstractArray{wT, N}}, + cdims; kwargs...) where {RT, yT, xT, wT, N} + cache_x, cache_w = cache + + # Don't cache x if not overwritten and w is active (and thus required) + if !(typeof(w) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(config)[3] + cache_x = x.val + end + end + + # Don't cache w if not overwritten and x is active (and thus required) + if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(config)[4] + cache_w = w.val + end + end + + dys = y.dval + dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval + dws = (typeof(w) <: EnzymeCore.Const) ? dys : w.dval + + if EnzymeRules.width(config) == 1 + dys = (dys,) + dxs = (dxs,) + dws = (dws,) + end + + for (dy, dx, dw) in zip(dys, dxs, dws) + if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val + + if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val + # dx += grad wrt x.val + $dataname(dx, dy, cache_w, cdims.val; alpha=xT(1), beta=xT(1), kwargs...) + end + if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val + # dw += grad wrt w.val + $filtername(dw, cache_x, dy, cdims.val; alpha=wT(1), beta=wT(1), kwargs...) + end + + dy .= 0 + end + end + + return (nothing, nothing, nothing, nothing) + end end end -function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} +function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated func.val(dst.val, src.val, idx.val) end - primal = if EnzymeCore.EnzymeRules.needs_primal(config) + primal = if EnzymeRules.needs_primal(config) dst.val else nothing end - shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + shadow = if EnzymeRules.needs_shadow(config) dst.dval else nothing end # Cache idx if its overwritten - cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] + cache_idx = ( EnzymeRules.overwritten(config)[4] && !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) ) ? copy(idx.val) : nothing - return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) + return EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) end -function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, cache_idx, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} +function EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, cache_idx, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} # Don't cache idx if not overwritten if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[4] + if !EnzymeRules.overwritten(config)[4] cache_idx = idx.val end end @@ -137,7 +139,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN ddsts = dst.dval dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval - if EnzymeCore.EnzymeRules.width(config) == 1 + if EnzymeRules.width(config) == 1 ddsts = (ddsts,) dsrcs = (dsrcs,) end @@ -158,38 +160,44 @@ end -function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, op::EnzymeCore.Const, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} +function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, op::EnzymeCore.Const, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} @assert !(OutType <: EnzymeCore.Const) if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated func.val(op.val, dst.val, src.val, idx.val) end - primal = if EnzymeCore.EnzymeRules.needs_primal(config) + primal = if EnzymeRules.needs_primal(config) dst.val else nothing end - shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + shadow = if EnzymeRules.needs_shadow(config) dst.dval else nothing end # Cache idx if its overwritten - cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] + cache_idx = ( EnzymeRules.overwritten(config)[4] && !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) ) ? copy(idx.val) : nothing - return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) + return EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) end -function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, cache_idx, op::Union{EnzymeCore.Const{typeof(+)},EnzymeCore.Const{typeof(-)}}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} +function EnzymeRules.reverse(config, + func::EnzymeCore.Const{typeof(NNlib.scatter!)}, + ::Type{RT}, + cache_idx, + op::Union{EnzymeCore.Const{typeof(+)},EnzymeCore.Const{typeof(-)}}, dst::OutType, + src, + idx::EnzymeCore.Const) where {OutType, RT} # Don't cache idx if not overwritten if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[4] + if !EnzymeRules.overwritten(config)[4] cache_idx = idx.val end end @@ -197,7 +205,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN ddsts = dst.dval dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval - if EnzymeCore.EnzymeRules.width(config) == 1 + if EnzymeRules.width(config) == 1 ddsts = (ddsts,) dsrcs = (dsrcs,) end @@ -229,51 +237,51 @@ for pool in [:maxpool, :meanpool, :lpnormpool] @eval begin -function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, y::OutType, x, dims; kwargs...) where {OutType, RT} +function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, y::OutType, x, dims; kwargs...) where {OutType, RT} if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated func.val(y.val, x.val, dims.val; kwargs...) end - primal = if EnzymeCore.EnzymeRules.needs_primal(config) + primal = if EnzymeRules.needs_primal(config) y.val else nothing end - shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + shadow = if EnzymeRules.needs_shadow(config) y.dval else nothing end - cache_y = ( EnzymeCore.EnzymeRules.overwritten(config)[2] + cache_y = ( EnzymeRules.overwritten(config)[2] && !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) ) ? copy(y.val) : nothing - cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] + cache_x = ( EnzymeRules.overwritten(config)[3] && !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) ) ? copy(x.val) : nothing cache = (cache_y, cache_x) - return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) end -function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, cache, y, x, dims; kwargs...) where {RT} +function EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, cache, y, x, dims; kwargs...) where {RT} cache_y, cache_x = cache # Don't cache y if not overwritten if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[2] + if !EnzymeRules.overwritten(config)[2] cache_y = y.val end end # Don't cache x if not overwritten if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[3] + if !EnzymeRules.overwritten(config)[3] cache_x = x.val end end @@ -281,7 +289,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof($p dys = y.dval dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval - if EnzymeCore.EnzymeRules.width(config) == 1 + if EnzymeRules.width(config) == 1 dys = (dys,) dxs = (dxs,) end @@ -303,7 +311,7 @@ end end end -function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, rng, dst::OutType, src, p, dims) where {OutType, RT} +function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, rng, dst::OutType, src, p, dims) where {OutType, RT} T = float(real(eltype(dst.val))) val = convert(T, 1/(1-p.val)) @@ -320,12 +328,12 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{ dst.val .= (keep .* val) .* src.val end - primal = if EnzymeCore.EnzymeRules.needs_primal(config) + primal = if EnzymeRules.needs_primal(config) dst.val else nothing end - shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + shadow = if EnzymeRules.needs_shadow(config) dst.dval else nothing @@ -335,17 +343,17 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{ keep = nothing end - return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, keep) + return EnzymeRules.AugmentedReturn(primal, shadow, keep) end -function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, keep, rng, dst::OutType, src, p, dims) where {OutType, RT} +function EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, keep, rng, dst::OutType, src, p, dims) where {OutType, RT} T = float(real(eltype(dst.val))) val = convert(T, 1/(1-p.val)) ddsts = dst.dval dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval - if EnzymeCore.EnzymeRules.width(config) == 1 + if EnzymeRules.width(config) == 1 ddsts = (ddsts,) dsrcs = (dsrcs,) end diff --git a/src/NNlib.jl b/src/NNlib.jl index 4c6deeb41..8450a0261 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -39,12 +39,6 @@ is_nnpack_available() = false end end -@static if !isdefined(Base, :get_extension) -@init @require EnzymeCore="f151be2c-9106-41f4-ab19-57ee4f262869" begin - include("../ext/NNlibEnzymeCore/NNlibEnzymeCoreExt.jl") -end -end - include("activations.jl") for f in ACTIVATIONS @eval export $(f) diff --git a/src/enzyme.jl b/src/enzyme.jl deleted file mode 100644 index aee96b641..000000000 --- a/src/enzyme.jl +++ /dev/null @@ -1,367 +0,0 @@ -import EnzymeCore - -for (name, dataname, filtername) in ((typeof(NNlib.conv!), NNlib.∇conv_data!, NNlib.∇conv_filter!), - (typeof(NNlib.depthwiseconv!), NNlib.∇depthwiseconv_data!, NNlib.∇depthwiseconv_filter!) ) - @eval begin - -function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT}, - y::EnzymeCore.Annotation{<:AbstractArray{yT, N}}, - x::EnzymeCore.Annotation{<:AbstractArray{xT, N}}, - w::EnzymeCore.Annotation{<:AbstractArray{wT, N}}, - cdims; kwargs...) where {RT, yT, xT, wT, N} - - if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated - func.val(y.val, x.val, w.val, cdims.val; kwargs...) - end - - primal = if EnzymeCore.EnzymeRules.needs_primal(config) - y.val - else - nothing - end - shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) - y.dval - else - nothing - end - - # Cache x if its overwritten and w is active (and thus required) - cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] - && !(typeof(w) <: EnzymeCore.Const) - && !(typeof(y) <: EnzymeCore.Const) - ) ? copy(x.val) : nothing - - # Cache w if its overwritten and x is active (and thus required) - cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] - && !(typeof(x) <: EnzymeCore.Const) - && !(typeof(y) <: EnzymeCore.Const) - ) ? copy(w.val) : nothing - - cache = (cache_x, cache_w) - - return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) -end - -function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, ::Type{RT}, cache, - y::EnzymeCore.Annotation{<:AbstractArray{yT, N}}, - x::EnzymeCore.Annotation{<:AbstractArray{xT, N}}, - w::EnzymeCore.Annotation{<:AbstractArray{wT, N}}, - cdims; kwargs...) where {RT, yT, xT, wT, N} - cache_x, cache_w = cache - - # Don't cache x if not overwritten and w is active (and thus required) - if !(typeof(w) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[3] - cache_x = x.val - end - end - - # Don't cache w if not overwritten and x is active (and thus required) - if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[4] - cache_w = w.val - end - end - - dys = y.dval - dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval - dws = (typeof(w) <: EnzymeCore.Const) ? dys : w.dval - - if EnzymeCore.EnzymeRules.width(config) == 1 - dys = (dys,) - dxs = (dxs,) - dws = (dws,) - end - - for (dy, dx, dw) in zip(dys, dxs, dws) - if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val - - if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val - # dx += grad wrt x.val - $dataname(dx, dy, cache_w, cdims.val; alpha=xT(1), beta=xT(1), kwargs...) - end - if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val - # dw += grad wrt w.val - $filtername(dw, cache_x, dy, cdims.val; alpha=wT(1), beta=wT(1), kwargs...) - end - - dy .= 0 - end - end - - return (nothing, nothing, nothing, nothing) -end - -end -end - -function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} - - if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated - func.val(dst.val, src.val, idx.val) - end - - primal = if EnzymeCore.EnzymeRules.needs_primal(config) - dst.val - else - nothing - end - shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) - dst.dval - else - nothing - end - - # Cache idx if its overwritten - cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] - && !(typeof(src) <: EnzymeCore.Const) - && !(typeof(dst) <: EnzymeCore.Const) - ) ? copy(idx.val) : nothing - - return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) -end - -function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, cache_idx, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} - - # Don't cache idx if not overwritten - if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[4] - cache_idx = idx.val - end - end - - ddsts = dst.dval - dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval - - if EnzymeCore.EnzymeRules.width(config) == 1 - ddsts = (ddsts,) - dsrcs = (dsrcs,) - end - - for (ddst, dsrc) in zip(ddsts, dsrcs) - if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val - - if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val - NNlib.scatter!(+, dsrc, ddst, cache_idx) - end - - ddst .= 0 - end - end - - return (nothing, nothing, nothing) -end - - - -function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, op::EnzymeCore.Const, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} - - @assert !(OutType <: EnzymeCore.Const) - if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated - func.val(op.val, dst.val, src.val, idx.val) - end - - primal = if EnzymeCore.EnzymeRules.needs_primal(config) - dst.val - else - nothing - end - shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) - dst.dval - else - nothing - end - - # Cache idx if its overwritten - cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] - && !(typeof(src) <: EnzymeCore.Const) - && !(typeof(dst) <: EnzymeCore.Const) - ) ? copy(idx.val) : nothing - - return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) -end - -function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, cache_idx, op::Union{EnzymeCore.Const{typeof(+)},EnzymeCore.Const{typeof(-)}}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} - - # Don't cache idx if not overwritten - if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[4] - cache_idx = idx.val - end - end - - ddsts = dst.dval - dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval - - if EnzymeCore.EnzymeRules.width(config) == 1 - ddsts = (ddsts,) - dsrcs = (dsrcs,) - end - - for (ddst, dsrc) in zip(ddsts, dsrcs) - if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val - - if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val - - if eltype(typeof(op)) == typeof(+) - dsrc .+= NNlib.gather(ddst, cache_idx) - else - @assert eltype(typeof(op)) == typeof(-) - dsrc .-= NNlib.gather(ddst, cache_idx) - end - end - - end - end - - return (nothing, nothing, nothing, nothing) -end - - - -for pool in [:maxpool, :meanpool, :lpnormpool] - pool! = Symbol(pool, :!) - ∇pool = Symbol(:∇, pool, :!) - - @eval begin - -function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, y::OutType, x, dims; kwargs...) where {OutType, RT} - - if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated - func.val(y.val, x.val, dims.val; kwargs...) - end - - primal = if EnzymeCore.EnzymeRules.needs_primal(config) - y.val - else - nothing - end - shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) - y.dval - else - nothing - end - - cache_y = ( EnzymeCore.EnzymeRules.overwritten(config)[2] - && !(typeof(x) <: EnzymeCore.Const) - && !(typeof(y) <: EnzymeCore.Const) - ) ? copy(y.val) : nothing - - cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] - && !(typeof(x) <: EnzymeCore.Const) - && !(typeof(y) <: EnzymeCore.Const) - ) ? copy(x.val) : nothing - - cache = (cache_y, cache_x) - - return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) -end - -function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, cache, y, x, dims; kwargs...) where {RT} - cache_y, cache_x = cache - - # Don't cache y if not overwritten - if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[2] - cache_y = y.val - end - end - - # Don't cache x if not overwritten - if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) - if !EnzymeCore.EnzymeRules.overwritten(config)[3] - cache_x = x.val - end - end - - dys = y.dval - dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval - - if EnzymeCore.EnzymeRules.width(config) == 1 - dys = (dys,) - dxs = (dxs,) - end - - for (dy, dx) in zip(dys, dxs) - if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val - - if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val - NNlib.$(∇pool)(dx, dy, cache_y, cache_x, dims.val; alpha=eltype(dx)(1), beta=eltype(dx)(1), kwargs...) - end - - dy .= 0 - end - end - - return (nothing, nothing, nothing) -end - -end -end - -function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, rng, dst::OutType, src, p, dims) where {OutType, RT} - - T = float(real(eltype(dst.val))) - val = convert(T, 1/(1-p.val)) - keep = if dims.val isa Colon - similar(dst.val, T, size(dst.val)) - else - similar(dst.val, T, ntuple(d -> d in dims.val ? size(dst.val,d) : 1, ndims(dst.val))) - end - rand!(rng.val, keep) - - keep = keep .> p.val - - if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated - dst.val .= (keep .* val) .* src.val - end - - primal = if EnzymeCore.EnzymeRules.needs_primal(config) - dst.val - else - nothing - end - shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) - dst.dval - else - nothing - end - - if typeof(dst) <: EnzymeCore.Const || typeof(src) <: EnzymeCore.Const - keep = nothing - end - - return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, keep) -end - -function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, keep, rng, dst::OutType, src, p, dims) where {OutType, RT} - T = float(real(eltype(dst.val))) - val = convert(T, 1/(1-p.val)) - - ddsts = dst.dval - dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval - - if EnzymeCore.EnzymeRules.width(config) == 1 - ddsts = (ddsts,) - dsrcs = (dsrcs,) - end - - for (ddst, dsrc) in zip(ddsts, dsrcs) - if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val - - if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val - dsrc .+= (keep .* val) .* ddst - end - - ddst .= 0 - end - end - - dp = if typeof(p) <: EnzymeCore.Active - typeof(p.val)(0) - else - nothing - end - - return (nothing, nothing, nothing, dp, nothing) -end From 7414d3d44f7b46a5d47dd492b7ee6c0d61a29664 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 28 Sep 2023 11:13:05 -0400 Subject: [PATCH 26/26] Remove inlining --- src/conv.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/conv.jl b/src/conv.jl index f15551e02..3fecb9151 100644 --- a/src/conv.jl +++ b/src/conv.jl @@ -47,7 +47,7 @@ Apply convolution filter `w` to input `x`. `x` and `w` are 3d/4d/5d tensors in 1d/2d/3d convolutions respectively. `x` and `w` may have real or complex element types. """ -@inline function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1) where {T, N} +function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1) where {T, N} stride = expand(Val(N - 2), stride) padding = expand(Val(N - 2), pad) dilation = expand(Val(N - 2), dilation) @@ -62,7 +62,7 @@ end Depthwise convolution operation with filter `w` on input `x`. `x` and `w` are 3d/4d/5d tensors in 1d/2d/3d convolutions respectively. """ -@inline function depthwiseconv(x, w::AbstractArray{T, N}; stride=1, pad=0, dilation=1, flipped=false) where {T, N} +function depthwiseconv(x, w::AbstractArray{T, N}; stride=1, pad=0, dilation=1, flipped=false) where {T, N} stride = expand(Val(N-2), stride) pad = expand(Val(N-2), pad) dilation = expand(Val(N-2), dilation) @@ -80,7 +80,7 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack) # First make auto-allocating versions of the conv()-like calls: for name in (:conv, :depthwiseconv) @eval begin - @inline function $(Symbol("$(name)$(backend)"))( + function $(Symbol("$(name)$(backend)"))( x::AbstractArray{xT,N}, w::AbstractArray{wT,N}, cdims::ConvDims; kwargs...) where {xT, wT, N} y = similar(x, promote_type(xT, wT), output_size(cdims)..., @@ -92,7 +92,7 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack) for name in (:∇conv_data, :∇depthwiseconv_data) @eval begin - @inline function $(Symbol("$(name)$(backend)"))( + function $(Symbol("$(name)$(backend)"))( dy::AbstractArray{yT,N}, w::AbstractArray{wT,N}, cdims::C; kwargs...) where {yT, wT, N, C <: ConvDims} dx = similar(dy, input_size(cdims)..., channels_in(cdims), size(dy, N)) @@ -104,7 +104,7 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack) # We do the conv/depthwiseconv filter backprops separately, as the shape calculation # for `w` is slightly different for depthwise than for normal dense convolution. @eval begin - @inline function $(Symbol("∇conv_filter$(backend)"))( + function $(Symbol("∇conv_filter$(backend)"))( x::AbstractArray{xT,N}, dy::AbstractArray{yT,N}, cdims::ConvDims; kwargs...) where {xT, yT, N} dw = similar(dy, kernel_size(cdims)..., channels_in(cdims) ÷ groupcount(cdims), @@ -114,7 +114,7 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack) end @eval begin - @inline function $(Symbol("∇depthwiseconv_filter$(backend)"))( + function $(Symbol("∇depthwiseconv_filter$(backend)"))( x::AbstractArray{xT,N}, dy::AbstractArray{yT,N}, cdims::ConvDims; kwargs...) where {xT, yT, N} dw = similar(dy, kernel_size(cdims)..., channel_multiplier(cdims), @@ -137,7 +137,7 @@ for front_name in (:conv, :∇conv_data, :∇conv_filter, for backend in (Symbol(), :_direct, :_im2col) ## NNPACK is only for 2d conv for N in (3, 4) @eval begin - @inline function $(Symbol("$(front_name)$(backend)!"))( + function $(Symbol("$(front_name)$(backend)!"))( y::AbstractArray{yT,$N}, x::AbstractArray{xT,$N}, w::AbstractArray{wT,$N}, cdims::ConvDims; kwargs...) where {yT, xT, wT}