Skip to content

Commit

Permalink
add train_autodiff macro
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Aug 15, 2022
1 parent 7aa74e0 commit faf572f
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 4 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ ProgressLogging = "0.1"
Reexport = "0.2, 1.0"
SpecialFunctions = "1.8.2, 2.1.2"
StatsBase = "0.33"
Yota = "0.7.4"
Zygote = "0.6.34"
julia = "1.6"

Expand All @@ -49,6 +50,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Yota = "cd998857-8626-517d-b929-70ad188a48f0"

[targets]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Yota"]
48 changes: 47 additions & 1 deletion src/train/Train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using LinearAlgebra
using Optimisers: Optimisers
using Functors: fmap

export train!, update!, adjust!, FluxState,
export train!, update!, adjust!, FluxState, @train_autodiff,
Descent, Adam, Momentum, Nesterov, RMSProp,
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief #,
# InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser,
Expand Down Expand Up @@ -108,6 +108,52 @@ include("implicit_train.jl") # Params etc, Zygote only

explicit_withgradient(f, args...) = Zygote.withgradient(f, args...) # can overload this to use e.g. Yota / Diffractor

"""
Flux.@train_autodiff Zygote
Flux.@train_autodiff Yota
Flux.@train_autodiff Diffractor
This macro allows the use of `train!` with various automatic differentiation packages,
instead of the default Zygote.jl.
You should load the package, then call this macro.
Only affects "explicit-mode" versions `train!(loss, model, data, opt)` or `train!(loss, model, opt)`,
since the (deprecated) "implicit-mode" `train!(loss, ps::Params, data, opt)` is Zygote-specific.
Only works with [Yota.jl](https://github.com/dfdx/Yota.jl) and [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl),
and with the default [Zygote.jl](https://github.com/FluxML/Zygote.jl).
!!! note
This is experimental!
"""
macro train_autodiff(pkg)
if pkg == :Diffractor
return quote
Diffractor.gradient(sin, 0.0)[1] 1.0 # ensures an error if not loaded
function Flux.Train.explicit_withgradient(f, args...)
y, back = Diffractor.∂⃖¹(f, args...)
dy1 = Flux.Zygote.sensitivity(y) # Zygote is loaded, and this gives nice errors
return (; value = y, gradient = Base.tail(back(dy1)))
end
end |> esc
elseif pkg == :Yota
return quote
Yota.grad(sin, 0.0) # [2][1] ≈ 1.0
function Flux.Train.explicit_withgradient(f, args...)
value, (_, gradient...) = Yota.grad(f, args...)
return (; value, gradient)
end
end |> esc
elseif pkg == :Zygote
return quote
Flux.Train.explicit_withgradient(f, args...) = Flux.Zygote.withgradient(f, args...)
end |> esc
else
throw("@train_autodiff expects either Zygote, Yota, or Diffractor. No other arguments are understood.")
end
end


### Misc. related utilities

"""
Expand Down
6 changes: 6 additions & 0 deletions src/train/explicit_train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ end
which evaluates the gradient of `loss(model, x1, y1)` with respect to `model`,
to know how to update the parameters stored within `model`.
To change the package used to calculate gradients, enter `using Yota; Flux.@train_autodiff Yota`
to use [Yota.jl](https://github.com/dfdx/Yota.jl). The same command works with [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl).
It is often convenient to provide the function `loss` using `do` block syntax,
instead of defining a named function:
```
Expand Down Expand Up @@ -100,6 +103,9 @@ This calls `Zygote.withgradient(m -> Flux.crossentropy(m(x1), y1), model)`.
Then it updates the parameters contained within `model` according
to the chosen `opt`imiser.
Finally it returns the value of the loss function.
To change the package used to calculate gradients, enter `using Yota; Flux.@train_autodiff Yota`
to use [Yota.jl](https://github.com/dfdx/Yota.jl). The same command works with [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl).
"""
function train!(loss::Function, model, opt::FluxState)
_initialise!(opt, model)
Expand Down
46 changes: 44 additions & 2 deletions test/train.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Flux.Train
using Zygote: Params, gradient

import Optimisers, FillArrays, ComponentArrays
import Optimisers, FillArrays, ComponentArrays, Yota

using Test
using Random
Expand All @@ -22,7 +22,7 @@ using Random
end
end

@testset "Explicit train!" begin
@testset "Explicit train! with Zygote" begin
Random.seed!(84)
w = randn(10, 10)
w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset.
Expand Down Expand Up @@ -70,6 +70,48 @@ end
end
end

using Yota
using Flux: Descent, Adam, AdamW, FluxState
Flux.@train_autodiff Yota

@testset "Explicit train! with Yota" begin
Random.seed!(84)
w = randn(10, 10)
w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset.
@testset for opt in [Descent(0.1), Adam(), AdamW()]
@test opt isa FluxState
@test opt.state isa Missing

loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
@test loss(model, rand(10, 10)) > 1

train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
@test loss(model, rand(10, 10)) < 0.01
@test opt.state isa NamedTuple
end

# Test 3-arg `train!` method:
@testset for opt in [Descent(0.1), Adam(), AdamW()]
@test opt isa FluxState
@test opt.state isa Missing

loss(m) = let x = rand(10)
Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
end
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
@test loss(model) > 1

for i in 1:10^5
train!(loss, model, opt)
end
@test loss(model) < 0.01
@test opt.state isa NamedTuple
end
end

Flux.@train_autodiff Zygote

#=
@testset "update!: handle Fills from Zygote" begin
Expand Down

0 comments on commit faf572f

Please sign in to comment.