Skip to content

Commit

Permalink
Merge pull request torch#283 from torch/batchnormfix
Browse files Browse the repository at this point in the history
batchnorm is clonable by adding the running estimates to constructor
  • Loading branch information
soumith committed Jun 3, 2015
2 parents 3bbed8d + 83a3815 commit fdd6659
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 36 deletions.
26 changes: 12 additions & 14 deletions BatchNormalization.lua
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,24 @@
]]--
local BN,parent = torch.class('nn.BatchNormalization', 'nn.Module')

function BN:__init(nOutput, eps, momentum)
function BN:__init(nOutput, eps, momentum, affine)
parent.__init(self)
assert(nOutput and type(nOutput) == 'number',
'Missing argument #1: dimensionality of input. ' ..
'Give 0 for no affine transform')
'Missing argument #1: dimensionality of input. ')
assert(nOutput ~= 0, 'To set affine=false call BatchNormalization'
.. '(nOutput, eps, momentum, false) ')
if affine ~= nil then
assert(type(affine) == 'boolean', 'affine has to be true/false')
self.affine = affine
else
self.affine = true
end
self.eps = eps or 1e-5
self.train = true
self.momentum = momentum or 0.1
self.running_mean = torch.Tensor()
self.running_std = torch.Tensor()
self.running_mean = torch.zeros(nOutput)
self.running_std = torch.ones(nOutput)

if nOutput > 0 then self.affine = true end
if self.affine then
self.weight = torch.Tensor(nOutput)
self.bias = torch.Tensor(nOutput)
Expand Down Expand Up @@ -71,20 +77,12 @@ function BN:updateOutput(input)
self.output:resizeAs(input)
self.gradInput:resizeAs(input)
if self.train == false then
assert(self.running_mean:nDimension() ~= 0,
'Module never run on training data. First run on some training data before evaluating.')
self.output:copy(input)
self.buffer:repeatTensor(self.running_mean, nBatch, 1)
self.output:add(-1, self.buffer)
self.buffer:repeatTensor(self.running_std, nBatch, 1)
self.output:cmul(self.buffer)
else -- training mode
if self.running_mean:nDimension() == 0 then
self.running_mean:resize(input:size(2)):zero()
end
if self.running_std:nDimension() == 0 then
self.running_std:resize(input:size(2)):zero()
end
-- calculate mean over mini-batch
self.buffer:mean(input, 1) -- E(x) = expectation of x.
self.running_mean:mul(1 - self.momentum):add(self.momentum, self.buffer) -- add to running mean
Expand Down
26 changes: 12 additions & 14 deletions SpatialBatchNormalization.lua
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,24 @@
]]--
local BN,parent = torch.class('nn.SpatialBatchNormalization', 'nn.Module')

function BN:__init(nFeature, eps, momentum)
function BN:__init(nFeature, eps, momentum, affine)
parent.__init(self)
assert(nFeature and type(nFeature) == 'number',
'Missing argument #1: Number of feature planes. ' ..
'Give 0 for no affine transform')
'Missing argument #1: Number of feature planes. ')
assert(nFeature ~= 0, 'To set affine=false call SpatialBatchNormalization'
.. '(nFeature, eps, momentum, false) ')
if affine ~=nil then
assert(type(affine) == 'boolean', 'affine has to be true/false')
self.affine = affine
else
self.affine = true
end
self.eps = eps or 1e-5
self.train = true
self.momentum = momentum or 0.1

self.running_mean = torch.Tensor()
self.running_std = torch.Tensor()
if nFeature > 0 then self.affine = true end
self.running_mean = torch.zeros(nFeature)
self.running_std = torch.ones(nFeature)
if self.affine then
self.weight = torch.Tensor(nFeature)
self.bias = torch.Tensor(nFeature)
Expand Down Expand Up @@ -75,20 +81,12 @@ function BN:updateOutput(input)
self.output:resizeAs(input)
self.gradInput:resizeAs(input)
if self.train == false then
assert(self.running_mean:nDimension() ~= 0,
'Module never run on training data. First run on some training data before evaluating.')
self.output:copy(input)
self.buffer:repeatTensor(self.running_mean:view(1, nFeature, 1, 1), nBatch, 1, iH, iW)
self.output:add(-1, self.buffer)
self.buffer:repeatTensor(self.running_std:view(1, nFeature, 1, 1), nBatch, 1, iH, iW)
self.output:cmul(self.buffer)
else -- training mode
if self.running_mean:nDimension() == 0 then
self.running_mean:resize(nFeature):zero()
end
if self.running_std:nDimension() == 0 then
self.running_std:resize(nFeature):zero()
end
-- calculate mean over mini-batch, over feature-maps
local in_folded = input:view(nBatch, nFeature, iH * iW)
self.buffer:mean(in_folded, 1)
Expand Down
6 changes: 3 additions & 3 deletions doc/convolution.md
Original file line number Diff line number Diff line change
Expand Up @@ -513,10 +513,10 @@ w2=image.display(processed)
<a name="nn.SpatialBatchNormalization"/>
## SpatialBatchNormalization ##

