forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
THCBlas.cu
580 lines (496 loc) · 22.7 KB
/
THCBlas.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
#include "THCBlas.h"
#include "THCGeneral.h"
#include "TH/THHalf.h"
#include <algorithm>
float THCudaBlas_Sdot(THCState *state, int64_t n, float *x, int64_t incx, float *y, int64_t incy)
{
if (n == 1) {
incx = 1;
incy = 1;
}
if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) {
int i_n = (int)n;
int i_incx = (int)incx;
int i_incy = (int)incy;
float result;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSdot(handle, i_n, x, i_incx, y, i_incy, &result));
return result;
}
THError("Cublas_Sdot only supports n, incx and incy "
"up to signed integer limits: %d", INT_MAX);
return 0;
}
double THCudaBlas_Ddot(THCState *state, int64_t n, double *x, int64_t incx, double *y, int64_t incy)
{
if (n == 1) {
incx = 1;
incy = 1;
}
if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) {
int i_n = (int)n;
int i_incx = (int)incx;
int i_incy = (int)incy;
double result;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDdot(handle, i_n, x, i_incx, y, i_incy, &result));
return result;
}
THError("Cublas_Ddot only supports n, incx and incy "
"up to signed integer limits: %d", INT_MAX);
return 0;
}
at::Half THCudaBlas_Hdot(THCState *state, int64_t n, at::Half *x, int64_t incx, at::Half *y, int64_t incy)
{
#if CUDA_VERSION >= 8000
if (n == 1) {
incx = 1;
incy = 1;
}
if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) {
at::Half result;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDotEx(handle, n,
x, CUDA_R_16F, incx,
y, CUDA_R_16F, incy,
&result, CUDA_R_16F,
CUDA_R_32F));
return result;
}
THError("Cublas_Hdot only supports n, incx and incy "
"up to signed integer limits: %d", INT_MAX);
return 0.0;
#else
THError("Cublas_Hdot requires CUDA 8.0+");
return 0.0;
#endif
}
/* Level 2 */
void adjustLdLevel2(int64_t m, int64_t n, int64_t *lda)
{
// Note: leading dimensions generally are checked that they are > 0 and at least as big the result
// requires (even if the value won't be used).
// TODO: why does Level3 check trans but this doesn't?
if (n <= 1)
*lda = std::max<int64_t>(m, 1);
}
void THCudaBlas_Sgemv(THCState *state, char trans, int64_t m, int64_t n, float alpha, float *a, int64_t lda, float *x, int64_t incx, float beta, float *y, int64_t incy)
{
adjustLdLevel2(m, n, &lda);
cublasOperation_t op;
if (trans == 't') op = CUBLAS_OP_T;
else if (trans == 'n') op = CUBLAS_OP_N;
else if (trans == 'c') op = CUBLAS_OP_C;
else THError("Cublas_Sgemv parameter trans should be 't', 'n' or 'c'.");
if( (m <= INT_MAX) && (n <= INT_MAX) &&
(lda > 0) && (lda <= INT_MAX) &&
(incx > 0) && (incx <= INT_MAX) &&
(incy > 0) && (incy <= INT_MAX) )
{
int i_m = (int)m;
int i_n = (int)n;
int i_lda = (int)lda;
int i_incx = (int)incx;
int i_incy = (int)incy;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSgemv(handle, op, i_m, i_n, &alpha, a, i_lda, x, i_incx, &beta, y, i_incy));
return;
}
THError("Cublas_Sgemv only supports m, n, lda, incx, incy"
"in the range 0 < [val] <= %d", INT_MAX);
}
void THCudaBlas_Dgemv(THCState *state, char trans, int64_t m, int64_t n, double alpha, double *a, int64_t lda, double *x, int64_t incx, double beta, double *y, int64_t incy)
{
adjustLdLevel2(m, n, &lda);
cublasOperation_t op;
if (trans == 't') op = CUBLAS_OP_T;
else if (trans == 'n') op = CUBLAS_OP_N;
else if (trans == 'c') op = CUBLAS_OP_C;
else THError("Cublas_Sgemv parameter trans should be 't', 'n' or 'c'.");
if( (m <= INT_MAX) && (n <= INT_MAX) &&
(lda > 0) && (lda <= INT_MAX) &&
(incx > 0) && (incx <= INT_MAX) &&
(incy > 0) && (incy <= INT_MAX) )
{
int i_m = (int)m;
int i_n = (int)n;
int i_lda = (int)lda;
int i_incx = (int)incx;
int i_incy = (int)incy;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDgemv(handle, op, i_m, i_n, &alpha, a, i_lda, x, i_incx, &beta, y, i_incy));
return;
}
THError("Cublas_Dgemv only supports m, n, lda, incx, incy"
"in the range 0 < [val] <= %d", INT_MAX);
}
void THCudaBlas_Sger(THCState *state, int64_t m, int64_t n, float alpha, float *x, int64_t incx, float *y, int64_t incy, float *a, int64_t lda)
{
adjustLdLevel2(m, n, &lda);
if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
{
int i_m = (int)m;
int i_n = (int)n;
int i_lda = (int)lda;
int i_incx = (int)incx;
int i_incy = (int)incy;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSger(handle, i_m, i_n, &alpha, x, i_incx, y, i_incy, a, i_lda));
return;
}
THError("Cublas_Sger only supports m, n, lda, incx, incy"
"with the bound [val] <= %d", INT_MAX);
}
void THCudaBlas_Dger(THCState *state, int64_t m, int64_t n, double alpha, double *x, int64_t incx, double *y, int64_t incy, double *a, int64_t lda)
{
adjustLdLevel2(m, n, &lda);
if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
{
int i_m = (int)m;
int i_n = (int)n;
int i_lda = (int)lda;
int i_incx = (int)incx;
int i_incy = (int)incy;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDger(handle, i_m, i_n, &alpha, x, i_incx, y, i_incy, a, i_lda));
return;
}
THError("Cublas_Dger only supports m, n, lda, incx, incy"
"with the bound [val] <= %d", INT_MAX);
}
cublasOperation_t convertTransToCublasOperation(char trans) {
if (trans == 't') return CUBLAS_OP_T;
else if (trans == 'n') return CUBLAS_OP_N;
else if (trans == 'c') return CUBLAS_OP_C;
else {
THError("trans must be one of: t, n, c");
return CUBLAS_OP_T;
}
}
void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t *lda, int64_t *ldb, int64_t *ldc)
{
int transa_ = ((transa == 't') || (transa == 'T'));
int transb_ = ((transb == 't') || (transb == 'T'));
// Note: leading dimensions generally are checked that they are > 0 and at least as big the result
// requires (even if the value won't be used).
if(n <= 1)
*ldc = std::max<int64_t>(m, 1);
if(transa_)
{
if(m <= 1)
*lda = std::max<int64_t>(k, 1);
}
else
{
if(k <= 1)
*lda = std::max<int64_t>(m, 1);
}
if(transb_)
{
if(k <= 1)
*ldb = std::max<int64_t>(n, 1);
}
else
{
if(n <= 1)
*ldb = std::max<int64_t>(k, 1);
}
}
/* Level 3 */
void THCudaBlas_Sgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, float alpha, float *a, int64_t lda, float *b, int64_t ldb, float beta, float *c, int64_t ldc)
{
adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);
if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) )
{
int i_m = (int)m;
int i_n = (int)n;
int i_k = (int)k;
int i_lda = (int)lda;
int i_ldb = (int)ldb;
int i_ldc = (int)ldc;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSgemm(handle, opa, opb, i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb, &beta, c, i_ldc));
return;
}
THError("Cublas_Sgemm only supports m, n, k, lda, ldb, ldc"
"with the bound [val] <= %d", INT_MAX);
}
// In CUDA 8.0, definition of data types for sgemmex changed
#if CUDA_VERSION < 8000
# define CUDA_R_16F CUBLAS_DATA_HALF
#endif
void THCudaBlas_Hgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, at::Half alpha, at::Half *a, int64_t lda, at::Half *b, int64_t ldb, at::Half beta, at::Half *c, int64_t ldc)
{
adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);
if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) )
{
int i_m = (int)m;
int i_n = (int)n;
int i_k = (int)k;
int i_lda = (int)lda;
int i_ldb = (int)ldb;
int i_ldc = (int)ldc;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
#ifdef __HIP_PLATFORM_HCC__
THCublasCheck(rocblas_hgemm(handle, opa, opb, i_m, i_n, i_k,
reinterpret_cast<rocblas_half*>(&alpha), reinterpret_cast<rocblas_half*>(a), i_lda,
reinterpret_cast<rocblas_half*>(b), i_ldb, reinterpret_cast<rocblas_half*>(&beta),
reinterpret_cast<rocblas_half*>(c), i_ldc));
#else
// Simulated Hgemm
float fAlpha = alpha;
float fBeta = beta;
#if CUDA_VERSION < 9000
THCublasCheck(cublasSgemmEx(handle, opa, opb,
i_m, i_n, i_k, &fAlpha,
a, CUDA_R_16F, i_lda, b, CUDA_R_16F,
i_ldb, &fBeta, c, CUDA_R_16F, i_ldc));
#else
cudaDeviceProp* prop = THCState_getCurrentDeviceProperties(state);
if (prop->major >= 5){
#ifndef __HIP_PLATFORM_HCC__
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#endif
THCublasCheck(cublasGemmEx(handle, opa, opb,
i_m, i_n, i_k, &fAlpha,
a, CUDA_R_16F, i_lda, b, CUDA_R_16F,
i_ldb, &fBeta, c, CUDA_R_16F, i_ldc,
CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP));
#ifndef __HIP_PLATFORM_HCC__
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif
}else{
THCublasCheck(cublasSgemmEx(handle, opa, opb,
i_m, i_n, i_k, &fAlpha,
a, CUDA_R_16F, i_lda, b, CUDA_R_16F,
i_ldb, &fBeta, c, CUDA_R_16F, i_ldc));
}
#endif
#endif
return;
}
THError("Cublas_Hgemm only supports m, n, k, lda, ldb, ldc"
"with th bound [val] <= %d", INT_MAX);
}
void THCudaBlas_Dgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, double alpha, double *a, int64_t lda, double *b, int64_t ldb, double beta, double *c, int64_t ldc)
{
adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);
if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) )
{
int i_m = (int)m;
int i_n = (int)n;
int i_k = (int)k;
int i_lda = (int)lda;
int i_ldb = (int)ldb;
int i_ldc = (int)ldc;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDgemm(handle, opa, opb, i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb, &beta, c, i_ldc));
return;
}
THError("Cublas_Dgemm only supports m, n, k, lda, ldb, ldc"
"with the bound [val] <= %d", INT_MAX);
}
#if CUDA_VERSION >= 9010
void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
at::Half alpha, const at::Half *a, int64_t lda, int64_t strideA, const at::Half *b, int64_t ldb, int64_t strideB,
at::Half beta, at::Half *c, int64_t ldc, int64_t strideC, int64_t batchCount)
{
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
{
THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d", INT_MAX);
}
adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
float fAlpha = alpha;
float fBeta = beta;
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
THCublasCheck(cublasGemmStridedBatchedEx(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA,
b, CUDA_R_16F, (int)ldb, strideB,
(void*)&fBeta, c, CUDA_R_16F, (int)ldc, strideC,
(int)batchCount, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
}
#endif
void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
float alpha, const float *a[], int64_t lda, const float *b[], int64_t ldb,
float beta, float *c[], int64_t ldc, int64_t batchCount)
{
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
{
THError("Cublas_SgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d", INT_MAX);
}
#ifdef __HIP_PLATFORM_HCC__
const int64_t stridea = (transa == 'N' || transa == 'n') ? lda*k : lda*n;
const int64_t strideb = (transb == 'N' || transb == 'n') ? ldb*n : ldb*k;
const int64_t stridec = ldc*n;
THCudaBlas_SgemmStridedBatched(state, transa, transb, m, n, k, alpha, *a, lda, stridea, *b, ldb, strideb, beta, *c, ldc, stridec, batchCount);
#else
adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSgemmBatched(handle,
opa, opb, (int)m, (int)n, (int)k,
&alpha, a, (int)lda, b, (int)ldb, &beta, c, (int)ldc,
(int)batchCount));
#endif
}
#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
float alpha, const float *a, int64_t lda, int64_t strideA, const float *b, int64_t ldb, int64_t strideB,
float beta, float *c, int64_t ldc, int64_t strideC, int64_t batchCount)
{
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
{
THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d", INT_MAX);
}
adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSgemmStridedBatched(handle,
opa, opb, (int)m, (int)n, (int)k,
&alpha, a, (int)lda, strideA, b, (int)ldb, strideB, &beta, c, (int)ldc, strideC,
(int)batchCount));
}
#endif
void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
double alpha, const double *a[], int64_t lda, const double *b[], int64_t ldb,
double beta, double *c[], int64_t ldc, int64_t batchCount)
{
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
{
THError("Cublas_DgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d", INT_MAX);
}
#ifdef __HIP_PLATFORM_HCC__
const int64_t stridea = (transa == 'N' || transa == 'n') ? lda*k : lda*n;
const int64_t strideb = (transb == 'N' || transb == 'n') ? ldb*n : ldb*k;
const int64_t stridec = ldc*n;
THCudaBlas_DgemmStridedBatched(state, transa, transb, m, n, k, alpha, *a, lda, stridea, *b, ldb, strideb, beta, *c, ldc, stridec, batchCount);
#else
adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDgemmBatched(handle,
opa, opb, (int)m, (int)n, (int)k,
&alpha, a, (int)lda, b, (int)ldb, &beta, c, (int)ldc,
(int)batchCount));
#endif
}
#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
double alpha, const double *a, int64_t lda, int64_t strideA, const double *b, int64_t ldb, int64_t strideB,
double beta, double *c, int64_t ldc, int64_t strideC, int64_t batchCount)
{
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
{
THError("Cublas_DgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d", INT_MAX);
}
adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDgemmStridedBatched(handle,
opa, opb, (int)m, (int)n, (int)k,
&alpha, a, (int)lda, strideA, b, (int)ldb, strideB, &beta, c, (int)ldc, strideC,
(int)batchCount));
}
#endif
/* Inverse */
void THCudaBlas_Sgetrf(THCState *state, int n, float **a, int lda, int *pivot, int *info, int batchSize) {
if( (n >= INT_MAX) || (lda >= INT_MAX) || (batchSize >= INT_MAX) )
{
THError("Cublas_Sgetrf only supports n, lda, batchSize"
"with the bound [val] <= %d", INT_MAX);
}
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSgetrfBatched(handle, n, a, lda, pivot, info, batchSize));
}
void THCudaBlas_Dgetrf(THCState *state, int n, double **a, int lda, int *pivot, int *info, int batchSize) {
if( (n >= INT_MAX) || (lda >= INT_MAX) || (batchSize >= INT_MAX) )
{
THError("Cublas_Dgetrf only supports n, lda, batchSize"
"with the bound [val] <= %d", INT_MAX);
}
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDgetrfBatched(handle, n, a, lda, pivot, info, batchSize));
}
void THCudaBlas_Sgetrs(THCState *state, char transa, int n, int nrhs, const float **a, int lda, int *pivot, float **b, int ldb, int *info, int batchSize)
{
if( (n >= INT_MAX) || (nrhs >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (batchSize >= INT_MAX) )
{
THError("Cublas_Dgetrs only supports n, nrhs, lda, ldb, batchSize"
"with the bound [val] <= %d", INT_MAX);
}
// no need to adjust leading dimensions, since matrices are square
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSgetrsBatched(handle, opa, n, nrhs, a, lda, pivot, b, ldb, info, batchSize));
}
void THCudaBlas_Dgetrs(THCState *state, char transa, int n, int nrhs, const double **a, int lda, int *pivot, double **b, int ldb, int *info, int batchSize)
{
if( (n >= INT_MAX) || (nrhs >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (batchSize >= INT_MAX) )
{
THError("Cublas_Dgetrs only supports n, nrhs, lda, ldb, batchSize"
"with the bound [val] <= %d", INT_MAX);
}
// no need to adjust leading dimensions, since matrices are square
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDgetrsBatched(handle, opa, n, nrhs, a, lda, pivot, b, ldb, info, batchSize));
}
void THCudaBlas_Sgetri(THCState *state, int n, const float **a, int lda, int *pivot, float **c, int ldc, int *info, int batchSize) {
if( (n >= INT_MAX) || (lda >= INT_MAX)|| (ldc >= INT_MAX) || (batchSize >= INT_MAX) )
{
THError("Cublas_Sgetri only supports n, lda, ldc, batchSize"
"with the bound [val] <= %d", INT_MAX);
}
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSgetriBatched(handle, n, a, lda, pivot, c, ldc, info, batchSize));
}
void THCudaBlas_Dgetri(THCState *state, int n, const double **a, int lda, int *pivot, double **c, int ldc, int *info, int batchSize) {
if( (n >= INT_MAX) || (lda >= INT_MAX)|| (ldc >= INT_MAX) || (batchSize >= INT_MAX) )
{
THError("Cublas_Dgetri only supports n, lda, ldc, batchSize"
"with the bound [val] <= %d", INT_MAX);
}
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDgetriBatched(handle, n, a, lda, pivot, c, ldc, info, batchSize));
}