diff --git a/pytensor/tensor/blas_c.py b/pytensor/tensor/blas_c.py index 6d1d830cf8..9513163027 100644 --- a/pytensor/tensor/blas_c.py +++ b/pytensor/tensor/blas_c.py @@ -344,6 +344,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non """ code = """ + bool is_float; int elemsize; float fbeta; double dbeta; @@ -361,11 +362,23 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non %(fail)s; } - if (PyArray_DESCR(%(y)s)->type_num == NPY_DOUBLE) { elemsize = 8; } - else if (PyArray_DESCR(%(y)s)->type_num == NPY_FLOAT) { elemsize = 4;} + if ((PyArray_DESCR(%(y)s)->type_num != PyArray_DESCR(%(x)s)->type_num) + || (PyArray_DESCR(%(y)s)->type_num != PyArray_DESCR(%(A)s)->type_num)) + { + PyErr_SetString(PyExc_TypeError, "GEMV: dtypes of A, x, y do not match"); + %(fail)s; + } + if (PyArray_DESCR(%(y)s)->type_num == NPY_DOUBLE) { + is_float = 0; + elemsize = 8; + } + else if (PyArray_DESCR(%(y)s)->type_num == NPY_FLOAT) { + elemsize = 4; + is_float = 1; + } else { - PyErr_SetString(PyExc_NotImplementedError, "complex Gemv"); %(fail)s; + PyErr_SetString(PyExc_NotImplementedError, "GEMV: Inputs must be float or double"); } fbeta = dbeta = ((dtype_%(beta)s*)PyArray_DATA(%(beta)s))[0]; @@ -408,37 +421,40 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non Py_INCREF(%(z)s); } } + { - char TRANS = 'T'; - char NOTRANS = 'N'; int NA0 = PyArray_DIMS(%(A)s)[0]; int NA1 = PyArray_DIMS(%(A)s)[1]; - /* This formula is needed in the case where A is actually a row or - * column matrix, because BLAS sometimes insists that the strides: - * - are not smaller than the number of elements in the array - * - are not 0. - */ - int SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1); - int SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1); - int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize; - int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize; - - dtype_%(A)s* A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s); - dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s); - dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s); - // gemv expects pointers to the beginning of memory arrays, - // but numpy provides a pointer to the first element, - // so when the stride is negative, we need to get the last one. - if (Sx < 0) - x_data += (NA1 - 1) * Sx; - if (Sz < 0) - z_data += (NA0 - 1) * Sz; if (NA0 * NA1) { + // Non-empty A matrix + + /* In the case where A is actually a row or column matrix, + * the strides corresponding to the dummy dimension don't matter, + * but BLAS requires these to be no smaller than the number of elements in the array. + */ + int SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : NA1; + int SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : NA0; + int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize; + int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize; + + dtype_%(A)s* A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s); + dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s); + dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s); + + // gemv expects pointers to the beginning of memory arrays, + // but numpy provides a pointer to the first element, + // so when the stride is negative, we need to get the last one. + if (Sx < 0) + x_data += (NA1 - 1) * Sx; + if (Sz < 0) + z_data += (NA0 - 1) * Sz; + if ( ((SA0 < 0) || (SA1 < 0)) && (abs(SA0) == 1 || (abs(SA1) == 1)) ) { // We can treat the array A as C-or F-contiguous by changing the order of iteration + // printf("GEMV: Iterating in reverse NA0=%%d, NA1=%%d, SA0=%%d, SA1=%%d\\n", NA0, NA1, SA0, SA1); if (SA0 < 0){ A_data += (NA0 -1) * SA0; // Jump to first row SA0 = -SA0; // Iterate over rows in reverse @@ -452,27 +468,45 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non } else if ((SA0 < 0) || (SA1 < 0) || ((SA0 != 1) && (SA1 != 1))) { // Array isn't contiguous, we have to make a copy - // - if the copy is too long, maybe call vector/vector dot on - // each row instead - // printf("GEMV: Making a copy SA0=%%d, SA1=%%d\\n", SA0, SA1); + // - if the copy is too long, maybe call vector/vector dot on each row instead + // printf("GEMV: Making a copy NA0=%%d, NA1=%%d, SA0=%%d, SA1=%%d\\n", NA0, NA1, SA0, SA1); npy_intp dims[2]; dims[0] = NA0; dims[1] = NA1; - - PyArrayObject * A_copy = (PyArrayObject *) PyArray_Copy( - %(A)s); + PyArrayObject * A_copy = (PyArrayObject *) PyArray_Copy(%(A)s); if (!A_copy) %(fail)s Py_XDECREF(%(A)s); %(A)s = A_copy; - SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1); - SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1); + SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : NA1; + SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : NA0; A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s); } + //else {printf("GEMV: Using the original array NA0=%%d, NA1=%%d, SA0=%%d, SA1=%%d\\n", NA0, NA1, SA0, SA1);} - if (SA0 == 1) + if (NA0 == 1) + { + // Vector-vector dot product, it seems faster to avoid GEMV + dtype_%(alpha)s alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; + + if (is_float) + { + z_data[0] *= fbeta; + z_data[0] += alpha * sdot_(&NA1, (float*)(A_data), &SA1, + (float*)x_data, &Sx); + } + else + { + z_data[0] *= dbeta; + z_data[0] += alpha * ddot_(&NA1, (double*)(A_data), &SA1, + (double*)x_data, &Sx); + } + } + else if (SA0 == 1) { - if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT) + // F-contiguous + char NOTRANS = 'N'; + if (is_float) { float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; sgemv_(&NOTRANS, &NA0, &NA1, @@ -482,7 +516,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non &fbeta, (float*)z_data, &Sz); } - else if (PyArray_DESCR(%(A)s)->type_num == NPY_DOUBLE) + else { double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; dgemv_(&NOTRANS, &NA0, &NA1, @@ -492,97 +526,39 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non &dbeta, (double*)z_data, &Sz); } - else - { - PyErr_SetString(PyExc_AssertionError, - "neither float nor double dtype"); - %(fail)s - } } else if (SA1 == 1) { - if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT) + // C-contiguous + char TRANS = 'T'; + if (is_float) { float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; - - // Check for vector-vector dot (NA0 == 1). The code may work - // for SA1 != 1 as well, but has not been tested for this case, - // so SA1 == 1 is required for safety. - if (NA0 == 1 && SA1 == 1) - { - if (fbeta != 0.f) { - z_data[0] = fbeta*z_data[0]; - } else { - z_data[0] = 0.f; - } - z_data[0] += alpha*sdot_(&NA1, - (float*)(A_data), &SA1, - (float*)x_data, &Sx); - } - else - { - sgemv_(&TRANS, &NA1, &NA0, - &alpha, - (float*)(A_data), &SA0, - (float*)x_data, &Sx, - &fbeta, - (float*)z_data, &Sz); - } - } - else if (PyArray_DESCR(%(A)s)->type_num == NPY_DOUBLE) - { - double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; - - // Check for vector-vector dot (NA0 == 1). The code may work - // for SA1 != 1 as well, but has not been tested for this case, - // so SA1 == 1 is required for safety. - if (NA0 == 1 && SA1 == 1) - { - if (dbeta != 0.) { - z_data[0] = dbeta*z_data[0]; - } else { - z_data[0] = 0.; - } - z_data[0] += alpha*ddot_(&NA1, - (double*)(A_data), &SA1, - (double*)x_data, &Sx); - } - else - { - dgemv_(&TRANS, &NA1, &NA0, - &alpha, - (double*)(A_data), &SA0, - (double*)x_data, &Sx, - &dbeta, - (double*)z_data, &Sz); - } + sgemv_(&TRANS, &NA1, &NA0, + &alpha, + (float*)(A_data), &SA0, + (float*)x_data, &Sx, + &fbeta, + (float*)z_data, &Sz); } else { - PyErr_SetString(PyExc_AssertionError, - "neither float nor double dtype"); - %(fail)s + double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; + dgemv_(&TRANS, &NA1, &NA0, + &alpha, + (double*)(A_data), &SA0, + (double*)x_data, &Sx, + &dbeta, + (double*)z_data, &Sz); } } else { PyErr_SetString(PyExc_AssertionError, - "xx is a double-strided matrix, and should have been " - "copied into a memory-contiguous one."); + "A is neither C nor F-contiguous, it should have been copied into a memory-contiguous array;"); %(fail)s } } - else if (dbeta != 1.0) - { - // the matrix has at least one dim of length 0 - // so we do this loop, which either iterates over 0 elements - // or else it does the right thing for length-0 A. - dtype_%(z)s * zptr = (dtype_%(z)s*)(PyArray_DATA(%(z)s)); - for (int i = 0; i < NA0; ++i) - { - zptr[i * Sz] = (dbeta == 0.0 ? 0.0 : zptr[i * Sz] * dbeta); - } - } } """ return code % locals() @@ -613,7 +589,7 @@ def c_code(self, node, name, inp, out, sub): return code def c_code_cache_version(self): - return (15, blas_header_version(), check_force_gemv_init()) + return (16, blas_header_version(), check_force_gemv_init()) cgemv_inplace = CGemv(inplace=True) diff --git a/tests/tensor/test_blas.py b/tests/tensor/test_blas.py index 37e2c380b9..f3fcf72cc5 100644 --- a/tests/tensor/test_blas.py +++ b/tests/tensor/test_blas.py @@ -2226,8 +2226,10 @@ def cmp_gemv(self, a_shp, b_shp, c_shp, rng): a.set_value(a_dev.copy()[::a_step], borrow=True) b.set_value(b_dev.copy()[::b_step1, ::b_step2], borrow=True) + # Copy as C so that it becomes F after the transpose in the graph b_t.set_value( - np.transpose(b_dev.copy())[::b_step2, ::b_step1], borrow=True + np.transpose(b_dev).copy(order="C")[::b_step2, ::b_step1], + borrow=True, ) c.set_value(c_dev.copy()[::c_step], borrow=True) @@ -2244,6 +2246,7 @@ def test_gemv(self): self.cmp_gemv(3, (3, 5), 5, rng) self.cmp_gemv(1, (1, 5), 5, rng) self.cmp_gemv(3, (3, 1), 1, rng) + self.cmp_gemv(1, (1, 1), 1, rng) self.cmp_gemv(0, (0, 5), 5, rng) self.cmp_gemv(3, (3, 0), 0, rng) self.cmp_gemv(0, (0, 1), 1, rng) @@ -2301,6 +2304,7 @@ def test_ger_strides(self): self.cmp_ger((3, 5), 3, 5, rng) self.cmp_ger((1, 5), 1, 5, rng) self.cmp_ger((3, 1), 3, 1, rng) + self.cmp_ger((1, 1), 1, 1, rng) self.cmp_ger((0, 5), 0, 5, rng) self.cmp_ger((3, 0), 3, 0, rng) self.cmp_ger((0, 1), 0, 1, rng) diff --git a/tests/tensor/test_blas_c.py b/tests/tensor/test_blas_c.py index 26747d2199..e46c036766 100644 --- a/tests/tensor/test_blas_c.py +++ b/tests/tensor/test_blas_c.py @@ -243,6 +243,7 @@ def test_gemv1(self): self.t_gemv1((0, 2)) self.t_gemv1((3, 1)) self.t_gemv1((3, 0)) + self.t_gemv1((1, 1)) self.t_gemv1((1, 0)) self.t_gemv1((0, 1)) self.t_gemv1((0, 0)) @@ -413,6 +414,32 @@ class TestBlasStridesC(TestBlasStrides): mode = mode_blas_opt +def test_gemv_vector_dot_perf(benchmark): + n = 400_000 + a = pt.vector("A", shape=(n,)) + b = pt.vector("x", shape=(n,)) + + out = CGemv(inplace=True)( + pt.empty((1,)), + 1.0, + a[None], + b, + 0.0, + ) + fn = pytensor.function([a, b], out, accept_inplace=True, trust_input=True) + + rng = np.random.default_rng(430) + test_a = rng.normal(size=n) + test_b = rng.normal(size=n) + + np.testing.assert_allclose( + fn(test_a, test_b), + np.dot(test_a, test_b), + ) + + benchmark(fn, test_a, test_b) + + @pytest.mark.parametrize( "neg_stride1", (True, False), ids=["neg_stride1", "pos_stride1"] )