Skip to content

Commit

Permalink
test with GPUArrays
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Apr 22, 2022
1 parent dd571ca commit d8a5eb1
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 3 deletions.
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Optimisers"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
authors = ["Mike J Innes <[email protected]>"]
version = "0.2.2"
version = "0.2.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -16,9 +16,12 @@ Functors = "0.2.8"
julia = "1.6"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "StaticArrays", "Zygote"]
test = ["Adapt", "CUDA", "GPUArrays", "StaticArrays", "Test", "Zygote"]
103 changes: 103 additions & 0 deletions test/gpuarrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
using Optimisers
using ChainRulesCore #, Functors, StaticArrays, Zygote
using LinearAlgebra, Statistics, Test

import CUDA
if CUDA.functional()
using CUDA # exports CuArray, etc
@info "starting CUDA tests"
else
@info "CUDA not functional, testing via GPUArrays"
using GPUArrays
GPUArrays.allowscalar(false)

# GPUArrays provides a fake GPU array, for testing
jl_file = normpath(joinpath(pathof(GPUArrays), "..", "..", "test", "jlarray.jl"))
using Random, Adapt # loaded within jl_file
include(jl_file)
using .JLArrays
cu = jl
CuArray{T,N} = JLArray{T,N}
end

@test cu(rand(3)) .+ 1 isa CuArray

@testset "very basics" begin
m = (cu([1.0, 2.0]),)
mid = objectid(m[1])
g = (cu([25, 33]),)
o = Descent(0.1f0)
s = Optimisers.setup(o, m)

s2, m2 = Optimisers.update(s, m, g)
@test Array(m[1]) == 1:2 # not mutated
@test m2[1] isa CuArray
@test Array(m2[1]) [1,2] .- 0.1 .* [25, 33] atol=1e-6

s3, m3 = Optimisers.update!(s, m, g)
@test objectid(m3[1]) == mid
@test Array(m3[1]) [1,2] .- 0.1 .* [25, 33] atol=1e-6

g4 = Tangent{typeof(m)}(g...)
s4, m4 = Optimisers.update!(s, (cu([1.0, 2.0]),), g4)
@test Array(m4[1]) [1,2] .- 0.1 .* [25, 33] atol=1e-6
end

@testset "basic mixed" begin
# Works trivially as every element of the tree is either here or there
m = (device = cu([1.0, 2.0]), host = [3.0, 4.0], neither = (5, 6, sin))
s = Optimisers.setup(ADAM(0.1), m)
@test s.device.state[1] isa CuArray
@test s.host.state[1] isa Array

g = (device = cu([1, 0.1]), host = [1, 10], neither = nothing)
s2, m2 = Optimisers.update(s, m, g)

@test m2.device isa CuArray
@test Array(m2.device) [0.9, 1.9] atol=1e-6

@test m2.host isa Array
@test m2.host [2.9, 3.9]
end

RULES = [
# Just a selection:
Descent(), ADAM(), RMSProp(), NADAM(),
# A few chained combinations:
OptimiserChain(WeightDecay(), ADAM(0.001)),
OptimiserChain(ClipNorm(), ADAM(0.001)),
OptimiserChain(ClipGrad(0.5), Momentum()),
]

name(o) = typeof(o).name.name # just for printing testset headings
name(o::OptimiserChain) = join(name.(o.opts), "")

@testset "rules: simple sum" begin
@testset "$(name(o))" for o in RULES
m = cu(shuffle!(reshape(1:64, 8, 8) .+ 0.0))
s = Optimisers.setup(o, m)
for _ in 1:10
g = Zygote.gradient(x -> sum(abs2, x + x'), m)[1]
s, m = Optimisers.update!(s, m, g)
end
@test sum(m) < sum(1:64)
end
end

@testset "destructure GPU" begin
m = (x = cu(Float32[1,2,3]), y = (0, 99), z = cu(Float32[4,5]))
v, re = destructure(m)
@test v isa CuArray
@test re(2v).x isa CuArray
end

@testset "destructure mixed" begin
# Not sure what should happen here!
m_c1 = (x = cu(Float32[1,2,3]), y = Float32[4,5])
v, re = destructure(m_c1)
@test re(2v).x isa CuArray
@test_broken re(2v).y isa Array

m_c2 = (x = Float32[1,2,3], y = cu(Float32[4,5]))
@test_skip destructure(m_c2) # ERROR: Scalar indexing
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,7 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
@testset verbose=true "Optimisation Rules" begin
include("rules.jl")
end
@testset verbose=true "GPU" begin
include("gpuarrays.jl")
end
end

0 comments on commit d8a5eb1

Please sign in to comment.