Skip to content

Commit

Permalink
Merge pull request #372 from OHBA-analysis/glm
Browse files Browse the repository at this point in the history
Add GLM wrappers to preprocessing
  • Loading branch information
matsvanes authored Jan 7, 2025
2 parents 3000af9 + 8739ab2 commit 8b917cf
Show file tree
Hide file tree
Showing 3 changed files with 265 additions and 9 deletions.
23 changes: 22 additions & 1 deletion osl_ephys/preprocessing/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def write_dataset(dataset, outbase, run_id, ftype='preproc-raw', overwrite=False
outnames = {"raw": outbase.format(run_id=run_id, ftype=ftype, fext="fif")}
if Path(outnames["raw"]).exists() and not overwrite:
raise ValueError(
"{} already exists. Please delete or do use overwrite=True.".format(fif_outname)
"{} already exists. Please delete or do use overwrite=True.".format(outnames['raw'])
)
logger.info(f"Saving dataset['raw'] as {outnames['raw']}")
dataset["raw"].save(outnames['raw'], overwrite=overwrite)
Expand Down Expand Up @@ -494,6 +494,14 @@ def write_dataset(dataset, outbase, run_id, ftype='preproc-raw', overwrite=False
logger.info(f"Saving dataset['glm'] as {outnames['glm']}")
dataset["glm"].save_pkl(outnames['glm'], overwrite=overwrite)

if "fig" in dataset and "fig" not in skip and dataset['fig'] is not None:
keys = dataset["fig"].keys()
outnames['fig'] = {}
for key in keys:
outnames['fig'][key] = outbase.format(run_id=run_id, ftype=key, fext="png")
logger.info(f"Saving dataset['fig'][{key}] as {outnames['fig'][key]}")
dataset["fig"][key].savefig(outnames['fig'][key])

# save remaining keys as pickle files
for key in dataset:
if key not in outnames and key not in skip:
Expand Down Expand Up @@ -857,13 +865,15 @@ def run_proc_chain(
"epochs": None,
"event_id": config["meta"]["event_codes"],
"ica": None,
"fig": {},
}

# Do the preprocessing
for stage in deepcopy(config["preproc"]):
method, userargs = next(iter(stage.items()))
target = userargs.get("target", "raw") # Raw is default
func = find_func(method, target=target, extra_funcs=extra_funcs)

# Actual function call
dataset = func(dataset, userargs)

Expand Down Expand Up @@ -959,6 +969,7 @@ def run_proc_batch(
overwrite=False,
skip_save=None,
extra_funcs=None,
covs=None,
random_seed='auto',
verbose="INFO",
mneverbose="WARNING",
Expand Down Expand Up @@ -995,6 +1006,8 @@ def run_proc_batch(
List of keys to skip writing to disk. If None, we don't skip any keys.
extra_funcs : list
User-defined functions.
covs : dict or pd.DataFrame
Covariates to use for building the GLM design
random_seed : 'auto' (default), int or None
Random seed to set. If 'auto', a random seed will be generated. Random seeds are set for both Python and NumPy.
If None, no random seed is set.
Expand Down Expand Up @@ -1153,11 +1166,19 @@ def run_proc_batch(
for key in group_inputs[0]:
dataset[key] = [group_inputs[i][key] for i in range(len(group_inputs))]
skip_save.append(key)

if covs is not None:
dataset['covs'] = covs
dataset['fig'] = {}

for stage in deepcopy(config["group"]):
method, userargs = next(iter(stage.items()))
# make sure the function always knows it's a group processing
userargs['run_on_group'] = True
target = userargs.get("target", "raw") # Raw is default
# skip.append(stage if userargs.get("skip_save") is True else None) # skip saving this stage to disk
func = find_func(method, target=target, extra_funcs=extra_funcs)

# Actual function call
dataset = func(dataset, userargs)
outbase = os.path.join(outdir, "{ftype}.{fext}")
Expand Down
242 changes: 241 additions & 1 deletion osl_ephys/preprocessing/osl_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
from os.path import exists
from scipy import stats
from pathlib import Path

import matplotlib.pyplot as plt
import glmtools
from ..glm import glm_epochs, glm_spectrum, glm_irasa, group_glm_epochs, group_glm_spectrum, MaxStatPermuteGLMSpectrum, ClusterPermuteGLMSpectrum
from ..glm.glm_base import SensorMaxStatPerm, SensorClusterPerm
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -922,3 +925,240 @@ def run_osl_ica_manualreject(dataset, userargs):
else:
logger.info("Components were not removed from raw data")
return dataset

#%% GLM wrappers

def run_osl_zscore_present_data(dataset, userargs):
"""
z-scoring parametric regressors, without NaNs
Nans will be zeros in the z-scored version
Parameters
----------
dataset: dict
Dictionary containing at least an MNE object with the key ``covs``.
userargs: dict
Dictionary of additional arguments containing the keys ``keys``.
"""
keys = userargs.pop("keys", None)
# make sure keys is a single string or list of strings
if keys[0]=='[' and keys[-1]==']':
keys = keys[1:-1].split(' ')

for key in keys:
new = stats.zscore(dataset["covs"][key], nan_policy='omit')
new[np.isnan(dataset["covs"][key])] = 0
dataset["covs"][key] = new
return dataset


def run_osl_glm_add_regressor(dataset, userargs):
"""osl-ephys Batch wrapper for :py:func:`osl_ephys.preprocessing.osl_glm.add_regressor <osl_ephys.preprocessing.osl_glm.add_regressor>`.
Parameters
----------
dataset: dict
Dictionary containing at least an MNE object with the key ``covs``.
userargs: dict
Dictionary of additional arguments containing the keys ``keys``.
"""
logger.info("osl-ephys Stage - {0}".format("GLM Add Regressor"))
if 'design_config' not in dataset or not isinstance(dataset['design_config'], glmtools.design.DesignConfig):
dataset['design_config'] = glmtools.design.DesignConfig()

rtype = userargs.pop("rtype", None)
name = userargs.pop("name", None)
codes = userargs.pop("codes", None)
preproc = userargs.pop("preproc", None)
key = userargs.pop("key", None)

if rtype == 'Constant':
dataset['design_config'].add_regressor(name, rtype)
elif rtype == 'Categorical':
if codes == 'unique': # add a regressor for each unique value
codes = np.unique(dataset['covs'][key])
for code in codes:
dataset['design_config'].add_regressor(name=name + '_{0}'.format(code), rtype=rtype, codes=code)
else:
codes = [float(codes)
if np.logical_or(type(codes) == int, type(codes) == float)
else np.array(codes[0].split(" ")).astype(float)][0]
dataset['design_config'].add_regressor(name=name, rtype=rtype, codes=codes)
elif rtype == 'Parametric':
dataset['design_config'].add_regressor(name=name, rtype=rtype, datainfo=key, preproc=preproc)
elif rtype == 'MeanEffects':
dataset['design_config'].add_regressor(name=name + '_{0}',rtype=rtype, datainfo=key)
else:
raise ValueError("Unknown regressor type")
return dataset


def run_osl_glm_add_contrast(dataset, userargs):
"""osl-ephys Batch wrapper for :py:func:`osl_ephys.preprocessing.osl_glm.add_regressor <osl_ephys.preprocessing.osl_glm.add_regressor>`.
Parameters
"""
logger.info("osl-ephys Stage - {0}".format("GLM Add Contrast"))

simple = userargs.pop("simple", False)
name = userargs.pop("name", None)
values = userargs.pop("values", None)
key = userargs.pop("key", None)

if simple:
dataset['design_config'].add_simple_contrasts()
else:
if values == 'unique':
values = np.unique(dataset['covs'][key])
values={f"{key}_{v}": 1/len(values) for v in values}
else:
for key, value in values.items():
if isinstance(values[key], str):
values[key] = float(eval(value))
else:
values[key] = float(value)
dataset['design_config'].add_contrast(name=name, values=values)

return dataset


def run_osl_glm_fit(dataset, userargs):
""" wrapper for the different glm functions in the glm module
Parameters
----------
dataset: dict
Dictionary containing at least an MNE object with the key ``covs``.
userargs: dict
Dictionary of additional arguments
Returns
-------
dataset: dict
Input dictionary containing MNE objects that have been modified in place.
"""
run_on_group = userargs.pop("run_on_group", False)

method = userargs.pop("method", None)
if method is None:
raise ValueError("method not specified")
target = userargs.pop("target", None)
if target is None:
if run_on_group:
target = "glm"
else:
if method in ['epochs', 'glm_epochs']:
target = "epochs"
elif method in ['spectrum', 'glm_spectrum']:
target = "raw"
name = userargs.pop("name", None)
if name is None:
if run_on_group:
name = "group_glm"
else:
name = "glm"
metric = userargs.pop("metric", 'copes')

plot_summary = userargs.pop("plot_summary", True)
plot_efficiency = userargs.pop("plot_efficiency", True)
plot_leverage = userargs.pop("plot_leverage", True)

if method == 'epochs' or method == 'glm_epochs':
baseline = userargs.pop("baseline", None)
if baseline is not None:
baseline = np.array(baseline[1:-1].split(" ")).astype(float)

if run_on_group:
dataset[name] = group_glm_epochs(dataset[target], dataset['design_config'], dataset['covs'], metric, baseline)
else:
dataset[name] = glm_epochs(dataset['design_config'], dataset[target])

elif method in ['spectrum', 'glm_spectrum', 'irasa', 'glm_irasa']:
if run_on_group:
dataset[name] = group_glm_spectrum(dataset[target], dataset['design_config'], dataset['covs'], metric, baseline)
else:
reg_categorical = userargs.pop("reg_categorical", None)
if reg_categorical[0]=='[' and reg_categorical[-1]==']':
reg_categorical = userargs["covs"][reg_categorical[1:-1].split(' ')]
else:
reg_categorical = userargs["covs"][reg_categorical]

reg_ztrans = userargs.pop("reg_ztrans", None)
if reg_ztrans[0]=='[' and reg_ztrans[-1]==']':
reg_ztrans = userargs["covs"][reg_ztrans[1:-1].split(' ')]
else:
reg_ztrans = userargs["covs"][reg_ztrans]

reg_unitmax = userargs.pop("reg_unitmax", None)
if reg_unitmax[0]=='[' and reg_unitmax[-1]==']':
reg_unitmax = userargs["covs"][reg_unitmax[1:-1].split(' ')]
else:
reg_unitmax = userargs["covs"][reg_unitmax]

if method in ['spectrum', 'glm_spectrum']:
dataset[name] = glm_spectrum(dataset[target], reg_unitmax=reg_unitmax, reg_ztrans=reg_ztrans, reg_categorical=reg_categorical, **userargs)
else:
dataset[name] = glm_irasa(dataset[target], reg_unitmax=reg_unitmax, reg_ztrans=reg_ztrans, reg_categorical=reg_categorical, **userargs)

if plot_summary:
dataset['fig'][name + 'design_summary'] = dataset[name].design.plot_summary(show=False)

if plot_efficiency:
dataset['fig'][name + 'design_efficiency'] = dataset[name].design.plot_efficiency(show=False)

if plot_leverage:
dataset['fig'][name + 'design_leverage'] = dataset[name].design.plot_leverage(show=False)

return dataset


def run_osl_glm_permutations(dataset, userargs):
""" wrapper for the different permutation options in the glm module
Parameters
----------
dataset: dict
Dictionary containing at least an MNE object with the key ``covs``.
userargs: dict
Dictionary of additional arguments
Returns
-------
dataset: dict
Input dictionary containing MNE objects that have been modified in place.
"""
run_on_group = userargs.pop("run_on_group", False)
target = userargs.pop("target", "group_glm")
name = userargs.pop("name", "group_glm_perm")
method = userargs.pop("method", None)
if method is None:
raise ValueError("method not specified")
type = userargs.pop("type", None)
if type is None:
raise ValueError("type not specified (e.g. 'max', 'cluster')")

thresh = userargs.pop("threshold", 95)
plot_sig = userargs.pop("plot_sig", True)

contrast = userargs.pop("contrast", None)
contrast = dataset[target].contrast_names.index(contrast)
fl_contrast = userargs.pop("fl_contrast", 0)
if fl_contrast != 0:
fl_contrast = dataset[target].fl_contrast_names.index(fl_contrast)

if type in ['max', 'maxstat']:
if method == 'epochs' or method == 'glm_epochs':
dataset[name] = SensorMaxStatPerm(dataset[target], contrast, fl_contrast, **userargs)
elif method == 'spectrum' or method == 'glm_spectrum':
dataset[name] = MaxStatPermuteGLMSpectrum(dataset[target], contrast, fl_contrast, **userargs)
elif type == 'cluster':
if method == 'epochs' or method == 'glm_epochs':
dataset[name] = SensorClusterPerm(dataset[target], contrast, fl_contrast, **userargs)
elif method == 'spectrum' or method == 'glm_spectrum':
dataset[name] = ClusterPermuteGLMSpectrum(dataset[target], contrast, fl_contrast, **userargs)

if plot_sig:
fig, ax = plt.subplots()
dataset[name].plot_sig_clusters(thresh, ax=ax)
dataset['fig'][name + 'sig' + str(thresh)] = fig
return dataset
9 changes: 2 additions & 7 deletions osl_ephys/report/preproc_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,6 @@

from ..utils import process_file_inputs, validate_outdir
from ..utils.logger import log_or_print
from ..preprocessing import (
read_dataset,
load_config,
get_config_from_fif,
plot_preproc_flowchart,
)


# ----------------------------------------------------------------------------------
Expand All @@ -64,6 +58,7 @@ def gen_report_from_fif(infiles, outdir, ftype=None, logsdir=None, run_id=None):
run_id : str
Run ID.
"""
from ..preprocessing import read_dataset

# Validate input files and directory to save html file and plots to
infiles, outnames, good_files = process_file_inputs(infiles)
Expand Down Expand Up @@ -493,7 +488,7 @@ def plot_flowchart(raw, savebase=None):
Path to saved figure.
"""

from ..preprocessing import get_config_from_fif, plot_preproc_flowchart
# Get config info from raw.info['description']
config_list = get_config_from_fif(raw)

Expand Down

0 comments on commit 8b917cf

Please sign in to comment.