Skip to content

Commit

Permalink
Add pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 25, 2023
1 parent 02bfb60 commit 1bb3081
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 30 deletions.
161 changes: 131 additions & 30 deletions src/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ for name in (typeof(NNlib.conv!), typeof(NNlib.depthwiseconv!))

function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT}

@assert !(OutType <: EnzymeCore.Const)
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
func.val(y.val, x.val, w.val, cdims.val; kwargs...)
end
Expand All @@ -22,10 +21,16 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{
end

# Cache x if its overwritten and w is active (and thus required)
cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] && !(typeof(w) <: EnzymeCore.Const) ) ? copy(x.val) : nothing
cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3]
&& !(typeof(w) <: EnzymeCore.Const)
&& !(typeof(y) <: EnzymeCore.Const)
) ? copy(x.val) : nothing

# Cache w if its overwritten and x is active (and thus required)
cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(x) <: EnzymeCore.Const) ) ? copy(w.val) : nothing
cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4]
&& !(typeof(x) <: EnzymeCore.Const)
&& !(typeof(y) <: EnzymeCore.Const)
) ? copy(w.val) : nothing

cache = (cache_x, cache_w)

Expand All @@ -36,14 +41,14 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, :
cache_x, cache_w = cache

# Don't cache x if not overwritten and w is active (and thus required)
if !(typeof(w) <: EnzymeCore.Const)
if !(typeof(w) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const)
if !EnzymeCore.EnzymeRules.overwritten(config)[3]
cache_x = x.val
end
end

# Don't cache w if not overwritten and x is active (and thus required)
if !(typeof(x) <: EnzymeCore.Const)
if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const)
if !EnzymeCore.EnzymeRules.overwritten(config)[4]
cache_w = w.val
end
Expand All @@ -60,15 +65,19 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, :
end

for (dy, dx, dw) in zip(dys, dxs, dws)
if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val
# dx += grad wrt x.val
NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...)
end
if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val
# dw += grad wrt w.val
NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...)
if !(typeof(y) <: EnzymeCore.Const) && dy !== w.val

if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val
# dx += grad wrt x.val
NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...)
end
if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val
# dw += grad wrt w.val
NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...)
end

dy .= 0
end
dy .= 0
end

return (nothing, nothing, nothing, nothing)
Expand All @@ -79,7 +88,6 @@ end

function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}

@assert !(OutType <: EnzymeCore.Const)
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
func.val(dst.val, src.val, idx.val)
end
Expand All @@ -96,15 +104,18 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{
end

# Cache idx if its overwritten
cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(src) <: EnzymeCore.Const) ) ? copy(idx.val) : nothing
cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4]
&& !(typeof(src) <: EnzymeCore.Const)
&& !(typeof(dst) <: EnzymeCore.Const)
) ? copy(idx.val) : nothing

return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx)
end

function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, cache_idx, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}

# Don't cache idx if not overwritten
if !(typeof(src) <: EnzymeCore.Const)
if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const)
if !EnzymeCore.EnzymeRules.overwritten(config)[4]
cache_idx = idx.val
end
Expand All @@ -119,11 +130,12 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
end

for (ddst, dsrc) in zip(ddsts, dsrcs)
if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val &&
!(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val
NNlib.scatter!(+, dsrc, ddst, cache_idx)
end
if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val

if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val
NNlib.scatter!(+, dsrc, ddst, cache_idx)
end

ddst .= 0
end
end
Expand Down Expand Up @@ -152,15 +164,18 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{
end

# Cache idx if its overwritten
cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(src) <: EnzymeCore.Const) ) ? copy(idx.val) : nothing
cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4]
&& !(typeof(src) <: EnzymeCore.Const)
&& !(typeof(dst) <: EnzymeCore.Const)
) ? copy(idx.val) : nothing

return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx)
end

function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, cache_idx, op::Union{EnzymeCore.Const{typeof(+)},EnzymeCore.Const{typeof(-)}}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}

# Don't cache idx if not overwritten
if !(typeof(src) <: EnzymeCore.Const)
if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const)
if !EnzymeCore.EnzymeRules.overwritten(config)[4]
cache_idx = idx.val
end
Expand All @@ -175,15 +190,20 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
end

for (ddst, dsrc) in zip(ddsts, dsrcs)
if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val &&
!(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val

if eltype(typeof(op)) == typeof(+)
dsrc .+= NNlib.gather(ddst, cache_idx)
else
@assert eltype(typeof(op)) == typeof(-)
dsrc .-= NNlib.gather(ddst, cache_idx)
if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val

if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val

if eltype(typeof(op)) == typeof(+)
dsrc .+= NNlib.gather(ddst, cache_idx)
else
@assert eltype(typeof(op)) == typeof(-)
dsrc .-= NNlib.gather(ddst, cache_idx)
end
end

ddst .= 0

end
end

Expand All @@ -192,3 +212,84 @@ end



for pool in [:maxpool, :meanpool, :lpnormpool]
pool! = Symbol(pool, :!)
∇pool = Symbol(:∇, pool)

@eval begin

function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, y::OutType, x, dims; kwargs...) where {OutType, RT}

if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
func.val(y.val, x.val, dims.val; kwargs...)
end

primal = if EnzymeCore.EnzymeRules.needs_primal(config)
y.val
else
nothing
end
shadow = if EnzymeCore.EnzymeRules.needs_shadow(config)
y.dval
else
nothing
end

cache_y = ( EnzymeCore.EnzymeRules.overwritten(config)[2]
&& !(typeof(x) <: EnzymeCore.Const)
&& !(typeof(y) <: EnzymeCore.Const)
) ? copy(y.val) : nothing

cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3]
&& !(typeof(x) <: EnzymeCore.Const)
&& !(typeof(y) <: EnzymeCore.Const)
) ? copy(x.val) : nothing

cache = (cache_y, cache_x)

return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache)
end

function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, cache, y, x, dims; kwargs...) where {RT}
cache_y, cache_x = cache

# Don't cache y if not overwritten
if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const)
if !EnzymeCore.EnzymeRules.overwritten(config)[2]
cache_y = y.val
end
end

# Don't cache x if not overwritten
if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const)
if !EnzymeCore.EnzymeRules.overwritten(config)[3]
cache_x = x.val
end
end

dys = y.dval
dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval

if EnzymeCore.EnzymeRules.width(config) == 1
dys = (dys,)
dxs = (dxs,)
end

for (dy, dx, dw) in zip(dys, dxs)
if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val

if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val
NNlib.$(∇pool)(dx, dy, cache_y, cache_x, dims; alpha=eltype(dx)(1), beta=eltype(dx)(1), kwargs...)
end

dy .= 0
end
end

return (nothing, nothing, nothing)
end

end
end


19 changes: 19 additions & 0 deletions test/pooling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -967,3 +967,22 @@ end
gradtest(x -> sum(maxpool(x, k)), x, skip = spatial_rank==2)
gradtest(x -> sum(meanpool(x, k)), x)
end


@testset "EnzymeRules: pooling! $pool spatial_rank=$spatial_rank " for spatial_rank in (1, 2),
(pool, pool!) in ((maxpool, maxpool!), (meanpool, meanpool!))

x = rand(rng, repeat([10], spatial_rank)..., 3, 2)
pdims = PoolDims(x, 2)
y = pool(x, pdims)

for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)

EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue

EnzymeTestUtils.test_reverse(pool!, Tret, (y, Tdst), (x, Tsrc), (pdims, EnzymeCore.Const))
end

end

0 comments on commit 1bb3081

Please sign in to comment.