Skip to content

Commit

Permalink
Merge pull request torch#270 from colesbury/lookup
Browse files Browse the repository at this point in the history
Speed up LookupTable
  • Loading branch information
soumith committed Jun 16, 2015
2 parents ef6b2aa + 9716078 commit fdfcd12
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 120 deletions.
159 changes: 39 additions & 120 deletions LookupTable.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,156 +2,75 @@ local LookupTable, parent = torch.class('nn.LookupTable', 'nn.Module')

LookupTable.__version = 3

function LookupTable:__init(nIndex, ...)
function LookupTable:__init(nIndex, nOutput)
parent.__init(self)
local arg = {...}

if select('#', ...) == 1 and type(arg[1]) ~= "number" then
local size = arg[1]
self.size = torch.LongStorage(#size + 1)
for i=1,#size do
self.size[i+1] = size[i]
end
else
self.size = torch.LongStorage(select('#', ...)+1)
for i=1,select('#',...) do
self.size[i+1] = arg[i]
end
end
self.weight = torch.Tensor(nIndex, nOutput)
self.gradWeight = torch.Tensor(nIndex, nOutput):zero()
self._count = torch.IntTensor()
self._input = torch.LongTensor()

self.size[1] = nIndex

local batchSize = torch.LongTensor(#self.size + 1)
batchSize:narrow(1, 2,#self.size):copy(torch.LongTensor(self.size))
batchSize[1] = 1
self.batchSize = batchSize:storage()

self.weight = torch.Tensor(self.size)
self.gradWeight = torch.Tensor(self.size):zero()
self.inputs = {}

self.accUpdate = false
self.shouldScaleGradByFreq = false

self.nBackward = 0
self:reset()
end

function LookupTable:accUpdateOnly()
self.accUpdate = true
self.gradWeight = nil
return self
end

function LookupTable:scaleGradByFreq()
self.shouldScaleGradByFreq = true
return self
end

function LookupTable:reset(stdv)
stdv = stdv or 1
if nn.oldSeed then
self.weight:apply(function()
return torch.normal(0, stdv)
end)
else
self.weight:normal(0, stdv)
end
self.weight:normal(0, stdv)
end

function LookupTable:updateOutput(input)
function LookupTable:makeInputContiguous(input)
-- make sure input is a contiguous torch.LongTensor
if (not input:isContiguous()) or torch.type(input) ~= 'torch.LongTensor' then
self._indices = self._indices or torch.LongTensor()
self._indices:resize(input:size()):copy(input)
input = self._indices
if (not input:isContiguous()) or torch.type(input) ~= torch.type(self._input) then
self._input:resize(input:size()):copy(input)
return self._input
end

return input
end

function LookupTable:updateOutput(input)
input = self:makeInputContiguous(input)
if input:dim() == 1 then
local nIndex = input:size(1)
self.size[1] = nIndex
self.output:index(self.weight, 1, input)
elseif input:dim() == 2 then
local nExample = input:size(1)
local nIndex = input:size(2)
self.batchSize[1] = nExample
self.batchSize[2] = nIndex

self._inputView = self._inputView or torch.LongTensor()
self._inputView:view(input, -1)
self.output:index(self.weight, 1, self._inputView)
self.output = self.output:view(nExample, nIndex, self.size[2])
self.output:index(self.weight, 1, input:view(-1))
self.output = self.output:view(input:size(1), input:size(2), self.weight:size(2))
else
error("input must be a vector or matrix")
end

return self.output
end

function LookupTable:zeroGradParameters()
if not self.accUpdate then
for k,_ in pairs(self.inputs) do
self.gradWeight:select(1, k):zero()
end
end
self.inputs = {}
self.nBackward = 0
end

function LookupTable:accGradParameters(input, gradOutput, scale)
scale = scale or 1
if input:dim() == 1 then
self.nBackward = self.nBackward + 1
for i=1,input:size(1) do
local k = input[i]
self.inputs[k] = (self.inputs[k] or 0) + 1
self.gradWeight:select(1, k):add(scale, gradOutput:select(1, i))
end
elseif input:dim() == 2 then
self.nBackward = self.nBackward + input:size(1)
for i=1,input:size(1) do
local input = input:select(1, i)
local gradOutput = gradOutput:select(1, i)
for j=1,input:size(1) do
local k = input[j]
self.inputs[k] = (self.inputs[k] or 0) + 1
self.gradWeight:select(1, k):add(scale, gradOutput:select(1, j))
end
end
end
end

function LookupTable:accUpdateGradParameters(input, gradOutput, lr)
if input:dim() == 1 then
for i=1,input:size(1) do
local k = input[i]
local kscale = self:scaleUpdateByKey(k)
self.inputs[k] = (self.inputs[k] or 0) + 1
self.weight:select(1, input[i]):add(-lr*kscale, gradOutput:select(1, i))
end
elseif input:dim() == 2 then
for i=1,input:size(1) do
local input = input:select(1, i)
local gradOutput = gradOutput:select(1, i)
for j=1,input:size(1) do
local k = input[j]
local kscale = self:scaleUpdateByKey(k)
self.inputs[k] = (self.inputs[k] or 0) + 1
self.weight:select(1, k):add(-lr*kscale, gradOutput:select(1, j))
end
end
end
end

function LookupTable:updateParameters(learningRate)
assert(not self.accUpdate, "use accUpdateGradParameters instead")
for k,nBackward in pairs(self.inputs) do
local kscale = self:scaleUpdateByKey(k)
self.weight:select(1, k):add(-learningRate*kscale, self.gradWeight:select(1, k))
end
input = self:makeInputContiguous(input)
self.gradWeight.nn.LookupTable_accGradParameters(self, input, gradOutput, scale)
end

function LookupTable:type(type)
self._indices = nil
self._inputView = nil
parent.type(self, type)
end

-- scale the update for each key
function LookupTable:scaleUpdateByKey(inputKey)
-- default is to perform no key-based scalling
return 1
if type == 'torch.CudaTensor' then
-- CUDA uses _sorted and _indices temporary tensors
self._sorted = self.weight.new()
self._indices = self.weight.new()
else
-- self._count and self._input should only be converted if using Cuda
self._count = torch.IntTensor()
self._input = torch.LongTensor()
end

return self
end

-- we do not need to accumulate parameters when sharing
Expand Down
119 changes: 119 additions & 0 deletions generic/LookupTable.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/LookupTable.c"
#else

static void nn_(LookupTable_resetCount)(int *count_data, THLongTensor *input)
{
int i;
long *input_data = THLongTensor_data(input);
long numel = THLongTensor_nElement(input);

for (i = 0; i<numel; i++)
{
long k = input_data[i] - 1;
count_data[k] = 0;
}
for (i = 0; i<numel; i++)
{
long k = input_data[i] - 1;
count_data[k]++;
}
}

static int nn_(LookupTable_accGradParameters)(lua_State *L)
{
long i;
THLongTensor *input = luaT_checkudata(L, 2, "torch.LongTensor");
THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
real lr = luaL_optnumber(L, 4, 1);
THTensor *gradWeight = luaT_getfieldcheckudata(L, 1, "gradWeight", torch_Tensor);
int *count_data = NULL;

if (luaT_getfieldcheckboolean(L, 1, "shouldScaleGradByFreq"))
{
THIntTensor *count = luaT_getfieldcheckudata(L, 1, "_count", "torch.IntTensor");
THIntTensor_resize1d(count, gradWeight->size[0]);
count_data = THIntTensor_data(count);
}

if (!THTensor_(isContiguous)(gradWeight))
luaL_error(L, "gradWeight must be contiguous");
if (!THLongTensor_isContiguous(input))
luaL_error(L, "input must be contiguous");
if (input->nDimension != 1 && input->nDimension != 2)
luaL_error(L, "input must be a vector or matrix");

long *input_data = THLongTensor_data(input);
long numel = THLongTensor_nElement(input);
long numw = THTensor_(size)(gradWeight, 0);

// check that inputs are all within range
for (i=0; i<numel; i++)
if (input_data[i] < 1 || input_data[i] > numw)
THError("input out of range");

gradOutput = THTensor_(newContiguous)(gradOutput);

real *gw = THTensor_(data)(gradWeight);
real *go = THTensor_(data)(gradOutput);
long stride = THTensor_(stride)(gradWeight, 0);

if (count_data)
nn_(LookupTable_resetCount)(count_data, input);

#ifdef _OPENMP
if (numel > 1000)
{
// The strategy is to parallelize over sections of the vocabulary, so that
// thread 1 handles updates to gradWeight[0..nVocab/nThreads]. Every thread
// has to traverse the entire input, but the dominating factor is the axpy
// BLAS call.
#pragma omp parallel private(i)
{
int tid = omp_get_thread_num();
int nthreads = omp_get_num_threads();

long start = tid * (numw/nthreads + 1);
long end = start + (numw/nthreads + 1);
for (i=0; i<numel; i++)
{
long k = input_data[i] - 1;
if (k >= start && k < end)
{
real scale = lr;
if (count_data) scale /= count_data[k];
THBlas_(axpy)(stride, scale, go + i*stride, 1, gw + k*stride, 1);
}
}
}

THTensor_(free)(gradOutput);
return 0;
}
#endif

for (i=0; i<numel; i++)
{
long k = input_data[i] - 1;
real scale = lr;
if (count_data) scale /= count_data[k];
THBlas_(axpy)(stride, scale, go + i*stride, 1, gw + k*stride, 1);
}

THTensor_(free)(gradOutput);
return 0;
}

static const struct luaL_Reg nn_(LookupTable__) [] = {
{"LookupTable_accGradParameters", nn_(LookupTable_accGradParameters)},
{NULL, NULL}
};

static void nn_(LookupTable_init)(lua_State *L)
{
luaT_pushmetatable(L, torch_Tensor);
luaT_registeratname(L, nn_(LookupTable__), "nn");
lua_pop(L,1);
}

#endif
5 changes: 5 additions & 0 deletions init.c
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@
#include "generic/SpatialUpSamplingNearest.c"
#include "THGenerateFloatTypes.h"

#include "generic/LookupTable.c"
#include "THGenerateFloatTypes.h"

LUA_EXTERNC DLL_EXPORT int luaopen_libnn(lua_State *L);

int luaopen_libnn(lua_State *L)
Expand Down Expand Up @@ -173,6 +176,7 @@ int luaopen_libnn(lua_State *L)
nn_FloatMultiLabelMarginCriterion_init(L);
nn_FloatL1Cost_init(L);
nn_FloatSpatialUpSamplingNearest_init(L);
nn_FloatLookupTable_init(L);

nn_DoubleMin_init(L);
nn_DoubleMax_init(L);
Expand Down Expand Up @@ -214,6 +218,7 @@ int luaopen_libnn(lua_State *L)
nn_DoubleMultiLabelMarginCriterion_init(L);
nn_DoubleL1Cost_init(L);
nn_DoubleSpatialUpSamplingNearest_init(L);
nn_DoubleLookupTable_init(L);

return 1;
}

0 comments on commit fdfcd12

Please sign in to comment.