@@ -344,6 +344,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
344
344
"""
345
345
code = """
346
346
347
+ bool is_float;
347
348
int elemsize;
348
349
float fbeta;
349
350
double dbeta;
@@ -361,11 +362,23 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
361
362
%(fail)s;
362
363
}
363
364
364
- if (PyArray_DESCR(%(y)s)->type_num == NPY_DOUBLE) { elemsize = 8; }
365
- else if (PyArray_DESCR(%(y)s)->type_num == NPY_FLOAT) { elemsize = 4;}
365
+ if ((PyArray_DESCR(%(y)s)->type_num != PyArray_DESCR(%(x)s)->type_num)
366
+ || (PyArray_DESCR(%(y)s)->type_num != PyArray_DESCR(%(A)s)->type_num))
367
+ {
368
+ PyErr_SetString(PyExc_TypeError, "GEMV: dtypes of A, x, y do not match");
369
+ %(fail)s;
370
+ }
371
+ if (PyArray_DESCR(%(y)s)->type_num == NPY_DOUBLE) {
372
+ is_float = 0;
373
+ elemsize = 8;
374
+ }
375
+ else if (PyArray_DESCR(%(y)s)->type_num == NPY_FLOAT) {
376
+ elemsize = 4;
377
+ is_float = 1;
378
+ }
366
379
else {
367
- PyErr_SetString(PyExc_NotImplementedError, "complex Gemv");
368
380
%(fail)s;
381
+ PyErr_SetString(PyExc_NotImplementedError, "GEMV: Inputs must be float or double");
369
382
}
370
383
371
384
fbeta = dbeta = ((dtype_%(beta)s*)PyArray_DATA(%(beta)s))[0];
@@ -408,37 +421,40 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
408
421
Py_INCREF(%(z)s);
409
422
}
410
423
}
424
+
411
425
{
412
- char TRANS = 'T';
413
- char NOTRANS = 'N';
414
426
int NA0 = PyArray_DIMS(%(A)s)[0];
415
427
int NA1 = PyArray_DIMS(%(A)s)[1];
416
- /* This formula is needed in the case where A is actually a row or
417
- * column matrix, because BLAS sometimes insists that the strides:
418
- * - are not smaller than the number of elements in the array
419
- * - are not 0.
420
- */
421
- int SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1);
422
- int SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1);
423
- int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
424
- int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
425
-
426
- dtype_%(A)s* A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
427
- dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
428
- dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
429
- // gemv expects pointers to the beginning of memory arrays,
430
- // but numpy provides a pointer to the first element,
431
- // so when the stride is negative, we need to get the last one.
432
- if (Sx < 0)
433
- x_data += (NA1 - 1) * Sx;
434
- if (Sz < 0)
435
- z_data += (NA0 - 1) * Sz;
436
428
437
429
if (NA0 * NA1)
438
430
{
431
+ // Non-empty A matrix
432
+
433
+ /* In the case where A is actually a row or column matrix,
434
+ * the strides corresponding to the dummy dimension don't matter,
435
+ * but BLAS requires these to be no smaller than the number of elements in the array.
436
+ */
437
+ int SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : NA1;
438
+ int SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : NA0;
439
+ int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
440
+ int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
441
+
442
+ dtype_%(A)s* A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
443
+ dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
444
+ dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
445
+
446
+ // gemv expects pointers to the beginning of memory arrays,
447
+ // but numpy provides a pointer to the first element,
448
+ // so when the stride is negative, we need to get the last one.
449
+ if (Sx < 0)
450
+ x_data += (NA1 - 1) * Sx;
451
+ if (Sz < 0)
452
+ z_data += (NA0 - 1) * Sz;
453
+
439
454
if ( ((SA0 < 0) || (SA1 < 0)) && (abs(SA0) == 1 || (abs(SA1) == 1)) )
440
455
{
441
456
// We can treat the array A as C-or F-contiguous by changing the order of iteration
457
+ // printf("GEMV: Iterating in reverse NA0=%%d, NA1=%%d, SA0=%%d, SA1=%%d\\ n", NA0, NA1, SA0, SA1);
442
458
if (SA0 < 0){
443
459
A_data += (NA0 -1) * SA0; // Jump to first row
444
460
SA0 = -SA0; // Iterate over rows in reverse
@@ -452,27 +468,45 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
452
468
} else if ((SA0 < 0) || (SA1 < 0) || ((SA0 != 1) && (SA1 != 1)))
453
469
{
454
470
// Array isn't contiguous, we have to make a copy
455
- // - if the copy is too long, maybe call vector/vector dot on
456
- // each row instead
457
- // printf("GEMV: Making a copy SA0=%%d, SA1=%%d\\ n", SA0, SA1);
471
+ // - if the copy is too long, maybe call vector/vector dot on each row instead
472
+ // printf("GEMV: Making a copy NA0=%%d, NA1=%%d, SA0=%%d, SA1=%%d\\ n", NA0, NA1, SA0, SA1);
458
473
npy_intp dims[2];
459
474
dims[0] = NA0;
460
475
dims[1] = NA1;
461
-
462
- PyArrayObject * A_copy = (PyArrayObject *) PyArray_Copy(
463
- %(A)s);
476
+ PyArrayObject * A_copy = (PyArrayObject *) PyArray_Copy(%(A)s);
464
477
if (!A_copy)
465
478
%(fail)s
466
479
Py_XDECREF(%(A)s);
467
480
%(A)s = A_copy;
468
- SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : ( NA1 + 1) ;
469
- SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : ( NA0 + 1) ;
481
+ SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : NA1;
482
+ SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : NA0;
470
483
A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
471
484
}
485
+ //else {printf("GEMV: Using the original array NA0=%%d, NA1=%%d, SA0=%%d, SA1=%%d\\ n", NA0, NA1, SA0, SA1);}
472
486
473
- if (SA0 == 1)
487
+ if (NA0 == 1)
488
+ {
489
+ // Vector-vector dot product, it seems faster to avoid GEMV
490
+ dtype_%(alpha)s alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
491
+
492
+ if (is_float)
493
+ {
494
+ z_data[0] *= fbeta;
495
+ z_data[0] += alpha * sdot_(&NA1, (float*)(A_data), &SA1,
496
+ (float*)x_data, &Sx);
497
+ }
498
+ else
499
+ {
500
+ z_data[0] *= dbeta;
501
+ z_data[0] += alpha * ddot_(&NA1, (double*)(A_data), &SA1,
502
+ (double*)x_data, &Sx);
503
+ }
504
+ }
505
+ else if (SA0 == 1)
474
506
{
475
- if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
507
+ // F-contiguous
508
+ char NOTRANS = 'N';
509
+ if (is_float)
476
510
{
477
511
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
478
512
sgemv_(&NOTRANS, &NA0, &NA1,
@@ -482,7 +516,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
482
516
&fbeta,
483
517
(float*)z_data, &Sz);
484
518
}
485
- else if (PyArray_DESCR(%(A)s)->type_num == NPY_DOUBLE)
519
+ else
486
520
{
487
521
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
488
522
dgemv_(&NOTRANS, &NA0, &NA1,
@@ -492,97 +526,39 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
492
526
&dbeta,
493
527
(double*)z_data, &Sz);
494
528
}
495
- else
496
- {
497
- PyErr_SetString(PyExc_AssertionError,
498
- "neither float nor double dtype");
499
- %(fail)s
500
- }
501
529
}
502
530
else if (SA1 == 1)
503
531
{
504
- if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
532
+ // C-contiguous
533
+ char TRANS = 'T';
534
+ if (is_float)
505
535
{
506
536
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
507
-
508
- // Check for vector-vector dot (NA0 == 1). The code may work
509
- // for SA1 != 1 as well, but has not been tested for this case,
510
- // so SA1 == 1 is required for safety.
511
- if (NA0 == 1 && SA1 == 1)
512
- {
513
- if (fbeta != 0.f) {
514
- z_data[0] = fbeta*z_data[0];
515
- } else {
516
- z_data[0] = 0.f;
517
- }
518
- z_data[0] += alpha*sdot_(&NA1,
519
- (float*)(A_data), &SA1,
520
- (float*)x_data, &Sx);
521
- }
522
- else
523
- {
524
- sgemv_(&TRANS, &NA1, &NA0,
525
- &alpha,
526
- (float*)(A_data), &SA0,
527
- (float*)x_data, &Sx,
528
- &fbeta,
529
- (float*)z_data, &Sz);
530
- }
531
- }
532
- else if (PyArray_DESCR(%(A)s)->type_num == NPY_DOUBLE)
533
- {
534
- double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
535
-
536
- // Check for vector-vector dot (NA0 == 1). The code may work
537
- // for SA1 != 1 as well, but has not been tested for this case,
538
- // so SA1 == 1 is required for safety.
539
- if (NA0 == 1 && SA1 == 1)
540
- {
541
- if (dbeta != 0.) {
542
- z_data[0] = dbeta*z_data[0];
543
- } else {
544
- z_data[0] = 0.;
545
- }
546
- z_data[0] += alpha*ddot_(&NA1,
547
- (double*)(A_data), &SA1,
548
- (double*)x_data, &Sx);
549
- }
550
- else
551
- {
552
- dgemv_(&TRANS, &NA1, &NA0,
553
- &alpha,
554
- (double*)(A_data), &SA0,
555
- (double*)x_data, &Sx,
556
- &dbeta,
557
- (double*)z_data, &Sz);
558
- }
537
+ sgemv_(&TRANS, &NA1, &NA0,
538
+ &alpha,
539
+ (float*)(A_data), &SA0,
540
+ (float*)x_data, &Sx,
541
+ &fbeta,
542
+ (float*)z_data, &Sz);
559
543
}
560
544
else
561
545
{
562
- PyErr_SetString(PyExc_AssertionError,
563
- "neither float nor double dtype");
564
- %(fail)s
546
+ double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
547
+ dgemv_(&TRANS, &NA1, &NA0,
548
+ &alpha,
549
+ (double*)(A_data), &SA0,
550
+ (double*)x_data, &Sx,
551
+ &dbeta,
552
+ (double*)z_data, &Sz);
565
553
}
566
554
}
567
555
else
568
556
{
569
557
PyErr_SetString(PyExc_AssertionError,
570
- "xx is a double-strided matrix, and should have been "
571
- "copied into a memory-contiguous one.");
558
+ "A is neither C nor F-contiguous, it should have been copied into a memory-contiguous array;");
572
559
%(fail)s
573
560
}
574
561
}
575
- else if (dbeta != 1.0)
576
- {
577
- // the matrix has at least one dim of length 0
578
- // so we do this loop, which either iterates over 0 elements
579
- // or else it does the right thing for length-0 A.
580
- dtype_%(z)s * zptr = (dtype_%(z)s*)(PyArray_DATA(%(z)s));
581
- for (int i = 0; i < NA0; ++i)
582
- {
583
- zptr[i * Sz] = (dbeta == 0.0 ? 0.0 : zptr[i * Sz] * dbeta);
584
- }
585
- }
586
562
}
587
563
"""
588
564
return code % locals ()
@@ -613,7 +589,7 @@ def c_code(self, node, name, inp, out, sub):
613
589
return code
614
590
615
591
def c_code_cache_version (self ):
616
- return (15 , blas_header_version (), check_force_gemv_init ())
592
+ return (16 , blas_header_version (), check_force_gemv_init ())
617
593
618
594
619
595
cgemv_inplace = CGemv (inplace = True )
0 commit comments