diff --git a/Project.toml b/Project.toml index 07a7098b01..86ad76dd52 100644 --- a/Project.toml +++ b/Project.toml @@ -39,6 +39,7 @@ ProgressLogging = "0.1" Reexport = "0.2, 1.0" SpecialFunctions = "1.8.2, 2.1.2" StatsBase = "0.33" +Yota = "0.7.4" Zygote = "0.6.34" julia = "1.6" @@ -49,6 +50,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Yota = "cd998857-8626-517d-b929-70ad188a48f0" [targets] -test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"] +test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Yota"] diff --git a/src/train/Train.jl b/src/train/Train.jl index ac5724dbc6..44a210df8b 100644 --- a/src/train/Train.jl +++ b/src/train/Train.jl @@ -4,7 +4,7 @@ using LinearAlgebra using Optimisers: Optimisers using Functors: fmap -export train!, update!, adjust!, FluxState, +export train!, update!, adjust!, FluxState, @train_autodiff, Descent, Adam, Momentum, Nesterov, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief #, # InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser, @@ -108,6 +108,52 @@ include("implicit_train.jl") # Params etc, Zygote only explicit_withgradient(f, args...) = Zygote.withgradient(f, args...) # can overload this to use e.g. Yota / Diffractor +""" + Flux.@train_autodiff Zygote + Flux.@train_autodiff Yota + Flux.@train_autodiff Diffractor + +This macro allows the use of `train!` with various automatic differentiation packages, +instead of the default Zygote.jl. +You should load the package, then call this macro. + +Only affects "explicit-mode" versions `train!(loss, model, data, opt)` or `train!(loss, model, opt)`, +since the (deprecated) "implicit-mode" `train!(loss, ps::Params, data, opt)` is Zygote-specific. + +Only works with [Yota.jl](https://github.com/dfdx/Yota.jl) and [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl), +and with the default [Zygote.jl](https://github.com/FluxML/Zygote.jl). + +!!! note + This is experimental! +""" +macro train_autodiff(pkg) + if pkg == :Diffractor + return quote + Diffractor.gradient(sin, 0.0)[1] ≈ 1.0 # ensures an error if not loaded + function Flux.Train.explicit_withgradient(f, args...) + y, back = Diffractor.∂⃖¹(f, args...) + dy1 = Flux.Zygote.sensitivity(y) # Zygote is loaded, and this gives nice errors + return (; value = y, gradient = Base.tail(back(dy1))) + end + end |> esc + elseif pkg == :Yota + return quote + Yota.grad(sin, 0.0) # [2][1] ≈ 1.0 + function Flux.Train.explicit_withgradient(f, args...) + value, (_, gradient...) = Yota.grad(f, args...) + return (; value, gradient) + end + end |> esc + elseif pkg == :Zygote + return quote + Flux.Train.explicit_withgradient(f, args...) = Flux.Zygote.withgradient(f, args...) + end |> esc + else + throw("@train_autodiff expects either Zygote, Yota, or Diffractor. No other arguments are understood.") + end +end + + ### Misc. related utilities """ diff --git a/src/train/explicit_train.jl b/src/train/explicit_train.jl index 78d5eb46fc..cb1e2aea5d 100644 --- a/src/train/explicit_train.jl +++ b/src/train/explicit_train.jl @@ -25,6 +25,9 @@ end which evaluates the gradient of `loss(model, x1, y1)` with respect to `model`, to know how to update the parameters stored within `model`. +To change the package used to calculate gradients, enter `using Yota; Flux.@train_autodiff Yota` +to use [Yota.jl](https://github.com/dfdx/Yota.jl). The same command works with [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl). + It is often convenient to provide the function `loss` using `do` block syntax, instead of defining a named function: ``` @@ -100,6 +103,9 @@ This calls `Zygote.withgradient(m -> Flux.crossentropy(m(x1), y1), model)`. Then it updates the parameters contained within `model` according to the chosen `opt`imiser. Finally it returns the value of the loss function. + +To change the package used to calculate gradients, enter `using Yota; Flux.@train_autodiff Yota` +to use [Yota.jl](https://github.com/dfdx/Yota.jl). The same command works with [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl). """ function train!(loss::Function, model, opt::FluxState) _initialise!(opt, model) diff --git a/test/train.jl b/test/train.jl index f6f54860dd..8e15cd3e15 100644 --- a/test/train.jl +++ b/test/train.jl @@ -1,7 +1,7 @@ using Flux.Train using Zygote: Params, gradient -import Optimisers, FillArrays, ComponentArrays +import Optimisers, FillArrays, ComponentArrays, Yota using Test using Random @@ -22,7 +22,7 @@ using Random end end -@testset "Explicit train!" begin +@testset "Explicit train! with Zygote" begin Random.seed!(84) w = randn(10, 10) w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset. @@ -70,6 +70,48 @@ end end end +using Yota +using Flux: Descent, Adam, AdamW, FluxState +Flux.@train_autodiff Yota + +@testset "Explicit train! with Yota" begin + Random.seed!(84) + w = randn(10, 10) + w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset. + @testset for opt in [Descent(0.1), Adam(), AdamW()] + @test opt isa FluxState + @test opt.state isa Missing + + loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + model = (weight=copy(w2), bias=zeros(10), ignore=nothing) + @test loss(model, rand(10, 10)) > 1 + + train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) + @test loss(model, rand(10, 10)) < 0.01 + @test opt.state isa NamedTuple + end + + # Test 3-arg `train!` method: + @testset for opt in [Descent(0.1), Adam(), AdamW()] + @test opt isa FluxState + @test opt.state isa Missing + + loss(m) = let x = rand(10) + Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + end + model = (weight=copy(w2), bias=zeros(10), ignore=nothing) + @test loss(model) > 1 + + for i in 1:10^5 + train!(loss, model, opt) + end + @test loss(model) < 0.01 + @test opt.state isa NamedTuple + end +end + +Flux.@train_autodiff Zygote + #= @testset "update!: handle Fills from Zygote" begin