Skip to content

Commit 3ff344b

Browse files
committed
Fix bug in handling of row/column matrices in GEMV c_code
Bug was caused by reusing the adjusted strides in the logic to decide whether the call to GEMV should be transposed or not. Particularly the +1 in the strides variable was causing the error branch (no double-strides) to be reached wrongly. The +1 was supposedly there for the case of matrix with length 0, but that triggers a branch where the adjusted strides are never used. This bug was introduced in afe934b
1 parent 8362f6a commit 3ff344b

File tree

3 files changed

+94
-113
lines changed

3 files changed

+94
-113
lines changed

pytensor/tensor/blas_c.py

Lines changed: 88 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
344344
"""
345345
code = """
346346
347+
bool is_float;
347348
int elemsize;
348349
float fbeta;
349350
double dbeta;
@@ -361,11 +362,23 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
361362
%(fail)s;
362363
}
363364
364-
if (PyArray_DESCR(%(y)s)->type_num == NPY_DOUBLE) { elemsize = 8; }
365-
else if (PyArray_DESCR(%(y)s)->type_num == NPY_FLOAT) { elemsize = 4;}
365+
if ((PyArray_DESCR(%(y)s)->type_num != PyArray_DESCR(%(x)s)->type_num)
366+
|| (PyArray_DESCR(%(y)s)->type_num != PyArray_DESCR(%(A)s)->type_num))
367+
{
368+
PyErr_SetString(PyExc_TypeError, "GEMV: dtypes of A, x, y do not match");
369+
%(fail)s;
370+
}
371+
if (PyArray_DESCR(%(y)s)->type_num == NPY_DOUBLE) {
372+
is_float = 0;
373+
elemsize = 8;
374+
}
375+
else if (PyArray_DESCR(%(y)s)->type_num == NPY_FLOAT) {
376+
elemsize = 4;
377+
is_float = 1;
378+
}
366379
else {
367-
PyErr_SetString(PyExc_NotImplementedError, "complex Gemv");
368380
%(fail)s;
381+
PyErr_SetString(PyExc_NotImplementedError, "GEMV: Inputs must be float or double");
369382
}
370383
371384
fbeta = dbeta = ((dtype_%(beta)s*)PyArray_DATA(%(beta)s))[0];
@@ -408,37 +421,40 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
408421
Py_INCREF(%(z)s);
409422
}
410423
}
424+
411425
{
412-
char TRANS = 'T';
413-
char NOTRANS = 'N';
414426
int NA0 = PyArray_DIMS(%(A)s)[0];
415427
int NA1 = PyArray_DIMS(%(A)s)[1];
416-
/* This formula is needed in the case where A is actually a row or
417-
* column matrix, because BLAS sometimes insists that the strides:
418-
* - are not smaller than the number of elements in the array
419-
* - are not 0.
420-
*/
421-
int SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1);
422-
int SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1);
423-
int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
424-
int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
425-
426-
dtype_%(A)s* A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
427-
dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
428-
dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
429-
// gemv expects pointers to the beginning of memory arrays,
430-
// but numpy provides a pointer to the first element,
431-
// so when the stride is negative, we need to get the last one.
432-
if (Sx < 0)
433-
x_data += (NA1 - 1) * Sx;
434-
if (Sz < 0)
435-
z_data += (NA0 - 1) * Sz;
436428
437429
if (NA0 * NA1)
438430
{
431+
// Non-empty A matrix
432+
433+
/* In the case where A is actually a row or column matrix,
434+
* the strides corresponding to the dummy dimension don't matter,
435+
* but BLAS requires these to be no smaller than the number of elements in the array.
436+
*/
437+
int SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : NA1;
438+
int SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : NA0;
439+
int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
440+
int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
441+
442+
dtype_%(A)s* A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
443+
dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
444+
dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
445+
446+
// gemv expects pointers to the beginning of memory arrays,
447+
// but numpy provides a pointer to the first element,
448+
// so when the stride is negative, we need to get the last one.
449+
if (Sx < 0)
450+
x_data += (NA1 - 1) * Sx;
451+
if (Sz < 0)
452+
z_data += (NA0 - 1) * Sz;
453+
439454
if ( ((SA0 < 0) || (SA1 < 0)) && (abs(SA0) == 1 || (abs(SA1) == 1)) )
440455
{
441456
// We can treat the array A as C-or F-contiguous by changing the order of iteration
457+
// printf("GEMV: Iterating in reverse NA0=%%d, NA1=%%d, SA0=%%d, SA1=%%d\\n", NA0, NA1, SA0, SA1);
442458
if (SA0 < 0){
443459
A_data += (NA0 -1) * SA0; // Jump to first row
444460
SA0 = -SA0; // Iterate over rows in reverse
@@ -452,27 +468,45 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
452468
} else if ((SA0 < 0) || (SA1 < 0) || ((SA0 != 1) && (SA1 != 1)))
453469
{
454470
// Array isn't contiguous, we have to make a copy
455-
// - if the copy is too long, maybe call vector/vector dot on
456-
// each row instead
457-
// printf("GEMV: Making a copy SA0=%%d, SA1=%%d\\n", SA0, SA1);
471+
// - if the copy is too long, maybe call vector/vector dot on each row instead
472+
// printf("GEMV: Making a copy NA0=%%d, NA1=%%d, SA0=%%d, SA1=%%d\\n", NA0, NA1, SA0, SA1);
458473
npy_intp dims[2];
459474
dims[0] = NA0;
460475
dims[1] = NA1;
461-
462-
PyArrayObject * A_copy = (PyArrayObject *) PyArray_Copy(
463-
%(A)s);
476+
PyArrayObject * A_copy = (PyArrayObject *) PyArray_Copy(%(A)s);
464477
if (!A_copy)
465478
%(fail)s
466479
Py_XDECREF(%(A)s);
467480
%(A)s = A_copy;
468-
SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1);
469-
SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1);
481+
SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : NA1;
482+
SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : NA0;
470483
A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
471484
}
485+
//else {printf("GEMV: Using the original array NA0=%%d, NA1=%%d, SA0=%%d, SA1=%%d\\n", NA0, NA1, SA0, SA1);}
472486
473-
if (SA0 == 1)
487+
if (NA0 == 1)
488+
{
489+
// Vector-vector dot product, it seems faster to avoid GEMV
490+
dtype_%(alpha)s alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
491+
492+
if (is_float)
493+
{
494+
z_data[0] *= fbeta;
495+
z_data[0] += alpha * sdot_(&NA1, (float*)(A_data), &SA1,
496+
(float*)x_data, &Sx);
497+
}
498+
else
499+
{
500+
z_data[0] *= dbeta;
501+
z_data[0] += alpha * ddot_(&NA1, (double*)(A_data), &SA1,
502+
(double*)x_data, &Sx);
503+
}
504+
}
505+
else if (SA0 == 1)
474506
{
475-
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
507+
// F-contiguous
508+
char NOTRANS = 'N';
509+
if (is_float)
476510
{
477511
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
478512
sgemv_(&NOTRANS, &NA0, &NA1,
@@ -482,7 +516,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
482516
&fbeta,
483517
(float*)z_data, &Sz);
484518
}
485-
else if (PyArray_DESCR(%(A)s)->type_num == NPY_DOUBLE)
519+
else
486520
{
487521
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
488522
dgemv_(&NOTRANS, &NA0, &NA1,
@@ -492,97 +526,39 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
492526
&dbeta,
493527
(double*)z_data, &Sz);
494528
}
495-
else
496-
{
497-
PyErr_SetString(PyExc_AssertionError,
498-
"neither float nor double dtype");
499-
%(fail)s
500-
}
501529
}
502530
else if (SA1 == 1)
503531
{
504-
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
532+
// C-contiguous
533+
char TRANS = 'T';
534+
if (is_float)
505535
{
506536
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
507-
508-
// Check for vector-vector dot (NA0 == 1). The code may work
509-
// for SA1 != 1 as well, but has not been tested for this case,
510-
// so SA1 == 1 is required for safety.
511-
if (NA0 == 1 && SA1 == 1)
512-
{
513-
if (fbeta != 0.f) {
514-
z_data[0] = fbeta*z_data[0];
515-
} else {
516-
z_data[0] = 0.f;
517-
}
518-
z_data[0] += alpha*sdot_(&NA1,
519-
(float*)(A_data), &SA1,
520-
(float*)x_data, &Sx);
521-
}
522-
else
523-
{
524-
sgemv_(&TRANS, &NA1, &NA0,
525-
&alpha,
526-
(float*)(A_data), &SA0,
527-
(float*)x_data, &Sx,
528-
&fbeta,
529-
(float*)z_data, &Sz);
530-
}
531-
}
532-
else if (PyArray_DESCR(%(A)s)->type_num == NPY_DOUBLE)
533-
{
534-
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
535-
536-
// Check for vector-vector dot (NA0 == 1). The code may work
537-
// for SA1 != 1 as well, but has not been tested for this case,
538-
// so SA1 == 1 is required for safety.
539-
if (NA0 == 1 && SA1 == 1)
540-
{
541-
if (dbeta != 0.) {
542-
z_data[0] = dbeta*z_data[0];
543-
} else {
544-
z_data[0] = 0.;
545-
}
546-
z_data[0] += alpha*ddot_(&NA1,
547-
(double*)(A_data), &SA1,
548-
(double*)x_data, &Sx);
549-
}
550-
else
551-
{
552-
dgemv_(&TRANS, &NA1, &NA0,
553-
&alpha,
554-
(double*)(A_data), &SA0,
555-
(double*)x_data, &Sx,
556-
&dbeta,
557-
(double*)z_data, &Sz);
558-
}
537+
sgemv_(&TRANS, &NA1, &NA0,
538+
&alpha,
539+
(float*)(A_data), &SA0,
540+
(float*)x_data, &Sx,
541+
&fbeta,
542+
(float*)z_data, &Sz);
559543
}
560544
else
561545
{
562-
PyErr_SetString(PyExc_AssertionError,
563-
"neither float nor double dtype");
564-
%(fail)s
546+
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
547+
dgemv_(&TRANS, &NA1, &NA0,
548+
&alpha,
549+
(double*)(A_data), &SA0,
550+
(double*)x_data, &Sx,
551+
&dbeta,
552+
(double*)z_data, &Sz);
565553
}
566554
}
567555
else
568556
{
569557
PyErr_SetString(PyExc_AssertionError,
570-
"xx is a double-strided matrix, and should have been "
571-
"copied into a memory-contiguous one.");
558+
"A is neither C nor F-contiguous, it should have been copied into a memory-contiguous array;");
572559
%(fail)s
573560
}
574561
}
575-
else if (dbeta != 1.0)
576-
{
577-
// the matrix has at least one dim of length 0
578-
// so we do this loop, which either iterates over 0 elements
579-
// or else it does the right thing for length-0 A.
580-
dtype_%(z)s * zptr = (dtype_%(z)s*)(PyArray_DATA(%(z)s));
581-
for (int i = 0; i < NA0; ++i)
582-
{
583-
zptr[i * Sz] = (dbeta == 0.0 ? 0.0 : zptr[i * Sz] * dbeta);
584-
}
585-
}
586562
}
587563
"""
588564
return code % locals()
@@ -613,7 +589,7 @@ def c_code(self, node, name, inp, out, sub):
613589
return code
614590

615591
def c_code_cache_version(self):
616-
return (15, blas_header_version(), check_force_gemv_init())
592+
return (16, blas_header_version(), check_force_gemv_init())
617593

618594

619595
cgemv_inplace = CGemv(inplace=True)

tests/tensor/test_blas.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2226,8 +2226,10 @@ def cmp_gemv(self, a_shp, b_shp, c_shp, rng):
22262226

22272227
a.set_value(a_dev.copy()[::a_step], borrow=True)
22282228
b.set_value(b_dev.copy()[::b_step1, ::b_step2], borrow=True)
2229+
# Copy as C so that it becomes F after the transpose in the graph
22292230
b_t.set_value(
2230-
np.transpose(b_dev.copy())[::b_step2, ::b_step1], borrow=True
2231+
np.transpose(b_dev).copy(order="C")[::b_step2, ::b_step1],
2232+
borrow=True,
22312233
)
22322234
c.set_value(c_dev.copy()[::c_step], borrow=True)
22332235

@@ -2244,6 +2246,7 @@ def test_gemv(self):
22442246
self.cmp_gemv(3, (3, 5), 5, rng)
22452247
self.cmp_gemv(1, (1, 5), 5, rng)
22462248
self.cmp_gemv(3, (3, 1), 1, rng)
2249+
self.cmp_gemv(1, (1, 1), 1, rng)
22472250
self.cmp_gemv(0, (0, 5), 5, rng)
22482251
self.cmp_gemv(3, (3, 0), 0, rng)
22492252
self.cmp_gemv(0, (0, 1), 1, rng)
@@ -2301,6 +2304,7 @@ def test_ger_strides(self):
23012304
self.cmp_ger((3, 5), 3, 5, rng)
23022305
self.cmp_ger((1, 5), 1, 5, rng)
23032306
self.cmp_ger((3, 1), 3, 1, rng)
2307+
self.cmp_ger((1, 1), 1, 1, rng)
23042308
self.cmp_ger((0, 5), 0, 5, rng)
23052309
self.cmp_ger((3, 0), 3, 0, rng)
23062310
self.cmp_ger((0, 1), 0, 1, rng)

tests/tensor/test_blas_c.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def test_gemv1(self):
243243
self.t_gemv1((0, 2))
244244
self.t_gemv1((3, 1))
245245
self.t_gemv1((3, 0))
246+
self.t_gemv1((1, 1))
246247
self.t_gemv1((1, 0))
247248
self.t_gemv1((0, 1))
248249
self.t_gemv1((0, 0))

0 commit comments

Comments
 (0)