-
-
Notifications
You must be signed in to change notification settings - Fork 608
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
127 additions
and
381 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ using LinearAlgebra | |
using Optimisers: Optimisers | ||
using Functors: fmap | ||
|
||
export train!, update!, adjust!, FluxState, @epochs, | ||
export train!, update!, adjust!, FluxState, | ||
Descent, Adam, Momentum, Nesterov, RMSProp, | ||
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief #, | ||
# InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser, | ||
|
@@ -15,7 +15,7 @@ export train!, update!, adjust!, FluxState, @epochs, | |
|
||
""" | ||
FluxState(rule, state=missing) | ||
This is an interface between the all-mutable world Flux.jl likes, | ||
and the could-be-immutable world that Optimisers.jl inhabits. | ||
|
@@ -56,34 +56,14 @@ end | |
|
||
### Two styles of gradient, and their `train!` functions | ||
|
||
using ProgressLogging: @progress, @withprogress, @logprogress | ||
using ProgressLogging: @progress, @withprogress, @logprogress # TODO add progress logging again | ||
using Zygote: Zygote, Params | ||
|
||
include("explicit_train.jl.jl") # new! | ||
include("implicit_train.jl.jl") # Params etc, Zygote only | ||
include("explicit_train.jl") # new! | ||
include("implicit_train.jl") # Params etc, Zygote only | ||
|
||
explicit_withgradient(f, args...) = Zygote.withgradient(f, args...) # can overload this to use e.g. Yota / Diffractor | ||
|
||
# using Requires # Flux doesn't use this right now | ||
# @init @require Diffractor="9f5e2b26-1114-432f-b630-d3fe2085c51c" begin | ||
# @eval function explicit_withgradient(f, args...) | ||
# y, back = Diffractor.∂⃖¹(f, args...) | ||
# _, grads... = back(Zygote.sensitivity(y)) | ||
# return (; value = y, gradient = grads) | ||
# end | ||
# end | ||
|
||
#= | ||
using Diffractor | ||
function Flux.Train.explicit_withgradient(f, args...) | ||
y, back = Diffractor.∂⃖¹(f, args...) | ||
_, grads... = back(one(y)) | ||
return (; value = y, gradient = grads) | ||
end | ||
=# | ||
|
||
### Misc. related utilities | ||
|
||
""" | ||
|
@@ -107,94 +87,4 @@ function adjust!(opt::FluxState, eta::Real) | |
return opt | ||
end | ||
|
||
""" | ||
@epochs N body | ||
Run `body` expression `N` times. Mainly useful for quickly doing | ||
multiple epochs of training in a REPL. | ||
Functionally equivalent to this loop: | ||
``` | ||
for _ in 1:N | ||
body | ||
end | ||
``` | ||
... but adds progress logging and `@info` messages, | ||
and returns the result of the last iteration. | ||
# Examples | ||
```jldoctest | ||
julia> Flux.@epochs 2 println("hello") | ||
[ Info: Epoch 1 | ||
hello | ||
[ Info: Epoch 2 | ||
hello | ||
``` | ||
""" | ||
macro epochs(n, ex) | ||
@gensym val | ||
body = :(for i in 1:$(esc(n)) | ||
@info "Epoch $i" | ||
$(esc(val)) = $(esc(ex)) | ||
end) | ||
loop = Expr(:macrocall, Symbol("@progress"), __source__, body) | ||
Expr(:block, :($(esc(val)) = nothing), loop, :($(esc(val)))) | ||
# TODO make this actualy return the value? Names aren't right. | ||
# | ||
# $loop | ||
# # @progress for i in 1:$(esc(n)) | ||
# # @info "Epoch $i" | ||
# # $(esc(val)) = $(esc(ex)) | ||
# # end | ||
# $val # DOESN"T WORK! Expr(:macrocall, ...) ? | ||
# end | ||
end | ||
|
||
end | ||
|
||
|
||
#= | ||
using Flux, Random | ||
data = [(rand(3,2).*[i,1,20/i], [i i]) for i in 1:50] |> shuffle!; | ||
# This exact code works on [email protected]. There, train! returns nothing: | ||
model2 = Chain(Dense(3 => 7, relu), Dense(7 => 1)) | ||
opt2 = Flux.Adam() | ||
Flux.train!(Flux.params(model2), data, opt2) do x, y | ||
Flux.mse(model2(x), y) | ||
end | ||
opt2 # contains an IdDict | ||
# This is the new "explicit" method of Train | ||
model1 = Chain(Dense(3 => 7, relu), Dense(7 => 1)) | ||
opt1 = Flux.Adam() | ||
Flux.train!(model1, data, opt1) do m, x, y | ||
Flux.mse(m(x), y) | ||
end |> sum | ||
opt1 # contains state tree | ||
# This is new 3-arg train!, one step not an iteration over data: | ||
x1, y1 = data[1] | ||
Flux.train!(model1, opt1) do m | ||
Flux.mse(m(x1), y1) | ||
end | ||
julia> using ProgressLogging | ||
julia> @macroexpand1 @loop N body | ||
begin | ||
x = nothing | ||
@progress for i in 1:N | ||
@info "step $i" | ||
x = body | ||
end | ||
x | ||
end | ||
=# | ||
end # module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.