-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbaseline_testing.py
148 lines (121 loc) · 5.24 KB
/
baseline_testing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
###
from ieeg.navigate import channel_outlier_marker, trial_ieeg, crop_empty_data, \
outliers_to_nan
from ieeg.io import raw_from_layout, get_data
from ieeg.timefreq.utils import crop_pad
from ieeg.timefreq import gamma
from ieeg.calc.scaling import rescale
import mne
import os
import numpy as np
from ieeg.calc.reshape import make_data_same
from ieeg.calc.stats import time_perm_cluster
from utils import calculate_RTs
import matplotlib.pyplot as plt
def get_baseline(inst: mne.io.BaseRaw, base_times: tuple[float, float]):
inst = inst.copy()
inst.load_data()
ch_type = inst.get_channel_types(only_data_chs=True)[0]
inst.set_eeg_reference(ref_channels="average", ch_type=ch_type)
adjusted_base_times = [base_times[0] - 0.5, base_times[1] + 0.5]
trials = trial_ieeg(inst, "experimentStart", adjusted_base_times,
preload=True)
# outliers_to_nan(trials, outliers=10)
HG_base = gamma.extract(trials, copy=False, n_jobs=1)
crop_pad(HG_base, "0.5s")
del inst
return HG_base
def plot_HG_and_stats(sub, task, output_name, event=None, times=(-1, 1.5),
base_times=(-0.5, 0), LAB_root=None, channels=None,
full_trial_base=False):
"""
Plot high gamma (HG) and statistics for a given subject and task using specified event.
Parameters:
- sub (str): The subject identifier.
- task (str): The task identifier.
- output_name (str): The name for the output files.
- event (str, optional): Event name to process. Defaults to None.
- times (tuple, optional): A tuple indicating the start and end times for processing. Defaults to (-1, 1.5).
- base_times (tuple, optional): A tuple indicating the start and end base times for processing. Defaults to (-0.5, 0).
- LAB_root (str, optional): The root directory for the lab. Will be determined based on OS if not provided. Defaults to None.
- channels (list of strings, optional): The channels to plot and get stats for. Default is all channels.
- full_trial_base (boolean): Whether to use the full trial as the baseline period. Default is False.
This function will process the provided event for a given subject and task.
High gamma (HG) will be computed, and statistics will be calculated and plotted.
The results will be saved to output files.
"""
pass
sub = 'D0057'
task = 'GlobalLocal'
output_name = "response_experimentStartBase1secTo101Sec"
event = "Response"
times = (-1, 1.5)
base_times = (1, 101)
LAB_root = None
channels = None
full_trial_base = False
if LAB_root is None:
HOME = os.path.expanduser("~")
if os.name == 'nt': # windows
LAB_root = os.path.join(HOME, "Box", "CoganLab")
else: # mac
LAB_root = os.path.join(HOME, "Library", "CloudStorage", "Box-Box",
"CoganLab")
layout = get_data(task, root=LAB_root)
filt = raw_from_layout(layout.derivatives['derivatives/clean'], subject=sub,
extension='.edf', desc='clean', preload=False)
save_dir = os.path.join(layout.root, 'derivatives', 'freqFilt', 'figs', sub)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
###
good = crop_empty_data(filt)
# %%
print(f"good channels before dropping bads: {len(good.ch_names)}")
print(f"filt channels before dropping bads: {len(filt.ch_names)}")
good.info['bads'] = channel_outlier_marker(good, 3, 2)
print("Bad channels in 'good':", good.info['bads'])
filt.drop_channels(good.info[
'bads']) # this has to come first cuz if you drop from good first, then good.info['bads'] is just empty
good.drop_channels(good.info['bads'])
print("Bad channels in 'good' after dropping once:", good.info['bads'])
print(f"good channels after dropping bads: {len(good.ch_names)}")
print(f"filt channels after dropping bads: {len(filt.ch_names)}")
HG_base = get_baseline(filt, base_times)
good.load_data()
# If channels is None, use all channels
if channels is None:
channels = good.ch_names
else:
# Validate the provided channels
invalid_channels = [ch for ch in channels if ch not in good.ch_names]
if invalid_channels:
raise ValueError(
f"The following channels are not valid: {invalid_channels}")
# Use only the specified channels
good.pick_channels(channels)
ch_type = filt.get_channel_types(only_data_chs=True)[0]
good.set_eeg_reference(ref_channels="average", ch_type=ch_type)
# Epoching and HG extraction for the specified event
times_adj = [times[0] - 0.5, times[1] + 0.5]
trials = trial_ieeg(good, event, times_adj, preload=True,
reject_by_annotation=False)
outliers_to_nan(trials, outliers=10)
HG_ev1 = gamma.extract(trials, copy=False, n_jobs=1)
crop_pad(HG_ev1, "0.5s")
HG_ev1_rescaled = rescale(HG_ev1, HG_base, copy=True, mode='zscore')
HG_base.decimate(2)
HG_ev1.decimate(2)
RTs, skipped = calculate_RTs(good)
avg_RT = np.median(RTs)
###
print(f"Shape of HG_ev1._data: {HG_ev1._data.shape}")
print(f"Shape of HG_base._data: {HG_base._data.shape}")
sig1 = HG_ev1._data
sig2 = HG_base._data
sig2 = make_data_same(sig2, sig1.shape)
print(f"Shape of sig1: {sig1.shape}")
print(f"Shape of sig2: {sig2.shape}")
mat = time_perm_cluster(sig1, sig2, 0.05, n_jobs=6, ignore_adjacency=1)
fig = plt.figure()
plt.imshow(mat, aspect='auto')
fig.savefig("fig")