forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
BatchNorm.cpp
381 lines (333 loc) · 11.9 KB
/
BatchNorm.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Config.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAConfig.h>
#if !AT_CUDNN_ENABLED()
namespace at {
namespace native {
// See Note [ATen preprocessor philosophy]
std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
const Tensor& input,
const Tensor& weight,
const c10::optional<Tensor>& bias_opt,
const c10::optional<Tensor>& running_mean_opt,
const c10::optional<Tensor>& running_var_opt,
bool training,
double exponential_average_factor,
double epsilon) {
AT_ERROR("cudnn_batch_norm: ATen not compiled with cuDNN support");
}
std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
const Tensor& input,
const Tensor& grad_output,
const Tensor& weight,
const c10::optional<Tensor>& running_mean_opt,
const c10::optional<Tensor>& running_var_opt,
const c10::optional<Tensor>& save_mean_opt,
const c10::optional<Tensor>& save_var_opt,
double epsilon,
const Tensor& reservedSpace) {
AT_ERROR("cudnn_batch_norm_backward: ATen not compiled with cuDNN support");
}
} // namespace native
} // namespace at
#else // AT_CUDNN_ENABLED
#include <ATen/cuda/Exceptions.h>
#include <ATen/cudnn/Descriptors.h>
#include <ATen/cudnn/Types.h>
#include <ATen/cudnn/Utils.h>
#include <ATen/TensorUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/cudnn_batch_norm_backward_native.h>
#include <ATen/ops/cudnn_batch_norm_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#endif
namespace at {
namespace native {
namespace {
Tensor expandScale(const Tensor& t, int64_t dim) {
std::vector<int64_t> size{1, t.numel()};
while (static_cast<int64_t>(size.size()) < dim) {
size.emplace_back(1);
}
return t.view(size);
}
cudnnBatchNormMode_t getCudnnBatchNormMode(
bool training,
at::MemoryFormat memory_format,
int64_t dim) {
if (dim == 2) {
return CUDNN_BATCHNORM_PER_ACTIVATION;
} else if (training && memory_format == at::MemoryFormat::ChannelsLast) {
return CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
} else if (training && memory_format == at::MemoryFormat::ChannelsLast3d) {
#if CUDNN_VERSION >= 8100
return CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#else
return CUDNN_BATCHNORM_SPATIAL;
#endif // CUDNN_VERSION >= 8100
} else {
// TODO: The new CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode was
// introduced in CuDNN 7 for performance optimization, but it results in
// accuracy losses in convolution models such as ResNeXt-101 and
// video R(2+1)D. We will fall back to the normal CUDNN_BATCHNORM_SPATIAL
return CUDNN_BATCHNORM_SPATIAL;
}
}
} // namespace
std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
const Tensor& input_t,
const Tensor& weight_t,
const c10::optional<Tensor>& bias_t_opt,
const c10::optional<Tensor>& running_mean_t_opt,
const c10::optional<Tensor>& running_var_t_opt,
bool training,
double exponential_average_factor,
double epsilon) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_t_maybe_owned =
at::borrow_from_optional_tensor(bias_t_opt);
const Tensor& bias_t = *bias_t_maybe_owned;
const Tensor& running_mean_t =
c10::value_or_else(running_mean_t_opt, [] { return Tensor(); });
const Tensor& running_var_t =
c10::value_or_else(running_var_t_opt, [] { return Tensor(); });
TensorArg input{input_t, "input", 1}, weight{weight_t, "weight", 2},
bias{bias_t, "bias", 3}, running_mean{running_mean_t, "running_mean", 4},
running_var{running_var_t, "running_var", 5};
CheckedFrom c = "cudnn_batch_norm";
checkAllDefined(c, {input, weight, bias});
if (!training) {
checkAllDefined(c, {running_mean, running_var});
}
checkAllSameGPU(c, {input, weight, bias, running_mean, running_var});
if (input->scalar_type() == ScalarType::Half) {
checkScalarType(c, weight, ScalarType::Float);
} else {
checkAllSameType(c, {input, weight});
}
checkAllSameType(c, {weight, bias, running_mean, running_var});
// TODO: is weight required to be contiguous?
checkAllContiguous(c, {weight, bias, running_mean, running_var});
// TODO: TensorArg check should start handle memory format
TORCH_CHECK(input->is_contiguous(input->suggest_memory_format()));
checkDimRange(c, input, 2, 6 /* exclusive */);
auto num_features = input->size(1);
for (auto t : {weight, bias, running_mean, running_var}) {
if (t->defined()) {
checkNumel(c, t, num_features);
}
}
cudnnBatchNormMode_t mode = getCudnnBatchNormMode(
training, input->suggest_memory_format(), input->dim());
auto output_t =
at::empty_like(*input, input->options(), input->suggest_memory_format());
TensorArg output{output_t, "output", 0};
auto handle = getCudnnHandle();
auto dataType = getCudnnDataType(*input);
TensorDescriptor idesc{*input, 4}; // input descriptor
TensorDescriptor wdesc{
expandScale(*weight, input->dim()),
4}; // descriptor for weight, bias, running_mean, etc.
Constant one(dataType, 1);
Constant zero(dataType, 0);
Tensor save_mean, save_var;
Tensor reserve;
if (training) {
int64_t num_features = input_t.size(1);
save_mean = at::empty({num_features}, weight_t.options());
save_var = at::empty({num_features}, weight_t.options());
auto op = CUDNN_BATCHNORM_OPS_BN;
size_t workspace_size;
AT_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
handle,
mode,
op,
idesc.desc(),
idesc.desc(),
idesc.desc(),
wdesc.desc(),
nullptr,
&workspace_size));
Tensor workspace = at::empty(workspace_size, input->options().dtype(kByte));
// get the reserved size and allocate as tensor
size_t reserve_size;
AT_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
handle, mode, op, nullptr, idesc.desc(), &reserve_size));
reserve = at::empty(reserve_size, input->options().dtype(kByte));
AT_CUDNN_CHECK(cudnnBatchNormalizationForwardTrainingEx(
handle,
mode,
op,
&one,
&zero,
idesc.desc(),
input->data_ptr(),
nullptr, // z descriptor for BN-Add-Relu
nullptr, // z for BN-Add-ReLU
idesc.desc(),
output->data_ptr(),
wdesc.desc(),
weight->data_ptr(),
bias->data_ptr(),
exponential_average_factor,
at::maybe_data_ptr(running_mean),
at::maybe_data_ptr(running_var),
epsilon,
save_mean.mutable_data_ptr(),
save_var.mutable_data_ptr(),
nullptr,
workspace.data_ptr(),
workspace_size,
reserve.mutable_data_ptr(),
reserve_size));
} else {
reserve = at::empty({0}, input->options().dtype(kByte));
// This keeps a consistent output with native_batch_norm
save_mean = at::empty({0}, weight_t.options());
save_var = at::empty({0}, weight_t.options());
AT_CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
handle,
mode,
&one,
&zero,
idesc.desc(),
input->data_ptr(),
idesc.desc(),
output->data_ptr(),
wdesc.desc(),
weight->data_ptr(),
bias->data_ptr(),
running_mean->data_ptr(),
running_var->data_ptr(),
epsilon));
}
// save_mean and save_var can be undefined
// If this causes problems, we can initialize them to empty tensors
// of the correct type
return std::tuple<Tensor, Tensor, Tensor, Tensor>{
output_t, save_mean, save_var, reserve};
}
// NB: CuDNN only implements the backward algorithm for batchnorm
// in training mode (evaluation mode batchnorm has a different algorithm),
// which is why this doesn't accept a 'training' parameter.
std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
const Tensor& input_t,
const Tensor& grad_output_t,
const Tensor& weight_t,
// Unused: but we require them to be passed so that double backwards
// has access
const c10::optional<Tensor>& running_mean_opt,
const c10::optional<Tensor>& running_var_opt,
const c10::optional<Tensor>& save_mean_t_opt,
const c10::optional<Tensor>& save_var_t_opt,
double epsilon,
const Tensor& reserveSpace) {
// See [Note: hacky wrapper removal for optional tensor]
const Tensor& save_mean_t =
c10::value_or_else(save_mean_t_opt, [] { return Tensor(); });
const Tensor& save_var_t =
c10::value_or_else(save_var_t_opt, [] { return Tensor(); });
// TODO: Is it worth it to have a contiguous call or maybe we should go with
// whatever format is given here.
auto grad_output_contig =
grad_output_t.contiguous(input_t.suggest_memory_format());
TensorArg input{input_t, "input", 1},
grad_output{grad_output_contig, "grad_output", 2},
weight{weight_t, "weight", 3}, save_mean{save_mean_t, "save_mean", 4},
save_var{save_var_t, "save_var", 5},
reserve{reserveSpace, "reserve_space", 6};
CheckedFrom c = "cudnn_batch_norm_backward";
checkAllDefined(c, {input, grad_output, weight, save_mean, save_var});
checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var});
if (input->scalar_type() == ScalarType::Half) {
checkScalarType(c, weight, ScalarType::Float);
} else {
checkAllSameType(c, {input, weight});
}
checkAllSameType(c, {input, grad_output});
checkAllSameType(c, {weight, save_mean, save_var});
// TODO: is weight required to be contiguous?
checkAllContiguous(c, {save_mean, save_var});
// TODO: TensorArg check should start handle memory format
TORCH_CHECK(input->is_contiguous(input->suggest_memory_format()));
TORCH_CHECK(grad_output->is_contiguous(input->suggest_memory_format()));
checkDimRange(c, input, 2, 6 /* exclusive */);
checkSameSize(c, input, grad_output);
auto num_features = input->size(1);
for (auto t : {weight, save_mean, save_var}) {
checkNumel(c, t, num_features);
}
cudnnBatchNormMode_t mode = getCudnnBatchNormMode(
true, // training
input->suggest_memory_format(),
input->dim());
auto grad_input_t = at::empty(
input->sizes(), input->options(), input->suggest_memory_format());
auto grad_weight_t = at::empty(weight->sizes(), weight->options());
auto grad_bias_t = at::empty(weight->sizes(), weight->options());
auto handle = getCudnnHandle();
auto dataType = getCudnnDataType(*input);
TensorDescriptor idesc{*input, 4}; // input, grad_output descriptor
TensorDescriptor odesc{*grad_output, 4}; // input, grad_output descriptor
TensorDescriptor wdesc{
expandScale(*weight, input->dim()),
4}; // descriptor for weight, save_mean, etc.
Constant one(dataType, 1);
Constant zero(dataType, 0);
auto op = CUDNN_BATCHNORM_OPS_BN;
size_t workspace_size;
AT_CUDNN_CHECK(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
handle,
mode,
op,
idesc.desc(),
idesc.desc(),
idesc.desc(),
nullptr,
odesc.desc(),
wdesc.desc(),
nullptr,
&workspace_size));
Tensor workspace = at::empty(workspace_size, input->options().dtype(kByte));
AT_CUDNN_CHECK(cudnnBatchNormalizationBackwardEx(
handle,
mode,
op,
&one,
&zero,
&one,
&zero,
idesc.desc(),
input->data_ptr(),
nullptr,
nullptr,
odesc.desc(),
grad_output->data_ptr(),
nullptr,
nullptr,
idesc.desc(),
grad_input_t.data_ptr(),
wdesc.desc(),
weight->data_ptr(),
nullptr,
grad_weight_t.data_ptr(),
grad_bias_t.data_ptr(),
epsilon,
save_mean->data_ptr(),
save_var->data_ptr(),
nullptr,
workspace.data_ptr(),
workspace_size,
reserve->data_ptr(),
reserve->numel()));
return std::tuple<Tensor, Tensor, Tensor>{
grad_input_t, grad_weight_t, grad_bias_t};
}
} // namespace native
} // namespace at
#endif