Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve cluster script & HCP pipeline #187

Merged
merged 22 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions PUMI/pipelines/func/deconfound.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,32 @@ def fieldmap_correction_qc(wf, volume='middle', **kwargs):

"""

def create_montage(vol_1, vol_2, vol_corrected):
def get_cut_cords(func, n_slices=10):
import nibabel as nib
import numpy as np

func_img = nib.load(func)
y_dim = func_img.shape[1] # y-dimension (coronal direction) is the second dimension in the image shape

slices = np.linspace(-y_dim / 2, y_dim / 2, n_slices)
# slices might contain floats but this is not a problem since nilearn will round floats to the
# nearest integer value!
return slices

def create_montage(vol_1, vol_2, vol_corrected, n_slices=10):
from matplotlib import pyplot as plt
from pathlib import Path
from nilearn import plotting
import os

fig, axes = plt.subplots(3, 1, facecolor='black', figsize=(10, 15))

plotting.plot_anat(vol_1, display_mode='ortho', title='Image #1', black_bg=True, axes=axes[0])
plotting.plot_anat(vol_2, display_mode='ortho', title='Image #2', black_bg=True, axes=axes[1])
plotting.plot_anat(vol_corrected, display_mode='ortho', title='Corrected', black_bg=True, axes=axes[2])
plotting.plot_anat(vol_1, display_mode='y', cut_coords=get_cut_cords(vol_1, n_slices=n_slices),
title='Image #1', black_bg=True, axes=axes[0])
plotting.plot_anat(vol_2, display_mode='y', cut_coords=get_cut_cords(vol_2, n_slices=n_slices),
title='Image #2', black_bg=True, axes=axes[1])
plotting.plot_anat(vol_corrected, display_mode='y', cut_coords=get_cut_cords(vol_corrected, n_slices=n_slices),
title='Corrected', black_bg=True, axes=axes[2])

path = str(Path(os.getcwd() + '/fieldmap_correction_comparison.png'))
plt.savefig(path)
Expand All @@ -66,19 +81,21 @@ def create_montage(vol_1, vol_2, vol_corrected):
wf.connect(vol_corrected, 'out_file', montage, 'vol_corrected')

wf.connect(montage, 'out_file', 'outputspec', 'out_file')
wf.connect(montage, 'out_file', 'sinker', 'out_file')
wf.connect(montage, 'out_file', 'sinker', 'qc_fieldmap_correction')


@FuncPipeline(inputspec_fields=['func_1', 'func_2'],
outputspec_fields=['out_file'])
def fieldmap_correction(wf, encoding_direction=['y-', 'y'], readout_times=[0.08264, 0.08264], tr=0.72, **kwargs):
def fieldmap_correction(wf, encoding_direction=['x-', 'x'], trt=[0.0522, 0.0522], tr=0.72, **kwargs):
"""

Fieldmap correction pipeline.

Parameters:
encoding_direction (list): List of encoding directions (default is left-right and right-left phase encoding).
readout_times (list): List of readout times (default adapted to rsfMRI data of the HCP WU 1200 dataset).
trt (list): List of total readout times (default adapted to rsfMRI data of the HCP WU 1200 dataset).
Default is:
1*(10**(-3))*EchoSpacingMS*EpiFactor = 1*(10**(-3))*0.58*90 = 0.0522 (for LR and RL image)
tr (float): Repetition time (default adapted to rsfMRI data of the HCP WU 1200 dataset).

