Skip to content

Commit

Permalink
[FIX] Fix some issues around the usage of Tracer and `SUVRReference…
Browse files Browse the repository at this point in the history
…Region` (aramis-lab#1280)
  • Loading branch information
NicolasGensollen authored Sep 5, 2024
1 parent 1183c61 commit e562e13
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 55 deletions.
7 changes: 4 additions & 3 deletions clinica/pipelines/machine_learning/classification_cli.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional
from typing import Optional, Union

import click

from clinica import option
from clinica.pipelines import cli_param
from clinica.pipelines.engine import clinica_pipeline
from clinica.utils.atlas import T1AndPetVolumeAtlasName
from clinica.utils.pet import SUVRReferenceRegion, Tracer

pipeline_name = "machinelearning-classification"

Expand Down Expand Up @@ -62,8 +63,8 @@ def cli(
subjects_visits_tsv: str,
diagnoses_tsv: str,
output_directory: str,
acq_label: Optional[str] = None,
suvr_reference_region: Optional[str] = None,
acq_label: Optional[Union[str, Tracer]] = None,
suvr_reference_region: Optional[Union[str, SUVRReferenceRegion]] = None,
atlas: Optional[str] = None,
n_procs: Optional[int] = None,
) -> None:
Expand Down
12 changes: 6 additions & 6 deletions clinica/pipelines/machine_learning/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,9 @@ def get_images(self):
"preprocessing",
f"group-{self._input_params['group_label']}",
"atlas_statistics",
f"{self._subjects[i]}_{self._sessions[i]}_trc-{self._input_params['acq_label']}_pet"
f"{self._subjects[i]}_{self._sessions[i]}_trc-{self._input_params['acq_label'].value}_pet"
f"_space-{self._input_params['atlas']}{pvc_key_value}"
f"_suvr-{self._input_params['suvr_reference_region']}_statistics.tsv",
f"_suvr-{self._input_params['suvr_reference_region'].value}_statistics.tsv",
)
for i in range(len(self._subjects))
]
Expand Down Expand Up @@ -383,8 +383,8 @@ def get_images(self):
self._sessions[i],
"pet",
"surface",
f"{self._subjects[i]}_{self._sessions[i]}_trc-{self._input_params['acq_label']}_pet"
f"_space-fsaverage_suvr-{self._input_params['suvr_reference_region']}"
f"{self._subjects[i]}_{self._sessions[i]}_trc-{self._input_params['acq_label'].value}_pet"
f"_space-fsaverage_suvr-{self._input_params['suvr_reference_region'].value}"
f"_pvc-iy_hemi-{h}_fwhm-{self._input_params['fwhm']}_projection.mgh",
)
for h in hemi
Expand Down Expand Up @@ -579,8 +579,8 @@ def get_images(self):
"pet",
"preprocessing",
f"group-{self._input_params['group_label']}",
f"{self._subjects[i]}_{self._sessions[i]}_trc-{self._input_params['acq_label']}_pet"
f"_space-Ixi549Space{pvc_key_value}_suvr-{self._input_params['suvr_reference_region']}"
f"{self._subjects[i]}_{self._sessions[i]}_trc-{self._input_params['acq_label'].value}_pet"
f"_space-Ixi549Space{pvc_key_value}_suvr-{self._input_params['suvr_reference_region'].value}"
f"_mask-brain{fwhm_key_value}_pet.nii.gz",
)
for i in range(len(self._subjects))
Expand Down
57 changes: 30 additions & 27 deletions clinica/pipelines/machine_learning/ml_workflows.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Optional, Union

import numpy as np

from clinica.pipelines.machine_learning import algorithm, base, input, validation
from clinica.utils.pet import SUVRReferenceRegion, Tracer


