Skip to content

Commit

Permalink
disallow trivial Duplicated
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 10, 2024
1 parent db67dcf commit 2636454
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions ext/FluxEnzymeExt/FluxEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Flux: _make_zero!
import Flux.Train: _enzyme_train!, _rule_to_state
import Flux.Optimise
import Optimisers
import Functors
import Enzyme
using Enzyme: EnzymeRules, Active, Const, Duplicated, autodiff, ReverseWithPrimal
using Enzyme: autodiff_thunk, ReverseSplitWithPrimal
Expand All @@ -17,11 +18,17 @@ EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true
function Flux._enzyme_gradient(f, args::Union{Const, Duplicated}...; zero::Bool=true)
for x in args
zero && x isa Duplicated && _make_zero!(x.dval)
_check_mutable(x)
end
Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
map(_grad_or_nothing, args)

Check warning on line 24 in ext/FluxEnzymeExt/FluxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/FluxEnzymeExt/FluxEnzymeExt.jl#L18-L24

Added lines #L18 - L24 were not covered by tests
end

_check_mutable(x::Const) = nothing
_check_mutable(x::Duplicated) = Functors.anymutable(x) || error(

Check warning on line 28 in ext/FluxEnzymeExt/FluxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/FluxEnzymeExt/FluxEnzymeExt.jl#L27-L28

Added lines #L27 - L28 were not covered by tests
"""`Flux.gradient(f, Duplicatged(x), ...)` expects `x` to contain mutable parameter arrays."""
)

# This function strips the returned gradient to be Zygote-like:
_grad_or_nothing(dup::Duplicated) = Flux.fmapstructure(_grad_or_nothing, dup.dval; prune=nothing)
_grad_or_nothing(::Const) = nothing
Expand All @@ -30,6 +37,7 @@ _grad_or_nothing(x) = Optimisers.isnumeric(x) ? x : nothing
function Flux._enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero::Bool=true)
for x in args
zero && x isa Duplicated && _make_zero!(x.dval)
_check_mutable(x)
end

Check warning on line 41 in ext/FluxEnzymeExt/FluxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/FluxEnzymeExt/FluxEnzymeExt.jl#L37-L41

Added lines #L37 - L41 were not covered by tests

# _, val = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
Expand Down

0 comments on commit 2636454

Please sign in to comment.