Skip to content

Commit

Permalink
fixup, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 23, 2024
1 parent c4d1de4 commit fc03e5a
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 13 deletions.
51 changes: 38 additions & 13 deletions ext/FluxMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ using Flux, Fluxperimental, Mooncake
import Fluxperimental: _moonstrip
# using Flux: Const

println("loaded mooncake ext")

function Fluxperimental.Moonduo(x)
dx = Mooncake.zero_tangent(x)
Moonduo(x, dx)
Expand Down Expand Up @@ -94,7 +92,8 @@ function _moon_withgradient(f, args::Moonduo...; zero)
rule = Mooncake.build_rrule(f, plain...)

for x in args
zero && _moonzero!(x.dval)
_check_mutable(x)
zero && Mooncake.set_to_zero!!(x.dval)
end
coduals = map(x -> Mooncake.CoDual(x.val, x.dval), args)
val, _ = Mooncake.__value_and_gradient!!(rule, Mooncake.zero_codual(f), coduals...)
Expand All @@ -103,23 +102,19 @@ function _moon_withgradient(f, args::Moonduo...; zero)
(; val, grad)
end

_moonzero!(dx::Mooncake.Tangent) = foreach(_moonzero!, dx.fields)
_moonzero!(dx::Mooncake.MutableTangent) = foreach(_moonzero!, dx.fields)
_moonzero!(dx::Mooncake.NoTangent) = nothing
_moonzero!(dx::Union{Tuple, NamedTuple, AbstractArray}) = foreach(_moonzero!, dx)
_moonzero!(dx::AbstractArray{Mooncake.NoTangent}) = nothing
_moonzero!(dx::AbstractArray{<:Number}) = dx .= 0
function _moonzero!(dx)
@warn "not sure what to do with this type" typeof(dx)
dx
end
# _check_mutable(x::Const) = nothing
_check_mutable(x::Moonduo) = Functors.anymutable(x) || error(
"""`Flux.gradient(f, Moonduo(x), ...)` expects `x` to contain mutable parameter arrays."""
)

_moonstrip(dx::Mooncake.Tangent) = map(_moonstrip, dx.fields)
_moonstrip(dx::Mooncake.MutableTangent) = map(_moonstrip, dx.fields)
_moonstrip(dx::Mooncake.NoTangent) = nothing
_moonstrip(dx::Union{Tuple, NamedTuple, AbstractArray}) = map(_moonstrip, dx)
_moonstrip(dx::AbstractArray{Mooncake.NoTangent}) = nothing
_moonstrip(dx::AbstractArray{<:Number}) = dx
_moonstrip(dx::AbstractArray{<:Integer}) = nothing
_moonstrip!(dx::Number) = nothing
function _moonstrip(dx)
@warn "not sure what to do with this type" typeof(dx)
dx
Expand All @@ -134,4 +129,34 @@ function Flux.update!(opt_state, model::Moonduo)
nothing
end

### Flux.Train, for train!

_applyloss(loss, model, d...) = loss(model, d...)

"""
train!(loss, Moonduo(model), data, opt_state)
This method uses Mooncake.jl instead of Zygote.jl to compute the gradients, but is otherwise the
same as `Flux.train!(loss, model, data, opt_state)`.
"""
function Flux.train!(loss, model::Moonduo, data, opt; cb=nothing, epochs::Int=1)
isnothing(cb) || error("""train! does not support callback functions.
For more control use a loop with `gradient` and `update!`.""")
Flux.Train.@withprogress for (i,d) in enumerate(Iterators.cycle(data, epochs))
d_splat = d isa Tuple ? d : (d,)
rule = Mooncake.build_rrule(f, model.val, d_splat...) # perhaps not ideal to do this inside the loop?

Mooncake.set_to_zero!!(model.dval)
l, _ = Mooncake.__value_and_gradient!!(rule, Mooncake.zero_codual(f), model, map(Mooncake.zero_codual, d_splat)...)

if !isfinite(l)
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
end

Flux.update!(opt, model)

Flux.Train.@logprogress Base.haslength(data) ? i/(length(data)*epochs) : nothing
end
end

end # module
45 changes: 45 additions & 0 deletions test/mooncake.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using Flux, Fluxperimental, Mooncake

@testset "gradient, withgradient, Moonduo" begin
# Tests above are about how Enzyme digests Flux layers.
# Tests here are just the interface Flux.gradient(f, Moonduo(model)) etc.
m1 = Moonduo(Dense(3=>2))
@test m1 isa Moonduo
g1 = Flux.gradient(m -> sum(m.bias), m1) |> only
@test iszero(g1.weight)
@test g1.bias == [1, 1]
@test m1.dval.bias == [1, 1]

g2 = Flux.withgradient((m,x) -> sum(m(x)), m1, Moonduo([1,2,3f0])) # would prefer Const
@test g2.val sum(m1([1,2,3f0]))
@test g2.grad[1].weight [1 2 3; 1 2 3]
@test_skip g2.grad[2] === nothing # implicitly Const

# g3 = Flux.withgradient(Moonduo([1,2,4.], zeros(3))) do x
# z = 1 ./ x
# sum(z), z # here z is an auxillary output
# end
# @test g3.grad[1] ≈ [-1.0, -0.25, -0.0625]
# @test g3.val[1] ≈ 1.75
# @test g3.val[2] ≈ [1.0, 0.5, 0.25]
# g4 = Flux.withgradient(Moonduo([1,2,4.], zeros(3))) do x
# z = 1 ./ x
# (loss=sum(z), aux=string(z))
# end
# @test g4.grad[1] ≈ [-1.0, -0.25, -0.0625]
# @test g4.val.loss ≈ 1.75
# @test g4.val.aux == "[1.0, 0.5, 0.25]"

# setup understands Moonduo:
@test Flux.setup(Adam(), m1) == Flux.setup(Adam(), m1.val)

# # At least one Moonduo is required:
# @test_throws ArgumentError Flux.gradient(m -> sum(m.bias), Const(m1.val))
# @test_throws ArgumentError Flux.gradient((m,x) -> sum(m(x)), Const(m1.val), [1,2,3f0])
# @test_throws ArgumentError Flux.withgradient(m -> sum(m.bias), Const(m1.val))
# @test_throws ArgumentError Flux.withgradient((m,x) -> sum(m(x)), Const(m1.val), [1,2,3f0])
# # Active is disallowed:
# @test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, m1, Active(3f0))
# @test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, m1.val, Active(3f0))
# @test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, Const(m1.val), Active(3f0))
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ using Flux, Fluxperimental

include("new_recur.jl")

include("mooncake.jl")
end

0 comments on commit fc03e5a

Please sign in to comment.