Skip to content

Commit

Permalink
Adding Batch L2 Normalization Layer that makes all rows of input Tens…
Browse files Browse the repository at this point in the history
…or unit L2 norm
  • Loading branch information
karpathy authored and soumith committed May 13, 2015
1 parent 28b0d2a commit 905ea8c
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 0 deletions.
40 changes: 40 additions & 0 deletions L2Normalize.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@

--[[
This layer expects an [n x d] Tensor and normalizes each
row to have unit L2 norm.
]]--
local L2Normalize, parent = torch.class('nn.L2Normalize', 'nn.Module')
function L2Normalize:__init()
parent.__init(self)
end
function L2Normalize:updateOutput(input)
assert(input:dim() == 2, 'only mini-batch supported (2D tensor), got '
.. input:dim() .. 'D tensor instead')
self.output:resizeAs(input)
self.buffer = self.buffer or input.new()
self.normSquared = self.normSquared or input.new()
self.normSquared:sum(self.buffer:cmul(input, input), 2)
self.buffer:sqrt(self.normSquared)
self.output:copy(input):cdiv(self.buffer:expandAs(input))
return self.output
end

function L2Normalize:updateGradInput(input, gradOutput)
assert(input:dim() == 2, 'only mini-batch supported')
assert(gradOutput:dim() == 2, 'only mini-batch supported')
local n = input:size(1) -- batch size
local d = input:size(2) -- dimensionality of vectors
-- compute diagonal term
self.eye = self.eye or torch.eye(d):typeAs(input):repeatTensor(n,1):view(n,d,d)
self.diag = self.diag or self.eye.new()
self.diag:cmul(self.eye, self.normSquared:view(n,1,1):expand(n,d,d))
-- compute cross term
local b1 = input:view(n,d,1)
local b2 = input:view(n,1,d)
self.diag:add(-torch.bmm(b1,b2))
-- compute the local gradient of the L2 transformation
self.diag:cdiv(torch.pow(self.buffer,3):view(n,1,1):expand(n,d,d))
-- chain the gradient
self.gradInput:resize(n,d,1):bmm(self.diag, gradOutput:view(n,d,1)):resize(n,d)
return self.gradInput
end
1 change: 1 addition & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ include('WeightedEuclidean.lua')
include('PairwiseDistance.lua')
include('CosineDistance.lua')
include('DotProduct.lua')
include('L2Normalize.lua')

include('Exp.lua')
include('Log.lua')
Expand Down
23 changes: 23 additions & 0 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3554,6 +3554,29 @@ function nntest.Padding()
mytester:assertTensorEq(gradInput, input, 0.00001, "Padding backward error")
end

function nntest.L2Normalize()
local ini = math.random(6,8)
local inj = math.random(3,5)
local input = torch.randn(ini, inj)

local module = nn.L2Normalize()

-- test correctness of output
local output = module:forward(input)
local norms = torch.norm(output, 2, 2)
local desired_norms = torch.ones(ini)
mytester:assertTensorEq(norms, desired_norms, 0.000001, 'L2Normalize forward err')

-- test the Jacobian
local err = jac.testJacobian(module,input)
mytester:assertlt(err, precision, 'error on state ')

-- test IO correctness
local ferr, berr = jac.testIO(module,input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
end

mytester:add(nntest)

if not nn then
Expand Down

0 comments on commit 905ea8c

Please sign in to comment.