diff --git a/README.rst b/README.rst index 1668b50..1aacbec 100644 --- a/README.rst +++ b/README.rst @@ -20,6 +20,10 @@ It requires the following packages: - Pandas, - Rich. + +CLI +*** + usage: roiloc [-h] -p PATH -i INPUTPATTERN [-r ROI [ROI ...]] -c CONTRAST [-b] [-t TRANSFORM] [-m MARGIN [MARGIN ...]] [--mask MASK] [--extracrops EXTRACROPS [EXTRACROPS ...]] [--savesteps] @@ -57,6 +61,41 @@ arguments:: atlas). +Python API +********** + +Even if the CLI interface is the main use case, a Python API is also available since v0.2.0. + +The API syntax retakes sklearn's API syntax, with a ``RoiLocator`` class, having ``fit``, ``transform``, ``fit_transform`` and ``inverse_transform`` methods as seen below. + +.. code-block:: python + + import ants + from roiloc.locator import RoiLocator + + image = ants.image_read("./sub00_t2w.nii.gz", + reorient="LPI") + + locator = RoiLocator(contrast="t2", roi="hippocampus", bet=False) + + # Fit the locator and get the transformed MRIs + right, left = locator.fit_transform(image) + # Coordinates can be obtained through the `coords` attribute + print(locator.get_coords()) + + # Let 'model' be a segmentation model of the hippocampus + right_seg = model(right) + left_seg = model(left) + + # Transform the segmentation back to the original image + right_seg = locator.inverse_transform(right_seg) + left_seg = locator.inverse_transform(left_seg) + + # Save the resulting segmentations in the original space + ants.image_write(right_seg, "./sub00_hippocampus_right.nii.gz") + ants.image_write(left_seg, "./sub00_hippocampus_left.nii.gz") + + Installation ************ diff --git a/roiloc/locator.py b/roiloc/locator.py index 8d1d9f8..a50bc13 100644 --- a/roiloc/locator.py +++ b/roiloc/locator.py @@ -4,11 +4,35 @@ from ants.core import ANTsImage from .location import crop, get_coords -from .registration import get_roi, register +from .registration import get_roi from .template import get_atlas, get_mni, get_roi_indices class RoiLocator: + """Crop an MRI image to a ROI. + + Args: + contrast (str): Contrast to use for registration. + roi (str): ROI to use for registration. + bet (bool, optional): Use brain extracted MNI template. Defaults to False. + transform_type (str, optional): Type of transformation for the registration. + Defaults to "AffineFast". + margin (list, optional): Margin to apply. Defaults to [4, 4, 4]. + mask (Optional[ANTsImage], optional): Brain mask to improve registration quality. + Defaults to None. + + Attributes: + coords (dict): Dictionary of coordinates for each side of the ROI. + _fwdtransforms (list): List of forward transforms. + _invtransforms (list): List of inverse transforms. + _mni (ANTsImage): MNI template. + _atlas (ANTsImage): CerebrA atlas image. + _roi_idx (list): List of indices for the ROI in the CerebrA atlas. + _image (ANTsImage): Input image used to inverse transform. + + Exemples: + >>> from roiloc.locator import RoiLocator + """ def __init__(self, contrast: str, @@ -17,6 +41,7 @@ def __init__(self, transform_type: str = "AffineFast", margin: list = [4, 4, 4], mask: Optional[ANTsImage] = None): + self.transform_type = transform_type self.margin = margin self.mask = mask @@ -31,9 +56,19 @@ def __init__(self, self.coords = {} def get_coords(self) -> dict: + """Get the coordinates of the ROI. + + Returns: + dict: Dictionary of coordinates for each side of the ROI. + """ return self.coords def fit(self, image: ANTsImage): + """Fit the ROI to the image and set coords. + + Args: + image (ANTsImage): Image to fit the ROI to. + """ self._image = image registration = ants.registration(fixed=image, @@ -59,16 +94,40 @@ def fit(self, image: ANTsImage): self.coords[side] = coords def transform(self, image: ANTsImage) -> list: + """Crop the image to the ROI. + + Args: + image (ANTsImage): Image to transform. + + Returns: + list: List of transformed images. + """ return [ crop(image, self.coords[side], log_coords=False, ri=True) for side in ["right", "left"] ] def fit_transform(self, image: ANTsImage) -> list: + """Fit the ROI to the image and transform. + + Args: + image (ANTsImage): Image to fit the ROI to. + + Returns: + list: List of transformed images. + """ self.fit(image) return self.transform(image) def inverse_transform(self, image: ANTsImage) -> ANTsImage: + """Inverse transform the image to the native space. + + Args: + image (ANTsImage): Image to inverse transform. + + Returns: + ANTsImage: Inverse transformed image. + """ return ants.decrop_image(image, self._image)