diff --git a/ants/registration/registration.py b/ants/registration/registration.py index 1a377524..5741dc2a 100644 --- a/ants/registration/registration.py +++ b/ants/registration/registration.py @@ -1621,8 +1621,8 @@ def label_image_registration(fixed_label_images, of the moving image. type_of_linear_transform : string - Use label images with the centers of mass to a calculate linear - transform of type 'rigid', 'similarity', or 'affine'. + Use label images with the centers of mass to a calculate linear + transform of type 'identity', 'rigid', 'similarity', or 'affine'. type_of_deformable_transform : string Only works with deformable-only transforms, specifically the family @@ -1707,7 +1707,7 @@ def label_image_registration(fixed_label_images, if output_prefix == "" or output_prefix is None or len(output_prefix) == 0: output_prefix = mktemp() - allowable_linear_transforms = ['rigid', 'similarity', 'affine'] + allowable_linear_transforms = ['rigid', 'similarity', 'affine', 'identity'] if not type_of_linear_transform in allowable_linear_transforms: raise ValueError("Unrecognized linear transform.") @@ -1729,6 +1729,8 @@ def label_image_registration(fixed_label_images, if len(common_label_ids[i]) == 0: raise ValueError("No common labels for image pair " + str(i)) + deformable_multivariate_extras = list() + if verbose: print("Total number of labels: " + str(total_number_of_labels)) @@ -1739,7 +1741,7 @@ def label_image_registration(fixed_label_images, ############################## linear_xfrm = None - if type_of_linear_transform is not None: + if type_of_linear_transform != 'identity': if verbose: print("\n\nComputing linear transform.\n") @@ -1749,8 +1751,6 @@ def label_image_registration(fixed_label_images, fixed_centers_of_mass = np.zeros((total_number_of_labels, image_dimension)) moving_centers_of_mass = np.zeros((total_number_of_labels, image_dimension)) - deformable_multivariate_extras = list() - count = 0 for i in range(len(common_label_ids)): for j in range(len(common_label_ids[i])): @@ -1783,6 +1783,16 @@ def label_image_registration(fixed_label_images, if do_deformable: + if type_of_linear_transform == "identity": + for i in range(len(common_label_ids)): + for j in range(len(common_label_ids[i])): + label = common_label_ids[i][j] + fixed_single_label_image = ants.threshold_image(fixed_label_images[i], label, label, 1, 0) + moving_single_label_image = ants.threshold_image(moving_label_images[i], label, label, 1, 0) + deformable_multivariate_extras.append(["MSQ", fixed_single_label_image, + moving_single_label_image, + label_image_weights[i], 0]) + if verbose: print("\n\nComputing deformable transform using images.\n") @@ -1940,13 +1950,20 @@ def label_image_registration(fixed_label_images, find_inverse_warps = np.where([re.search("[0-9]InverseWarp.nii.gz", ff) for ff in all_xfrms])[0] find_forward_warps = np.where([re.search("[0-9]Warp.nii.gz", ff) for ff in all_xfrms])[0] - - if len(find_inverse_warps) > 0: - fwdtransforms = [all_xfrms[find_forward_warps[0]], linear_xfrm_file] - invtransforms = [linear_xfrm_file, all_xfrms[find_inverse_warps[0]]] + + fwdtransforms = [] + invtransforms = [] + if linear_xfrm is not None: + if len(find_inverse_warps) > 0: + fwdtransforms = [all_xfrms[find_forward_warps[0]], linear_xfrm_file] + invtransforms = [linear_xfrm_file, all_xfrms[find_inverse_warps[0]]] + else: + fwdtransforms = [linear_xfrm_file] + invtransforms = [linear_xfrm_file] else: - fwdtransforms = [linear_xfrm_file] - invtransforms = [linear_xfrm_file] + if len(find_inverse_warps) > 0: + fwdtransforms = [all_xfrms[find_forward_warps[0]]] + invtransforms = [all_xfrms[find_inverse_warps[0]]] if verbose: print("\n\nResulting transforms")