|
54 | 54 | lazy_xp_function(setdiff1d, jax_jit=False) |
55 | 55 | lazy_xp_function(sinc) |
56 | 56 |
|
| 57 | +NestedFloatList = list[float] | list["NestedFloatList"] |
| 58 | + |
57 | 59 |
|
58 | 60 | class TestApplyWhere: |
59 | 61 | @staticmethod |
@@ -291,68 +293,139 @@ def test_0D(self, xp: ModuleType): |
291 | 293 | y = atleast_nd(x, ndim=5) |
292 | 294 | xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1))) |
293 | 295 |
|
294 | | - def test_1D(self, xp: ModuleType): |
295 | | - x = xp.asarray([0, 1]) |
296 | | - |
297 | | - y = atleast_nd(x, ndim=0) |
298 | | - xp_assert_equal(y, x) |
299 | | - |
300 | | - y = atleast_nd(x, ndim=1) |
301 | | - xp_assert_equal(y, x) |
302 | | - |
303 | | - y = atleast_nd(x, ndim=2) |
304 | | - xp_assert_equal(y, xp.asarray([[0, 1]])) |
305 | | - |
306 | | - y = atleast_nd(x, ndim=5) |
307 | | - xp_assert_equal(y, xp.asarray([[[[[0, 1]]]]])) |
308 | | - |
309 | | - def test_2D(self, xp: ModuleType): |
310 | | - x = xp.asarray([[3.0]]) |
311 | | - |
312 | | - y = atleast_nd(x, ndim=0) |
313 | | - xp_assert_equal(y, x) |
314 | | - |
315 | | - y = atleast_nd(x, ndim=2) |
316 | | - xp_assert_equal(y, x) |
317 | | - |
318 | | - y = atleast_nd(x, ndim=3) |
319 | | - xp_assert_equal(y, 3 * xp.ones((1, 1, 1))) |
320 | | - |
321 | | - y = atleast_nd(x, ndim=5) |
322 | | - xp_assert_equal(y, 3 * xp.ones((1, 1, 1, 1, 1))) |
323 | | - |
324 | | - def test_3D(self, xp: ModuleType): |
325 | | - x = xp.asarray([[[3.0], [2.0]]]) |
326 | | - |
327 | | - y = atleast_nd(x, ndim=0) |
328 | | - xp_assert_equal(y, x) |
329 | | - |
330 | | - y = atleast_nd(x, ndim=2) |
331 | | - xp_assert_equal(y, x) |
332 | | - |
333 | | - y = atleast_nd(x, ndim=3) |
334 | | - xp_assert_equal(y, x) |
335 | | - |
336 | | - y = atleast_nd(x, ndim=5) |
337 | | - xp_assert_equal(y, xp.asarray([[[[[3.0], [2.0]]]]])) |
338 | | - |
339 | | - def test_5D(self, xp: ModuleType): |
340 | | - x = xp.ones((1, 1, 1, 1, 1)) |
341 | | - |
342 | | - y = atleast_nd(x, ndim=0) |
343 | | - xp_assert_equal(y, x) |
| 296 | + @pytest.mark.parametrize( |
| 297 | + ("x_data", "ndim", "expected_data"), |
| 298 | + [ |
| 299 | + # --- size-1 vector --- |
| 300 | + ([3.0], 0, [3.0]), |
| 301 | + ([3.0], 1, [3.0]), |
| 302 | + ([3.0], 2, [[3.0]]), |
| 303 | + ([3.0], 3, [[[3.0]]]), |
| 304 | + ([3.0], 5, [[[[[3.0]]]]]), |
| 305 | + # --- size-2 vector --- |
| 306 | + ([0.0, 1.0], 0, [0.0, 1.0]), |
| 307 | + ([0.0, 1.0], 1, [0.0, 1.0]), |
| 308 | + ([0.0, 1.0], 2, [[0.0, 1.0]]), |
| 309 | + ([0.0, 1.0], 5, [[[[[0.0, 1.0]]]]]), |
| 310 | + ], |
| 311 | + ) |
| 312 | + def test_1D( |
| 313 | + self, |
| 314 | + x_data: NestedFloatList, |
| 315 | + ndim: int, |
| 316 | + expected_data: NestedFloatList, |
| 317 | + xp: ModuleType, |
| 318 | + ): |
| 319 | + x = xp.asarray(x_data) |
| 320 | + expected = xp.asarray(expected_data) |
| 321 | + y = atleast_nd(x, ndim=ndim) |
| 322 | + xp_assert_equal(y, expected) |
344 | 323 |
|
345 | | - y = atleast_nd(x, ndim=4) |
346 | | - xp_assert_equal(y, x) |
| 324 | + @pytest.mark.parametrize( |
| 325 | + ("x_data", "ndim", "expected_data"), |
| 326 | + [ |
| 327 | + # --- size-1 vector --- |
| 328 | + ([[3.0]], 0, [[3.0]]), |
| 329 | + ([[3.0]], 1, [[3.0]]), |
| 330 | + ([[3.0]], 2, [[3.0]]), |
| 331 | + ([[3.0]], 3, [[[3.0]]]), |
| 332 | + ([[3.0]], 5, [[[[[3.0]]]]]), |
| 333 | + # --- size-2 vector --- |
| 334 | + ([[0.0], [1.0]], 0, [[0.0], [1.0]]), |
| 335 | + ([[0.0, 1.0]], 1, [[0.0, 1.0]]), |
| 336 | + ([[0.0, 1.0]], 2, [[0.0, 1.0]]), |
| 337 | + ([[0.0], [1.0]], 3, [[[0.0], [1.0]]]), |
| 338 | + ([[0.0, 1.0]], 5, [[[[[0.0, 1.0]]]]]), |
| 339 | + ], |
| 340 | + ) |
| 341 | + def test_2D( |
| 342 | + self, |
| 343 | + x_data: NestedFloatList, |
| 344 | + ndim: int, |
| 345 | + expected_data: NestedFloatList, |
| 346 | + xp: ModuleType, |
| 347 | + ): |
| 348 | + x = xp.asarray(x_data) |
| 349 | + expected = xp.asarray(expected_data) |
| 350 | + y = atleast_nd(x, ndim=ndim) |
| 351 | + xp_assert_equal(y, expected) |
347 | 352 |
|
348 | | - y = atleast_nd(x, ndim=5) |
349 | | - xp_assert_equal(y, x) |
| 353 | + @pytest.mark.parametrize( |
| 354 | + ("x_data", "ndim", "expected_data"), |
| 355 | + [ |
| 356 | + ([[[0.0]], [[1.0]]], 0, [[[0.0]], [[1.0]]]), |
| 357 | + ([[[0.0], [1.0]]], 1, [[[0.0], [1.0]]]), |
| 358 | + ([[[0.0, 1.0]]], 2, [[[0.0, 1.0]]]), |
| 359 | + ([[[0.0]], [[1.0]]], 3, [[[0.0]], [[1.0]]]), |
| 360 | + ([[[0.0], [1.0]]], 5, [[[[[0.0], [1.0]]]]]), |
| 361 | + ], |
| 362 | + ) |
| 363 | + def test_3D( |
| 364 | + self, |
| 365 | + x_data: NestedFloatList, |
| 366 | + ndim: int, |
| 367 | + expected_data: NestedFloatList, |
| 368 | + xp: ModuleType, |
| 369 | + ): |
| 370 | + x = xp.asarray(x_data) |
| 371 | + expected = xp.asarray(expected_data) |
| 372 | + y = atleast_nd(x, ndim=ndim) |
| 373 | + xp_assert_equal(y, expected) |
350 | 374 |
|
351 | | - y = atleast_nd(x, ndim=6) |
352 | | - xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1, 1))) |
| 375 | + @pytest.mark.parametrize( |
| 376 | + ("x_data", "ndim", "expected_data"), |
| 377 | + [ |
| 378 | + ([[[[3.0], [2.0]]]], 0, [[[[3.0], [2.0]]]]), |
| 379 | + ([[[[3.0, 2.0]]]], 2, [[[[3.0, 2.0]]]]), |
| 380 | + ([[[[3.0]], [[2.0]]]], 4, [[[[3.0]], [[2.0]]]]), |
| 381 | + ([[[[3.0]]], [[[2.0]]]], 5, [[[[[3.0]]], [[[2.0]]]]]), |
| 382 | + ], |
| 383 | + ) |
| 384 | + def test_4D( |
| 385 | + self, |
| 386 | + x_data: NestedFloatList, |
| 387 | + ndim: int, |
| 388 | + expected_data: NestedFloatList, |
| 389 | + xp: ModuleType, |
| 390 | + ): |
| 391 | + x = xp.asarray(x_data) |
| 392 | + expected = xp.asarray(expected_data) |
| 393 | + y = atleast_nd(x, ndim=ndim) |
| 394 | + xp_assert_equal(y, expected) |
353 | 395 |
|
354 | | - y = atleast_nd(x, ndim=9) |
355 | | - xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1))) |
| 396 | + @pytest.mark.parametrize( |
| 397 | + ("x_data", "ndim", "expected_data"), |
| 398 | + [ |
| 399 | + ([[[[[3.0]], [[2.0]], [[1.0]]]]], 0, [[[[[3.0]], [[2.0]], [[1.0]]]]]), |
| 400 | + ([[[[[3.0, 2.0, 6.0]]]]], 2, [[[[[3.0, 2.0, 6.0]]]]]), |
| 401 | + ( |
| 402 | + [[[[[3.0]]], [[[2.0]]], [[[1.0]]]]], |
| 403 | + 4, |
| 404 | + [[[[[3.0]]], [[[2.0]]], [[[1.0]]]]], |
| 405 | + ), |
| 406 | + ( |
| 407 | + [[[[[3.0]], [[1.0]]], [[[2.0]], [[1.0]]], [[[1.0]], [[1.0]]]]], |
| 408 | + 6, |
| 409 | + [[[[[[3.0]], [[1.0]]], [[[2.0]], [[1.0]]], [[[1.0]], [[1.0]]]]]], |
| 410 | + ), |
| 411 | + ( |
| 412 | + [[[[[3.0]], [[1.0]]], [[[2.0]], [[1.0]]], [[[1.0]], [[1.0]]]]], |
| 413 | + 9, |
| 414 | + [[[[[[[[[3.0]], [[1.0]]], [[[2.0]], [[1.0]]], [[[1.0]], [[1.0]]]]]]]]], |
| 415 | + ), |
| 416 | + ], |
| 417 | + ) |
| 418 | + def test_5D( |
| 419 | + self, |
| 420 | + x_data: NestedFloatList, |
| 421 | + ndim: int, |
| 422 | + expected_data: NestedFloatList, |
| 423 | + xp: ModuleType, |
| 424 | + ): |
| 425 | + x = xp.asarray(x_data) |
| 426 | + expected = xp.asarray(expected_data) |
| 427 | + y = atleast_nd(x, ndim=ndim) |
| 428 | + xp_assert_equal(y, expected) |
356 | 429 |
|
357 | 430 | def test_device(self, xp: ModuleType, device: Device): |
358 | 431 | x = xp.asarray([1, 2, 3], device=device) |
|
0 commit comments