Skip to content

Commit

Permalink
make informative errors if you use Duplicated without loading Enzyme
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 29, 2024
1 parent 717c5f0 commit 7e155f8
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 108 deletions.
104 changes: 3 additions & 101 deletions ext/FluxEnzymeExt/FluxEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,55 +13,7 @@ EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true

### gradient & withgradient

"""
gradient(f, args::Union{Const,Duplicated}...)
This should return the same answer as `gradient(f, args...)`,
but it uses Enzyme.jl instead of Zygote.jl to compute the derivative.
Only available when Enzyme is loaded!
Besides returning the gradient, this is also stored within the `Duplicated` object.
Calling `Enzyme.Duplicated(model)` allocates space for the gradient,
which is zero'd befor use when calling `gradient`.
With the keyword `zero=false`, the new gradient will instead be added to what is already stored.
!!! warning "Experimental"
Enzyme support like this is new and somewhat experimental.
# Example
```
julia> using Flux
julia> model = Chain(Dense([3.0;;]));
julia> Flux.gradient(model, [1]) do m, x # computed using Zygote
sum(abs2, m(x))
end
((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), [18.0])
julia> using Enzyme
julia> dup_model = Duplicated(model); # allocates space for gradient
julia> Flux.gradient(dup_model, Const([1])) do m, x # Enzyme, returns the same
sum(abs2, m(x))
end
((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), nothing)
julia> dup_model # same gradient is also stored within Duplicated
Duplicated(
Chain(
Dense(1 => 1), # 2 parameters
),
# norm(∇) ≈ 8.49
)
julia> Flux.destructure((weight = [6.0;;], bias = [6.0]))[1] |> norm
8.48528137423857
```
"""
function Flux.gradient(f, args::Union{Const, Duplicated}...; zero::Bool=true)
function Flux._enzyme_gradient(f, args::Union{Const, Duplicated}...; zero::Bool=true)
for x in args
zero && x isa Duplicated && _make_zero!(x.dval)
end
Expand All @@ -74,41 +26,7 @@ _grad_or_nothing(dup::Duplicated) = Flux.fmapstructure(_grad_or_nothing, dup.dva
_grad_or_nothing(::Const) = nothing
_grad_or_nothing(x) = Optimisers.isnumeric(x) ? x : nothing

# _const_unless_dup(x) = Const(x)
# _const_unless_dup(dup::Duplicated) = x

# TODO allow for Duplicated as 2nd argument, assume others const? This produces ambiguities...
# Flux.withgradient(f, dup::Duplicated, rest...) = Flux.withgradient(f, dup, map(_const_unless_dup, rest)...)
# Flux.gradient(f, dup::Duplicated, rest...) = Flux.gradient(f, dup, map(_const_unless_dup, rest)...)

"""
withgradient(f, args::Union{Const,Duplicated}...)
This should return the same answer as `withgradient(f, model, args...)`,
but it uses Enzyme.jl instead of Zygote.jl to compute the derivative.
Only available when Enzyme is loaded!
Does not at present allow `f` to return a tuple of `(loss, aux)` the way `Zygote.withgradient` does.
# Example
```
julia> using Flux, Enzyme
julia> model = Chain(Embedding([1.1 2.2 3.3]), Dense([4.4;;]), only);
julia> model(3)
14.52
julia> Flux.withgradient(m -> m(3), model) # this uses Zygote
(val = 14.52, grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),))
julia> Flux.withgradient(m -> m(3), Duplicated(model)) # this uses Enzyme
(val = 14.52, grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),))
```
"""
function Flux.withgradient(f, args::Union{Const, Duplicated}...; zero::Bool=true)
function Flux._enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero::Bool=true)
for x in args
zero && x isa Duplicated && _make_zero!(x.dval)
end
Expand All @@ -121,23 +39,7 @@ end

_applyloss(loss, model, d...) = loss(model, d...)

using Flux: _old_to_new # from src/deprecations.jl
train!(loss, model::Duplicated, data, opt::Optimise.AbstractOptimiser; cb=nothing) =
train!(loss, model, data, _old_to_new(opt); cb)

function train!(loss, model::Duplicated, data, rule::Optimisers.AbstractRule; cb = nothing)
train!(loss, model, data, _rule_to_state(model, rule); cb)
end

