diff --git a/ants/core/ants_image_io.py b/ants/core/ants_image_io.py index d6d34974..c88c5ac9 100644 --- a/ants/core/ants_image_io.py +++ b/ants/core/ants_image_io.py @@ -259,7 +259,7 @@ def image_clone(image, pixeltype=None): image : ANTsImage image to clone - dtype : string (optional) + pixeltype : string (optional) new datatype for image Returns @@ -460,8 +460,8 @@ def clone(image, pixeltype=None): Arguments --------- - dtype: string (optional) - if None, the dtype will be the same as the cloned ANTsImage. Otherwise, + pixeltype: string (optional) + if None, the pixeltype will be the same as the cloned ANTsImage. Otherwise, the data will be cast to this type. This can be a numpy type or an ITK type. Options: @@ -478,7 +478,11 @@ def clone(image, pixeltype=None): pixeltype = image.pixeltype if pixeltype not in _supported_ptypes: - raise ValueError('Pixeltype %s not supported. Supported types are %s' % (pixeltype, _supported_ptypes)) + # check if the pixeltype is a numpy type + if pixeltype in _supported_ntypes: + pixeltype = _npy_to_itk_map[pixeltype] + else: + raise ValueError('Pixeltype %s not supported. Supported types are %s' % (pixeltype, _supported_ptypes)) if image.has_components and (not image.is_rgb): comp_imgs = ants.split_channels(image) diff --git a/tests/test_core_ants_image.py b/tests/test_core_ants_image.py index a8b5f6f4..b63fb699 100644 --- a/tests/test_core_ants_image.py +++ b/tests/test_core_ants_image.py @@ -27,7 +27,8 @@ def setUp(self): img2d = ants.image_read(ants.get_ants_data('r16')) img3d = ants.image_read(ants.get_ants_data('mni')) self.imgs = [img2d, img3d] - self.pixeltypes = ['unsigned char', 'unsigned int', 'float'] + self.pixeltypes = ['unsigned char', 'unsigned int', 'float', 'double'] + self.numpy_pixeltypes = ['uint8', 'uint32', 'float32', 'float64'] def tearDown(self): pass @@ -138,10 +139,10 @@ def test_clone(self): #self.setUp() for img in self.imgs: orig_ptype = img.pixeltype - for ptype in self.pixeltypes: + for ptype in [*self.pixeltypes, *self.numpy_pixeltypes]: imgclone = img.clone(ptype) - self.assertEqual(imgclone.pixeltype, ptype) + self.assertIn(ptype, [imgclone.dtype, imgclone.pixeltype]) self.assertEqual(img.pixeltype, orig_ptype) # test physical space consistency self.assertTrue(ants.image_physical_space_consistency(img, imgclone)) @@ -530,7 +531,7 @@ def setUp(self): img2d = ants.image_read(ants.get_ants_data('r16')).clone('float') img3d = ants.image_read(ants.get_ants_data('mni')).clone('float') self.imgs = [img2d, img3d] - self.pixeltypes = ['unsigned char', 'unsigned int', 'float'] + self.pixeltypes = ['unsigned char', 'unsigned int', 'float', 'double'] def tearDown(self): pass