diff --git a/Linear.lua b/Linear.lua index eccea0ebc..31980bc06 100644 --- a/Linear.lua +++ b/Linear.lua @@ -39,19 +39,12 @@ function Linear:updateOutput(input) self.output:addmv(1, self.weight, input) elseif input:dim() == 2 then local nframe = input:size(1) - local nunit = self.bias:size(1) - self.output:resize(nframe, nunit) + self.output:resize(nframe, self.bias:size(1)) if not self.addBuffer or self.addBuffer:nElement() ~= nframe then self.addBuffer = input.new(nframe):fill(1) end - if nunit == 1 then - -- Special case to fix output size of 1 bug: - self.output:copy(self.bias:view(1,nunit):expand(#self.output)) - self.output:select(2,1):addmv(1, input, self.weight:select(1,1)) - else - self.output:addmm(0, self.output, 1, input, self.weight:t()) - self.output:addr(1, self.addBuffer, self.bias) - end + self.output:addmm(0, self.output, 1, input, self.weight:t()) + self.output:addr(1, self.addBuffer, self.bias) else error('input must be vector or matrix') end @@ -79,23 +72,13 @@ end function Linear:accGradParameters(input, gradOutput, scale) scale = scale or 1 - if input:dim() == 1 then self.gradWeight:addr(scale, gradOutput, input) self.gradBias:add(scale, gradOutput) elseif input:dim() == 2 then - local nunit = self.bias:size(1) - - if nunit == 1 then - -- Special case to fix output size of 1 bug: - self.gradWeight:select(1,1):addmv(scale, input:t(), gradOutput:select(2,1)) - self.gradBias:addmv(scale, gradOutput:t(), self.addBuffer) - else - self.gradWeight:addmm(scale, gradOutput:t(), input) - self.gradBias:addmv(scale, gradOutput:t(), self.addBuffer) - end + self.gradWeight:addmm(scale, gradOutput:t(), input) + self.gradBias:addmv(scale, gradOutput:t(), self.addBuffer) end - end -- we do not need to accumulate parameters when sharing