-
-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
112 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "Optimisers" | ||
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" | ||
authors = ["Mike J Innes <[email protected]>"] | ||
version = "0.2.2" | ||
version = "0.2.2" | ||
|
||
[deps] | ||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
|
@@ -16,9 +16,12 @@ Functors = "0.2.8" | |
julia = "1.6" | ||
|
||
[extras] | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" | ||
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | ||
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" | ||
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
|
||
[targets] | ||
test = ["Test", "StaticArrays", "Zygote"] | ||
test = ["Adapt", "CUDA", "GPUArrays", "StaticArrays", "Test", "Zygote"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
using Optimisers | ||
using ChainRulesCore #, Functors, StaticArrays, Zygote | ||
using LinearAlgebra, Statistics, Test | ||
|
||
import CUDA | ||
if CUDA.functional() | ||
using CUDA # exports CuArray, etc | ||
@info "starting CUDA tests" | ||
else | ||
@info "CUDA not functional, testing via GPUArrays" | ||
using GPUArrays | ||
GPUArrays.allowscalar(false) | ||
|
||
# GPUArrays provides a fake GPU array, for testing | ||
jl_file = normpath(joinpath(pathof(GPUArrays), "..", "..", "test", "jlarray.jl")) | ||
using Random, Adapt # loaded within jl_file | ||
include(jl_file) | ||
using .JLArrays | ||
cu = jl | ||
CuArray{T,N} = JLArray{T,N} | ||
end | ||
|
||
@test cu(rand(3)) .+ 1 isa CuArray | ||
|
||
@testset "very basics" begin | ||
m = (cu([1.0, 2.0]),) | ||
mid = objectid(m[1]) | ||
g = (cu([25, 33]),) | ||
o = Descent(0.1f0) | ||
s = Optimisers.setup(o, m) | ||
|
||
s2, m2 = Optimisers.update(s, m, g) | ||
@test Array(m[1]) == 1:2 # not mutated | ||
@test m2[1] isa CuArray | ||
@test Array(m2[1]) ≈ [1,2] .- 0.1 .* [25, 33] atol=1e-6 | ||
|
||
s3, m3 = Optimisers.update!(s, m, g) | ||
@test objectid(m3[1]) == mid | ||
@test Array(m3[1]) ≈ [1,2] .- 0.1 .* [25, 33] atol=1e-6 | ||
|
||
g4 = Tangent{typeof(m)}(g...) | ||
s4, m4 = Optimisers.update!(s, (cu([1.0, 2.0]),), g4) | ||
@test Array(m4[1]) ≈ [1,2] .- 0.1 .* [25, 33] atol=1e-6 | ||
end | ||
|
||
@testset "basic mixed" begin | ||
# Works trivially as every element of the tree is either here or there | ||
m = (device = cu([1.0, 2.0]), host = [3.0, 4.0], neither = (5, 6, sin)) | ||
s = Optimisers.setup(ADAM(0.1), m) | ||
@test s.device.state[1] isa CuArray | ||
@test s.host.state[1] isa Array | ||
|
||
g = (device = cu([1, 0.1]), host = [1, 10], neither = nothing) | ||
s2, m2 = Optimisers.update(s, m, g) | ||
|
||
@test m2.device isa CuArray | ||
@test Array(m2.device) ≈ [0.9, 1.9] atol=1e-6 | ||
|
||
@test m2.host isa Array | ||
@test m2.host ≈ [2.9, 3.9] | ||
end | ||
|
||
RULES = [ | ||
# Just a selection: | ||
Descent(), ADAM(), RMSProp(), NADAM(), | ||
# A few chained combinations: | ||
OptimiserChain(WeightDecay(), ADAM(0.001)), | ||
OptimiserChain(ClipNorm(), ADAM(0.001)), | ||
OptimiserChain(ClipGrad(0.5), Momentum()), | ||
] | ||
|
||
name(o) = typeof(o).name.name # just for printing testset headings | ||
name(o::OptimiserChain) = join(name.(o.opts), " → ") | ||
|
||
@testset "rules: simple sum" begin | ||
@testset "$(name(o))" for o in RULES | ||
m = cu(shuffle!(reshape(1:64, 8, 8) .+ 0.0)) | ||
s = Optimisers.setup(o, m) | ||
for _ in 1:10 | ||
g = Zygote.gradient(x -> sum(abs2, x + x'), m)[1] | ||
s, m = Optimisers.update!(s, m, g) | ||
end | ||
@test sum(m) < sum(1:64) | ||
end | ||
end | ||
|
||
@testset "destructure GPU" begin | ||
m = (x = cu(Float32[1,2,3]), y = (0, 99), z = cu(Float32[4,5])) | ||
v, re = destructure(m) | ||
@test v isa CuArray | ||
@test re(2v).x isa CuArray | ||
end | ||
|
||
@testset "destructure mixed" begin | ||
# Not sure what should happen here! | ||
m_c1 = (x = cu(Float32[1,2,3]), y = Float32[4,5]) | ||
v, re = destructure(m_c1) | ||
@test re(2v).x isa CuArray | ||
@test_broken re(2v).y isa Array | ||
|
||
m_c2 = (x = Float32[1,2,3], y = cu(Float32[4,5])) | ||
@test_skip destructure(m_c2) # ERROR: Scalar indexing | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters