Skip to content

Commit

Permalink
Merge pull request torch#277 from szagoruyko/master
Browse files Browse the repository at this point in the history
Removed special case for nunit = 1 in Linear
  • Loading branch information
soumith committed May 27, 2015
2 parents e35f09a + ed48291 commit 97f5d1d
Showing 1 changed file with 5 additions and 22 deletions.
27 changes: 5 additions & 22 deletions Linear.lua
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,12 @@ function Linear:updateOutput(input)
self.output:addmv(1, self.weight, input)
elseif input:dim() == 2 then
local nframe = input:size(1)
local nunit = self.bias:size(1)
self.output:resize(nframe, nunit)
self.output:resize(nframe, self.bias:size(1))
if not self.addBuffer or self.addBuffer:nElement() ~= nframe then
self.addBuffer = input.new(nframe):fill(1)
end
if nunit == 1 then
-- Special case to fix output size of 1 bug:
self.output:copy(self.bias:view(1,nunit):expand(#self.output))
self.output:select(2,1):addmv(1, input, self.weight:select(1,1))
else
self.output:addmm(0, self.output, 1, input, self.weight:t())
self.output:addr(1, self.addBuffer, self.bias)
end
self.output:addmm(0, self.output, 1, input, self.weight:t())
self.output:addr(1, self.addBuffer, self.bias)
else
error('input must be vector or matrix')
end
Expand Down Expand Up @@ -79,23 +72,13 @@ end

function Linear:accGradParameters(input, gradOutput, scale)
scale = scale or 1

if input:dim() == 1 then
self.gradWeight:addr(scale, gradOutput, input)
self.gradBias:add(scale, gradOutput)
elseif input:dim() == 2 then
local nunit = self.bias:size(1)

if nunit == 1 then
-- Special case to fix output size of 1 bug:
self.gradWeight:select(1,1):addmv(scale, input:t(), gradOutput:select(2,1))
self.gradBias:addmv(scale, gradOutput:t(), self.addBuffer)
else
self.gradWeight:addmm(scale, gradOutput:t(), input)
self.gradBias:addmv(scale, gradOutput:t(), self.addBuffer)
end
self.gradWeight:addmm(scale, gradOutput:t(), input)
self.gradBias:addmv(scale, gradOutput:t(), self.addBuffer)
end

end

-- we do not need to accumulate parameters when sharing
Expand Down

0 comments on commit 97f5d1d

Please sign in to comment.