diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 4938ecc42f..845d6afc7a 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -31,7 +31,6 @@ fgraph_to_python, ) from pytensor.scalar.basic import ScalarType -from pytensor.scalar.math import Softplus from pytensor.sparse import SparseTensorType from pytensor.tensor.basic import Nonzero from pytensor.tensor.blas import BatchedDot @@ -466,7 +465,7 @@ def argort_vec(X, axis): axis = axis.item() Y = np.swapaxes(X, axis, 0) - result = np.empty_like(Y) + result = np.empty_like(Y, dtype="int64") indices = list(np.ndindex(Y.shape[1:])) @@ -607,25 +606,6 @@ def dot(x, y): return dot -@numba_funcify.register(Softplus) -def numba_funcify_Softplus(op, node, **kwargs): - x_dtype = np.dtype(node.inputs[0].dtype) - - @numba_njit - def softplus(x): - if x < -37.0: - value = np.exp(x) - elif x < 18.0: - value = np.log1p(np.exp(x)) - elif x < 33.3: - value = x + np.exp(-x) - else: - value = x - return direct_cast(value, x_dtype) - - return softplus - - @numba_funcify.register(Solve) def numba_funcify_Solve(op, node, **kwargs): assume_a = op.assume_a @@ -689,11 +669,6 @@ def batched_dot(x, y): return batched_dot -# NOTE: The remaining `pytensor.tensor.blas` `Op`s appear unnecessary, because -# they're only used to optimize basic `Dot` nodes, and those GEMV and GEMM -# optimizations are apparently already performed by Numba - - @numba_funcify.register(IfElse) def numba_funcify_IfElse(op, **kwargs): n_outs = op.n_outs diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 9fd81dadcf..7244762b93 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -561,7 +561,7 @@ def numba_funcify_Argmax(op, node, **kwargs): @numba_basic.numba_njit(inline="always") def argmax(x): - return 0 + return np.array(0, dtype="int64") else: axes = tuple(int(ax) for ax in axis) diff --git a/pytensor/link/numba/dispatch/linalg/decomposition/lu.py b/pytensor/link/numba/dispatch/linalg/decomposition/lu.py index 570c024b07..739f0a6990 100644 --- a/pytensor/link/numba/dispatch/linalg/decomposition/lu.py +++ b/pytensor/link/numba/dispatch/linalg/decomposition/lu.py @@ -30,7 +30,7 @@ def _lu_factor_to_lu(a, dtype, overwrite_a): # Fortran is 1 indexed, so we need to subtract 1 from the IPIV array IPIV = IPIV - 1 p_inv = _pivot_to_permutation(IPIV, dtype=dtype) - perm = np.argsort(p_inv) + perm = np.argsort(p_inv).astype("int32") return perm, L, U diff --git a/pytensor/link/numba/dispatch/nlinalg.py b/pytensor/link/numba/dispatch/nlinalg.py index 860560d0a6..3271b5bd26 100644 --- a/pytensor/link/numba/dispatch/nlinalg.py +++ b/pytensor/link/numba/dispatch/nlinalg.py @@ -52,7 +52,7 @@ def numba_funcify_Det(op, node, **kwargs): @numba_basic.numba_njit(inline="always") def det(x): - return numba_basic.direct_cast(np.linalg.det(inputs_cast(x)), out_dtype) + return np.array(np.linalg.det(inputs_cast(x))).astype(out_dtype) return det @@ -68,8 +68,8 @@ def numba_funcify_SLogDet(op, node, **kwargs): def slogdet(x): sign, det = np.linalg.slogdet(inputs_cast(x)) return ( - numba_basic.direct_cast(sign, out_dtype_1), - numba_basic.direct_cast(det, out_dtype_2), + np.array(sign).astype(out_dtype_1), + np.array(det).astype(out_dtype_2), ) return slogdet diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index e9b637b00f..7e4703c8df 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -28,7 +28,7 @@ Second, Switch, ) -from pytensor.scalar.math import Erf, Erfc, GammaLn, Log1mexp, Sigmoid +from pytensor.scalar.math import Erf, Erfc, GammaLn, Log1mexp, Sigmoid, Softplus @numba_funcify.register(ScalarOp) @@ -312,3 +312,22 @@ def erfc(x): @numba_funcify.register(Erfc) def numba_funcify_Erfc(op, **kwargs): return numba_basic.global_numba_func(erfc) + + +@numba_funcify.register(Softplus) +def numba_funcify_Softplus(op, node, **kwargs): + out_dtype = np.dtype(node.outputs[0].type.dtype) + + @numba_basic.numba_njit + def softplus(x): + if x < -37.0: + value = np.exp(x) + elif x < 18.0: + value = np.log1p(np.exp(x)) + elif x < 33.3: + value = x + np.exp(-x) + else: + value = x + return numba_basic.direct_cast(value, out_dtype) + + return softplus diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index e8390b8ebf..9132d7b202 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -14,7 +14,6 @@ numba = pytest.importorskip("numba") import pytensor.scalar as ps -import pytensor.scalar.math as psm import pytensor.tensor as pt import pytensor.tensor.math as ptm from pytensor import config, shared @@ -260,9 +259,12 @@ def compare_numba_and_py( if assert_fn is None: def assert_fn(x, y): - return np.testing.assert_allclose(x, y, rtol=1e-4) and compare_shape_dtype( - x, y - ) + np.testing.assert_allclose(x, y, rtol=1e-4, strict=True) + # Make sure we don't have one input be a np.ndarray while the other is not + if isinstance(x, np.ndarray): + assert isinstance(y, np.ndarray), "y is not a NumPy array, but x is" + else: + assert not isinstance(y, np.ndarray), "y is a NumPy array, but x is not" if any( inp.owner is not None @@ -295,8 +297,8 @@ def assert_fn(x, y): test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs numba_res = pytensor_numba_fn(*test_inputs_copy) if isinstance(graph_outputs, tuple | list): - for j, p in zip(numba_res, py_res, strict=True): - assert_fn(j, p) + for numba_res_i, python_res_i in zip(numba_res, py_res, strict=True): + assert_fn(numba_res_i, python_res_i) else: assert_fn(numba_res, py_res) @@ -640,48 +642,6 @@ def test_Dot(x, y, exc): ) -@pytest.mark.parametrize( - "x, exc", - [ - ( - (ps.float64(), np.array(0.0, dtype="float64")), - None, - ), - ( - (ps.float64(), np.array(-32.0, dtype="float64")), - None, - ), - ( - (ps.float64(), np.array(-40.0, dtype="float64")), - None, - ), - ( - (ps.float64(), np.array(32.0, dtype="float64")), - None, - ), - ( - (ps.float64(), np.array(40.0, dtype="float64")), - None, - ), - ( - (ps.int64(), np.array(32, dtype="int64")), - None, - ), - ], -) -def test_Softplus(x, exc): - x, x_test_value = x - g = psm.Softplus(ps.upgrade_to_float)(x) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - [x], - [g], - [x_test_value], - ) - - @pytest.mark.parametrize( "x, y, exc", [ diff --git a/tests/link/numba/test_nlinalg.py b/tests/link/numba/test_nlinalg.py index 67bdc6f1a0..8d7c3a449c 100644 --- a/tests/link/numba/test_nlinalg.py +++ b/tests/link/numba/test_nlinalg.py @@ -11,68 +11,20 @@ rng = np.random.default_rng(42849) +@pytest.mark.parametrize("dtype", ("float64", "int64")) @pytest.mark.parametrize( - "x, exc", - [ - ( - ( - pt.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - None, - ), - ( - ( - pt.lmatrix(), - (lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")), - ), - None, - ), - ], + "op", (nlinalg.Det(), nlinalg.SLogDet()), ids=["det", "slogdet"] ) -def test_Det(x, exc): - x, test_x = x - g = nlinalg.Det()(x) +def test_Det_SLogDet(op, dtype): + x = pt.matrix(dtype=dtype) - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - [x], - g, - [test_x], - ) + rng = np.random.default_rng([50, sum(map(ord, dtype))]) + x_ = rng.random(size=(3, 3)).astype(dtype) + test_x = x_.T.dot(x_) + g = op(x) -@pytest.mark.parametrize( - "x, exc", - [ - ( - ( - pt.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - None, - ), - ( - ( - pt.lmatrix(), - (lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")), - ), - None, - ), - ], -) -def test_SLogDet(x, exc): - x, test_x = x - g = nlinalg.SLogDet()(x) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - [x], - g, - [test_x], - ) + compare_numba_and_py([x], g, [test_x]) # We were seeing some weird results in CI where the following two almost diff --git a/tests/link/numba/test_scalar.py b/tests/link/numba/test_scalar.py index 504d2a163c..2125d7cc0e 100644 --- a/tests/link/numba/test_scalar.py +++ b/tests/link/numba/test_scalar.py @@ -3,12 +3,13 @@ import pytensor.scalar as ps import pytensor.scalar.basic as psb +import pytensor.scalar.math as psm import pytensor.tensor as pt -from pytensor import config +from pytensor import config, function from pytensor.scalar.basic import Composite from pytensor.tensor import tensor from pytensor.tensor.elemwise import Elemwise -from tests.link.numba.test_basic import compare_numba_and_py +from tests.link.numba.test_basic import compare_numba_and_py, numba_mode, py_mode rng = np.random.default_rng(42849) @@ -99,7 +100,11 @@ def test_Composite(inputs, input_values, scalar_fn): "v, dtype", [ ((pt.fscalar(), np.array(1.0, dtype="float32")), psb.float64), - ((pt.dscalar(), np.array(1.0, dtype="float64")), psb.float32), + pytest.param( + (pt.dscalar(), np.array(1.0, dtype="float64")), + psb.float32, + marks=pytest.mark.xfail(reason="Scalar downcasting not supported in numba"), + ), ], ) def test_Cast(v, dtype): @@ -145,3 +150,37 @@ def test_isnan(composite): [out], [np.array([1, 0], dtype="float64")], ) + + +@pytest.mark.parametrize( + "dtype", + [ + pytest.param( + "float32", + marks=pytest.mark.xfail(reason="Scalar downcasting not supported in numba"), + ), + "float64", + pytest.param( + "int16", + marks=pytest.mark.xfail(reason="Scalar downcasting not supported in numba"), + ), + "int64", + "uint32", + ], +) +def test_Softplus(dtype): + x = ps.get_scalar_type(dtype)("x") + g = psm.softplus(x) + + py_fn = function([x], g, mode=py_mode) + numba_fn = function([x], g, mode=numba_mode) + for value in (-40, -32, 0, 32, 40): + if value < 0 and dtype.startswith("u"): + continue + test_x = np.dtype(dtype).type(value) + np.testing.assert_allclose( + py_fn(test_x), + numba_fn(test_x), + strict=True, + err_msg=f"Failed for value {value}", + )