`module` = `nn.SpatialBatchNormalization(N [,eps] [, momentum])`
`module` = `nn.SpatialBatchNormalization(N [,eps] [, momentum] [,affine])`
where N = number of input feature maps
giving N = 0 disables the learnable affine transform.
eps is a small value added to the standard-deviation to avoid divide-by-zero. Defaults to 1e-5
`affine` is a boolean. When set to false, the learnable affine transform is disabled. Defaults to true

Implements Batch Normalization as described in the paper:
"Batch Normalization: Accelerating Deep Network Training
Expand Down Expand Up @@ -548,7 +548,7 @@ A = torch.randn(b, m, h, w)
C = model.forward(A) -- C will be of size `b x m x h x w`

-- without learnable parameters
model = nn.SpatialBatchNormalization(0)
model = nn.SpatialBatchNormalization(m, nil, nil, false)
A = torch.randn(b, m, h, w)
C = model.forward(A) -- C will be of size `b x m x h x w`
```
Expand Down
7 changes: 4 additions & 3 deletions doc/simple.md
Original file line number Diff line number Diff line change
Expand Up @@ -909,10 +909,11 @@ C = model.forward({A, B}) -- C will be of size `b x m x n`
## BatchNormalization ##

```lua
module = nn.BatchNormalization(N [, eps] [, momentum])
module = nn.BatchNormalization(N [, eps] [, momentum] [,affine])
```
where `N` is the dimensionality of input, giving `N = 0` disables the learnable affine transform.
where `N` is the dimensionality of input
`eps` is a small value added to the standard-deviation to avoid divide-by-zero. Defaults to `1e-5`.
`affine` is a boolean. When set to false, the learnable affine transform is disabled. Defaults to true

During training, this layer keeps a running estimate of its computed mean and std.
The running sum is kept with a default momentum of 0.1 (unless over-ridden)
Expand All @@ -939,7 +940,7 @@ A = torch.randn(b, m)
C = model.forward(A) -- C will be of size `b x m`

-- without learnable parameters
model = nn.BatchNormalization(0)
model = nn.BatchNormalization(m, nil, nil, false)
A = torch.randn(b, m)
C = model.forward(A) -- C will be of size `b x m`
```
Expand Down
11 changes: 9 additions & 2 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,13 @@ function nntest.Sqrt()
local err = out:dist(in1:sqrt())
mytester:assertlt(err, 1e-15, torch.typename(module) .. ' - forward err ')

-- Test zero inputs; we will avoid a div-by-zero by setting to zero
local zin = torch.DoubleTensor(5, 7):zero()
module:forward(zin)
local zgradout = torch.rand(5, 7)
local zgradin = module:backward(zin, zgradout)
mytester:assertTensorEq(zgradin, torch.DoubleTensor(5, 7):zero(), 0.000001, "error in sqrt backward singularity")

local ini = math.random(3,5)
local inj = math.random(3,5)
local ink = math.random(3,5)
Expand Down Expand Up @@ -3662,7 +3669,7 @@ function nntest.BatchNormalization()
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')

-- batch norm without affine transform
module = nn.BatchNormalization(0)
module = nn.BatchNormalization(indim, 1e-5, 0.1, false)

local err = jac.testJacobian(module,input)
mytester:assertlt(err,precision, 'error on state ')
Expand Down Expand Up @@ -3716,7 +3723,7 @@ function nntest.SpatialBatchNormalization()
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')

-- batch norm without affine transform
module = nn.SpatialBatchNormalization(0)
module = nn.SpatialBatchNormalization(indim, 1e-5, 0.1, false)

local err = jac.testJacobian(module,input)
mytester:assertlt(err,precision, 'error on state ')
Expand Down

0 comments on commit fdd6659

Please sign in to comment.