From 03484104740f5e5601bab56c7931c701ff894e56 Mon Sep 17 00:00:00 2001 From: Jeff Stout Date: Wed, 15 Jan 2025 11:33:40 -0500 Subject: [PATCH] ENH - qa gui writes out bad epochs to .ds file --- nih2mne/GUI/qt_gui.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/nih2mne/GUI/qt_gui.py b/nih2mne/GUI/qt_gui.py index 8488e05..e49c17c 100755 --- a/nih2mne/GUI/qt_gui.py +++ b/nih2mne/GUI/qt_gui.py @@ -35,7 +35,7 @@ import glob import time from PyQt5.QtCore import QTimer - +import pandas as pd ## Create subject tile @@ -265,27 +265,41 @@ def return_message_box_response(self, i): raw = self.bids_info.meg_list[idx].raw fname = self.bids_info.meg_list[idx].rel_path _bads = self.bids_info.current_meg_dset.info['bads'] + _annots = self.bids_info.current_meg_dset.annotations if i.text()=='&No': print(f'NOT saving bads to raw data') elif i.text()=='&Save': print(f'Save bad chans from Data Editor: {_bads}') self.write_bad_chans_to_raw(fname=fname, bad_chs=_bads) + self.write_bad_segments(fname=fname, annotations=_annots) else: print(f'No operation performed') def write_bad_chans_to_raw(self, fname=None, bad_chs=None): if fname==None: raise ValueError - idx = self.b_chooser_meg.currentIndex() print(fname) bads_fname = op.join(fname, 'BadChannels') bads_str = ''.join([i+'\n' for i in bad_chs]) with open(bads_fname, 'w') as f: f.writelines(bads_str) - # def write_bad_segments(self, fname=None, annotations=None): + def write_bad_segments(self, fname=None, annotations=None): + if fname==None: + raise ValueError + # Change the annotations to dataframe, filter for bads, format for CTF and write + annots_dframe = pd.DataFrame(annotations) + annots_dframe['description_lower']=annots_dframe.description.str.lower() + bads_dframe = annots_dframe[annots_dframe.description_lower.str[0:4]=='bad_'] + bads_dframe['offset'] = bads_dframe.onset + bads_dframe.duration + bads_dframe['epoch'] = 0 + #Format columns correctly + bads_dframe = bads_dframe[['epoch','onset', 'offset', 'description']] + #Write tsv + bads_segment_fname = op.join(fname, 'bad.segments') + bads_dframe.to_csv(bads_segment_fname, index=False, header=False, sep='\t') def plot_fids(self):