Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add easy way to use Mooncake.jl for gradients #26

Merged
merged 7 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,22 @@ version = "0.2.2"
[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[weakdeps]
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[extensions]
FluxMooncakeExt = "Mooncake"

[compat]
Compat = "4"
Flux = "0.14.23"
Flux = "0.14.23, 0.15"
Mooncake = "0.4.42"
NNlib = "0.9"
Optimisers = "0.3, 0.4"
ProgressMeter = "1.7.2"
Expand All @@ -23,4 +31,4 @@ julia = "1.10"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Test", "Mooncake"]
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ There are no formal documentation pages, but these links to the source will show
[`@compact(kw...) do ...`](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/compact.jl), and
[`@autostruct function Mine(d) ...`](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/autostruct.jl).
* Experimental [`apply(c::Chain, x)`](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/chain.jl) interface
* Easy way to [use Mooncake.jl](https://github.com/FluxML/Fluxperimental.jl/blob/master/ext/FluxMooncakeExt.jl) instead of Zygote.jl.
173 changes: 173 additions & 0 deletions ext/FluxMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
module FluxMooncakeExt

using Flux, Fluxperimental, Optimisers, Functors, Mooncake
import Fluxperimental: _moonstrip
# using Flux: Const

function Fluxperimental.Moonduo(x)
dx = Mooncake.zero_tangent(x)
Moonduo(x, dx)
end

# Flux gradient etc.

"""
Flux.gradient(f, args::Moonduo...)

This uses Mooncake.jl to compute the derivative,
which is both stored within `Moonduo` and returned.
Similar to the Enzyme.jl methods like `Flux.gradient(f, m::Duplicated)`.

# Example

```julia
julia> using Flux

julia> model = Chain(Dense([3.0;;]));

julia> Flux.gradient(model, [1]) do m, x # computed using Zygote
sum(abs2, m(x))
end
((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), [18.0])

julia> using Fluxperimental, Mooncake

julia> dup_model = Moonduo(model); # allocates space for gradient

julia> Flux.gradient(dup_model, Moonduo([1])) do m, x # Mooncake, returns the same
sum(abs2, m(x))
end
((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), nothing)

julia> dup_model # same gradient is also stored within Duplicated
Moonduo(
Chain(
Dense(1 => 1), # 2 parameters
),
# norm(∇) ≈ 8.49
)

julia> Flux.destructure((weight = [6.0;;], bias = [6.0]))[1] |> norm
8.48528137423857

julia> Flux.gradient(dup_model, Moonduo([1]); zero=false) do m, x # grad accumulation
sum(abs2, m(x))
end
((layers = ((weight = [12.0;;], bias = [12.0], σ = nothing),),), nothing)
```

!!! note
At present there is no way to mark some arguments constant.
Instead of `gradient(loss, Duplicated(model), Const(data))`,
you can write `gradient(m -> loss(m, data), Moonduo(model))`.
"""
Flux.gradient(f, args::Moonduo...; zero::Bool=true) = _moon_withgradient(f, args...; zero).grad

"""
Flux.withgradient(f, args::Moonduo...)

This should return the same answer as `withgradient(f, model, args...)`,
but it uses Mooncake.jl instead of Zygote.jl to compute the derivative.

# Example

```julia
julia> using Flux, Fluxperimental, Mooncake

julia> model = Chain(Embedding([1.1 2.2 3.3]), Dense([4.4;;]), only);

julia> model(3)
14.52

julia> Flux.withgradient(m -> m(3), model) # this uses Zygote
(val = 14.52, grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),))

julia> Flux.withgradient(m -> m(3), Moonduo(model)) # this uses Mooncake
(val = 14.52, grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),))
```

!!! warning
With Zygote, the function `f` may return Tuple or NamedTuple, with the loss as the first element.
This feature is not supported here, for now.
"""
Flux.withgradient(f, args::Moonduo...; zero::Bool=true) = _moon_withgradient(f, args...; zero)

function _moon_withgradient(f, args::Moonduo...; zero)
plain = map(x -> x.val, args)
rule = Mooncake.build_rrule(f, plain...)

for x in args
_check_mutable(x)
zero && Mooncake.set_to_zero!!(x.dval)
end
coduals = map(x -> Mooncake.CoDual(x.val, x.dval), args)
val, _ = Mooncake.__value_and_gradient!!(rule, Mooncake.zero_codual(f), coduals...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any chance at all that f will contain trainable parameters, or does Flux insist that you not do that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea Flux.gradient(f, x, y) has length 2, alla Zygote. I agree that's not the maximally flexible thing, and very occasionally you end up with gradient(|>, x, f)... but in real use it seems like never.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool


grad = map(x -> _moongrad(x.dval), args)
(; val, grad)
end

# _check_mutable(x::Const) = nothing
_check_mutable(x::Moonduo) = Functors.anymutable(x) || error(
"""`Flux.gradient(f, Moonduo(x), ...)` expects `x` to contain mutable parameter arrays."""
)

function _moongrad(dx)
dx2 = _moonstrip(dx) # remove all the weird types
isnothing(dx2) && return
return Flux.fmapstructure(identity, dx2; prune=nothing)
end

_moonstrip(dx::Mooncake.Tangent) = map(_moonstrip, dx.fields)
_moonstrip(dx::Mooncake.MutableTangent) = map(_moonstrip, dx.fields)
_moonstrip(dx::Mooncake.NoTangent) = nothing
_moonstrip(dx::Union{Tuple, NamedTuple, AbstractArray}) = map(_moonstrip, dx)
_moonstrip(dx::AbstractArray{Mooncake.NoTangent}) = nothing
_moonstrip(dx::AbstractArray{<:Number}) = dx
_moonstrip(dx::AbstractArray{<:Integer}) = nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To the best of my knowledge, this shouldn't ever be a case that you see.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I didn't try very hard! Do you know if there are other types which can occur, besides those handled here?

Copy link
Member

@willtebbutt willtebbutt Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you place any constraints on the types of the things that you take gradients w.r.t.?

I ask because there's nothing to stop people adding a new type, and declaring a non-standard tangent type for it (i.e. while some subtype of Tangent is the thing returned by tangent_type by default for structs, there's nothing to stop people making tangent_type return something else for a type that they own). So in principle you could see literally anything. In practice, assuming that you're working with Arrays, and structs / mutable structs / Tuples / NamedTuples of Arrays, I think you should be fine.

My honest advice would be to do what you're doing at the minute. i.e. check that it works for models that you care about, and ensure that there's a good error message so that a user knows where to ask for help if they encounter something you weren't expecting.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, sounds good!

No constraints, but in reality it's going to be arrays & all kinds of structs.

_moonstrip(dx::Number) = nothing
function _moonstrip(dx)
@warn "not sure what to do with this type" typeof(dx)
dx
end

# Optimisers etc.

Flux.setup(rule::Optimisers.AbstractRule, m::Moonduo) = Flux.setup(rule, m.val)

function Flux.update!(opt_state, model::Moonduo)
Flux.update!(opt_state, model.val, _moongrad(model.dval))
nothing
end

### Flux.Train, for train!

_applyloss(loss, model, d...) = loss(model, d...)

"""
train!(loss, Moonduo(model), data, opt_state)

This method uses Mooncake.jl instead of Zygote.jl to compute the gradients, but is otherwise the
same as `Flux.train!(loss, model, data, opt_state)`.
"""
function Flux.train!(loss, model::Moonduo, data, opt; cb=nothing, epochs::Int=1)
isnothing(cb) || error("""train! does not support callback functions.
For more control use a loop with `gradient` and `update!`.""")
Flux.Train.@withprogress for (i,d) in enumerate(Iterators.cycle(data, epochs))
d_splat = d isa Tuple ? d : (d,)
rule = Mooncake.build_rrule(f, model.val, d_splat...) # perhaps not ideal to do this inside the loop?

Mooncake.set_to_zero!!(model.dval)
l, _ = Mooncake.__value_and_gradient!!(rule, Mooncake.zero_codual(f), model, map(Mooncake.zero_codual, d_splat)...)

if !isfinite(l)
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
end

Flux.update!(opt, model)

Flux.Train.@logprogress Base.haslength(data) ? i/(length(data)*epochs) : nothing
end
end

end # module
3 changes: 3 additions & 0 deletions src/Fluxperimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,7 @@ export @autostruct

include("new_recur.jl")

include("mooncake.jl")
export Moonduo

end # module Fluxperimental
40 changes: 40 additions & 0 deletions src/mooncake.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
Moonduo(x, [dx])

This stores both an object `x` and its gradient `dx`,
with `dx` in the format used by Mooncake.jl. This is automatically allocated
when you call `Moonduo(x)`.

This serves the same purpose as Enzyme.jl's `Duplicated` type.
Both of these AD engines prefer that space for the gradient be pre-allocated.

Maybe this is like `Mooncake.CoDual`, except that is marked private, and seems discouraged:
https://github.com/compintell/Mooncake.jl/issues/275
An advantage of Flux owning this type is that we can provide pretty printing without piracy.
"""
struct Moonduo{X,DX}
val::X
dval::DX
end

function Moonduo(args...)
if length(args)==1
error("The method `Moonduo(x)` is only available when Mooncake.jl is loaded!")
else
error("The only legal methods are `Moonduo(x)` and `Moonduo(x, dx)`.")
end
end

Optimisers.trainable(m::Moonduo) = (; m.val)

Flux.@layer :expand Moonduo

(m::Moonduo)(x...) = m.val(x...)

function _moonstrip end

function Flux._show_pre_post(obj::Moonduo)
nrm = Flux.norm(destructure(_moonstrip(obj.dval))[1])
str = repr(round(nrm; sigdigits=3))
"Moonduo(", " # norm(∇) ≈ $str\n) "
end
45 changes: 45 additions & 0 deletions test/mooncake.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using Flux, Fluxperimental, Mooncake

@testset "gradient, withgradient, Moonduo" begin
# Tests above are about how Enzyme digests Flux layers.
# Tests here are just the interface Flux.gradient(f, Moonduo(model)) etc.
m1 = Moonduo(Dense(3=>2))
@test m1 isa Moonduo
g1 = Flux.gradient(m -> sum(m.bias), m1) |> only
@test iszero(g1.weight)
@test g1.bias == [1, 1]
@test m1.dval.fields.bias == [1, 1]

g2 = Flux.withgradient((m,x) -> sum(m(x)), m1, Moonduo([1,2,3f0])) # would prefer Const
@test g2.val ≈ sum(m1([1,2,3f0]))
@test g2.grad[1].weight ≈ [1 2 3; 1 2 3]
@test_skip g2.grad[2] === nothing # implicitly Const

# g3 = Flux.withgradient(Moonduo([1,2,4.], zeros(3))) do x
# z = 1 ./ x
# sum(z), z # here z is an auxillary output
# end
# @test g3.grad[1] ≈ [-1.0, -0.25, -0.0625]
# @test g3.val[1] ≈ 1.75
# @test g3.val[2] ≈ [1.0, 0.5, 0.25]
# g4 = Flux.withgradient(Moonduo([1,2,4.], zeros(3))) do x
# z = 1 ./ x
# (loss=sum(z), aux=string(z))
# end
# @test g4.grad[1] ≈ [-1.0, -0.25, -0.0625]
# @test g4.val.loss ≈ 1.75
# @test g4.val.aux == "[1.0, 0.5, 0.25]"

# setup understands Moonduo:
@test Flux.setup(Adam(), m1) == Flux.setup(Adam(), m1.val)

# # At least one Moonduo is required:
# @test_throws ArgumentError Flux.gradient(m -> sum(m.bias), Const(m1.val))
# @test_throws ArgumentError Flux.gradient((m,x) -> sum(m(x)), Const(m1.val), [1,2,3f0])
# @test_throws ArgumentError Flux.withgradient(m -> sum(m.bias), Const(m1.val))
# @test_throws ArgumentError Flux.withgradient((m,x) -> sum(m(x)), Const(m1.val), [1,2,3f0])
# # Active is disallowed:
# @test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, m1, Active(3f0))
# @test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, m1.val, Active(3f0))
# @test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, Const(m1.val), Active(3f0))
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ using Flux, Fluxperimental

include("new_recur.jl")

include("mooncake.jl")
end
Loading