"""
train!(loss, Duplicated(model), data, opt_state)
This method uses Enzyme.jl instead of Zygote.jl to compute the gradients,
but is otherwise the same as `train!(loss, model, data, opt_state)`.
Only available when Enzyme is loaded.
"""
function train!(loss, model::Duplicated, data, opt; cb = nothing)
function _enzyme_train!(loss, model::Duplicated, data, opt; cb = nothing)
isnothing(cb) || error("""train! does not support callback functions.
For more control use a loop with `gradient` and `update!`.""")
@withprogress for (i,d) in enumerate(data)
Expand Down
4 changes: 3 additions & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ const stack = MLUtils.stack # now exported by Base
import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions
using Optimisers: freeze!, thaw!, adjust!, trainables
using Random: default_rng

using Zygote, ChainRulesCore
using Zygote: Params, @adjoint, pullback
using Zygote.ForwardDiff: value
using EnzymeCore: EnzymeCore

@reexport using MLDataDevices: MLDataDevices, supported_gpu_backends, reset_gpu_device!,
default_device_rng,
Expand Down Expand Up @@ -76,7 +78,7 @@ include("functor.jl")

@compat(public, (
# from OneHotArrays.jl
onehot, onehotbatch, onecold,
onehot, onehotbatch, onecold,
# from Functors.jl
functor, @functor, KeyPath, haskeypath, getkeypath,
# from Optimise/Train/Optimisers.jl
Expand Down
110 changes: 108 additions & 2 deletions src/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,77 @@ julia> gradient([7, 11], 0, 1) do x, y, d
([14.0, 22.0], 2.0, nothing)
```
"""
gradient(f, args...) = Zygote.gradient(f, args...)
function gradient(f, args...; zero::Bool=true)
for a in args
a isa EnzymeCore.Duplicated && return _enzyme_gradient(f, map(_ensure_enzyme, args)...; zero)
end
Zygote.gradient(f, args...)
end

_ensure_enzyme(x::EnzymeCore.Duplicated) = x
_ensure_enzyme(x::EnzymeCore.Const) = x
_ensure_enzyme(x) = EnzymeCore.Const(x)

"""
gradient(f, args::Union{Const,Duplicated}...)
This should return the same answer as `gradient(f, args...)`,
but it uses Enzyme.jl instead of Zygote.jl to compute the derivative.
Only available when Enzyme is loaded!
This method is used when at least one argument is of type `Duplicated`,
and all unspecified aguments are wrapped in `Const`.
Besides returning the gradient, this is also stored within the `Duplicated` object.
Calling `Enzyme.Duplicated(model)` allocates space for the gradient,
which is zero'd befor use when calling `gradient`.
With the keyword `zero=false`, the new gradient will instead be added to what is already stored.
!!! warning "Experimental"
Enzyme support like this is new and somewhat experimental.
# Example
```
julia> using Flux
julia> model = Chain(Dense([3.0;;]));
julia> Flux.gradient(model, [1]) do m, x # computed using Zygote
sum(abs2, m(x))
end
((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), [18.0])
julia> using Enzyme
julia> dup_model = Duplicated(model); # allocates space for gradient
julia> Flux.gradient(dup_model, Const([1])) do m, x # Enzyme, returns the same
sum(abs2, m(x))
end
((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), nothing)
julia> dup_model # same gradient is also stored within Duplicated
Duplicated(
Chain(
Dense(1 => 1), # 2 parameters
),
# norm(∇) ≈ 8.49
)
julia> Flux.destructure((weight = [6.0;;], bias = [6.0]))[1] |> norm
8.48528137423857
julia> Flux.gradient(dup_model, [1]; zero=false) do m, x # implict Const([1]), and grad accumulation
sum(abs2, m(x))
end
((layers = ((weight = [12.0;;], bias = [12.0], σ = nothing),),), nothing)
```
"""
gradient(f, args::Union{EnzymeCore.Const, EnzymeCore.Duplicated}...; zero::Bool=true) = _enzyme_gradient(f, args...; zero)

# FluxEnzymeExt defines more specific _enzyme_gradient(f, args::Union{Const, Duplicated}...; zero)
_enzyme_gradient(f, args...; zero) = error("methods like `gradient(f, x::Duplicated)` are only available when Enzyme is loaded.")


"""
Expand Down Expand Up @@ -67,4 +136,41 @@ julia> withgradient(3.0, 4.0) do x, y
(val = (div = 0.75, mul = 12.0), grad = (0.25, -0.1875))
```
"""
withgradient(f, args...) = Zygote.withgradient(f, args...)
function withgradient(f, args...; zero::Bool=true)
for a in args
a isa EnzymeCore.Duplicated && return _enzyme_withgradient(f, map(_ensure_enzyme, args)...; zero)
end
Zygote.withgradient(f, args...)
end

"""
withgradient(f, args::Union{Const,Duplicated}...)
This should return the same answer as `withgradient(f, model, args...)`,
but it uses Enzyme.jl instead of Zygote.jl to compute the derivative.
Only available when Enzyme is loaded!
Does not at present allow `f` to return a tuple of `(loss, aux)` the way `Zygote.withgradient` does.
# Example
```
julia> using Flux, Enzyme
julia> model = Chain(Embedding([1.1 2.2 3.3]), Dense([4.4;;]), only);
julia> model(3)
14.52
julia> Flux.withgradient(m -> m(3), model) # this uses Zygote
(val = 14.52, grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),))
julia> Flux.withgradient(m -> m(3), Duplicated(model)) # this uses Enzyme
(val = 14.52, grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),))
```
"""
withgradient(f, args::Union{EnzymeCore.Const, EnzymeCore.Duplicated}...; zero::Bool=true) = _enzyme_withgradient(f, args...; zero)

# FluxEnzymeExt defines more specific _enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero)
_enzyme_withgradient(f, args...; zero) = error("methods like `withgradient(f, x::Duplicated)` are only available when Enzyme is loaded.")
1 change: 0 additions & 1 deletion src/layers/macro.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import EnzymeCore

"""
@layer Dense
Expand Down
34 changes: 31 additions & 3 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ module Train
using LinearAlgebra
using Optimisers: Optimisers
using Functors: fmap, fmapstructure
using ..Flux: Flux # used only in docstring
import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions
using ..Flux: Flux # used only in docstring
import ..Flux.Optimise: train!, update!, Optimise # during 0.13, we add methods to the old functions

export setup, train!

using ProgressLogging: @progress, @withprogress, @logprogress
using Zygote: Zygote, Params
using EnzymeCore: Duplicated

"""
opt_state = setup(rule, model)
Expand Down Expand Up @@ -56,7 +57,7 @@ end
train!(loss, model, data, opt_state)
Uses a `loss` function and training `data` to improve the `model`'s parameters
according to a particular optimisation rule encoded in `opt_state`.
according to a particular optimisation rule encoded in `opt_state`.
Iterates through `data` once, evaluating for each `d in data` either
`loss(model, d...)` if `d isa Tuple`, or else `loss(model, d)` for other `d`.
Expand Down Expand Up @@ -134,4 +135,31 @@ function _rule_to_state(model, rule::Optimisers.AbstractRule)
state
end

"""
train!(loss, Duplicated(model), data, opt_state)
This method uses Enzyme.jl instead of Zygote.jl to compute the gradients,
but is otherwise the same as `train!(loss, model, data, opt_state)`.
Only available when Enzyme is loaded.
!!! compat "New"
This method was added in Flux 0.13.9.
"""
train!(loss, model::Duplicated, data, opt; cb = nothing) = _enzyme_train!(loss, model, data, opt; cb = nothing)

# FluxEnzymeExt defines more specific _enzyme_train!(loss, model::Duplicated, data, opt; cb)
_enzyme_train!(loss, model, data, opt; cb = nothing) = error("The method `train!(loss, Duplicated(model), data, opt_state)` is only available when Enzyme.jl is loaded")

# Following src/deprecations.jl
function train!(loss, model::Duplicated, data, opt::Optimise.AbstractOptimiser; cb=nothing)
train!(loss, model, data, _old_to_new(opt); cb)
end

# This method let you use Optimisers.Descent() without setup, when there is no state
function train!(loss, model::Duplicated, data, rule::Optimisers.AbstractRule; cb=nothing)
train!(loss, model, data, _rule_to_state(model, rule); cb)
end

end # module Train

0 comments on commit 7e155f8

Please sign in to comment.