-
Notifications
You must be signed in to change notification settings - Fork 51
/
SpatialReSamplingEx.lua
82 lines (73 loc) · 3.15 KB
/
SpatialReSamplingEx.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
local SpatialReSamplingEx, parent = torch.class('nn.SpatialReSamplingEx', 'nn.Module')
local help_desc = [[
Extended spatial resampling.
]]
function SpatialReSamplingEx:__init(...)
parent.__init(self)
-- get args
xlua.unpack_class(
self, {...}, 'nn.SpatialReSampling', help_desc,
{arg='rwidth', type='number', help='ratio: owidth/iwidth'},
{arg='rheight', type='number', help='ratio: oheight/iheight'},
{arg='owidth', type='number', help='output width'},
{arg='oheight', type='number', help='output height'},
{arg='mode', type='string', help='Mode : simple | average (only for downsampling) | bilinear', default = 'simple'},
{arg='yDim', type='number', help='image y dimension', default=2},
{arg='xDim', type='number', help='image x dimension', default=3}
)
if self.yDim+1 ~= self.xDim then
error('nn.SpatialReSamplingEx: yDim must be equals to xDim-1')
end
self.outputSize = torch.LongStorage(4)
self.inputSize = torch.LongStorage(4)
if self.mode == 'simple' then self.mode_c = 0 end
if self.mode == 'average' then self.mode_c = 1 end
if self.mode == 'bilinear' then self.mode_c = 2 end
if not self.mode_c then
error('SpatialReSampling: mode must be simple | average | bilinear')
end
end
local function round(a)
return math.floor(a+0.5)
end
function SpatialReSamplingEx:updateOutput(input)
-- compute iheight, iwidth, oheight and owidth
self.iheight = input:size(self.yDim)
self.iwidth = input:size(self.xDim)
self.oheightCurrent = self.oheight or round(self.rheight*self.iheight)
self.owidthCurrent = self.owidth or round(self.rwidth*self.iwidth)
if not ((self.oheightCurrent>=self.iheight) == (self.owidthCurrent>=self.iwidth)) then
error('SpatialReSamplingEx: Cannot upsample one dimension while downsampling the other')
end
-- resize input into K1 x iheight x iwidth x K2 tensor
self.inputSize:fill(1)
for i = 1,self.yDim-1 do
self.inputSize[1] = self.inputSize[1] * input:size(i)
end
self.inputSize[2] = self.iheight
self.inputSize[3] = self.iwidth
for i = self.xDim+1,input:nDimension() do
self.inputSize[4] = self.inputSize[4] * input:size(i)
end
local reshapedInput = input:reshape(self.inputSize)
-- prepare output of size K1 x oheight x owidth x K2
self.outputSize[1] = self.inputSize[1]
self.outputSize[2] = self.oheightCurrent
self.outputSize[3] = self.owidthCurrent
self.outputSize[4] = self.inputSize[4]
self.output:resize(self.outputSize)
-- resample over dims 2 and 3
input.nn.SpatialReSamplingEx_updateOutput(self, input:reshape(self.inputSize))
--resize output into the same shape as input
local outputSize2 = input:size()
outputSize2[self.yDim] = self.oheightCurrent
outputSize2[self.xDim] = self.owidthCurrent
self.output = self.output:reshape(outputSize2)
return self.output
end
function SpatialReSamplingEx:updateGradInput(input, gradOutput)
self.gradInput:resize(self.inputSize)
input.nn.SpatialReSamplingEx_updateGradInput(self, gradOutput:reshape(self.outputSize))
self.gradInput = self.gradInput:reshape(input:size())
return self.gradInput
end