From c4d1de41caef4166870ef633d1eeeef4e585c9b0 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 22 Nov 2024 15:35:54 -0500 Subject: [PATCH 1/7] first working Mooncake ext --- Project.toml | 9 ++- ext/FluxMooncakeExt.jl | 137 +++++++++++++++++++++++++++++++++++++++++ src/Fluxperimental.jl | 3 + src/mooncake.jl | 40 ++++++++++++ 4 files changed, 188 insertions(+), 1 deletion(-) create mode 100644 ext/FluxMooncakeExt.jl create mode 100644 src/mooncake.jl diff --git a/Project.toml b/Project.toml index 8d274be..21d58df 100644 --- a/Project.toml +++ b/Project.toml @@ -10,9 +10,16 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +[weakdeps] +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" + +[extensions] +FluxMooncakeExt = "Mooncake" + [compat] Compat = "4" -Flux = "0.14.23" +Flux = "0.14.23, 0.15" +Mooncake = "0.4.42" NNlib = "0.9" Optimisers = "0.3, 0.4" ProgressMeter = "1.7.2" diff --git a/ext/FluxMooncakeExt.jl b/ext/FluxMooncakeExt.jl new file mode 100644 index 0000000..e6ac2a1 --- /dev/null +++ b/ext/FluxMooncakeExt.jl @@ -0,0 +1,137 @@ +module FluxMooncakeExt + +using Flux, Fluxperimental, Mooncake +import Fluxperimental: _moonstrip +# using Flux: Const + +println("loaded mooncake ext") + +function Fluxperimental.Moonduo(x) + dx = Mooncake.zero_tangent(x) + Moonduo(x, dx) +end + +# Flux gradient etc. + +""" + Flux.gradient(f, args::Moonduo...) + +This uses Mooncake.jl to compute the derivative, +which is both stored within `Moonduo` and returned. +Similar to the Enzyme.jl methods like `Flux.gradient(f, m::Duplicated)`. + +# Example + +```julia +julia> using Flux + +julia> model = Chain(Dense([3.0;;])); + +julia> Flux.gradient(model, [1]) do m, x # computed using Zygote + sum(abs2, m(x)) + end +((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), [18.0]) + +julia> using Fluxperimental, Mooncake + +julia> dup_model = Moonduo(model); # allocates space for gradient + +julia> Flux.gradient(dup_model, Moonduo([1])) do m, x # Mooncake, returns the same + sum(abs2, m(x)) + end +((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), nothing) + +julia> dup_model # same gradient is also stored within Duplicated +Moonduo( + Chain( + Dense(1 => 1), # 2 parameters + ), + # norm(∇) ≈ 8.49 +) + +julia> Flux.destructure((weight = [6.0;;], bias = [6.0]))[1] |> norm +8.48528137423857 + +julia> Flux.gradient(dup_model, Moonduo([1]); zero=false) do m, x # grad accumulation + sum(abs2, m(x)) + end +((layers = ((weight = [12.0;;], bias = [12.0], σ = nothing),),), nothing) +``` +""" +Flux.gradient(f, args::Moonduo...; zero::Bool=true) = _moon_withgradient(f, args...; zero).grad + +""" + Flux.withgradient(f, args::Moonduo...) + +This should return the same answer as `withgradient(f, model, args...)`, +but it uses Mooncake.jl instead of Zygote.jl to compute the derivative. + +# Example + +```julia +julia> using Flux, Fluxperimental, Mooncake + +julia> model = Chain(Embedding([1.1 2.2 3.3]), Dense([4.4;;]), only); + +julia> model(3) +14.52 + +julia> Flux.withgradient(m -> m(3), model) # this uses Zygote +(val = 14.52, grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),)) + +julia> Flux.withgradient(m -> m(3), Moonduo(model)) # this uses Mooncake +(val = 14.52, grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),)) +``` + +!!! warning + With Zygote, the function `f` may return Tuple or NamedTuple, with the loss as the first element. + This feature is not supported here, for now. +""" +Flux.withgradient(f, args::Moonduo...; zero::Bool=true) = _moon_withgradient(f, args...; zero) + +function _moon_withgradient(f, args::Moonduo...; zero) + plain = map(x -> x.val, args) + rule = Mooncake.build_rrule(f, plain...) + + for x in args + zero && _moonzero!(x.dval) + end + coduals = map(x -> Mooncake.CoDual(x.val, x.dval), args) + val, _ = Mooncake.__value_and_gradient!!(rule, Mooncake.zero_codual(f), coduals...) + + grad = map(x -> _moonstrip(x.dval), args) + (; val, grad) +end + +_moonzero!(dx::Mooncake.Tangent) = foreach(_moonzero!, dx.fields) +_moonzero!(dx::Mooncake.MutableTangent) = foreach(_moonzero!, dx.fields) +_moonzero!(dx::Mooncake.NoTangent) = nothing +_moonzero!(dx::Union{Tuple, NamedTuple, AbstractArray}) = foreach(_moonzero!, dx) +_moonzero!(dx::AbstractArray{Mooncake.NoTangent}) = nothing +_moonzero!(dx::AbstractArray{<:Number}) = dx .= 0 +function _moonzero!(dx) + @warn "not sure what to do with this type" typeof(dx) + dx +end + +_moonstrip(dx::Mooncake.Tangent) = map(_moonstrip, dx.fields) +_moonstrip(dx::Mooncake.MutableTangent) = map(_moonstrip, dx.fields) +_moonstrip(dx::Mooncake.NoTangent) = nothing +_moonstrip(dx::Union{Tuple, NamedTuple, AbstractArray}) = map(_moonstrip, dx) +_moonstrip(dx::AbstractArray{Mooncake.NoTangent}) = nothing +_moonstrip(dx::AbstractArray{<:Number}) = dx +function _moonstrip(dx) + @warn "not sure what to do with this type" typeof(dx) + dx +end + +# Optimisers etc. + +Flux.setup(m::Moonduo) = Flux.setup(m.val) + +function Flux.update!(opt_state, model::Moonduo) + Flux.update!(opt_state, model.val, _moonstrip(model.dval)) + nothing +end + +end # module diff --git a/src/Fluxperimental.jl b/src/Fluxperimental.jl index 703b63b..83a871d 100644 --- a/src/Fluxperimental.jl +++ b/src/Fluxperimental.jl @@ -21,4 +21,7 @@ export @autostruct include("new_recur.jl") +include("mooncake.jl") +export Moonduo + end # module Fluxperimental diff --git a/src/mooncake.jl b/src/mooncake.jl new file mode 100644 index 0000000..56cfec0 --- /dev/null +++ b/src/mooncake.jl @@ -0,0 +1,40 @@ +""" + Moonduo(x, [dx]) + +This stores both an object `x` and its gradient `dx`, +with `dx` in the format used by Mooncake.jl. This is automatically allocated +when you call `Moonduo(x)`. + +This serves the same purpose as Enzyme.jl's `Duplicated` type. +Both of these AD engines prefer that space for the gradient be pre-allocated. + +Maybe this is like Mooncake.CoDual, except that it's marked private and seems discouraged: +https://github.com/compintell/Mooncake.jl/issues/275 + +""" +struct Moonduo{X,DX} + val::X + dval::DX +end + +function Moonduo(args...) + if length(args)==1 + error("The method `Moonduo(x)` is only available when Mooncake.jl is loaded!") + else + error("The only legal methods are `Moonduo(x)` and `Moonduo(x, dx)`.") + end +end + +Optimisers.trainable(m::Moonduo) = (; m.val) + +Flux.@layer :expand Moonduo + +(m::Moonduo)(x...) = m.val(x...) + +function _moonstrip end + +function Flux._show_pre_post(obj::Moonduo) + nrm = Flux.norm(destructure(_moonstrip(obj.dval))[1]) + str = repr(round(nrm; sigdigits=3)) + "Moonduo(", " # norm(∇) ≈ $str\n) " +end From fc03e5aef3e137b3abfb36c247917c26df7e8cc3 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 22 Nov 2024 19:18:19 -0500 Subject: [PATCH 2/7] fixup, add tests --- ext/FluxMooncakeExt.jl | 51 +++++++++++++++++++++++++++++++----------- test/mooncake.jl | 45 +++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 84 insertions(+), 13 deletions(-) create mode 100644 test/mooncake.jl diff --git a/ext/FluxMooncakeExt.jl b/ext/FluxMooncakeExt.jl index e6ac2a1..bc31dcc 100644 --- a/ext/FluxMooncakeExt.jl +++ b/ext/FluxMooncakeExt.jl @@ -4,8 +4,6 @@ using Flux, Fluxperimental, Mooncake import Fluxperimental: _moonstrip # using Flux: Const -println("loaded mooncake ext") - function Fluxperimental.Moonduo(x) dx = Mooncake.zero_tangent(x) Moonduo(x, dx) @@ -94,7 +92,8 @@ function _moon_withgradient(f, args::Moonduo...; zero) rule = Mooncake.build_rrule(f, plain...) for x in args - zero && _moonzero!(x.dval) + _check_mutable(x) + zero && Mooncake.set_to_zero!!(x.dval) end coduals = map(x -> Mooncake.CoDual(x.val, x.dval), args) val, _ = Mooncake.__value_and_gradient!!(rule, Mooncake.zero_codual(f), coduals...) @@ -103,16 +102,10 @@ function _moon_withgradient(f, args::Moonduo...; zero) (; val, grad) end -_moonzero!(dx::Mooncake.Tangent) = foreach(_moonzero!, dx.fields) -_moonzero!(dx::Mooncake.MutableTangent) = foreach(_moonzero!, dx.fields) -_moonzero!(dx::Mooncake.NoTangent) = nothing -_moonzero!(dx::Union{Tuple, NamedTuple, AbstractArray}) = foreach(_moonzero!, dx) -_moonzero!(dx::AbstractArray{Mooncake.NoTangent}) = nothing -_moonzero!(dx::AbstractArray{<:Number}) = dx .= 0 -function _moonzero!(dx) - @warn "not sure what to do with this type" typeof(dx) - dx -end +# _check_mutable(x::Const) = nothing +_check_mutable(x::Moonduo) = Functors.anymutable(x) || error( + """`Flux.gradient(f, Moonduo(x), ...)` expects `x` to contain mutable parameter arrays.""" +) _moonstrip(dx::Mooncake.Tangent) = map(_moonstrip, dx.fields) _moonstrip(dx::Mooncake.MutableTangent) = map(_moonstrip, dx.fields) @@ -120,6 +113,8 @@ _moonstrip(dx::Mooncake.NoTangent) = nothing _moonstrip(dx::Union{Tuple, NamedTuple, AbstractArray}) = map(_moonstrip, dx) _moonstrip(dx::AbstractArray{Mooncake.NoTangent}) = nothing _moonstrip(dx::AbstractArray{<:Number}) = dx +_moonstrip(dx::AbstractArray{<:Integer}) = nothing +_moonstrip!(dx::Number) = nothing function _moonstrip(dx) @warn "not sure what to do with this type" typeof(dx) dx @@ -134,4 +129,34 @@ function Flux.update!(opt_state, model::Moonduo) nothing end +### Flux.Train, for train! + +_applyloss(loss, model, d...) = loss(model, d...) + +""" + train!(loss, Moonduo(model), data, opt_state) + +This method uses Mooncake.jl instead of Zygote.jl to compute the gradients, but is otherwise the +same as `Flux.train!(loss, model, data, opt_state)`. +""" +function Flux.train!(loss, model::Moonduo, data, opt; cb=nothing, epochs::Int=1) + isnothing(cb) || error("""train! does not support callback functions. + For more control use a loop with `gradient` and `update!`.""") + Flux.Train.@withprogress for (i,d) in enumerate(Iterators.cycle(data, epochs)) + d_splat = d isa Tuple ? d : (d,) + rule = Mooncake.build_rrule(f, model.val, d_splat...) # perhaps not ideal to do this inside the loop? + + Mooncake.set_to_zero!!(model.dval) + l, _ = Mooncake.__value_and_gradient!!(rule, Mooncake.zero_codual(f), model, map(Mooncake.zero_codual, d_splat)...) + + if !isfinite(l) + throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) + end + + Flux.update!(opt, model) + + Flux.Train.@logprogress Base.haslength(data) ? i/(length(data)*epochs) : nothing + end +end + end # module diff --git a/test/mooncake.jl b/test/mooncake.jl new file mode 100644 index 0000000..3d5275f --- /dev/null +++ b/test/mooncake.jl @@ -0,0 +1,45 @@ +using Flux, Fluxperimental, Mooncake + +@testset "gradient, withgradient, Moonduo" begin + # Tests above are about how Enzyme digests Flux layers. + # Tests here are just the interface Flux.gradient(f, Moonduo(model)) etc. + m1 = Moonduo(Dense(3=>2)) + @test m1 isa Moonduo + g1 = Flux.gradient(m -> sum(m.bias), m1) |> only + @test iszero(g1.weight) + @test g1.bias == [1, 1] + @test m1.dval.bias == [1, 1] + + g2 = Flux.withgradient((m,x) -> sum(m(x)), m1, Moonduo([1,2,3f0])) # would prefer Const + @test g2.val ≈ sum(m1([1,2,3f0])) + @test g2.grad[1].weight ≈ [1 2 3; 1 2 3] + @test_skip g2.grad[2] === nothing # implicitly Const + + # g3 = Flux.withgradient(Moonduo([1,2,4.], zeros(3))) do x + # z = 1 ./ x + # sum(z), z # here z is an auxillary output + # end + # @test g3.grad[1] ≈ [-1.0, -0.25, -0.0625] + # @test g3.val[1] ≈ 1.75 + # @test g3.val[2] ≈ [1.0, 0.5, 0.25] + # g4 = Flux.withgradient(Moonduo([1,2,4.], zeros(3))) do x + # z = 1 ./ x + # (loss=sum(z), aux=string(z)) + # end + # @test g4.grad[1] ≈ [-1.0, -0.25, -0.0625] + # @test g4.val.loss ≈ 1.75 + # @test g4.val.aux == "[1.0, 0.5, 0.25]" + + # setup understands Moonduo: + @test Flux.setup(Adam(), m1) == Flux.setup(Adam(), m1.val) + + # # At least one Moonduo is required: + # @test_throws ArgumentError Flux.gradient(m -> sum(m.bias), Const(m1.val)) + # @test_throws ArgumentError Flux.gradient((m,x) -> sum(m(x)), Const(m1.val), [1,2,3f0]) + # @test_throws ArgumentError Flux.withgradient(m -> sum(m.bias), Const(m1.val)) + # @test_throws ArgumentError Flux.withgradient((m,x) -> sum(m(x)), Const(m1.val), [1,2,3f0]) + # # Active is disallowed: + # @test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, m1, Active(3f0)) + # @test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, m1.val, Active(3f0)) + # @test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, Const(m1.val), Active(3f0)) +end diff --git a/test/runtests.jl b/test/runtests.jl index 92a4191..55f35de 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,4 +13,5 @@ using Flux, Fluxperimental include("new_recur.jl") + include("mooncake.jl") end From 139ff3ee6b99575f5ef11328c8126d3ba4a6f2b9 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 22 Nov 2024 19:40:30 -0500 Subject: [PATCH 3/7] fixup --- Project.toml | 3 ++- README.md | 1 + ext/FluxMooncakeExt.jl | 16 +++++++++++----- test/mooncake.jl | 2 +- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 21d58df..2dd8c98 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.2.2" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" @@ -30,4 +31,4 @@ julia = "1.10" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Test", "Mooncake"] diff --git a/README.md b/README.md index 6323e3d..cab59ca 100644 --- a/README.md +++ b/README.md @@ -42,3 +42,4 @@ There are no formal documentation pages, but these links to the source will show [`@compact(kw...) do ...`](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/compact.jl), and [`@autostruct function Mine(d) ...`](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/autostruct.jl). * Experimental [`apply(c::Chain, x)`](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/chain.jl) interface +* Easy way to [use Mooncake.jl](https://github.com/FluxML/Fluxperimental.jl/blob/master/ext/FluxMooncakeExt.jl) instead of Zygote.jl. diff --git a/ext/FluxMooncakeExt.jl b/ext/FluxMooncakeExt.jl index bc31dcc..a3a8eab 100644 --- a/ext/FluxMooncakeExt.jl +++ b/ext/FluxMooncakeExt.jl @@ -1,6 +1,6 @@ module FluxMooncakeExt -using Flux, Fluxperimental, Mooncake +using Flux, Fluxperimental, Optimisers, Functors, Mooncake import Fluxperimental: _moonstrip # using Flux: Const @@ -98,7 +98,7 @@ function _moon_withgradient(f, args::Moonduo...; zero) coduals = map(x -> Mooncake.CoDual(x.val, x.dval), args) val, _ = Mooncake.__value_and_gradient!!(rule, Mooncake.zero_codual(f), coduals...) - grad = map(x -> _moonstrip(x.dval), args) + grad = map(x -> _moongrad(x.dval), args) (; val, grad) end @@ -107,6 +107,12 @@ _check_mutable(x::Moonduo) = Functors.anymutable(x) || error( """`Flux.gradient(f, Moonduo(x), ...)` expects `x` to contain mutable parameter arrays.""" ) +function _moongrad(dx) + dx2 = _moonstrip(dx) # remove all the weird types + isnothing(dx2) && return + return Flux.fmapstructure(identity, dx2; prune=nothing) +end + _moonstrip(dx::Mooncake.Tangent) = map(_moonstrip, dx.fields) _moonstrip(dx::Mooncake.MutableTangent) = map(_moonstrip, dx.fields) _moonstrip(dx::Mooncake.NoTangent) = nothing @@ -114,7 +120,7 @@ _moonstrip(dx::Union{Tuple, NamedTuple, AbstractArray}) = map(_moonstrip, dx) _moonstrip(dx::AbstractArray{Mooncake.NoTangent}) = nothing _moonstrip(dx::AbstractArray{<:Number}) = dx _moonstrip(dx::AbstractArray{<:Integer}) = nothing -_moonstrip!(dx::Number) = nothing +_moonstrip(dx::Number) = nothing function _moonstrip(dx) @warn "not sure what to do with this type" typeof(dx) dx @@ -122,10 +128,10 @@ end # Optimisers etc. -Flux.setup(m::Moonduo) = Flux.setup(m.val) +Flux.setup(rule::Optimisers.AbstractRule, m::Moonduo) = Flux.setup(rule, m.val) function Flux.update!(opt_state, model::Moonduo) - Flux.update!(opt_state, model.val, _moonstrip(model.dval)) + Flux.update!(opt_state, model.val, _moongrad(model.dval)) nothing end diff --git a/test/mooncake.jl b/test/mooncake.jl index 3d5275f..a1287cc 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -8,7 +8,7 @@ using Flux, Fluxperimental, Mooncake g1 = Flux.gradient(m -> sum(m.bias), m1) |> only @test iszero(g1.weight) @test g1.bias == [1, 1] - @test m1.dval.bias == [1, 1] + @test m1.dval.fields.bias == [1, 1] g2 = Flux.withgradient((m,x) -> sum(m(x)), m1, Moonduo([1,2,3f0])) # would prefer Const @test g2.val ≈ sum(m1([1,2,3f0])) From 227e01fd4bc9df3ccdea7dea394bff3cbd343796 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 23 Nov 2024 13:05:14 -0500 Subject: [PATCH 4/7] note about const --- ext/FluxMooncakeExt.jl | 5 +++++ src/mooncake.jl | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/ext/FluxMooncakeExt.jl b/ext/FluxMooncakeExt.jl index a3a8eab..a5acee8 100644 --- a/ext/FluxMooncakeExt.jl +++ b/ext/FluxMooncakeExt.jl @@ -55,6 +55,11 @@ julia> Flux.gradient(dup_model, Moonduo([1]); zero=false) do m, x # grad accumu end ((layers = ((weight = [12.0;;], bias = [12.0], σ = nothing),),), nothing) ``` + +!!! note + At present there is no way to mark some arguments constant. + Instead of `gradient(loss, Duplicated(model), Const(data))`, + you can write `gradient(m -> loss(m, data), Moonduo(model))`. """ Flux.gradient(f, args::Moonduo...; zero::Bool=true) = _moon_withgradient(f, args...; zero).grad diff --git a/src/mooncake.jl b/src/mooncake.jl index 56cfec0..8b4f736 100644 --- a/src/mooncake.jl +++ b/src/mooncake.jl @@ -8,9 +8,9 @@ when you call `Moonduo(x)`. This serves the same purpose as Enzyme.jl's `Duplicated` type. Both of these AD engines prefer that space for the gradient be pre-allocated. -Maybe this is like Mooncake.CoDual, except that it's marked private and seems discouraged: +Maybe this is like `Mooncake.CoDual`, except that is marked private, and seems discouraged: https://github.com/compintell/Mooncake.jl/issues/275 - +An advantage of Flux owning this type is that we can provide pretty printing without piracy. """ struct Moonduo{X,DX} val::X From a095888422b8d2d936782f1c304f14a6d177de12 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 25 Nov 2024 20:21:11 -0500 Subject: [PATCH 5/7] tweaks, including state & loadmodel --- Project.toml | 2 +- ext/FluxMooncakeExt.jl | 20 ++++++++++++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 2dd8c98..e19540e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Fluxperimental" uuid = "3102ee7a-c841-4564-8f7f-ec69bd4fd658" -version = "0.2.2" +version = "0.2.3" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/ext/FluxMooncakeExt.jl b/ext/FluxMooncakeExt.jl index a5acee8..6b1e920 100644 --- a/ext/FluxMooncakeExt.jl +++ b/ext/FluxMooncakeExt.jl @@ -9,7 +9,7 @@ function Fluxperimental.Moonduo(x) Moonduo(x, dx) end -# Flux gradient etc. +### Flux gradient etc. """ Flux.gradient(f, args::Moonduo...) @@ -124,14 +124,13 @@ _moonstrip(dx::Mooncake.NoTangent) = nothing _moonstrip(dx::Union{Tuple, NamedTuple, AbstractArray}) = map(_moonstrip, dx) _moonstrip(dx::AbstractArray{Mooncake.NoTangent}) = nothing _moonstrip(dx::AbstractArray{<:Number}) = dx -_moonstrip(dx::AbstractArray{<:Integer}) = nothing _moonstrip(dx::Number) = nothing function _moonstrip(dx) - @warn "not sure what to do with this type" typeof(dx) + @error "not sure what to do with this type, in a gradient from Mooncake" typeof(dx) dx end -# Optimisers etc. +### Optimisers etc. Flux.setup(rule::Optimisers.AbstractRule, m::Moonduo) = Flux.setup(rule, m.val) @@ -170,4 +169,17 @@ function Flux.train!(loss, model::Moonduo, data, opt; cb=nothing, epochs::Int=1) end end +### Model state & loading + +Flux.state(x::Moonduo) = Flux.state(x.val) + +function Flux.loadmodel!(dst::Moonduo, src::Moonduo; kw...) + Flux.loadmodel!(dst.val, src.val; kw...) + dst +end +function Flux.loadmodel!(dst::Moonduo, src; kw...) + Flux.loadmodel!(dst.val, src; kw...) + dst +end + end # module From 3041e7a528cde3f03f1060b0070a10b890e78dfb Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 25 Nov 2024 21:34:03 -0500 Subject: [PATCH 6/7] tweak --- src/mooncake.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/mooncake.jl b/src/mooncake.jl index 8b4f736..e8826f0 100644 --- a/src/mooncake.jl +++ b/src/mooncake.jl @@ -7,10 +7,6 @@ when you call `Moonduo(x)`. This serves the same purpose as Enzyme.jl's `Duplicated` type. Both of these AD engines prefer that space for the gradient be pre-allocated. - -Maybe this is like `Mooncake.CoDual`, except that is marked private, and seems discouraged: -https://github.com/compintell/Mooncake.jl/issues/275 -An advantage of Flux owning this type is that we can provide pretty printing without piracy. """ struct Moonduo{X,DX} val::X @@ -29,7 +25,12 @@ Optimisers.trainable(m::Moonduo) = (; m.val) Flux.@layer :expand Moonduo -(m::Moonduo)(x...) = m.val(x...) +function (m::Moonduo)(x...) + Zygote.isderiving() && error("""`Moonduo(flux_model)` is only for use with Mooncake.jl. + Calling `Zygote.gradient` directly on such a wrapped model is not supported. + You may have accidentally called `Flux.gradient(loss, Moonduo(model), x)` without wrapping `x`.""") + m.val(x...) +end function _moonstrip end From 2a536d1a080d2623ecefe718a199704266a79c1f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 25 Nov 2024 22:57:20 -0500 Subject: [PATCH 7/7] Update mooncake.jl --- src/mooncake.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mooncake.jl b/src/mooncake.jl index e8826f0..215ce74 100644 --- a/src/mooncake.jl +++ b/src/mooncake.jl @@ -26,7 +26,7 @@ Optimisers.trainable(m::Moonduo) = (; m.val) Flux.@layer :expand Moonduo function (m::Moonduo)(x...) - Zygote.isderiving() && error("""`Moonduo(flux_model)` is only for use with Mooncake.jl. + Flux.Zygote.isderiving() && error("""`Moonduo(flux_model)` is only for use with Mooncake.jl. Calling `Zygote.gradient` directly on such a wrapped model is not supported. You may have accidentally called `Flux.gradient(loss, Moonduo(model), x)` without wrapping `x`.""") m.val(x...)