forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
THCApply.cuh
748 lines (666 loc) · 26.5 KB
/
THCApply.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
#ifndef THC_APPLY_INC
#define THC_APPLY_INC
#include "THCTensorCopy.h"
#include "THCReduceApplyUtils.cuh"
#include "THCTensorTypeUtils.cuh"
#include "THCTensorCopy.hpp"
//
// This file contains pointwise operation functions and kernels that
// work on both contiguous and non-contiguous tensor arguments of
// arbitrary (up to MAX_CUTORCH_DIMS) dimensioned arguments without
// copying or temporary storage.
//
// Rearrange dimensions for pointwise operations so that strides are in
// decreasing order as much as possible, so that kernels have better memory
// access patterns.
//
// For example, consider a binary operation on two "transposed" 2-dim tensors:
// sizes: 256 512
// aInfo->strides: 1 256
// bInfo->strides: 1 256
//
// Given this, each concurrent memory access inside kernelPointwiseApply2() is
// exactly 256 elements apart, resulting in poor performance.
//
// This function exchanges dimensions so that memory access is contiguous:
// sizes: 512 256
// aInfo->strides: 256 1
// bInfo->strides: 256 1
//
// (Actually, it becomes even better because now collapseDims() can turn each
// input into one contiguous array.)
//
// In general, given M (<=3) TensorInfo's with N dimensions, we can view each
// strides[i] (0 <= i < N) as an M-tuple. Given each pair i < j, we exchange
// strides[i] and [j] if
// (1) strides[i][k] < strides[j][k] for some k (0 <= k < M)
// (exchanging them will benefit input #k), and
// (2) strides[i][k] <= strieds[j][k] for all k
// (exchanging them will not make any input worse).
template <typename T1, typename IndexType,
typename T2 = void, typename T3 = void>
void rearrangeDims(TensorInfo<T1, IndexType>* aInfo,
TensorInfo<T2, IndexType>* bInfo = nullptr,
TensorInfo<T3, IndexType>* cInfo = nullptr) {
int numInfos = 1;
int dims = aInfo->dims;
IndexType *sizes[3] = { aInfo->sizes, };
IndexType *strides[3] = { aInfo->strides, };
if (bInfo != nullptr) {
++numInfos;
if (bInfo->dims != dims) return;
sizes[1] = bInfo->sizes;
strides[1] = bInfo->strides;
}
if (cInfo != nullptr) {
++numInfos;
if (cInfo->dims != dims) return;
sizes[2] = cInfo->sizes;
strides[2] = cInfo->strides;
}
// Bail out if sizes do not match: we are using "deprecated pointwise
// behavior" among tensors of different shapes but same number of elements.
for (int i = 1; i < numInfos; ++i) {
for (int j = 0; j < dims; ++j) {
if (sizes[i][j] != sizes[0][j]) return;
}
}
for (int i = 0; i < dims - 1; ++i) {
// No need to consider dimensions of size 1.
if (sizes[0][i] == 1) continue;
for (int j = i + 1; j < dims; ++j) {
if (sizes[0][j] == 1) continue;
// Compare the relative sizes of strides between dim #i and dim #j.
bool hasIncreasingStrides = false;
bool hasDecreasingStrides = false;
for (int k = 0; k < numInfos; k++) {
IndexType stride_i = strides[k][i];
IndexType stride_j = strides[k][j];
if (stride_i < stride_j) {
hasIncreasingStrides = true;
} else if (stride_i > stride_j) {
hasDecreasingStrides = true;
}
}
if (hasIncreasingStrides && !hasDecreasingStrides) {
for (int k = 0; k < numInfos; k++) {
IndexType size = sizes[k][i];
sizes[k][i] = sizes[k][j];
sizes[k][j] = size;
IndexType stride = strides[k][i];
strides[k][i] = strides[k][j];
strides[k][j] = stride;
}
}
}
}
}
// Threads per block for our apply kernel
// FIXME: use occupancy calculator instead
#define THC_APPLY_THREADS_PER_BLOCK (32 * 16)
#define THC_APPLY_BLOCKS_PER_SM 4
template <typename Op,
typename Ta,
typename IndexType,
int ADims>
__global__ void
kernelPointwiseApply1(const OffsetInfo<Ta, IndexType, ADims> a,
IndexType totalElements,
Op op) {
// NOTE: The two typecasts below are essential when IndexType is 64-bit;
// without them, results are silently truncated to 32 bits!
for (IndexType linearIndex = (IndexType) blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < totalElements;
linearIndex += (IndexType) gridDim.x * blockDim.x) {
op(a.get(linearIndex));
}
}
template <typename Op,
typename Ta, typename Tb,
typename IndexType,
int ADims, int BDims>
__global__ void
kernelPointwiseApply2(const OffsetInfo<Ta, IndexType, ADims> a,
const OffsetInfo<Tb, IndexType, BDims> b,
IndexType totalElements,
Op op) {
for (IndexType linearIndex = (IndexType) blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < totalElements;
linearIndex += (IndexType) gridDim.x * blockDim.x) {
op(a.get(linearIndex), b.get(linearIndex));
}
}
template <typename Op,
typename Ta, typename Tb, typename Tc,
typename IndexType,
int ADims, int BDims, int CDims>
__global__ void
kernelPointwiseApply3(const OffsetInfo<Ta, IndexType, ADims> a,
const OffsetInfo<Tb, IndexType, BDims> b,
const OffsetInfo<Tc, IndexType, CDims> c,
IndexType totalElements,
Op op) {
for (IndexType linearIndex = (IndexType) blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < totalElements;
linearIndex += (IndexType) gridDim.x * blockDim.x) {
op(a.get(linearIndex), b.get(linearIndex), c.get(linearIndex));
}
}
inline dim3 getApplyBlock() {
return dim3(THC_APPLY_THREADS_PER_BLOCK);
}
inline bool getApplyGrid(THCState* state, uint64_t totalElements, dim3& grid, int curDevice) {
if (curDevice == -1) return false;
uint64_t numBlocks = THCCeilDiv(totalElements, static_cast<uint64_t>(THC_APPLY_THREADS_PER_BLOCK));
uint64_t maxGridX = THCState_getDeviceProperties(state, curDevice)->maxGridSize[0];
if (numBlocks > maxGridX)
numBlocks = maxGridX;
// For 32-bit indices, make sure that gridDim.x * blockDim.x fits in 32 bits.
if (totalElements <= INT32_MAX &&
numBlocks > INT32_MAX / THC_APPLY_THREADS_PER_BLOCK)
numBlocks = INT32_MAX / THC_APPLY_THREADS_PER_BLOCK;
grid = dim3(numBlocks);
return true;
}
template <typename ScalarTypeA,
typename TensorTypeA,
typename Op>
bool THC_pointwiseApply1(THCState* state,
TensorTypeA* a,
const Op& op,
TensorArgType aType = ReadWrite) {
if (THCTensor_nDimensionLegacyAll(state, a) > MAX_CUTORCH_DIMS) {
return false;
}
if (THCTensor_nDimensionLegacyAll(state, a) == 0) {
// Zero-dim tensor; do nothing
return true;
}
const dim3 block = getApplyBlock();
dim3 grid;
ptrdiff_t totalElements = THCTensor_nElement(state, a);
int curDevice = -1;
cudaGetDevice(&curDevice);
if (!getApplyGrid(state, totalElements, grid, curDevice)) {
return false;
}
/*
Expands readable/writable tensors whose indices may be "overlapped."
This ensures that each element of the tensor is operated on once and only
once.
*/
TensorTypeA* oldA = NULL;
if (aType == ReadWrite &&
THCTensor_maybeOverlappingIndices(state, a)) {
// Must perform in contiguous space
oldA = a;
a = (TensorTypeA*)THCTensor_newContiguous<ScalarTypeA>(state, a);
}
// It is possible that the tensor dimensions are able to be collapsed,
// and thus we can reduce the actual code complexity of the copy by
// exploiting this knowledge statically, since the div/mod is the
// most expensive part of the operation, more so than memory accesses.
// For instance, when copying a non-contiguous to a contiguous tensor
// (or vice versa), the contiguous tensor can be collapsed to one
// dimension, and the loop to translate the linear index to the array
// index can be similarly collapsed. That is what this unrolling is for.
#define HANDLE_CASE(TYPE, A) \
kernelPointwiseApply1<Op, \
ScalarTypeA, \
TYPE, A> \
<<<grid, block, 0, THCState_getCurrentStreamOnDevice(state, curDevice)>>>( \
OffsetInfo<ScalarTypeA, TYPE, A> \
(aInfo), \
(TYPE) totalElements, op);
#define HANDLE_A_CASE(TYPE, A) { \
switch (A) { \
case 1: \
HANDLE_CASE(TYPE, 1); \
break; \
case 2: \
HANDLE_CASE(TYPE, 2); \
break; \
default: \
HANDLE_CASE(TYPE, -1); \
break; \
} \
}
// Can we use 32-bit integer math in the kernel (the linear ID for the copy
// and the resulting non-linear offset is all computable using 32-bit math?)
// We also use unsigned index math in the kernel, as signed div/mod has
// additional overhead.
if (THCTensor_canUse32BitIndexMath(state, a)) {
TensorInfo<ScalarTypeA, unsigned int> aInfo =
getTensorInfo<ScalarTypeA, TensorTypeA, unsigned int>(state, a);
rearrangeDims(&aInfo);
aInfo.collapseDims();
#if CUDA_VERSION < 9000
if (!aInfo.isContiguous()) {
grid.x = min(THCState_getCurrentDeviceProperties(state)->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x);
}
#endif
HANDLE_A_CASE(unsigned int, aInfo.dims);
} else {
TensorInfo<ScalarTypeA, uint64_t> aInfo =
getTensorInfo<ScalarTypeA, TensorTypeA, uint64_t>(state, a);
rearrangeDims(&aInfo);
aInfo.collapseDims();
/*
Only instantiates the all 1D special case and the fallback all nD case for
large (64-bit indexed) tensors to reduce compilation time.
*/
if (aInfo.dims == 1) {
OffsetInfo<ScalarTypeA, uint64_t, 1>
aOffset(aInfo);
kernelPointwiseApply1<Op,
ScalarTypeA,
uint64_t, 1>
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
aOffset, (uint64_t) totalElements, op);
} else {
#if CUDA_VERSION < 9000
grid.x = min(THCState_getCurrentDeviceProperties(state)->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x);
#endif
OffsetInfo<ScalarTypeA, uint64_t, -1>
aOffset(aInfo);
kernelPointwiseApply1<Op,
ScalarTypeA,
uint64_t, -1>
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
aOffset, (uint64_t) totalElements, op);
}
}
#undef HANDLE_CASE
#undef HANDLE_A_CASE
if (oldA) {
// Ignore overlaps when copying back; if we use THCTensor_copy
// instead, it will recursively try and invoke ourselves to make
// oldA contiguous.
THCTensor_copyIgnoringOverlaps<ScalarTypeA>(state, oldA, a);
THCTensor_free(state, a);
a = oldA;
}
return true;
}
template <typename ScalarTypeA,
typename ScalarTypeB,
typename TensorTypeA,
typename TensorTypeB,
typename Op>
bool THC_pointwiseApply2(THCState* state,
TensorTypeA* a,
TensorTypeB* b,
const Op& op,
TensorArgType aType = ReadWrite,
TensorArgType bType = ReadOnly) {
ptrdiff_t totalElements = THCTensor_nElement(state, a);
if (totalElements != THCTensor_nElement(state, b)) {
return false;
}
if (THCTensor_nDimensionLegacyAll(state, a) > MAX_CUTORCH_DIMS ||
THCTensor_nDimensionLegacyAll(state, b) > MAX_CUTORCH_DIMS) {
return false;
}
if (THCTensor_nDimensionLegacyAll(state, a) == 0) {
// Zero-dim tensor; do nothing
return true;
}
const dim3 block = getApplyBlock();
dim3 grid;
int curDevice = -1;
cudaGetDevice(&curDevice);
if (!getApplyGrid(state, totalElements, grid, curDevice)) {
return false;
}
/*
Expands readable/writable tensors whose indices may be "overlapped."
This ensures that each element of the tensor is operated on once and only
once.
*/
TensorTypeA* oldA = NULL;
TensorTypeB* oldB = NULL;
if (aType == ReadWrite &&
THCTensor_maybeOverlappingIndices(state, a)) {
// Must perform in contiguous space
oldA = a;
a = (TensorTypeA*)THCTensor_newContiguous<ScalarTypeA>(state, a);
}
if (bType == ReadWrite &&
THCTensor_maybeOverlappingIndices(state, b)) {
// Must perform in contiguous space
oldB = b;
b = (TensorTypeB*)THCTensor_newContiguous<ScalarTypeB>(state, b);
}
// It is possible that the tensor dimensions are able to be collapsed,
// and thus we can reduce the actual code complexity of the copy by
// exploiting this knowledge statically, since the div/mod is the
// most expensive part of the operation, more so than memory accesses.
// For instance, when copying a non-contiguous to a contiguous tensor
// (or vice versa), the contiguous tensor can be collapsed to one
// dimension, and the loop to translate the linear index to the array
// index can be similarly collapsed. That is what this unrolling is for.
#define HANDLE_CASE(TYPE, A, B) \
kernelPointwiseApply2<Op, \
ScalarTypeA, \
ScalarTypeB, \
TYPE, A, B> \
<<<grid, block, 0, THCState_getCurrentStreamOnDevice(state, curDevice)>>>( \
OffsetInfo<ScalarTypeA, TYPE, A> \
(aInfo), \
OffsetInfo<ScalarTypeB, TYPE, B> \
(bInfo), \
(TYPE) totalElements, op);
#define HANDLE_B_CASE(TYPE, A, B) { \
switch (B) { \
case 1: \
HANDLE_CASE(TYPE, A, 1); \
break; \
case 2: \
HANDLE_CASE(TYPE, A, 2); \
break; \
default: \
HANDLE_CASE(TYPE, A, -1); \
break; \
} \
}
#define HANDLE_A_CASE(TYPE, A, B) { \
switch (A) { \
case 1: \
HANDLE_B_CASE(TYPE, 1, B); \
break; \
case 2: \
HANDLE_B_CASE(TYPE, 2, B); \
break; \
default: \
HANDLE_B_CASE(TYPE, -1, B); \
break; \
} \
}
if (THCTensor_canUse32BitIndexMath(state, a) &&
THCTensor_canUse32BitIndexMath(state, b)) {
TensorInfo<ScalarTypeA, unsigned int> aInfo =
getTensorInfo<ScalarTypeA, TensorTypeA, unsigned int>(state, a);
TensorInfo<ScalarTypeB, unsigned int> bInfo =
getTensorInfo<ScalarTypeB, TensorTypeB, unsigned int>(state, b);
rearrangeDims(&aInfo, &bInfo);
aInfo.collapseDims();
bInfo.collapseDims();
#if CUDA_VERSION < 9000
if (!(aInfo.isContiguous() && bInfo.isContiguous()))
grid.x = min(THCState_getCurrentDeviceProperties(state)->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x);
#endif
HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims);
} else {
TensorInfo<ScalarTypeA, uint64_t> aInfo =
getTensorInfo<ScalarTypeA, TensorTypeA, uint64_t>(state, a);
TensorInfo<ScalarTypeB, uint64_t> bInfo =
getTensorInfo<ScalarTypeB, TensorTypeB, uint64_t>(state, b);
rearrangeDims(&aInfo, &bInfo);
aInfo.collapseDims();
bInfo.collapseDims();
/*
Only instantiates the all 1D special case and the fallback all nD case for
large (64-bit indexed) tensors to reduce compilation time.
*/
if (aInfo.dims == 1 && bInfo.dims == 1) {
OffsetInfo<ScalarTypeA, uint64_t, 1>
aOffset(aInfo);
OffsetInfo<ScalarTypeB, uint64_t, 1>
bOffset(bInfo);
kernelPointwiseApply2<Op,
ScalarTypeA,
ScalarTypeB,
uint64_t, 1, 1>
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
aOffset, bOffset, (uint64_t) totalElements, op);
} else {
#if CUDA_VERSION < 9000
grid.x = min(THCState_getCurrentDeviceProperties(state)->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x);
#endif
OffsetInfo<ScalarTypeA, uint64_t, -1>
aOffset(aInfo);
OffsetInfo<ScalarTypeB, uint64_t, -1>
bOffset(bInfo);
kernelPointwiseApply2<Op,
ScalarTypeA,
ScalarTypeB,
uint64_t, -1, -1>
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
aOffset, bOffset, (uint64_t) totalElements, op);
}
}
#undef HANDLE_CASE
#undef HANDLE_B_CASE
#undef HANDLE_A_CASE
if (oldA) {
// Ignore overlaps when copying back; if we use THCTensor_copy
// instead, it will recursively try and invoke ourselves to make
// oldA contiguous.
THCTensor_copyIgnoringOverlaps<ScalarTypeA>(state, oldA, a);
THCTensor_free(state, a);
a = oldA;
}
if (oldB) {
// Ignore overlaps when copying back; if we use THCTensor_copy
// instead, it will recursively try and invoke ourselves to make
// oldB contiguous.
THCTensor_copyIgnoringOverlaps<ScalarTypeB>(state, oldB, b);
THCTensor_free(state, b);
b = oldB;
}
return true;
}
template <typename ScalarTypeA,
typename ScalarTypeB,
typename ScalarTypeC,
typename TensorTypeA,
typename TensorTypeB,
typename TensorTypeC,
typename Op>
bool THC_pointwiseApply3(THCState* state,
TensorTypeA* a,
TensorTypeB* b,
TensorTypeC* c,
const Op& op,
TensorArgType aType = ReadWrite,
TensorArgType bType = ReadOnly,
TensorArgType cType = ReadOnly) {
ptrdiff_t totalElements = THCTensor_nElement(state, a);
if (totalElements != THCTensor_nElement(state, b) ||
totalElements != THCTensor_nElement(state, c)) {
return false;
}
if (THCTensor_nDimensionLegacyAll(state, a) > MAX_CUTORCH_DIMS ||
THCTensor_nDimensionLegacyAll(state, b) > MAX_CUTORCH_DIMS ||
THCTensor_nDimensionLegacyAll(state, c) > MAX_CUTORCH_DIMS) {
return false;
}
if (THCTensor_nDimensionLegacyAll(state, a) == 0) {
// Zero-dim tensor; do nothing
return true;
}
const dim3 block = getApplyBlock();
dim3 grid;
int curDevice = -1;
cudaGetDevice(&curDevice);
if (!getApplyGrid(state, totalElements, grid, curDevice)) {
return false;
}
/*
Expands readable/writable tensors whose indices may be "overlapped."
This ensures that each element of the tensor is operated on once and only
once.
*/
TensorTypeA* oldA = NULL;
TensorTypeB* oldB = NULL;
TensorTypeC* oldC = NULL;
if (aType == ReadWrite &&
THCTensor_maybeOverlappingIndices(state, a)) {
// Must perform in contiguous space
oldA = a;
a = (TensorTypeA*)THCTensor_newContiguous<ScalarTypeA>(state, a);
}
if (bType == ReadWrite &&
THCTensor_maybeOverlappingIndices(state, b)) {
// Must perform in contiguous space
oldB = b;
b = (TensorTypeB*)THCTensor_newContiguous<ScalarTypeB>(state, b);
}
if (cType == ReadWrite &&
THCTensor_maybeOverlappingIndices(state, c)) {
// Must perform in contiguous space
oldC = c;
c = (TensorTypeC*)THCTensor_newContiguous<ScalarTypeC>(state, c);
}
#define HANDLE_CASE(TYPE, A, B, C) \
kernelPointwiseApply3<Op, \
ScalarTypeA, \
ScalarTypeB, \
ScalarTypeC, \
TYPE, A, B, C> \
<<<grid, block, 0, THCState_getCurrentStreamOnDevice(state, curDevice)>>>( \
OffsetInfo<ScalarTypeA, TYPE, A> \
(aInfo), \
OffsetInfo<ScalarTypeB, TYPE, B> \
(bInfo), \
OffsetInfo<ScalarTypeC, TYPE, C> \
(cInfo), \
(TYPE) totalElements, op);
#define HANDLE_C_CASE(TYPE, A, B, C) { \
switch (C) { \
case 1: \
HANDLE_CASE(TYPE, A, B, 1); \
break; \
case 2: \
HANDLE_CASE(TYPE, A, B, 2); \
break; \
default: \
HANDLE_CASE(TYPE, A, B, -1); \
break; \
} \
}
#define HANDLE_B_CASE(TYPE, A, B, C) { \
switch (B) { \
case 1: \
HANDLE_C_CASE(TYPE, A, 1, C); \
break; \
case 2: \
HANDLE_C_CASE(TYPE, A, 2, C); \
break; \
default: \
HANDLE_C_CASE(TYPE, A, -1, C); \
break; \
} \
}
#define HANDLE_A_CASE(TYPE, A, B, C) { \
switch (A) { \
case 1: \
HANDLE_B_CASE(TYPE, 1, B, C); \
break; \
case 2: \
HANDLE_B_CASE(TYPE, 2, B, C); \
break; \
default: \
HANDLE_B_CASE(TYPE, -1, B, C); \
break; \
} \
}
if (THCTensor_canUse32BitIndexMath(state, a) &&
THCTensor_canUse32BitIndexMath(state, b) &&
THCTensor_canUse32BitIndexMath(state, c)) {
TensorInfo<ScalarTypeA, unsigned int> aInfo =
getTensorInfo<ScalarTypeA, TensorTypeA, unsigned int>(state, a);
TensorInfo<ScalarTypeB, unsigned int> bInfo =
getTensorInfo<ScalarTypeB, TensorTypeB, unsigned int>(state, b);
TensorInfo<ScalarTypeC, unsigned int> cInfo =
getTensorInfo<ScalarTypeC, TensorTypeC, unsigned int>(state, c);
rearrangeDims(&aInfo, &bInfo, &cInfo);
aInfo.collapseDims();
bInfo.collapseDims();
cInfo.collapseDims();
#if CUDA_VERSION < 9000
if (!(aInfo.isContiguous() && bInfo.isContiguous() && cInfo.isContiguous()))
grid.x = min(THCState_getCurrentDeviceProperties(state)->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x);
#endif
HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims, cInfo.dims);
} else {
TensorInfo<ScalarTypeA, uint64_t> aInfo =
getTensorInfo<ScalarTypeA, TensorTypeA, uint64_t>(state, a);
TensorInfo<ScalarTypeB, uint64_t> bInfo =
getTensorInfo<ScalarTypeB, TensorTypeB, uint64_t>(state, b);
TensorInfo<ScalarTypeC, uint64_t> cInfo =
getTensorInfo<ScalarTypeC, TensorTypeC, uint64_t>(state, c);
rearrangeDims(&aInfo, &bInfo, &cInfo);
aInfo.collapseDims();
bInfo.collapseDims();
cInfo.collapseDims();
/*
Only instantiates the all 1D special case and the fallback all nD case for
large (64-bit indexed) tensors to reduce compilation time.
*/
if (aInfo.dims == 1 && bInfo.dims == 1 && cInfo.dims == 1) {
OffsetInfo<ScalarTypeA, uint64_t, 1>
aOffset(aInfo);
OffsetInfo<ScalarTypeB, uint64_t, 1>
bOffset(bInfo);
OffsetInfo<ScalarTypeC, uint64_t, 1>
cOffset(cInfo);
kernelPointwiseApply3<Op,
ScalarTypeA,
ScalarTypeB,
ScalarTypeC,
uint64_t, 1, 1, 1>
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
aOffset, bOffset, cOffset, (uint64_t) totalElements, op);
} else {
#if CUDA_VERSION < 9000
grid.x = min(THCState_getCurrentDeviceProperties(state)->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x);
#endif
OffsetInfo<ScalarTypeA, uint64_t, -1>
aOffset(aInfo);
OffsetInfo<ScalarTypeB, uint64_t, -1>
bOffset(bInfo);
OffsetInfo<ScalarTypeC, uint64_t, -1>
cOffset(cInfo);
kernelPointwiseApply3<Op,
ScalarTypeA,
ScalarTypeB,
ScalarTypeC,
uint64_t, -1, -1, -1>
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
aOffset, bOffset, cOffset, (uint64_t) totalElements, op);
}
}
#undef HANDLE_CASE
#undef HANDLE_C_CASE
#undef HANDLE_B_CASE
#undef HANDLE_A_CASE
if (oldA) {
// Ignore overlaps when copying back; if we use THCTensor_copy
// instead, it will recursively try and invoke ourselves to make
// oldA contiguous.
THCTensor_copyIgnoringOverlaps<ScalarTypeA>(state, oldA, a);
THCTensor_free(state, a);
a = oldA;
}
if (oldB) {
// Ignore overlaps when copying back; if we use THCTensor_copy
// instead, it will recursively try and invoke ourselves to make
// oldB contiguous.
THCTensor_copyIgnoringOverlaps<ScalarTypeB>(state, oldB, b);
THCTensor_free(state, b);
b = oldB;
}
if (oldC) {
// Ignore overlaps when copying back; if we use THCTensor_copy
// instead, it will recursively try and invoke ourselves to make
// oldC contiguous.
THCTensor_copyIgnoringOverlaps<ScalarTypeC>(state, oldC, c);
THCTensor_free(state, c);
c = oldC;
}
return true;
}
#undef THC_APPLY_THREADS_PER_BLOCK
#undef THC_APPLY_BLOCKS_PER_SM
#endif // THC_APPLY_INC