-
-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add easy way to use Mooncake.jl for gradients #26
Changes from 4 commits
c4d1de4
fc03e5a
139ff3e
227e01f
a095888
3041e7a
2a536d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
module FluxMooncakeExt | ||
|
||
using Flux, Fluxperimental, Optimisers, Functors, Mooncake | ||
import Fluxperimental: _moonstrip | ||
# using Flux: Const | ||
|
||
function Fluxperimental.Moonduo(x) | ||
dx = Mooncake.zero_tangent(x) | ||
Moonduo(x, dx) | ||
end | ||
|
||
# Flux gradient etc. | ||
|
||
""" | ||
Flux.gradient(f, args::Moonduo...) | ||
|
||
This uses Mooncake.jl to compute the derivative, | ||
which is both stored within `Moonduo` and returned. | ||
Similar to the Enzyme.jl methods like `Flux.gradient(f, m::Duplicated)`. | ||
|
||
# Example | ||
|
||
```julia | ||
julia> using Flux | ||
|
||
julia> model = Chain(Dense([3.0;;])); | ||
|
||
julia> Flux.gradient(model, [1]) do m, x # computed using Zygote | ||
sum(abs2, m(x)) | ||
end | ||
((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), [18.0]) | ||
|
||
julia> using Fluxperimental, Mooncake | ||
|
||
julia> dup_model = Moonduo(model); # allocates space for gradient | ||
|
||
julia> Flux.gradient(dup_model, Moonduo([1])) do m, x # Mooncake, returns the same | ||
sum(abs2, m(x)) | ||
end | ||
((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), nothing) | ||
|
||
julia> dup_model # same gradient is also stored within Duplicated | ||
Moonduo( | ||
Chain( | ||
Dense(1 => 1), # 2 parameters | ||
), | ||
# norm(∇) ≈ 8.49 | ||
) | ||
|
||
julia> Flux.destructure((weight = [6.0;;], bias = [6.0]))[1] |> norm | ||
8.48528137423857 | ||
|
||
julia> Flux.gradient(dup_model, Moonduo([1]); zero=false) do m, x # grad accumulation | ||
sum(abs2, m(x)) | ||
end | ||
((layers = ((weight = [12.0;;], bias = [12.0], σ = nothing),),), nothing) | ||
``` | ||
|
||
!!! note | ||
At present there is no way to mark some arguments constant. | ||
Instead of `gradient(loss, Duplicated(model), Const(data))`, | ||
you can write `gradient(m -> loss(m, data), Moonduo(model))`. | ||
""" | ||
Flux.gradient(f, args::Moonduo...; zero::Bool=true) = _moon_withgradient(f, args...; zero).grad | ||
|
||
""" | ||
Flux.withgradient(f, args::Moonduo...) | ||
|
||
This should return the same answer as `withgradient(f, model, args...)`, | ||
but it uses Mooncake.jl instead of Zygote.jl to compute the derivative. | ||
|
||
# Example | ||
|
||
```julia | ||
julia> using Flux, Fluxperimental, Mooncake | ||
|
||
julia> model = Chain(Embedding([1.1 2.2 3.3]), Dense([4.4;;]), only); | ||
|
||
julia> model(3) | ||
14.52 | ||
|
||
julia> Flux.withgradient(m -> m(3), model) # this uses Zygote | ||
(val = 14.52, grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),)) | ||
|
||
julia> Flux.withgradient(m -> m(3), Moonduo(model)) # this uses Mooncake | ||
(val = 14.52, grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),)) | ||
``` | ||
|
||
!!! warning | ||
With Zygote, the function `f` may return Tuple or NamedTuple, with the loss as the first element. | ||
This feature is not supported here, for now. | ||
""" | ||
Flux.withgradient(f, args::Moonduo...; zero::Bool=true) = _moon_withgradient(f, args...; zero) | ||
|
||
function _moon_withgradient(f, args::Moonduo...; zero) | ||
plain = map(x -> x.val, args) | ||
rule = Mooncake.build_rrule(f, plain...) | ||
|
||
for x in args | ||
_check_mutable(x) | ||
zero && Mooncake.set_to_zero!!(x.dval) | ||
end | ||
coduals = map(x -> Mooncake.CoDual(x.val, x.dval), args) | ||
val, _ = Mooncake.__value_and_gradient!!(rule, Mooncake.zero_codual(f), coduals...) | ||
|
||
grad = map(x -> _moongrad(x.dval), args) | ||
(; val, grad) | ||
end | ||
|
||
# _check_mutable(x::Const) = nothing | ||
_check_mutable(x::Moonduo) = Functors.anymutable(x) || error( | ||
"""`Flux.gradient(f, Moonduo(x), ...)` expects `x` to contain mutable parameter arrays.""" | ||
) | ||
|
||
function _moongrad(dx) | ||
dx2 = _moonstrip(dx) # remove all the weird types | ||
isnothing(dx2) && return | ||
return Flux.fmapstructure(identity, dx2; prune=nothing) | ||
end | ||
|
||
_moonstrip(dx::Mooncake.Tangent) = map(_moonstrip, dx.fields) | ||
_moonstrip(dx::Mooncake.MutableTangent) = map(_moonstrip, dx.fields) | ||
_moonstrip(dx::Mooncake.NoTangent) = nothing | ||
_moonstrip(dx::Union{Tuple, NamedTuple, AbstractArray}) = map(_moonstrip, dx) | ||
_moonstrip(dx::AbstractArray{Mooncake.NoTangent}) = nothing | ||
_moonstrip(dx::AbstractArray{<:Number}) = dx | ||
_moonstrip(dx::AbstractArray{<:Integer}) = nothing | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To the best of my knowledge, this shouldn't ever be a case that you see. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I didn't try very hard! Do you know if there are other types which can occur, besides those handled here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you place any constraints on the types of the things that you take gradients w.r.t.? I ask because there's nothing to stop people adding a new type, and declaring a non-standard tangent type for it (i.e. while some subtype of My honest advice would be to do what you're doing at the minute. i.e. check that it works for models that you care about, and ensure that there's a good error message so that a user knows where to ask for help if they encounter something you weren't expecting. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, sounds good! No constraints, but in reality it's going to be arrays & all kinds of structs. |
||
_moonstrip(dx::Number) = nothing | ||
function _moonstrip(dx) | ||
@warn "not sure what to do with this type" typeof(dx) | ||
dx | ||
end | ||
|
||
# Optimisers etc. | ||
|
||
Flux.setup(rule::Optimisers.AbstractRule, m::Moonduo) = Flux.setup(rule, m.val) | ||
|
||
function Flux.update!(opt_state, model::Moonduo) | ||
Flux.update!(opt_state, model.val, _moongrad(model.dval)) | ||
nothing | ||
end | ||
|
||
### Flux.Train, for train! | ||
|
||
_applyloss(loss, model, d...) = loss(model, d...) | ||
|
||
""" | ||
train!(loss, Moonduo(model), data, opt_state) | ||
|
||
This method uses Mooncake.jl instead of Zygote.jl to compute the gradients, but is otherwise the | ||
same as `Flux.train!(loss, model, data, opt_state)`. | ||
""" | ||
function Flux.train!(loss, model::Moonduo, data, opt; cb=nothing, epochs::Int=1) | ||
isnothing(cb) || error("""train! does not support callback functions. | ||
For more control use a loop with `gradient` and `update!`.""") | ||
Flux.Train.@withprogress for (i,d) in enumerate(Iterators.cycle(data, epochs)) | ||
d_splat = d isa Tuple ? d : (d,) | ||
rule = Mooncake.build_rrule(f, model.val, d_splat...) # perhaps not ideal to do this inside the loop? | ||
|
||
Mooncake.set_to_zero!!(model.dval) | ||
l, _ = Mooncake.__value_and_gradient!!(rule, Mooncake.zero_codual(f), model, map(Mooncake.zero_codual, d_splat)...) | ||
|
||
if !isfinite(l) | ||
throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) | ||
end | ||
|
||
Flux.update!(opt, model) | ||
|
||
Flux.Train.@logprogress Base.haslength(data) ? i/(length(data)*epochs) : nothing | ||
end | ||
end | ||
|
||
end # module |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,4 +21,7 @@ export @autostruct | |
|
||
include("new_recur.jl") | ||
|
||
include("mooncake.jl") | ||
export Moonduo | ||
|
||
end # module Fluxperimental |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
""" | ||
Moonduo(x, [dx]) | ||
|
||
This stores both an object `x` and its gradient `dx`, | ||
with `dx` in the format used by Mooncake.jl. This is automatically allocated | ||
when you call `Moonduo(x)`. | ||
|
||
This serves the same purpose as Enzyme.jl's `Duplicated` type. | ||
Both of these AD engines prefer that space for the gradient be pre-allocated. | ||
|
||
Maybe this is like `Mooncake.CoDual`, except that is marked private, and seems discouraged: | ||
https://github.com/compintell/Mooncake.jl/issues/275 | ||
An advantage of Flux owning this type is that we can provide pretty printing without piracy. | ||
""" | ||
struct Moonduo{X,DX} | ||
val::X | ||
dval::DX | ||
end | ||
|
||
function Moonduo(args...) | ||
if length(args)==1 | ||
error("The method `Moonduo(x)` is only available when Mooncake.jl is loaded!") | ||
else | ||
error("The only legal methods are `Moonduo(x)` and `Moonduo(x, dx)`.") | ||
end | ||
end | ||
|
||
Optimisers.trainable(m::Moonduo) = (; m.val) | ||
|
||
Flux.@layer :expand Moonduo | ||
|
||
(m::Moonduo)(x...) = m.val(x...) | ||
|
||
function _moonstrip end | ||
|
||
function Flux._show_pre_post(obj::Moonduo) | ||
nrm = Flux.norm(destructure(_moonstrip(obj.dval))[1]) | ||
str = repr(round(nrm; sigdigits=3)) | ||
"Moonduo(", " # norm(∇) ≈ $str\n) " | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
using Flux, Fluxperimental, Mooncake | ||
|
||
@testset "gradient, withgradient, Moonduo" begin | ||
# Tests above are about how Enzyme digests Flux layers. | ||
# Tests here are just the interface Flux.gradient(f, Moonduo(model)) etc. | ||
m1 = Moonduo(Dense(3=>2)) | ||
@test m1 isa Moonduo | ||
g1 = Flux.gradient(m -> sum(m.bias), m1) |> only | ||
@test iszero(g1.weight) | ||
@test g1.bias == [1, 1] | ||
@test m1.dval.fields.bias == [1, 1] | ||
|
||
g2 = Flux.withgradient((m,x) -> sum(m(x)), m1, Moonduo([1,2,3f0])) # would prefer Const | ||
@test g2.val ≈ sum(m1([1,2,3f0])) | ||
@test g2.grad[1].weight ≈ [1 2 3; 1 2 3] | ||
@test_skip g2.grad[2] === nothing # implicitly Const | ||
|
||
# g3 = Flux.withgradient(Moonduo([1,2,4.], zeros(3))) do x | ||
# z = 1 ./ x | ||
# sum(z), z # here z is an auxillary output | ||
# end | ||
# @test g3.grad[1] ≈ [-1.0, -0.25, -0.0625] | ||
# @test g3.val[1] ≈ 1.75 | ||
# @test g3.val[2] ≈ [1.0, 0.5, 0.25] | ||
# g4 = Flux.withgradient(Moonduo([1,2,4.], zeros(3))) do x | ||
# z = 1 ./ x | ||
# (loss=sum(z), aux=string(z)) | ||
# end | ||
# @test g4.grad[1] ≈ [-1.0, -0.25, -0.0625] | ||
# @test g4.val.loss ≈ 1.75 | ||
# @test g4.val.aux == "[1.0, 0.5, 0.25]" | ||
|
||
# setup understands Moonduo: | ||
@test Flux.setup(Adam(), m1) == Flux.setup(Adam(), m1.val) | ||
|
||
# # At least one Moonduo is required: | ||
# @test_throws ArgumentError Flux.gradient(m -> sum(m.bias), Const(m1.val)) | ||
# @test_throws ArgumentError Flux.gradient((m,x) -> sum(m(x)), Const(m1.val), [1,2,3f0]) | ||
# @test_throws ArgumentError Flux.withgradient(m -> sum(m.bias), Const(m1.val)) | ||
# @test_throws ArgumentError Flux.withgradient((m,x) -> sum(m(x)), Const(m1.val), [1,2,3f0]) | ||
# # Active is disallowed: | ||
# @test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, m1, Active(3f0)) | ||
# @test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, m1.val, Active(3f0)) | ||
# @test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, Const(m1.val), Active(3f0)) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,4 +13,5 @@ using Flux, Fluxperimental | |
|
||
include("new_recur.jl") | ||
|
||
include("mooncake.jl") | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any chance at all that
f
will contain trainable parameters, or does Flux insist that you not do that?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea
Flux.gradient(f, x, y)
has length 2, alla Zygote. I agree that's not the maximally flexible thing, and very occasionally you end up withgradient(|>, x, f)
... but in real use it seems like never.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool