diff --git a/LookupTable.lua b/LookupTable.lua index 0ca10bc87..3b2798b19 100644 --- a/LookupTable.lua +++ b/LookupTable.lua @@ -55,7 +55,14 @@ function LookupTable:updateOutput(input) end function LookupTable:accGradParameters(input, gradOutput, scale) - self.gradWeight.nn.LookupTable_accGradParameters(self, self.copiedInput and self._input or input, gradOutput, scale) + input = self.copiedInput and self._input or input + if input:dim() == 2 then + input = input:view(-1) + elseif input:dim() ~= 1 then + error("input must be a vector or matrix") + end + + self.gradWeight.nn.LookupTable_accGradParameters(self, input, gradOutput, scale) end function LookupTable:type(type)