From ed48291e63e5564f4cb33cd4ef6f468c8f209dda Mon Sep 17 00:00:00 2001 From: Sergey Zagoruyko Date: Wed, 27 May 2015 20:36:07 +0200 Subject: [PATCH] removed special case --- Linear.lua | 27 +++++---------------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/Linear.lua b/Linear.lua index eccea0ebc..31980bc06 100644 --- a/Linear.lua +++ b/Linear.lua @@ -39,19 +39,12 @@ function Linear:updateOutput(input) self.output:addmv(1, self.weight, input) elseif input:dim() == 2 then local nframe = input:size(1) - local nunit = self.bias:size(1) - self.output:resize(nframe, nunit) + self.output:resize(nframe, self.bias:size(1)) if not self.addBuffer or self.addBuffer:nElement() ~= nframe then self.addBuffer = input.new(nframe):fill(1) end - if nunit == 1 then - -- Special case to fix output size of 1 bug: - self.output:copy(self.bias:view(1,nunit):expand(#self.output)) - self.output:select(2,1):addmv(1, input, self.weight:select(1,1)) - else - self.output:addmm(0, self.output, 1, input, self.weight:t()) - self.output:addr(1, self.addBuffer, self.bias) - end + self.output:addmm(0, self.output, 1, input, self.weight:t()) + self.output:addr(1, self.addBuffer, self.bias) else error('input must be vector or matrix') end @@ -79,23 +72,13 @@ end function Linear:accGradParameters(input, gradOutput, scale) scale = scale or 1 - if input:dim() == 1 then self.gradWeight:addr(scale, gradOutput, input) self.gradBias:add(scale, gradOutput) elseif input:dim() == 2 then - local nunit = self.bias:size(1) - - if nunit == 1 then - -- Special case to fix output size of 1 bug: - self.gradWeight:select(1,1):addmv(scale, input:t(), gradOutput:select(2,1)) - self.gradBias:addmv(scale, gradOutput:t(), self.addBuffer) - else - self.gradWeight:addmm(scale, gradOutput:t(), input) - self.gradBias:addmv(scale, gradOutput:t(), self.addBuffer) - end + self.gradWeight:addmm(scale, gradOutput:t(), input) + self.gradBias:addmv(scale, gradOutput:t(), self.addBuffer) end - end -- we do not need to accumulate parameters when sharing