Skip to content

Commit

Permalink
don't own _make_zero!
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 23, 2024
1 parent ecef1f0 commit 976be71
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 9 deletions.
7 changes: 3 additions & 4 deletions ext/FluxEnzymeExt/FluxEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
module FluxEnzymeExt

using Flux
using Flux: _make_zero!
import Flux.Train: _enzyme_train!

import Optimisers
import Functors
import Enzyme
using Enzyme: EnzymeRules, Active, Const, Duplicated, autodiff, ReverseWithPrimal
using Enzyme: EnzymeRules, Active, Const, Duplicated, autodiff, ReverseWithPrimal, DuplicatedNoNeed, make_zero!
using Enzyme: autodiff_thunk, Reverse, ReverseSplitWithPrimal
using ProgressLogging: @withprogress, @logprogress

Expand Down Expand Up @@ -36,7 +35,7 @@ _grad_or_nothing(x) = Optimisers.isnumeric(x) ? x : nothing

function Flux._enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero::Bool=true)
for x in args
zero && x isa Duplicated && _make_zero!(x.dval)
zero && x isa Duplicated && make_zero!(x.dval)
_check_mutable(x)
end

Expand Down Expand Up @@ -85,7 +84,7 @@ function _enzyme_train!(loss, model::Duplicated, data, opt; cb = nothing)
@withprogress for (i,d) in enumerate(data)
d_splat = d isa Tuple ? d : (d,)

_make_zero!(model.dval)
make_zero!(model.dval)
_, l = Enzyme.autodiff(ReverseWithPrimal, _applyloss,
Active, Const(loss), model, map(Const, d_splat)...)

Expand Down
6 changes: 1 addition & 5 deletions src/layers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,10 @@ _noquotenode(s::Symbol) = s
_noquotenode(q::QuoteNode) = q.value # lets you write trainable=(:x,:y) instead of (x,y)
_noquotenode(ex) = error("expected a symbol here, as a field name, but got ", ex)

_make_zero_internal!(x::AbstractArray) = fill!(x, 0)
_make_zero_internal!(x) = x
_make_zero!(model) = fmap(_make_zero_internal!, model)

function _macro_enzyme(type)
out = quote
# One-arg method Duplicated(m::Layer) which allocates & zeros the gradient:
$EnzymeCore.Duplicated(m::$type) = $EnzymeCore.Duplicated(m, $_make_zero!($deepcopy(m)))
$EnzymeCore.Duplicated(m::$type) = $EnzymeCore.Duplicated(m, $EnzymeCore.make_zero(m))

# Not sure we want this, but make Duplicated{<:Layer} callable?
(m::$EnzymeCore.Duplicated{<:$type})(xs...) = m.val(xs...)
Expand Down

0 comments on commit 976be71

Please sign in to comment.