Skip to content

Commit

Permalink
Added Dense and Conv BatchEnsemble layers along with unit tests and e…
Browse files Browse the repository at this point in the history
…xample on MNIST classification using LeNet5 (#4)

* Added Dense and Conv BatchEnsemble layers along with unit tests and example on MNIST classification using LeNet5

* Merged dense batchensemble forward passes for rank>1 and rank=1; Fixed conv batchensemble unit test

* Changes:
1. Reduce imports and move them to main file
2. Renamed test file names
3. Added GPU tests for layers -- for now it's basic forward pass etc
  • Loading branch information
DwaraknathT authored Sep 7, 2021
1 parent b11fbea commit ed3d16c
Show file tree
Hide file tree
Showing 11 changed files with 728 additions and 18 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReliabilityDiagrams = "e5f51471-6270-49e4-a15a-f1cfbff4f856"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
julia = "1"
Expand Down
207 changes: 207 additions & 0 deletions examples/batchensemble.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
## Classification of MNIST dataset
## with the convolutional neural network known as LeNet5.
## This script also combines various
## packages from the Julia ecosystem with Flux.
using Flux
using Flux.Data: DataLoader
using Flux.Optimise: Optimiser, WeightDecay
using Flux: onehotbatch, onecold, glorot_normal, label_smoothing
using Flux.Losses: logitcrossentropy
using Statistics, Random
using Logging: with_logger
using TensorBoardLogger: TBLogger, tb_overwrite, set_step!, set_step_increment!
using ProgressMeter: @showprogress
import MLDatasets
import BSON
using CUDA
using Formatting

using DeepUncertainty

# LeNet5 "constructor".
# The model can be adapted to any image size
# and any number of output classes.
function LeNet5(args; imgsize = (28, 28, 1), nclasses = 10)
out_conv_size = (imgsize[1] ÷ 4 - 3, imgsize[2] ÷ 4 - 3, 16)

return Chain(
ConvBatchEnsemble((5, 5), imgsize[end] => 6, args.rank, args.ensemble_size, relu),
MaxPool((2, 2)),
ConvBatchEnsemble((5, 5), 6 => 16, args.rank, args.ensemble_size, relu),
MaxPool((2, 2)),
flatten,
DenseBatchEnsemble(prod(out_conv_size), 120, args.rank, args.ensemble_size, relu),
DenseBatchEnsemble(120, 84, args.rank, args.ensemble_size, relu),
DenseBatchEnsemble(84, nclasses, args.rank, args.ensemble_size),
)
end

function get_data(args)
xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
xtest, ytest = MLDatasets.MNIST.testdata(Float32)

xtrain = reshape(xtrain, 28, 28, 1, :)
xtest = reshape(xtest, 28, 28, 1, :)

ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9)

train_loader = DataLoader(
(xtrain, ytrain),
batchsize = args.batchsize,
shuffle = true,
partial = false,
)
test_loader = DataLoader((xtest, ytest), batchsize = args.batchsize, partial = false)

return train_loader, test_loader
end

loss(ŷ, y) = logitcrossentropy(ŷ, y)

function accuracy(preds, labels)
acc = sum(onecold(preds |> cpu) .== onecold(labels |> cpu))
return acc
end

function eval_loss_accuracy(args, loader, model, device)
l = [0.0f0 for x = 1:args.ensemble_size]
acc = [0 for x = 1:args.ensemble_size]
ece_list = [0.0f0 for x = 1:args.ensemble_size]
ntot = 0
mean_l = 0
mean_acc = 0
mean_ece = 0
for (x, y) in loader
x = repeat(x, 1, 1, 1, args.ensemble_size)
x, y = x |> device, y |> device
# Perform the forward pass
= model(x)
= softmax(ŷ, dims = 1)
# Reshape the predictions into [classes, batch_size, ensemble_size
reshaped_ŷ = reshape(ŷ, size(ŷ)[1], args.batchsize, args.ensemble_size)
# Loop through each model's predictions
for ensemble = 1:args.ensemble_size
model_predictions = reshaped_ŷ[:, :, ensemble]
# Calculate individual loss
l[ensemble] += loss(model_predictions, y) * size(model_predictions)[end]
acc[ensemble] += accuracy(model_predictions, y)
ece_list[ensemble] +=
ExpectedCalibrationError(model_predictions |> cpu, onecold(y |> cpu)) *
args.batchsize
end
# Get the mean predictions
mean_predictions = mean(reshaped_ŷ, dims = ndims(reshaped_ŷ))
mean_predictions = dropdims(mean_predictions, dims = ndims(mean_predictions))
mean_l += loss(mean_predictions, y) * size(mean_predictions)[end]
mean_acc += accuracy(mean_predictions, y)
mean_ece +=
ExpectedCalibrationError(mean_predictions |> cpu, onecold(y |> cpu)) *
args.batchsize
ntot += size(mean_predictions)[end]
end
# Normalize the loss
losses = [loss / ntot |> round4 for loss in l]
acc = [a / ntot * 100 |> round4 for a in acc]
ece_list = [x / ntot |> round4 for x in ece_list]
# Calculate mean loss
mean_l = mean_l / ntot |> round4
mean_acc = mean_acc / ntot * 100 |> round4
mean_ece = mean_ece / ntot |> round4

