Skip to content

Commit 1657d4a

Browse files
committed
Fix numba dispatch of Det and SlogDet returning non-arrays
1 parent 5ffe17a commit 1657d4a

File tree

3 files changed

+18
-65
lines changed

3 files changed

+18
-65
lines changed

pytensor/link/numba/dispatch/nlinalg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def numba_funcify_Det(op, node, **kwargs):
5252

5353
@numba_basic.numba_njit(inline="always")
5454
def det(x):
55-
return numba_basic.direct_cast(np.linalg.det(inputs_cast(x)), out_dtype)
55+
return np.array(np.linalg.det(inputs_cast(x))).astype(out_dtype)
5656

5757
return det
5858

@@ -68,8 +68,8 @@ def numba_funcify_SLogDet(op, node, **kwargs):
6868
def slogdet(x):
6969
sign, det = np.linalg.slogdet(inputs_cast(x))
7070
return (
71-
numba_basic.direct_cast(sign, out_dtype_1),
72-
numba_basic.direct_cast(det, out_dtype_2),
71+
np.array(sign).astype(out_dtype_1),
72+
np.array(det).astype(out_dtype_2),
7373
)
7474

7575
return slogdet

tests/link/numba/test_basic.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,12 @@ def compare_numba_and_py(
260260
if assert_fn is None:
261261

262262
def assert_fn(x, y):
263-
return np.testing.assert_allclose(x, y, rtol=1e-4) and compare_shape_dtype(
264-
x, y
265-
)
263+
np.testing.assert_allclose(x, y, rtol=1e-4, strict=True)
264+
# Make sure we don't have one input be a np.ndarray while the other is not
265+
if isinstance(x, np.ndarray):
266+
assert isinstance(y, np.ndarray), "y is not a NumPy array, but x is"
267+
else:
268+
assert not isinstance(y, np.ndarray), "y is a NumPy array, but x is not"
266269

267270
if any(
268271
inp.owner is not None

tests/link/numba/test_nlinalg.py

Lines changed: 9 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -11,68 +11,18 @@
1111
rng = np.random.default_rng(42849)
1212

1313

14-
@pytest.mark.parametrize(
15-
"x, exc",
16-
[
17-
(
18-
(
19-
pt.dmatrix(),
20-
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
21-
),
22-
None,
23-
),
24-
(
25-
(
26-
pt.lmatrix(),
27-
(lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")),
28-
),
29-
None,
30-
),
31-
],
32-
)
33-
def test_Det(x, exc):
34-
x, test_x = x
35-
g = nlinalg.Det()(x)
14+
@pytest.mark.parametrize("dtype", ("float64", "int64"))
15+
@pytest.mark.parametrize("op", (nlinalg.Det(), nlinalg.SLogDet()))
16+
def test_Det_SLogDet(op, dtype):
17+
x = pt.matrix(dtype=dtype)
3618

37-
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
38-
with cm:
39-
compare_numba_and_py(
40-
[x],
41-
g,
42-
[test_x],
43-
)
19+
rng = np.random.default_rng([50, sum(map(ord, dtype))])
20+
x_ = rng.random(size=(3, 3)).astype(dtype)
21+
test_x = x_.T.dot(x_)
4422

23+
g = op(x)
4524

46-
@pytest.mark.parametrize(
47-
"x, exc",
48-
[
49-
(
50-
(
51-
pt.dmatrix(),
52-
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
53-
),
54-
None,
55-
),
56-
(
57-
(
58-
pt.lmatrix(),
59-
(lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")),
60-
),
61-
None,
62-
),
63-
],
64-
)
65-
def test_SLogDet(x, exc):
66-
x, test_x = x
67-
g = nlinalg.SLogDet()(x)
68-
69-
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
70-
with cm:
71-
compare_numba_and_py(
72-
[x],
73-
g,
74-
[test_x],
75-
)
25+
compare_numba_and_py([x], g, [test_x])
7626

7727

7828
# We were seeing some weird results in CI where the following two almost

0 commit comments

Comments
 (0)