Skip to content

Commit

Permalink
Added support for multiple ROIs
Browse files Browse the repository at this point in the history
  • Loading branch information
clementpoiret committed Jul 10, 2021
1 parent f94b898 commit da12564
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 54 deletions.
46 changes: 27 additions & 19 deletions roiloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from rich.progress import track

from roiloc.location import crop, get_coords
from roiloc.registration import register
from roiloc.template import get_mni, get_roi, get_roi_indices
from roiloc.registration import get_roi, register
from roiloc.template import get_mni, get_roi_indices

console = Console()

Expand All @@ -28,7 +28,8 @@ def main(args):
path = Path(args.path).expanduser()

# Getting roi from cerebra's csv
roi_idx = get_roi_indices(args.roi.title())
rois_idx = {roi: get_roi_indices(roi) for roi in args.roi}
# roi_idx = get_roi_indices(args.roi)

# Loading mris, template and atlas
images = list(path.glob(args.inputpattern))
Expand All @@ -55,23 +56,29 @@ def main(args):
path=image_path.parent,
mask=args.mask)

print("Transforming and saving rois...")
for i, side in enumerate(["right", "left"]):
region = get_roi(
image=image,
atlas=atlas,
idx=int(roi_idx[i]),
transform=registration["fwdtransforms"],
output_dir=str(image_path.parent),
output_file=
f"{stem}_{args.roi}_{side}_{args.transform}_mask.nii.gz",
save=True)
registered_atlas = ants.apply_transforms(
fixed=image,
moving=atlas,
transformlist=registration["fwdtransforms"],
interpolator="nearestNeighbor")

coords = get_coords(region.numpy(), margin=args.margin)
for roi in rois_idx:
print(f"Transforming and saving {roi}...")

crop(
image_path, coords, image_path.parent /
f"{stem}_{args.roi}_{side}_{args.transform}_crop.nii.gz")
for i, side in enumerate(["right", "left"]):
region = get_roi(
registered_atlas=registered_atlas,
idx=int(rois_idx[roi][i]),
output_dir=str(image_path.parent),
output_file=
f"{stem}_{args.roi}_{side}_{args.transform}_mask.nii.gz",
save=True)

coords = get_coords(region.numpy(), margin=args.margin)

crop(
image_path, coords, image_path.parent /
f"{stem}_{args.roi}_{side}_{args.transform}_crop.nii.gz")

print("[bold green]Done! :)")

Expand Down Expand Up @@ -99,10 +106,11 @@ def main(args):
parser.add_argument(
"-r",
"--roi",
nargs='+',
help=
"ROI included in CerebrA. See `roiloc/MNI/cerebra/CerebrA_LabelDetails.csv` for more details. Default: 'Hippocampus'.",
required=False,
default="Hippocampus",
default=["Hippocampus"],
type=str)

parser.add_argument(
Expand Down
30 changes: 30 additions & 0 deletions roiloc/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import ants
from ants.core.ants_image import ANTsImage
from rich import print


def register(fixed: ANTsImage,
Expand Down Expand Up @@ -35,3 +36,32 @@ def register(fixed: ANTsImage,
moving=moving,
type_of_transform=type_of_transform,
mask=mask)


def get_roi(registered_atlas: ANTsImage,
idx: int,
output_dir: str,
output_file: str,
save: bool = True) -> ANTsImage:
"""Get the registered ROI from CerebrA atlas, into a
subject's native space.
Args:
image (ANTsImage): Subject's MRI
atlas (ANTsImage): CerebrA Atlas
idx (int): Index of the ROI
transform (list): Transformation from MNI to Native space
output_dir (str): Where to save the ROIs
output_file (str): Name of the ROIs
save (bool, optional): Save or not the ROIs. Defaults to True.
Returns:
ANTsImage: ROI in native space
"""
roi = registered_atlas.copy()
roi[roi != idx] = 0

if save:
roi.to_file(f"{output_dir}/{output_file}")

return roi
36 changes: 1 addition & 35 deletions roiloc/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,41 +26,7 @@ def get_mni(contrast: str, bet: bool) -> ANTsImage:


def get_roi_indices(roi: str) -> list:
roi = roi.title()
cerebra = pd.read_csv("roiloc/MNI/cerebra/CerebrA_LabelDetails.csv",
index_col="Label Name")
return [cerebra.loc[roi, "RH Label"], cerebra.loc[roi, "LH Labels"]]


def get_roi(image: ANTsImage,
atlas: ANTsImage,
idx: int,
transform: list,
output_dir: str,
output_file: str,
save: bool = True) -> ANTsImage:
"""Get the registered ROI from CerebrA atlas, into a
subject's native space.
Args:
image (ANTsImage): Subject's MRI
atlas (ANTsImage): CerebrA Atlas
idx (int): Index of the ROI
transform (list): Transformation from MNI to Native space
output_dir (str): Where to save the ROIs
output_file (str): Name of the ROIs
save (bool, optional): Save or not the ROIs. Defaults to True.
Returns:
ANTsImage: ROI in native space
"""
hippocampus = ants.apply_transforms(fixed=image,
moving=atlas,
transformlist=transform,
interpolator="nearestNeighbor")

hippocampus[hippocampus != idx] = 0

if save:
hippocampus.to_file(f"{output_dir}/{output_file}")

return hippocampus

0 comments on commit da12564

Please sign in to comment.