-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added Dense and Conv BatchEnsemble layers along with unit tests and e…
…xample on MNIST classification using LeNet5 (#4) * Added Dense and Conv BatchEnsemble layers along with unit tests and example on MNIST classification using LeNet5 * Merged dense batchensemble forward passes for rank>1 and rank=1; Fixed conv batchensemble unit test * Changes: 1. Reduce imports and move them to main file 2. Renamed test file names 3. Added GPU tests for layers -- for now it's basic forward pass etc
- Loading branch information
1 parent
b11fbea
commit ed3d16c
Showing
11 changed files
with
728 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
## Classification of MNIST dataset | ||
## with the convolutional neural network known as LeNet5. | ||
## This script also combines various | ||
## packages from the Julia ecosystem with Flux. | ||
using Flux | ||
using Flux.Data: DataLoader | ||
using Flux.Optimise: Optimiser, WeightDecay | ||
using Flux: onehotbatch, onecold, glorot_normal, label_smoothing | ||
using Flux.Losses: logitcrossentropy | ||
using Statistics, Random | ||
using Logging: with_logger | ||
using TensorBoardLogger: TBLogger, tb_overwrite, set_step!, set_step_increment! | ||
using ProgressMeter: @showprogress | ||
import MLDatasets | ||
import BSON | ||
using CUDA | ||
using Formatting | ||
|
||
using DeepUncertainty | ||
|
||
# LeNet5 "constructor". | ||
# The model can be adapted to any image size | ||
# and any number of output classes. | ||
function LeNet5(args; imgsize = (28, 28, 1), nclasses = 10) | ||
out_conv_size = (imgsize[1] ÷ 4 - 3, imgsize[2] ÷ 4 - 3, 16) | ||
|
||
return Chain( | ||
ConvBatchEnsemble((5, 5), imgsize[end] => 6, args.rank, args.ensemble_size, relu), | ||
MaxPool((2, 2)), | ||
ConvBatchEnsemble((5, 5), 6 => 16, args.rank, args.ensemble_size, relu), | ||
MaxPool((2, 2)), | ||
flatten, | ||
DenseBatchEnsemble(prod(out_conv_size), 120, args.rank, args.ensemble_size, relu), | ||
DenseBatchEnsemble(120, 84, args.rank, args.ensemble_size, relu), | ||
DenseBatchEnsemble(84, nclasses, args.rank, args.ensemble_size), | ||
) | ||
end | ||
|
||
function get_data(args) | ||
xtrain, ytrain = MLDatasets.MNIST.traindata(Float32) | ||
xtest, ytest = MLDatasets.MNIST.testdata(Float32) | ||
|
||
xtrain = reshape(xtrain, 28, 28, 1, :) | ||
xtest = reshape(xtest, 28, 28, 1, :) | ||
|
||
ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9) | ||
|
||
train_loader = DataLoader( | ||
(xtrain, ytrain), | ||
batchsize = args.batchsize, | ||
shuffle = true, | ||
partial = false, | ||
) | ||
test_loader = DataLoader((xtest, ytest), batchsize = args.batchsize, partial = false) | ||
|
||
return train_loader, test_loader | ||
end | ||
|
||
loss(ŷ, y) = logitcrossentropy(ŷ, y) | ||
|
||
function accuracy(preds, labels) | ||
acc = sum(onecold(preds |> cpu) .== onecold(labels |> cpu)) | ||
return acc | ||
end | ||
|
||
function eval_loss_accuracy(args, loader, model, device) | ||
l = [0.0f0 for x = 1:args.ensemble_size] | ||
acc = [0 for x = 1:args.ensemble_size] | ||
ece_list = [0.0f0 for x = 1:args.ensemble_size] | ||
ntot = 0 | ||
mean_l = 0 | ||
mean_acc = 0 | ||
mean_ece = 0 | ||
for (x, y) in loader | ||
x = repeat(x, 1, 1, 1, args.ensemble_size) | ||
x, y = x |> device, y |> device | ||
# Perform the forward pass | ||
ŷ = model(x) | ||
ŷ = softmax(ŷ, dims = 1) | ||
# Reshape the predictions into [classes, batch_size, ensemble_size | ||
reshaped_ŷ = reshape(ŷ, size(ŷ)[1], args.batchsize, args.ensemble_size) | ||
# Loop through each model's predictions | ||
for ensemble = 1:args.ensemble_size | ||
model_predictions = reshaped_ŷ[:, :, ensemble] | ||
# Calculate individual loss | ||
l[ensemble] += loss(model_predictions, y) * size(model_predictions)[end] | ||
acc[ensemble] += accuracy(model_predictions, y) | ||
ece_list[ensemble] += | ||
ExpectedCalibrationError(model_predictions |> cpu, onecold(y |> cpu)) * | ||
args.batchsize | ||
end | ||
# Get the mean predictions | ||
mean_predictions = mean(reshaped_ŷ, dims = ndims(reshaped_ŷ)) | ||
mean_predictions = dropdims(mean_predictions, dims = ndims(mean_predictions)) | ||
mean_l += loss(mean_predictions, y) * size(mean_predictions)[end] | ||
mean_acc += accuracy(mean_predictions, y) | ||
mean_ece += | ||
ExpectedCalibrationError(mean_predictions |> cpu, onecold(y |> cpu)) * | ||
args.batchsize | ||
ntot += size(mean_predictions)[end] | ||
end | ||
# Normalize the loss | ||
losses = [loss / ntot |> round4 for loss in l] | ||
acc = [a / ntot * 100 |> round4 for a in acc] | ||
ece_list = [x / ntot |> round4 for x in ece_list] | ||
# Calculate mean loss | ||
mean_l = mean_l / ntot |> round4 | ||
mean_acc = mean_acc / ntot * 100 |> round4 | ||
mean_ece = mean_ece / ntot |> round4 | ||
|
||
# Print the per ensemble mode loss and accuracy | ||
for ensemble = 1:args.ensemble_size | ||
@info (format( | ||
"Model {} Loss: {} Accuracy: {} ECE: {}", | ||
ensemble, | ||
losses[ensemble], | ||
acc[ensemble], | ||
ece_list[ensemble], | ||
)) | ||
end | ||
@info (format( | ||
"Mean Loss: {} Mean Accuracy: {} Mean ECE: {}", | ||
mean_l, | ||
mean_acc, | ||
mean_ece, | ||
)) | ||
@info "===========================================================" | ||
return nothing | ||
end | ||
|
||
## utility functions | ||
num_params(model) = sum(length, Flux.params(model)) | ||
round4(x) = round(x, digits = 4) | ||
|
||
# arguments for the `train` function | ||
Base.@kwdef mutable struct Args | ||
η = 3e-4 # learning rate | ||
λ = 0 # L2 regularizer param, implemented as weight decay | ||
batchsize = 32 # batch size | ||
epochs = 10 # number of epochs | ||
seed = 0 # set seed > 0 for reproducibility | ||
use_cuda = true # if true use cuda (if available) | ||
infotime = 1 # report every `infotime` epochs | ||
checktime = 5 # Save the model every `checktime` epochs. Set to 0 for no checkpoints. | ||
savepath = "runs/" # results path | ||
rank = 1 | ||
ensemble_size = 4 | ||
end | ||
|
||
function train(; kws...) | ||
args = Args(; kws...) | ||
args.seed > 0 && Random.seed!(args.seed) | ||
use_cuda = args.use_cuda && CUDA.functional() | ||
|
||
if use_cuda | ||
device = gpu | ||
@info "Training on GPU" | ||
else | ||
device = cpu | ||
@info "Training on CPU" | ||
end | ||
|
||
## DATA | ||
train_loader, test_loader = get_data(args) | ||
@info "Dataset MNIST: $(train_loader.nobs) train and $(test_loader.nobs) test examples" | ||
|
||
## MODEL AND OPTIMIZER | ||
model = LeNet5(args) |> device | ||
@info "LeNet5 model: $(num_params(model)) trainable params" | ||
|
||
ps = Flux.params(model) | ||
|
||
opt = ADAM(args.η) | ||
if args.λ > 0 # add weight decay, equivalent to L2 regularization | ||
opt = Optimiser(WeightDecay(args.λ), opt) | ||
end | ||
|
||
function report(epoch) | ||
# @info "Train Metrics" | ||
# eval_loss_accuracy(args, train_loader, model, device) | ||
@info "Test metrics" | ||
eval_loss_accuracy(args, test_loader, model, device) | ||
end | ||
|
||
## TRAINING | ||
@info "Start Training" | ||
report(0) | ||
for epoch = 1:args.epochs | ||
@showprogress for (x, y) in train_loader | ||
# Make copies of batches for ensembles | ||
x = repeat(x, 1, 1, 1, args.ensemble_size) | ||
y = repeat(y, 1, args.ensemble_size) | ||
x, y = x |> device, y |> device | ||
gs = Flux.gradient(ps) do | ||
ŷ = model(x) | ||
loss(ŷ, y) | ||
end | ||
|
||
Flux.Optimise.update!(opt, ps, gs) | ||
end | ||
|
||
## Printing and logging | ||
epoch % args.infotime == 0 && report(epoch) | ||
end | ||
end | ||
|
||
train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,17 @@ | ||
module DeepUncertainty | ||
|
||
using Flux | ||
using Random | ||
using Flux: @functor, glorot_normal, create_bias | ||
|
||
# Export layers | ||
export MCLayer, MCDense, MCConv | ||
export DenseBatchEnsemble, ConvBatchEnsemble | ||
export mean_loglikelihood, brier_score, ExpectedCalibrationError, prediction_metrics | ||
|
||
include("metrics.jl") | ||
include("layers/mclayers.jl") | ||
include("layers/BatchEnsemble/dense.jl") | ||
include("layers/BatchEnsemble/conv.jl") | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
""" | ||
ConvBatchEnsemble(filter, in => out, rank, | ||
ensemble_size, σ = identity; | ||
stride = 1, pad = 0, dilation = 1, | ||
groups = 1, [bias, weight, init]) | ||
ConvBatchEnsemble(layer, alpha, gamma, ensemble_bias, ensemble_act, rank) | ||
Creates a conv BatchEnsemble layer. Batch ensemble is a memory efficient alternative | ||
for deep ensembles. In deep ensembles, if the ensemble size is N, N different models | ||
are trained, making the time and memory complexity O(N * complexity of one network). | ||
BatchEnsemble generates weight matrices for each member in the ensemble using a | ||
couple of rank 1 vectors R (alpha), S (gamma), RS' and multiplying the result with | ||
weight matrix W element wise. We also call R and S as fast weights. | ||
Reference - https://arxiv.org/abs/2002.06715 | ||
During both training and testing, we repeat the samples along the batch dimension | ||
N times, where N is the ensemble_size. For example, if each mini batch has 10 samples | ||
and our ensemble size is 4, then the actual input to the layer has 40 samples. | ||
The output of the layer has 40 samples as well, and each 10 samples can be considered | ||
as the output of an esnemble member. | ||
# Fields | ||
- `layer`: The dense layer which transforms the pertubed input to output | ||
- `alpha`: The first Fast weight of size (in_dim, ensemble_size) | ||
- `gamma`: The second Fast weight of size (out_dim, ensemble_size) | ||
- `ensemble_bias`: Bias added to the ensemble output, separate from dense layer bias | ||
- `ensemble_act`: The activation function to be applied on ensemble output | ||
- `rank`: Rank of the fast weights (rank > 1 doesn't work on GPU for now) | ||
# Arguments | ||
- `filter::NTuple{N,Integer}`: Kernel dimensions, eg, (5, 5) | ||
- `ch::Pair{<:Integer,<:Integer}`: Input channels => output channels | ||
- `rank::Integer`: Rank of the fast weights | ||
- `ensemble_size::Integer`: Number of models in the ensemble | ||
- `σ::F=identity`: Activation of the dense layer, defaults to identity | ||
- `init=glorot_normal`: Initialization function, defaults to glorot_normal | ||
- `alpha_init=glorot_normal`: Initialization function for the alpha fast weight, | ||
defaults to glorot_normal | ||
- `gamma_init=glorot_normal`: Initialization function for the gamma fast weight, | ||
defaults to glorot_normal | ||
- `bias::Bool=true`: Toggle the usage of bias in the dense layer | ||
- `ensemble_bias::Bool=true`: Toggle the usage of ensemble bias | ||
- `ensemble_act::F=identity`: Activation function for enseble outputs | ||
""" | ||
struct ConvBatchEnsemble{L,F,M,B} | ||
layer::L | ||
alpha::M | ||
gamma::M | ||
ensemble_bias::B | ||
ensemble_act::F | ||
rank::Any | ||
function ConvBatchEnsemble( | ||
layer::L, | ||
alpha::M, | ||
gamma::M, | ||
ensemble_bias = true, | ||
ensemble_act::F = identity, | ||
rank = 1, | ||
) where {M,F,L} | ||
ensemble_bias = create_bias(gamma, ensemble_bias, size(gamma)[1], size(gamma)[2]) | ||
new{typeof(layer),F,M,typeof(ensemble_bias)}( | ||
layer, | ||
alpha, | ||
gamma, | ||
ensemble_bias, | ||
ensemble_act, | ||
rank, | ||
) | ||
end | ||
end | ||
|
||
function ConvBatchEnsemble( | ||
k::NTuple{N,Integer}, | ||
ch::Pair{<:Integer,<:Integer}, | ||
rank::Integer, | ||
ensemble_size::Integer, | ||
σ = identity; | ||
init = glorot_normal, | ||
alpha_init = glorot_normal, | ||
gamma_init = glorot_normal, | ||
stride = 1, | ||
pad = 0, | ||
dilation = 1, | ||
groups = 1, | ||
bias = true, | ||
ensemble_bias = true, | ||
ensemble_act = identity, | ||
) where {N} | ||
layer = Flux.Conv( | ||
k, | ||
ch, | ||
σ; | ||
stride = stride, | ||
pad = pad, | ||
dilation = dilation, | ||
init = init, | ||
groups = groups, | ||
bias = bias, | ||
) | ||
in_dim = ch[1] | ||
out_dim = ch[2] | ||
if rank >= 1 | ||
alpha_shape = (in_dim, ensemble_size) | ||
gamma_shape = (out_dim, ensemble_size) | ||
else | ||
error("Rank must be >= 1.") | ||
end | ||
alpha = alpha_init(alpha_shape) | ||
gamma = gamma_init(gamma_shape) | ||
|
||
return ConvBatchEnsemble(layer, alpha, gamma, ensemble_bias, ensemble_act, rank) | ||
end | ||
|
||
@functor ConvBatchEnsemble | ||
|
||
function (be::ConvBatchEnsemble)(x) | ||
# Conv Batch Ensemble params | ||
layer = be.layer | ||
alpha = be.alpha | ||
gamma = be.gamma | ||
e_b = be.ensemble_bias | ||
e_σ = be.ensemble_act | ||
|
||
batch_size = size(x)[end] | ||
in_size = size(alpha)[1] | ||
out_size = size(gamma)[1] | ||
ensemble_size = size(alpha)[2] | ||
samples_per_model = batch_size ÷ ensemble_size | ||
|
||
# Alpha, gamma shapes - [units, ensembles, rank] | ||
e_b = repeat(e_b, samples_per_model) | ||
alpha = repeat(alpha, samples_per_model) | ||
gamma = repeat(gamma, samples_per_model) | ||
# Reshape alpha, gamma to [units, batch_size, rank] | ||
e_b = reshape(e_b, (1, 1, out_size, batch_size)) | ||
alpha = reshape(alpha, (1, 1, in_size, batch_size)) | ||
gamma = reshape(gamma, (1, 1, out_size, batch_size)) | ||
|
||
perturbed_x = x .* alpha | ||
output = layer(perturbed_x) .* gamma | ||
output = e_σ.(output .+ e_b) | ||
|
||
return output | ||
end |
Oops, something went wrong.