class VoxelBasedKFoldDualSVM(base.MLWorkflow):
Expand All @@ -14,11 +17,11 @@ def __init__(
output_dir,
fwhm=0,
modulated="on",
acq_label=None,
suvr_reference_region=None,
acq_label: Optional[Union[str, Tracer]] = None,
suvr_reference_region: Optional[Union[str, SUVRReferenceRegion]] = None,
use_pvc_data=False,
precomputed_kernel=None,
mask_zeros=True,
mask_zeros: bool = True,
n_threads=15,
n_folds=10,
grid_search_folds=10,
Expand Down Expand Up @@ -46,8 +49,8 @@ def __init__(
output_dir,
fwhm=0,
modulated="on",
acq_label=None,
suvr_reference_region=None,
acq_label: Optional[Union[str, Tracer]] = None,
suvr_reference_region: Optional[Union[str, SUVRReferenceRegion]] = None,
use_pvc_data=False,
precomputed_kernel=None,
mask_zeros=True,
Expand Down Expand Up @@ -79,8 +82,8 @@ def __init__(
output_dir,
fwhm=0,
modulated="on",
acq_label=None,
suvr_reference_region=None,
acq_label: Optional[Union[str, Tracer]] = None,
suvr_reference_region: Optional[Union[str, SUVRReferenceRegion]] = None,
use_pvc_data=False,
precomputed_kernel=None,
mask_zeros=True,
Expand Down Expand Up @@ -110,8 +113,8 @@ def __init__(
group_label,
output_dir,
image_type="PET",
acq_label=None,
suvr_reference_region=None,
acq_label: Optional[Union[str, Tracer]] = None,
suvr_reference_region: Optional[Union[str, SUVRReferenceRegion]] = None,
fwhm=20,
precomputed_kernel=None,
n_threads=15,
Expand Down Expand Up @@ -141,8 +144,8 @@ def __init__(
image_type,
atlas,
output_dir,
acq_label=None,
suvr_reference_region=None,
acq_label: Optional[Union[str, Tracer]] = None,
suvr_reference_region: Optional[Union[str, SUVRReferenceRegion]] = None,
use_pvc_data=False,
n_threads=15,
n_iterations=100,
Expand Down Expand Up @@ -171,8 +174,8 @@ def __init__(
image_type,
atlas,
output_dir,
acq_label=None,
suvr_reference_region=None,
acq_label: Optional[Union[str, Tracer]] = None,
suvr_reference_region: Optional[Union[str, SUVRReferenceRegion]] = None,
use_pvc_data=False,
n_threads=15,
n_iterations=100,
Expand Down Expand Up @@ -201,8 +204,8 @@ def __init__(
image_type,
atlas,
output_dir,
acq_label=None,
suvr_reference_region=None,
acq_label: Optional[Union[str, Tracer]] = None,
suvr_reference_region: Optional[Union[str, SUVRReferenceRegion]] = None,
use_pvc_data=False,
n_threads=15,
n_iterations=100,
Expand Down Expand Up @@ -234,8 +237,8 @@ def __init__(
image_type,
atlas,
output_dir,
acq_label=None,
suvr_reference_region=None,
acq_label: Optional[Union[str, Tracer]] = None,
suvr_reference_region: Optional[Union[str, SUVRReferenceRegion]] = None,
use_pvc_data=False,
precomputed_kernel=None,
n_threads=15,
Expand Down Expand Up @@ -266,8 +269,8 @@ def __init__(
output_dir,
fwhm=0,
modulated="on",
acq_label=None,
suvr_reference_region=None,
acq_label: Optional[Union[str, Tracer]] = None,
suvr_reference_region: Optional[Union[str, SUVRReferenceRegion]] = None,
use_pvc_data=False,
precomputed_kernel=None,
mask_zeros=True,
Expand Down Expand Up @@ -298,8 +301,8 @@ def __init__(
image_type,
atlas,
output_dir,
acq_label=None,
suvr_reference_region=None,
acq_label: Optional[Union[str, Tracer]] = None,
suvr_reference_region: Optional[Union[str, SUVRReferenceRegion]] = None,
use_pvc_data=False,
n_threads=15,
n_iterations=100,
Expand Down Expand Up @@ -330,8 +333,8 @@ def __init__(
atlas,
dataset,
output_dir,
acq_label=None,
suvr_reference_region=None,
acq_label: Optional[Union[str, Tracer]] = None,
suvr_reference_region: Optional[Union[str, SUVRReferenceRegion]] = None,
use_pvc_data=False,
n_threads=15,
n_iterations=100,
Expand Down Expand Up @@ -361,8 +364,8 @@ def __init__(
atlas,
dataset,
output_dir,
acq_label=None,
suvr_reference_region=None,
acq_label: Optional[Union[str, Tracer]] = None,
suvr_reference_region: Optional[Union[str, SUVRReferenceRegion]] = None,
use_pvc_data=False,
n_threads=15,
n_iterations=100,
Expand Down Expand Up @@ -398,8 +401,8 @@ def __init__(
output_dir,
fwhm=0,
modulated="on",
acq_label=None,
suvr_reference_region=None,
acq_label: Optional[Union[str, Tracer]] = None,
suvr_reference_region: Optional[Union[str, SUVRReferenceRegion]] = None,
use_pvc_data=False,
precomputed_kernel=None,
mask_zeros=True,
Expand Down
35 changes: 20 additions & 15 deletions clinica/pipelines/pet/engine.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
from clinica.pipelines.engine import Pipeline
from clinica.utils.pet import ReconstructionMethod, Tracer
from clinica.utils.stream import log_and_raise


class PETPipeline(Pipeline):
def _check_pipeline_parameters(self) -> None:
"""Check pipeline parameters."""
if "acq_label" not in self.parameters.keys():
raise KeyError("Missing compulsory acq_label key in pipeline parameter.")
self.parameters.setdefault("reconstruction_method", None)
from clinica.utils.exceptions import ClinicaPipelineConfigurationError

if "acq_label" not in self.parameters:
log_and_raise(
"Missing compulsory 'acq_label' key in pipeline parameter.",
ClinicaPipelineConfigurationError,
)
self.parameters["acq_label"] = Tracer(self.parameters["acq_label"])
if "reconstruction_method" in self.parameters:
if self.parameters["reconstruction_method"]:
self.parameters["reconstruction_method"] = ReconstructionMethod(
self.parameters["reconstruction_method"]
)
else:
self.parameters["reconstruction_method"] = None

def _get_pet_scans_query(self) -> dict:
"""Return the query to retrieve PET scans."""
from clinica.utils.input_files import bids_pet_nii
from clinica.utils.pet import ReconstructionMethod, Tracer

pet_tracer = None
if self.parameters["acq_label"] is not None:
pet_tracer = Tracer(self.parameters["acq_label"])

reconstruction_method = None
if self.parameters["reconstruction_method"] is not None:
reconstruction_method = ReconstructionMethod(
self.parameters["reconstruction_method"]
)

return bids_pet_nii(pet_tracer, reconstruction_method)
return bids_pet_nii(
self.parameters["acq_label"], self.parameters["reconstruction_method"]
)
9 changes: 5 additions & 4 deletions clinica/pipelines/pet/volume/cli.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import List, Optional
from typing import List, Optional, Union

import click

from clinica import option
from clinica.pipelines import cli_param
from clinica.pipelines.engine import clinica_pipeline
from clinica.utils.pet import ReconstructionMethod, SUVRReferenceRegion, Tracer

pipeline_name = "pet-volume"

Expand Down Expand Up @@ -62,9 +63,9 @@ def cli(
bids_directory: str,
caps_directory: str,
group_label: str,
acq_label: str,
suvr_reference_region: Optional[str] = None,
reconstruction_method: Optional[str] = None,
acq_label: Union[str, Tracer],
suvr_reference_region: Optional[Union[str, SUVRReferenceRegion]] = None,
reconstruction_method: Optional[Union[str, ReconstructionMethod]] = None,
pvc_psf_tsv: Optional[str] = None,
mask_tissues: List[int] = (1, 2, 3),
mask_threshold: float = 0.3,
Expand Down
4 changes: 4 additions & 0 deletions clinica/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class ClinicaXMLParserError(ClinicaParserError):
"""Base class for XML parser errors."""


class ClinicaPipelineConfigurationError(ClinicaException):
"""Base class for configuration errors of clinica pipelines."""


class ClinicaInconsistentDatasetError(ClinicaException):
"""Base class for inconsistent datasets errors."""

Expand Down

0 comments on commit e562e13

Please sign in to comment.