diff --git a/L2Normalize.lua b/L2Normalize.lua new file mode 100644 index 000000000..f1dfd0e99 --- /dev/null +++ b/L2Normalize.lua @@ -0,0 +1,40 @@ + +--[[ + This layer expects an [n x d] Tensor and normalizes each + row to have unit L2 norm. +]]-- +local L2Normalize, parent = torch.class('nn.L2Normalize', 'nn.Module') +function L2Normalize:__init() + parent.__init(self) +end +function L2Normalize:updateOutput(input) + assert(input:dim() == 2, 'only mini-batch supported (2D tensor), got ' + .. input:dim() .. 'D tensor instead') + self.output:resizeAs(input) + self.buffer = self.buffer or input.new() + self.normSquared = self.normSquared or input.new() + self.normSquared:sum(self.buffer:cmul(input, input), 2) + self.buffer:sqrt(self.normSquared) + self.output:copy(input):cdiv(self.buffer:expandAs(input)) + return self.output +end + +function L2Normalize:updateGradInput(input, gradOutput) + assert(input:dim() == 2, 'only mini-batch supported') + assert(gradOutput:dim() == 2, 'only mini-batch supported') + local n = input:size(1) -- batch size + local d = input:size(2) -- dimensionality of vectors + -- compute diagonal term + self.eye = self.eye or torch.eye(d):typeAs(input):repeatTensor(n,1):view(n,d,d) + self.diag = self.diag or self.eye.new() + self.diag:cmul(self.eye, self.normSquared:view(n,1,1):expand(n,d,d)) + -- compute cross term + local b1 = input:view(n,d,1) + local b2 = input:view(n,1,d) + self.diag:add(-torch.bmm(b1,b2)) + -- compute the local gradient of the L2 transformation + self.diag:cdiv(torch.pow(self.buffer,3):view(n,1,1):expand(n,d,d)) + -- chain the gradient + self.gradInput:resize(n,d,1):bmm(self.diag, gradOutput:view(n,d,1)):resize(n,d) + return self.gradInput +end diff --git a/init.lua b/init.lua index b1d36dbab..520b66e2a 100644 --- a/init.lua +++ b/init.lua @@ -46,6 +46,7 @@ include('WeightedEuclidean.lua') include('PairwiseDistance.lua') include('CosineDistance.lua') include('DotProduct.lua') +include('L2Normalize.lua') include('Exp.lua') include('Log.lua') diff --git a/test.lua b/test.lua index 9414a66d8..959c369c7 100644 --- a/test.lua +++ b/test.lua @@ -3554,6 +3554,29 @@ function nntest.Padding() mytester:assertTensorEq(gradInput, input, 0.00001, "Padding backward error") end +function nntest.L2Normalize() + local ini = math.random(6,8) + local inj = math.random(3,5) + local input = torch.randn(ini, inj) + + local module = nn.L2Normalize() + + -- test correctness of output + local output = module:forward(input) + local norms = torch.norm(output, 2, 2) + local desired_norms = torch.ones(ini) + mytester:assertTensorEq(norms, desired_norms, 0.000001, 'L2Normalize forward err') + + -- test the Jacobian + local err = jac.testJacobian(module,input) + mytester:assertlt(err, precision, 'error on state ') + + -- test IO correctness + local ferr, berr = jac.testIO(module,input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') +end + mytester:add(nntest) if not nn then