Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 23, 2024
1 parent fc03e5a commit 4fd4443
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ julia = "1.10"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Test", "Mooncake"]
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ There are no formal documentation pages, but these links to the source will show
[`@compact(kw...) do ...`](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/compact.jl), and
[`@autostruct function Mine(d) ...`](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/autostruct.jl).
* Experimental [`apply(c::Chain, x)`](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/chain.jl) interface
* Easy way to [use Mooncake.jl](https://github.com/FluxML/Fluxperimental.jl/blob/master/ext/FluxMooncakeExt.jl) instead of Zygote.jl.
16 changes: 11 additions & 5 deletions ext/FluxMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,34 +98,40 @@ function _moon_withgradient(f, args::Moonduo...; zero)
coduals = map(x -> Mooncake.CoDual(x.val, x.dval), args)
val, _ = Mooncake.__value_and_gradient!!(rule, Mooncake.zero_codual(f), coduals...)

grad = map(x -> _moonstrip(x.dval), args)
grad = map(x -> _moongrad(x.dval), args)
(; val, grad)
end

# _check_mutable(x::Const) = nothing
_check_mutable(x::Moonduo) = Functors.anymutable(x) || error(
_check_mutable(x::Moonduo) = Flux.Functors.anymutable(x) || error(
"""`Flux.gradient(f, Moonduo(x), ...)` expects `x` to contain mutable parameter arrays."""
)

function _moongrad(dx)
dx2 = _moonstrip(dx) # remove all the weird types
isnothing(dx2) && return
return Flux.fmapstructure(identity, dx2; prune=nothing)
end

_moonstrip(dx::Mooncake.Tangent) = map(_moonstrip, dx.fields)
_moonstrip(dx::Mooncake.MutableTangent) = map(_moonstrip, dx.fields)
_moonstrip(dx::Mooncake.NoTangent) = nothing
_moonstrip(dx::Union{Tuple, NamedTuple, AbstractArray}) = map(_moonstrip, dx)
_moonstrip(dx::AbstractArray{Mooncake.NoTangent}) = nothing
_moonstrip(dx::AbstractArray{<:Number}) = dx
_moonstrip(dx::AbstractArray{<:Integer}) = nothing
_moonstrip!(dx::Number) = nothing
_moonstrip(dx::Number) = nothing
function _moonstrip(dx)
@warn "not sure what to do with this type" typeof(dx)
dx
end

# Optimisers etc.

Flux.setup(m::Moonduo) = Flux.setup(m.val)
Flux.setup(rule::Flux.AbstractRule, m::Moonduo) = Flux.setup(rule, m.val)

function Flux.update!(opt_state, model::Moonduo)
Flux.update!(opt_state, model.val, _moonstrip(model.dval))
Flux.update!(opt_state, model.val, _moongrad(model.dval))
nothing
end

Expand Down
2 changes: 1 addition & 1 deletion test/mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using Flux, Fluxperimental, Mooncake
g1 = Flux.gradient(m -> sum(m.bias), m1) |> only
@test iszero(g1.weight)
@test g1.bias == [1, 1]
@test m1.dval.bias == [1, 1]
@test m1.dval.fields.bias == [1, 1]

g2 = Flux.withgradient((m,x) -> sum(m(x)), m1, Moonduo([1,2,3f0])) # would prefer Const
@test g2.val sum(m1([1,2,3f0]))
Expand Down

0 comments on commit 4fd4443

Please sign in to comment.