From 83a3815dc70255c978405e8e966d7b02d580cc11 Mon Sep 17 00:00:00 2001 From: soumith Date: Thu, 14 May 2015 14:29:42 -0700 Subject: [PATCH] batchnorm is clonable by adding the running estimates to constructor fixing batchnorm tests --- BatchNormalization.lua | 26 ++++++++++++-------------- SpatialBatchNormalization.lua | 26 ++++++++++++-------------- doc/convolution.md | 6 +++--- doc/simple.md | 7 ++++--- test.lua | 11 +++++++++-- 5 files changed, 40 insertions(+), 36 deletions(-) diff --git a/BatchNormalization.lua b/BatchNormalization.lua index 85f9bdb08..ba96cc271 100644 --- a/BatchNormalization.lua +++ b/BatchNormalization.lua @@ -29,18 +29,24 @@ ]]-- local BN,parent = torch.class('nn.BatchNormalization', 'nn.Module') -function BN:__init(nOutput, eps, momentum) +function BN:__init(nOutput, eps, momentum, affine) parent.__init(self) assert(nOutput and type(nOutput) == 'number', - 'Missing argument #1: dimensionality of input. ' .. - 'Give 0 for no affine transform') + 'Missing argument #1: dimensionality of input. ') + assert(nOutput ~= 0, 'To set affine=false call BatchNormalization' + .. '(nOutput, eps, momentum, false) ') + if affine ~= nil then + assert(type(affine) == 'boolean', 'affine has to be true/false') + self.affine = affine + else + self.affine = true + end self.eps = eps or 1e-5 self.train = true self.momentum = momentum or 0.1 - self.running_mean = torch.Tensor() - self.running_std = torch.Tensor() + self.running_mean = torch.zeros(nOutput) + self.running_std = torch.ones(nOutput) - if nOutput > 0 then self.affine = true end if self.affine then self.weight = torch.Tensor(nOutput) self.bias = torch.Tensor(nOutput) @@ -71,20 +77,12 @@ function BN:updateOutput(input) self.output:resizeAs(input) self.gradInput:resizeAs(input) if self.train == false then - assert(self.running_mean:nDimension() ~= 0, - 'Module never run on training data. First run on some training data before evaluating.') self.output:copy(input) self.buffer:repeatTensor(self.running_mean, nBatch, 1) self.output:add(-1, self.buffer) self.buffer:repeatTensor(self.running_std, nBatch, 1) self.output:cmul(self.buffer) else -- training mode - if self.running_mean:nDimension() == 0 then - self.running_mean:resize(input:size(2)):zero() - end - if self.running_std:nDimension() == 0 then - self.running_std:resize(input:size(2)):zero() - end -- calculate mean over mini-batch self.buffer:mean(input, 1) -- E(x) = expectation of x. self.running_mean:mul(1 - self.momentum):add(self.momentum, self.buffer) -- add to running mean diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua index 3f09c3f5b..cbc50d310 100644 --- a/SpatialBatchNormalization.lua +++ b/SpatialBatchNormalization.lua @@ -30,18 +30,24 @@ ]]-- local BN,parent = torch.class('nn.SpatialBatchNormalization', 'nn.Module') -function BN:__init(nFeature, eps, momentum) +function BN:__init(nFeature, eps, momentum, affine) parent.__init(self) assert(nFeature and type(nFeature) == 'number', - 'Missing argument #1: Number of feature planes. ' .. - 'Give 0 for no affine transform') + 'Missing argument #1: Number of feature planes. ') + assert(nFeature ~= 0, 'To set affine=false call SpatialBatchNormalization' + .. '(nFeature, eps, momentum, false) ') + if affine ~=nil then + assert(type(affine) == 'boolean', 'affine has to be true/false') + self.affine = affine + else + self.affine = true + end self.eps = eps or 1e-5 self.train = true self.momentum = momentum or 0.1 - self.running_mean = torch.Tensor() - self.running_std = torch.Tensor() - if nFeature > 0 then self.affine = true end + self.running_mean = torch.zeros(nFeature) + self.running_std = torch.ones(nFeature) if self.affine then self.weight = torch.Tensor(nFeature) self.bias = torch.Tensor(nFeature) @@ -75,20 +81,12 @@ function BN:updateOutput(input) self.output:resizeAs(input) self.gradInput:resizeAs(input) if self.train == false then - assert(self.running_mean:nDimension() ~= 0, - 'Module never run on training data. First run on some training data before evaluating.') self.output:copy(input) self.buffer:repeatTensor(self.running_mean:view(1, nFeature, 1, 1), nBatch, 1, iH, iW) self.output:add(-1, self.buffer) self.buffer:repeatTensor(self.running_std:view(1, nFeature, 1, 1), nBatch, 1, iH, iW) self.output:cmul(self.buffer) else -- training mode - if self.running_mean:nDimension() == 0 then - self.running_mean:resize(nFeature):zero() - end - if self.running_std:nDimension() == 0 then - self.running_std:resize(nFeature):zero() - end -- calculate mean over mini-batch, over feature-maps local in_folded = input:view(nBatch, nFeature, iH * iW) self.buffer:mean(in_folded, 1) diff --git a/doc/convolution.md b/doc/convolution.md index a94d44126..906c164d7 100755 --- a/doc/convolution.md +++ b/doc/convolution.md @@ -512,10 +512,10 @@ w2=image.display(processed) ## SpatialBatchNormalization ## -`module` = `nn.SpatialBatchNormalization(N [,eps] [, momentum])` +`module` = `nn.SpatialBatchNormalization(N [,eps] [, momentum] [,affine])` where N = number of input feature maps -giving N = 0 disables the learnable affine transform. eps is a small value added to the standard-deviation to avoid divide-by-zero. Defaults to 1e-5 +`affine` is a boolean. When set to false, the learnable affine transform is disabled. Defaults to true Implements Batch Normalization as described in the paper: "Batch Normalization: Accelerating Deep Network Training @@ -547,7 +547,7 @@ A = torch.randn(b, m, h, w) C = model.forward(A) -- C will be of size `b x m x h x w` -- without learnable parameters -model = nn.SpatialBatchNormalization(0) +model = nn.SpatialBatchNormalization(m, nil, nil, false) A = torch.randn(b, m, h, w) C = model.forward(A) -- C will be of size `b x m x h x w` ``` diff --git a/doc/simple.md b/doc/simple.md index 7c97a6b8c..44432976b 100755 --- a/doc/simple.md +++ b/doc/simple.md @@ -905,10 +905,11 @@ C = model.forward({A, B}) -- C will be of size `b x m x n` ## BatchNormalization ## ```lua -module = nn.BatchNormalization(N [, eps] [, momentum]) +module = nn.BatchNormalization(N [, eps] [, momentum] [,affine]) ``` -where `N` is the dimensionality of input, giving `N = 0` disables the learnable affine transform. +where `N` is the dimensionality of input `eps` is a small value added to the standard-deviation to avoid divide-by-zero. Defaults to `1e-5`. +`affine` is a boolean. When set to false, the learnable affine transform is disabled. Defaults to true During training, this layer keeps a running estimate of its computed mean and std. The running sum is kept with a default momentum of 0.1 (unless over-ridden) @@ -935,7 +936,7 @@ A = torch.randn(b, m) C = model.forward(A) -- C will be of size `b x m` -- without learnable parameters -model = nn.BatchNormalization(0) +model = nn.BatchNormalization(m, nil, nil, false) A = torch.randn(b, m) C = model.forward(A) -- C will be of size `b x m` ``` diff --git a/test.lua b/test.lua index d82e470fe..1b4847cfc 100644 --- a/test.lua +++ b/test.lua @@ -438,6 +438,13 @@ function nntest.Sqrt() local err = out:dist(in1:sqrt()) mytester:assertlt(err, 1e-15, torch.typename(module) .. ' - forward err ') + -- Test zero inputs; we will avoid a div-by-zero by setting to zero + local zin = torch.DoubleTensor(5, 7):zero() + module:forward(zin) + local zgradout = torch.rand(5, 7) + local zgradin = module:backward(zin, zgradout) + mytester:assertTensorEq(zgradin, torch.DoubleTensor(5, 7):zero(), 0.000001, "error in sqrt backward singularity") + local ini = math.random(3,5) local inj = math.random(3,5) local ink = math.random(3,5) @@ -3471,7 +3478,7 @@ function nntest.BatchNormalization() mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') -- batch norm without affine transform - module = nn.BatchNormalization(0) + module = nn.BatchNormalization(indim, 1e-5, 0.1, false) local err = jac.testJacobian(module,input) mytester:assertlt(err,precision, 'error on state ') @@ -3525,7 +3532,7 @@ function nntest.SpatialBatchNormalization() mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') -- batch norm without affine transform - module = nn.SpatialBatchNormalization(0) + module = nn.SpatialBatchNormalization(indim, 1e-5, 0.1, false) local err = jac.testJacobian(module,input) mytester:assertlt(err,precision, 'error on state ')