You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
ICLR 2016 has a really interesting paper "All You Need is a Good Init". In this post I will try to repeat the results of the authors and will do that in Torch.
General idea
So, what is the motivation? Batch Normalization helps, but it slows down the training process (the authors claim it's 30%). Can we do better without additional overhead?
Yes, we can spend more time for smart weights initialization (not much more), but get benefits in training speed, ability to use bigger learning rates and better results.
Pre-initialize network with orthonormal matrices as in Saxe et al.(2014)
for each layer L do
while |Var(B_L) - 1.0| >= Tol_var and (T_i) < T_max) do
do Forward pass with a mini-batch
calculate Var(B_L)
W_L = W_L /sqrt(Var(B_L))
end while
end for
What I found important for the implementation here:
L is a Convolutional or Fully Connected layer
For each layer we use a new minibatch.
We compute variance for the whole data in minibatch: Var(B_L)! (at first I thought that we compute variance feature-wise)
require'nn'-- add nninit.orthogonal to all convolutional and fully connected layersmodel:add(nn.SpatialConvolutionMM(1, 32, 5, 5):init('weight', nninit.orthogonal, {gain='relu'}))
model:add(nn.ReLU())
...model:add(nn.Linear(200, #classes):init('weight', nninit.orthogonal, {gain='relu'}))
--do LSUV after orthogonal init aboveifopt.lsuvthenmodel:lsuvInit(get_batch)
end
MNIST example
I used the following bash command to run the experiment (-f for full mnist dataset: 60 000 for training and 10 000 for testing):
th mnist-example.lua --lsuv -r lr
epoch
with lsuv (lr=0.1)
with lsuv (lr=0.05)
without lsuv (lr=0.001)
with lsuv (lr=0.001)
1
97.77%
96.69%
83.39%
78.28%
2
98.45%
97.94%
89.25%
87.75%
3
98.63%
98.37%
91.23%
91.19%
4
98.74%
98.57
92.46%
92.82%
5
98.88%
98.72%
93.23%
93.81%
6
98.97%
98.75%
93.88%
94.53%
7
99.03%
98.86%
94.44%
95.06%
8
99.01%
98.86%
94.81%
95.4%
9
99.01%
98.9%
95.03%
95.87%
10
98.96%
98.91
95.29%
96.15%
I did not wait for 100 epochs as the authors of the original paper did. At first, I thought that we can use bigger learning rates when we use LSUV, but then I realised that MNIST nolsuv case does not use BN, so, this is not true. And MNIST results just show us that training works and the accuracy rates are pretty comparable. Let's have a look at CIFAR-10 experiment.
CIFAR example
I did not check the limit of the accuracy we can achieve, but just checked if the training is comparable in general. And it is. Test dataset accuracy is on the pic.
The post has moved to my new page
ICLR 2016 has a really interesting paper "All You Need is a Good Init". In this post I will try to repeat the results of the authors and will do that in Torch.
General idea
So, what is the motivation? Batch Normalization helps, but it slows down the training process (the authors claim it's 30%). Can we do better without additional overhead?
Yes, we can spend more time for smart weights initialization (not much more), but get benefits in training speed, ability to use bigger learning rates and better results.
What I found important for the implementation here:
Torch implementation
Usage (from MNIST example):
MNIST example
I used the following bash command to run the experiment (-f for full mnist dataset: 60 000 for training and 10 000 for testing):
I did not wait for 100 epochs as the authors of the original paper did. At first, I thought that we can use bigger learning rates when we use LSUV, but then I realised that MNIST nolsuv case does not use BN, so, this is not true. And MNIST results just show us that training works and the accuracy rates are pretty comparable. Let's have a look at CIFAR-10 experiment.
CIFAR example
I did not check the limit of the accuracy we can achieve, but just checked if the training is comparable in general. And it is. Test dataset accuracy is on the pic.
References
Thanks for the debugging and help to @ikostrikov
If you want to ask me a question, you can find me here
The text was updated successfully, but these errors were encountered: