Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

All You Need is a Good Init #3

Open
yobibyte opened this issue Sep 5, 2016 · 0 comments
Open

All You Need is a Good Init #3

yobibyte opened this issue Sep 5, 2016 · 0 comments

Comments

@yobibyte
Copy link
Owner

yobibyte commented Sep 5, 2016

The post has moved to my new page

ICLR 2016 has a really interesting paper "All You Need is a Good Init". In this post I will try to repeat the results of the authors and will do that in Torch.

General idea

So, what is the motivation? Batch Normalization helps, but it slows down the training process (the authors claim it's 30%). Can we do better without additional overhead?

Yes, we can spend more time for smart weights initialization (not much more), but get benefits in training speed, ability to use bigger learning rates and better results.

Pre-initialize network with orthonormal matrices as in Saxe et al.(2014)
for each layer L do
    while |Var(B_L) - 1.0| >= Tol_var and (T_i) < T_max) do
        do Forward pass with a mini-batch
        calculate Var(B_L)
        W_L = W_L /sqrt(Var(B_L))
    end while
end for

What I found important for the implementation here:

  • L is a Convolutional or Fully Connected layer
  • For each layer we use a new minibatch.
  • We compute variance for the whole data in minibatch: Var(B_L)! (at first I thought that we compute variance feature-wise)

Torch implementation

require 'nn'

nn.Sequential.lsuvInit = function (self, get_batch, tol_var, t_max)
   local tol_var = tol_var or 0.1
   local t_max   = t_max or 10

   for _,m in ipairs(self:listModules()) do
      if m.weight ~= nil then
         local t_i = 1
         while true do
            local input = get_batch()
            self:forward(input)
            local var = torch.var(m.output)
            if torch.abs(var - 1.0) < tol_var or t_i > t_max then
               break
            end
            m.weight:div(math.sqrt(var))
            t_i = t_i + 1
         end
      end
   end
end

Usage (from MNIST example):

require 'nn'

-- add nninit.orthogonal to all convolutional and fully connected layers
model:add(nn.SpatialConvolutionMM(1, 32, 5, 5):init('weight', nninit.orthogonal, {gain = 'relu'}))
model:add(nn.ReLU())
...
model:add(nn.Linear(200, #classes):init('weight', nninit.orthogonal, {gain = 'relu'}))

--do LSUV after orthogonal init above
if opt.lsuv then
  model:lsuvInit(get_batch)
end

MNIST example

I used the following bash command to run the experiment (-f for full mnist dataset: 60 000 for training and 10 000 for testing):

th mnist-example.lua --lsuv -r lr
epoch with lsuv (lr=0.1) with lsuv (lr=0.05) without lsuv (lr=0.001) with lsuv (lr=0.001)
1 97.77% 96.69% 83.39% 78.28%
2 98.45% 97.94% 89.25% 87.75%
3 98.63% 98.37% 91.23% 91.19%
4 98.74% 98.57 92.46% 92.82%
5 98.88% 98.72% 93.23% 93.81%
6 98.97% 98.75% 93.88% 94.53%
7 99.03% 98.86% 94.44% 95.06%
8 99.01% 98.86% 94.81% 95.4%
9 99.01% 98.9% 95.03% 95.87%
10 98.96% 98.91 95.29% 96.15%

I did not wait for 100 epochs as the authors of the original paper did. At first, I thought that we can use bigger learning rates when we use LSUV, but then I realised that MNIST nolsuv case does not use BN, so, this is not true. And MNIST results just show us that training works and the accuracy rates are pretty comparable. Let's have a look at CIFAR-10 experiment.

CIFAR example

I did not check the limit of the accuracy we can achieve, but just checked if the training is comparable in general. And it is. Test dataset accuracy is on the pic.

References

Thanks for the debugging and help to @ikostrikov

If you want to ask me a question, you can find me here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant