Skip to content

Commit

Permalink
GPU code 3-4 faster on 1080 than CPU. expect another 3x on V100s
Browse files Browse the repository at this point in the history
  • Loading branch information
vpuri3 committed Jul 20, 2023
1 parent 6013a9c commit ba5dec7
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 48 deletions.
45 changes: 42 additions & 3 deletions examples/cuda_perf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ W = CUDA.rand(Co, Ci, M)
GC.gc()

println("# Y[m, co, b] = W[co, ci, m] * X[m, ci, b]")
W = reshape(W, (Co, Ci, M))
@btime CUDA.@sync @tullio Y[m, co, b] := W[co, ci, m] * X[m, ci, b]

GC.gc()
Expand Down Expand Up @@ -193,7 +192,7 @@ end
4.486 ms (163 allocations: 7.92 KiB)
"""

if true
if false
println("#==================#")
println("### Tullio Bilinear tests ###")
println("# with x/y[C, M, B]")
Expand Down Expand Up @@ -250,5 +249,45 @@ end
"""

nothing
###
# Gradient computation
###

using Zygote

# CUDA.@captured
# https://juliagpu.org/post/2021-06-10-cuda_3.3/#high-level_graph_apis
# https://github.com/JuliaGPU/CUDA.jl/blob/a8c55aed276892aeb7bbe5220448a5ca5922a9be/test/core/cudadrv.jl#L380-L395

C1, C2, Co = 32, 32, 4
M = 1024
B = 100

X = CUDA.rand(C1, M, B)
Y = CUDA.rand(C2, M, B)
W = CUDA.rand(Co, C1, C2, M)

function loss(X, Y, W)
# @tullio Z[co, m, b] := X[c1, m, b] * W[co, c1, c2, m] * Y[c2, m, b]

@tullio Z1[co, c1, m, b] := W[co, c1, c2, m] * Y[c2, m, b]
@tullio Z2[co, m, b] := Z1[co, c1, m, b] * X[c1, m, b]
sum(Z2)
end

function grad(X, Y, W)
f = W -> loss(X, Y, W)
l, pb = Zygote.pullback(f, W)
pb(one.(l))
end

CUDA.@time loss(X, Y, W);
CUDA.@time loss(X, Y, W);

CUDA.@time grad(X, Y, W);
CUDA.@time grad(X, Y, W);

CUDA.@profile CUDA.@time grad(X, Y, W);

GC.gc(false)
#
50 changes: 12 additions & 38 deletions examples/diffusion_fourier/gpu/bilin_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ include("../datagen.jl")

# parameters
N = 128 # problem size
K1 = 32 # ν-samples
K2 = 32 # f-samples
E = 20 # epochs
K1 = 16 # ν-samples
K2 = 16 # f-samples
E = 200 # epochs

rng = Random.default_rng()
Random.seed!(rng, 117)
Expand All @@ -62,9 +62,11 @@ c = size(__data[1], 1) # in channels
o = size(__data[2], 1) # out channels

NN = Lux.Chain(
PermutedBatchNorm(c, 3),
Dense(c , w, tanh),
OpKernel(w, w, m, tanh),
OpKernel(w, w, m, tanh),
OpKernel(w, w, m, tanh),
Dense(w , o)
)

Expand All @@ -75,14 +77,14 @@ dir = joinpath(@__DIR__, "dump")

model, _ = train_model(rng, NN, __data, data__, _V, opt;
learning_rates, maxiters, dir, cbstep = 1, device = gpu)

GC.gc()
end

###
# Bilinear (linear / nonlin) model
###

if false
if true

__data = split_data(_data)
data__ = split_data(data_)
Expand All @@ -94,47 +96,19 @@ c1 = size(__data[1][1], 1) # in channel nonlin
c2 = size(__data[1][2], 1) # in channel linear
o = size(__data[2] , 1) # out channel

# NN = linear_nonlinear(Dense(c1, w1, tanh), Dense(c2, w2), Bilinear((w1, w2) => o))
NN = linear_nonlinear(Dense(c1, w1, tanh), Dense(c2, w2), OpConvBilinear(w1, w2, o, m))
# NN = OpConvBilinear(c1, c2, o, m) # fast
nonlin = Chain(PermutedBatchNorm(c1, 3), Dense(c1, w1, tanh), OpKernel(w1, w1, m, tanh))
linear = Dense(c2, w2, use_bias = false)
bilin = OpConvBilinear(w1, w2, o, m)

