Skip to content

Commit

Permalink
add nibabel conversion (lacey import to prevent forced dependency)
Browse files Browse the repository at this point in the history
  • Loading branch information
ga84mun committed Aug 15, 2024
1 parent 5f550b8 commit a7190e2
Showing 1 changed file with 112 additions and 18 deletions.
130 changes: 112 additions & 18 deletions ants/utils/nifti_to_ants.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,135 @@
__all__ = ["nifti_to_ants"]
__all__ = ["nifti_to_ants", "ants_to_nifti", "from_nibabel", "to_nibabel"]

from typing import TYPE_CHECKING

import os
from tempfile import mkstemp
import numpy as np
import ants
import numpy as np

if TYPE_CHECKING:
from nibabel.nifti1 import Nifti1Image


def nifti_to_ants( nib_image ):
def nifti_to_ants(nib_image: "Nifti1Image") -> ants.ANTsImage:
"""
Converts a given Nifti image into an ANTsPy image
Convert a Nifti image to an ANTsPy image.
Parameters
----------
img: NiftiImage
nib_image : Nifti1Image
The Nifti image to be converted.
Returns
-------
ants_image: ANTsImage
ants_image : ants.ANTsImage
The converted ANTs image.
"""
ndim = nib_image.ndim

if ndim < 3:
print("Dimensionality is less than 3.")
return None
raise NotImplementedError("Conversion is only implemented for 3D or higher images.")

q_form = nib_image.get_qform()
spacing = nib_image.header["pixdim"][1 : ndim + 1]

origin = np.zeros((ndim))
origin = np.zeros(ndim)
origin[:3] = q_form[:3, 3]

direction = np.diag(np.ones(ndim))
direction = np.eye(ndim)
direction[:3, :3] = q_form[:3, :3] / spacing[:3]

ants_img = ants.from_numpy(
data = nib_image.get_data().astype( np.float ),
origin = origin.tolist(),
spacing = spacing.tolist(),
direction = direction )

ants_img = ants.from_numpy(data=nib_image.get_fdata(), origin=origin.tolist(), spacing=spacing.tolist(), direction=direction)

return ants_img


def get_ras_affine_from_ants(ants_img: ants.ANTsImage) -> np.ndarray:
"""
Convert ANTs image affine to RAS coordinate system.
Parameters
----------
ants_img : ants.ANTsImage
The ANTs image whose affine is to be converted.
Returns
-------
affine : np.ndarray
The affine matrix in RAS coordinates.
"""
spacing = np.array(ants_img.spacing)
direction_lps = np.array(ants_img.direction)
origin_lps = np.array(ants_img.origin)
direction_length = direction_lps.shape[0] * direction_lps.shape[1]
if direction_length == 9:
rotation_lps = direction_lps.reshape(3, 3)
elif direction_length == 4: # 2D case (1, W, H, 1)
rotation_lps_2d = direction_lps.reshape(2, 2)
rotation_lps = np.eye(3)
rotation_lps[:2, :2] = rotation_lps_2d
spacing = np.append(spacing, 1)
origin_lps = np.append(origin_lps, 0)
elif direction_length == 16: # Fix potential bad NIfTI
rotation_lps = direction_lps.reshape(4, 4)[:3, :3]
spacing = spacing[:-1]
origin_lps = origin_lps[:-1]
else:
raise NotImplementedError(f"Unexpected direction length = {direction_length}.")

rotation_ras = np.dot(np.diag([-1, -1, 1]), rotation_lps)
rotation_ras_zoom = rotation_ras * spacing
translation_ras = np.dot(np.diag([-1, -1, 1]), origin_lps)

affine = np.eye(4)
affine[:3, :3] = rotation_ras_zoom
affine[:3, 3] = translation_ras

return affine


def ants_to_nifti(img: ants.ANTsImage, header=None) -> "Nifti1Image":
"""
Convert an ANTs image to a Nifti image.
Parameters
----------
img : ants.ANTsImage
The ANTs image to be converted.
header : Nifti1Header, optional
Optional header to use for the Nifti image.
Returns
-------
img : Nifti1Image
The converted Nifti image.
"""
from nibabel.nifti1 import Nifti1Image

affine = get_ras_affine_from_ants(img)
arr = img.numpy()

if header is not None:
header.set_data_dtype(arr.dtype)

return Nifti1Image(arr, affine, header)


# Legacy names for backwards compatibility
from_nibabel = nifti_to_ants
to_nibabel = ants_to_nifti

if __name__ == "__main__":
import nibabel as nib

fn = ants.get_ants_data("mni")
ants_img = ants.image_read(fn)
nii_mni: "Nifti1Image" = nib.load(fn)
ants_mni = to_nibabel(ants_img)
assert (ants_mni.get_qform() == nii_mni.get_qform()).all()
temp = ants.from_nibabel(nii_mni)
assert ants.image_physical_space_consistency(ants_img, temp)

fn = ants.get_data("ch2")
ants_mni = ants.image_read(fn)
nii_mni = nib.load(fn)
ants_mni = to_nibabel(ants_mni)
assert (ants_mni.get_qform() == nii_mni.get_qform()).all()

0 comments on commit a7190e2

Please sign in to comment.