diff --git a/Dropout.lua b/Dropout.lua index ee9388e09..66eda2107 100644 --- a/Dropout.lua +++ b/Dropout.lua @@ -40,3 +40,7 @@ end function Dropout:setp(p) self.p = p end + +function Dropout:__tostring__() + return string.format('%s(%f)', torch.type(self), self.p) +end diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index 4a42b637f..c5fce40cb 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -119,3 +119,15 @@ function SpatialConvolution:type(type) self.fgradInput = torch.Tensor() return parent.type(self,type) end + +function SpatialConvolution:__tostring__() + local s = string.format('%s(in: %d, out: %d, kW: %d, kH: %d', torch.type(self), + self.nInputPlane, self.nOutputPlane, self.kW, self.kH) + if self.dW ~= 1 or self.dH ~= 1 then + s = s .. string.format(', dW: %d, dH: %d', self.dW, self.dH) + end + if self.padding ~= 0 then + s = s .. ', padding: ' .. self.padding + end + return s .. ')' +end diff --git a/SpatialConvolutionMM.lua b/SpatialConvolutionMM.lua index f27c6fae6..e9352573b 100644 --- a/SpatialConvolutionMM.lua +++ b/SpatialConvolutionMM.lua @@ -83,3 +83,15 @@ function SpatialConvolutionMM:type(type) self.fgradInput = torch.Tensor() return parent.type(self,type) end + +function SpatialConvolutionMM:__tostring__() + local s = string.format('%s(in: %d, out: %d, kW: %d, kH: %d', torch.type(self), + self.nInputPlane, self.nOutputPlane, self.kW, self.kH) + if self.dW ~= 1 or self.dH ~= 1 then + s = s .. string.format(', dW: %d, dH: %d', self.dW, self.dH) + end + if self.padding ~= 0 then + s = s .. ', padding: ' .. self.padding + end + return s .. ')' +end diff --git a/SpatialDropout.lua b/SpatialDropout.lua index 673678344..094e5bdc7 100755 --- a/SpatialDropout.lua +++ b/SpatialDropout.lua @@ -41,3 +41,7 @@ end function SpatialDropout:setp(p) self.p = p end + +function SpatialDropout:__tostring__() + return string.format('%s(%f)', torch.type(self), self.p) +end diff --git a/SpatialMaxPooling.lua b/SpatialMaxPooling.lua index 21197ac43..de5110869 100644 --- a/SpatialMaxPooling.lua +++ b/SpatialMaxPooling.lua @@ -32,3 +32,8 @@ function SpatialMaxPooling:empty() self.indices:resize() self.indices:storage():resize(0) end + +function SpatialMaxPooling:__tostring__() + return string.format('%s(kW: %d, kH: %d, dW: %d, dH: %d)', torch.type(self), + self.kW, self.kH, self.dW, self.dH) +end