From fc03e5aef3e137b3abfb36c247917c26df7e8cc3 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 22 Nov 2024 19:18:19 -0500 Subject: [PATCH] fixup, add tests --- ext/FluxMooncakeExt.jl | 51 +++++++++++++++++++++++++++++++----------- test/mooncake.jl | 45 +++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 84 insertions(+), 13 deletions(-) create mode 100644 test/mooncake.jl diff --git a/ext/FluxMooncakeExt.jl b/ext/FluxMooncakeExt.jl index e6ac2a1..bc31dcc 100644 --- a/ext/FluxMooncakeExt.jl +++ b/ext/FluxMooncakeExt.jl @@ -4,8 +4,6 @@ using Flux, Fluxperimental, Mooncake import Fluxperimental: _moonstrip # using Flux: Const -println("loaded mooncake ext") - function Fluxperimental.Moonduo(x) dx = Mooncake.zero_tangent(x) Moonduo(x, dx) @@ -94,7 +92,8 @@ function _moon_withgradient(f, args::Moonduo...; zero) rule = Mooncake.build_rrule(f, plain...) for x in args - zero && _moonzero!(x.dval) + _check_mutable(x) + zero && Mooncake.set_to_zero!!(x.dval) end coduals = map(x -> Mooncake.CoDual(x.val, x.dval), args) val, _ = Mooncake.__value_and_gradient!!(rule, Mooncake.zero_codual(f), coduals...) @@ -103,16 +102,10 @@ function _moon_withgradient(f, args::Moonduo...; zero) (; val, grad) end -_moonzero!(dx::Mooncake.Tangent) = foreach(_moonzero!, dx.fields) -_moonzero!(dx::Mooncake.MutableTangent) = foreach(_moonzero!, dx.fields) -_moonzero!(dx::Mooncake.NoTangent) = nothing -_moonzero!(dx::Union{Tuple, NamedTuple, AbstractArray}) = foreach(_moonzero!, dx) -_moonzero!(dx::AbstractArray{Mooncake.NoTangent}) = nothing -_moonzero!(dx::AbstractArray{<:Number}) = dx .= 0 -function _moonzero!(dx) - @warn "not sure what to do with this type" typeof(dx) - dx -end +# _check_mutable(x::Const) = nothing +_check_mutable(x::Moonduo) = Functors.anymutable(x) || error( + """`Flux.gradient(f, Moonduo(x), ...)` expects `x` to contain mutable parameter arrays.""" +) _moonstrip(dx::Mooncake.Tangent) = map(_moonstrip, dx.fields) _moonstrip(dx::Mooncake.MutableTangent) = map(_moonstrip, dx.fields) @@ -120,6 +113,8 @@ _moonstrip(dx::Mooncake.NoTangent) = nothing _moonstrip(dx::Union{Tuple, NamedTuple, AbstractArray}) = map(_moonstrip, dx) _moonstrip(dx::AbstractArray{Mooncake.NoTangent}) = nothing _moonstrip(dx::AbstractArray{<:Number}) = dx +_moonstrip(dx::AbstractArray{<:Integer}) = nothing +_moonstrip!(dx::Number) = nothing function _moonstrip(dx) @warn "not sure what to do with this type" typeof(dx) dx @@ -134,4 +129,34 @@ function Flux.update!(opt_state, model::Moonduo) nothing end +### Flux.Train, for train! + +_applyloss(loss, model, d...) = loss(model, d...) + +""" + train!(loss, Moonduo(model), data, opt_state) + +This method uses Mooncake.jl instead of Zygote.jl to compute the gradients, but is otherwise the +same as `Flux.train!(loss, model, data, opt_state)`. +""" +function Flux.train!(loss, model::Moonduo, data, opt; cb=nothing, epochs::Int=1) + isnothing(cb) || error("""train! does not support callback functions. + For more control use a loop with `gradient` and `update!`.""") + Flux.Train.@withprogress for (i,d) in enumerate(Iterators.cycle(data, epochs)) + d_splat = d isa Tuple ? d : (d,) + rule = Mooncake.build_rrule(f, model.val, d_splat...) # perhaps not ideal to do this inside the loop? + + Mooncake.set_to_zero!!(model.dval) + l, _ = Mooncake.__value_and_gradient!!(rule, Mooncake.zero_codual(f), model, map(Mooncake.zero_codual, d_splat)...) + + if !isfinite(l) + throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) + end + + Flux.update!(opt, model) + + Flux.Train.@logprogress Base.haslength(data) ? i/(length(data)*epochs) : nothing + end +end + end # module diff --git a/test/mooncake.jl b/test/mooncake.jl new file mode 100644 index 0000000..3d5275f --- /dev/null +++ b/test/mooncake.jl @@ -0,0 +1,45 @@ +using Flux, Fluxperimental, Mooncake + +@testset "gradient, withgradient, Moonduo" begin + # Tests above are about how Enzyme digests Flux layers. + # Tests here are just the interface Flux.gradient(f, Moonduo(model)) etc. + m1 = Moonduo(Dense(3=>2)) + @test m1 isa Moonduo + g1 = Flux.gradient(m -> sum(m.bias), m1) |> only + @test iszero(g1.weight) + @test g1.bias == [1, 1] + @test m1.dval.bias == [1, 1] + + g2 = Flux.withgradient((m,x) -> sum(m(x)), m1, Moonduo([1,2,3f0])) # would prefer Const + @test g2.val ≈ sum(m1([1,2,3f0])) + @test g2.grad[1].weight ≈ [1 2 3; 1 2 3] + @test_skip g2.grad[2] === nothing # implicitly Const + + # g3 = Flux.withgradient(Moonduo([1,2,4.], zeros(3))) do x + # z = 1 ./ x + # sum(z), z # here z is an auxillary output + # end + # @test g3.grad[1] ≈ [-1.0, -0.25, -0.0625] + # @test g3.val[1] ≈ 1.75 + # @test g3.val[2] ≈ [1.0, 0.5, 0.25] + # g4 = Flux.withgradient(Moonduo([1,2,4.], zeros(3))) do x + # z = 1 ./ x + # (loss=sum(z), aux=string(z)) + # end + # @test g4.grad[1] ≈ [-1.0, -0.25, -0.0625] + # @test g4.val.loss ≈ 1.75 + # @test g4.val.aux == "[1.0, 0.5, 0.25]" + + # setup understands Moonduo: + @test Flux.setup(Adam(), m1) == Flux.setup(Adam(), m1.val) + + # # At least one Moonduo is required: + # @test_throws ArgumentError Flux.gradient(m -> sum(m.bias), Const(m1.val)) + # @test_throws ArgumentError Flux.gradient((m,x) -> sum(m(x)), Const(m1.val), [1,2,3f0]) + # @test_throws ArgumentError Flux.withgradient(m -> sum(m.bias), Const(m1.val)) + # @test_throws ArgumentError Flux.withgradient((m,x) -> sum(m(x)), Const(m1.val), [1,2,3f0]) + # # Active is disallowed: + # @test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, m1, Active(3f0)) + # @test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, m1.val, Active(3f0)) + # @test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, Const(m1.val), Active(3f0)) +end diff --git a/test/runtests.jl b/test/runtests.jl index 92a4191..55f35de 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,4 +13,5 @@ using Flux, Fluxperimental include("new_recur.jl") + include("mooncake.jl") end