From 52cabec23b56ad9d63d82acdf84e83deb5db2afc Mon Sep 17 00:00:00 2001 From: Mats Date: Mon, 6 Jan 2025 16:47:01 +0000 Subject: [PATCH] reorganize and clean up surface based minimum_norm --- osl_ephys/report/src_report.py | 19 +- osl_ephys/source_recon/freesurfer_utils.py | 57 +-- osl_ephys/source_recon/minimum_norm.py | 117 +++--- .../source_recon/parcellation/parcellation.py | 31 +- osl_ephys/source_recon/wrappers.py | 356 +++++++++++++++--- 5 files changed, 394 insertions(+), 186 deletions(-) diff --git a/osl_ephys/report/src_report.py b/osl_ephys/report/src_report.py index f6baab5..8ac6037 100644 --- a/osl_ephys/report/src_report.py +++ b/osl_ephys/report/src_report.py @@ -76,16 +76,19 @@ def gen_html_data(config, outdir, subject, reportdir, logger=None, extra_funcs=N data["coregister"] = data.pop("coregister", False) data["beamform"] = data.pop("beamform", False) data["beamform_and_parcellate"] = data.pop("beamform_and_parcellate", False) + data["minimum_norm"] = data.pop("minimum_norm", False) + data["minimum_norm_and_parcellate"] = data.pop("minimum_norm_and_parcellate", False) data["fix_sign_ambiguity"] = data.pop("fix_sign_ambiguity", False) # Save info - if data["beamform_and_parcellate"]: + if data["beamform_and_parcellate"] or data['minimum_norm_and_parcellate']: data["n_samples"] = data["n_samples"] if data["coregister"]: data["fid_err"] = data["fid_err"] - if data["beamform_and_parcellate"]: + if data["beamform_and_parcellate"] or data['minimum_norm_and_parcellate']: data["parcellation_file"] = data["parcellation_file"] - data["parcellation_filename"] = Path(data["parcellation_file"]).name + if data["beamform_and_parcellate"]: + data["parcellation_filename"] = Path(data["parcellation_file"]).name if data["fix_sign_ambiguity"]: data["template"] = data["template"] data["metrics"] = data["metrics"] @@ -242,6 +245,8 @@ def gen_html_summary(reportdir, logsdir=None): data["coregister"] = subject_data[0]["coregister"] data["beamform"] = subject_data[0]["beamform"] data["beamform_and_parcellate"] = subject_data[0]["beamform_and_parcellate"] + data["minimum_norm"] = subject_data[0]["minimum_norm"] + data["minimum_norm_and_parcellate"] = subject_data[0]["minimum_norm_and_parcellate"] data["fix_sign_ambiguity"] = subject_data[0]["fix_sign_ambiguity"] if data["coregister"]: @@ -249,8 +254,12 @@ def gen_html_summary(reportdir, logsdir=None): fid_err_table = pd.DataFrame() fid_err_table["Session ID"] = [subject_data[i]["fif_id"] for i in range(len(subject_data))] - for i_err, hdr in enumerate(["Nasion", "LPA", "RPA"]): - fid_err_table[hdr] = [np.round(subject_data[i]['fid_err'][i_err], decimals=2) if 'fid_err' in subject_data[i].keys() else None for i in range(len(subject_data))] + if len(subject_data[0]['fid_err'])==4: + for i_err, hdr in enumerate(["Nasion", "LPA", "RPA", "Med(HSP-MRI)"]): + fid_err_table[hdr] = [np.round(subject_data[i]['fid_err'][i_err], decimals=2) if 'fid_err' in subject_data[i].keys() else None for i in range(len(subject_data))] + else: + for i_err, hdr in enumerate(["Nasion", "LPA", "RPA"]): + fid_err_table[hdr] = [np.round(subject_data[i]['fid_err'][i_err], decimals=2) if 'fid_err' in subject_data[i].keys() else None for i in range(len(subject_data))] fid_err_table.index += 1 # Start indexing from 1 data['coreg_table'] = fid_err_table.to_html(classes="display", table_id="coreg_tbl") diff --git a/osl_ephys/source_recon/freesurfer_utils.py b/osl_ephys/source_recon/freesurfer_utils.py index 5d554e7..578de3a 100644 --- a/osl_ephys/source_recon/freesurfer_utils.py +++ b/osl_ephys/source_recon/freesurfer_utils.py @@ -93,20 +93,7 @@ def get_freesurfer_files(subjects_dir, subject): "std_brain_mri": op.join(os.environ["FREESURFER_HOME"], "subjects", "fsaverage", "mri", "T1.mgz"), "completed": op.join(surfaces_dir, "completed.txt"), } - # "mni2mri_flirt_xform_file": op.join(surfaces_dir, "tranforms", "talairach.xfm"), - # "mni_mri_t_file": op.join(surfaces_dir, "mni_mri-trans.fif"), - # "bet_outskin_mesh_vtk_file": op.join(surfaces_dir, "outskin_mesh.vtk"), # BET output - # "bet_inskull_mesh_vtk_file": op.join(surfaces_dir, "inskull_mesh.vtk"), # BET output - # "bet_outskull_mesh_vtk_file": op.join(surfaces_dir, "outskull_mesh.vtk"), # BET output - # "bet_outskin_mesh_file": op.join(surfaces_dir, "outskin_mesh.nii.gz"), - # "bet_outskin_plus_nose_mesh_file": op.join(surfaces_dir, "outskin_plus_nose_mesh.nii.gz"), - # "bet_inskull_mesh_file": op.join(surfaces_dir, "inskull_mesh.nii.gz"), - # "bet_outskull_mesh_file": op.join(surfaces_dir, "outskull_mesh.nii.gz"), - # "std_brain": op.join(os.environ["FSLDIR"], "data", "standard", "MNI152_T1_1mm_brain.nii.gz"), - # "std_brain_bigfov": op.join(os.environ["FSLDIR"], "data", "standard", "MNI152_T1_1mm_BigFoV_facemask.nii.gz"), - # "completed": op.join(surfaces_dir, "completed.txt"), - # } - + # Coregistration files coreg_dir = op.join(fs_dir, "mne_src") os.makedirs(coreg_dir, exist_ok=True) @@ -116,46 +103,10 @@ def get_freesurfer_files(subjects_dir, subject): "coreg_trans": op.join(coreg_dir, "coreg-trans.fif"), "coreg_html": op.join(coreg_dir, "coreg.html"), "source_space": op.join(coreg_dir, "space-src.fif"), - # "stc": op.join(coreg_dir, "stc-data", f"{subject}-lh.stc"), - # "filters_plot_cov": op.join(coreg_dir, "noise_cov.png"), - # "filters_plot_svd": op.join(coreg_dir, "noise_svd.png"), + "inverse_solution": op.join(coreg_dir, "{0}-inv.fif"), + "source_estimate_raw": op.join(coreg_dir, "src-raw"), # followed by -lh/rh.stc + "source_estimate_epo": op.join(coreg_dir, "src-epo"), } - - - # "info_fif_file": op.join(coreg_dir, "info-raw.fif"), - # "smri_file": op.join(coreg_dir, "scaled_smri.nii.gz"), - # "head_scaledmri_t_file": op.join(coreg_dir, "head_scaledmri-trans.fif"), - # "head_mri_t_file": op.join(coreg_dir, "head_mri-trans.fif"), - # "ctf_head_mri_t_file": op.join(coreg_dir, "ctf_head_mri-trans.fif"), - # "mrivoxel_scaledmri_t_file": op.join(coreg_dir, "mrivoxel_scaledmri_t_file-trans.fif"), - # "mni_nasion_mni_file": op.join(coreg_dir, "mni_nasion.txt"), - # "mni_rpa_mni_file": op.join(coreg_dir, "mni_rpa.txt"), - # "mni_lpa_mni_file": op.join(coreg_dir, "mni_lpa.txt"), - # "smri_nasion_file": op.join(coreg_dir, "smri_nasion.txt"), - # "smri_rpa_file": op.join(coreg_dir, "smri_rpa.txt"), - # "smri_lpa_file": op.join(coreg_dir, "smri_lpa.txt"), - # "polhemus_nasion_file": op.join(coreg_dir, "polhemus_nasion.txt"), - # "polhemus_rpa_file": op.join(coreg_dir, "polhemus_rpa.txt"), - # "polhemus_lpa_file": op.join(coreg_dir, "polhemus_lpa.txt"), - # "polhemus_headshape_file": op.join(coreg_dir, "polhemus_headshape.txt"), - # # BET mesh output in native space - # "bet_outskin_mesh_vtk_file": op.join(coreg_dir, "scaled_outskin_mesh.vtk"), - # "bet_inskull_mesh_vtk_file": op.join(coreg_dir, "scaled_inskull_mesh.vtk"), - # "bet_outskull_mesh_vtk_file": op.join(coreg_dir, "scaled_outskull_mesh.vtk"), - # # Freesurfer mesh in native space - # # - these are the ones shown in coreg_display() if doing surf plot - # # - these are also used by MNE forward modelling - # "bet_outskin_surf_file": op.join(coreg_dir, "scaled_outskin_surf.surf"), - # "bet_inskull_surf_file": op.join(coreg_dir, "scaled_inskull_surf.surf"), - # "bet_outskull_surf_file": op.join(coreg_dir, "scaled_outskull_surf.surf"), - # "bet_outskin_plus_nose_surf_file": op.join(coreg_dir, "scaled_outskin_plus_nose_surf.surf"), - # # BET output surface mask as nii in native space - # "bet_outskin_mesh_file": op.join(coreg_dir, "scaled_outskin_mesh.nii.gz"), - # "bet_outskin_plus_nose_mesh_file": op.join(coreg_dir, "scaled_outskin_plus_nose_mesh.nii.gz"), - # "bet_inskull_mesh_file": op.join(coreg_dir, "scaled_inskull_mesh.nii.gz"), - # "bet_outskull_mesh_file": op.join(coreg_dir, "scaled_outskull_mesh.nii.gz"), - # "std_brain": op.join(os.environ["FSLDIR"], "data", "standard", "MNI152_T1_1mm_brain.nii.gz"), - # All Freesurfer files files files = {"surf": surf_files, "coreg": coreg_files, "fwd_model": op.join(fs_dir, "model-fwd.fif")} diff --git a/osl_ephys/source_recon/minimum_norm.py b/osl_ephys/source_recon/minimum_norm.py index 9d09c72..7786bad 100644 --- a/osl_ephys/source_recon/minimum_norm.py +++ b/osl_ephys/source_recon/minimum_norm.py @@ -29,17 +29,12 @@ def minimum_norm( outdir, subject, - preproc_file, - epoch_file, + data, chantypes, method, rank, - morph=False, - lambda2=0.1, - depth=0.8, - loose='auto', - freq_range=None, - pick_ori="normal", + depth, + loose, ): """Run minimum norm source localization. @@ -49,10 +44,8 @@ def minimum_norm( Output directory. subject : str Subject ID. - preproc_file : str - Preprocessed file. - epoch_file : str - Epoch file. + data : mnep.io.Raw, mne.Epochs + Preprocessed data. chantypes : list List of channel types to include. method : str @@ -61,77 +54,93 @@ def minimum_norm( Rank of the data covariance matrix. morph : bool, str Morph method, e.g. fsaverage. Can be False. - lambda2 : float - Regularization parameter. depth : float Depth weighting. loose : float Loose parameter. - freq_range : list - Band pass filter applied before source estimation. - weight_norm : str - Weight normalization. - pick_ori : str - Orientation to pick. reg : float Regularization parameter. reportdir : str Report directory. - - """ - - if preproc_file is None: - preproc_file = epoch_file + """ log_or_print("*** RUNNING MNE SOURCE LOCALIZATION ***") fwd_fname = freesurfer_utils.get_freesurfer_files(outdir, subject)['fwd_model'] coreg_files = freesurfer_utils.get_coreg_filenames(outdir, subject) - if epoch_file is not None: - data = mne.read_epochs(epoch_file, preload=True) - else: - data = mne.io.read_raw(preproc_file, preload=True) - - # Bandpass filter - if freq_range is not None: - logger.info(f"bandpass filtering: {freq_range[0]}-{freq_range[1]} Hz") - data = data.filter( - l_freq=freq_range[0], - h_freq=freq_range[1], - method="iir", - iir_params={"order": 5, "ftype": "butter"}, - ) - - - if isinstance(data, mne.io.Raw): - data_cov = mne.compute_raw_covariance(data, method="empirical", rank=rank) - else: - data_cov = mne.compute_covariance(data, method="empirical", rank=rank) noise_cov = calc_noise_cov(data, rank, chantypes) fwd = mne.read_forward_solution(fwd_fname) + log_or_print(f"*** Making {method} inverse solution ***") inverse_operator = mne.minimum_norm.make_inverse_operator(data.info, fwd, noise_cov, loose=loose, depth=depth) - del fwd + + log_or_print(f"*** Saving {method} inverse operator ***") + mne.minimum_norm.write_inverse_operator(coreg_files['inverse_operator'].format(method), inverse_operator, overwrite=True) + return inverse_operator + + +def apply_inverse_solution( + outdir, + subject, + data, + method, + lambda2, + pick_ori, + inverse_operator=None, + morph="fsaverage", + save=False, + ): + """ Apply previously computed minimum norm inverse solution. + + Parameters + ---------- + outdir : str + Output directory. + subject : str + Subject ID. + data : mne.io.Raw, mne.Epochs + Raw or Epochs object. + inverse_operator : mne.minimum_norm.InverseOperator + Inverse operator. + method : str + Inverse method. + lambda2 : float + Regularization parameter. + pick_ori : str + Orientation to pick. + morph : bool, str + Morph method, e.g. fsaverage. Can be False. + save : bool + Save source estimate (default: False). + """ + + coreg_files = freesurfer_utils.get_coreg_filenames(outdir, subject) + if inverse_operator is None: + inverse_operator = mne.minimum_norm.read_inverse_operator(coreg_files['inverse_operator'].format(method)) log_or_print(f"*** Applying {method} inverse solution ***") - if epoch_file is not None: + if isinstance(data, mne.Epochs): stc = mne.minimum_norm.apply_inverse_epochs(data, inverse_operator, lambda2=lambda2, method=method, pick_ori=pick_ori) else: stc = mne.minimum_norm.apply_inverse_raw(data, inverse_operator, lambda2=lambda2, method=method, pick_ori=pick_ori) if morph: + log_or_print(f"*** Morphing source estimate to {morph} ***") src_from = mne.read_source_spaces(coreg_files['source_space']) morph = morph_surface(outdir, subject, src_from, subject_to=morph) morph.save(op.join(outdir, subject, "mne_src", morph), overwrite=True) stc = morph.apply(stc) - - if epoch_file is not None: - stc.save(op.join(outdir, subject, "mne_src", "src-epo"), overwrite=True) - else: - stc.save(op.join(outdir, subject, "mne_src", "src-raw"), overwrite=True) - + + if save: + log_or_print(f"*** Saving Source estimate ***") + if isinstance(data, mne.Epochs): + stc.save(op.join(outdir, subject, "mne_src", "src-epo"), overwrite=True) + else: + stc.save(op.join(outdir, subject, "mne_src", "src-raw"), overwrite=True) + return stc + def calc_noise_cov(data, data_cov_rank, chantypes): """Calculate noise covariance. @@ -175,7 +184,7 @@ def calc_noise_cov(data, data_cov_rank, chantypes): bads = [b for b in data.info["bads"] if b in data_cov.ch_names] noise_cov = mne.Covariance( - noise_cov_diag, data_cov.ch_names, bads, data.info["projs"], nfree=1e10 + noise_cov_diag, data_cov.ch_names, bads, data.info["projs"], nfree=data.n_times ) return noise_cov diff --git a/osl_ephys/source_recon/parcellation/parcellation.py b/osl_ephys/source_recon/parcellation/parcellation.py index 404b613..e4379ec 100644 --- a/osl_ephys/source_recon/parcellation/parcellation.py +++ b/osl_ephys/source_recon/parcellation/parcellation.py @@ -24,13 +24,15 @@ from osl_ephys.utils.logger import log_or_print from osl_ephys.source_recon import freesurfer_utils -def load_parcellation(parcellation_file): +def load_parcellation(parcellation_file, subject=None): """Load a parcellation file. Parameters ---------- parcellation_file : str Path to parcellation file. + subject : str + Subject ID. Only needed for FreeSurfer parcellations. Returns ------- @@ -40,9 +42,12 @@ def load_parcellation(parcellation_file): # check if it's a freesurfer parcellation if 'SUBJECTS_DIR' in os.environ: - avail = mne.label._read_annot_cands(os.path.join(os.environ["SUBJECTS_DIR"], 'fsaverage', 'label')) + if subject is None: + subject = "fsaverage" + + avail = mne.label._read_annot_cands(os.path.join(os.environ["SUBJECTS_DIR"], subject, 'label')) if parcellation_file in avail: - labels = mne.label.read_labels_from_annot('fsaverage', parcellation_file) + labels = mne.label.read_labels_from_annot(subject, parcellation_file) if parcellation_file == 'aparc' or parcellation_file == "oasis.chubs": labels = [l for l in labels if "unknown" not in l.name] elif parcellation_file == 'aparc.a2009s': @@ -55,7 +60,7 @@ def load_parcellation(parcellation_file): labels = [l for l in labels if "LOBE" in l.name] return labels - # otherwise, load the parcellation file + # otherwise, load the nifti parcellation file parcellation_file = find_file(parcellation_file) return nib.load(parcellation_file) @@ -363,7 +368,7 @@ def _get_parcel_timeseries(voxel_timeseries, parcellation_asmatrix, method="spat return parcel_timeseries, voxel_weightings, voxel_assignments -def surf_parcellate_timeseries(subject_dir, subject, preproc_file, epoch_file, method, parcellation_file): +def surf_parcellate_timeseries(subject_dir, subject, stc, method, parc): """Save parcellated data as a fif file. Parameters @@ -372,23 +377,17 @@ def surf_parcellate_timeseries(subject_dir, subject, preproc_file, epoch_file, m Path to subject directory. subject : str Subject ID. - preproc_file : str - Path to preprocessed file. - epoch_file : str or None - Path to epoch file. + stc : mne.SourceEstimate + Source estimate. method : str Parcellation method. Can be 'pca_flip', 'max', 'mean', 'mean_flip', 'auto' + parc : str + Parcellation name. """ fs_files = freesurfer_utils.get_freesurfer_files(subject_dir, subject) - if epoch_file is not None: - src_file = op.join(subject_dir, subject, "mne_src", "src-epo") - else: - src_file = op.join(subject_dir, subject, "mne_src", "src-raw") - labels = mne.read_labels_from_annot(subjects_dir=subject_dir, subject=subject, parc=parcellation_file) - labels = [l for l in labels if "unknown" not in l.name] + labels = load_parcellation(parc, subject=subject) - stc = mne.read_source_estimate(src_file, subject=subject) src = mne.read_source_spaces(fs_files['coreg']['source_space']) parcel_data = mne.extract_label_time_course(stc, labels, src, mode=method) return parcel_data diff --git a/osl_ephys/source_recon/wrappers.py b/osl_ephys/source_recon/wrappers.py index 9f171dd..bd83829 100644 --- a/osl_ephys/source_recon/wrappers.py +++ b/osl_ephys/source_recon/wrappers.py @@ -20,7 +20,7 @@ import numpy as np from . import rhino, beamforming, parcellation, sign_flipping, freesurfer_utils -from .minimum_norm import minimum_norm as minimum_norm_estimate +from .minimum_norm import minimum_norm, apply_inverse_solution as minimum_norm_estimate, apply_inverse_solution from ..report import src_report from ..utils.logger import log_or_print @@ -343,6 +343,7 @@ def coregister( elif mode=="mne" or mode=='freesurfer': coreg_files = freesurfer_utils.get_coreg_filenames(outdir, subject) + coreg_filename = coreg_files['coreg_html'] def save_coreg_html(filename): fig = mne.viz.plot_alignment(info, trans=coreg.trans, **plot_kwargs) @@ -366,9 +367,11 @@ def save_coreg_html(filename): if n_init is None: n_init = 20 - coreg.fit_fiducials(verbose=True) + fiducials_kwargs = kwargs.pop("fit_fiducials", {}) + coreg.fit_fiducials(**fiducials_kwargs) - coreg.fit_icp(n_iterations=n_init, verbose=True, **kwargs) + icp_kwargs = kwargs.pop("fit_icp", {}) + coreg.fit_icp(n_iterations=n_init, **icp_kwargs) #coreg.omit_head_shape_points(distance=1e-3) plot_kwargs = dict( @@ -411,6 +414,15 @@ def save_coreg_html(filename): "coreg_plot": coreg_filename, }, ) + + if mode=="mne" or mode=='freesurfer': + src_report.add_to_data( + f"{reportdir}/{subject}/data.pkl", + { + "fiducials_kwargs": fiducials_kwargs, + "icp_kwargs": icp_kwargs, + } + ) def forward_model( @@ -470,16 +482,22 @@ def forward_model( ) mne.write_source_spaces(filenames['source_space'], src, overwrite=True) - if model == "Single Layer": - conductivity = (0.3,) # for single layer - elif model == "Triple Layer": - conductivity = (0.3, 0.006, 0.3) # for three layers + conductivity = kwargs.pop("conductivity", None) + if conductivity is None: + if model == "Single Layer": + conductivity = (0.3,) # for single layer + elif model == "Triple Layer": + conductivity = (0.3, 0.006, 0.3) # for three layers + + + ico = kwargs.pop("ico", 4) + mindist = kwargs.pop("mindist", 0) model = mne.make_bem_model( subjects_dir=outdir, subject=subject, conductivity=conductivity, - **kwargs + ico=ico, ) bem = mne.make_bem_solution(model) @@ -491,10 +509,7 @@ def forward_model( trans=trans, src=src, bem=bem, - meg=True, - eeg=False, - mindist=5.0, - verbose=True, + mindist=mindist, ) mne.write_forward_solution(fwd_fname, fwd, overwrite=True) log_or_print("*** FINISHED SURFACE BASED FORWARD MODEL ***") @@ -511,6 +526,16 @@ def forward_model( "eeg": eeg, }, ) + + if mode == 'surface' or mode == 'surf': + src_report.add_to_data( + f"{reportdir}/{subject}/data.pkl", + { + "conductivity": conductivity, + "ico": ico, + "mindist": mindist, + } + ) # ------------------------------------- @@ -628,9 +653,13 @@ def minimum_norm( subject, preproc_file, epoch_file, + source_method, chantypes, - method, rank, + depth, + loose, + reg, + pick_ori, freq_range=None, reportdir=None, **kwargs, @@ -646,44 +675,69 @@ def minimum_norm( preproc_file : str Path to the preprocessed fif file. epoch_file : str - Path to epoched preprocessed fif file. + Path to epoched preprocessed fif file. + source_method : str + Method to use in the source localization. E.g., 'MNE' or 'dSPM'. chantypes : list List of channel types to include. method : str Method to use in the source localization. rank : int Rank of the noise covariance matrix. + depth : float + Depth weighting. + reg : float + Regularization parameter for the minimum norm estimate. + pick_ori : str + Orientation of the dipoles. freq_range : list, optional Lower and upper band to bandpass filter before beamforming. If None, no filtering is done reportdir : str, optional Path to report directory. """ - - logger.info("MNE source localize") + if epoch_file is not None: + data = mne.read_epochs(epoch_file, preload=True) + else: + data = mne.io.read_raw(preproc_file, preload=True) + + # Bandpass filter + if freq_range is not None: + logger.info(f"bandpass filtering: {freq_range[0]}-{freq_range[1]} Hz") + data = data.filter( + l_freq=freq_range[0], + h_freq=freq_range[1], + method="iir", + iir_params={"order": 5, "ftype": "butter"}, + ) + + logger.info("MNE source localize") minimum_norm_estimate( outdir, subject, - preproc_file, - epoch_file, + data, chantypes, - method, + source_method, rank, - freq_range=freq_range, + depth, + loose, **kwargs, ) if reportdir is not None: - # Save info for the report src_report.add_to_data( f"{reportdir}/{subject}/data.pkl", { "minimum_norm": True, "chantypes": chantypes, - "method": method, + "method": source_method, "rank": rank, + "depth": depth, + "loose": loose, + "lambda2": reg, + "pick_ori": pick_ori, "freq_range": freq_range, }, ) @@ -699,9 +753,10 @@ def parcellate( orthogonalisation, source_method='lcmv', spatial_resolution=None, - reference_brain="mni", + reference_brain="mni/fsaverage", extra_chans="stim", reportdir=None, + **kwargs, ): """Wrapper function for parcellation. @@ -721,6 +776,10 @@ def parcellate( Method to use in the parcellation. orthogonalisation : bool Should we do orthogonalisation? + lambda2 : float + The regularisation parameter for the minimum norm estimate. + pick_ori : str + Orientation of the dipoles, used for minimum_norm. source_method : str, optional Method used for source reconstruction. Can be 'lcmv' or 'mne'. spatial_resolution : int, optional @@ -728,7 +787,10 @@ def parcellate( (must be an integer, or will be cast to nearest int). If None, then the gridstep used in coreg_filenames['forward_model_file'] is used. reference_brain : str, optional - 'mni' indicates that the reference_brain is the stdbrain in MNI space. + 'mni' indicates that the reference_brain is the stdbrain in MNI space (volumetric). + 'fsaverage' indicates that the reference_brain is the fsaverage subject (surface). + 'mni/fsaverage' indicates that the reference_brain depends on the source_method + ('mni' for lcmv, 'fsaverage' for mne). 'mri' indicates that the reference_brain is the subject's sMRI in the scaled native/mri space. 'unscaled_mri' indicates that the reference_brain is the subject's @@ -752,31 +814,35 @@ def parcellate( else: data = mne.io.read_raw_fif(preproc_file, preload=True) - # beamforming is applied in place, whereas linear inverse methods are loaded from disk. - if source_method == 'lcmv' or source_method == 'beamform': - if reportdir is None: - raise ValueError( + if reportdir is None: + raise ValueError( "This function can only be used when a report was generated " - "when beamforming. Please use beamform_and_parcellate." - ) - - # Get settings passed to the beamform wrapper - report_data = pickle.load(open(f"{reportdir}/{subject}/data.pkl", "rb")) - freq_range = report_data.pop("freq_range") - chantypes = report_data.pop("chantypes") - if isinstance(chantypes, str): - chantypes = [chantypes] - - # Bandpass filter - if freq_range is not None: - logger.info(f"bandpass filtering: {freq_range[0]}-{freq_range[1]} Hz") - data = data.filter( - l_freq=freq_range[0], - h_freq=freq_range[1], - method="iir", - iir_params={"order": 5, "ftype": "butter"}, + "when using source estimation (beamforming/minimum_norm). Please use beamform_and_parcellate" + "or minimum_norm_and_parcellate." ) + # Get settings passed to the beamform/minimum_norm wrapper + report_data = pickle.load(open(f"{reportdir}/{subject}/data.pkl", "rb")) + freq_range = report_data.pop("freq_range") + chantypes = report_data.pop("chantypes") + if isinstance(chantypes, str): + chantypes = [chantypes] + + # Bandpass filter + if freq_range is not None: + logger.info(f"bandpass filtering: {freq_range[0]}-{freq_range[1]} Hz") + data = data.filter( + l_freq=freq_range[0], + h_freq=freq_range[1], + method="iir", + iir_params={"order": 5, "ftype": "butter"}, + ) + + # source recon is applied in place + if source_method == 'lcmv' or source_method == 'beamform': + if reference_brain == 'mni/fsaverage': + reference_brain = 'mni' + # Pick channels chantype_data = data.copy().pick(chantypes) @@ -806,18 +872,20 @@ def parcellate( working_dir=f"{outdir}/{subject}/parc", ) - elif source_method=='minimum_norm': - if reportdir is not None: - report_data = pickle.load(open(f"{reportdir}/{subject}/data.pkl", "rb")) - freq_range = report_data.pop("freq_range") - chantypes = report_data.pop("chantypes") - if isinstance(chantypes, str): - chantypes = [chantypes] - else: - freq_range = [np.max([data.info['highpass'], 0]), np.min([data.info['lowpass'], 100])] - - - parcel_data = parcellation.surf_parcellate_timeseries(outdir, subject, preproc_file, epoch_file, method, parcellation_file) + elif source_method=='minimum_norm': + if reference_brain == 'mni/fsaverage': + reference_brain = 'fsaverage' + elif reference_brain == 'mri': + reference_brain = subject + + pick_ori = report_data.pop("pick_ori") + lambda2 = report_data.pop("lambda2") + + # sources are not estimated yet; so first read in inverse solution + stc = apply_inverse_solution(outdir, subject, data, source_method, lambda2=lambda2, + pick_ori=pick_ori, inverse_operator=None, morph=reference_brain, save=False, + ) + parcel_data = parcellation.surf_parcellate_timeseries(subject_dir=outdir, subject=reference_brain, stc=stc, method=method, parcellation=parcellation_file) # Orthogonalisation if orthogonalisation not in [None, "symmetric", "none", "None"]: @@ -870,6 +938,7 @@ def parcellate( "parcellate": True, "parcellation_file": parcellation_file, "method": method, + "reference_brain": reference_brain, "orthogonalisation": orthogonalisation, "parc_fif_file": str(parc_fif_file), "n_samples": n_samples, @@ -1086,6 +1155,7 @@ def beamform_and_parcellate( "filters_svd_plot": filters_svd_plot, "parcellation_file": parcellation_file, "method": method, + "reference_brain": reference_brain, "orthogonalisation": orthogonalisation, "parc_fif_file": str(parc_fif_file), "n_samples": n_samples, @@ -1097,6 +1167,176 @@ def beamform_and_parcellate( ) +def minimum_norm_and_parcellate( + outdir, + subject, + preproc_file, + epoch_file, + chantypes, + source_method, + rank, + depth, + loose, + reg, + pick_ori, + method, + parcellation_file, + orthogonalisation, + reference_brain="fsaverage", + extra_chans="stim", + freq_range=None, + reportdir=None, +): + """Wrapper function for minimum_norm and parcellation. + + Parameters + ---------- + outdir : str + Path to where to output the source reconstruction files. + subject : str + Subject name/id. + preproc_file : str + Path to the preprocessed fif file. + epoch_file : str + Path to epoched preprocessed fif file. + chantypes : list + List of channel types to include. + source_method : str + Method to used for inverse modelling (e.g., MNE, eLORETA). + rank : int + Rank of the noise covariance matrix. + depth : float + Depth weighting factor. + loose : float + Loose orientation constraint. + reg : float + The regularisation parameter for the minimum norm estimate. + pick_ori : str + Orientation of the dipoles, used for minimum_norm. + method : str + Method to use in the parcellation. + parcellation_file : str + Path to the parcellation file to use. + orthogonalisation : bool + Should we do orthogonalisation? + freq_range : list, optional + Lower and upper band to bandpass filter before beamforming. + If None, no filtering is done. + reportdir : str, optional + Path to report directory. + """ + logger.info("minimum_norm_and_parcellate") + + # Load sensor-level data + if epoch_file is not None: + logger.info("using epoched data") + data = mne.read_epochs(epoch_file, preload=True) + else: + data = mne.io.read_raw_fif(preproc_file, preload=True) + + # Bandpass filter + if freq_range is not None: + logger.info(f"bandpass filtering: {freq_range[0]}-{freq_range[1]} Hz") + data = data.filter( + l_freq=freq_range[0], + h_freq=freq_range[1], + method="iir", + iir_params={"order": 5, "ftype": "butter"}, + ) + + logger.info("MNE source localize") + inverse_operator = minimum_norm_estimate( + outdir, + subject, + data, + chantypes, + source_method, + rank, + depth, + loose, + ) + + if reference_brain == 'mri': + reference_brain = subject + + # sources are not estimated yet; so first read in inverse solution + stc = apply_inverse_solution(outdir, subject, data, source_method, lambda2=reg, pick_ori=pick_ori, + inverse_operator=inverse_operator, morph=reference_brain, save=False, + ) + parcel_data = parcellation.surf_parcellate_timeseries(subject_dir=outdir, subject=reference_brain, stc=stc, method=method, parcellation=parcellation_file) + + # Orthogonalisation + if orthogonalisation not in [None, "symmetric", "none", "None"]: + raise NotImplementedError(orthogonalisation) + + if orthogonalisation == "symmetric": + logger.info(f"{orthogonalisation} orthogonalisation") + parcel_data = parcellation.symmetric_orthogonalise( + parcel_data, maintain_magnitudes=True + ) + + os.makedirs(f"{outdir}/{subject}/parc", exist_ok=True) + if epoch_file is None: + # Save parcellated data as a MNE Raw object + parc_fif_file = f"{outdir}/{subject}/parc/parc-raw.fif" + logger.info(f"saving {parc_fif_file}") + parc_raw = parcellation.convert2mne_raw( + parcel_data, data, extra_chans=extra_chans + ) + parc_raw.save(parc_fif_file, overwrite=True) + else: + # Save parcellated data as a MNE Epochs object + parc_fif_file = f"{outdir}/{subject}/parc/parc-epo.fif" + logger.info(f"saving {parc_fif_file}") + parc_epo = parcellation.convert2mne_epochs(parcel_data, data) + parc_epo.save(parc_fif_file, overwrite=True) + + # Save plots + parc_psd_plot = f"{subject}/parc/psd.png" + parcellation.plot_psd( + parcel_data, + fs=data.info["sfreq"], + freq_range=freq_range, + parcellation_file=parcellation_file, + filename=f"{outdir}/{parc_psd_plot}", + ) + parc_corr_plot = f"{subject}/parc/corr.png" + parcellation.plot_correlation(parcel_data, filename=f"{outdir}/{parc_corr_plot}") + + if reportdir is not None: + # Save info for the report + n_parcels = parcel_data.shape[0] + n_samples = parcel_data.shape[1] + if parcel_data.ndim == 3: + n_epochs = parcel_data.shape[2] + else: + n_epochs = None + src_report.add_to_data( + f"{reportdir}/{subject}/data.pkl", + { + "minimum_norm_and_parcellate": True, + "minimum_norm": True, + "parcellate": True, + "chantypes": chantypes, + "reference_brain": reference_brain, + "method": source_method, + "rank": rank, + "depth": depth, + "loose": loose, + "lambda2": reg, + "pick_ori": pick_ori, + "parcellation_file": parcellation_file, + "method": method, + "orthogonalisation": orthogonalisation, + "parc_fif_file": str(parc_fif_file), + "n_samples": n_samples, + "n_parcels": n_parcels, + "n_epochs": n_epochs, + "parc_psd_plot": parc_psd_plot, + "parc_corr_plot": parc_corr_plot, + }, + ) + # ---------------------- # Sign flipping wrappers