diff --git a/pytensor/link/numba/dispatch/__init__.py b/pytensor/link/numba/dispatch/__init__.py index 1fefb1d06d..3b15f93442 100644 --- a/pytensor/link/numba/dispatch/__init__.py +++ b/pytensor/link/numba/dispatch/__init__.py @@ -14,6 +14,6 @@ import pytensor.link.numba.dispatch.sparse import pytensor.link.numba.dispatch.subtensor import pytensor.link.numba.dispatch.tensor_basic - +import pytensor.link.numba.dispatch.blas # isort: on diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index f6e62ae2f8..523bdcdff9 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -75,7 +75,7 @@ def numba_njit(*args, fastmath=None, **kwargs): message=( "(\x1b\\[1m)*" # ansi escape code for bold text "Cannot cache compiled function " - '"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor)" ' + '"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor|banded_dot)" ' "as it uses dynamic globals" ), category=NumbaWarning, diff --git a/pytensor/link/numba/dispatch/blas.py b/pytensor/link/numba/dispatch/blas.py new file mode 100644 index 0000000000..85d28c3cb4 --- /dev/null +++ b/pytensor/link/numba/dispatch/blas.py @@ -0,0 +1,59 @@ +from pytensor.link.numba.dispatch import numba_funcify +from pytensor.link.numba.dispatch.basic import numba_njit +from pytensor.link.numba.dispatch.linalg.dot.banded import _gbmv +from pytensor.link.numba.dispatch.linalg.dot.general import _matrix_vector_product +from pytensor.link.numba.dispatch.slinalg import _COMPLEX_DTYPE_NOT_SUPPORTED_MSG +from pytensor.tensor.blas import BandedGEMV, Gemv +from pytensor.tensor.type import complex_dtypes + + +@numba_funcify.register(Gemv) +def numba_funcify_Gemv(op, node, **kwargs): + """ + Function to handle the Gemv operation in Numba. + """ + overwrite_y = op.inplace + + @numba_njit() + def numba_gemv(y, alpha, A, x, beta): + """ + Numba implementation of the Gemv operation. + """ + return _matrix_vector_product( + alpha=alpha, + A=A, + x=x, + beta=beta, + y=y, + overwrite_y=overwrite_y, + ) + + return numba_gemv + + +@numba_funcify.register(BandedGEMV) +def numba_funcify_BandedGEMV(op, node, **kwargs): + kl = op.lower_diags + ku = op.upper_diags + overwrite_y = op.overwrite_y + trans = int(op.transpose) + dtype = node.inputs[0].dtype + + if dtype in complex_dtypes: + raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) + + @numba_njit(cache=False) + def banded_gemv(A, x, y, alpha, beta): + return _gbmv( + A=A, + x=x, + kl=kl, + ku=ku, + y=y, + alpha=alpha, + beta=beta, + overwrite_y=overwrite_y, + trans=trans, + ) + + return banded_gemv diff --git a/pytensor/link/numba/dispatch/linalg/_BLAS.py b/pytensor/link/numba/dispatch/linalg/_BLAS.py new file mode 100644 index 0000000000..9d002c37e1 --- /dev/null +++ b/pytensor/link/numba/dispatch/linalg/_BLAS.py @@ -0,0 +1,93 @@ +import ctypes + +from numba.core.extending import get_cython_function_address +from numba.np.linalg import ensure_blas, ensure_lapack, get_blas_kind + +from pytensor.link.numba.dispatch.linalg._LAPACK import ( + _get_float_pointer_for_dtype, + _ptr_int, +) + + +def _get_blas_ptr_and_ptr_type(dtype, name): + d = get_blas_kind(dtype) + func_name = f"{d}{name}" + float_pointer = _get_float_pointer_for_dtype(d) + lapack_ptr = get_cython_function_address("scipy.linalg.cython_blas", func_name) + + return lapack_ptr, float_pointer + + +class _BLAS: + """ + Functions to return type signatures for wrapped BLAS functions. + + Here we are specifically concered with BLAS functions exposed by scipy, and not used by numpy. + + Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74 + """ + + def __init__(self): + ensure_lapack() + ensure_blas() + + @classmethod + def numba_xgemv(cls, dtype): + """ + xGEMV performs one of the following matrix operations: + + y = alpha * A @ x + beta * y, or y = alpha * A.T @ x + beta * y + + Where alpha and beta are scalars, x and y are vectors, and A is a general matrix. + """ + + blas_ptr, float_pointer = _get_blas_ptr_and_ptr_type(dtype, "gemv") + + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # TRANS + _ptr_int, # M + _ptr_int, # N + float_pointer, # ALPHA + float_pointer, # A + _ptr_int, # LDA + float_pointer, # X + _ptr_int, # INCX + float_pointer, # BETA + float_pointer, # Y + _ptr_int, # INCY + ) + + return functype(blas_ptr) + + @classmethod + def numba_xgbmv(cls, dtype): + """ + xGBMV performs one of the following matrix operations: + + y = alpha * A @ x + beta * y, or y = alpha * A.T @ x + beta * y + + Where alpha and beta are scalars, x and y are vectors, and A is a band matrix with kl sub-diagonals and ku + super-diagonals. + """ + + blas_ptr, float_pointer = _get_blas_ptr_and_ptr_type(dtype, "gbmv") + + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # TRANS + _ptr_int, # M + _ptr_int, # N + _ptr_int, # KL + _ptr_int, # KU + float_pointer, # ALPHA + float_pointer, # A + _ptr_int, # LDA + float_pointer, # X + _ptr_int, # INCX + float_pointer, # BETA + float_pointer, # Y + _ptr_int, # INCY + ) + + return functype(blas_ptr) diff --git a/pytensor/link/numba/dispatch/linalg/dot/__init__.py b/pytensor/link/numba/dispatch/linalg/dot/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pytensor/link/numba/dispatch/linalg/dot/banded.py b/pytensor/link/numba/dispatch/linalg/dot/banded.py new file mode 100644 index 0000000000..4667b87417 --- /dev/null +++ b/pytensor/link/numba/dispatch/linalg/dot/banded.py @@ -0,0 +1,179 @@ +from collections.abc import Callable +from typing import Any + +import numpy as np +from numba import njit as numba_njit +from numba.core.extending import overload +from numba.np.linalg import ensure_blas, ensure_lapack +from scipy import linalg + +from pytensor.link.numba.dispatch.linalg._BLAS import _BLAS +from pytensor.link.numba.dispatch.linalg._LAPACK import ( + _get_underlying_float, + val_to_int_ptr, +) +from pytensor.link.numba.dispatch.linalg.utils import ( + _check_scipy_linalg_matrix, + _copy_to_fortran_order_even_if_1d, + _trans_char_to_int, +) + + +@numba_njit(inline="always") +def A_to_banded(A: np.ndarray, kl: int, ku: int) -> np.ndarray: + m, n = A.shape + + # This matrix is build backwards then transposed to get it into Fortran order + # (order="F" is not allowed in Numba land) + A_banded = np.zeros((n, kl + ku + 1), dtype=A.dtype).T + + for i, k in enumerate(range(ku, -kl - 1, -1)): + if k >= 0: + A_banded[i, k:] = np.diag(A, k=k) + else: + A_banded[i, : n + k] = np.diag(A, k=k) + + return A_banded + + +def _gbmv( + alpha: np.ndarray, + A: np.ndarray, + x: np.ndarray, + kl: int, + ku: int, + beta: np.ndarray | None = None, + y: np.ndarray | None = None, + overwrite_y: bool = False, + trans: int = 1, +) -> Any: + """ + Thin wrapper around gmbv. This code will only be called if njit is disabled globally + (e.g. during testing) + """ + (fn,) = linalg.get_blas_funcs(("gbmv",), (A, x)) + m, n = A.shape + A_banded = A_to_banded(A, kl=kl, ku=ku) + + incx = x.strides[0] // x.itemsize + offx = 0 if incx >= 0 else -x.size + 1 + + if y is not None: + incy = y.strides[0] // y.itemsize + offy = 0 if incy >= 0 else -y.size + 1 + else: + incy = 1 + offy = 0 + + return fn( + m=m, + n=n, + kl=kl, + ku=ku, + a=A_banded, + alpha=alpha, + x=x, + incx=incx, + offx=offx, + beta=beta, + y=y, + overwrite_y=overwrite_y, + incy=incy, + offy=offy, + trans=trans, + ) + + +@overload(_gbmv) +def gbmv_impl( + alpha: np.ndarray, + A: np.ndarray, + x: np.ndarray, + kl: int, + ku: int, + beta: np.ndarray | None = None, + y: np.ndarray | None = None, + overwrite_y: bool = False, + trans: int = 1, +) -> Callable[ + [ + np.ndarray, + np.ndarray, + np.ndarray, + int, + int, + np.ndarray | None, + np.ndarray | None, + bool, + int, + ], + np.ndarray, +]: + ensure_lapack() + ensure_blas() + _check_scipy_linalg_matrix(A, "dot_banded") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_gbmv = _BLAS().numba_xgbmv(dtype) + + def impl( + alpha: np.ndarray, + A: np.ndarray, + x: np.ndarray, + kl: int, + ku: int, + beta: np.ndarray | None = None, + y: np.ndarray | None = None, + overwrite_y: bool = False, + trans: int = 1, + ) -> np.ndarray: + m, n = A.shape + + A_banded = A_to_banded(A, kl=kl, ku=ku) + x_stride = x.strides[0] // x.itemsize + + if beta is None: + beta = np.zeros((), dtype=dtype) + + if y is None: + y_copy = np.empty(shape=(m,), dtype=dtype) + elif overwrite_y and y.flags.f_contiguous: + y_copy = y + else: + y_copy = _copy_to_fortran_order_even_if_1d(y) + + y_stride = y_copy.strides[0] // y_copy.itemsize + + TRANS = val_to_int_ptr(_trans_char_to_int(trans)) + M = val_to_int_ptr(m) + N = val_to_int_ptr(n) + LDA = val_to_int_ptr(A_banded.shape[0]) + + KL = val_to_int_ptr(kl) + KU = val_to_int_ptr(ku) + + INCX = val_to_int_ptr(x_stride) + INCY = val_to_int_ptr(y_stride) + + numba_gbmv( + TRANS, + M, + N, + KL, + KU, + alpha.view(w_type).ctypes, + A_banded.view(w_type).ctypes, + LDA, + # x.view().ctypes is creating a pointer to the beginning of the memory where the array is. When we have + # a negative stride, we need to trick BLAS by pointing to the last element of the array. + # The [-1:] slice is a workaround to make sure x remains an array (otherwise it has no .ctypes) + (x if x_stride >= 0 else x[-1:]).view(w_type).ctypes, + INCX, + beta.view(w_type).ctypes, + y_copy.view(w_type).ctypes, + INCY, + ) + + return y_copy + + return impl diff --git a/pytensor/link/numba/dispatch/linalg/dot/general.py b/pytensor/link/numba/dispatch/linalg/dot/general.py new file mode 100644 index 0000000000..9ebe9408e3 --- /dev/null +++ b/pytensor/link/numba/dispatch/linalg/dot/general.py @@ -0,0 +1,146 @@ +from collections.abc import Callable +from typing import cast as type_cast + +import numpy as np +from numba.core.extending import overload +from numba.np.linalg import ensure_blas, ensure_lapack +from scipy import linalg + +from pytensor.link.numba.dispatch.linalg._BLAS import _BLAS +from pytensor.link.numba.dispatch.linalg._LAPACK import ( + _get_underlying_float, + val_to_int_ptr, +) +from pytensor.link.numba.dispatch.linalg.utils import ( + _check_scipy_linalg_matrix, + _copy_to_fortran_order_even_if_1d, + _trans_char_to_int, +) + + +def _matrix_vector_product( + alpha: np.ndarray, + A: np.ndarray, + x: np.ndarray, + beta: np.ndarray | None = None, + y: np.ndarray | None = None, + overwrite_y: bool = False, + trans: int = 1, +) -> np.ndarray: + """ + Thin wrapper around gmv. This code will only be called if njit is disabled globally + (e.g. during testing) + """ + (fn,) = linalg.get_blas_funcs(("gemv",), (A, x)) + + incx = x.strides[0] // x.itemsize + offx = 0 if incx >= 0 else -x.size + 1 + + if y is not None: + incy = y.strides[0] // y.itemsize if y is not None else 1 + offy = 0 if incy >= 0 else -y.size + 1 + else: + incy = 1 + offy = 0 + + res = fn( + alpha=alpha, + a=A, + x=x, + beta=beta, + y=y, + overwrite_y=overwrite_y, + offx=offx, + incx=incx, + offy=offy, + incy=incy, + trans=trans, + ) + + return type_cast(np.ndarray, res) + + +@overload(_matrix_vector_product) +def matrix_vector_product_impl( + alpha: np.ndarray, + A: np.ndarray, + x: np.ndarray, + beta: np.ndarray | None = None, + y: np.ndarray | None = None, + overwrite_y: bool = False, + trans: int = 1, +) -> Callable[ + [ + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray | None, + np.ndarray | None, + bool, + int, + ], + np.ndarray, +]: + ensure_lapack() + ensure_blas() + _check_scipy_linalg_matrix(A, "matrix_vector_product") + dtype = A.dtype + w_type = _get_underlying_float(dtype) + numba_gemv = _BLAS().numba_xgemv(dtype) + + def impl( + alpha: np.ndarray, + A: np.ndarray, + x: np.ndarray, + beta: np.ndarray | None = None, + y: np.ndarray | None = None, + overwrite_y: bool = False, + trans: int = 1, + ) -> np.ndarray: + m, n = A.shape + x_stride = x.strides[0] // x.itemsize + + if beta is None: + beta = np.zeros((), dtype=dtype) + + if y is None: + y_copy = np.empty(shape=(m,), dtype=dtype) + elif overwrite_y and y.flags.f_contiguous: + y_copy = y + else: + y_copy = _copy_to_fortran_order_even_if_1d(y) + + y_stride = y_copy.strides[0] // y_copy.itemsize + + TRANS = val_to_int_ptr(_trans_char_to_int(trans)) + M = val_to_int_ptr(m) + N = val_to_int_ptr(n) + LDA = val_to_int_ptr(A.shape[0]) + + # ALPHA = np.array(alpha, dtype=dtype) + + INCX = val_to_int_ptr(x_stride) + # BETA = np.array(beta, dtype=dtype) + INCY = val_to_int_ptr(y_stride) + + numba_gemv( + TRANS, + M, + N, + alpha.view(w_type).ctypes, + A.view(w_type).ctypes, + LDA, + # x.view().ctypes is creating a pointer to the beginning of the memory where the array is. When we have + # a negative stride, we need to trick BLAS by pointing to the last element of the array. + # The [-1:] slice is a workaround to make sure x remains an array (otherwise it has no .ctypes) + (x if x_stride >= 0 else x[-1:]).view(w_type).ctypes, + INCX, + beta.view(w_type).ctypes, + y_copy.view(w_type).ctypes, + # (y_copy if y_stride >= 0 else y_copy[:-1:]).view(w_type).ctypes, + INCY, + ) + + return y_copy + + return impl diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index fc8afcea50..7f702be745 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -83,9 +83,15 @@ from pathlib import Path import numpy as np +from numpy import zeros +from scipy import linalg as scipy_linalg -from pytensor.graph import vectorize_graph +import pytensor +from pytensor import tensor as pt +from pytensor.graph import Op, vectorize_graph from pytensor.npy_2_compat import normalize_axis_tuple +from pytensor.tensor import TensorLike, as_tensor_variable +from pytensor.tensor.blockwise import Blockwise try: @@ -97,13 +103,12 @@ import pytensor.scalar from pytensor.configdefaults import config from pytensor.graph.basic import Apply, view_roots -from pytensor.graph.op import Op from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType from pytensor.printing import FunctionPrinter, pprint from pytensor.scalar import bool as bool_t -from pytensor.tensor.basic import as_tensor_variable, cast +from pytensor.tensor.basic import cast from pytensor.tensor.blas_headers import blas_header_text, blas_header_version from pytensor.tensor.math import dot, tensordot from pytensor.tensor.shape import specify_broadcastable @@ -1788,3 +1793,172 @@ def batched_tensordot(x, y, axes=2): core_tensordot = tensordot(core_x, core_y, axes=core_axes) return vectorize_graph(core_tensordot, replace={core_x: x, core_y: y}) + + +class BandedGEMV(Op): + __props__ = ("lower_diags", "upper_diags", "transpose", "overwrite_y") + gufunc_signature = "(m,n),(n),(n),(),()->(m)" + + def __init__( + self, + lower_diags: int, + upper_diags: int, + transpose: bool = False, + overwrite_y: bool = False, + ): + self.lower_diags = lower_diags + self.upper_diags = upper_diags + self.overwrite_y = overwrite_y + self.transpose = transpose + + self.destroy_map = {} + + if self.overwrite_y: + self.destroy_map = {0: [2]} + + def make_node(self, A, x, y, alpha, beta): + if A.ndim != 2: + raise TypeError("A must be a 2D tensor") + if x.ndim != 1: + raise TypeError("x must be a 1D tensor") + + A = as_tensor_variable(A) + x = as_tensor_variable(x) + y = as_tensor_variable(y) + alpha = as_tensor_variable(alpha) + beta = as_tensor_variable(beta) + + out_dtype = pytensor.scalar.upcast(A.dtype, x.dtype) + output = x.type.clone(dtype=out_dtype)() + + return pytensor.graph.basic.Apply(self, [A, x, y, alpha, beta], [output]) + + def infer_shape(self, fgraph, nodes, shapes): + A_shape, _ = shapes + return [(A_shape[0],)] + + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> Op: + if 2 in allowed_inplace_inputs: + new_props = self._props_dict() # type: ignore + new_props["overwrite_y"] = True + return type(self)(**new_props) + else: + return self + + def perform(self, node, inputs, outputs_storage): + A, x, y, alpha, beta = inputs + m, n = A.shape + + x_stride = x.strides[0] // x.itemsize + y_stride = y.strides[0] // y.itemsize if y is not None else 1 + + offx = 0 if x_stride >= 0 else -x.size + 1 + offy = 0 if y_stride >= 0 else -y.size + 1 + + kl = self.lower_diags + ku = self.upper_diags + + A_banded = zeros((kl + ku + 1, n), dtype=A.dtype, order="F") + + for i, k in enumerate(range(ku, -kl - 1, -1)): + if k >= 0: + A_banded[i, k:] = np.diag(A, k=k) + else: + A_banded[i, : n + k] = np.diag(A, k=k) + + (fn,) = scipy_linalg.get_blas_funcs(("gbmv",), dtype=A.dtype) + outputs_storage[0][0] = fn( + m=m, + n=n, + kl=kl, + ku=ku, + a=A_banded, + alpha=alpha, + x=x, + incx=x_stride, + offx=offx, + beta=beta, + y=y, + overwrite_y=self.overwrite_y, + incy=y_stride, + offy=offy, + trans=int(self.transpose), + ) + + def L_op(self, inputs, outputs, output_grads): + # This is exactly the same as the usual gradient of a matrix-vector product, except that the banded structure + # is exploited. + A, x = inputs + (G_bar,) = output_grads + + A_bar = pt.outer(G_bar, x.T) + x_bar = self(A.T, G_bar) + + return [A_bar, x_bar] + + +def banded_gemv( + A: TensorLike, + x: TensorLike, + lower_diags: int, + upper_diags: int, + y: TensorLike | None = None, + alpha: TensorLike | None = None, + beta: TensorLike | None = None, +): + """ + Specialized matrix-vector multiplication for cases when A is a banded matrix. + + In BLAS, matrix-vector multiplication is done by the GEMV family of routines, and computes alpha * A @ x + beta * y + + Unlike other dot functions in Pytensor, banded_dot uses a low-level API. No rewrites (yet!) exist to try to infer + the values of alpha, beta, or y from a compute graph surrounding a Dot22 Op. To get the most out of this Op, the + user will thus have to explicitly declare each argument. + + In addition, no type-checking is done on A at runtime, so all data in A off the banded diagonals will be ignored. + This will lead to incorrect results if A is not actually a banded matrix. + + Parameters + ---------- + A: TensorLike + Matrix to perform banded dot on. + x: TensorLike + Vector to perform banded dot on. + lower_diags: int + Number of nonzero lower diagonals of A + upper_diags: int + Number of nonzero upper diagonals of A + y: TensorLike, optional + Vector to be added into the dot-product A @ x. This often called a "rank-one update" term. If not provided, + no rank-one update is performed on the dot product. Ignored if beta is zero. + alpha: TensorLike, optional + Scalar factor multiplying the dot-product A @ x. Default is 1.0 + beta: TensorLike, optional + Scalar factor multiplying the rank-one update vector y. Ignored if y is None. Default is 0.0 + + Returns + ------- + out: Tensor + The matrix multiplication result + """ + A = as_tensor_variable(A) + x = as_tensor_variable(x) + + if alpha is None: + alpha = pt.ones((), dtype=A.type.dtype) + else: + alpha = as_tensor_variable(alpha) + + if beta is None: + beta = pt.zeros((), dtype=A.type.dtype) + else: + beta = as_tensor_variable(beta) + + if y is None: + y = pt.empty(A.shape[:-1], dtype=A.type.dtype) + else: + y = as_tensor_variable(y) + + return Blockwise(BandedGEMV(lower_diags, upper_diags, overwrite_y=False))( + A, x, y, alpha, beta + ) diff --git a/tests/link/numba/test_blas.py b/tests/link/numba/test_blas.py new file mode 100644 index 0000000000..8fe546fc15 --- /dev/null +++ b/tests/link/numba/test_blas.py @@ -0,0 +1,135 @@ +import itertools + +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor import In, Mode, config, function +from pytensor.compile import get_mode +from pytensor.graph import RewriteDatabaseQuery +from pytensor.link.numba import NumbaLinker +from pytensor.tensor.blas import Gemv, banded_gemv +from tests.link.numba.test_basic import compare_numba_and_py, numba_inplace_mode +from tests.tensor.test_slinalg import _make_banded_A + + +numba_blas_mode = Mode( + NumbaLinker(), + RewriteDatabaseQuery( + include=["fast_run", "numba", "BlasOpt"], + exclude=[ + "cxx_only", + "c_blas", + "local_careduce_fusion", + "scan_save_mem_prealloc", + ], + ), +) + + +def test_banded_dot(): + rng = np.random.default_rng() + + A = pt.tensor("A", shape=(10, 10), dtype=config.floatX) + A_val = _make_banded_A(rng.normal(size=(10, 10)), kl=1, ku=1).astype(config.floatX) + + x = pt.tensor("x", shape=(10,), dtype=config.floatX) + x_val = rng.normal(size=(10,)).astype(config.floatX) + + output = banded_gemv(A, x, upper_diags=1, lower_diags=1) + + fn, _ = compare_numba_and_py( + [A, x], + output, + test_inputs=[A_val, x_val], + numba_mode=numba_inplace_mode, + eval_obj_mode=False, + ) + + for stride in [2, -1, -2]: + x_shape = (10 * abs(stride),) + x_val = rng.normal(size=x_shape).astype(config.floatX) + x_val = x_val[::stride] + + nb_output = fn(A_val, x_val) + expected = A_val @ x_val + + np.testing.assert_allclose( + nb_output, + expected, + strict=True, + err_msg=f"Test failed for stride = {stride}", + ) + + +def test_numba_gemv(): + rng = np.random.default_rng() + A = pt.tensor("A", shape=(10, 10)) + x = pt.tensor("x", shape=(10,)) + y = pt.tensor("y", shape=(10,)) + alpha, beta = pt.dscalars("alpha", "beta") + + output = alpha * A @ x + beta * y + + A_val = rng.normal(size=(10, 10)).astype(config.floatX) + x_val = rng.normal(size=(10,)).astype(config.floatX) + y_val = rng.normal(size=(10,)).astype(config.floatX) + alpha_val, beta_val = rng.normal(size=(2,)).astype(config.floatX) + + fn, _ = compare_numba_and_py( + [A, x, y, alpha, beta], + output, + test_inputs=[A_val, x_val, y_val, alpha_val, beta_val], + numba_mode=numba_blas_mode, + eval_obj_mode=False, + ) + assert any(isinstance(node.op, Gemv) for node in fn.maker.fgraph.toposort()) + + for stride, matrix in itertools.product([2, -1, -2], ["x", "y"]): + shape = (10 * abs(stride),) + val = rng.normal(size=shape).astype(config.floatX) + + if matrix == "x": + x_val = val[::stride] + else: + y_val = val[::stride] + + nb_output = fn(A_val, x_val, y_val, alpha_val, beta_val) + expected = alpha_val * A_val @ x_val + beta_val * y_val + + np.testing.assert_allclose( + nb_output, + expected, + strict=True, + err_msg=f"Test failed for stride = {stride}", + ) + + +@pytest.mark.parametrize("size", [10, 100, 1000], ids=str) +@pytest.mark.parametrize("use_blas_gemv", [True, False], ids=["numba+blas", "numba"]) +def test_numba_gemv_benchmark(size, use_blas_gemv, benchmark): + rng = np.random.default_rng() + mode = numba_blas_mode if use_blas_gemv else get_mode("NUMBA") + + A = pt.tensor("A", shape=(None, None)) + x = pt.tensor("x", shape=(None,)) + y = pt.tensor("y", shape=(None,)) + alpha, beta = pt.dscalars("alpha", "beta") + + out = alpha * (A @ x) + beta * y + fn = function([A, x, In(y, mutable=True), alpha, beta], out, mode=mode) + + if use_blas_gemv: + assert any(isinstance(node.op, Gemv) for node in fn.maker.fgraph.toposort()) + else: + assert not any(isinstance(node.op, Gemv) for node in fn.maker.fgraph.toposort()) + + A_val = rng.normal(size=(size, size)).astype(config.floatX) + x_val = rng.normal(size=(size,)).astype(config.floatX) + y_val = rng.normal(size=(size,)).astype(config.floatX) + alpha_val, beta_val = rng.normal(size=(2,)).astype(config.floatX) + + res = fn(A=A_val, x=x_val, y=y_val, alpha=alpha_val, beta=beta_val) + np.testing.assert_allclose(res, y_val) + + benchmark(fn, A=A_val, x=x_val, y=y_val, alpha=alpha_val, beta=beta_val) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index f18f514244..c0fba78f8e 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -11,6 +11,7 @@ from pytensor.configdefaults import config from pytensor.graph.basic import equal_computations from pytensor.tensor import TensorVariable +from pytensor.tensor.blas import BandedGEMV, banded_gemv from pytensor.tensor.slinalg import ( Cholesky, CholeskySolve, @@ -1051,3 +1052,69 @@ def test_block_diagonal_blockwise(): B = np.random.normal(size=(1, batch_size, 4, 4)).astype(config.floatX) result = block_diag(A, B).eval() assert result.shape == (10, batch_size, 6, 6) + + +def _make_banded_A(A, kl, ku): + diag_idxs = range(-kl, ku + 1) + diags = (np.diag(A, k=k) for k in diag_idxs) + return sum(np.diag(d, k=k) for k, d in zip(diag_idxs, diags)) + + +@pytest.mark.parametrize( + "kl, ku", [(1, 1), (0, 1), (2, 2)], ids=["tridiag", "upper-only", "banded"] +) +@pytest.mark.parametrize("stride", [1, 2, -1], ids=lambda x: f"stride={x}") +def test_banded_dot(kl, ku, stride): + rng = np.random.default_rng() + + size = 10 + + A_val = _make_banded_A(rng.normal(size=(size, size)), kl=kl, ku=ku).astype( + config.floatX + ) + x_val = rng.normal(size=(size * abs(stride),)).astype(config.floatX) + x_val = x_val[::stride] + + A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype) + x = pt.tensor("x", shape=x_val.shape, dtype=x_val.dtype) + res = banded_gemv(A, x, kl, ku) + res_2 = A @ x + + fn = function([A, x], [res, res_2], trust_input=True) + assert any(isinstance(node.op, BandedGEMV) for node in fn.maker.fgraph.apply_nodes) + + out_val, out_2_val = fn(A_val, x_val) + + atol = 1e-4 if config.floatX == "float32" else 1e-8 + rtol = 1e-4 if config.floatX == "float32" else 1e-8 + + np.testing.assert_allclose(out_val, out_2_val, atol=atol, rtol=rtol) + + +def test_banded_dot_grad(): + rng = np.random.default_rng() + size = 10 + + A_val = _make_banded_A(rng.normal(size=(size, size)), kl=1, ku=1).astype( + config.floatX + ) + x_val = rng.normal(size=(size,)).astype(config.floatX) + + def make_banded_pt(A): + # Like structured solve Ops, we have to incldue the transformation from an unconstrained matrix A to a banded + # matrix on the compute graph. Otherwise, the random perturbations used by verify_grad will result in invalid + # input matrices. + + diag_idxs = range(-1, 2) + diags = (pt.diag(A, k=k) for k in diag_idxs) + return sum(pt.diag(d, k=k) for k, d in zip(diag_idxs, diags)) + + def test_fn(A, x): + return banded_gemv(make_banded_pt(A), x, lower_diags=1, upper_diags=1).sum() + + utt.verify_grad( + test_fn, + [A_val, x_val], + rng=rng, + eps=1e-4 if config.floatX == "float32" else 1e-8, + )