Skip to content

Commit 664f853

Browse files
committed
fix bad delegation behaviour with atleast_3d and improve atleast_nd unittests.
1 parent 7b3f326 commit 664f853

File tree

2 files changed

+132
-59
lines changed

2 files changed

+132
-59
lines changed

src/array_api_extra/_delegation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array
6767
if xp is None:
6868
xp = array_namespace(x)
6969

70-
if 1 <= ndim <= 3 and (
70+
if 1 <= ndim <= 2 and (
7171
is_numpy_namespace(xp)
7272
or is_jax_namespace(xp)
7373
or is_dask_namespace(xp)

tests/test_funcs.py

Lines changed: 131 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
lazy_xp_function(setdiff1d, jax_jit=False)
5555
lazy_xp_function(sinc)
5656

57+
NestedFloatList = list[float] | list["NestedFloatList"]
58+
5759

5860
class TestApplyWhere:
5961
@staticmethod
@@ -291,68 +293,139 @@ def test_0D(self, xp: ModuleType):
291293
y = atleast_nd(x, ndim=5)
292294
xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1)))
293295

294-
def test_1D(self, xp: ModuleType):
295-
x = xp.asarray([0, 1])
296-
297-
y = atleast_nd(x, ndim=0)
298-
xp_assert_equal(y, x)
299-
300-
y = atleast_nd(x, ndim=1)
301-
xp_assert_equal(y, x)
302-
303-
y = atleast_nd(x, ndim=2)
304-
xp_assert_equal(y, xp.asarray([[0, 1]]))
305-
306-
y = atleast_nd(x, ndim=5)
307-
xp_assert_equal(y, xp.asarray([[[[[0, 1]]]]]))
308-
309-
def test_2D(self, xp: ModuleType):
310-
x = xp.asarray([[3.0]])
311-
312-
y = atleast_nd(x, ndim=0)
313-
xp_assert_equal(y, x)
314-
315-
y = atleast_nd(x, ndim=2)
316-
xp_assert_equal(y, x)
317-
318-
y = atleast_nd(x, ndim=3)
319-
xp_assert_equal(y, 3 * xp.ones((1, 1, 1)))
320-
321-
y = atleast_nd(x, ndim=5)
322-
xp_assert_equal(y, 3 * xp.ones((1, 1, 1, 1, 1)))
323-
324-
def test_3D(self, xp: ModuleType):
325-
x = xp.asarray([[[3.0], [2.0]]])
326-
327-
y = atleast_nd(x, ndim=0)
328-
xp_assert_equal(y, x)
329-
330-
y = atleast_nd(x, ndim=2)
331-
xp_assert_equal(y, x)
332-
333-
y = atleast_nd(x, ndim=3)
334-
xp_assert_equal(y, x)
335-
336-
y = atleast_nd(x, ndim=5)
337-
xp_assert_equal(y, xp.asarray([[[[[3.0], [2.0]]]]]))
338-
339-
def test_5D(self, xp: ModuleType):
340-
x = xp.ones((1, 1, 1, 1, 1))
341-
342-
y = atleast_nd(x, ndim=0)
343-
xp_assert_equal(y, x)
296+
@pytest.mark.parametrize(
297+
("x_data", "ndim", "expected_data"),
298+
[
299+
# --- size-1 vector ---
300+
([3.0], 0, [3.0]),
301+
([3.0], 1, [3.0]),
302+
([3.0], 2, [[3.0]]),
303+
([3.0], 3, [[[3.0]]]),
304+
([3.0], 5, [[[[[3.0]]]]]),
305+
# --- size-2 vector ---
306+
([0.0, 1.0], 0, [0.0, 1.0]),
307+
([0.0, 1.0], 1, [0.0, 1.0]),
308+
([0.0, 1.0], 2, [[0.0, 1.0]]),
309+
([0.0, 1.0], 5, [[[[[0.0, 1.0]]]]]),
310+
],
311+
)
312+
def test_1D(
313+
self,
314+
x_data: NestedFloatList,
315+
ndim: int,
316+
expected_data: NestedFloatList,
317+
xp: ModuleType,
318+
):
319+
x = xp.asarray(x_data)
320+
expected = xp.asarray(expected_data)
321+
y = atleast_nd(x, ndim=ndim)
322+
xp_assert_equal(y, expected)
344323

