Skip to content

Commit

Permalink
Merge pull request torch#293 from fmassa/conv_pad
Browse files Browse the repository at this point in the history
SpatialConvolutionMM supports padW/padH != 1
  • Loading branch information
soumith committed Jun 18, 2015
2 parents 25d46e6 + 793b6bf commit 8fa3ee9
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 44 deletions.
18 changes: 13 additions & 5 deletions SpatialConvolution.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
local SpatialConvolution, parent = torch.class('nn.SpatialConvolution', 'nn.Module')

function SpatialConvolution:__init(nInputPlane, nOutputPlane, kW, kH, dW, dH, padding)
function SpatialConvolution:__init(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
parent.__init(self)

dW = dW or 1
Expand All @@ -13,7 +13,8 @@ function SpatialConvolution:__init(nInputPlane, nOutputPlane, kW, kH, dW, dH, pa

self.dW = dW
self.dH = dH
self.padding = padding or 0
self.padW = padW or 0
self.padH = padH or self.padW

self.weight = torch.Tensor(nOutputPlane, nInputPlane, kH, kW)
self.bias = torch.Tensor(nOutputPlane)
Expand Down Expand Up @@ -45,7 +46,14 @@ end
local function backCompatibility(self)
self.finput = self.finput or self.weight.new()
self.fgradInput = self.fgradInput or self.weight.new()
self.padding = self.padding or 0
if self.padding then
self.padW = self.padding
self.padH = self.padding
self.padding = nil
else
self.padW = self.padW or 0
self.padH = self.padH or 0
end
if self.weight:dim() == 2 then
self.weight = self.weight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW)
end
Expand Down Expand Up @@ -128,8 +136,8 @@ function SpatialConvolution:__tostring__()
end
if self.padding and self.padding ~= 0 then
s = s .. ', ' .. self.padding .. ',' .. self.padding
elseif self.pad_w or self.pad_h then
s = s .. ', ' .. self.pad_w .. ',' .. self.pad_h
elseif (self.padW or self.padH) and (self.padW ~= 0 or self.padH ~= 0) then
s = s .. ', ' .. self.padW .. ',' .. self.padW
end
return s .. ')'
end
15 changes: 11 additions & 4 deletions SpatialConvolutionMM.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
local SpatialConvolutionMM, parent = torch.class('nn.SpatialConvolutionMM', 'nn.Module')

function SpatialConvolutionMM:__init(nInputPlane, nOutputPlane, kW, kH, dW, dH, padding)
function SpatialConvolutionMM:__init(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
parent.__init(self)

dW = dW or 1
Expand All @@ -13,7 +13,8 @@ function SpatialConvolutionMM:__init(nInputPlane, nOutputPlane, kW, kH, dW, dH,

self.dW = dW
self.dH = dH
self.padding = padding or 0
self.padW = padW or 0
self.padH = padH or self.padW

self.weight = torch.Tensor(nOutputPlane, nInputPlane*kH*kW)
self.bias = torch.Tensor(nOutputPlane)
Expand Down Expand Up @@ -62,6 +63,12 @@ local function makeContiguous(self, input, gradOutput)
end

function SpatialConvolutionMM:updateOutput(input)
-- backward compatibility
if self.padding then
self.padW = self.padding
self.padH = self.padding
self.padding = nil
end
input = makeContiguous(self, input)
return input.nn.SpatialConvolutionMM_updateOutput(self, input)
end
Expand Down Expand Up @@ -92,8 +99,8 @@ function SpatialConvolutionMM:__tostring__()
end
if self.padding and self.padding ~= 0 then
s = s .. ', ' .. self.padding .. ',' .. self.padding
elseif self.pad_w or self.pad_h then
s = s .. ', ' .. self.pad_w .. ',' .. self.pad_h
elseif (self.padW or self.padH) and (self.padW ~= 0 or self.padH ~= 0) then
s = s .. ', ' .. self.padW .. ',' .. self.padW
end
return s .. ')'
end
9 changes: 5 additions & 4 deletions doc/convolution.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ are spatial (e.g. `height x width`). These are commonly used for processing imag
### SpatialConvolution ###

```lua
module = nn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, [dW], [dH], [padding])
module = nn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, [dW], [dH], [padW], [padH])
```

Applies a 2D convolution over an input image composed of several input planes. The `input` tensor in
Expand All @@ -276,7 +276,8 @@ The parameters are the following:
* `kH`: The kernel height of the convolution
* `dW`: The step of the convolution in the width dimension. Default is `1`.
* `dH`: The step of the convolution in the height dimension. Default is `1`.
* `padding`: The additional zeros added per side to the input planes. Default is `0`, a good number is `(kernelSize-1)/2` for square kernels.
* `padW`: The additional zeros added per width to the input planes. Default is `0`, a good number is `(kW-1)/2`.
* `padH`: The additional zeros added per height to the input planes. Default is `0`, a good number is `(kH-1)/2`.

Note that depending of the size of your kernel, several (of the last)
columns or rows of the input image might be lost. It is up to the user to
Expand All @@ -285,8 +286,8 @@ add proper padding in images.
If the input image is a 3D tensor `nInputPlane x height x width`, the output image size
will be `nOutputPlane x oheight x owidth` where
```lua
owidth = floor((width + 2*padding - kW) / dW + 1)
oheight = floor((height + 2*padding - kH) / dH + 1)
owidth = floor((width + 2*padW - kW) / dW + 1)
oheight = floor((height + 2*padH - kH) / dH + 1)
```

The parameters of the convolution can be found in `self.weight` (Tensor of
Expand Down
54 changes: 28 additions & 26 deletions generic/SpatialConvolutionMM.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
static void nn_(unfolded_acc)(THTensor *finput, THTensor *input,
int kW, int kH,
int dW, int dH,
int padding,
int padW, int padH,
int nInputPlane,
int inputWidth, int inputHeight,
int outputWidth, int outputHeight)
Expand All @@ -25,21 +25,21 @@ static void nn_(unfolded_acc)(THTensor *finput, THTensor *input,
{
real *src = finput_data + nip*(kH*kW*outputHeight*outputWidth) + kh*(kW*outputHeight*outputWidth) + kw*(outputHeight*outputWidth);
real *dst = input_data + nip*(inputHeight*inputWidth);
if (padding > 0) {
if (padW > 0 || padH > 0) {
int lpad,rpad;
for(y = 0; y < outputHeight; y++) {
iy = y*dH - padding + kh;
iy = y*dH - padH + kh;
if (iy < 0 || iy >= inputHeight) {
} else {
if (dW==1){
ix = 0 - padding + kw;
lpad = fmaxf(0,padding-kw);
rpad = fmaxf(0,padding-(kW-kw-1));
ix = 0 - padW + kw;
lpad = fmaxf(0,padW-kw);
rpad = fmaxf(0,padW-(kW-kw-1));
THVector_(add)(dst+iy*inputWidth+ix+lpad, src+y*outputWidth+lpad, 1, outputWidth - lpad - rpad); /* note: THVector_add could handle 1 value better */
}
else{
for (x=0; x<outputWidth; x++){
ix = x*dW - padding + kw;
ix = x*dW - padW + kw;
if (ix < 0 || ix >= inputWidth){
}else
THVector_(add)(dst+iy*inputWidth+ix, src+y*outputWidth+x, 1, 1);
Expand Down Expand Up @@ -67,7 +67,7 @@ static void nn_(unfolded_acc)(THTensor *finput, THTensor *input,
static void nn_(unfolded_copy)(THTensor *finput, THTensor *input,
int kW, int kH,
int dW, int dH,
int padding,
int padW, int padH,
int nInputPlane,
int inputWidth, int inputHeight,
int outputWidth, int outputHeight)
Expand All @@ -85,17 +85,17 @@ static void nn_(unfolded_copy)(THTensor *finput, THTensor *input,
int x,y,ix,iy;
real *dst = finput_data + nip*(kH*kW*outputHeight*outputWidth) + kh*(kW*outputHeight*outputWidth) + kw*(outputHeight*outputWidth);
real *src = input_data + nip*(inputHeight*inputWidth);
if (padding > 0) {
if (padW > 0 || padH > 0) {
int lpad,rpad;
for(y = 0; y < outputHeight; y++) {
iy = y*dH - padding + kh;
iy = y*dH - padH + kh;
if (iy < 0 || iy >= inputHeight) {
memset(dst+y*outputWidth, 0, sizeof(real)*outputWidth);
} else {
if (dW==1){
ix = 0 - padding + kw;
lpad = fmaxf(0,padding-kw);
rpad = fmaxf(0,padding-(kW-kw-1));
ix = 0 - padW + kw;
lpad = fmaxf(0,padW-kw);
rpad = fmaxf(0,padW-(kW-kw-1));
if (outputWidth-rpad-lpad <= 0) {
memset(dst+y*outputWidth, 0, sizeof(real)*outputWidth);
} else {
Expand All @@ -106,7 +106,7 @@ static void nn_(unfolded_copy)(THTensor *finput, THTensor *input,
}
else{
for (x=0; x<outputWidth; x++){
ix = x*dW - padding + kw;
ix = x*dW - padW + kw;
if (ix < 0 || ix >= inputWidth)
memset(dst+y*outputWidth+x, 0, sizeof(real)*1);
else
Expand All @@ -131,14 +131,14 @@ static void nn_(unfolded_copy)(THTensor *finput, THTensor *input,
}

static void nn_(SpatialConvolutionMM_updateOutput_frame)(THTensor *input, THTensor *output, THTensor *weight, THTensor *bias, THTensor *finput,
int kW, int kH, int dW, int dH, int padding,
int kW, int kH, int dW, int dH, int padW, int padH,
long nInputPlane, long inputWidth, long inputHeight,
long nOutputPlane, long outputWidth, long outputHeight)
{
long i;
THTensor *output2d;

nn_(unfolded_copy)(finput, input, kW, kH, dW, dH, padding, nInputPlane, inputWidth, inputHeight, outputWidth, outputHeight);
nn_(unfolded_copy)(finput, input, kW, kH, dW, dH, padW, padH, nInputPlane, inputWidth, inputHeight, outputWidth, outputHeight);

output2d = THTensor_(newWithStorage2d)(output->storage, output->storageOffset,
nOutputPlane, -1,
Expand All @@ -159,7 +159,8 @@ static int nn_(SpatialConvolutionMM_updateOutput)(lua_State *L)
int kH = luaT_getfieldcheckint(L, 1, "kH");
int dW = luaT_getfieldcheckint(L, 1, "dW");
int dH = luaT_getfieldcheckint(L, 1, "dH");
int padding = luaT_getfieldcheckint(L, 1, "padding");
int padW = luaT_getfieldcheckint(L, 1, "padW");
int padH = luaT_getfieldcheckint(L, 1, "padH");

THTensor *finput = luaT_getfieldcheckudata(L, 1, "finput", torch_Tensor);
THTensor *weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor);
Expand Down Expand Up @@ -190,16 +191,16 @@ static int nn_(SpatialConvolutionMM_updateOutput)(lua_State *L)
inputWidth = input->size[dimw];
inputHeight = input->size[dimh];
nOutputPlane = weight->size[0];
outputWidth = (inputWidth + 2*padding - kW) / dW + 1;
outputHeight = (inputHeight + 2*padding - kH) / dH + 1;
outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
outputHeight = (inputHeight + 2*padH - kH) / dH + 1;

if(input->nDimension == 3)
{
THTensor_(resize2d)(finput, kW*kH*nInputPlane, outputHeight*outputWidth);
THTensor_(resize3d)(output, nOutputPlane, outputHeight, outputWidth);

nn_(SpatialConvolutionMM_updateOutput_frame)(input, output, weight, bias, finput,
kW, kH, dW, dH, padding,
kW, kH, dW, dH, padW, padH,
nInputPlane, inputWidth, inputHeight,
nOutputPlane, outputWidth, outputHeight);
}
Expand All @@ -219,7 +220,7 @@ static int nn_(SpatialConvolutionMM_updateOutput)(lua_State *L)
THTensor *finput_t = THTensor_(newSelect)(finput, 0, t);

nn_(SpatialConvolutionMM_updateOutput_frame)(input_t, output_t, weight, bias, finput_t,
kW, kH, dW, dH, padding,
kW, kH, dW, dH, padW, padH,
nInputPlane, inputWidth, inputHeight,
nOutputPlane, outputWidth, outputHeight);

Expand All @@ -234,7 +235,7 @@ static int nn_(SpatialConvolutionMM_updateOutput)(lua_State *L)


static void nn_(SpatialConvolutionMM_updateGradInput_frame)(THTensor *gradInput, THTensor *gradOutput, THTensor *weight, THTensor *fgradInput,
int kW, int kH, int dW, int dH, int padding)
int kW, int kH, int dW, int dH, int padW, int padH)
{
THTensor *gradOutput2d = THTensor_(newWithStorage2d)(gradOutput->storage, gradOutput->storageOffset,
gradOutput->size[0], -1,
Expand All @@ -244,7 +245,7 @@ static void nn_(SpatialConvolutionMM_updateGradInput_frame)(THTensor *gradInput,

THTensor_(zero)(gradInput);

nn_(unfolded_acc)(fgradInput, gradInput, kW, kH, dW, dH, padding, gradInput->size[0], gradInput->size[2], gradInput->size[1], gradOutput->size[2], gradOutput->size[1]);
nn_(unfolded_acc)(fgradInput, gradInput, kW, kH, dW, dH, padW, padH, gradInput->size[0], gradInput->size[2], gradInput->size[1], gradOutput->size[2], gradOutput->size[1]);
}

static int nn_(SpatialConvolutionMM_updateGradInput)(lua_State *L)
Expand All @@ -255,7 +256,8 @@ static int nn_(SpatialConvolutionMM_updateGradInput)(lua_State *L)
int kH = luaT_getfieldcheckint(L, 1, "kH");
int dW = luaT_getfieldcheckint(L, 1, "dW");
int dH = luaT_getfieldcheckint(L, 1, "dH");
int padding = luaT_getfieldcheckint(L, 1, "padding");
int padW = luaT_getfieldcheckint(L, 1, "padW");
int padH = luaT_getfieldcheckint(L, 1, "padH");
int nOutputPlane = luaT_getfieldcheckint(L, 1, "nOutputPlane");

THTensor *finput = luaT_getfieldcheckudata(L, 1, "finput", torch_Tensor);
Expand All @@ -271,7 +273,7 @@ static int nn_(SpatialConvolutionMM_updateGradInput)(lua_State *L)

if(input->nDimension == 3)
{
nn_(SpatialConvolutionMM_updateGradInput_frame)(gradInput, gradOutput, weight, fgradInput, kW, kH, dW, dH, padding);
nn_(SpatialConvolutionMM_updateGradInput_frame)(gradInput, gradOutput, weight, fgradInput, kW, kH, dW, dH, padW, padH);
}
else
{
Expand All @@ -285,7 +287,7 @@ static int nn_(SpatialConvolutionMM_updateGradInput)(lua_State *L)
THTensor *gradOutput_t = THTensor_(newSelect)(gradOutput, 0, t);
THTensor *fgradInput_t = THTensor_(newSelect)(fgradInput, 0, t);

nn_(SpatialConvolutionMM_updateGradInput_frame)(gradInput_t, gradOutput_t, weight, fgradInput_t, kW, kH, dW, dH, padding);
nn_(SpatialConvolutionMM_updateGradInput_frame)(gradInput_t, gradOutput_t, weight, fgradInput_t, kW, kH, dW, dH, padW, padH);

THTensor_(free)(gradInput_t);
THTensor_(free)(gradOutput_t);
Expand Down
11 changes: 6 additions & 5 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1446,12 +1446,13 @@ function nntest.SpatialConvolutionMM()
local kj = math.random(1,5)
local di = math.random(1,4)
local dj = math.random(1,4)
local padding = math.random(0,2)
local padW = math.random(0,2)
local padH = math.random(0,2)
local outi = math.random(5,9)
local outj = math.random(5,9)
local ini = (outi-1)*di+ki-padding*2
local inj = (outj-1)*dj+kj-padding*2
local module = nn.SpatialConvolutionMM(from, to, ki, kj, di, dj, padding)
local ini = (outi-1)*di+ki-padW*2
local inj = (outj-1)*dj+kj-padH*2
local module = nn.SpatialConvolutionMM(from, to, ki, kj, di, dj, padW, padH)
local input = torch.Tensor(from, inj, ini):zero()

-- stochastic
Expand Down Expand Up @@ -1486,7 +1487,7 @@ function nntest.SpatialConvolutionMM()
--verbose = true
local batch = math.random(2,5)

module = nn.SpatialConvolutionMM(from, to, ki, kj, di, dj, padding)
module = nn.SpatialConvolutionMM(from, to, ki, kj, di, dj, padW, padH)
input = torch.Tensor(batch,from,inj,ini):zero()

local err = jac.testJacobian(module, input)
Expand Down

0 comments on commit 8fa3ee9

Please sign in to comment.