|
1 | 1 | # TODO add typing and docs |
2 | | -from abc import abstractmethod |
| 2 | +import shutil |
| 3 | +from abc import ABC, abstractmethod |
3 | 4 | from pathlib import Path |
4 | | -from shutil import copyfile |
| 5 | +from typing import Optional, Union |
| 6 | +from enum import Enum |
5 | 7 |
|
6 | 8 | from auxiliary.nifti.io import read_nifti, write_nifti |
7 | | -from auxiliary.turbopath import name_extractor |
8 | 9 | from brainles_hd_bet import run_hd_bet |
9 | 10 |
|
10 | 11 |
|
| 12 | +class Mode(Enum): |
| 13 | + FAST = "fast" |
| 14 | + ACCURATE = "accurate" |
| 15 | + |
| 16 | + |
11 | 17 | class BrainExtractor: |
12 | 18 | @abstractmethod |
13 | 19 | def extract( |
14 | 20 | self, |
15 | | - input_image_path: str, |
16 | | - masked_image_path: str, |
17 | | - brain_mask_path: str, |
18 | | - log_file_path: str, |
19 | | - # TODO convert mode to enum |
20 | | - mode: str, |
| 21 | + input_image_path: Union[str, Path], |
| 22 | + masked_image_path: Union[str, Path], |
| 23 | + brain_mask_path: Union[str, Path], |
| 24 | + log_file_path: Optional[Union[str, Path]], |
| 25 | + mode: Union[str, Mode], |
| 26 | + **kwargs, |
21 | 27 | ) -> None: |
| 28 | + """ |
| 29 | + Abstract method to extract the brain from an input image. |
| 30 | +
|
| 31 | + Args: |
| 32 | + input_image_path (str or Path): Path to the input image. |
| 33 | + masked_image_path (str or Path): Path where the brain-extracted image will be saved. |
| 34 | + brain_mask_path (str or Path): Path where the brain mask will be saved. |
| 35 | + log_file_path (str or Path, Optional): Path to the log file. |
| 36 | + mode (str or Mode): Extraction mode. |
| 37 | + **kwargs: Additional keyword arguments. |
| 38 | + """ |
22 | 39 | pass |
23 | 40 |
|
24 | 41 | def apply_mask( |
25 | 42 | self, |
26 | | - input_image_path: Path, |
27 | | - mask_path: Path, |
28 | | - bet_image_path: Path, |
| 43 | + input_image_path: Union[str, Path], |
| 44 | + mask_path: Union[str, Path], |
| 45 | + bet_image_path: Union[str, Path], |
29 | 46 | ) -> None: |
30 | 47 | """ |
31 | 48 | Apply a brain mask to an input image. |
32 | 49 |
|
33 | 50 | Args: |
34 | | - input_image_path (str): Path to the input image (NIfTI format). |
35 | | - mask_path (str): Path to the brain mask image (NIfTI format). |
36 | | - bet_image_path (str): Path to save the resulting masked image (NIfTI format). |
| 51 | + input_image_path (str or Path): Path to the input image (NIfTI format). |
| 52 | + mask_path (str or Path): Path to the brain mask image (NIfTI format). |
| 53 | + bet_image_path (str or Path): Path to save the resulting masked image (NIfTI format). |
37 | 54 | """ |
38 | 55 |
|
39 | | - # read data |
40 | | - input_data = read_nifti(input_image_path) |
41 | | - mask_data = read_nifti(mask_path) |
| 56 | + try: |
| 57 | + # Read data |
| 58 | + input_data = read_nifti(str(input_image_path)) |
| 59 | + mask_data = read_nifti(str(mask_path)) |
| 60 | + except FileNotFoundError as e: |
| 61 | + raise FileNotFoundError(f"File not found: {e.filename}") from e |
| 62 | + except Exception as e: |
| 63 | + raise RuntimeError(f"Error reading files: {e}") from e |
| 64 | + |
| 65 | + # Check that the input and mask have the same shape |
| 66 | + if input_data.shape != mask_data.shape: |
| 67 | + raise ValueError("Input image and mask must have the same dimensions.") |
42 | 68 |
|
43 | | - # mask and save it |
| 69 | + # Mask and save it |
44 | 70 | masked_data = input_data * mask_data |
45 | 71 |
|
46 | | - write_nifti( |
47 | | - input_array=masked_data, |
48 | | - output_nifti_path=bet_image_path, |
49 | | - reference_nifti_path=input_image_path, |
50 | | - create_parent_directory=True, |
51 | | - ) |
| 72 | + try: |
| 73 | + write_nifti( |
| 74 | + input_array=masked_data, |
| 75 | + output_nifti_path=str(bet_image_path), |
| 76 | + reference_nifti_path=str(input_image_path), |
| 77 | + create_parent_directory=True, |
| 78 | + ) |
| 79 | + except Exception as e: |
| 80 | + raise RuntimeError(f"Error writing output file: {e}") from e |
52 | 81 |
|
53 | 82 |
|
54 | 83 | class HDBetExtractor(BrainExtractor): |
55 | 84 | def extract( |
56 | 85 | self, |
57 | | - input_image_path: str, |
58 | | - masked_image_path: str, |
59 | | - brain_mask_path: str, |
60 | | - log_file_path: str = None, |
| 86 | + input_image_path: Union[str, Path], |
| 87 | + masked_image_path: Union[str, Path], |
| 88 | + brain_mask_path: Union[str, Path], |
| 89 | + log_file_path: Optional[Union[str, Path]] = None, |
61 | 90 | # TODO convert mode to enum |
62 | | - mode: str = "accurate", |
63 | | - device: int | str = 0, |
64 | | - do_tta: bool = True, |
| 91 | + mode: Union[str, Mode] = Mode.ACCURATE, |
| 92 | + device: Optional[Union[int, str]] = 0, |
| 93 | + do_tta: Optional[bool] = True, |
65 | 94 | ) -> None: |
66 | 95 | # GPU + accurate + TTA |
67 | | - """skullstrips images with HD-BET generates a skullstripped file and mask""" |
| 96 | + """ |
| 97 | + Skull-strips images with HD-BET and generates a skull-stripped file and mask. |
| 98 | +
|
| 99 | + Args: |
| 100 | + input_image_path (str or Path): Path to the input image. |
| 101 | + masked_image_path (str or Path): Path where the brain-extracted image will be saved. |
| 102 | + brain_mask_path (str or Path): Path where the brain mask will be saved. |
| 103 | + log_file_path (str or Path, Optional): Path to the log file. |
| 104 | + mode (str or Mode): Extraction mode ('fast' or 'accurate'). |
| 105 | + device (str or int): Device to use for computation (e.g., 0 for GPU 0, 'cpu' for CPU). |
| 106 | + do_tta (bool): whether to do test time data augmentation by mirroring along all axes. |
| 107 | + """ |
| 108 | + |
| 109 | + # Ensure mode is a Mode enum instance |
| 110 | + if isinstance(mode, str): |
| 111 | + try: |
| 112 | + mode_enum = Mode(mode.lower()) |
| 113 | + except ValueError: |
| 114 | + raise ValueError(f"'{mode}' is not a valid Mode.") |
| 115 | + elif isinstance(mode, Mode): |
| 116 | + mode_enum = mode |
| 117 | + else: |
| 118 | + raise TypeError("Mode must be a string or a Mode enum instance.") |
| 119 | + |
| 120 | + # Run HD-BET |
68 | 121 | run_hd_bet( |
69 | | - mri_fnames=[input_image_path], |
70 | | - output_fnames=[masked_image_path], |
71 | | - # device=0, |
72 | | - # TODO consider postprocessing |
73 | | - # postprocess=False, |
74 | | - mode=mode, |
| 122 | + mri_fnames=[str(input_image_path)], |
| 123 | + output_fnames=[str(masked_image_path)], |
| 124 | + mode=mode_enum.value, |
75 | 125 | device=device, |
| 126 | + # TODO consider postprocessing |
76 | 127 | postprocess=False, |
77 | 128 | do_tta=do_tta, |
78 | 129 | keep_mask=True, |
79 | 130 | overwrite=True, |
80 | 131 | ) |
81 | 132 |
|
82 | | - hdbet_mask_path = ( |
83 | | - Path(masked_image_path).parent |
84 | | - / f"{name_extractor(masked_image_path)}_mask.nii.gz" |
| 133 | + # Construct the path to the generated mask |
| 134 | + masked_image_path = Path(masked_image_path) |
| 135 | + hdbet_mask_path = masked_image_path.with_name( |
| 136 | + masked_image_path.name.replace(".nii.gz", "_mask.nii.gz") |
85 | 137 | ) |
| 138 | + |
86 | 139 | if hdbet_mask_path.resolve() != Path(brain_mask_path).resolve(): |
87 | | - copyfile( |
88 | | - src=hdbet_mask_path, |
89 | | - dst=brain_mask_path, |
90 | | - ) |
| 140 | + try: |
| 141 | + shutil.copyfile( |
| 142 | + src=str(hdbet_mask_path), |
| 143 | + dst=str(brain_mask_path), |
| 144 | + ) |
| 145 | + except Exception as e: |
| 146 | + raise RuntimeError(f"Error copying mask file: {e}") from e |
0 commit comments