Skip to content

Commit

Permalink
Fix CUDA LookupTable with 2D inputs.
Browse files Browse the repository at this point in the history
The CUDA version of LookupTable was incorrect for 2D inputs. This fixes
it by viewing 2D inputs as a 1D tensor.
  • Loading branch information
colesbury committed Jun 19, 2015
1 parent 3550f31 commit 9fcfe12
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion LookupTable.lua
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,14 @@ function LookupTable:updateOutput(input)
end

function LookupTable:accGradParameters(input, gradOutput, scale)
self.gradWeight.nn.LookupTable_accGradParameters(self, self.copiedInput and self._input or input, gradOutput, scale)
input = self.copiedInput and self._input or input
if input:dim() == 2 then
input = input:view(-1)
elseif input:dim() ~= 1 then
error("input must be a vector or matrix")
end

self.gradWeight.nn.LookupTable_accGradParameters(self, input, gradOutput, scale)
end

function LookupTable:type(type)
Expand Down

0 comments on commit 9fcfe12

Please sign in to comment.