diff --git a/LookupTable.lua b/LookupTable.lua index 5b5f5654e..378ab5ece 100644 --- a/LookupTable.lua +++ b/LookupTable.lua @@ -2,156 +2,75 @@ local LookupTable, parent = torch.class('nn.LookupTable', 'nn.Module') LookupTable.__version = 3 -function LookupTable:__init(nIndex, ...) +function LookupTable:__init(nIndex, nOutput) parent.__init(self) - local arg = {...} - if select('#', ...) == 1 and type(arg[1]) ~= "number" then - local size = arg[1] - self.size = torch.LongStorage(#size + 1) - for i=1,#size do - self.size[i+1] = size[i] - end - else - self.size = torch.LongStorage(select('#', ...)+1) - for i=1,select('#',...) do - self.size[i+1] = arg[i] - end - end + self.weight = torch.Tensor(nIndex, nOutput) + self.gradWeight = torch.Tensor(nIndex, nOutput):zero() + self._count = torch.IntTensor() + self._input = torch.LongTensor() - self.size[1] = nIndex - - local batchSize = torch.LongTensor(#self.size + 1) - batchSize:narrow(1, 2,#self.size):copy(torch.LongTensor(self.size)) - batchSize[1] = 1 - self.batchSize = batchSize:storage() - - self.weight = torch.Tensor(self.size) - self.gradWeight = torch.Tensor(self.size):zero() - self.inputs = {} - - self.accUpdate = false + self.shouldScaleGradByFreq = false - self.nBackward = 0 self:reset() end function LookupTable:accUpdateOnly() - self.accUpdate = true self.gradWeight = nil + return self +end + +function LookupTable:scaleGradByFreq() + self.shouldScaleGradByFreq = true + return self end function LookupTable:reset(stdv) stdv = stdv or 1 - if nn.oldSeed then - self.weight:apply(function() - return torch.normal(0, stdv) - end) - else - self.weight:normal(0, stdv) - end + self.weight:normal(0, stdv) end -function LookupTable:updateOutput(input) +function LookupTable:makeInputContiguous(input) -- make sure input is a contiguous torch.LongTensor - if (not input:isContiguous()) or torch.type(input) ~= 'torch.LongTensor' then - self._indices = self._indices or torch.LongTensor() - self._indices:resize(input:size()):copy(input) - input = self._indices + if (not input:isContiguous()) or torch.type(input) ~= torch.type(self._input) then + self._input:resize(input:size()):copy(input) + return self._input end - + return input +end + +function LookupTable:updateOutput(input) + input = self:makeInputContiguous(input) if input:dim() == 1 then - local nIndex = input:size(1) - self.size[1] = nIndex self.output:index(self.weight, 1, input) elseif input:dim() == 2 then - local nExample = input:size(1) - local nIndex = input:size(2) - self.batchSize[1] = nExample - self.batchSize[2] = nIndex - - self._inputView = self._inputView or torch.LongTensor() - self._inputView:view(input, -1) - self.output:index(self.weight, 1, self._inputView) - self.output = self.output:view(nExample, nIndex, self.size[2]) + self.output:index(self.weight, 1, input:view(-1)) + self.output = self.output:view(input:size(1), input:size(2), self.weight:size(2)) + else + error("input must be a vector or matrix") end - return self.output end -function LookupTable:zeroGradParameters() - if not self.accUpdate then - for k,_ in pairs(self.inputs) do - self.gradWeight:select(1, k):zero() - end - end - self.inputs = {} - self.nBackward = 0 -end - function LookupTable:accGradParameters(input, gradOutput, scale) - scale = scale or 1 - if input:dim() == 1 then - self.nBackward = self.nBackward + 1 - for i=1,input:size(1) do - local k = input[i] - self.inputs[k] = (self.inputs[k] or 0) + 1 - self.gradWeight:select(1, k):add(scale, gradOutput:select(1, i)) - end - elseif input:dim() == 2 then - self.nBackward = self.nBackward + input:size(1) - for i=1,input:size(1) do - local input = input:select(1, i) - local gradOutput = gradOutput:select(1, i) - for j=1,input:size(1) do - local k = input[j] - self.inputs[k] = (self.inputs[k] or 0) + 1 - self.gradWeight:select(1, k):add(scale, gradOutput:select(1, j)) - end - end - end -end - -function LookupTable:accUpdateGradParameters(input, gradOutput, lr) - if input:dim() == 1 then - for i=1,input:size(1) do - local k = input[i] - local kscale = self:scaleUpdateByKey(k) - self.inputs[k] = (self.inputs[k] or 0) + 1 - self.weight:select(1, input[i]):add(-lr*kscale, gradOutput:select(1, i)) - end - elseif input:dim() == 2 then - for i=1,input:size(1) do - local input = input:select(1, i) - local gradOutput = gradOutput:select(1, i) - for j=1,input:size(1) do - local k = input[j] - local kscale = self:scaleUpdateByKey(k) - self.inputs[k] = (self.inputs[k] or 0) + 1 - self.weight:select(1, k):add(-lr*kscale, gradOutput:select(1, j)) - end - end - end -end - -function LookupTable:updateParameters(learningRate) - assert(not self.accUpdate, "use accUpdateGradParameters instead") - for k,nBackward in pairs(self.inputs) do - local kscale = self:scaleUpdateByKey(k) - self.weight:select(1, k):add(-learningRate*kscale, self.gradWeight:select(1, k)) - end + input = self:makeInputContiguous(input) + self.gradWeight.nn.LookupTable_accGradParameters(self, input, gradOutput, scale) end function LookupTable:type(type) - self._indices = nil - self._inputView = nil parent.type(self, type) -end --- scale the update for each key -function LookupTable:scaleUpdateByKey(inputKey) - -- default is to perform no key-based scalling - return 1 + if type == 'torch.CudaTensor' then + -- CUDA uses _sorted and _indices temporary tensors + self._sorted = self.weight.new() + self._indices = self.weight.new() + else + -- self._count and self._input should only be converted if using Cuda + self._count = torch.IntTensor() + self._input = torch.LongTensor() + end + + return self end -- we do not need to accumulate parameters when sharing diff --git a/generic/LookupTable.c b/generic/LookupTable.c new file mode 100644 index 000000000..c47a92935 --- /dev/null +++ b/generic/LookupTable.c @@ -0,0 +1,119 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/LookupTable.c" +#else + +static void nn_(LookupTable_resetCount)(int *count_data, THLongTensor *input) +{ + int i; + long *input_data = THLongTensor_data(input); + long numel = THLongTensor_nElement(input); + + for (i = 0; isize[0]); + count_data = THIntTensor_data(count); + } + + if (!THTensor_(isContiguous)(gradWeight)) + luaL_error(L, "gradWeight must be contiguous"); + if (!THLongTensor_isContiguous(input)) + luaL_error(L, "input must be contiguous"); + if (input->nDimension != 1 && input->nDimension != 2) + luaL_error(L, "input must be a vector or matrix"); + + long *input_data = THLongTensor_data(input); + long numel = THLongTensor_nElement(input); + long numw = THTensor_(size)(gradWeight, 0); + + // check that inputs are all within range + for (i=0; i numw) + THError("input out of range"); + + gradOutput = THTensor_(newContiguous)(gradOutput); + + real *gw = THTensor_(data)(gradWeight); + real *go = THTensor_(data)(gradOutput); + long stride = THTensor_(stride)(gradWeight, 0); + + if (count_data) + nn_(LookupTable_resetCount)(count_data, input); + +#ifdef _OPENMP + if (numel > 1000) + { + // The strategy is to parallelize over sections of the vocabulary, so that + // thread 1 handles updates to gradWeight[0..nVocab/nThreads]. Every thread + // has to traverse the entire input, but the dominating factor is the axpy + // BLAS call. + #pragma omp parallel private(i) + { + int tid = omp_get_thread_num(); + int nthreads = omp_get_num_threads(); + + long start = tid * (numw/nthreads + 1); + long end = start + (numw/nthreads + 1); + for (i=0; i= start && k < end) + { + real scale = lr; + if (count_data) scale /= count_data[k]; + THBlas_(axpy)(stride, scale, go + i*stride, 1, gw + k*stride, 1); + } + } + } + + THTensor_(free)(gradOutput); + return 0; + } +#endif + + for (i=0; i