Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Dense and Conv BatchEnsemble layers along with unit tests and example on MNIST classification using LeNet5 #4

Merged
merged 3 commits into from
Sep 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReliabilityDiagrams = "e5f51471-6270-49e4-a15a-f1cfbff4f856"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
julia = "1"
Expand Down
207 changes: 207 additions & 0 deletions examples/batchensemble.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
## Classification of MNIST dataset
## with the convolutional neural network known as LeNet5.
## This script also combines various
## packages from the Julia ecosystem with Flux.
using Flux
using Flux.Data: DataLoader
using Flux.Optimise: Optimiser, WeightDecay
using Flux: onehotbatch, onecold, glorot_normal, label_smoothing
using Flux.Losses: logitcrossentropy
using Statistics, Random
using Logging: with_logger
using TensorBoardLogger: TBLogger, tb_overwrite, set_step!, set_step_increment!
using ProgressMeter: @showprogress
import MLDatasets
import BSON
using CUDA
using Formatting

using DeepUncertainty

# LeNet5 "constructor".
# The model can be adapted to any image size
# and any number of output classes.
function LeNet5(args; imgsize = (28, 28, 1), nclasses = 10)
out_conv_size = (imgsize[1] ÷ 4 - 3, imgsize[2] ÷ 4 - 3, 16)

return Chain(
ConvBatchEnsemble((5, 5), imgsize[end] => 6, args.rank, args.ensemble_size, relu),
MaxPool((2, 2)),
ConvBatchEnsemble((5, 5), 6 => 16, args.rank, args.ensemble_size, relu),
MaxPool((2, 2)),
flatten,
DenseBatchEnsemble(prod(out_conv_size), 120, args.rank, args.ensemble_size, relu),
DenseBatchEnsemble(120, 84, args.rank, args.ensemble_size, relu),
DenseBatchEnsemble(84, nclasses, args.rank, args.ensemble_size),
)
end

function get_data(args)
xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
xtest, ytest = MLDatasets.MNIST.testdata(Float32)

xtrain = reshape(xtrain, 28, 28, 1, :)
xtest = reshape(xtest, 28, 28, 1, :)

ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9)

train_loader = DataLoader(
(xtrain, ytrain),
batchsize = args.batchsize,
shuffle = true,
partial = false,
)
test_loader = DataLoader((xtest, ytest), batchsize = args.batchsize, partial = false)

return train_loader, test_loader
end

loss(ŷ, y) = logitcrossentropy(ŷ, y)

function accuracy(preds, labels)
acc = sum(onecold(preds |> cpu) .== onecold(labels |> cpu))
return acc
end

function eval_loss_accuracy(args, loader, model, device)
l = [0.0f0 for x = 1:args.ensemble_size]
acc = [0 for x = 1:args.ensemble_size]
ece_list = [0.0f0 for x = 1:args.ensemble_size]
ntot = 0
mean_l = 0
mean_acc = 0
mean_ece = 0
for (x, y) in loader
x = repeat(x, 1, 1, 1, args.ensemble_size)
x, y = x |> device, y |> device
# Perform the forward pass
ŷ = model(x)
ŷ = softmax(ŷ, dims = 1)
# Reshape the predictions into [classes, batch_size, ensemble_size
reshaped_ŷ = reshape(ŷ, size(ŷ)[1], args.batchsize, args.ensemble_size)
# Loop through each model's predictions
for ensemble = 1:args.ensemble_size
model_predictions = reshaped_ŷ[:, :, ensemble]
# Calculate individual loss
l[ensemble] += loss(model_predictions, y) * size(model_predictions)[end]
acc[ensemble] += accuracy(model_predictions, y)
ece_list[ensemble] +=
ExpectedCalibrationError(model_predictions |> cpu, onecold(y |> cpu)) *
args.batchsize
end
# Get the mean predictions
mean_predictions = mean(reshaped_ŷ, dims = ndims(reshaped_ŷ))
mean_predictions = dropdims(mean_predictions, dims = ndims(mean_predictions))
mean_l += loss(mean_predictions, y) * size(mean_predictions)[end]
mean_acc += accuracy(mean_predictions, y)
mean_ece +=
ExpectedCalibrationError(mean_predictions |> cpu, onecold(y |> cpu)) *
args.batchsize
ntot += size(mean_predictions)[end]
end
# Normalize the loss
losses = [loss / ntot |> round4 for loss in l]
acc = [a / ntot * 100 |> round4 for a in acc]
ece_list = [x / ntot |> round4 for x in ece_list]
# Calculate mean loss
mean_l = mean_l / ntot |> round4
mean_acc = mean_acc / ntot * 100 |> round4
mean_ece = mean_ece / ntot |> round4