# NN = linear_nonlinear(Dense(c1, w1, tanh), NoOpLayer(), OpConvBilinear(w1, c2, o, m))
NN = linear_nonlinear(nonlin, linear, bilin)

opt = Optimisers.Adam()
learning_rates = (1f-3,)
maxiters = E .* (1.00,) .|> Int
dir = joinpath(@__DIR__, "dump")

model, _ = train_model(rng, NN, __data, data__, _V, opt;
learning_rates, maxiters, dir, cbstep = 1, device = gpu)

end

if true

w1 = 32
w2 = 32
wo = 1
m = (32,)
K = 100

__data = ((rand(w1, N, K), rand(w2, N, K)), rand(wo, N, K))
data__ = ((rand(w1, N, K), rand(w2, N, K)), rand(wo, N, K))

c1 = size(__data[1][1], 1) # in channel nonlin
c2 = size(__data[1][2], 1) # in channel linear
o = size(__data[2] , 1) # out channel

NN = OpConvBilinear(w1, w2, wo, m)

opt = Optimisers.Adam()
learning_rates = (1f-3,)
maxiters = E .* (1.00,) .|> Int
dir = joinpath(@__DIR__, "dump")

model, _ = train_model(rng, NN, __data, data__, _V, opt;
learning_rates, maxiters, dir, cbstep = 1, device = gpu)

learning_rates, maxiters, dir, cbstep = 1, device = cpu)
end

nothing
Expand Down
6 changes: 4 additions & 2 deletions src/operator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,11 @@ function opconv_wt(x, y, W)
Y = reshape(y, (C2, prod(modes), B)) # [C2, M, B]

# apply weight to get [Co, M, B]
@tullio Z[co, m, b] := X[c1, m, b] * W[co, c1, c2, m] * Y[c2, m, b]
# @tullio Z[co, m, b] := X[c1, m, b] * W[co, c1, c2, m] * Y[c2, m, b]
@tullio Z1[co, c1, m, b] := W[co, c1, c2, m] * Y[c2, m, b]
@tullio Z2[co, m, b] := Z1[co, c1, m, b] * X[c1, m, b]

# un-reshape
reshape(Z, (Co, modes..., B))
reshape(Z2, (Co, modes..., B))
end
#
23 changes: 18 additions & 5 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ function train_model(
@assert data[2] isa AbstractArray
end

_data, data_ = (_data, data_) |> device
_devicedata, devicedata_ = (_data, data_) |> device

@assert length(learning_rates) == length(maxiters)

# utility functions
_model, _loss, _stats = model_setup(NN, _data; lossfun)
model_, loss_, stats_ = model_setup(NN, data_; lossfun)
_model, _loss, _stats = model_setup(NN, _devicedata; lossfun)
model_, loss_, stats_ = model_setup(NN, devicedata_; lossfun)

# analysis callback
CB = (p, st; io = io) -> callback(p, st; io, _loss, _stats, loss_, stats_)
Expand Down Expand Up @@ -110,10 +110,23 @@ function train_model(
CB(p, st; io = statsfile)
close(statsfile)

# visualization
# transfer to host device and free stuff
if device == Lux.gpu
if _data[1] isa AbstractArray
CUDA.unsafe_free!(_devicedata[1])
CUDA.unsafe_free!(devicedata_[2])
else
CUDA.unsafe_free!.(_devicedata[1])
CUDA.unsafe_free!.(devicedata_[1])
end

CUDA.unsafe_free!(_devicedata[2])
CUDA.unsafe_free!(devicedata_[2])
end

p, st = (p, st) |> Lux.cpu
_data, data_ = (_data, data_) |> Lux.cpu

# visualization
plt_train = plot_training(ITER, _LOSS, LOSS_)
plts = visualize(V, _data, data_, NN, p, st; nsamples)

Expand Down

0 comments on commit ba5dec7

Please sign in to comment.