From b5380b4c313faef84c1917983ea4a9106638f335 Mon Sep 17 00:00:00 2001 From: Philip Cook Date: Tue, 22 Oct 2024 09:59:28 -0400 Subject: [PATCH] BUG: Fix slicing. More consistent with numpy ENH: Support more indexing options --- ants/core/ants_image.py | 30 +++++-- tests/test_core_ants_image_indexing.py | 117 ++++++++++++++++--------- 2 files changed, 98 insertions(+), 49 deletions(-) diff --git a/ants/core/ants_image.py b/ants/core/ants_image.py index 87d21663..f28e8210 100644 --- a/ants/core/ants_image.py +++ b/ants/core/ants_image.py @@ -507,23 +507,37 @@ def __getitem__(self, idx): raise ValueError('images do not occupy same physical space') return self.numpy().__getitem__(idx.numpy().astype('bool')) - ndim = len(idx) + # convert idx to tuple if it is not, eg im[10] or im[10:20] + if not isinstance(idx, tuple): + idx = (idx,) + + ndim = len(self.shape) + + if len(idx) > ndim: + raise ValueError('Too many indices for image') + if len(idx) < ndim: + # If not all dimensions are indexed, assume the rest are full slices + # eg im[10] -> im[10, :, :] + idx = idx + (slice(None),) * (ndim - len(idx)) + sizes = list(self.shape) starts = [0] * ndim - + stops = list(self.shape) for i in range(ndim): ti = idx[i] if isinstance(ti, slice): if ti.start: starts[i] = ti.start if ti.stop: - sizes[i] = ti.stop - starts[i] - else: - sizes[i] = self.shape[i] - starts[i] + if ti.stop < 0: + stops[i] = self.shape[i] + ti.stop + else: + stops[i] = ti.stop + + sizes[i] = stops[i] - starts[i] - if ti.stop and ti.start: - if ti.stop < ti.start: - raise Exception('Reverse indexing is not supported.') + if stops[i] < starts[i]: + raise ValueError('Reverse indexing is not supported.') elif isinstance(ti, int): starts[i] = ti diff --git a/tests/test_core_ants_image_indexing.py b/tests/test_core_ants_image_indexing.py index 0dcd3d5d..673c9819 100644 --- a/tests/test_core_ants_image_indexing.py +++ b/tests/test_core_ants_image_indexing.py @@ -20,12 +20,12 @@ class TestClass_AntsImageIndexing(unittest.TestCase): - + def setUp(self): pass def tearDown(self): pass - + def test_pixeltype_2d(self): img = ants.image_read(ants.get_data('r16')) for ptype in ['unsigned char', 'unsigned int', 'float', 'double']: @@ -33,7 +33,7 @@ def test_pixeltype_2d(self): self.assertEqual(img.pixeltype, ptype) img2 = img[:10,:10] self.assertEqual(img2.pixeltype, ptype) - + def test_pixeltype_3d(self): img = ants.image_read(ants.get_data('mni')) for ptype in ['unsigned char', 'unsigned int', 'float', 'double']: @@ -43,41 +43,41 @@ def test_pixeltype_3d(self): self.assertEqual(img2.pixeltype, ptype) img3 = img[:10,:10,10] self.assertEqual(img3.pixeltype, ptype) - + def test_2d(self): img = ants.image_read(ants.get_data('r16')) - + img2 = img[:10,:10] self.assertEqual(img2.dimension, 2) img2 = img[:5,:5] self.assertEqual(img2.dimension, 2) - + img2 = img[1:20,1:10] self.assertEqual(img2.dimension, 2) img2 = img[:5,:5] self.assertEqual(img2.dimension, 2) img2 = img[:5,4:5] self.assertEqual(img2.dimension, 2) - + img2 = img[5:5,5:5] - + # down to 1d arr = img[10,:] self.assertTrue(isinstance(arr, np.ndarray)) - + # single value arr = img[10,10] - + def test_2d_image_index(self): img = ants.image_read(ants.get_data('r16')) idx = img > 200 - + # acts like a mask img2 = img[idx] def test_3d(self): img = ants.image_read(ants.get_data('mni')) - + img2 = img[:10,:10,:10] self.assertEqual(img2.dimension, 3) img2 = img[:5,:5,:5] @@ -86,7 +86,7 @@ def test_3d(self): self.assertEqual(img2.dimension, 3) img2 = img[:5,:5,:5] self.assertEqual(img2.dimension, 3) - + # down to 2d img2 = img[10,:,:] self.assertEqual(img2.dimension, 2) @@ -100,110 +100,145 @@ def test_3d(self): self.assertEqual(img2.dimension, 2) img2 = img[2:20,3:30,10] self.assertEqual(img2.dimension, 2) - + # down to 1d arr = img[10,:,10] self.assertTrue(isinstance(arr, np.ndarray)) arr = img[10,:,5] self.assertTrue(isinstance(arr, np.ndarray)) - + # single value arr = img[10,10,10] - + def test_double_indexing(self): img = ants.image_read(ants.get_data('mni')) img2 = img[20:,:,:] self.assertEqual(img2.shape, (162,218,182)) - + img3 = img[0,:,:] self.assertEqual(img3.shape, (218,182)) - + def test_reverse_error(self): img = ants.image_read(ants.get_data('mni')) with self.assertRaises(Exception): img2 = img[20:10,:,:] - + def test_2d_vector(self): img = ants.image_read(ants.get_data('r16')) img2 = img[:10,:10] - + img_v = ants.merge_channels([img]) img_v2 = img_v[:10,:10] - + self.assertTrue(ants.allclose(img2, ants.split_channels(img_v2)[0])) def test_2d_vector_multi(self): img = ants.image_read(ants.get_data('r16')) img2 = img[:10,:10] - + img_v = ants.merge_channels([img,img,img]) img_v2 = img_v[:10,:10] - + self.assertTrue(ants.allclose(img2, ants.split_channels(img_v2)[0])) self.assertTrue(ants.allclose(img2, ants.split_channels(img_v2)[1])) self.assertTrue(ants.allclose(img2, ants.split_channels(img_v2)[2])) - + def test_setting_3d(self): img = ants.image_read(ants.get_data('mni')) img2d = img[100,:,:] - + # setting a sub-image with an image img2 = img + 10 img2[100,:,:] = img2d - + self.assertFalse(ants.allclose(img, img2)) self.assertTrue(ants.allclose(img2d, img2[100,:,:])) - + # setting a sub-image with an array img2 = img + 10 img2[100,:,:] = img2d.numpy() - + self.assertFalse(ants.allclose(img, img2)) self.assertTrue(ants.allclose(img2d, img2[100,:,:])) - + def test_setting_2d(self): img = ants.image_read(ants.get_data('r16')) img2d = img[100,:] - + # setting a sub-image with an image img2 = img + 10 img2[100,:] = img2d - + self.assertFalse(ants.allclose(img, img2)) self.assertTrue(np.allclose(img2d, img2[100,:])) - - + + def test_setting_2d_sub_image(self): img = ants.image_read(ants.get_data('r16')) img2d = img[10:30,10:30] - + # setting a sub-image with an image img2 = img + 10 img2[10:30,10:30] = img2d - + self.assertFalse(ants.allclose(img, img2)) self.assertTrue(ants.allclose(img2d, img2[10:30,10:30])) - + # setting a sub-image with an array img2 = img + 10 img2[10:30,10:30] = img2d.numpy() - + self.assertFalse(ants.allclose(img, img2)) self.assertTrue(ants.allclose(img2d, img2[10:30,10:30])) - + def test_setting_correctness(self): - + img = ants.image_read(ants.get_data('r16')) * 0 self.assertEqual(img.sum(), 0) - + img2 = img[10:30,10:30] img2 = img2 + 10 self.assertEqual(img2.mean(), 10) - + img[:20,:20] = img2 self.assertEqual(img.sum(), img2.sum()) self.assertEqual(img.numpy()[:20,:20].sum(), img2.sum()) self.assertNotEqual(img.numpy()[10:30,10:30].sum(), img2.sum()) + + def test_slicing_3d(self): + img = ants.image_read(ants.get_data('mni')) + img2 = img[:10,:10,:10] + img3 = img[10:20,10:20,10:20] + + self.assertTrue(ants.allclose(img2, img3)) + + img_np = img.numpy() + + self.assertTrue(np.allclose(img2.numpy(), img_np[:10,:10,:10])) + self.assertTrue(np.allclose(img[20].numpy(), img_np[20])) + self.assertTrue(np.allclose(img[:,20:40].numpy(), img_np[:,20:40])) + self.assertTrue(np.allclose(img[:,:,20:-2].numpy(), img_np[:,:,20:-2])) + self.assertTrue(np.allclose(img[0:-1,].numpy(), img_np[0:-1,])) + self.assertTrue(np.allclose(img[100,10:100,0:-1].numpy(), img_np[100,10:100,0:-1])) + self.assertTrue(np.allclose(img[:,10:,30:].numpy(), img_np[:,10:,30:])) + # if the slice returns 1D, it should be a numpy array already + self.assertTrue(np.allclose(img[100:-1,30,40], img_np[100:-1,30,40])) + + def test_slicing_2d(self): + img = ants.image_read(ants.get_data('r16')) + + img2 = img[:10,:10] + + img_np = img.numpy() + + self.assertTrue(np.allclose(img2.numpy(), img_np[:10,:10])) + self.assertTrue(np.allclose(img[:,20:40].numpy(), img_np[:,20:40])) + self.assertTrue(np.allclose(img[0:-1,].numpy(), img_np[0:-1,])) + self.assertTrue(np.allclose(img[50:,10:-3].numpy(), img_np[50:,10:-3])) + # if the slice returns 1D, it should be a numpy array already + self.assertTrue(np.allclose(img[20], img_np[20])) + self.assertTrue(np.allclose(img[100:-1,30], img_np[100:-1,30])) + if __name__ == '__main__': run_tests() \ No newline at end of file