# Print the per ensemble mode loss and accuracy
for ensemble = 1:args.ensemble_size
@info (format(
"Model {} Loss: {} Accuracy: {} ECE: {}",
ensemble,
losses[ensemble],
acc[ensemble],
ece_list[ensemble],
))
end
@info (format(
"Mean Loss: {} Mean Accuracy: {} Mean ECE: {}",
mean_l,
mean_acc,
mean_ece,
))
@info "==========================================================="
return nothing
end

## utility functions
num_params(model) = sum(length, Flux.params(model))
round4(x) = round(x, digits = 4)

# arguments for the `train` function
Base.@kwdef mutable struct Args
η = 3e-4 # learning rate
λ = 0 # L2 regularizer param, implemented as weight decay
batchsize = 32 # batch size
epochs = 10 # number of epochs
seed = 0 # set seed > 0 for reproducibility
use_cuda = true # if true use cuda (if available)
infotime = 1 # report every `infotime` epochs
checktime = 5 # Save the model every `checktime` epochs. Set to 0 for no checkpoints.
savepath = "runs/" # results path
rank = 1
ensemble_size = 4
end

function train(; kws...)
args = Args(; kws...)
args.seed > 0 && Random.seed!(args.seed)
use_cuda = args.use_cuda && CUDA.functional()

if use_cuda
device = gpu
@info "Training on GPU"
else
device = cpu
@info "Training on CPU"
end

## DATA
train_loader, test_loader = get_data(args)
@info "Dataset MNIST: $(train_loader.nobs) train and $(test_loader.nobs) test examples"

## MODEL AND OPTIMIZER
model = LeNet5(args) |> device
@info "LeNet5 model: $(num_params(model)) trainable params"

ps = Flux.params(model)

opt = ADAM(args.η)
if args.λ > 0 # add weight decay, equivalent to L2 regularization
opt = Optimiser(WeightDecay(args.λ), opt)
end

function report(epoch)
# @info "Train Metrics"
# eval_loss_accuracy(args, train_loader, model, device)
@info "Test metrics"
eval_loss_accuracy(args, test_loader, model, device)
end

## TRAINING
@info "Start Training"
report(0)
for epoch = 1:args.epochs
@showprogress for (x, y) in train_loader
# Make copies of batches for ensembles
x = repeat(x, 1, 1, 1, args.ensemble_size)
y = repeat(y, 1, args.ensemble_size)
x, y = x |> device, y |> device
gs = Flux.gradient(ps) do
ŷ = model(x)
loss(ŷ, y)
end

Flux.Optimise.update!(opt, ps, gs)
end

## Printing and logging
epoch % args.infotime == 0 && report(epoch)
end
end

train()
7 changes: 7 additions & 0 deletions src/DeepUncertainty.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
module DeepUncertainty

using Flux
using Random
using Flux: @functor, glorot_normal, create_bias

# Export layers
export MCLayer, MCDense, MCConv
export DenseBatchEnsemble, ConvBatchEnsemble
export mean_loglikelihood, brier_score, ExpectedCalibrationError, prediction_metrics

include("metrics.jl")
include("layers/mclayers.jl")
include("layers/BatchEnsemble/dense.jl")
include("layers/BatchEnsemble/conv.jl")

end
145 changes: 145 additions & 0 deletions src/layers/BatchEnsemble/conv.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""
ConvBatchEnsemble(filter, in => out, rank,
ensemble_size, σ = identity;
stride = 1, pad = 0, dilation = 1,
groups = 1, [bias, weight, init])
ConvBatchEnsemble(layer, alpha, gamma, ensemble_bias, ensemble_act, rank)

Creates a conv BatchEnsemble layer. Batch ensemble is a memory efficient alternative
for deep ensembles. In deep ensembles, if the ensemble size is N, N different models
are trained, making the time and memory complexity O(N * complexity of one network).
BatchEnsemble generates weight matrices for each member in the ensemble using a
couple of rank 1 vectors R (alpha), S (gamma), RS' and multiplying the result with
weight matrix W element wise. We also call R and S as fast weights.

Reference - https://arxiv.org/abs/2002.06715

During both training and testing, we repeat the samples along the batch dimension
N times, where N is the ensemble_size. For example, if each mini batch has 10 samples
and our ensemble size is 4, then the actual input to the layer has 40 samples.
The output of the layer has 40 samples as well, and each 10 samples can be considered
as the output of an esnemble member.

