diff --git a/Project.toml b/Project.toml index 21d58df..39aeb3e 100644 --- a/Project.toml +++ b/Project.toml @@ -30,4 +30,4 @@ julia = "1.10" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Test", "Mooncake"] diff --git a/README.md b/README.md index 6323e3d..cab59ca 100644 --- a/README.md +++ b/README.md @@ -42,3 +42,4 @@ There are no formal documentation pages, but these links to the source will show [`@compact(kw...) do ...`](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/compact.jl), and [`@autostruct function Mine(d) ...`](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/autostruct.jl). * Experimental [`apply(c::Chain, x)`](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/chain.jl) interface +* Easy way to [use Mooncake.jl](https://github.com/FluxML/Fluxperimental.jl/blob/master/ext/FluxMooncakeExt.jl) instead of Zygote.jl. diff --git a/ext/FluxMooncakeExt.jl b/ext/FluxMooncakeExt.jl index bc31dcc..7e55885 100644 --- a/ext/FluxMooncakeExt.jl +++ b/ext/FluxMooncakeExt.jl @@ -98,15 +98,21 @@ function _moon_withgradient(f, args::Moonduo...; zero) coduals = map(x -> Mooncake.CoDual(x.val, x.dval), args) val, _ = Mooncake.__value_and_gradient!!(rule, Mooncake.zero_codual(f), coduals...) - grad = map(x -> _moonstrip(x.dval), args) + grad = map(x -> _moongrad(x.dval), args) (; val, grad) end # _check_mutable(x::Const) = nothing -_check_mutable(x::Moonduo) = Functors.anymutable(x) || error( +_check_mutable(x::Moonduo) = Flux.Functors.anymutable(x) || error( """`Flux.gradient(f, Moonduo(x), ...)` expects `x` to contain mutable parameter arrays.""" ) +function _moongrad(dx) + dx2 = _moonstrip(dx) # remove all the weird types + isnothing(dx2) && return + return Flux.fmapstructure(identity, dx2; prune=nothing) +end + _moonstrip(dx::Mooncake.Tangent) = map(_moonstrip, dx.fields) _moonstrip(dx::Mooncake.MutableTangent) = map(_moonstrip, dx.fields) _moonstrip(dx::Mooncake.NoTangent) = nothing @@ -114,7 +120,7 @@ _moonstrip(dx::Union{Tuple, NamedTuple, AbstractArray}) = map(_moonstrip, dx) _moonstrip(dx::AbstractArray{Mooncake.NoTangent}) = nothing _moonstrip(dx::AbstractArray{<:Number}) = dx _moonstrip(dx::AbstractArray{<:Integer}) = nothing -_moonstrip!(dx::Number) = nothing +_moonstrip(dx::Number) = nothing function _moonstrip(dx) @warn "not sure what to do with this type" typeof(dx) dx @@ -122,10 +128,10 @@ end # Optimisers etc. -Flux.setup(m::Moonduo) = Flux.setup(m.val) +Flux.setup(rule::Flux.AbstractRule, m::Moonduo) = Flux.setup(rule, m.val) function Flux.update!(opt_state, model::Moonduo) - Flux.update!(opt_state, model.val, _moonstrip(model.dval)) + Flux.update!(opt_state, model.val, _moongrad(model.dval)) nothing end diff --git a/test/mooncake.jl b/test/mooncake.jl index 3d5275f..a1287cc 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -8,7 +8,7 @@ using Flux, Fluxperimental, Mooncake g1 = Flux.gradient(m -> sum(m.bias), m1) |> only @test iszero(g1.weight) @test g1.bias == [1, 1] - @test m1.dval.bias == [1, 1] + @test m1.dval.fields.bias == [1, 1] g2 = Flux.withgradient((m,x) -> sum(m(x)), m1, Moonduo([1,2,3f0])) # would prefer Const @test g2.val ≈ sum(m1([1,2,3f0]))