Skip to content

Commit

Permalink
tests, II
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Aug 1, 2022
1 parent 01c7af1 commit cb927a8
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 8 deletions.
11 changes: 7 additions & 4 deletions src/train/explicit_train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ function train!(loss::Function, model, data, opt::FluxState)
s = opt.state
s isa IdDict && error("""Can't mix explicit & implicit modes!
Once `FluxState` is initialised by `train!` in one mode, it cannot be used in the other.""")
# TODO check whether this loop ought to be in another function, for perfomance / type-stability.
for d in data
l, (g, _...) = explicit_withgradient(loss, model, data_splat(d)...)
s, model = Optimisers.update!(s, model, g)
Expand Down Expand Up @@ -110,11 +111,13 @@ function train!(loss::Function, model, opt::FluxState)
end

# This method lets you use Optimisers.Descent() instead of Flux.Descent(), when there is no state
function train!(loss::Function, model, data, opt::Optimisers.AbstractRule)
function train!(loss::Function, model, data, rule::Optimisers.AbstractRule)
opt = FluxState(rule, missing)
_initialise!(opt, model)
fmap(opt.state, exclude = x -> x isa Optimsers.Leaf) do leaf
leaf.state isa Nothing || @warn "Optimiser state will be lost! Please wrap optimisation rule in `FluxState`, e.g. by using `Flux.Adam()`" leaf
@gensym warn_id
fmap(opt.state, exclude = x -> x isa Optimisers.Leaf) do leaf
leaf.state isa Nothing || @warn "Optimiser state will be discarded! Please wrap optimisation rule from Optimisers.jl in `FluxState`, e.g. by using `Flux.Adam()`" leaf maxlog=1 _id=warn_id
leaf
end
train!(loss, model, data, FluxState(opt))
train!(loss, model, data, opt)
end
38 changes: 34 additions & 4 deletions test/train.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Flux.Train
using Zygote: Params, gradient

import FillArrays, ComponentArrays
import Optimisers, FillArrays, ComponentArrays

using Test
using Random
Expand Down Expand Up @@ -29,10 +29,40 @@ end
w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset.
@testset for opt in [Descent(0.1), Adam()]
@test opt isa FluxState
w′ = copy(w2)
b = zeros(10)
@test opt.state isa Missing

loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
@test loss(model, rand(10, 10)) > 1

train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
@test loss(model, rand(10, 10)) < 0.01
@test opt.state isa NamedTuple
end

# Test 3-arg `train!` method:
@testset for opt in [Descent(0.1), Adam()]
@test opt isa FluxState
@test opt.state isa Missing

loss(m) = let x = rand(10)
Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
end
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
@test loss(model) > 1

for i in 1:10^5
train!(loss, model, opt)
end
@test loss(model) < 0.01
@test opt.state isa NamedTuple
end

# Test direct use of Optimisers.jl rule, only really OK for `Descent`:
@testset for opt in [Optimisers.Descent(0.1), Optimisers.Adam()]
@test opt isa Optimisers.AbstractRule
loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
model = (weight=w′, bias=b, ignore=nothing)
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
@test loss(model, rand(10, 10)) > 1
train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
@test loss(model, rand(10, 10)) < 0.01
Expand Down

0 comments on commit cb927a8

Please sign in to comment.