345-
y = atleast_nd(x, ndim=4)
346-
xp_assert_equal(y, x)
324+
@pytest.mark.parametrize(
325+
("x_data", "ndim", "expected_data"),
326+
[
327+
# --- size-1 vector ---
328+
([[3.0]], 0, [[3.0]]),
329+
([[3.0]], 1, [[3.0]]),
330+
([[3.0]], 2, [[3.0]]),
331+
([[3.0]], 3, [[[3.0]]]),
332+
([[3.0]], 5, [[[[[3.0]]]]]),
333+
# --- size-2 vector ---
334+
([[0.0], [1.0]], 0, [[0.0], [1.0]]),
335+
([[0.0, 1.0]], 1, [[0.0, 1.0]]),
336+
([[0.0, 1.0]], 2, [[0.0, 1.0]]),
337+
([[0.0], [1.0]], 3, [[[0.0], [1.0]]]),
338+
([[0.0, 1.0]], 5, [[[[[0.0, 1.0]]]]]),
339+
],
340+
)
341+
def test_2D(
342+
self,
343+
x_data: NestedFloatList,
344+
ndim: int,
345+
expected_data: NestedFloatList,
346+
xp: ModuleType,
347+
):
348+
x = xp.asarray(x_data)
349+
expected = xp.asarray(expected_data)
350+
y = atleast_nd(x, ndim=ndim)
351+
xp_assert_equal(y, expected)
347352

348-
y = atleast_nd(x, ndim=5)
349-
xp_assert_equal(y, x)
353+
@pytest.mark.parametrize(
354+
("x_data", "ndim", "expected_data"),
355+
[
356+
([[[0.0]], [[1.0]]], 0, [[[0.0]], [[1.0]]]),
357+
([[[0.0], [1.0]]], 1, [[[0.0], [1.0]]]),
358+
([[[0.0, 1.0]]], 2, [[[0.0, 1.0]]]),
359+
([[[0.0]], [[1.0]]], 3, [[[0.0]], [[1.0]]]),
360+
([[[0.0], [1.0]]], 5, [[[[[0.0], [1.0]]]]]),
361+
],
362+
)
363+
def test_3D(
364+
self,
365+
x_data: NestedFloatList,
366+
ndim: int,
367+
expected_data: NestedFloatList,
368+
xp: ModuleType,
369+
):
370+
x = xp.asarray(x_data)
371+
expected = xp.asarray(expected_data)
372+
y = atleast_nd(x, ndim=ndim)
373+
xp_assert_equal(y, expected)
350374

351-
y = atleast_nd(x, ndim=6)
352-
xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1, 1)))
375+
@pytest.mark.parametrize(
376+
("x_data", "ndim", "expected_data"),
377+
[
378+
([[[[3.0], [2.0]]]], 0, [[[[3.0], [2.0]]]]),
379+
([[[[3.0, 2.0]]]], 2, [[[[3.0, 2.0]]]]),
380+
([[[[3.0]], [[2.0]]]], 4, [[[[3.0]], [[2.0]]]]),
381+
([[[[3.0]]], [[[2.0]]]], 5, [[[[[3.0]]], [[[2.0]]]]]),
382+
],
383+
)
384+
def test_4D(
385+
self,
386+
x_data: NestedFloatList,
387+
ndim: int,
388+
expected_data: NestedFloatList,
389+
xp: ModuleType,
390+
):
391+
x = xp.asarray(x_data)
392+
expected = xp.asarray(expected_data)
393+
y = atleast_nd(x, ndim=ndim)
394+
xp_assert_equal(y, expected)
353395

354-
y = atleast_nd(x, ndim=9)
355-
xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1)))
396+
@pytest.mark.parametrize(
397+
("x_data", "ndim", "expected_data"),
398+
[
399+
([[[[[3.0]], [[2.0]], [[1.0]]]]], 0, [[[[[3.0]], [[2.0]], [[1.0]]]]]),
400+
([[[[[3.0, 2.0, 6.0]]]]], 2, [[[[[3.0, 2.0, 6.0]]]]]),
401+
(
402+
[[[[[3.0]]], [[[2.0]]], [[[1.0]]]]],
403+
4,
404+
[[[[[3.0]]], [[[2.0]]], [[[1.0]]]]],
405+
),
406+
(
407+
[[[[[3.0]], [[1.0]]], [[[2.0]], [[1.0]]], [[[1.0]], [[1.0]]]]],
408+
6,
409+
[[[[[[3.0]], [[1.0]]], [[[2.0]], [[1.0]]], [[[1.0]], [[1.0]]]]]],
410+
),
411+
(
412+
[[[[[3.0]], [[1.0]]], [[[2.0]], [[1.0]]], [[[1.0]], [[1.0]]]]],
413+
9,
414+
[[[[[[[[[3.0]], [[1.0]]], [[[2.0]], [[1.0]]], [[[1.0]], [[1.0]]]]]]]]],
415+
),
416+
],
417+
)
418+
def test_5D(
419+
self,
420+
x_data: NestedFloatList,
421+
ndim: int,
422+
expected_data: NestedFloatList,
423+
xp: ModuleType,
424+
):
425+
x = xp.asarray(x_data)
426+
expected = xp.asarray(expected_data)
427+
y = atleast_nd(x, ndim=ndim)
428+
xp_assert_equal(y, expected)
356429

357430
def test_device(self, xp: ModuleType, device: Device):
358431
x = xp.asarray([1, 2, 3], device=device)

0 commit comments

Comments
 (0)