Inputs:
Expand All @@ -91,8 +108,11 @@ def fieldmap_correction(wf, encoding_direction=['y-', 'y'], readout_times=[0.082
Sinking:
- 4d distortion corrected image.

For more information regarding the parameters:

For more information:
https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/topup/ExampleTopupFollowedByApplytopup
https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/topup/Faq#How_do_I_know_what_phase-encode_vectors_to_put_into_my_--datain_text_file.3F
https://www.humanconnectome.org/storage/app/media/documentation/s1200/HCP_S1200_Release_Appendix_I.pdf

"""

Expand Down Expand Up @@ -126,7 +146,7 @@ def fieldmap_correction(wf, encoding_direction=['y-', 'y'], readout_times=[0.082
# Estimate susceptibility induced distortions
topup = Node(fsl.TOPUP(), name='topup')
topup.inputs.encoding_direction = encoding_direction
topup.inputs.readout_times = readout_times
topup.inputs.readout_times = trt
wf.connect(merger, 'merged_file', topup, 'in_file')

# The two original 4D files are also needed inside a list
Expand Down
286 changes: 117 additions & 169 deletions pipelines/hcp_rcpl.py → pipelines/hcp.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
#!/usr/bin/env python3

import argparse
import glob
from PUMI import globals
from nipype import IdentityInterface, DataGrabber
from nipype.interfaces.fsl import Reorient2Std
from nipype.interfaces import afni
from PUMI.engine import BidsPipeline, NestedNode as Node, FuncPipeline, GroupPipeline, BidsApp
from PUMI.pipelines.anat.anat_proc import anat_proc
from PUMI.pipelines.func.compcor import anat_noise_roi
from PUMI.pipelines.func.compcor import anat_noise_roi, compcor
from PUMI.pipelines.anat.func_to_anat import func2anat
from PUMI.pipelines.func.deconfound import fieldmap_correction
from nipype.interfaces import utility

from PUMI.pipelines.func.deconfound import fieldmap_correction
from PUMI.pipelines.func.func_proc import func_proc_despike_afni
from PUMI.pipelines.func.timeseries_extractor import extract_timeseries_nativespace
from PUMI.pipelines.func.timeseries_extractor import pick_atlas, extract_timeseries_nativespace
from PUMI.utils import mist_modules, mist_labels, get_reference
from PUMI.pipelines.func.func2standard import func2standard
from PUMI.engine import NestedWorkflow as Workflow

from pathlib import Path
from PUMI.pipelines.multimodal.image_manipulation import pick_volume
from PUMI.engine import save_software_versions
import traits
import os

Expand Down Expand Up @@ -369,169 +365,121 @@ def merge_predictions(rpn_out_file, rcpl_out_file):
wf.connect(merge_predictions_wf, 'out_file', 'sinker', 'pain_predictions')


parser = argparse.ArgumentParser()
@BidsPipeline(output_query={
'T1w': dict(
datatype='anat',
suffix='T1w',
extension=['nii', 'nii.gz']
),
'bold_lr': dict(
datatype='func',
suffix='bold',
acquisition='LR',
extension=['nii', 'nii.gz']
),
'bold_rl': dict(
datatype='func',
suffix='bold',
acquisition='RL',
extension=['nii', 'nii.gz']
)
})
def hcp(wf, bbr=True, **kwargs):
"""
The HCP pipeline is the RCPL pipeline but with different inputs (two bold images with different phase encodings
instead of one bold image) and with additional fieldmap correction.

parser.add_argument(
'--bids_dir',
required=True,
help='Root directory of the input dataset.'
)
CAUTION: This pipeline assumes that you converted the HCP dataset into the BIDS format!
"""

parser.add_argument(
'--output_dir',
required=True,
help='Directory where the results will be stored.'
print('* bbr:', bbr)

reorient_struct_wf = Node(Reorient2Std(output_type='NIFTI_GZ'), name="reorient_struct_wf")
wf.connect('inputspec', 'T1w', reorient_struct_wf, 'in_file')

reorient_func_lr_wf = Node(Reorient2Std(output_type='NIFTI_GZ'), name="reorient_func_lr_wf")
wf.connect('inputspec', 'bold_lr', reorient_func_lr_wf, 'in_file')

reorient_func_rl_wf = Node(Reorient2Std(output_type='NIFTI_GZ'), name="reorient_func_rl_wf")
wf.connect('inputspec', 'bold_rl', reorient_func_rl_wf, 'in_file')

fieldmap_corr = fieldmap_correction('fieldmap_corr')
wf.connect(reorient_func_lr_wf, 'out_file', fieldmap_corr, 'func_1')
wf.connect(reorient_func_rl_wf, 'out_file', fieldmap_corr, 'func_2')

anatomical_preprocessing_wf = anat_proc(name='anatomical_preprocessing_wf', bet_tool='deepbet')
wf.connect(reorient_struct_wf, 'out_file', anatomical_preprocessing_wf, 'in_file')

func2anat_wf = func2anat(name='func2anat_wf', bbr=bbr)
wf.connect(fieldmap_corr, 'out_file', func2anat_wf, 'func')
wf.connect(anatomical_preprocessing_wf, 'brain', func2anat_wf, 'head')
wf.connect(anatomical_preprocessing_wf, 'probmap_wm', func2anat_wf, 'anat_wm_segmentation')
wf.connect(anatomical_preprocessing_wf, 'probmap_csf', func2anat_wf, 'anat_csf_segmentation')
wf.connect(anatomical_preprocessing_wf, 'probmap_gm', func2anat_wf, 'anat_gm_segmentation')
wf.connect(anatomical_preprocessing_wf, 'probmap_ventricle', func2anat_wf, 'anat_ventricle_segmentation')

compcor_roi_wf = anat_noise_roi('compcor_roi_wf')
wf.connect(func2anat_wf, 'wm_mask_in_funcspace', compcor_roi_wf, 'wm_mask')
wf.connect(func2anat_wf, 'ventricle_mask_in_funcspace', compcor_roi_wf, 'ventricle_mask')

func_proc_wf = func_proc_despike_afni('func_proc_wf', bet_tool='deepbet', deepbet_n_dilate=2)
wf.connect(fieldmap_corr, 'out_file', func_proc_wf, 'func')
wf.connect(compcor_roi_wf, 'out_file', func_proc_wf, 'cc_noise_roi')

pick_atlas_wf = mist_atlas('pick_atlas_wf')
mist_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data_in/atlas/MIST"))
pick_atlas_wf.get_node('inputspec').inputs.labelmap = os.path.join(mist_dir, 'Parcellations/MIST_122.nii.gz')
pick_atlas_wf.get_node('inputspec').inputs.modules = mist_modules(mist_directory=mist_dir, resolution="122")
pick_atlas_wf.get_node('inputspec').inputs.labels = mist_labels(mist_directory=mist_dir, resolution="122")

extract_timeseries = extract_timeseries_nativespace('extract_timeseries')
wf.connect(pick_atlas_wf, 'relabeled_atlas', extract_timeseries, 'atlas')
wf.connect(pick_atlas_wf, 'reordered_labels', extract_timeseries, 'labels')
wf.connect(pick_atlas_wf, 'reordered_modules', extract_timeseries, 'modules')
wf.connect(anatomical_preprocessing_wf, 'brain', extract_timeseries, 'anat')
wf.connect(func2anat_wf, 'anat_to_func_linear_xfm', extract_timeseries, 'inv_linear_reg_mtrx')
wf.connect(anatomical_preprocessing_wf, 'mni2anat_warpfield', extract_timeseries, 'inv_nonlinear_reg_mtrx')
wf.connect(func2anat_wf, 'gm_mask_in_funcspace', extract_timeseries, 'gm_mask')
wf.connect(func_proc_wf, 'func_preprocessed', extract_timeseries, 'func')
wf.connect(func_proc_wf, 'FD', extract_timeseries, 'confounds')

func2std = func2standard('func2std')
wf.connect(anatomical_preprocessing_wf, 'brain', func2std, 'anat')
wf.connect(func2anat_wf, 'func_to_anat_linear_xfm', func2std, 'linear_reg_mtrx')
wf.connect(anatomical_preprocessing_wf, 'anat2mni_warpfield', func2std, 'nonlinear_reg_mtrx')
wf.connect(anatomical_preprocessing_wf, 'std_template', func2std, 'reference_brain')
wf.connect(func_proc_wf, 'func_preprocessed', func2std, 'func')
wf.connect(func_proc_wf, 'mc_ref_vol', func2std, 'bbr2ants_source_file')

calculate_connectivity_wf = calculate_connectivity('calculate_connectivity_wf')
wf.connect(extract_timeseries, 'timeseries', calculate_connectivity_wf, 'ts_files')
wf.connect(func_proc_wf, 'FD', calculate_connectivity_wf, 'fd_files')

predict_pain_sensitivity_rpn_wf = predict_pain_sensitivity_rpn('predict_pain_sensitivity_rpn_wf')
wf.connect(calculate_connectivity_wf, 'features', predict_pain_sensitivity_rpn_wf, 'X')
wf.connect(fieldmap_corr, 'out_file', predict_pain_sensitivity_rpn_wf, 'in_file')

predict_pain_sensitivity_rcpl_wf = predict_pain_sensitivity_rcpl('predict_pain_sensitivity_rcpl_wf')
wf.connect(calculate_connectivity_wf, 'features', predict_pain_sensitivity_rcpl_wf, 'X')
wf.connect(fieldmap_corr, 'out_file', predict_pain_sensitivity_rcpl_wf, 'in_file')

collect_pain_predictions_wf = collect_pain_predictions('collect_pain_predictions_wf')
wf.connect(predict_pain_sensitivity_rpn_wf, 'out_file', collect_pain_predictions_wf, 'rpn_out_file')
wf.connect(predict_pain_sensitivity_rcpl_wf, 'out_file', collect_pain_predictions_wf, 'rcpl_out_file')

wf.write_graph('HCP-pipeline.png')
save_software_versions(wf)


hcp_app = BidsApp(
pipeline=hcp,
name='hcp'
)

parser.add_argument(
hcp_app.parser.add_argument(
'--bbr',
default='yes',
type=lambda x: (str(x).lower() == ['true', '1', 'yes']),
help='Use BBR registration: yes/no (default: yes)'
type=lambda x: (str(x).lower() in ['true', '1', 'yes']),
help="Use BBR registration: yes/no (default: yes)"
)

parser.add_argument('--n_procs', type=int,
help='Amount of threads to execute in parallel.'
+ 'If not set, the amount of CPU cores is used.'
+ 'Caution: Does only work with the MultiProc-plugin!')

parser.add_argument('--memory_gb', type=int,
help='Memory limit in GB. If not set, use 90% of the available memory'
+ 'Caution: Does only work with the MultiProc-plugin!')


cli_args = parser.parse_args()

input_dir = cli_args.bids_dir
output_dir = cli_args.output_dir
bbr = cli_args.bbr

plugin_args = {}
if cli_args.n_procs is not None:
plugin_args['n_procs'] = cli_args.n_procs

if cli_args.memory_gb is not None:
plugin_args['memory_gb'] = cli_args.memory_gb

subjects = []
excluded = []
for path in glob.glob(str(input_dir) + '/*'):
id = path.split('/')[-1]

base = path + '/unprocessed/3T/'

t1w_base = str(Path(base + '/T1w_MPR1/' + id + '_3T_T1w_MPR1.nii'))
has_t1w = os.path.isfile(t1w_base) or os.path.isfile(t1w_base + '.gz')

lr_base = str(Path(base + '/rfMRI_REST1_LR/' + id + '_3T_rfMRI_REST1_LR.nii'))
has_lr = os.path.isfile(lr_base) or os.path.isfile(lr_base + '.gz')

rl_base = str(Path(base + '/rfMRI_REST1_RL/' + id + '_3T_rfMRI_REST1_RL.nii'))
has_rl = os.path.isfile(rl_base) or os.path.isfile(rl_base + '.gz')

if has_t1w and has_lr and has_rl:
subjects.append(id)
else:
excluded.append(id)

print('-' * 100)
print(f'Included %d subjects.' % len(subjects))
print(f'Excluded %d subjects.' % len(excluded))
print('-' * 100)


wf = Workflow(name='HCP-RCPL')
wf.base_dir = '.'
globals.cfg_parser.set('SINKING', 'sink_dir', str(Path(os.path.abspath(output_dir + '/derivatives'))))
globals.cfg_parser.set('SINKING', 'qc_dir', str(Path(os.path.abspath(output_dir + '/derivatives/qc'))))


# Create a subroutine (subgraph) for every subject
inputspec = Node(interface=IdentityInterface(fields=['subject']), name='inputspec')
inputspec.iterables = [('subject', subjects)]

T1w_grabber = Node(DataGrabber(infields=['subject'], outfields=['T1w']), name='T1w_grabber')
T1w_grabber.inputs.base_directory = os.path.abspath(input_dir)
T1w_grabber.inputs.template = '%s/unprocessed/3T/T1w_MPR1/*T1w_MPR1.nii*'
T1w_grabber.inputs.sort_filelist = True
wf.connect(inputspec, 'subject', T1w_grabber, 'subject')

bold_lr_grabber = Node(DataGrabber(infields=['subject'], outfields=['bold_lr']), name='bold_lr_grabber')
bold_lr_grabber.inputs.base_directory = os.path.abspath(input_dir)
bold_lr_grabber.inputs.template = '%s/unprocessed/3T/rfMRI_REST1_LR/*_3T_rfMRI_REST1_LR.nii*'
bold_lr_grabber.inputs.sort_filelist = True
wf.connect(inputspec, 'subject', bold_lr_grabber, 'subject')

bold_rl_grabber = Node(DataGrabber(infields=['subject'], outfields=['bold_rl']), name='bold_rl_grabber')
bold_rl_grabber.inputs.base_directory = os.path.abspath(input_dir)
bold_rl_grabber.inputs.template = '%s/unprocessed/3T/rfMRI_REST1_RL/*_3T_rfMRI_REST1_RL.nii*'
bold_rl_grabber.inputs.sort_filelist = True
wf.connect(inputspec, 'subject', bold_rl_grabber, 'subject')

reorient_struct_wf = Node(Reorient2Std(output_type='NIFTI_GZ'), name="reorient_struct_wf")
wf.connect(T1w_grabber, 'T1w', reorient_struct_wf, 'in_file')

reorient_func_lr_wf = Node(Reorient2Std(output_type='NIFTI_GZ'), name="reorient_func_lr_wf")
wf.connect(bold_lr_grabber, 'bold_lr', reorient_func_lr_wf, 'in_file')

reorient_func_rl_wf = Node(Reorient2Std(output_type='NIFTI_GZ'), name="reorient_func_rl_wf")
wf.connect(bold_rl_grabber, 'bold_rl', reorient_func_rl_wf, 'in_file')

fieldmap_corr = fieldmap_correction('fieldmap_corr')
wf.connect(reorient_func_lr_wf, 'out_file', fieldmap_corr, 'func_1')
wf.connect(reorient_func_rl_wf, 'out_file', fieldmap_corr, 'func_2')

anatomical_preprocessing_wf = anat_proc(name='anatomical_preprocessing_wf', bet_tool='deepbet')
wf.connect(reorient_struct_wf, 'out_file', anatomical_preprocessing_wf, 'in_file')

func2anat_wf = func2anat(name='func2anat_wf', bbr=bbr)
wf.connect(fieldmap_corr, 'out_file', func2anat_wf, 'func')
wf.connect(anatomical_preprocessing_wf, 'brain', func2anat_wf, 'head')
wf.connect(anatomical_preprocessing_wf, 'probmap_wm', func2anat_wf, 'anat_wm_segmentation')
wf.connect(anatomical_preprocessing_wf, 'probmap_csf', func2anat_wf, 'anat_csf_segmentation')
wf.connect(anatomical_preprocessing_wf, 'probmap_gm', func2anat_wf, 'anat_gm_segmentation')
wf.connect(anatomical_preprocessing_wf, 'probmap_ventricle', func2anat_wf, 'anat_ventricle_segmentation')

compcor_roi_wf = anat_noise_roi('compcor_roi_wf')
wf.connect(func2anat_wf, 'wm_mask_in_funcspace', compcor_roi_wf, 'wm_mask')
wf.connect(func2anat_wf, 'ventricle_mask_in_funcspace', compcor_roi_wf, 'ventricle_mask')

func_proc_wf = func_proc_despike_afni('func_proc_wf', bet_tool='deepbet', deepbet_n_dilate=2)
wf.connect(fieldmap_corr, 'out_file', func_proc_wf, 'func')
wf.connect(compcor_roi_wf, 'out_file', func_proc_wf, 'cc_noise_roi')

pick_atlas_wf = mist_atlas('pick_atlas_wf')
mist_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data_in/atlas/MIST"))
pick_atlas_wf.get_node('inputspec').inputs.labelmap = os.path.join(mist_dir, 'Parcellations/MIST_122.nii.gz')
pick_atlas_wf.get_node('inputspec').inputs.modules = mist_modules(mist_directory=mist_dir, resolution="122")
pick_atlas_wf.get_node('inputspec').inputs.labels = mist_labels(mist_directory=mist_dir, resolution="122")

extract_timeseries = extract_timeseries_nativespace('extract_timeseries')
wf.connect(pick_atlas_wf, 'relabeled_atlas', extract_timeseries, 'atlas')
wf.connect(pick_atlas_wf, 'reordered_labels', extract_timeseries, 'labels')
wf.connect(pick_atlas_wf, 'reordered_modules', extract_timeseries, 'modules')
wf.connect(anatomical_preprocessing_wf, 'brain', extract_timeseries, 'anat')
wf.connect(func2anat_wf, 'anat_to_func_linear_xfm', extract_timeseries, 'inv_linear_reg_mtrx')
wf.connect(anatomical_preprocessing_wf, 'mni2anat_warpfield', extract_timeseries, 'inv_nonlinear_reg_mtrx')
wf.connect(func2anat_wf, 'gm_mask_in_funcspace', extract_timeseries, 'gm_mask')
wf.connect(func_proc_wf, 'func_preprocessed', extract_timeseries, 'func')
wf.connect(func_proc_wf, 'FD', extract_timeseries, 'confounds')

calculate_connectivity_wf = calculate_connectivity('calculate_connectivity_wf')
wf.connect(extract_timeseries, 'timeseries', calculate_connectivity_wf, 'ts_files')
wf.connect(func_proc_wf, 'FD', calculate_connectivity_wf, 'fd_files')

predict_pain_sensitivity_rpn_wf = predict_pain_sensitivity_rpn('predict_pain_sensitivity_rpn_wf')
wf.connect(calculate_connectivity_wf, 'features', predict_pain_sensitivity_rpn_wf, 'X')
wf.connect(fieldmap_corr, 'out_file', predict_pain_sensitivity_rpn_wf, 'in_file')

predict_pain_sensitivity_rcpl_wf = predict_pain_sensitivity_rcpl('predict_pain_sensitivity_rcpl_wf')
wf.connect(calculate_connectivity_wf, 'features', predict_pain_sensitivity_rcpl_wf, 'X')
wf.connect(fieldmap_corr, 'out_file', predict_pain_sensitivity_rcpl_wf, 'in_file')

collect_pain_predictions_wf = collect_pain_predictions('collect_pain_predictions_wf')
wf.connect(predict_pain_sensitivity_rpn_wf, 'out_file', collect_pain_predictions_wf, 'rpn_out_file')
wf.connect(predict_pain_sensitivity_rcpl_wf, 'out_file', collect_pain_predictions_wf, 'rcpl_out_file')

wf.write_graph('Pipeline.png')
wf.run(plugin='MultiProc', plugin_args=plugin_args)
hcp_app.run()
Loading
Loading