# Fields
- `layer`: The dense layer which transforms the pertubed input to output
- `alpha`: The first Fast weight of size (in_dim, ensemble_size)
- `gamma`: The second Fast weight of size (out_dim, ensemble_size)
- `ensemble_bias`: Bias added to the ensemble output, separate from dense layer bias
- `ensemble_act`: The activation function to be applied on ensemble output
- `rank`: Rank of the fast weights (rank > 1 doesn't work on GPU for now)

# Arguments
- `filter::NTuple{N,Integer}`: Kernel dimensions, eg, (5, 5)
- `ch::Pair{<:Integer,<:Integer}`: Input channels => output channels
- `rank::Integer`: Rank of the fast weights
- `ensemble_size::Integer`: Number of models in the ensemble
- `σ::F=identity`: Activation of the dense layer, defaults to identity
- `init=glorot_normal`: Initialization function, defaults to glorot_normal
- `alpha_init=glorot_normal`: Initialization function for the alpha fast weight,
defaults to glorot_normal
- `gamma_init=glorot_normal`: Initialization function for the gamma fast weight,
defaults to glorot_normal
- `bias::Bool=true`: Toggle the usage of bias in the dense layer
- `ensemble_bias::Bool=true`: Toggle the usage of ensemble bias
- `ensemble_act::F=identity`: Activation function for enseble outputs
"""
struct ConvBatchEnsemble{L,F,M,B}
layer::L
alpha::M
gamma::M
ensemble_bias::B
ensemble_act::F
rank::Any
DwaraknathT marked this conversation as resolved.
Show resolved Hide resolved
function ConvBatchEnsemble(
layer::L,
alpha::M,
gamma::M,
ensemble_bias = true,
ensemble_act::F = identity,
rank = 1,
) where {M,F,L}
ensemble_bias = create_bias(gamma, ensemble_bias, size(gamma)[1], size(gamma)[2])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you test it with FluxML/Flux.jl#1402

new{typeof(layer),F,M,typeof(ensemble_bias)}(
layer,
alpha,
gamma,
ensemble_bias,
ensemble_act,
rank,
)
end
end

function ConvBatchEnsemble(
k::NTuple{N,Integer},
ch::Pair{<:Integer,<:Integer},
rank::Integer,
ensemble_size::Integer,
σ = identity;
init = glorot_normal,
alpha_init = glorot_normal,
gamma_init = glorot_normal,
stride = 1,
pad = 0,
dilation = 1,
groups = 1,
bias = true,
ensemble_bias = true,
ensemble_act = identity,
Comment on lines +73 to +88
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as last time about keeping things simple and general.

Maybe it makes sense to have a constructor that takes in a Conv layer directly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it does. I guess we can have both as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually need the input/output dimensions to create the alpha/gamma matrices. Might as well keep them in the signature, or we'll have to infer them from the conv layer's struct and that might change anytime in flux source ?

) where {N}
layer = Flux.Conv(
k,
ch,
σ;
stride = stride,
pad = pad,
dilation = dilation,
init = init,
groups = groups,
bias = bias,
)
in_dim = ch[1]
out_dim = ch[2]
if rank >= 1
alpha_shape = (in_dim, ensemble_size)
gamma_shape = (out_dim, ensemble_size)
else
error("Rank must be >= 1.")
end
alpha = alpha_init(alpha_shape)
gamma = gamma_init(gamma_shape)

return ConvBatchEnsemble(layer, alpha, gamma, ensemble_bias, ensemble_act, rank)
end

@functor ConvBatchEnsemble

function (be::ConvBatchEnsemble)(x)
# Conv Batch Ensemble params
layer = be.layer
alpha = be.alpha
gamma = be.gamma
e_b = be.ensemble_bias
e_σ = be.ensemble_act

batch_size = size(x)[end]
in_size = size(alpha)[1]
out_size = size(gamma)[1]
ensemble_size = size(alpha)[2]
samples_per_model = batch_size ÷ ensemble_size

# Alpha, gamma shapes - [units, ensembles, rank]
e_b = repeat(e_b, samples_per_model)
alpha = repeat(alpha, samples_per_model)
gamma = repeat(gamma, samples_per_model)
# Reshape alpha, gamma to [units, batch_size, rank]
e_b = reshape(e_b, (1, 1, out_size, batch_size))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Size of the bias seems relevant here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we know that the shape of the bias allocated can fit into the container its expected to be in

alpha = reshape(alpha, (1, 1, in_size, batch_size))
gamma = reshape(gamma, (1, 1, out_size, batch_size))

perturbed_x = x .* alpha
output = layer(perturbed_x) .* gamma
output = e_σ.(output .+ e_b)

return output
end
Loading