forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Reduce.cuh
483 lines (413 loc) · 14.4 KB
/
Reduce.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
#pragma once
#include <ATen/ATen.h>
#include <ATen/cuda/Array.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/detail/FunctionTraits.h>
#include <THC/THCDeviceUtils.cuh>
#include <THC/THCGeneral.hpp>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh>
#include <iosfwd>
namespace at { namespace native {
using at::cuda::Array;
static inline int64_t div_up(int64_t a, int64_t b) {
return (a + b - 1) / b;
}
struct ReduceConfig {
static constexpr int LANE = 0;
static constexpr int WARP = 1;
static constexpr int CTA = 2;
static constexpr int NUM_THREADS = 512;
ReduceConfig(int element_size_bytes, int num_outputs, int num_inputs)
: element_size_bytes(element_size_bytes)
, num_inputs(num_inputs)
, num_outputs(num_outputs) {}
int element_size_bytes;
int num_inputs;
int num_outputs;
int step_input = 1;
int step_output = 1;
int ctas_per_output = 1;
int input_mult[3] = {0, 0, 0};
int output_mult[2] = {0, 0};
int split_input(int parallelism) {
int step = step_input;
step_input *= parallelism;
return step;
}
int split_output(int parallelism) {
int step = step_output;
step_output *= parallelism;
return step;
}
dim3 block() const {
int warp_size = at::cuda::warp_size();
return dim3(warp_size, NUM_THREADS / warp_size);
}
dim3 grid() const {
return dim3(div_up(num_outputs, step_output), ctas_per_output);
}
AT_HOST_DEVICE bool should_warp_reduce() const {
return input_mult[LANE] != 0;
}
AT_HOST_DEVICE bool should_block_reduce() const {
return input_mult[WARP] != 0;
}
AT_HOST_DEVICE bool should_global_reduce() const {
return input_mult[CTA] != 0;
}
AT_DEVICE bool should_store(int output_idx) const {
return output_idx < num_outputs &&
(!should_warp_reduce() || threadIdx.x == 0) &&
(!should_block_reduce() || threadIdx.y == 0);
}
AT_HOST_DEVICE int input_idx() const {
int lane = threadIdx.x;
int warp = threadIdx.y;
int cta2 = blockIdx.y;
return (lane * input_mult[LANE] +
warp * input_mult[WARP] +
cta2 * input_mult[CTA]);
}
AT_HOST_DEVICE int output_idx() const {
int lane = threadIdx.x;
int warp = threadIdx.y;
int cta1 = blockIdx.x;
return (lane * output_mult[LANE] +
warp * output_mult[WARP] +
cta1 * step_output);
}
AT_DEVICE int shared_memory_offset(int offset) const {
return threadIdx.x + (threadIdx.y + offset) * blockDim.x;
}
AT_DEVICE int staging_memory_offset(int cta2) const {
int offset = cta2 + blockIdx.x * gridDim.y;
if (!should_warp_reduce()) {
offset = threadIdx.x + offset * blockDim.x;
}
return offset;
}
int shared_memory_size() const {
if (!should_block_reduce()) {
return 0;
}
return element_size_bytes * NUM_THREADS;
}
int global_memory_size() const {
if (!should_global_reduce()) {
return 0;
}
int size = element_size_bytes * num_outputs * ctas_per_output;
if (!should_warp_reduce()) {
size *= block().x;
}
return size;
}
int semaphore_size() const {
if (!should_global_reduce()) {
return 0;
}
return sizeof(int) * grid().x;
}
int values_per_thread() const {
return div_up(num_inputs, step_input);
}
};
std::ostream& operator<<(std::ostream& out, const ReduceConfig& config);
template<int nt, typename R>
__launch_bounds__(nt, 4)
__global__ void reduce_kernel(R reduction) {
reduction.run();
}
static OffsetCalculator<2> make_output_calculator(const TensorIterator& iter) {
int num_reduce_dims = iter.num_reduce_dims();
int num_output_dims = iter.ndim() - num_reduce_dims;
std::array<const int64_t*, 2> strides = {
iter.strides(0).data() + num_reduce_dims,
iter.strides(1).data() + num_reduce_dims,
};
auto shape = iter.shape().data() + num_reduce_dims;
return OffsetCalculator<2>(num_output_dims, shape, strides.data());
}
static OffsetCalculator<1> make_input_calculator(const TensorIterator& iter) {
int num_reduce_dims = iter.num_reduce_dims();
std::array<const int64_t*, 1> strides = {
iter.strides(1).data(),
};
return OffsetCalculator<1>(num_reduce_dims, iter.shape().data(), strides.data());
}
template <int vt, typename func_t>
__device__ void strided_iterate(func_t f, int begin, int end, int stride) {
if (begin + (vt - 1) * stride < end) {
#pragma unroll
for (int i = 0; i < vt; i++) {
f(i, begin + i * stride);
}
} else {
#pragma unroll
for (int i = 0; i < vt; i++) {
int idx = begin + i * stride;
if (idx < end) {
f(i, idx);
}
}
}
}
template <int vt, typename type_t, typename foo_t>
__device__ Array<type_t, vt> load_memory(const type_t* in, int begin, int end, int stride, foo_t foo) {
Array<type_t, vt> res;
strided_iterate<vt>([&](int i, int idx) {
res[i] = in[foo(idx)];
}, begin, end, stride);
return res;
}
template <int vt, typename type_t>
__device__ Array<type_t, vt> load_memory(const type_t* in, int begin, int end, int stride) {
return load_memory<vt>(in, begin, end, stride, [](int idx) { return idx; });
}
template <typename scalar_t, typename func_t>
struct ReduceOp {
using traits = binary_function_traits<func_t>;
using arg_t = typename traits::arg2_t;
using InputCalculator = OffsetCalculator<1>;
using OutputCalculator = OffsetCalculator<2>;
static constexpr int vt0 = 4;
func_t op;
arg_t ident;
ReduceConfig config;
InputCalculator input_calc;
OutputCalculator output_calc;
const void* src;
void* dst;
void* buffer;
int* semaphores;
bool accumulate;
ReduceOp(func_t op, ReduceConfig config, InputCalculator input_calc, OutputCalculator output_calc,
const void* src, void* dst, void* buffer, int* semaphores)
: op(op)
, config(config)
, input_calc(input_calc)
, output_calc(output_calc)
, src(src)
, dst(dst)
, buffer(buffer)
, semaphores(semaphores) {
}
AT_DEVICE void run() const {
int output_idx = config.output_idx();
int input_idx = config.input_idx();
auto base_offsets = output_calc.get(output_idx);
arg_t value = ident;
if (output_idx < config.num_outputs && input_idx < config.num_inputs) {
auto input_slice = (const char*)src + base_offsets[1];
value = thread_reduce((const scalar_t*)input_slice);
}
bool should_block_reduce = config.should_block_reduce();
if (should_block_reduce) {
value = block_reduce(value);
}
if (config.should_warp_reduce() && (!should_block_reduce || threadIdx.y == 0)) {
value = warp_reduce(value);
}
auto out = (scalar_t*)((char*)dst + base_offsets[0]);
if (config.should_global_reduce()) {
value = global_reduce(value, out);
} else if (config.should_store(output_idx)) {
if (accumulate) {
value = op(*out, value);
}
*out = value;
}
}
AT_DEVICE Array<scalar_t, vt0> load_inputs(const scalar_t* data, int offset) const {
int end = config.num_inputs;
int stride = input_calc.strides_[0][0] / sizeof(scalar_t);
if (input_calc.dims == 1) {
return load_memory<vt0>(data, offset, end, config.step_input, [&](int idx) {
return idx * stride;
});
} else {
return load_memory<vt0>(data, offset, end, config.step_input, [&](int idx) {
return input_calc.get(idx)[0] / sizeof(scalar_t);
});
}
}
AT_DEVICE arg_t thread_reduce_once(const scalar_t* data, int offset) const {
auto values = load_inputs(data, offset);
arg_t value;
strided_iterate<vt0>([&](int i, int idx) {
value = i == 0 ? (arg_t)values[0] : op(value, values[i]);
}, offset, config.num_inputs, config.step_input);
return value;
}
AT_DEVICE arg_t thread_reduce(const scalar_t* data) const {
arg_t value = ident;
int idx = config.input_idx();
while (idx < config.num_inputs) {
arg_t next = thread_reduce_once(data, idx);
value = op(value, next);
idx += config.step_input * vt0;
}
return value;
}
AT_DEVICE arg_t warp_reduce(arg_t value) const {
for (int offset = 1; offset < warpSize; offset <<= 1) {
arg_t other = WARP_SHFL_DOWN(value, offset);
value = op(value, other);
}
return value;
}
AT_DEVICE arg_t block_reduce(arg_t value) const {
extern __shared__ char shared_memory[];
arg_t* shared = (arg_t*)shared_memory;
shared[config.shared_memory_offset(0)] = value;
int num_warps = (blockDim.x * blockDim.y) / warpSize;
for (int offset = num_warps / 2; offset > 0; offset >>= 1) {
__syncthreads();
if (threadIdx.y < offset && threadIdx.y + offset < num_warps) {
arg_t other = shared[config.shared_memory_offset(offset)];
value = op(value, other);
shared[config.shared_memory_offset(0)] = value;
}
}
return value;
}
AT_DEVICE bool mark_block_finished() const {
extern __shared__ int is_last_block_done_shared[];
__syncthreads();
if (threadIdx.x == 0 && threadIdx.y == 0) {
int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1);
is_last_block_done_shared[0] = (prev_blocks_finished == gridDim.y - 1);
}
__syncthreads();
bool is_last_block_done = is_last_block_done_shared[0];
__syncthreads();
return is_last_block_done;
}
AT_DEVICE arg_t global_reduce(arg_t value, scalar_t* out) const {
arg_t* reduce_buffer = (arg_t*)buffer;
bool should_store = config.should_store(config.output_idx());
if (should_store) {
int offset = config.staging_memory_offset(blockIdx.y);
reduce_buffer[offset] = value;
}
__threadfence(); // make sure writes are globally visible
__syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done
bool is_last_block_done = mark_block_finished();
if (is_last_block_done) {
value = 0;
if (config.should_warp_reduce()) {
int input_offset = threadIdx.x + threadIdx.y * blockDim.x;
int step = blockDim.x * blockDim.y;
for (; input_offset < config.ctas_per_output; input_offset += step) {
int idx = config.staging_memory_offset(input_offset);
arg_t next = reduce_buffer[idx];
value = op(value, next);
}
} else {
int input_offset = threadIdx.y;
int step = blockDim.y;
for (; input_offset < config.ctas_per_output; input_offset += step) {
int idx = config.staging_memory_offset(input_offset);
arg_t next = reduce_buffer[idx];
value = op(value, next);
}
}
value = block_reduce(value);
if (config.should_warp_reduce()) {
value = warp_reduce(value);
}
if (should_store) {
if (accumulate) {
value = op(*out, value);
}
*out = value;
}
}
return value;
}
};
template<int nt, typename R>
static void launch_reduce_kernel(const ReduceConfig& config, const R& reduction) {
dim3 block = config.block();
dim3 grid = config.grid();
auto stream = at::cuda::getCurrentCUDAStream();
int shared_memory = config.shared_memory_size();
reduce_kernel<nt, R><<<grid, block, shared_memory, stream>>>(reduction);
AT_CUDA_CHECK(cudaGetLastError());
}
template <typename scalar_t, typename func_t, typename ident_t=double>
inline void gpu_reduce_kernel(TensorIterator& iter, const func_t& op, ident_t ident=0) {
ASSERT_HOST_DEVICE_LAMBDA(func_t);
AT_ASSERT(iter.numel() > 0 && iter.ntensors() == 2);
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
gpu_reduce_kernel<scalar_t>(sub_iter, op);
}
return;
}
char* out_data = (char*)iter.data_ptr(0);
const char* in_data = (char*)iter.data_ptr(1);
using traits = binary_function_traits<func_t>;
using arg_t = typename traits::arg2_t;
int warp_size = at::cuda::warp_size();
int warps_per_cta = ReduceConfig::NUM_THREADS / warp_size;
// Start by assuming that each thread handles a single output and all
// the inputs for that output.
int64_t num_outputs = iter.num_output_elements();
int64_t inputs_per_output = iter.numel() / num_outputs;
auto config = ReduceConfig(sizeof(scalar_t), num_outputs, inputs_per_output);
if (iter.ndim() == 0 || iter.strides(/*arg=*/1)[0] == sizeof(scalar_t)) {
// Split the input across lanes if the input is contiguous in the reduced
// dimension. This will require reduction between threads using warp
// shuffle instructions.
config.input_mult[0] = config.split_input(warp_size);
} else {
// Otherwise split the output across lanes in a warp.
config.output_mult[0] = config.split_output(warp_size);
}
if (config.values_per_thread() >= warps_per_cta * 16) {
// Divide the input across warps in a thread-block, if that leaves at least
// 16 elements to be summed by each thread. This will require inter-warp
// reduction using shared memory.
config.input_mult[1] = config.split_input(warps_per_cta);
} else {
// Otherwise, each warp handles a separate output.
config.output_mult[1] = config.split_output(warps_per_cta);
}
if (config.values_per_thread() >= 256 && num_outputs <= 4096) {
// Divide the input across thread-blocks if the amount of work per-thread
// is large enough and the size of the output is small enough. This will
// require a reduction using global memory.
config.ctas_per_output = div_up(config.values_per_thread(), 16);
if (config.ctas_per_output > 65535) {
config.ctas_per_output = 65535;
}
config.input_mult[2] = config.split_input(config.ctas_per_output);
}
auto output_calc = make_output_calculator(iter);
auto input_calc = make_input_calculator(iter);
at::DataPtr buffer;
at::DataPtr semaphores;
if (config.should_global_reduce()) {
auto& allocator = *at::globalContext().getTHCState()->cudaDeviceAllocator;
buffer = allocator.allocate(config.global_memory_size());
semaphores = allocator.allocate(config.semaphore_size());
auto stream = at::cuda::getCurrentCUDAStream();
AT_CUDA_CHECK(cudaMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream));
}
auto reduce = ReduceOp<scalar_t, func_t>(
op,
config,
input_calc,
output_calc,
in_data,
out_data,
buffer.get(),
(int*)semaphores.get());
reduce.ident = ident;
reduce.accumulate = iter.should_accumulate();
launch_reduce_kernel<ReduceConfig::NUM_THREADS>(config, reduce);
}
}} // namespace at::native