diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index ac98d9d30..4a42b637f 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -73,12 +73,16 @@ end -- function to re-view the weight layout in a way that would make the MM ops happy local function viewWeight(self) self.weight = self.weight:view(self.nOutputPlane, self.nInputPlane * self.kH * self.kW) - self.gradWeight = self.gradWeight and self.gradWeight:view(self.nOutputPlane, self.nInputPlane * self.kH * self.kW) + if self.gradWeight and self.gradWeight:dim() > 0 then + self.gradWeight = self.gradWeight:view(self.nOutputPlane, self.nInputPlane * self.kH * self.kW) + end end local function unviewWeight(self) self.weight = self.weight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW) - self.gradWeight = self.gradWeight and self.gradWeight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW) + if self.gradWeight and self.gradWeight:dim() > 0 then + self.gradWeight = self.gradWeight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW) + end end function SpatialConvolution:updateOutput(input)