forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathVolumetricAveragePooling.cu
279 lines (255 loc) · 9.58 KB
/
VolumetricAveragePooling.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
#include "THCUNN.h"
#include "THCTensor.hpp"
#include "common.h"
#include "THCDeviceTensor.cuh"
#include "THCDeviceTensorUtils.cuh"
#include "THCDeviceUtils.cuh"
#include "TH/THHalf.h"
#include "THCHalfAutoNumerics.cuh"
#include "THCAtomics.cuh"
template <typename Dtype, typename Acctype>
__global__ void cuda_VolumetricAveragePooling_updateOutput(
THCDeviceTensor<Dtype, 4> input,
THCDeviceTensor<Dtype, 4> output,
int kT, int kH, int kW,
int dT, int dH, int dW,
int padT, int padH, int padW,
bool count_include_pad, int offsetZ)
{
int oCol = blockIdx.x * blockDim.x + threadIdx.x;
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
int oFrame = (blockIdx.z + offsetZ) % output.getSize(1); // output frame/time
int slice = (blockIdx.z + offsetZ) / output.getSize(1); // output slice/feature
if (oRow < output.getSize(2) && oCol < output.getSize(3))
{
Acctype sum = 0.0;
int tstart = oFrame * dT - padT;
int hstart = oRow * dH - padH;
int wstart = oCol * dW - padW;
int tend = min(tstart + kT, input.getSize(1) + padT);
int hend = min(hstart + kH, input.getSize(2) + padH);
int wend = min(wstart + kW, input.getSize(3) + padW);
int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
tstart = max(tstart, 0);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
tend = min(tend, input.getSize(1));
hend = min(hend, input.getSize(2));
wend = min(wend, input.getSize(3));
Acctype divide_factor;
if (count_include_pad)
divide_factor = static_cast<Acctype>(pool_size);
else
divide_factor = static_cast<Acctype>((tend - tstart) * (hend - hstart) * (wend - wstart));
int ti, hi, wi;
for (ti = tstart; ti < tend; ++ti)
{
for (hi = hstart; hi < hend; ++hi)
{
for (wi = wstart; wi < wend; ++wi)
{
Dtype val = input[slice][ti][hi][wi];
sum += val;
}
}
}
output[slice][oFrame][oRow][oCol] = ScalarConvert<Acctype, Dtype>::to(sum / divide_factor);
}
}
// Inner-most loop size (kW) passed as template parameter for
// performance reasons.
//
template<int KERNEL_WIDTH, typename Dtype, typename Acctype>
__global__ void cuda_VolumetricAveragePooling_updateOutput_fixedKW(
THCDeviceTensor<Dtype, 4> input,
THCDeviceTensor<Dtype, 4> output,
int kT, int kH,
int dT, int dH, int dW,
int padT, int padH, int padW,
bool count_include_pad, int offsetZ)
{
int oCol = blockIdx.x * blockDim.x + threadIdx.x;
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
int oFrame = (blockIdx.z + offsetZ) % output.getSize(1); // output frame/time
int slice = (blockIdx.z + offsetZ) / output.getSize(1); // output slice/feature
if (oRow < output.getSize(2) && oCol < output.getSize(3))
{
Acctype sum = 0.0;
int tstart = oFrame * dT - padT;
int hstart = oRow * dH - padH;
int wstart = oCol * dW - padW;
int tend = min(tstart + kT, input.getSize(1) + padT);
int hend = min(hstart + kH, input.getSize(2) + padH);
int wend = min(wstart + KERNEL_WIDTH, input.getSize(3) + padW);
int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
tstart = max(tstart, 0);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
tend = min(tend, input.getSize(1));
hend = min(hend, input.getSize(2));
wend = min(wend, input.getSize(3));
Acctype divide_factor;
if (count_include_pad)
divide_factor = static_cast<Acctype>(pool_size);
else
divide_factor = static_cast<Acctype>((tend - tstart) * (hend - hstart) * (wend - wstart));
int ti, hi, wi;
for (ti = tstart; ti < tend; ++ti)
{
for (hi = hstart; hi < hend; ++hi)
{
for (wi = wstart; wi < wend; ++wi)
{
Dtype val = input[slice][ti][hi][wi];
sum += val;
}
}
}
output[slice][oFrame][oRow][oCol] = ScalarConvert<Acctype, Dtype>::to(sum / divide_factor);
}
}
#define LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \
cuda_VolumetricAveragePooling_updateOutput_fixedKW<KW, scalar_t, accreal> \
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
cudaInput, cudaOutput, kT, kH, dT, dH, dW, padT, padH, padW, count_include_pad, offsetZ); \
break
template <typename Dtype, typename Acctype>
__global__ void cuda_VolumetricAveragePooling_updateGradInput_Stride1(
THCDeviceTensor<Dtype, 4> gradOutput,
THCDeviceTensor<Dtype, 4> gradInput,
int kT, int kH, int kW,
Acctype normFactor, int offsetZ)
{
int iCol = blockIdx.x * blockDim.x + threadIdx.x;
int iRow = blockIdx.y * blockDim.y + threadIdx.y;
int iFrame = (blockIdx.z + offsetZ) % gradInput.getSize(1); // input frame/time
int slice = (blockIdx.z + offsetZ) / gradInput.getSize(1); // input slice/feature
// guard against over-tiled threads
if (iRow < gradInput.getSize(2) && iCol < gradInput.getSize(3))
{
Acctype sum = 0.0;
Dtype *gOut = &gradOutput[slice][max(0, iFrame - kT + 1)]
[max(0, iRow - kH + 1)][max(0, iCol - kW + 1)];
int frameOffset = 0;
for (int oFrame = max(0, iFrame - kT + 1);
oFrame < min(iFrame + 1, gradOutput.getSize(1));
++oFrame)
{
int rowOffset = frameOffset;
for (int oRow = max(0, iRow - kH + 1);
oRow < min(iRow + 1, gradOutput.getSize(2));
++oRow)
{
int colOffset = rowOffset;
for (int oCol = max(0, iCol - kW + 1);
oCol < min(iCol + 1, gradOutput.getSize(3));
++oCol)
{
sum += gOut[colOffset];
++colOffset;
}
rowOffset += gradOutput.getSize(3);
}
frameOffset += gradOutput.getSize(2) * gradOutput.getSize(3);
}
gradInput[slice][iFrame][iRow][iCol] = ScalarConvert<Acctype, Dtype>::to(sum * normFactor);
}
}
template <typename Dtype, typename Acctype>
__global__ void cuda_VolumetricAveragePooling_updateGradInput_atomicAdd(
THCDeviceTensor<Dtype, 4> gradOutput,
THCDeviceTensor<Dtype, 4> gradInput,
int kT, int kH, int kW,
int dT, int dH, int dW,
int padT, int padH, int padW,
bool count_include_pad, int offsetZ)
{
int oCol = blockIdx.x * blockDim.x + threadIdx.x;
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
int oFrame = (blockIdx.z + offsetZ) % gradOutput.getSize(1); // gradOutput frame/time
int slice = (blockIdx.z + offsetZ) / gradOutput.getSize(1); // gradOutput slice/feature
// guard against over-tiled threads
if (oRow < gradOutput.getSize(2) && oCol < gradOutput.getSize(3))
{
int tstart = oFrame * dT - padT;
int hstart = oRow * dH - padH;
int wstart = oCol * dW - padW;
int tend = min(tstart + kT, gradInput.getSize(1) + padT);
int hend = min(hstart + kH, gradInput.getSize(2) + padH);
int wend = min(wstart + kW, gradInput.getSize(3) + padW);
int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
tstart = max(tstart, 0);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
tend = min(tend, gradInput.getSize(1));
hend = min(hend, gradInput.getSize(2));
wend = min(wend, gradInput.getSize(3));
Acctype divide_factor;
if (count_include_pad)
divide_factor = static_cast<Acctype>(pool_size);
else
divide_factor = static_cast<Acctype>((tend - tstart) * (hend - hstart) * (wend - wstart));
Dtype val = ScalarConvert<Acctype, Dtype>::to(
ScalarConvert<Dtype, Acctype>::to(gradOutput[slice][oFrame][oRow][oCol]) / divide_factor);
for (int iFrame = tstart; iFrame < tend; ++iFrame)
{
for (int iRow = hstart; iRow < hend; ++iRow)
{
for (int iCol = wstart; iCol < wend; ++iCol)
{
atomicAdd(&gradInput[slice][iFrame][iRow][iCol], val);
}
}
}
}
}
template <typename Dtype, typename Acctype>
__global__ void cuda_VolumetricAveragePooling_updateGradInput(
THCDeviceTensor<Dtype, 4> gradOutput,
THCDeviceTensor<Dtype, 4> gradInput,
int kT, int kH, int kW,
int dT, int dH, int dW,
int padT, int padH, int padW,
bool count_include_pad, int offsetZ)
{
int oCol = blockIdx.x * blockDim.x + threadIdx.x;
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
int oFrame = (blockIdx.z + offsetZ) % gradOutput.getSize(1); // gradOutput frame/time
int slice = (blockIdx.z + offsetZ) / gradOutput.getSize(1); // gradOutput slice/feature
// guard against over-tiled threads
if (oRow < gradOutput.getSize(2) && oCol < gradOutput.getSize(3))
{
int tstart = oFrame * dT - padT;
int hstart = oRow * dH - padH;
int wstart = oCol * dW - padW;
int tend = min(tstart + kT, gradInput.getSize(1) + padT);
int hend = min(hstart + kH, gradInput.getSize(2) + padH);
int wend = min(wstart + kW, gradInput.getSize(3) + padW);
int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
tstart = max(tstart, 0);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
tend = min(tend, gradInput.getSize(1));
hend = min(hend, gradInput.getSize(2));
wend = min(wend, gradInput.getSize(3));
Acctype divide_factor;
if (count_include_pad)
divide_factor = static_cast<Acctype>(pool_size);
else
divide_factor = static_cast<Acctype>((tend - tstart) * (hend - hstart) * (wend - wstart));
Dtype val = ScalarConvert<Acctype, Dtype>::to(
ScalarConvert<Dtype, Acctype>::to(gradOutput[slice][oFrame][oRow][oCol]) / divide_factor);
for (int iFrame = tstart; iFrame < tend; ++iFrame)
{
for (int iRow = hstart; iRow < hend; ++iRow)
{
for (int iCol = wstart; iCol < wend; ++iCol)
{
gradInput[slice][iFrame][iRow][iCol] = val;
}
}
}
}
}
#include "generic/VolumetricAveragePooling.cu"
#include "THCGenerateFloatTypes.h"