# Print the per ensemble mode loss and accuracy
for ensemble = 1:args.ensemble_size
@info (format(
"Model {} Loss: {} Accuracy: {} ECE: {}",
ensemble,
losses[ensemble],
acc[ensemble],
ece_list[ensemble],
))
end
@info (format(
"Mean Loss: {} Mean Accuracy: {} Mean ECE: {}",
mean_l,
mean_acc,
mean_ece,
))
@info "==========================================================="
return nothing
end

## utility functions
num_params(model) = sum(length, Flux.params(model))
round4(x) = round(x, digits = 4)

# arguments for the `train` function
Base.@kwdef mutable struct Args
η = 3e-4 # learning rate
λ = 0 # L2 regularizer param, implemented as weight decay
batchsize = 32 # batch size
epochs = 10 # number of epochs
seed = 0 # set seed > 0 for reproducibility
use_cuda = true # if true use cuda (if available)
infotime = 1 # report every `infotime` epochs
checktime = 5 # Save the model every `checktime` epochs. Set to 0 for no checkpoints.
savepath = "runs/" # results path
rank = 1
ensemble_size = 4
end

function train(; kws...)
args = Args(; kws...)
args.seed > 0 && Random.seed!(args.seed)
use_cuda = args.use_cuda && CUDA.functional()

if use_cuda
device = gpu
@info "Training on GPU"
else
device = cpu
@info "Training on CPU"
end

## DATA
train_loader, test_loader = get_data(args)
@info "Dataset MNIST: $(train_loader.nobs) train and $(test_loader.nobs) test examples"

## MODEL AND OPTIMIZER
model = LeNet5(args) |> device
@info "LeNet5 model: $(num_params(model)) trainable params"

ps = Flux.params(model)

opt = ADAM(args.η)
if args.λ > 0 # add weight decay, equivalent to L2 regularization
opt = Optimiser(WeightDecay(args.λ), opt)
end

function report(epoch)
# @info "Train Metrics"
# eval_loss_accuracy(args, train_loader, model, device)
@info "Test metrics"
eval_loss_accuracy(args, test_loader, model, device)
end

## TRAINING
@info "Start Training"
report(0)
for epoch = 1:args.epochs
@showprogress for (x, y) in train_loader
# Make copies of batches for ensembles
x = repeat(x, 1, 1, 1, args.ensemble_size)
y = repeat(y, 1, args.ensemble_size)
x, y = x |> device, y |> device
gs = Flux.gradient(ps) do
= model(x)
loss(ŷ, y)
end

Flux.Optimise.update!(opt, ps, gs)
end

## Printing and logging
epoch % args.infotime == 0 && report(epoch)
end
end

train()
7 changes: 7 additions & 0 deletions src/DeepUncertainty.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
module DeepUncertainty

using Flux
using Random
using Flux: @functor, glorot_normal, create_bias

# Export layers
export MCLayer, MCDense, MCConv
export DenseBatchEnsemble, ConvBatchEnsemble
export mean_loglikelihood, brier_score, ExpectedCalibrationError, prediction_metrics

include("metrics.jl")
include("layers/mclayers.jl")
include("layers/BatchEnsemble/dense.jl")
include("layers/BatchEnsemble/conv.jl")

