Skip to content

Commit

Permalink
training on cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
vpuri3 committed Aug 9, 2023
1 parent 05ee83d commit 81f035c
Show file tree
Hide file tree
Showing 17 changed files with 81 additions and 48 deletions.
10 changes: 6 additions & 4 deletions examples/pdebench/burgers_bilinear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ rng = Random.default_rng()
Random.seed!(rng, 983254)

N = 1024
E = 100
E = 20

# trajectories
_K = 512
Expand All @@ -59,8 +59,8 @@ V = FourierSpace(N)
###

#============================#
w = 32 # width
l = 8
w = 16 # width
l = 4
m = (128,) # modes
c = size(_data[1], 1) # in channels
o = size(_data[2], 1) # out channels
Expand All @@ -70,10 +70,12 @@ root = Chain(
OpKernel(w, w, m, relu),
OpKernel(w, w, m, relu),
)

branch = Chain(
OpKernel(w, w, m, relu), # use_bias = true
OpKernel(w, w, m, relu),
)

fuse = OpConvBilinear(w, w, l, m)

project = Chain(
Expand All @@ -91,7 +93,7 @@ NN = Chain(
#============================#

opt = Optimisers.Adam()
batchsize = 32
batchsize = 64

learning_rates = (1f-2, 1f-3, 1f-4, 1f-5)
nepochs = E .* (0.25, 0.25, 0.25, 0.25) .|> Int
Expand Down
22 changes: 12 additions & 10 deletions examples/pdebench/darcy2d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,18 @@ import Lux: cpu, gpu
using Tullio, Zygote

using FFTW, LinearAlgebra
BLAS.set_num_threads(20)
FFTW.set_num_threads(40)
BLAS.set_num_threads(4)
FFTW.set_num_threads(8)

rng = Random.default_rng()
Random.seed!(rng, 345)

N = 128
E = 10 # epochs
E = 40

# trajectories
_K = 512
K_ = 128
_K = 128
K_ = 32

# get data
dir = @__DIR__
Expand Down Expand Up @@ -70,14 +70,16 @@ NN = Lux.Chain(
)

opt = Optimisers.Adam()
batchsize = 32
learning_rates = (1f-2, 1f-3,)
nepochs = E .* (0.10, 0.90,) .|> Int
batchsize = 16
learning_rates = (1f-2, 1f-3, 5f-4, 2.5f-4,)
nepochs = E .* (0.25, 0.25, 0.25, 0.25,) .|> Int
dir = joinpath(@__DIR__, "model_darcy2D")
device = Lux.gpu
device = Lux.cpu # Lux.gpu

model, ST = train_model(rng, NN, _data, data_, V, opt;
batchsize, learning_rates, nepochs, dir, device)

nothing
plot_training(ST...)

# nothing
#
Binary file added examples/pdebench/dump/plt_r2_test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/pdebench/dump/plt_r2_train.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/pdebench/dump/plt_training.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/pdebench/dump/plt_traj_test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/pdebench/dump/plt_traj_train.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
20 changes: 11 additions & 9 deletions examples/pdebench/dump/statistics.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
TRAIN LOSS: 0.00051646 TEST LOSS: 0.00079046
TRAIN LOSS: 0.00048486 TEST LOSS: 0.00611112
#======================#
TRAIN STATS
R² score: 0.46964792
mean SQR error: 0.00078803
mean ABS error: 0.01567526
max ABS error: 0.35974452
R² score: 0.9978241
MSE (mean SQR error): 0.00071497
RMSE (root mean SQR error): 0.02673884
MAE (mean ABS error): 0.0173665
maxAE (max ABS error) 0.29969817

#======================#
#======================#
TEST STATS
R² score: -0.22938039
mean SQR error: 0.00111063
mean ABS error: 0.02132492
max ABS error: 0.29181193
R² score: 0.9746224
MSE (mean SQR error): 0.00838597
RMSE (root mean SQR error): 0.09157494
MAE (mean ABS error): 0.06254827
maxAE (max ABS error) 0.4538644

#======================#
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/pdebench/model_burgers1D_nu0.001_bilinear/plt_r2_train.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 11 additions & 11 deletions examples/pdebench/model_burgers1D_nu0.001_bilinear/statistics.txt
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
TRAIN LOSS: 0.00090023 TEST LOSS: 0.00972281
TRAIN LOSS: 0.00579758 TEST LOSS: 0.01066288
#======================#
TRAIN STATS
R² score: 0.9961141
MSE (mean SQR error): 0.00132547
RMSE (root mean SQR error): 0.03640705
MAE (mean ABS error): 0.02396612
maxAE (max ABS error) 0.30083415
R² score: 0.9823463
MSE (mean SQR error): 0.00589802
RMSE (root mean SQR error): 0.07679854
MAE (mean ABS error): 0.05752488
maxAE (max ABS error) 0.3610783

#======================#
#======================#
TEST STATS
R² score: 0.9676699
MSE (mean SQR error): 0.01123419
RMSE (root mean SQR error): 0.10599148
MAE (mean ABS error): 0.07846318
maxAE (max ABS error) 0.3812811
R² score: 0.96915275
MSE (mean SQR error): 0.01066288
RMSE (root mean SQR error): 0.10326123
MAE (mean ABS error): 0.07668994
maxAE (max ABS error) 0.4150022

#======================#
Binary file added examples/pdebench/model_darcy2D/plt_training.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 19 additions & 0 deletions examples/pdebench/model_darcy2D/statistics.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
TRAIN LOSS: 0.00020804 TEST LOSS: 8.247e-5
#======================#
TRAIN STATS
R² score: 0.76355976
MSE (mean SQR error): 0.0003289
RMSE (root mean SQR error): 0.01813549
MAE (mean ABS error): 0.00971096
maxAE (max ABS error) 0.26131245

#======================#
#======================#
TEST STATS
R² score: 0.85738647
MSE (mean SQR error): 0.00025149
RMSE (root mean SQR error): 0.01585852
MAE (mean ABS error): 0.00898928
maxAE (max ABS error) 0.16604173

#======================#
36 changes: 22 additions & 14 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,26 @@ function train_model(
_stats = (p, st; io = io) -> statistics(NN, p, st, _loader; io)
stats_ = (p, st; io = io) -> statistics(NN, p, st, loader_; io)

# full batch losses for CB
# full batch losses for cb_stats
_loss = (p, st) -> minibatch_metric(NN, p, st, _loader, lossfun)
loss_ = (p, st) -> minibatch_metric(NN, p, st, loader_, lossfun)

# callback functions
EPOCH = Int[]
_LOSS = Float32[]
_LOSS_MINIBATCH = Float32[]
LOSS_ = Float32[]

# callback for printing statistics
CB = (p, st; io = io) -> callback(p, st; io, _loss, _stats, loss_, stats_)
cb_stats = (p, st; io = io) -> callback(p, st; io, _loss, _stats, loss_, stats_)

# cb_batch =

# early stopping: (need fullbatch validation loss for early stopping?)
# https://github.com/jeffheaton/app_deep_learning/blob/main/t81_558_class_03_4_early_stop.ipynb

# callback for training
cb = (p, st, epoch, nepoch; io = io) -> callback(p, st; io,
cb_epoch = (p, st, epoch, nepoch; io = io) -> callback(p, st; io,
_loss, _LOSS, loss_, LOSS_,
EPOCH, epoch, nepoch, step = cbstep)

Expand All @@ -86,7 +92,7 @@ function train_model(
p, st = (p, st) |> device

# print stats
CB(p, st)
cb_stats(p, st)

println(io, "#======================#")
println(io, "Starting Trainig Loop")
Expand All @@ -107,19 +113,19 @@ function train_model(
println(io, "#======================#")

if device == Lux.gpu
CUDA.@time p, st, opt_st = optimize(NN, p, st, _loader, nepoch; lossfun, opt, opt_st, cb, io)
CUDA.@time p, st, opt_st = optimize(NN, p, st, _loader, nepoch; lossfun, opt, opt_st, cb_epoch, io)
else
@time p, st, opt_st = optimize(NN, p, st, _loader, nepoch; lossfun, opt, opt_st, cb, io)
@time p, st, opt_st = optimize(NN, p, st, _loader, nepoch; lossfun, opt, opt_st, cb_epoch, io)
end

CB(p, st)
cb_stats(p, st)
end

# TODO - output a train.log file with timings

# save statistics
statsfile = open(joinpath(dir, "statistics.txt"), "w")
CB(p, st; io = statsfile)
cb_stats(p, st; io = statsfile)
close(statsfile)

# transfer model to host device
Expand Down Expand Up @@ -339,12 +345,13 @@ function optimize(NN, p, st, loader, nepochs;
lossfun = mse,
opt = Optimisers.Adam(),
opt_st = nothing,
cb = nothing,
cb_batch = nothing,
cb_epoch = nothing,
io::Union{Nothing, IO} = stdout,
)

# print stats
!isnothing(cb) && cb(p, st, 0, nepochs; io)
!isnothing(cb_epoch) && cb_epoch(p, st, 0, nepochs; io)

function loss(x, ŷ, p, st)
y, st = NN(x, p, st)
Expand All @@ -367,12 +374,13 @@ function optimize(NN, p, st, loader, nepochs;
opt_st, p = Optimisers.update!(opt_st, p, g)

println(io, "Epoch [$epoch / $nepochs]" * "\t Batch loss: $l")

# GC.gc(false)
!isnothing(cb_batch) && cb_batch(p, st, step)
end

# TODO - add stopping criteria for GD

println(io, "#=======================#")
!isnothing(cb) && cb(p, st, epoch, nepochs; io)
!isnothing(cb_epoch) && cb_epoch(p, st, epoch, nepochs; io)
println(io, "#=======================#")
end

Expand All @@ -396,7 +404,7 @@ function plot_training(EPOCH, _LOSS, LOSS_; dir = nothing)
xlabel = "Epochs", ylabel = "Loss (MSE)",
ylims = (minimum(_LOSS) / 10, maximum(LOSS_) * 10))

plot!(plt, EPOCH, _LOSS, w = 2.0, c = :green, label = "Train Dataset")
plot!(plt, EPOCH, _LOSS, w = 2.0, c = :green, label = "Train Dataset") # (; ribbon = (lower, upper))
plot!(plt, EPOCH, LOSS_, w = 2.0, c = :red, label = "Test Dataset")

vline!(plt, EPOCH[z[2:end]], c = :black, w = 2.0, label = nothing)
Expand Down

0 comments on commit 81f035c

Please sign in to comment.