Skip to content

Commit

Permalink
Merge pull request torch#261 from torch/convfix
Browse files Browse the repository at this point in the history
fixing typing in SpatialConvolution
  • Loading branch information
soumith committed May 7, 2015
2 parents ed4653c + 3e6a5d4 commit 28b0d2a
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions SpatialConvolution.lua
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function SpatialConvolution:__init(nInputPlane, nOutputPlane, kW, kH, dW, dH, pa
self.bias = torch.Tensor(nOutputPlane)
self.gradWeight = torch.Tensor(nOutputPlane, nInputPlane, kH, kW)
self.gradBias = torch.Tensor(nOutputPlane)

self:reset()
end

Expand All @@ -35,7 +35,7 @@ function SpatialConvolution:reset(stdv)
end)
self.bias:apply(function()
return torch.uniform(-stdv, stdv)
end)
end)
else
self.weight:uniform(-stdv, stdv)
self.bias:uniform(-stdv, stdv)
Expand All @@ -46,10 +46,10 @@ local function backCompatibility(self)
self.finput = self.finput or self.weight.new()
self.fgradInput = self.fgradInput or self.weight.new()
self.padding = self.padding or 0
if self.weight:dim() == 2 then
if self.weight:dim() == 2 then
self.weight = self.weight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW)
end
if self.gradWeight and self.gradWeight:dim() == 2 then
if self.gradWeight and self.gradWeight:dim() == 2 then
self.gradWeight = self.gradWeight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW)
end
end
Expand Down Expand Up @@ -109,3 +109,9 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale)
unviewWeight(self)
return out
end

function SpatialConvolution:type(type)
self.finput = torch.Tensor()
self.fgradInput = torch.Tensor()
return parent.type(self,type)
end

0 comments on commit 28b0d2a

Please sign in to comment.