end
145 changes: 145 additions & 0 deletions src/layers/BatchEnsemble/conv.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""
ConvBatchEnsemble(filter, in => out, rank,
ensemble_size, σ = identity;
stride = 1, pad = 0, dilation = 1,
groups = 1, [bias, weight, init])
ConvBatchEnsemble(layer, alpha, gamma, ensemble_bias, ensemble_act, rank)
Creates a conv BatchEnsemble layer. Batch ensemble is a memory efficient alternative
for deep ensembles. In deep ensembles, if the ensemble size is N, N different models
are trained, making the time and memory complexity O(N * complexity of one network).
BatchEnsemble generates weight matrices for each member in the ensemble using a
couple of rank 1 vectors R (alpha), S (gamma), RS' and multiplying the result with
weight matrix W element wise. We also call R and S as fast weights.
Reference - https://arxiv.org/abs/2002.06715
During both training and testing, we repeat the samples along the batch dimension
N times, where N is the ensemble_size. For example, if each mini batch has 10 samples
and our ensemble size is 4, then the actual input to the layer has 40 samples.
The output of the layer has 40 samples as well, and each 10 samples can be considered
as the output of an esnemble member.
# Fields
- `layer`: The dense layer which transforms the pertubed input to output
- `alpha`: The first Fast weight of size (in_dim, ensemble_size)
- `gamma`: The second Fast weight of size (out_dim, ensemble_size)
- `ensemble_bias`: Bias added to the ensemble output, separate from dense layer bias
- `ensemble_act`: The activation function to be applied on ensemble output
- `rank`: Rank of the fast weights (rank > 1 doesn't work on GPU for now)
# Arguments
- `filter::NTuple{N,Integer}`: Kernel dimensions, eg, (5, 5)
- `ch::Pair{<:Integer,<:Integer}`: Input channels => output channels
- `rank::Integer`: Rank of the fast weights
- `ensemble_size::Integer`: Number of models in the ensemble
- `σ::F=identity`: Activation of the dense layer, defaults to identity
- `init=glorot_normal`: Initialization function, defaults to glorot_normal
- `alpha_init=glorot_normal`: Initialization function for the alpha fast weight,
defaults to glorot_normal
- `gamma_init=glorot_normal`: Initialization function for the gamma fast weight,
defaults to glorot_normal
- `bias::Bool=true`: Toggle the usage of bias in the dense layer
- `ensemble_bias::Bool=true`: Toggle the usage of ensemble bias
- `ensemble_act::F=identity`: Activation function for enseble outputs
"""
struct ConvBatchEnsemble{L,F,M,B}
layer::L
alpha::M
gamma::M
ensemble_bias::B
ensemble_act::F
rank::Any
function ConvBatchEnsemble(
layer::L,
alpha::M,
gamma::M,
ensemble_bias = true,
ensemble_act::F = identity,
rank = 1,
) where {M,F,L}
ensemble_bias = create_bias(gamma, ensemble_bias, size(gamma)[1], size(gamma)[2])
new{typeof(layer),F,M,typeof(ensemble_bias)}(
layer,
alpha,
gamma,
ensemble_bias,
ensemble_act,
rank,
)
end
end

function ConvBatchEnsemble(
k::NTuple{N,Integer},
ch::Pair{<:Integer,<:Integer},
rank::Integer,
ensemble_size::Integer,
σ = identity;
init = glorot_normal,
alpha_init = glorot_normal,
gamma_init = glorot_normal,
stride = 1,
pad = 0,
dilation = 1,
groups = 1,
bias = true,
ensemble_bias = true,
ensemble_act = identity,
) where {N}
layer = Flux.Conv(
k,
ch,
σ;
stride = stride,
pad = pad,
dilation = dilation,
init = init,
groups = groups,
bias = bias,
)
in_dim = ch[1]
out_dim = ch[2]
if rank >= 1
alpha_shape = (in_dim, ensemble_size)
gamma_shape = (out_dim, ensemble_size)
else
error("Rank must be >= 1.")
end
alpha = alpha_init(alpha_shape)
gamma = gamma_init(gamma_shape)

return ConvBatchEnsemble(layer, alpha, gamma, ensemble_bias, ensemble_act, rank)
end

@functor ConvBatchEnsemble

function (be::ConvBatchEnsemble)(x)
# Conv Batch Ensemble params
layer = be.layer
alpha = be.alpha
gamma = be.gamma
e_b = be.ensemble_bias
e_σ = be.ensemble_act

batch_size = size(x)[end]
in_size = size(alpha)[1]
out_size = size(gamma)[1]
ensemble_size = size(alpha)[2]
samples_per_model = batch_size ÷ ensemble_size

# Alpha, gamma shapes - [units, ensembles, rank]
e_b = repeat(e_b, samples_per_model)
alpha = repeat(alpha, samples_per_model)
gamma = repeat(gamma, samples_per_model)
# Reshape alpha, gamma to [units, batch_size, rank]
e_b = reshape(e_b, (1, 1, out_size, batch_size))
alpha = reshape(alpha, (1, 1, in_size, batch_size))
gamma = reshape(gamma, (1, 1, out_size, batch_size))

perturbed_x = x .* alpha
output = layer(perturbed_x) .* gamma
output = e_σ.(output .+ e_b)

return output
end
Loading

0 comments on commit ed3d16c

Please sign in to comment.