diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6e4144e6..62635132 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,3 +22,10 @@ repos: hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] + + # docformatter + - repo: https://github.com/PyCQA/docformatter + rev: eb1df347edd128b30cd3368dddc3aa65edcfac38 + hooks: + - id: docformatter + additional_dependencies: [ tomli ] diff --git a/README.md b/README.md index 314e76b8..b2fa9d36 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,7 @@ The colorpalettes for light and dark theme are inspired from [PyQtDarkTheme](htt Many ideas and basics for GUI-Programming where taken from [LearnPyQt](https://www.learnpyqt.com/) and numerous stackoverflow-questions/solutions. +For the nodes, code and major influences from [NodeGraphQt](https://github.com/jchanvfx/NodeGraphQt) are used. The development is financially supported by [Heidelberg University](https://www.uni-heidelberg.de/de/forschung/forschungsprofil/fields-of-focus/field-of-focus-i). diff --git a/mne_pipeline_hd/conftest.py b/mne_pipeline_hd/conftest.py index 20e53e15..4c7e7dc9 100644 --- a/mne_pipeline_hd/conftest.py +++ b/mne_pipeline_hd/conftest.py @@ -8,8 +8,9 @@ import pytest +from mne_pipeline_hd.gui.node.node_viewer import NodeViewer from mne_pipeline_hd.gui.main_window import MainWindow -from mne_pipeline_hd.pipeline.controller import Controller +from mne_pipeline_hd.pipeline.controller import Controller, NewController from mne_pipeline_hd.pipeline.pipeline_utils import _set_test_run @@ -32,3 +33,58 @@ def main_window(controller, qtbot): qtbot.addWidget(mw) return mw + + +@pytest.fixture +def nodeviewer(qtbot): + viewer = NodeViewer(NewController(), debug_mode=True) + viewer.resize(1000, 1000) + qtbot.addWidget(viewer) + viewer.show() + + func_kwargs = { + "ports": [ + { + "name": "In1", + "port_type": "in", + "accepted_ports": ["Out1"], + }, + { + "name": "In2", + "port_type": "in", + "accepted_ports": ["Out1, Out2"], + }, + { + "name": "Out1", + "port_type": "out", + "accepted_ports": ["In1"], + "multi_connection": True, + }, + { + "name": "Out2", + "port_type": "out", + "accepted_ports": ["In1", "In2"], + "multi_connection": True, + }, + ], + "name": "test_func", + "parameters": { + "low_cutoff": { + "alias": "Low-Cutoff", + "gui": "FloatGui", + "default": 0.1, + }, + "high_cutoff": { + "alias": "High-Cutoff", + "gui": "FloatGui", + "default": 0.2, + }, + }, + } + func_node1 = viewer.create_node("FunctionNode", **func_kwargs) + func_node2 = viewer.create_node("FunctionNode", **func_kwargs) + func_node1.output(port_idx=0).connect_to(func_node2.input(port_idx=0)) + + func_node2.setPos(400, 100) + + return viewer diff --git a/mne_pipeline_hd/development/animation_test.py b/mne_pipeline_hd/development/animation_test.py new file mode 100644 index 00000000..588b65e0 --- /dev/null +++ b/mne_pipeline_hd/development/animation_test.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +import sys +from qtpy.QtWidgets import ( + QApplication, + QGraphicsPathItem, + QGraphicsView, + QGraphicsScene, +) +from qtpy.QtGui import QPen, QColor, QPainterPath +from qtpy.QtCore import Qt, QVariantAnimation + + +class Edge(QGraphicsPathItem): + def __init__(self): + super().__init__() + self.setPen(QPen(Qt.white, 5, Qt.SolidLine)) + + self.animation = QVariantAnimation() + self.animation.setLoopCount(-1) + self.animation.valueChanged.connect(self.handle_valueChanged) + self.animation.setStartValue(QColor("blue")) + self.animation.setKeyValueAt(0.25, QColor("green")) + self.animation.setKeyValueAt(0.5, QColor("yellow")) + self.animation.setKeyValueAt(0.75, QColor("red")) + self.animation.setEndValue(QColor("blue")) + self.animation.setDuration(2000) + + def start_animation(self): + self.animation.start() + + def handle_valueChanged(self, value): + pen = self.pen() + pen.setColor(value) + self.setPen(pen) + + +app = QApplication(sys.argv) +viewer = QGraphicsView() +scene = QGraphicsScene() +viewer.setScene(scene) +edge = Edge() +scene.addItem(edge) +path = QPainterPath() +path.moveTo(0, 0) +path.lineTo(100, 100) +edge.setPath(path) +edge.start_animation() +viewer.show() + +sys.exit(app.exec()) diff --git a/mne_pipeline_hd/development/development_considerations.md b/mne_pipeline_hd/development/development_considerations.md index 8fdc81c4..269b34cd 100644 --- a/mne_pipeline_hd/development/development_considerations.md +++ b/mne_pipeline_hd/development/development_considerations.md @@ -48,3 +48,11 @@ setting should be device/OS-dependent: 2. QSettings(), which is stored by Qt on an OS-depending location and which may differ between devices/OS. Settings which dependent on the device/OS should go here (e.g. `n_jobs` or `use_cuda`) + +## Nodes +Nodes should improve usability and the representation of the pipeline by the following: +- The order of execution is now clearer and renders the function-dependency considerations obsolete. +- The user can now see the input and output of each function. +- Parameters will now go to each function directly, overview only optional +- Using multiple File-Lists or Projets side-by-side will be more easy to handle. +- diff --git a/mne_pipeline_hd/extra/parameters.csv b/mne_pipeline_hd/extra/parameters.csv index 07aca06d..3be7d3c8 100644 --- a/mne_pipeline_hd/extra/parameters.csv +++ b/mne_pipeline_hd/extra/parameters.csv @@ -78,7 +78,7 @@ stc_animation_span;;Inverse;(0,0.5);s;time-span for stc-animation[s];TupleGui; stc_animation_dilat;;Inverse;20;;time-dilation for stc-animation;IntGui; target_labels;Target Labels;Inverse;[];;;LabelGui; label_colors;;Inverse;{};;Set custom colors for labels.;ColorGui;{'keys': 'target_labels', 'none_select':True} -extract_mode;Label-Extraction-Mode;Inverse;auto;;mode for extracting label-time-course from Source-Estimate;ComboGui;{'options': ['auto', 'max', 'mean', 'mean_flip', 'pca_flip']} +extract_mode;Label-Extraction-Mode;Inverse;mean;;mode for extracting label-time-course from Source-Estimate;ComboGui;{'options': ['auto', 'max', 'mean', 'mean_flip', 'pca_flip']} con_methods;;Connectivity;['coh'];;methods for connectivity;CheckListGui;{'options': ['coh', 'cohy', 'imcoh', 'plv', 'ciplv', 'ppc', 'pli', 'pli2_unbiased', 'wpli', 'wpli2_debiased']} con_fmin;;Connectivity;30;;lower frequency/frequencies for connectivity;MultiTypeGui;{'type_selection': True, 'types': ['float', 'list']} con_fmax;;Connectivity;80;;upper frequency/frequencies for connectivity;MultiTypeGui;{'type_selection': True, 'types': ['float', 'list']} diff --git a/mne_pipeline_hd/functions/operations.py b/mne_pipeline_hd/functions/operations.py index d87f376d..fd8f4a0c 100644 --- a/mne_pipeline_hd/functions/operations.py +++ b/mne_pipeline_hd/functions/operations.py @@ -893,8 +893,6 @@ def tfr( powers = list() itcs = list() - meeg.load_epochs() - # Calculate Time-Frequency for each trial from epochs # using the selected method for trial, epoch in meeg.get_trial_epochs(): diff --git a/mne_pipeline_hd/functions/plot.py b/mne_pipeline_hd/functions/plot.py index 8d7ca81f..f282c097 100644 --- a/mne_pipeline_hd/functions/plot.py +++ b/mne_pipeline_hd/functions/plot.py @@ -370,7 +370,7 @@ def plot_compare_evokeds(meeg, show_plots): evokeds = {f"{evoked.comment}={evoked.nave}": evoked for evoked in evokeds} - fig = mne.viz.plot_compare_evokeds(evokeds, show=show_plots) + fig = mne.viz.plot_compare_evokeds(evokeds, title=meeg.name, show=show_plots) meeg.plot_save("evokeds", subfolder="compare", matplotlib_figure=fig) diff --git a/mne_pipeline_hd/gui/base_widgets.py b/mne_pipeline_hd/gui/base_widgets.py index e54fe7f9..fc5abc78 100644 --- a/mne_pipeline_hd/gui/base_widgets.py +++ b/mne_pipeline_hd/gui/base_widgets.py @@ -139,11 +139,11 @@ def _data_changed(self, index, _): logger().debug(f"{data} changed at {index}") def content_changed(self): - """Informs ModelView about external change made in data""" + """Informs ModelView about external change made in data.""" self.model.layoutChanged.emit() def replace_data(self, new_data): - """Replaces model._data with new_data""" + """Replaces model._data with new_data.""" self.model._data = new_data self.content_changed() @@ -431,12 +431,12 @@ def _checked_changed(self): logger().debug(f"Changed values: {self.model._checked}") def replace_checked(self, new_checked): - """Replaces model._checked with new checked list""" + """Replaces model._checked with new checked list.""" self.model._checked = new_checked self.content_changed() def select_all(self): - """Select all Items while leaving reference to model._checked intact""" + """Select all Items while leaving reference to model._checked intact.""" for item in [i for i in self.model._data if i not in self.model._checked]: self.model._checked.append(item) # Inform Model about changes @@ -444,8 +444,7 @@ def select_all(self): self._checked_changed() def clear_all(self): - """Deselect all Items while leaving reference - to model._checked intact""" + """Deselect all Items while leaving reference to model._checked intact.""" self.model._checked.clear() # Inform Model about changes self.content_changed() @@ -453,8 +452,8 @@ def clear_all(self): class CheckDictList(BaseList): - """A List-Widget to display the items of a list and mark them depending on - their appearance in check_dict. + """A List-Widget to display the items of a list and mark them depending on their + appearance in check_dict. Parameters ---------- @@ -512,15 +511,15 @@ def __init__( ) def replace_check_dict(self, new_check_dict=None): - """Replaces model.check_dict with new check_dict""" + """Replaces model.check_dict with new check_dict.""" if new_check_dict: self.model._check_dict = new_check_dict self.content_changed() class CheckDictEditList(EditList): - """A List-Widget to display the items of a list and mark them - depending of their appearance in check_dict. + """A List-Widget to display the items of a list and mark them depending of their + appearance in check_dict. Parameters ---------- @@ -589,7 +588,7 @@ def __init__( ) def replace_check_dict(self, new_check_dict=None): - """Replaces model.check_dict with new check_dict""" + """Replaces model.check_dict with new check_dict.""" if new_check_dict: self.model._check_dict = new_check_dict self.content_changed() @@ -616,10 +615,9 @@ def __init__( model.layoutChanged.emit() def get_keyvalue_by_index(self, index): - """For the given index, make an entry in item_dict with the data - at index as key and a dict as value defining. - if data is key or value and refering to the corresponding key/value - of data depending on its type. + """For the given index, make an entry in item_dict with the data at index as key + and a dict as value defining. if data is key or value and refering to the + corresponding key/value of data depending on its type. Parameters ---------- @@ -683,7 +681,7 @@ def select(self, keys, values, clear_selection=True): class SimpleDict(BaseDict): - """A Widget to display a Dictionary + """A Widget to display a Dictionary. Parameters ---------- @@ -699,7 +697,6 @@ class SimpleDict(BaseDict): Set True to resize the rows to contents. resize_columns : bool Set True to resize the columns to contents. - """ def __init__( @@ -723,8 +720,9 @@ def __init__( # ToDo: DataChanged somehow not emitted when row is removed +# ToDo: Bug when removing multiple rows (fix and add tests) class EditDict(BaseDict): - """A Widget to display and edit a Dictionary + """A Widget to display and edit a Dictionary. Parameters ---------- @@ -745,7 +743,6 @@ class EditDict(BaseDict): Set True to resize the rows to contents. resize_columns : bool Set True to resize the columns to contents. - """ def __init__( @@ -831,8 +828,7 @@ def edit_item(self): class BasePandasTable(Base): - """ - The Base-Class for a table from a pandas DataFrame + """The Base-Class for a table from a pandas DataFrame. Parameters ---------- @@ -870,7 +866,7 @@ def __init__( model.layoutChanged.emit() def get_rowcol_by_index(self, index, data_list): - """Get the data at index and the row and column of this data + """Get the data at index and the row and column of this data. Parameters ---------- @@ -883,7 +879,6 @@ def get_rowcol_by_index(self, index, data_list): ----- Because this function is supposed to be called consecutively, the information is stored in an existing list (data_list) - """ data = self.model.getData(index) row = self.model.headerData( @@ -928,9 +923,7 @@ def _selection_changed(self): logger().debug(f"Selection changed to {selection_list}") def select(self, values=None, rows=None, columns=None, clear_selection=True): - """ - Select items in Pandas DataFrame by value - or select complete rows/columns. + """Select items in Pandas DataFrame by value or select complete rows/columns. Parameters ---------- @@ -942,7 +935,6 @@ def select(self, values=None, rows=None, columns=None, clear_selection=True): Names of columns. clear_selection: bool | None Set True if you want to clear the selection before selecting. - """ indexes = list() # Get indexes for matching items in pd_data @@ -980,7 +972,7 @@ def select(self, values=None, rows=None, columns=None, clear_selection=True): class SimplePandasTable(BasePandasTable): - """A Widget to display a pandas DataFrame + """A Widget to display a pandas DataFrame. Parameters ---------- @@ -1024,7 +1016,7 @@ def __init__( class EditPandasTable(BasePandasTable): - """A Widget to display and edit a pandas DataFrame + """A Widget to display and edit a pandas DataFrame. Parameters ---------- @@ -1148,8 +1140,8 @@ def init_ui(self): self.setLayout(layout) def update_data(self): - """Has to be called, when model._data is rereferenced - by for example add_row to keep external data updated. + """Has to be called, when model._data is rereferenced by for example add_row to + keep external data updated. Returns ------- @@ -1319,7 +1311,7 @@ def closeEvent(self, event): class AssignWidget(QWidget): - """ """ + """""" def __init__( self, diff --git a/mne_pipeline_hd/gui/dialogs.py b/mne_pipeline_hd/gui/dialogs.py index cd74adfd..d006b059 100644 --- a/mne_pipeline_hd/gui/dialogs.py +++ b/mne_pipeline_hd/gui/dialogs.py @@ -33,9 +33,8 @@ class CheckListDlg(QDialog): def __init__(self, parent, data, checked): - """ - BaseClass for A Dialog with a Check-List, - open() has to be called in SubClass or directly. + """BaseClass for A Dialog with a Check-List, open() has to be called in SubClass + or directly. Parameters ---------- diff --git a/mne_pipeline_hd/gui/function_widgets.py b/mne_pipeline_hd/gui/function_widgets.py index 8936a2d2..6c1a8911 100644 --- a/mne_pipeline_hd/gui/function_widgets.py +++ b/mne_pipeline_hd/gui/function_widgets.py @@ -1736,8 +1736,7 @@ def init_ui(self): self.setLayout(layout) def _check_empty(self): - """Check if the dict for current_func in add_kwargs is empty, - then remove it""" + """Check if the dict for current_func in add_kwargs is empty, then remove it.""" if self.current_func: if self.current_func in self.ct.pr.add_kwargs: if len(self.ct.pr.add_kwargs[self.current_func]) == 0: diff --git a/mne_pipeline_hd/gui/gui_utils.py b/mne_pipeline_hd/gui/gui_utils.py index 2669a30c..2fbccc26 100644 --- a/mne_pipeline_hd/gui/gui_utils.py +++ b/mne_pipeline_hd/gui/gui_utils.py @@ -27,8 +27,10 @@ Signal, Slot, QTimer, + QEvent, ) -from qtpy.QtGui import QFont, QTextCursor, QPalette, QColor, QIcon +from qtpy.QtTest import QTest +from qtpy.QtGui import QFont, QTextCursor, QMouseEvent, QPalette, QColor, QIcon from qtpy.QtWidgets import ( QApplication, QDialog, @@ -159,9 +161,9 @@ def init_ui(self): def show_error_dialog(exc_str): - """Checks if a QApplication instance is available - and shows the Error-Dialog. - If unavailable (non-console application), log an additional notice. + """Checks if a QApplication instance is available and shows the Error-Dialog. + + If unavailable (non-console application), log an additional notice. """ if QApplication.instance() is not None: ErrorDialog(exc_str, title="A unexpected error occurred") @@ -205,14 +207,14 @@ def __init__(self, *args, **kwargs): def exception_hook(self, exc_type, exc_value, exc_traceback): """Function handling uncaught exceptions. + It is triggered each time an uncaught exception occurs. """ if issubclass(exc_type, KeyboardInterrupt): # ignore keyboard interrupt to support console applications sys.__excepthook__(exc_type, exc_value, exc_traceback) else: - # Print Error to Console - traceback.print_exception(exc_type, exc_value, exc_traceback) + # Error logging exc_info = (exc_type, exc_value, exc_traceback) exc_str = ( exc_type.__name__, @@ -220,9 +222,7 @@ def exception_hook(self, exc_type, exc_value, exc_traceback): "".join(traceback.format_tb(exc_traceback)), ) logger().critical( - f"Uncaught exception:\n" - f"{exc_str[0]}: {exc_str[1]}\n" - f"{exc_str[2]}", + "Uncaught exception:", exc_info=exc_info, ) @@ -236,7 +236,7 @@ def __init__(self, parent=None): class ConsoleWidget(QPlainTextEdit): - """A Widget displaying formatted stdout/stderr-output""" + """A Widget displaying formatted stdout/stderr-output.""" def __init__(self): super().__init__() @@ -317,8 +317,8 @@ def mouseDoubleClickEvent(self, event): class MainConsoleWidget(ConsoleWidget): - """A subclass of ConsoleWidget which is linked to stdout/stderr - of the main process""" + """A subclass of ConsoleWidget which is linked to stdout/stderr of the main + process.""" def __init__(self): super().__init__() @@ -357,7 +357,7 @@ def flush(self): class WorkerSignals(QObject): - """Class for standard Worker-Signals""" + """Class for standard Worker-Signals.""" # Emitted when the function finished and returns the return-value finished = Signal(object) @@ -379,7 +379,7 @@ class WorkerSignals(QObject): class Worker(QRunnable): - """A class to execute a function in a seperate Thread + """A class to execute a function in a seperate Thread. Parameters ---------- @@ -391,7 +391,6 @@ class Worker(QRunnable): Any Arguments passed to the executed function kwargs Any Keyword-Arguments passed to the executed function - """ def __init__(self, function, *args, **kwargs): @@ -405,9 +404,7 @@ def __init__(self, function, *args, **kwargs): @Slot() def run(self): - """ - Initialise the runner function with passed args, kwargs. - """ + """Initialise the runner function with passed args, kwargs.""" # Add signals to kwargs if in parameters of function if "worker_signals" in signature(self.function).parameters: self.kwargs["worker_signals"] = self.signals @@ -430,7 +427,7 @@ def cancel(self): # ToDo: Make PyQt-independent with tqdm class WorkerDialog(QDialog): - """A Dialog for a Worker doing a function""" + """A Dialog for a Worker doing a function.""" thread_finished = Signal(object) @@ -757,6 +754,76 @@ def get_user_input_string(prompt, title="Input required!", force=False): return user_input +def invert_rgb_color(color_tuple): + return tuple(map(lambda i, j: i - j, (255, 255, 255), color_tuple)) + + +def format_color(clr): + """This converts a hex-color-string to a tuple of RGB-values.""" + if isinstance(clr, str): + clr = clr.strip("#") + return tuple(int(clr[i : i + 2], 16) for i in (0, 2, 4)) + return clr + + +def mouse_interaction(func): + def wrapper(**kwargs): + QTest.qWaitForWindowExposed(kwargs["widget"]) + QTest.qWait(10) + func(**kwargs) + QTest.qWait(10) + + return wrapper + + +@mouse_interaction +def mousePress(widget=None, pos=None, button=None, modifier=None): + if modifier is None: + modifier = Qt.KeyboardModifier.NoModifier + event = QMouseEvent( + QEvent.Type.MouseButtonPress, pos, button, Qt.MouseButton.NoButton, modifier + ) + QApplication.sendEvent(widget, event) + + +@mouse_interaction +def mouseRelease(widget=None, pos=None, button=None, modifier=None): + if modifier is None: + modifier = Qt.KeyboardModifier.NoModifier + event = QMouseEvent( + QEvent.Type.MouseButtonRelease, pos, button, Qt.MouseButton.NoButton, modifier + ) + QApplication.sendEvent(widget, event) + + +@mouse_interaction +def mouseMove(widget=None, pos=None, button=None, modifier=None): + if button is None: + button = Qt.MouseButton.NoButton + if modifier is None: + modifier = Qt.KeyboardModifier.NoModifier + event = QMouseEvent( + QEvent.Type.MouseMove, pos, Qt.MouseButton.NoButton, button, modifier + ) + QApplication.sendEvent(widget, event) + + +def mouseClick(widget, pos, button, modifier=None): + mouseMove(widget=widget, pos=pos) + mousePress(widget=widget, pos=pos, button=button, modifier=modifier) + mouseRelease(widget=widget, pos=pos, button=button, modifier=modifier) + + +def mouseDrag(widget, positions, button, modifier=None): + mouseMove(widget=widget, pos=positions[0]) + mousePress(widget=widget, pos=positions[0], button=button, modifier=modifier) + for pos in positions[1:]: + mouseMove(widget=widget, pos=pos, button=button, modifier=modifier) + # For some reason moeve again to last position + mouseMove(widget=widget, pos=positions[-1], button=button, modifier=modifier) + mouseRelease(widget=widget, pos=positions[-1], button=button, modifier=modifier) + + def get_palette(theme): color_roles = { "foreground": ["WindowText", "ToolTipText", "Text"], diff --git a/mne_pipeline_hd/gui/loading_widgets.py b/mne_pipeline_hd/gui/loading_widgets.py index d346d82b..47c3a897 100644 --- a/mne_pipeline_hd/gui/loading_widgets.py +++ b/mne_pipeline_hd/gui/loading_widgets.py @@ -4,6 +4,7 @@ License: BSD 3-Clause Github: https://github.com/marsipu/mne-pipeline-hd """ +import logging import os import re import shutil @@ -82,8 +83,7 @@ def index_parser(index, all_items, groups=None): - """ - Parses indices from a index-string in all_items + """Parses indices from a index-string in all_items. Parameters ---------- @@ -93,7 +93,6 @@ def index_parser(index, all_items, groups=None): All items Returns ------- - """ indices = list() rm = list() @@ -158,7 +157,11 @@ def index_parser(index, all_items, groups=None): indices = [int(index)] indices = [i for i in indices if i not in rm] - files = np.asarray(all_items)[indices].tolist() + try: + files = np.asarray(all_items)[indices].tolist() + except IndexError: + logging.warning("Index out of range") + files = [] return files @@ -533,11 +536,9 @@ def sel_all(self): # Todo: Enable Drag&Drop class AddFilesWidget(QWidget): - def __init__(self, main_win): - super().__init__(main_win) - self.mw = main_win - self.ct = main_win.ct - self.pr = main_win.ct.pr + def __init__(self, ct): + super().__init__() + self.ct = ct self.layout = QVBoxLayout() self.erm_keywords = [ @@ -651,6 +652,7 @@ def insert_files(self, files_list): # Get already existing files and skip them if ( file_path in list(self.pd_files["Path"]) + # ToDo: Fix all pr stuff or file_name in self.pr.all_meeg or file_name in self.pr.all_erm ): @@ -856,11 +858,9 @@ def __init__(self, main_win): class AddMRIWidget(QWidget): - def __init__(self, main_win): - super().__init__(main_win) - self.mw = main_win - self.ct = main_win.ct - self.pr = main_win.ct.pr + def __init__(self, ct): + super().__init__() + self.ct = ct self.layout = QVBoxLayout() self.folders = list() @@ -1151,7 +1151,7 @@ def copy_bads(self): class SubBadsWidget(QWidget): - """A Dialog to select Bad-Channels for the files""" + """A Dialog to select Bad-Channels for the files.""" def __init__(self, main_win): """ @@ -1502,7 +1502,7 @@ def init_ui(self): self.setLayout(self.layout) def get_event_id(self): - """Get unique event-ids from events""" + """Get unique event-ids from events.""" if self.name in self.pr.meeg_event_id: self.event_id = self.pr.meeg_event_id[self.name] else: @@ -1547,7 +1547,7 @@ def save_event_id(self): self.pr.sel_event_id[self.name] = sel_event_id def file_selected(self, current, _): - """Called when File from file_widget is selected""" + """Called when File from file_widget is selected.""" # Save event_id for previous file self.save_event_id() @@ -1571,7 +1571,7 @@ def file_selected(self, current, _): self.checked_labels = list() self.update_check_list() - # ToDo: Make all combinations possible + # ToDo: Make all combinations possible and also int-keys (can't split) def update_check_list(self): self.labels = [k for k in self.queries.keys()] # Get selectable trials and update widget @@ -1739,7 +1739,7 @@ def copy_trans(self): class FileManagment(QDialog): - """A Dialog for File-Management + """A Dialog for File-Management. Parameters ---------- @@ -1998,7 +1998,7 @@ def _get_current(self, kind): return obj_name, path_type def show_parameters(self, kind): - """Show the parameters, which are different for the selected cell + """Show the parameters, which are different for the selected cell. Parameters ---------- @@ -2085,13 +2085,12 @@ def _remove_finished(self, kind): obj_table.content_changed() def remove_file(self, kind): - """Remove the file at the path of the current cell + """Remove the file at the path of the current cell. Parameters ---------- kind : str If it is MEEG, FSMRI or Group - """ msgbx = QMessageBox.question( diff --git a/mne_pipeline_hd/gui/main_window.py b/mne_pipeline_hd/gui/main_window.py index 9042c9b0..3a55d76e 100644 --- a/mne_pipeline_hd/gui/main_window.py +++ b/mne_pipeline_hd/gui/main_window.py @@ -70,6 +70,7 @@ SubjectWizard, ExportDialog, ) +from mne_pipeline_hd.gui.node.node_viewer import NodeViewer from mne_pipeline_hd.gui.parameter_widgets import ( BoolGui, IntGui, @@ -558,6 +559,245 @@ def add_func_bts(self): self.tab_func_widget.addTab(tab, tab_name) set_app_theme() + # Add experimental Node-Tab + self.node_viewer = NodeViewer(self.ct, self) + self.tab_func_widget.addTab(self.node_viewer, "Node-Graph") + self.tab_func_widget.setCurrentWidget(self.node_viewer) + + demo_dict = { + "Filter Raw": { + "parameters": { + "low_cutoff": { + "alias": "Low-Cutoff", + "gui": "FloatGui", + "default": 0.1, + }, + "high_cutoff": { + "alias": "High-Cutoff", + "gui": "FloatGui", + "default": 0.2, + }, + }, + "ports": [ + { + "name": "Raw", + "port_type": "in", + }, + { + "name": "Raw", + "port_type": "out", + "multi_connection": True, + }, + ], + }, + "Get Events": { + "parameters": { + "event_id": { + "alias": "Event-ID", + "gui": "IntGui", + "default": 1, + }, + }, + "ports": [ + { + "name": "Raw", + "port_type": "in", + }, + { + "name": "Events", + "port_type": "out", + "multi_connection": True, + }, + ], + }, + "Epoch Data": { + "parameters": { + "epochs_tmin": { + "alias": "tmin", + "gui": "FloatGui", + "default": -0.2, + }, + "epochs_tmax": { + "alias": "tmax", + "gui": "FloatGui", + "default": 0.5, + }, + "apply_baseline": { + "alias": "Baseline", + "gui": "BoolGui", + "default": True, + }, + }, + "ports": [ + { + "name": "Raw", + "port_type": "in", + }, + { + "name": "Events", + "port_type": "in", + }, + { + "name": "Epochs", + "port_type": "out", + "multi_connection": True, + }, + ], + }, + "Average Epochs": { + "parameters": { + "event_id": { + "alias": "Event-ID", + "gui": "IntGui", + "default": 1, + }, + }, + "ports": [ + { + "name": "Epochs", + "port_type": "in", + }, + { + "name": "Evokeds", + "port_type": "out", + "multi_connection": True, + }, + ], + }, + "Make Forward Model": { + "parameters": { + "fwd_subject": { + "alias": "Forward Subject", + "gui": "StringGui", + "default": "fsaverage", + }, + }, + "ports": [ + { + "name": "MRI", + "port_type": "in", + }, + { + "name": "Fwd", + "port_type": "out", + "multi_connection": True, + }, + ], + }, + "Make Inverse Operator": { + "parameters": { + "inv_subject": { + "alias": "Inverse Subject", + "gui": "StringGui", + "default": "fsaverage", + }, + }, + "ports": [ + { + "name": "Evokeds", + "port_type": "in", + }, + { + "name": "Fwd", + "port_type": "in", + }, + { + "name": "Inv", + "port_type": "out", + "multi_connection": True, + }, + ], + }, + "Plot Source Estimates": { + "parameters": { + "subject": { + "alias": "Subject", + "gui": "StringGui", + "default": "fsaverage", + }, + }, + "ports": [ + { + "name": "Inv", + "port_type": "in", + }, + { + "name": "Plot", + "port_type": "out", + "multi_connection": True, + }, + ], + }, + } + + # Add some demo nodes + meeg_node = self.node_viewer.create_node("MEEGInputNode") + mri_node = self.node_viewer.create_node("MRIInputNode") + ass_node = self.node_viewer.create_node( + node_class="AssignmentNode", + name="Assignment", + ports=[ + { + "name": "Evokeds", + "port_type": "in", + }, + { + "name": "Fwd", + "port_type": "in", + }, + { + "name": "Evokeds", + "port_type": "out", + }, + { + "name": "Fwd", + "port_type": "out", + }, + ], + ) + fn = dict() + for func_name, func_kwargs in demo_dict.items(): + fnode = self.node_viewer.create_node("FunctionNode", **func_kwargs) + fn[func_name] = fnode + + # Wire up the nodes + meeg_node.output(port_idx=0).connect_to(fn["Filter Raw"].input(port_idx=0)) + meeg_node.output(port_idx=0).connect_to(fn["Get Events"].input(port_idx=0)) + fn["Epoch Data"].input(port_name="Raw").connect_to( + fn["Filter Raw"].output(port_idx=0) + ) + fn["Epoch Data"].input(port_name="Events").connect_to( + fn["Get Events"].output(port_idx=0) + ) + fn["Epoch Data"].output(port_name="Epochs").connect_to( + fn["Average Epochs"].input(port_name="Epochs") + ) + + mri_node.output(port_idx=0).connect_to( + fn["Make Forward Model"].input(port_name="MRI") + ) + + ass_node.input(port_idx=0).connect_to( + fn["Average Epochs"].output(port_name="Evokeds") + ) + ass_node.input(port_idx=1).connect_to( + fn["Make Forward Model"].output(port_name="Fwd") + ) + ass_node.output(port_idx=0).connect_to( + fn["Make Inverse Operator"].input(port_name="Evokeds") + ) + ass_node.output(port_idx=1).connect_to( + fn["Make Inverse Operator"].input(port_name="Fwd") + ) + + fn["Plot Source Estimates"].input(port_name="Inv").connect_to( + fn["Make Inverse Operator"].output(port_name="Inv") + ) + + self.node_viewer.auto_layout_nodes() + self.node_viewer.clear_selection() + self.node_viewer.fit_to_selection() + def update_func_bts(self): # Remove tabs in tab_func_widget while self.tab_func_widget.count(): diff --git a/mne_pipeline_hd/gui/models.py b/mne_pipeline_hd/gui/models.py index c9707cd6..e73d8761 100644 --- a/mne_pipeline_hd/gui/models.py +++ b/mne_pipeline_hd/gui/models.py @@ -25,7 +25,7 @@ class BaseListModel(QAbstractListModel): - """A basic List-Model + """A basic List-Model. Parameters ---------- @@ -99,7 +99,7 @@ def supportedDragActions(self): class EditListModel(BaseListModel): - """An editable List-Model + """An editable List-Model. Parameters ---------- @@ -133,8 +133,7 @@ def setData(self, index, value, role=None): class CheckListModel(BaseListModel): - """ - A Model for a Check-List + """A Model for a Check-List. Parameters ---------- @@ -146,7 +145,6 @@ class CheckListModel(BaseListModel): Set True if you want to display the list-index in front of each value drag_drop: bool Set True to enable Drag&Drop. - """ def __init__( @@ -204,8 +202,7 @@ def flags(self, index): class CheckDictModel(BaseListModel): - """ - A Model for a list, which marks items which are present in a dictionary + """A Model for a list, which marks items which are present in a dictionary. Parameters ---------- @@ -263,7 +260,7 @@ def data(self, index, role=None): class CheckDictEditModel(CheckDictModel, EditListModel): - """An editable List-Model + """An editable List-Model. Parameters ---------- @@ -305,7 +302,7 @@ def __init__( class BaseDictModel(QAbstractTableModel): - """Basic Model for Dictonaries + """Basic Model for Dictonaries. Parameters ---------- @@ -360,7 +357,7 @@ def columnCount(self, parent=None, *args, **kwargs): # ToDo: Somehow inputs are automatically sorted (annoyig, disable-toggle) class EditDictModel(BaseDictModel): - """An editable model for Dictionaries + """An editable model for Dictionaries. Parameters ---------- @@ -431,7 +428,7 @@ def removeRows(self, row, count, parent=None, *args, **kwargs): class BasePandasModel(QAbstractTableModel): - """Basic Model for pandas DataFrame + """Basic Model for pandas DataFrame. Parameters ---------- @@ -469,7 +466,8 @@ def columnCount(self, parent=None, *args, **kwargs): class EditPandasModel(BasePandasModel): - """Editable TableModel for Pandas DataFrames + """Editable TableModel for Pandas DataFrames. + Parameters ---------- data : pandas.DataFrame | None @@ -810,8 +808,8 @@ def removeRows(self, row, count, parent=None, *args, **kwargs): class FileManagementModel(BasePandasModel): - """A model for the Pandas-DataFrames containing information - about the existing files""" + """A model for the Pandas-DataFrames containing information about the existing + files.""" def __init__(self, data, **kwargs): super().__init__(data, **kwargs) @@ -857,9 +855,8 @@ def data(self, index, role=None): class CustomFunctionModel(QAbstractListModel): - """A Model for the Pandas-DataFrames containing information about - new custom functions/their paramers to display only their name - and if they are ready. + """A Model for the Pandas-DataFrames containing information about new custom + functions/their paramers to display only their name and if they are ready. Parameters ---------- @@ -893,7 +890,7 @@ def rowCount(self, parent=None, *args, **kwargs): class RunModel(QAbstractListModel): - """A model for the items/functions of a Pipeline-Run""" + """A model for the items/functions of a Pipeline-Run.""" def __init__(self, data, mode, **kwargs): super().__init__(**kwargs) diff --git a/mne_pipeline_hd/gui/node/__init__.py b/mne_pipeline_hd/gui/node/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mne_pipeline_hd/gui/node/base_node.py b/mne_pipeline_hd/gui/node/base_node.py new file mode 100644 index 00000000..8d7cef5d --- /dev/null +++ b/mne_pipeline_hd/gui/node/base_node.py @@ -0,0 +1,707 @@ +# -*- coding: utf-8 -*- +import logging +from collections import OrderedDict + +from mne_pipeline_hd.gui.gui_utils import format_color +from mne_pipeline_hd.gui.node.node_defaults import defaults +from mne_pipeline_hd.gui.node.ports import Port +from qtpy.QtCore import QRectF, Qt +from qtpy.QtGui import QColor, QPen, QPainterPath +from qtpy.QtWidgets import QGraphicsItem, QGraphicsTextItem, QGraphicsProxyWidget + + +class NodeTextItem(QGraphicsTextItem): + def __init__(self, text, parent=None): + super().__init__(text, parent) + self.setTextInteractionFlags(Qt.TextInteractionFlag.NoTextInteraction) + + +class BaseNode(QGraphicsItem): + """Base class for all nodes in the NodeGraph. + + Parameters + ---------- + ct : Controller + A Controller-instance, where all session information is stored and managed. + name : str + Name of the node. + ports : dict, list + Dictionary with keys as (old) port id and values as dictionaries which contain kwargs for the :meth:`BaseNode.add_port()`. + Can also be just a list with kwargs for the :meth:`BaseNode.add_port()`. + old_id : int, None, optional + Old id for reestablishing connections. + """ + + def __init__(self, ct, name=None, ports=None, old_id=None): + self.ct = ct + # Initialize QGraphicsItem + super().__init__() + self.setFlags( + self.GraphicsItemFlag.ItemIsSelectable | self.GraphicsItemFlag.ItemIsMovable + ) + self.setCacheMode(QGraphicsItem.DeviceCoordinateCache) + self.setZValue(1) + + # Initialize hidden attributes for properties (with node_defaults) + self.id = id(self) + self.old_id = old_id + self._name = name + + self._width = defaults["nodes"]["width"] + self._height = defaults["nodes"]["height"] + self._color = defaults["nodes"]["color"] + self._selected_color = defaults["nodes"]["selected_color"] + self._border_color = defaults["nodes"]["border_color"] + self._selected_border_color = defaults["nodes"]["selected_border_color"] + self._text_color = defaults["nodes"]["text_color"] + + self._title_item = NodeTextItem(self._name, self) + self._inputs = OrderedDict() + self._outputs = OrderedDict() + self._widgets = list() + + # Initialize iports + ports = ports or list() + # If old id is added for reestablishing connections + if isinstance(ports, dict): + for port_kwargs in ports.values(): + self.add_port(**port_kwargs) + else: + for port_kwargs in ports: + self.add_port(**port_kwargs) + + @property + def name(self): + return self._name + + @name.setter + def name(self, name): + self._name = name + self._title_item.setPlainText(name) + + @property + def width(self): + return self._width + + @width.setter + def width(self, width): + width = max(width, defaults["nodes"]["width"]) + self._width = width + + @property + def height(self): + return self._height + + @height.setter + def height(self, height): + height = max(height, defaults["nodes"]["height"]) + self._height = height + + @property + def color(self): + return self._color + + @color.setter + def color(self, color): + self._color = format_color(color) + self.update() + + @property + def selected_color(self): + return self._selected_color + + @selected_color.setter + def selected_color(self, color): + self._selected_color = format_color(color) + self.update() + + @property + def border_color(self): + return self._border_color + + @border_color.setter + def border_color(self, color): + self._border_color = format_color(color) + self.update() + + @property + def selected_border_color(self): + return self._selected_border_color + + @selected_border_color.setter + def selected_border_color(self, color): + self._selected_border_color = format_color(color) + self.update() + + @property + def text_color(self): + return self._text_color + + @text_color.setter + def text_color(self, color): + self._text_color = format_color(color) + self.update() + + @property + def inputs(self): + """Returns the input ports in a list (self._inputs is an OrderedDict and can be + accessed internally when necessary)""" + return list(self._inputs.values()) + + @property + def outputs(self): + """Returns the output ports in a list (self._outputs is an OrderedDict and can + be accessed internally when necessary)""" + return list(self._outputs.values()) + + @property + def ports(self): + """Returns all ports in a list.""" + return list(self._inputs.values()) + list(self._outputs.values()) + + @property + def widgets(self): + return self._widgets + + @property + def xy_pos(self): + """Return the item scene postion. ("node.pos" conflicted with + "QGraphicsItem.pos()" so it was refactored to "xy_pos".) + + Returns: + list[float]: x, y scene position. + """ + return [float(self.scenePos().x()), float(self.scenePos().y())] + + @xy_pos.setter + def xy_pos(self, pos=None): + """Set the item scene postion. ("node.pos" conflicted with "QGraphicsItem.pos()" + so it was refactored to "xy_pos".) + + Args: + pos (list[float]): x, y scene position. + """ + pos = pos or (0.0, 0.0) + self.setPos(*pos) + + @property + def viewer(self): + if self.scene(): + return self.scene().viewer() + + # ---------------------------------------------------------------------------------- + # Logic methods + # ---------------------------------------------------------------------------------- + def add_port( + self, + name, + port_type, + multi_connection=False, + accepted_ports=None, + old_id=None, + ): + """Adds a Port QGraphicsItem into the node. + + Parameters + ---------- + name : str + name for the port. + port_type : str + "in" or "out". + multi_connection : bool + allow multiple connections. + accepted_ports : list, None + list of accepted port names, if None all ports are accepted. + old_id : int, None, optional + old port id for reestablishing connections. + + Returns + ------- + PortItem + Port QGraphicsItem. + """ + # Check port type + if port_type not in ["in", "out"]: + raise ValueError(f"Invalid port type: {port_type}") + # port names must be unique for inputs/outputs + existing = self.inputs if port_type == "in" else self.outputs + if name in [p.name for p in existing]: + logging.warning(f"Input port {name} already exists.") + return + # Create port + port = Port( + self, name, port_type, multi_connection, accepted_ports, old_id=old_id + ) + # Add port to port-container + ports = self._inputs if port_type == "in" else self._outputs + ports[port.id] = port + # Update scene + if self.scene(): + self.draw_node() + + return port + + def add_input( + self, + name, + multi_connection=False, + accepted_ports=None, + ): + """Adds a Port QGraphicsItem into the node as input. + + Parameters + ---------- + name : str + name for the port. + multi_connection : bool + allow multiple connections. + accepted_ports : list, None + list of accepted port names, if None all ports are accepted. + + Returns + ------- + PortItem + Port QGraphicsItem. + """ + port = self.add_port(name, "in", multi_connection, accepted_ports) + + return port + + def add_output( + self, + name, + multi_connection=False, + accepted_ports=None, + ): + """Adds a Port QGraphicsItem into the node as output. + + Parameters + ---------- + name : str + name for the port. + multi_connection : bool + allow multiple connections. + accepted_ports : list, None + list of accepted port names, if None all ports are accepted. + + Returns + ------- + PortItem + Port QGraphicsItem. + """ + port = self.add_port(name, "out", multi_connection, accepted_ports) + + return port + + def port( + self, port_type=None, port_idx=None, port_name=None, port_id=None, old_id=None + ): + """Get port by the name or index. + + Parameters + ---------- + port_type : str, None + "in" or "out". If None, inputs and outputs will be searched. + port_idx : int + Index of the port. + port_name : str, optional + Name of the port. + port_id : int, optional + Id of the port. + old_id : int, optional + Old id of the port for reestablishing connections. + + Returns + ------- + Port + The port that matches the provided index, name, or id. If multiple + parameters are provided, the method will prioritize them in + the following order: port_idx, port_name, port_id, old_id. + If no parameters are provided or if no match is found. + the method will return None. + """ + if port_type is None: + ports = self._inputs | self._outputs + elif port_type not in ["in", "out"]: + raise ValueError(f"Invalid port type: {port_type}") + else: + ports = self._inputs if port_type == "in" else self._outputs + port_list = list(ports.values()) + + if port_idx is not None: + if not isinstance(port_idx, int): + raise ValueError(f"Invalid port index: {port_idx}") + if port_idx < len(port_list): + return port_list[port_idx] + else: + logging.warning(f"{port_type} port {port_idx} not found.") + elif port_name is not None: + if not isinstance(port_name, str): + raise ValueError(f"Invalid port name: {port_name}") + port_names = [p for p in port_list if p.name == port_name] + if len(port_names) > 1: + logging.warning( + "More than one port with the same name. This should not be allowed." + ) + elif len(port_names) == 0: + logging.warning(f"{port_type} port {port_name} not found.") + else: + return port_names[0] + elif port_id is not None: + if not isinstance(port_id, int): + raise ValueError(f"Invalid port id: {port_id}") + if port_id in ports: + return ports[port_id] + else: + logging.warning(f"{port_type} port {port_id} not found.") + elif old_id is not None: + if not isinstance(old_id, int): + raise ValueError(f"Invalid old port id: {old_id}") + old_id_ports = [p for p in port_list if p.old_id == old_id] + if len(old_id_ports) > 1: + logging.warning( + "More than one port with the same old id. This should not be allowed." + ) + elif len(old_id_ports) == 0: + logging.warning(f"{port_type} port with old id {old_id} not found.") + else: + return old_id_ports[0] + else: + logging.warning("No port identifier provided.") + + def input(self, **port_kwargs): + """Get input port by the name, index, id or old id as in port().""" + return self.port(port_type="in", **port_kwargs) + + def output(self, **port_kwargs): + """Get output port by the name, index, id or old id as in port().""" + return self.port(port_type="out", **port_kwargs) + + def connected_input_nodes(self): + """Returns all nodes connected from the input ports. + + Returns: + dict: {: } + """ + nodes = OrderedDict() + for p in self.inputs: + nodes[p] = [cp.node for cp in p.connected_ports] + return nodes + + def connected_output_nodes(self): + """Returns all nodes connected from the output ports. + + Returns: + dict: {: } + """ + nodes = OrderedDict() + for p in self.outputs: + nodes[p] = [cp.node for cp in p.connected_ports] + return nodes + + def add_widget(self, widget): + """Add widget to the node.""" + proxy_widget = QGraphicsProxyWidget(self) + proxy_widget.setWidget(widget) + self.widgets.append(proxy_widget) + + def delete(self): + """Remove node from the scene.""" + if self.scene() is not None: + self.scene().removeItem(self) + del self + + def to_dict(self): + node_dict = { + "name": self.name, + "class": self.__class__.__name__, + "pos": self.xy_pos, + "ports": {p.id: p.to_dict() for p in self.ports}, + "old_id": self.id, + } + + return node_dict + + @classmethod + def from_dict(cls, ct, node_dict): + node_kwargs = {k: v for k, v in node_dict.items() if k not in ["class", "pos"]} + node = cls(ct, **node_kwargs) + node.xy_pos = node_dict["pos"] + + return node + + # ---------------------------------------------------------------------------------- + # Qt methods + # ---------------------------------------------------------------------------------- + def boundingRect(self): + # NodeViewer.node_position_scene() depends + # on the position of boundingRect to be (0, 0). + return QRectF(0, 0, self.width, self.height) + + def paint(self, painter, option, widget=None): + painter.save() + painter.setPen(Qt.PenStyle.NoPen) + painter.setBrush(Qt.BrushStyle.NoBrush) + + # base background. + margin = 1.0 + rect = self.boundingRect() + rect = QRectF( + rect.left() + margin, + rect.top() + margin, + rect.width() - (margin * 2), + rect.height() - (margin * 2), + ) + + radius = 4.0 + painter.setBrush(QColor(*self.color)) + painter.drawRoundedRect(rect, radius, radius) + + # light overlay on background when selected. + if self.isSelected(): + painter.setBrush(QColor(*self.selected_color)) + painter.drawRoundedRect(rect, radius, radius) + + # node name background. + padding = 3.0, 2.0 + text_rect = self._title_item.boundingRect() + text_rect = QRectF( + text_rect.x() + padding[0], + rect.y() + padding[1], + rect.width() - padding[0] - margin, + text_rect.height() - (padding[1] * 2), + ) + if self.isSelected(): + painter.setBrush(QColor(*self.selected_color)) + else: + painter.setBrush(QColor(0, 0, 0, 80)) + painter.drawRoundedRect(text_rect, 3.0, 3.0) + + # node border + if self.isSelected(): + border_width = 1.2 + border_color = QColor(*self.selected_border_color) + else: + border_width = 0.8 + border_color = QColor(*self.border_color) + + border_rect = QRectF(rect.left(), rect.top(), rect.width(), rect.height()) + + pen = QPen(border_color, border_width) + pen.setCosmetic(True) + path = QPainterPath() + path.addRoundedRect(border_rect, radius, radius) + painter.setBrush(Qt.BrushStyle.NoBrush) + painter.setPen(pen) + painter.drawPath(path) + + painter.restore() + + def mousePressEvent(self, event): + """Re-implemented to ignore event if LMB is over port collision area. + + Args: + event (QtWidgets.QGraphicsSceneMouseEvent): mouse event. + """ + if event.button() == Qt.MouseButton.LeftButton: + for p in self.inputs + self.outputs: + if p.hovered: + event.ignore() + return + super().mousePressEvent(event) + + def mouseReleaseEvent(self, event): + """Re-implemented to ignore event if Alt modifier is pressed. + + Args: + event (QtWidgets.QGraphicsSceneMouseEvent): mouse event. + """ + if event.modifiers() == Qt.KeyboardModifier.AltModifier: + event.ignore() + return + super().mouseReleaseEvent(event) + + def itemChange(self, change, value): + """Re-implemented to update pipes on selection changed. + + Args: + change: + value: + """ + if change == self.GraphicsItemChange.ItemSelectedChange and self.scene(): + self.reset_pipes() + if value: + self.highlight_pipes() + if self.isSelected(): + self.setZValue(1) + else: + self.setZValue(2) + + return super().itemChange(change, value) + + def _set_base_size(self, add_w=0.0, add_h=0.0): + """Sets the initial base size for the node. + + Args: + add_w (float): add additional width. + add_h (float): add additional height. + """ + self.width, self.height = self.calc_size(add_w, add_h) + + def _set_text_color(self, color): + """Set text color. + + Args: + color (tuple): color value in (r, g, b, a). + """ + for port in self.inputs + self.outputs: + port.text_color = color + self._title_item.setDefaultTextColor(QColor(*color)) + + def activate_pipes(self): + """Active pipe color.""" + ports = self.inputs + self.outputs + for port in ports: + for pipe in port.connected_pipes.values(): + pipe.activate() + + def highlight_pipes(self): + """Highlight pipe color.""" + ports = self.inputs + self.outputs + for port in ports: + for pipe in port.connected_pipes.values(): + pipe.highlight() + + def reset_pipes(self): + """Reset all the pipe colors.""" + ports = self.inputs + self.outputs + for port in ports: + for pipe in port.connected_pipes.values(): + pipe.reset() + + @staticmethod + def _get_ports_size(ports): + width = 0.0 + height = 0.0 + for port in ports: + if not port.isVisible(): + continue + port_width = port.boundingRect().width() / 2 + text_width = port.text.boundingRect().width() + width = max([width, port_width + text_width]) + height += port.boundingRect().height() + return width, height + + def calc_size(self, add_w=0.0, add_h=0.0): + # width, height from node name text. + title_width = self._title_item.boundingRect().width() + title_height = self._title_item.boundingRect().height() + + # width, height from node ports. + input_width, input_height = self._get_ports_size(self.inputs) + + # width, height from outputs + output_width, output_height = self._get_ports_size(self.outputs) + + # width, height from node embedded widgets. + widget_width = 0.0 + widget_height = 0.0 + for proxy_widget in self.widgets: + if not proxy_widget.isVisible(): + continue + w_width = proxy_widget.boundingRect().width() + w_height = proxy_widget.boundingRect().height() + widget_width = max([widget_width, w_width]) + widget_height += w_height + + width = input_width + output_width + height = max([title_height, input_height, output_height, widget_height]) + # add additional width for node widget. + if widget_width: + width += widget_width + # add padding if no inputs or outputs. + if not self.inputs or not self.outputs: + width += 10 + # add bottom margin for node widget. + if widget_height: + height += 10 + + width += add_w + height += add_h + + width = max([width, title_width]) + + return width, height + + def align_title(self): + rect = self.boundingRect() + text_rect = self._title_item.boundingRect() + x = rect.center().x() - (text_rect.width() / 2) + self._title_item.setPos(x, rect.y()) + + def align_widgets(self, v_offset=0.0): + if not self.widgets: + return + rect = self.boundingRect() + y = rect.y() + v_offset + inputs = [p for p in self.inputs if p.isVisible()] + outputs = [p for p in self.outputs if p.isVisible()] + for widget in self.widgets: + if not widget.isVisible(): + continue + widget_rect = widget.boundingRect() + if not inputs: + x = rect.left() + 10 + elif not outputs: + x = rect.right() - widget_rect.width() - 10 + else: + x = rect.center().x() - (widget_rect.width() / 2) + widget.setPos(x, y) + y += widget_rect.height() + + def align_ports(self, v_offset=0.0): + width = self._width + spacing = 1 + + # adjust input position + inputs = [p for p in self.inputs if p.isVisible()] + if inputs: + port_width = inputs[0].boundingRect().width() + port_height = inputs[0].boundingRect().height() + port_x = (port_width / 2) * -1 + port_y = v_offset + for port in inputs: + port.setPos(port_x, port_y) + port_y += port_height + spacing + + # adjust output position + outputs = [p for p in self.outputs if p.isVisible()] + if outputs: + port_width = outputs[0].boundingRect().width() + port_height = outputs[0].boundingRect().height() + port_x = width - (port_width / 2) + port_y = v_offset + for port in outputs: + port.setPos(port_x, port_y) + port_y += port_height + spacing + + def draw_node(self): + if self.scene(): + height = self._title_item.boundingRect().height() + 4 + + # setup initial base size. + self._set_base_size(add_h=height) + # set text color when node is initialized. + self._set_text_color(self.text_color) + + # --- set the initial node layout --- + # align title text + self.align_title() + # arrange input and output ports. + self.align_ports(v_offset=height) + # arrange node widgets + self.align_widgets(v_offset=height) + + self.update() + else: + logging.warning("Node not in scene.") diff --git a/mne_pipeline_hd/gui/node/node_defaults.py b/mne_pipeline_hd/gui/node/node_defaults.py new file mode 100644 index 00000000..49628a2b --- /dev/null +++ b/mne_pipeline_hd/gui/node/node_defaults.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +from qtpy.QtCore import Qt + +defaults = { + "nodes": { + "width": 160, + "height": 60, + "color": (23, 32, 41, 255), + "selected_color": (255, 255, 255, 30), + "border_color": (74, 84, 85, 255), + "selected_border_color": (254, 207, 42, 255), + "text_color": (255, 255, 255, 180), + }, + "ports": { + "size": 22, + "color": (49, 115, 100, 255), + "border_color": (29, 202, 151, 255), + "active_color": (14, 45, 59, 255), + "active_border_color": (107, 166, 193, 255), + "hover_color": (17, 43, 82, 255), + "hover_border_color": (136, 255, 35, 255), + "click_falloff": 15, + }, + "pipes": { + "width": 1.2, + "color": (175, 95, 30, 255), + "disabled_color": (200, 60, 60, 255), + "active_color": (70, 255, 220, 255), + "highlight_color": (232, 184, 13, 255), + "style": Qt.SolidLine, + }, + "viewer": { + "background_color": (35, 35, 35), + "grid_mode": "lines", + "grid_size": 50, + "grid_color": (45, 45, 45), + "zoom_min": -0.95, + "zoom_max": 2.0, + "pipe_layout": "curved", + }, + "slicer": {"width": 1.5, "color": (255, 50, 75)}, +} diff --git a/mne_pipeline_hd/gui/node/node_scene.py b/mne_pipeline_hd/gui/node/node_scene.py new file mode 100644 index 00000000..86cfbfb7 --- /dev/null +++ b/mne_pipeline_hd/gui/node/node_scene.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +from mne_pipeline_hd.gui.node.node_defaults import defaults +from qtpy.QtCore import Qt, QLineF +from qtpy.QtGui import QColor, QPen, QPainter +from qtpy.QtWidgets import QGraphicsScene + + +class NodeScene(QGraphicsScene): + def __init__(self, parent=None): + super(NodeScene, self).__init__(parent) + self._grid_mode = "lines" + self._grid_size = defaults["viewer"]["grid_size"] + self._grid_color = defaults["viewer"]["grid_color"] + self._bg_color = defaults["viewer"]["background_color"] + self.setBackgroundBrush(QColor(*self._bg_color)) + + @property + def grid_mode(self): + return self._grid_mode + + @grid_mode.setter + def grid_mode(self, mode=None): + if mode is None: + mode = defaults["viewer"]["grid_mode"] + self._grid_mode = mode + + @property + def grid_size(self): + return self._grid_size + + @grid_size.setter + def grid_size(self, size=None): + if size is None: + size = defaults["viewer"]["grid_size"] + self._grid_size = size + + @property + def grid_color(self): + return self._grid_color + + @grid_color.setter + def grid_color(self, color=None): + if color is None: + color = defaults["viewer"]["grid_color"] + self._grid_color = color + + @property + def bg_color(self): + return self._bg_color + + @bg_color.setter + def bg_color(self, color=None): + if color is None: + color = defaults["viewer"]["background_color"] + self._bg_color = color + self.setBackgroundBrush(QColor(*self._bg_color)) + + def _draw_grid(self, painter, rect, pen, grid_size): + """Draws the grid lines in the scene. + + Args: + painter (QPainter): painter object. + rect (QRectF): rect object. + pen (QPen): pen object. + grid_size (int): grid size. + """ + left = int(rect.left()) + right = int(rect.right()) + top = int(rect.top()) + bottom = int(rect.bottom()) + + first_left = left - (left % grid_size) + first_top = top - (top % grid_size) + + lines = [] + lines.extend( + [QLineF(x, top, x, bottom) for x in range(first_left, right, grid_size)] + ) + lines.extend( + [QLineF(left, y, right, y) for y in range(first_top, bottom, grid_size)] + ) + + painter.setPen(pen) + painter.drawLines(lines) + + def _draw_dots(self, painter, rect, pen, grid_size): + """Draws the grid dots in the scene. + + Args: + painter (QPainter): painter object. + rect (QRectF): rect object. + pen (QPen): pen object. + grid_size (int): grid size. + """ + zoom = self.viewer().get_zoom() + if zoom < 0: + grid_size = int(abs(zoom) / 0.3 + 1) * grid_size + + left = int(rect.left()) + right = int(rect.right()) + top = int(rect.top()) + bottom = int(rect.bottom()) + + first_left = left - (left % grid_size) + first_top = top - (top % grid_size) + + pen.setWidth(grid_size / 10) + painter.setPen(pen) + + [ + painter.drawPoint(int(x), int(y)) + for x in range(first_left, right, grid_size) + for y in range(first_top, bottom, grid_size) + ] + + def drawBackground(self, painter, rect): + super(NodeScene, self).drawBackground(painter, rect) + + painter.save() + painter.setRenderHint(QPainter.RenderHint.Antialiasing, False) + painter.setBrush(self.backgroundBrush()) + + if self._grid_mode == "dots": + pen = QPen(QColor(*self.grid_color), 0.65) + self._draw_dots(painter, rect, pen, self._grid_size) + + elif self._grid_mode == "lines": + zoom = self.viewer().get_zoom() + if zoom > -0.5: + pen = QPen(QColor(*self.grid_color), 0.65) + self._draw_grid(painter, rect, pen, self.grid_size) + + color = QColor(*self._bg_color).darker(200) + if zoom < -0.0: + color = color.darker(100 - int(zoom * 110)) + pen = QPen(color, 0.65) + self._draw_grid(painter, rect, pen, self.grid_size * 8) + + painter.restore() + + def mousePressEvent(self, event): + selected_nodes = self.viewer().selected_nodes() + if self.viewer(): + self.viewer().sceneMousePressEvent(event) + super(NodeScene, self).mousePressEvent(event) + keep_selection = any( + [ + event.button() == Qt.MouseButton.MiddleButton, + event.button() == Qt.MouseButton.RightButton, + event.modifiers() == Qt.KeyboardModifier.AltModifier, + ] + ) + if keep_selection: + for node in selected_nodes: + node.setSelected(True) + + def mouseMoveEvent(self, event): + if self.viewer(): + self.viewer().sceneMouseMoveEvent(event) + super(NodeScene, self).mouseMoveEvent(event) + + def mouseReleaseEvent(self, event): + if self.viewer(): + self.viewer().sceneMouseReleaseEvent(event) + super(NodeScene, self).mouseReleaseEvent(event) + + def viewer(self): + return self.views()[0] if self.views() else None diff --git a/mne_pipeline_hd/gui/node/node_viewer.py b/mne_pipeline_hd/gui/node/node_viewer.py new file mode 100644 index 00000000..4c24a87e --- /dev/null +++ b/mne_pipeline_hd/gui/node/node_viewer.py @@ -0,0 +1,1299 @@ +# -*- coding: utf-8 -*- +import logging +import math +from collections import OrderedDict + +import qtpy + +from mne_pipeline_hd.gui.node import nodes +from mne_pipeline_hd.gui.gui_utils import invert_rgb_color +from mne_pipeline_hd.gui.node.base_node import BaseNode +from mne_pipeline_hd.gui.node.node_defaults import defaults +from mne_pipeline_hd.gui.node.node_scene import NodeScene +from mne_pipeline_hd.gui.node.pipes import LivePipeItem, SlicerPipeItem, Pipe +from mne_pipeline_hd.gui.node.ports import Port +from qtpy.QtCore import QMimeData, QPointF, QPoint, QRectF, Qt, QRect, QSize, Signal +from qtpy.QtGui import QColor, QPainter, QPainterPath +from qtpy.QtWidgets import ( + QGraphicsView, + QRubberBand, + QGraphicsTextItem, + QGraphicsPathItem, +) + + +class NodeViewer(QGraphicsView): + """The NodeGraph displays the nodes and connections and manages them.""" + + # ---------------------------------------------------------------------------------- + # Signals + # ---------------------------------------------------------------------------------- + NodesCreated = Signal(list) + NodesDeleted = Signal(list) + NodeDoubleClicked = Signal(BaseNode) + PortConnected = Signal(Port, Port) + PortDisconnected = Signal(Port, Port) + DataDropped = Signal(QMimeData, QPointF) + + MovedNodes = Signal(dict) + ConnectionChanged = Signal(list, list) + InsertNode = Signal(object, str, dict) + NodeNameChanged = Signal(str, str) + + def __init__(self, ct, parent=None, debug_mode=False): + super().__init__(parent) + self.ct = ct + self._debug_mode = debug_mode + + # attributes + self._nodes = OrderedDict() + self._pipe_layout = defaults["viewer"]["pipe_layout"] + self._last_size = self.size() + self._detached_port = None + self._start_port = None + self._origin_pos = None + self._previous_pos = QPoint(int(self.width() / 2), int(self.height() / 2)) + self._prev_selection_nodes = [] + self._prev_selection_pipes = [] + self._node_positions = {} + self.LMB_state = False + self.RMB_state = False + self.MMB_state = False + self.COLLIDING_state = False + + # init QGraphicsView + self.setScene(NodeScene(self)) + self.setRenderHint(QPainter.RenderHint.Antialiasing, True) + self.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) + self.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) + self.setViewportUpdateMode(QGraphicsView.ViewportUpdateMode.FullViewportUpdate) + self.setCacheMode(QGraphicsView.CacheModeFlag.CacheBackground) + self.setOptimizationFlag( + QGraphicsView.OptimizationFlag.DontAdjustForAntialiasing + ) + self.setAcceptDrops(True) + + # set initial range + self._scene_range = QRectF(0, 0, self.size().width(), self.size().height()) + self._update_scene() + + # initialize rubberband + self._rubber_band = QRubberBand(QRubberBand.Shape.Rectangle, self) + self._rubber_band.isActive = False + + # initialize cursor text + text_color = QColor(*invert_rgb_color(defaults["viewer"]["background_color"])) + text_color.setAlpha(50) + self._cursor_text = QGraphicsTextItem() + self._cursor_text.setFlag( + self._cursor_text.GraphicsItemFlag.ItemIsSelectable, False + ) + self._cursor_text.setDefaultTextColor(text_color) + self._cursor_text.setZValue(-2) + font = self._cursor_text.font() + font.setPointSize(7) + self._cursor_text.setFont(font) + self.scene().addItem(self._cursor_text) + + # initialize live pipe + self._LIVE_PIPE = LivePipeItem() + self._LIVE_PIPE.setVisible(False) + self.scene().addItem(self._LIVE_PIPE) + + # initialize slicer pipe + self._SLICER_PIPE = SlicerPipeItem() + self._SLICER_PIPE.setVisible(False) + self.scene().addItem(self._SLICER_PIPE) + + # initialize debug path + self._debug_path = QGraphicsPathItem() + self._debug_path.setZValue(1) + pen = self._debug_path.pen() + pen.setColor(QColor(255, 0, 0, 255)) + pen.setWidth(2) + self._debug_path.setPen(pen) + self._debug_path.setPath(QPainterPath()) + self.scene().addItem(self._debug_path) + + # ---------------------------------------------------------------------------------- + # Properties + # ---------------------------------------------------------------------------------- + @property + def nodes(self): + """Return list of nodes in the node graph. + + Returns: + -------- + nodes: OrderedDict + The nodes are stored in an OrderedDict with the node id as the key. + """ + return self._nodes + + @property + def pipe_layout(self): + """Return the pipe layout mode. + + Returns + ------- + layout: str + Pipe layout mode (either 'straight', 'curved', or 'angle'). + """ + return self._pipe_layout + + @pipe_layout.setter + def pipe_layout(self, layout): + """Set the pipe layout mode. + + Parameters + ---------- + layout: str + Pipe layout mode (either 'straight', 'curved', or 'angle'). + """ + if layout not in ["straight", "curved", "angle"]: + logging.warning( + f"{layout} is not a valid pipe layout, " f"defaulting to 'curved'." + ) + layout = "curved" + self._pipe_layout = layout + + # ---------------------------------------------------------------------------------- + # Logic methods + # ---------------------------------------------------------------------------------- + def add_node(self, node): + """Add a node to the node graph. + + See Also + -------- + NodeGraph.registered_nodes : To list all node types + + Parameters + ---------- + node : BaseNode + The node to add to the node graph. + """ + self.scene().addItem(node) + self._nodes[node.id] = node + + # draw node (necessary to redraw after it is added to the scene) + node.draw_node() + + return node + + def remove_node(self, node): + """Remove a node from the node graph. + + Parameters + ---------- + node : BaseNode + Node instance to remove. + """ + # Remove connected pipes + for port in node.ports: + for connected_port in port.connected_ports: + port.disconnect_from(connected_port) + # Remove node + if node in self.scene().items(): + self.scene().removeItem(node) + # Deliberately with room for KeyError to detect, + # if nodes are not correctly added in the first place + self.nodes.pop(node.id) + + node.delete() + + def create_node(self, node_class="BaseNode", **kwargs): + """Create a node from the given class. + + Parameters + ---------- + node_class: str + A string to speficy the node class. + kwargs: dict + Additional keyword arguments to pass into BaseNode.__init__(). + + Returns + ------- + node + The created node. + """ + if isinstance(node_class, str): + node_class = getattr(nodes, node_class) + node = node_class(self.ct, **kwargs) + else: + raise ValueError("node_info must be a string.") + + self.add_node(node) + + return node + + def node(self, node_idx=None, node_name=None, node_id=None, old_id=None): + """Get a node from the node graph based on either its index, name, or id. + + Parameters + ---------- + node_idx : int, optional + Index of the node in the node graph. + node_name : str, optional + Name of the node in the node graph. + node_id : str, optional + Unique identifier of the node in the node graph. + old_id: int, optional + Old id of the node for reestablishing connections. + + Returns + ------- + BaseNode + The node that matches the provided index, name, or id. If multiple + parameters are provided, the method will prioritize them in + the following order: node_idx, node_name, node_id, old_id. + If no parameters are provided or if no match is found, + the method will return None. + """ + if node_idx is not None: + return list(self.nodes.values())[node_idx] + elif node_name is not None: + return [n for n in self.nodes.values() if n.name == node_name] + elif node_id is not None: + return self.nodes[node_id] + elif old_id is not None: + for node in self.nodes.values(): + if node.old_id == old_id: + return node + logging.warning("No node found with the provided parameters.") + + def to_dict(self): + viewer_dict = dict() + viewer_dict["nodes"] = { + node_id: node.to_dict() for node_id, node in self.nodes.items() + } + + # Save connections + viewer_dict["connections"] = dict() + for node_id, node in self.nodes.items(): + viewer_dict["connections"][node.id] = dict() + for port in node.ports: + viewer_dict["connections"][node.id][port.id] = dict() + for connected_port in port.connected_ports: + viewer_dict["connections"][node.id][port.id][ + connected_port.node.id + ] = connected_port.id + + return viewer_dict + + def from_dict(self, viewer_dict): + self.clear() + # Create nodes + for node_info in viewer_dict["nodes"].values(): + node_class = getattr(nodes, node_info["class"]) + node = node_class.from_dict(self.ct, node_info) + self.add_node(node) + # Initialize connections + for node_id, port_dict in viewer_dict["connections"].items(): + node = self.node(old_id=node_id) + for port_id, connected_dict in port_dict.items(): + port = node.port(old_id=port_id) + for con_node_id, con_port_id in connected_dict.items(): + connected_node = self.node(old_id=con_node_id) + connected_port = connected_node.port(old_id=con_port_id) + port.connect_to(connected_port) + + def clear(self): + """Clear the node graph.""" + # list conversion necessary because self.nodes is mutated + for node in list(self.nodes.values()): + self.remove_node(node) + + # ---------------------------------------------------------------------------------- + # Qt methods + # ---------------------------------------------------------------------------------- + def _set_viewer_zoom(self, value, sensitivity=None, pos=None): + """Sets the zoom level. + + Args: + value (float): zoom factor. + sensitivity (float): zoom sensitivity. + pos (QPoint): mapped position. + """ + if pos: + pos = self.mapToScene(pos) + if sensitivity is None: + scale = 1.001**value + self.scale(scale, scale, pos) + return + + if value == 0.0: + return + + scale = (0.9 + sensitivity) if value < 0.0 else (1.1 - sensitivity) + zoom = self.get_zoom() + if defaults["viewer"]["zoom_min"] >= zoom: + if scale == 0.9: + return + if defaults["viewer"]["zoom_max"] <= zoom: + if scale == 1.1: + return + self.scale(scale, scale, pos) + + def _set_viewer_pan(self, pos_x, pos_y): + """Set the viewer in panning mode. + + Args: + pos_x (float): x pos. + pos_y (float): y pos. + """ + self._scene_range.adjust(pos_x, pos_y, pos_x, pos_y) + self._update_scene() + + def scale(self, sx, sy, pos=None): + scale = [sx, sx] + center = pos or self._scene_range.center() + w = self._scene_range.width() / scale[0] + h = self._scene_range.height() / scale[1] + self._scene_range = QRectF( + center.x() - (center.x() - self._scene_range.left()) / scale[0], + center.y() - (center.y() - self._scene_range.top()) / scale[1], + w, + h, + ) + self._update_scene() + + def _update_scene(self): + """Redraw the scene.""" + self.setSceneRect(self._scene_range) + self.fitInView(self._scene_range, Qt.AspectRatioMode.KeepAspectRatio) + + def _combined_rect(self, nodes): + """Returns a QRectF with the combined size of the provided node items. + + Args: + nodes (list[AbstractNodeItem]): list of node qgraphics items. + + Returns: + QRectF: combined rect + """ + group = self.scene().createItemGroup(nodes) + rect = group.boundingRect() + self.scene().destroyItemGroup(group) + return rect + + def _items_near(self, pos, width=20, height=20): + """Filter node graph items from the specified position, width and height area. + + Args: + pos (QPointF): scene pos. + width (int): width area. + height (int): height area. + + Returns: + list: qgraphics items from the scene. + """ + x, y = pos.x() - width, pos.y() - height + rect = QRectF(x, y, width, height) + items = [] + excl = [self._LIVE_PIPE, self._SLICER_PIPE] + for item in self.scene().items(rect): + if item in excl: + continue + items.append(item) + return items + + # Reimplement events + def resizeEvent(self, event): + w, h = self.size().width(), self.size().height() + if 0 in [w, h]: + self.resize(self._last_size) + delta = max(w / self._last_size.width(), h / self._last_size.height()) + self._set_viewer_zoom(delta) + self._last_size = self.size() + super().resizeEvent(event) + + def contextMenuEvent(self, event): + # ToDo: reimplement context menu. + pass + + return super().contextMenuEvent(event) + + def mousePressEvent(self, event): + if event.button() == Qt.MouseButton.LeftButton: + self.LMB_state = True + elif event.button() == Qt.MouseButton.RightButton: + self.RMB_state = True + elif event.button() == Qt.MouseButton.MiddleButton: + self.MMB_state = True + + self._origin_pos = event.pos() + self._previous_pos = event.pos() + self._prev_selection_nodes, self._prev_selection_pipes = self.selected_items() + + # cursor pos. + map_pos = self.mapToScene(event.pos()) + + # debug path + if self._debug_mode: + if self.LMB_state: + path = self._debug_path.path() + path.moveTo(map_pos) + self._debug_path.setPath(path) + + # pipe slicer enabled. + if self.LMB_state and event.modifiers() == ( + Qt.KeyboardModifier.AltModifier | Qt.KeyboardModifier.ShiftModifier + ): + self._SLICER_PIPE.draw_path(map_pos, map_pos) + self._SLICER_PIPE.setVisible(True) + return + + # pan mode. + if event.modifiers() == Qt.KeyboardModifier.AltModifier: + return + + items = self._items_near(map_pos, 20, 20) + nodes = [i for i in items if self.isnode(i)] + + if len(nodes) > 0: + self.MMB_state = False + + # update the recorded node positions. + selection = set([]) + selection.update(self.selected_nodes()) + self._node_positions.update({n: n.xy_pos for n in selection}) + + # show selection marquee. + if self.LMB_state and not items: + rect = QRect(self._previous_pos, QSize()) + rect = rect.normalized() + map_rect = self.mapToScene(rect).boundingRect() + self.scene().update(map_rect) + self._rubber_band.setGeometry(rect) + self._rubber_band.isActive = True + + if not self._LIVE_PIPE.isVisible(): + super().mousePressEvent(event) + + def mouseReleaseEvent(self, event): + if event.button() == Qt.MouseButton.LeftButton: + self.LMB_state = False + elif event.button() == Qt.MouseButton.RightButton: + self.RMB_state = False + elif event.button() == Qt.MouseButton.MiddleButton: + self.MMB_state = False + + # hide pipe slicer. + if self._SLICER_PIPE.isVisible(): + for i in self.scene().items(self._SLICER_PIPE.path()): + if self.ispipe(i) and i != self._LIVE_PIPE: + i.input_port.disconnect_from(i.output_port) + p = QPointF(0.0, 0.0) + self._SLICER_PIPE.draw_path(p, p) + self._SLICER_PIPE.setVisible(False) + + # hide selection marquee + if self._rubber_band.isActive: + self._rubber_band.isActive = False + if self._rubber_band.isVisible(): + rect = self._rubber_band.rect() + map_rect = self.mapToScene(rect).boundingRect() + self._rubber_band.hide() + self.scene().update(map_rect) + return + + # find position changed nodes and emit signal. + moved_nodes = { + n: xy_pos + for n, xy_pos in self._node_positions.items() + if n.xy_pos != xy_pos + } + # only emit of node is not colliding with a pipe. + if moved_nodes and not self.COLLIDING_state: + self.MovedNodes.emit(moved_nodes) + + # reset recorded positions. + self._node_positions = {} + + # emit signal if selected node collides with pipe. + # Note: if collide state is true then only 1 node is selected. + # ToDo: Implement colliding if necessary + # nodes, pipes = self.selected_items() + # if self.COLLIDING_state and nodes and pipes: + # self.InsertNode.emit(pipes[0], nodes[0].id, moved_nodes) + + super().mouseReleaseEvent(event) + + def mouseMoveEvent(self, event): + alt_modifier = event.modifiers() == Qt.KeyboardModifier.AltModifier + if self._debug_mode: + # Debug mouse + if self.LMB_state: + to_pos = self.mapToScene(event.pos()) + path = self._debug_path.path() + path.lineTo(to_pos) + self._debug_path.setPath(path) + + # Draw slicer + if self.LMB_state and event.modifiers() == ( + Qt.KeyboardModifier.AltModifier | Qt.KeyboardModifier.ShiftModifier + ): + if self._SLICER_PIPE.isVisible(): + p1 = self._SLICER_PIPE.path().pointAtPercent(0) + p2 = self.mapToScene(self._previous_pos) + self._SLICER_PIPE.draw_path(p1, p2) + self._SLICER_PIPE.show() + self._previous_pos = event.pos() + super().mouseMoveEvent(event) + return + + # Pan view + if self.MMB_state or ( + self.LMB_state and alt_modifier and not self._LIVE_PIPE.isVisible() + ): + previous_pos = self.mapToScene(self._previous_pos) + current_pos = self.mapToScene(event.pos()) + delta = previous_pos - current_pos + self._set_viewer_pan(delta.x(), delta.y()) + + if self.LMB_state and self._rubber_band.isActive: + rect = QRect(self._origin_pos, event.pos()).normalized() + # if the rubber band is too small, do not show it. + if max(rect.width(), rect.height()) > 5: + if not self._rubber_band.isVisible(): + self._rubber_band.show() + map_rect = self.mapToScene(rect).boundingRect() + path = QPainterPath() + path.addRect(map_rect) + self._rubber_band.setGeometry(rect) + self.scene().setSelectionArea( + path, mode=Qt.ItemSelectionMode.IntersectsItemShape + ) + self.scene().update(map_rect) + + elif self.LMB_state: + self.COLLIDING_state = False + nodes, pipes = self.selected_items() + if len(nodes) == 1: + node = nodes[0] + [p.setSelected(False) for p in pipes] + + colliding_pipes = [ + i for i in node.collidingItems() if self.ispipe(i) and i.isVisible() + ] + for pipe in colliding_pipes: + if not pipe.input_port: + continue + port_node_check = all( + [ + pipe.input_port.node is not node, + pipe.output_port.node is not node, + ] + ) + if port_node_check: + pipe.setSelected(True) + self.COLLIDING_state = True + break + + self._previous_pos = event.pos() + super(NodeViewer, self).mouseMoveEvent(event) + + def wheelEvent(self, event): + try: + delta = event.delta() + except AttributeError: + # For PyQt5 + delta = event.angleDelta().y() + if delta == 0: + delta = event.angleDelta().x() + self._set_viewer_zoom(delta, pos=event.pos()) + + def dropEvent(self, event): + pos = self.mapToScene(event.pos()) + event.setDropAction(Qt.DropAction.CopyAction) + self.DataDropped.emit(event.mimeData(), QPointF(pos.x(), pos.y())) + + def dragEnterEvent(self, event): + is_acceptable = any( + [ + event.mimeData().hasFormat(i) + for i in ["nodegraphqt/nodes", "text/plain", "text/uri-list"] + ] + ) + if is_acceptable: + event.accept() + else: + event.ignore() + + def dragMoveEvent(self, event): + is_acceptable = any( + [ + event.mimeData().hasFormat(i) + for i in ["nodegraphqt/nodes", "text/plain", "text/uri-list"] + ] + ) + if is_acceptable: + event.accept() + else: + event.ignore() + + def dragLeaveEvent(self, event): + event.ignore() + + def keyPressEvent(self, event): + if self._LIVE_PIPE.isVisible(): + super(NodeViewer, self).keyPressEvent(event) + return + + # show cursor text + overlay_text = None + self._cursor_text.setVisible(False) + + if ( + event.modifiers() + == Qt.KeyboardModifier.AltModifier | Qt.KeyboardModifier.ShiftModifier + ): + overlay_text = "\n ALT + SHIFT:\n Pipe Slicer Enabled" + if overlay_text: + self._cursor_text.setPlainText(overlay_text) + self._cursor_text.setPos(self.mapToScene(self._previous_pos)) + self._cursor_text.setVisible(True) + + super(NodeViewer, self).keyPressEvent(event) + + def keyReleaseEvent(self, event): + # hide and reset cursor text. + self._cursor_text.setPlainText("") + self._cursor_text.setVisible(False) + + super(NodeViewer, self).keyReleaseEvent(event) + + # ---------------------------------------------------------------------------------- + # Scene Events + # ---------------------------------------------------------------------------------- + + def sceneMouseMoveEvent(self, event): + """ + triggered mouse move event for the scene. + - redraw the live connection pipe. + + Args: + event (QtWidgets.QGraphicsSceneMouseEvent): + The event handler from the QtWidgets.QGraphicsScene + """ + if not self._LIVE_PIPE.isVisible(): + return + if not self._start_port: + return + + pos = event.scenePos() + pointer_color = None + for item in self.scene().items(pos): + if not self.isport(item): + continue + + x = item.boundingRect().width() / 2 + y = item.boundingRect().height() / 2 + pos = item.scenePos() + pos.setX(pos.x() + x) + pos.setY(pos.y() + y) + if item == self._start_port: + break + pointer_color = defaults["pipes"]["highlight_color"] + # ToDo: Accept implementation + accept = True + if not accept: + pointer_color = [150, 60, 255] + break + + if item.node == self._start_port.node: + pointer_color = defaults["pipes"]["disabled_color"] + elif item.port_type == self._start_port.port_type: + pointer_color = defaults["pipes"]["disabled_color"] + break + + self._LIVE_PIPE.draw_path(self._start_port, cursor_pos=pos, color=pointer_color) + + def sceneMousePressEvent(self, event): + """ + triggered mouse press event for the scene (takes priority over viewer event). + - detect selected pipe and start connection. + + Args: + event (QtWidgets.QGraphicsScenePressEvent): + The event handler from the QtWidgets.QGraphicsScene + """ + # pipe slicer enabled. + if event.modifiers() == ( + Qt.KeyboardModifier.AltModifier | Qt.KeyboardModifier.ShiftModifier + ): + return + + # viewer pan mode. + if event.modifiers() == Qt.KeyboardModifier.AltModifier: + return + + if self._LIVE_PIPE.isVisible(): + self.apply_live_connection(event) + return + + pos = event.scenePos() + items = self._items_near(pos, 5, 5) + + # filter from the selection stack in the following order + # "node, port, pipe" this is to avoid selecting items under items. + node, port, pipe = None, None, None + for item in items: + if self.isnode(item): + node = item + elif self.isport(item): + port = item + elif self.ispipe(item): + pipe = item + if any([node, port, pipe]): + break + + if port: + if not port.multi_connection and len(port.connected_ports) > 0: + # ToDo: Might cause problems with multi-connections + self._detached_port = port.connected_ports[0] + self.start_live_connection(port) + if not port.multi_connection: + [p.delete() for p in port.connected_pipes.values()] + return + + if node: + node_items = [i for i in self._items_near(pos, 3, 3) if self.isnode(i)] + + # record the node positions at selection time. + for n in node_items: + self._node_positions[n] = n.xy_pos + + if pipe: + if not self.LMB_state: + return + + from_port = pipe.port_from_pos(pos, True) + from_port.hovered = True + + attr = { + "in": "input_port", + "out": "output_port", + } + self._detached_port = getattr(pipe, attr[from_port.port_type]) + self.start_live_connection(from_port) + self._LIVE_PIPE.draw_path(self._start_port, cursor_pos=pos) + + if event.modifiers() == Qt.KeyboardModifier.ShiftModifier: + self._LIVE_PIPE.shift_selected = True + return + + pipe.delete() + + def sceneMouseReleaseEvent(self, event): + """Triggered mouse release event for the scene. + + Args: + event (QtWidgets.QGraphicsSceneMouseEvent): + The event handler from the QtWidgets.QGraphicsScene + """ + if event.button() != Qt.MouseButton.MiddleButton: + self.apply_live_connection(event) + + def apply_live_connection(self, event): + """Triggered mouse press/release event for the scene. + + - verifies the live connection pipe. + - makes a connection pipe if valid. + - emits the "connection changed" signal. + + Args: + event (QtWidgets.QGraphicsSceneMouseEvent): + The event handler from the QtWidgets.QGraphicsScene + """ + if not self._LIVE_PIPE.isVisible(): + return + + self._start_port.hovered = False + + # find the end port. + end_port = None + for item in self.scene().items(event.scenePos()): + if self.isport(item): + end_port = item + break + + # if port disconnected from existing pipe. + if end_port is None: + if self._detached_port and not self._LIVE_PIPE.shift_selected: + dist = math.hypot( + self._previous_pos.x() - self._origin_pos.x(), + self._previous_pos.y() - self._origin_pos.y(), + ) + if dist <= 2.0: # cursor pos threshold. + self._start_port.connect_to(self._detached_port) + self._detached_port = None + else: + self._start_port.disconnect_from(self._detached_port) + + self._detached_port = None + self.end_live_connection() + return + + else: + if self._start_port is end_port: + return + + # constrain check + compatible = self._start_port.compatible(end_port, verbose=False) + + # restore connection if ports are not compatible + if not compatible: + if self._detached_port: + to_port = self._detached_port or end_port + self._start_port.connect_to(to_port) + self._detached_port = None + self.end_live_connection() + return + + # end connection if starting port is already connected. + if self._start_port.multi_connection and self._start_port.connected(end_port): + self._detached_port = None + self.end_live_connection() + logging.debug("Target Port is already connected.") + return + + # disconnect target port from its connections if not multi connection. + if not end_port.multi_connection and len(end_port.connected_ports) > 0: + end_port.clear_connections() + + # Connect from detached port if available. + if self._detached_port: + self._start_port.disconnect_from(self._detached_port) + + # Make connection + self._start_port.connect_to(end_port) + + self._detached_port = None + self.end_live_connection() + + def start_live_connection(self, selected_port): + """Create new pipe for the connection. + + (show the live pipe visibility from the port following the cursor position) + """ + if not selected_port: + return + self._start_port = selected_port + if self._start_port.port_type == "in": + self._LIVE_PIPE.input_port = self._start_port + elif self._start_port == "out": + self._LIVE_PIPE.output_port = self._start_port + self._LIVE_PIPE.setVisible(True) + self._LIVE_PIPE.draw_index_pointer( + selected_port, self.mapToScene(self._origin_pos) + ) + + def end_live_connection(self): + """Delete live connection pipe and reset start port. + + (hides the pipe item used for drawing the live connection) + """ + self._LIVE_PIPE.reset_path() + self._LIVE_PIPE.setVisible(False) + self._LIVE_PIPE.shift_selected = False + self._start_port = None + + def isnode(self, item): + """Check if the item is a node. + + Parameters + ---------- + item: QGraphicsItem + + Returns + ------- + result: bool + True if the item is a node. + """ + # For some reason, issubclass(item.__class__, BaseNode) does not work + if item in self.nodes.values(): + return True + return False + + def isport(self, item): + """Check if the item is a port. + + Parameters + ---------- + item: QGraphicsItem + + Returns + ------- + result: bool + True if the item is a port. + """ + return isinstance(item, Port) + + def ispipe(self, item): + """Check if the item is a pipe. + + Parameters + ---------- + item: QGraphicsItem + + Returns + ------- + result: bool + True if the item is a pipe. + """ + return isinstance(item, Pipe) + + def all_pipes(self): + """Returns all pipe qgraphic items. + + Returns: + list[PipeItem]: instances of pipe items. + """ + return [i for i in self.scene().items() if self.ispipe(i)] + + def selected_nodes(self): + """Returns selected node qgraphic items. + + Returns: + list[AbstractNodeItem]: instances of node items. + """ + return [i for i in self.scene().selectedItems() if self.isnode(i)] + + def selected_pipes(self): + """Returns selected pipe qgraphic items. + + Returns: + list[Pipe]: pipe items. + """ + return [i for i in self.scene().selectedItems() if self.ispipe(i)] + + def selected_items(self): + """Return selected graphic items in the scene. + + Returns: + tuple(list[AbstractNodeItem], list[Pipe]): + selected (node items, pipe items). + """ + nodes = [i for i in self.scene().selectedItems() if self.isnode(i)] + pipes = [i for i in self.scene().selectedItems() if self.ispipe(i)] + + return nodes, pipes + + def move_nodes(self, nodes, pos=None, offset=None): + """Globally move specified nodes. + + Args: + nodes (list[AbstractNodeItem]): node items. + pos (tuple or list): custom x, y position. + offset (tuple or list): x, y position offset. + """ + group = self.scene().createItemGroup(nodes) + group_rect = group.boundingRect() + if pos: + x, y = pos + else: + pos = self.mapToScene(self._previous_pos) + x = pos.x() - group_rect.center().x() + y = pos.y() - group_rect.center().y() + if offset: + x += offset[0] + y += offset[1] + group.setPos(x, y) + self.scene().destroyItemGroup(group) + + def get_pipes_from_nodes(self, nodes=None): + nodes = nodes or self.selected_nodes() + if not nodes: + return + pipes = [] + for node in nodes: + n_inputs = node.inputs if hasattr(node, "inputs") else [] + n_outputs = node.outputs if hasattr(node, "outputs") else [] + + for port in n_inputs: + for pipe in port.connected_pipes.values(): + connected_node = pipe.output_port.node + if connected_node in nodes: + pipes.append(pipe) + for port in n_outputs: + for pipe in port.connected_pipes.values(): + connected_node = pipe.input_port.node + if connected_node in nodes: + pipes.append(pipe) + return pipes + + def center_selection(self, nodes=None): + """Center on the given nodes or all nodes by default. + + Args: + nodes (list[AbstractNodeItem]): a list of node items. + """ + nodes = nodes or self.selected_nodes() or self.nodes.values() + if not nodes: + return + + rect = self._combined_rect(nodes) + self._scene_range.translate(rect.center() - self._scene_range.center()) + self.setSceneRect(self._scene_range) + + def clear_selection(self): + """Clear the selected items in the scene.""" + for node in self.nodes.values(): + node.setSelected(False) + + def reset_zoom(self, cent=None): + """Reset the viewer zoom level. + + Args: + cent (QtCore.QPoint): specified center. + """ + self._scene_range = QRectF(0, 0, self.size().width(), self.size().height()) + if cent: + self._scene_range.translate(cent - self._scene_range.center()) + self._update_scene() + + def get_zoom(self): + """Returns the viewer zoom level. + + Returns: + float: zoom level. + """ + transform = self.transform() + cur_scale = (transform.m11(), transform.m22()) + return float("{:0.2f}".format(cur_scale[0] - 1.0)) + + def set_zoom(self, value=0.0): + """Set the viewer zoom level. + + Args: + value (float): zoom level + """ + if value == 0.0: + self.reset_zoom() + return + zoom = self.get_zoom() + if zoom < 0.0: + if not ( + defaults["viewer"]["zoom_min"] <= zoom <= defaults["viewer"]["zoom_max"] + ): + return + else: + if not ( + defaults["viewer"]["zoom_min"] + <= value + <= defaults["viewer"]["zoom_max"] + ): + return + value = value - zoom + self._set_viewer_zoom(value, 0.0) + + def zoom_to_nodes(self, nodes): + self._scene_range = self._combined_rect(nodes) + self._update_scene() + + if self.get_zoom() > 0.1: + self.reset_zoom(self._scene_range.center()) + + def fit_to_selection(self): + """Sets the zoom level to fit selected nodes. + + If no nodes are selected then all nodes in the graph will be framed. + """ + nodes = self.selected_nodes() or self.nodes.values() + if not nodes: + return + self.zoom_to_nodes(nodes) + + def force_update(self): + """Redraw the current node graph scene.""" + self._update_scene() + + def scene_rect(self): + """Returns the scene rect size. + + Returns: + list[float]: x, y, width, height + """ + return [ + self._scene_range.x(), + self._scene_range.y(), + self._scene_range.width(), + self._scene_range.height(), + ] + + def set_scene_rect(self, rect): + """Sets the scene rect and redraws the scene. + + Args: + rect (list[float]): x, y, width, height + """ + self._scene_range = QRectF(*rect) + self._update_scene() + + def scene_center(self): + """Get the center x,y pos from the scene. + + Returns: + list[float]: x, y position. + """ + cent = self._scene_range.center() + return [cent.x(), cent.y()] + + def scene_cursor_pos(self): + """Returns the cursor last position mapped to the scene. + + Returns: + QtCore.QPoint: cursor position. + """ + return self.mapToScene(self._previous_pos) + + def nodes_rect_center(self, nodes): + """Get the center x,y pos from the specified nodes. + + Args: + nodes (list[AbstractNodeItem]): list of node qgrphics items. + + Returns: + list[float]: x, y position. + """ + cent = self._combined_rect(nodes).center() + return [cent.x(), cent.y()] + + def use_OpenGL(self): + """Use QOpenGLWidget as the viewer.""" + if qtpy.PYQT5 or qtpy.PYSIDE2: + from qtpy.QtWidgets import QOpenGLWidget + else: + from qtpy.QtOpenGLWidgets import QOpenGLWidget + self.setViewport(QOpenGLWidget()) + + def node_position_scene(self, **node_kwargs): + node = self.node(**node_kwargs) + scene_pos = node.scenePos() + node.boundingRect().center() + + return scene_pos + + def node_position_view(self, **node_kwargs): + scene_pos = self.node_position_scene(**node_kwargs) + view_pos = self.mapFromScene(scene_pos) + + return view_pos + + def port_position_scene(self, node, **port_kwargs): + port = node.port(**port_kwargs) + scene_pos = port.scenePos() + port.boundingRect().center() + + return scene_pos + + def port_position_view(self, node, **port_kwargs): + scene_pos = self.port_position_scene(node, **port_kwargs) + view_pos = self.mapFromScene(scene_pos) + + return view_pos + + # -------------------------------------------------------------------------------------- + # AutoLayout + # -------------------------------------------------------------------------------------- + + @staticmethod + def _update_node_rank(node, nodes_rank, down_stream=True): + """Recursive function for updating the node ranking. + + Args: + node (BaseNode): node to start from. + nodes_rank (dict): node ranking object to be updated. + down_stream (bool): true to rank down stram. + """ + if down_stream: + node_values = node.connected_output_nodes().values() + else: + node_values = node.connected_input_nodes().values() + + connected_nodes = set() + for nds in node_values: + connected_nodes.update(nds) + + rank = nodes_rank[node] + 1 + for n in connected_nodes: + if n in nodes_rank: + nodes_rank[n] = max(nodes_rank[n], rank) + else: + nodes_rank[n] = rank + NodeViewer._update_node_rank(n, nodes_rank, down_stream) + + @staticmethod + def _compute_node_rank(nodes, down_stream=True): + """Compute the ranking of nodes. + + Args: + nodes (list[BaseNode]): nodes to start ranking from. + down_stream (bool): true to compute down stream. + + Returns: + dict: {BaseNode: node_rank, ...} + """ + nodes_rank = {} + for node in nodes: + nodes_rank[node] = 0 + NodeViewer._update_node_rank(node, nodes_rank, down_stream) + return nodes_rank + + def auto_layout_nodes(self, nodes=None, down_stream=True, start_nodes=None): + """Auto layout the nodes in the node graph. + + Note: + If the node graph is acyclic then the ``start_nodes`` will need + to be specified. + + Args: + nodes (list[BaseNode]): list of nodes to auto layout + if nodes is None then all nodes is layed out. + down_stream (bool): false to layout up stream. + start_nodes (list[BaseNode]): + list of nodes to start the auto layout from (Optional). + """ + nodes = nodes or self.nodes.values() + + start_nodes = start_nodes or [] + if down_stream: + start_nodes += [ + n for n in nodes if not any(n.connected_input_nodes().values()) + ] + else: + start_nodes += [ + n for n in nodes if not any(n.connected_output_nodes().values()) + ] + + if not start_nodes: + return + + nodes_center_0 = self.nodes_rect_center(nodes) + + nodes_rank = NodeViewer._compute_node_rank(start_nodes, down_stream) + + rank_map = {} + for node, rank in nodes_rank.items(): + if rank in rank_map: + rank_map[rank].append(node) + else: + rank_map[rank] = [node] + + current_x = 0 + node_height = 120 + for rank in sorted(range(len(rank_map)), reverse=not down_stream): + ranked_nodes = rank_map[rank] + max_width = max([node.width for node in ranked_nodes]) + current_x += max_width + current_y = 0 + for idx, node in enumerate(ranked_nodes): + dy = max(node_height, node.height) + current_y += 0 if idx == 0 else dy + node.setPos(current_x, current_y) + current_y += dy * 0.5 + 10 + + current_x += max_width * 0.5 + 100 + + nodes_center_1 = self.nodes_rect_center(nodes) + dx = nodes_center_0[0] - nodes_center_1[0] + dy = nodes_center_0[1] - nodes_center_1[1] + [n.setPos(n.x() + dx, n.y() + dy) for n in nodes] diff --git a/mne_pipeline_hd/gui/node/nodegraphqt_license.md b/mne_pipeline_hd/gui/node/nodegraphqt_license.md new file mode 100644 index 00000000..94fc8bb9 --- /dev/null +++ b/mne_pipeline_hd/gui/node/nodegraphqt_license.md @@ -0,0 +1,22 @@ +MIT License +=========== + +Copyright (c) 2017 Johnny Chan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/mne_pipeline_hd/gui/node/nodes.py b/mne_pipeline_hd/gui/node/nodes.py new file mode 100644 index 00000000..eaa00eee --- /dev/null +++ b/mne_pipeline_hd/gui/node/nodes.py @@ -0,0 +1,166 @@ +# -*- coding: utf-8 -*- +import logging + +from mne_pipeline_hd.gui.gui_utils import get_exception_tuple +from mne_pipeline_hd.gui.node.base_node import BaseNode +from qtpy.QtWidgets import ( + QWidget, + QVBoxLayout, + QPushButton, + QDialog, + QScrollArea, + QGroupBox, +) + +from mne_pipeline_hd.gui.base_widgets import CheckList +from mne_pipeline_hd.gui.loading_widgets import AddFilesWidget, AddMRIWidget +from mne_pipeline_hd.gui import parameter_widgets + + +class BaseInputNode(BaseNode): + """Node for input data like MEEG, FSMRI, etc.""" + + # Add Start button (Depending from where we start, we get different orders of execution) + # There can be secondary inputs + def __init__(self, ct): + super().__init__(ct) + self.data_type = None + self.widget = QWidget() + + def init_widgets(self, data_type): + self.data_type = data_type + self.name = f"{data_type} Input Node" + # Add the output port + self.add_output("Data", multi_connection=True) + # Initialize the other widgets inside the node + layout = QVBoxLayout(self.widget) + import_bt = QPushButton("Import") + import_bt.clicked.connect(self.add_files) + layout.addWidget(import_bt) + input_list = CheckList( + # self.ct.inputs[data_type], + # self.ct.selected_inputs[data_type], + self.ct.pr.all_meeg, + self.ct.pr.sel_meeg, + ui_button_pos="bottom", + show_index=True, + title=f"Select {data_type}", + ) + layout.addWidget(input_list) + self.add_widget(self.widget) + + def add_files(self): + # This decides, wether the dialog is rendered outside or inside the scene + dlg = QDialog(self.viewer) + dlg.setWindowTitle("Import Files") + if self.data_type == "MEEG": + widget = AddFilesWidget(self.ct) + else: + widget = AddMRIWidget(self.ct) + dlg_layout = QVBoxLayout(dlg) + dlg_layout.addWidget(widget) + dlg.open() + + +class MEEGInputNode(BaseInputNode): + def __init__(self, ct): + super().__init__(ct) + self.init_widgets("MEEG") + + +class MRIInputNode(BaseInputNode): + def __init__(self, ct): + super().__init__(ct) + self.init_widgets("FSMRI") + + +class GroupNode(BaseNode): + def __init__(self, ct): + super().__init__(ct, name="Group Node") + # This node should be adaptive, when a new input data-type is connected, + # it should change the names of input-ports and output-ports accordingly + self.add_input("Data-In", multi_connection=True, accepted_ports=None) + self.add_output("Data-Out", multi_connection=True, accepted_ports=None) + + # ToDo: This will have a widget for selecting and organizing groups + + +class FunctionNode(BaseNode): + """This node is a prototype for a function node, which also displays parameters.""" + + def __init__( + self, ct, name=None, parameters=None, **kwargs + ): # **kwargs just for demo, later not needed + super().__init__(ct, name, **kwargs) + self.parameters = parameters + + self.init_parameters() + + def init_parameters(self): + group_box = QGroupBox("Parameters") + layout = QVBoxLayout(group_box) + if len(self.parameters) > 5: + widget = QScrollArea() + sub_widget = QWidget() + layout = QVBoxLayout(sub_widget) + widget.setWidget(sub_widget) + + for name, param_kwargs in self.parameters.items(): + alias = param_kwargs.get("alias", name) + gui = param_kwargs.get("gui", None) + default = param_kwargs.get("default", None) + if default is None: + logging.error(f"For parameter {name} no default value was defined.") + continue + if gui is None: + logging.error(f"For parameter {name} no GUI was defined.") + continue + extra_kwargs = { + k: v + for k, v in param_kwargs.items() + if k not in ["alias", "gui", "default"] + } + try: + parameter_gui = getattr(parameter_widgets, gui)( + data=self.ct, + name=name, + alias=alias, + default=default, + **extra_kwargs, + ) + except Exception: + err_tuple = get_exception_tuple() + logging.error( + f'Initialization of Parameter-Widget "{name}" ' + f"with value={default} " + f"failed:\n" + f"{err_tuple[1]}" + ) + else: + layout.addWidget(parameter_gui) + self.add_widget(group_box) + + def to_dict(self): + """Override dictionary representation because of additional attributes.""" + node_dict = super().to_dict() + node_dict["parameters"] = self.parameters + + return node_dict + + def mouseDoubleClickEvent(self, event): + # Open a dialog to show the code of the function (maybe even small editor) + pass + + +class AssignmentNode(BaseNode): + """This node assigns the input from 1 to an input upstream from 2, which then leads + to runningo the functions before for input 2 while caching input 1.""" + + # ToDo: + # Checks for assignments and if there are pairs for each input. + # Checks also for inputs in multiple pairs. + # Status color and status message (like "24/28 assigned") + # Should change port names depending on data-type connected + def __init__(self, ct, **kwargs): # **kwargs just for demo, later not needed + super().__init__(ct, **kwargs) + self.name = "Assignment Node" diff --git a/mne_pipeline_hd/gui/node/nodes_readme.md b/mne_pipeline_hd/gui/node/nodes_readme.md new file mode 100644 index 00000000..555e440d --- /dev/null +++ b/mne_pipeline_hd/gui/node/nodes_readme.md @@ -0,0 +1,12 @@ +## Using code from NodeGraphQt +The code for nodes in mne-pipeline-hd is partially copied or heavily inspired by code from +[NodeGraphQt](https://github.com/jchanvfx/NodeGraphQt). +There are various reasons, why this package is not directly used. +Among others this usecase requires some heavy customization for node creation and +implementing new logic seemed to require subclassing a lot of the base objects from NodeGraphQt. +While the original package with its MVC-architecture is very flexible, it is also quite complex. +It supports features, which are not needed here e.g. properties, multiple layouts and widgets. +And last but not least, there is no official support for PySide6/PyQt6 (yet, 2024/04). +NodeGraphQt is licensed under the MIT License. +The license for NodeGraphQt is included [here](./nodegraphqt_license.md). +Thank you to the maintainers of NodeGraphQt especially Johnny Chan for their work. diff --git a/mne_pipeline_hd/gui/node/pipes.py b/mne_pipeline_hd/gui/node/pipes.py new file mode 100644 index 00000000..c970431c --- /dev/null +++ b/mne_pipeline_hd/gui/node/pipes.py @@ -0,0 +1,527 @@ +# -*- coding: utf-8 -*- +import math + +from mne_pipeline_hd.gui.gui_utils import format_color +from mne_pipeline_hd.gui.node.node_defaults import defaults +from qtpy.QtCore import QPointF, Qt, QLineF, QRectF +from qtpy.QtGui import QPolygonF, QColor, QPainterPath, QBrush, QTransform, QPen +from qtpy.QtWidgets import ( + QGraphicsPathItem, + QGraphicsItem, + QGraphicsPolygonItem, + QGraphicsTextItem, +) + + +class Pipe(QGraphicsPathItem): + def __init__(self, input_port=None, output_port=None): + """Initialize the pipe item. + + Notes + ----- + The method "draw_path" has to be called at least once + after the pipe is added to the scene. + """ + super().__init__() + + # init QGraphicsPathItem + self.setZValue(-1) + self.setAcceptHoverEvents(True) + self.setFlag(QGraphicsItem.GraphicsItemFlag.ItemIsSelectable) + self.setCacheMode(QGraphicsItem.DeviceCoordinateCache) + + # Hidden attributes + self._input_port = input_port + self._output_port = output_port + self._color = defaults["pipes"]["color"] + self._style = defaults["pipes"]["style"] + self._active = False + self._highlight = False + + size = 6.0 + self._poly = QPolygonF() + self._poly.append(QPointF(-size, size)) + self._poly.append(QPointF(0.0, -size * 1.5)) + self._poly.append(QPointF(size, size)) + + self._dir_pointer = QGraphicsPolygonItem(self) + self._dir_pointer.setPolygon(self._poly) + self._dir_pointer.setFlag(self.GraphicsItemFlag.ItemIsSelectable, False) + + self.set_pipe_styling(color=self.color, width=2, style=self.style) + + # -------------------------------------------------------------------------------------- + # Properties + # -------------------------------------------------------------------------------------- + @property + def input_port(self): + return self._input_port + + @input_port.setter + def input_port(self, port): + self._input_port = port if hasattr(port, "connect_to") else None + + @property + def output_port(self): + return self._output_port + + @output_port.setter + def output_port(self, port): + self._output_port = port if hasattr(port, "connect_to") else None + + @property + def color(self): + return self._color + + @color.setter + def color(self, color): + self._color = format_color(color) + + @property + def style(self): + return self._style + + @style.setter + def style(self, style): + self._style = style + + # ---------------------------------------------------------------------------------- + # Qt methods + # ---------------------------------------------------------------------------------- + def hoverEnterEvent(self, event): + self.activate() + + def hoverLeaveEvent(self, event): + self.reset() + if self.input_port and self.output_port: + if self.input_port.node.isSelected() or self.output_port.node.isSelected(): + self.highlight() + if self.isSelected(): + self.highlight() + + def itemChange(self, change, value): + if change == self.GraphicsItemChange.ItemSelectedChange and self.scene(): + if value: + self.highlight() + else: + self.reset() + return super().itemChange(change, value) + + def paint(self, painter, option, widget): + """Draws the connection line between nodes. + + Args: + painter (QtGui.QPainter): painter used for drawing the item. + option (QtGui.QStyleOptionGraphicsItem): + used to describe the parameters needed to draw. + widget (QtWidgets.QWidget): not used. + """ + painter.save() + + pen = self.pen() + if not self.isEnabled() and not self._active: + pen.setColor(QColor(*defaults["pipes"]["disabled_color"])) + pen.setStyle(Qt.PenStyle.DotLine) + pen.setWidth(3) + + painter.setPen(pen) + painter.setBrush(self.brush()) + painter.setRenderHint(painter.RenderHint.Antialiasing, True) + painter.drawPath(self.path()) + + # QPaintDevice: Cannot destroy paint device that is being painted. + painter.restore() + + @staticmethod + def _calc_distance(p1, p2): + x = math.pow((p2.x() - p1.x()), 2) + y = math.pow((p2.y() - p1.y()), 2) + return math.sqrt(x + y) + + def _draw_direction_pointer(self): + """Updates the pipe direction pointer arrow.""" + if not (self.input_port and self.output_port): + self._dir_pointer.setVisible(False) + return + + if not self.isEnabled() and not (self._active or self._highlight): + color = QColor(*defaults["pipes"]["disabled_color"]) + pen = self._dir_pointer.pen() + pen.setColor(color) + self._dir_pointer.setPen(pen) + self._dir_pointer.setBrush(color.darker(200)) + + self._dir_pointer.setVisible(True) + loc_pt = self.path().pointAtPercent(0.49) + tgt_pt = self.path().pointAtPercent(0.51) + radians = math.atan2(tgt_pt.y() - loc_pt.y(), tgt_pt.x() - loc_pt.x()) + degrees = math.degrees(radians) - 90 + self._dir_pointer.setRotation(degrees) + self._dir_pointer.setPos(self.path().pointAtPercent(0.5)) + + cen_x = self.path().pointAtPercent(0.5).x() + cen_y = self.path().pointAtPercent(0.5).y() + dist = math.hypot(tgt_pt.x() - cen_x, tgt_pt.y() - cen_y) + + self._dir_pointer.setVisible(True) + if dist < 0.3: + self._dir_pointer.setVisible(False) + return + if dist < 1.0: + self._dir_pointer.setScale(dist) + + def draw_path(self, start_port, end_port=None, cursor_pos=None): + """Draws the path between ports. + + Args: + start_port (PortItem): port used to draw the starting point. + end_port (PortItem): port used to draw the end point. + cursor_pos (QtCore.QPointF): cursor position if specified this + will be the draw end point. + """ + if not start_port: + return + + # get start / end positions. + pos1 = start_port.scenePos() + pos1.setX(pos1.x() + (start_port.boundingRect().width() / 2)) + pos1.setY(pos1.y() + (start_port.boundingRect().height() / 2)) + if cursor_pos: + pos2 = cursor_pos + elif end_port: + pos2 = end_port.scenePos() + pos2.setX(pos2.x() + (start_port.boundingRect().width() / 2)) + pos2.setY(pos2.y() + (start_port.boundingRect().height() / 2)) + else: + return + + # visibility check for connected pipe. + if self.input_port and self.output_port: + is_visible = all( + [ + self._input_port.isVisible(), + self._output_port.isVisible(), + self._input_port.node.isVisible(), + self._output_port.node.isVisible(), + ] + ) + self.setVisible(is_visible) + + # don't draw pipe if a port or node is not visible. + if not is_visible: + return + + line = QLineF(pos1, pos2) + path = QPainterPath() + + path.moveTo(line.x1(), line.y1()) + + if self.scene(): + layout = self.scene().viewer().pipe_layout + else: + layout = "straight" + + if layout == "straight": + path.lineTo(pos2) + elif layout == "curved": + ctr_offset_x1, ctr_offset_x2 = pos1.x(), pos2.x() + tangent = abs(ctr_offset_x1 - ctr_offset_x2) + + max_width = start_port.node.boundingRect().width() + tangent = min(tangent, max_width) + if start_port.port_type == "in": + ctr_offset_x1 -= tangent + ctr_offset_x2 += tangent + else: + ctr_offset_x1 += tangent + ctr_offset_x2 -= tangent + + ctr_point1 = QPointF(ctr_offset_x1, pos1.y()) + ctr_point2 = QPointF(ctr_offset_x2, pos2.y()) + path.cubicTo(ctr_point1, ctr_point2, pos2) + elif layout == "angle": + ctr_offset_x1, ctr_offset_x2 = pos1.x(), pos2.x() + distance = abs(ctr_offset_x1 - ctr_offset_x2) / 2 + if start_port.port_type == "in": + ctr_offset_x1 -= distance + ctr_offset_x2 += distance + else: + ctr_offset_x1 += distance + ctr_offset_x2 -= distance + + ctr_point1 = QPointF(ctr_offset_x1, pos1.y()) + ctr_point2 = QPointF(ctr_offset_x2, pos2.y()) + path.lineTo(ctr_point1) + path.lineTo(ctr_point2) + path.lineTo(pos2) + self.setPath(path) + + self._draw_direction_pointer() + + def reset_path(self): + """Reset the pipe initial path position.""" + path = QPainterPath(QPointF(0.0, 0.0)) + self.setPath(path) + self._draw_direction_pointer() + + def port_from_pos(self, pos, reverse=False): + """ + Args: + pos (QtCore.QPointF): current scene position. + reverse (bool): false to return the nearest port. + + Returns: + PortItem: port item. + """ + inport_pos = self.input_port.scenePos() + outport_pos = self.output_port.scenePos() + input_dist = self._calc_distance(inport_pos, pos) + output_dist = self._calc_distance(outport_pos, pos) + if input_dist < output_dist: + port = self.output_port if reverse else self.input_port + else: + port = self.input_port if reverse else self.output_port + return port + + def set_pipe_styling(self, color, width=2, style=Qt.PenStyle.SolidLine): + """ + Args: + color (list or tuple): (r, g, b, a) values 0-255 + width (int): pipe width. + style (int): pipe style. + """ + pen = self.pen() + pen.setWidth(width) + pen.setColor(QColor(*color)) + pen.setStyle(style) + pen.setJoinStyle(Qt.PenJoinStyle.MiterJoin) + pen.setCapStyle(Qt.PenCapStyle.RoundCap) + self.setPen(pen) + self.setBrush(QBrush(Qt.BrushStyle.NoBrush)) + + pen = self._dir_pointer.pen() + pen.setJoinStyle(Qt.PenJoinStyle.MiterJoin) + pen.setCapStyle(Qt.PenCapStyle.RoundCap) + pen.setWidth(width) + pen.setColor(QColor(*color)) + self._dir_pointer.setPen(pen) + self._dir_pointer.setBrush(QColor(*color).darker(200)) + + def activate(self): + self._active = True + self.set_pipe_styling( + color=defaults["pipes"]["active_color"], + width=3, + style=defaults["pipes"]["style"], + ) + + def active(self): + return self._active + + def highlight(self): + self._highlight = True + self.set_pipe_styling( + color=defaults["pipes"]["highlight_color"], + width=2, + style=defaults["pipes"]["style"], + ) + + def highlighted(self): + return self._highlight + + def reset(self): + """Reset the pipe state and styling.""" + self._active = False + self._highlight = False + self.set_pipe_styling(color=self.color, width=2, style=self.style) + self._draw_direction_pointer() + + def delete(self): + # Remove pipe from connected_pipes in ports + if self.input_port: + self.input_port.connected_pipes.pop(self.output_port.id, None) + if self.output_port: + self.output_port.connected_pipes.pop(self.input_port.id, None) + if self.scene(): + self.scene().removeItem(self) + + +class LivePipeItem(Pipe): + """Live Pipe item used for drawing the live connection with the cursor.""" + + def __init__(self): + super(LivePipeItem, self).__init__() + self.setZValue(4) + + self.color = defaults["pipes"]["active_color"] + self.style = Qt.PenStyle.DashLine + self.set_pipe_styling(color=self.color, width=3, style=self.style) + + self.shift_selected = False + + self._idx_pointer = LivePipePolygonItem(self) + self._idx_pointer.setPolygon(self._poly) + self._idx_pointer.setBrush(QColor(*self.color).darker(300)) + pen = self._idx_pointer.pen() + pen.setWidth(self.pen().width()) + pen.setColor(self.pen().color()) + pen.setJoinStyle(Qt.PenJoinStyle.MiterJoin) + self._idx_pointer.setPen(pen) + + color = self.pen().color() + color.setAlpha(80) + self._idx_text = QGraphicsTextItem(self) + self._idx_text.setDefaultTextColor(color) + font = self._idx_text.font() + font.setPointSize(7) + self._idx_text.setFont(font) + + def hoverEnterEvent(self, event): + """Re-implemented back to the base default behaviour or the pipe will lose it + styling when another pipe is selected.""" + QGraphicsPathItem.hoverEnterEvent(self, event) + + def draw_path(self, start_port, end_port=None, cursor_pos=None, color=None): + """Re-implemented to also update the index pointer arrow position. + + Args: + start_port (PortItem): port used to draw the starting point. + end_port (PortItem): port used to draw the end point. + cursor_pos (QtCore.QPointF): cursor position if specified this + will be the draw end point. + color (list[int]): override arrow index pointer color. (r, g, b) + """ + super(LivePipeItem, self).draw_path(start_port, end_port, cursor_pos) + self.draw_index_pointer(start_port, cursor_pos, color) + + def draw_index_pointer(self, start_port, cursor_pos, color=None): + """Update the index pointer arrow position and direction when the live pipe path + is redrawn. + + Args: + start_port (PortItem): start port item. + cursor_pos (QtCore.QPoint): cursor scene position. + color (list[int]): override arrow index pointer color. (r, g, b). + """ + text_rect = self._idx_text.boundingRect() + + transform = QTransform() + transform.translate(cursor_pos.x(), cursor_pos.y()) + text_pos = ( + cursor_pos.x() - (text_rect.width() / 2), + cursor_pos.y() - (text_rect.height() * 1.25), + ) + if start_port.port_type == "in": + transform.rotate(-90) + else: + transform.rotate(90) + self._idx_text.setPos(*text_pos) + self._idx_text.setPlainText("{}".format(start_port.name)) + + self._idx_pointer.setPolygon(transform.map(self._poly)) + + pen_color = QColor(*defaults["pipes"]["highlight_color"]) + if isinstance(color, (list, tuple)): + pen_color = QColor(*color) + + pen = self._idx_pointer.pen() + pen.setColor(pen_color) + self._idx_pointer.setBrush(pen_color.darker(300)) + self._idx_pointer.setPen(pen) + + +class LivePipePolygonItem(QGraphicsPolygonItem): + """Custom live pipe polygon shape.""" + + def __init__(self, parent): + super(LivePipePolygonItem, self).__init__(parent) + self.setFlag(QGraphicsItem.GraphicsItemFlag.ItemIsSelectable, True) + + def paint(self, painter, option, widget): + """ + Args: + painter (QtGui.QPainter): painter used for drawing the item. + option (QtGui.QStyleOptionGraphicsItem): + used to describe the parameters needed to draw. + widget (QtWidgets.QWidget): not used. + """ + painter.save() + painter.setBrush(self.brush()) + painter.setPen(self.pen()) + painter.drawPolygon(self.polygon()) + painter.restore() + + +class SlicerPipeItem(QGraphicsPathItem): + """Base item used for drawing the pipe connection slicer.""" + + def __init__(self): + super(SlicerPipeItem, self).__init__() + self.setZValue(5) + + def paint(self, painter, option, widget): + """Draws the slicer pipe. + + Args: + painter (QtGui.QPainter): painter used for drawing the item. + option (QtGui.QStyleOptionGraphicsItem): + used to describe the parameters needed to draw. + widget (QtWidgets.QWidget): not used. + """ + color = QColor(*defaults["slicer"]["color"]) + p1 = self.path().pointAtPercent(0) + p2 = self.path().pointAtPercent(1) + size = 6.0 + offset = size / 2 + arrow_size = 4.0 + + painter.save() + painter.setRenderHint(painter.RenderHint.Antialiasing, True) + + font = painter.font() + font.setPointSize(12) + painter.setFont(font) + text = "slice" + text_x = painter.fontMetrics().width(text) / 2 + text_y = painter.fontMetrics().height() / 1.5 + text_pos = QPointF(p1.x() - text_x, p1.y() - text_y) + text_color = QColor(color) + text_color.setAlpha(80) + painter.setPen( + QPen(text_color, defaults["slicer"]["width"], Qt.PenStyle.SolidLine) + ) + painter.drawText(text_pos, text) + + painter.setPen( + QPen(color, defaults["slicer"]["width"], Qt.PenStyle.DashDotLine) + ) + painter.drawPath(self.path()) + + pen = QPen(color, defaults["slicer"]["width"], Qt.PenStyle.SolidLine) + pen.setCapStyle(Qt.PenCapStyle.RoundCap) + pen.setJoinStyle(Qt.PenJoinStyle.MiterJoin) + painter.setPen(pen) + painter.setBrush(color) + + rect = QRectF(p1.x() - offset, p1.y() - offset, size, size) + painter.drawEllipse(rect) + + arrow = QPolygonF() + arrow.append(QPointF(-arrow_size, arrow_size)) + arrow.append(QPointF(0.0, -arrow_size * 0.9)) + arrow.append(QPointF(arrow_size, arrow_size)) + + transform = QTransform() + transform.translate(p2.x(), p2.y()) + radians = math.atan2(p2.y() - p1.y(), p2.x() - p1.x()) + degrees = math.degrees(radians) - 90 + transform.rotate(degrees) + + painter.drawPolygon(transform.map(arrow)) + painter.restore() + + def draw_path(self, p1, p2): + path = QPainterPath() + path.moveTo(p1) + path.lineTo(p2) + self.setPath(path) diff --git a/mne_pipeline_hd/gui/node/ports.py b/mne_pipeline_hd/gui/node/ports.py new file mode 100644 index 00000000..9ee89322 --- /dev/null +++ b/mne_pipeline_hd/gui/node/ports.py @@ -0,0 +1,472 @@ +# -*- coding: utf-8 -*- +import logging +from collections import OrderedDict + +from mne_pipeline_hd.gui.gui_utils import format_color +from mne_pipeline_hd.gui.node.node_defaults import defaults +from mne_pipeline_hd.gui.node.pipes import Pipe +from qtpy.QtCore import QRectF +from qtpy.QtGui import QColor, QPen +from qtpy.QtWidgets import QGraphicsItem, QGraphicsTextItem + + +class PortText(QGraphicsTextItem): + def __init__(self, text, parent=None): + super().__init__(text, parent) + self.font().setPointSize(8) + self.setFont(self.font()) + self.setCacheMode(QGraphicsItem.DeviceCoordinateCache) + + +class Port(QGraphicsItem): + """A graphical representation of a port in a node-based interface. + + This class represents a port which can be an input or output port of a node. + It supports hover events, multiple connections, and connection compatibility checks. + It also handles the drawing of the port and its connections. + + Parameters + ---------- + node : Node + The node this port is part of. + name : str + The name of the port. + port_type : str + The type of the port, can be either 'in' or 'out'. + multi_connection : bool + Whether the port supports multiple connections or not, defaults to False. + accepted_ports : list, None + List of port names that this port can connect to. + If None, it can connect to any port. + old_id : int, None, optional + old id for reestablishing connections. + """ + + def __init__( + self, + node, + name, + port_type, + multi_connection=False, + accepted_ports=None, + old_id=None, + ): + super().__init__(node) + + # init Qt graphics item + self.setAcceptHoverEvents(True) + self.setCacheMode(QGraphicsItem.DeviceCoordinateCache) + self.setFlag(self.GraphicsItemFlag.ItemIsSelectable, False) + self.setFlag(self.GraphicsItemFlag.ItemSendsScenePositionChanges, True) + self.setZValue(2) + + # init text item + self.text = PortText(name, self) + + # (hidden) attributes + self.node = node + self.id = id(self) + self.old_id = old_id + self._name = name + self._port_type = port_type + self.multi_connection = multi_connection + self.connected_ports = list() + self.connected_pipes = OrderedDict() + self.accepted_ports = accepted_ports + + self._width = defaults["ports"]["size"] + self._height = defaults["ports"]["size"] + self._color = defaults["ports"]["color"] + self._border_color = defaults["ports"]["border_color"] + self._active_color = defaults["ports"]["active_color"] + self._active_border_color = defaults["ports"]["active_border_color"] + self._hover_color = defaults["ports"]["hover_color"] + self._hover_border_color = defaults["ports"]["hover_border_color"] + self._text_color = defaults["nodes"]["text_color"] + self._hovered = False + + # -------------------------------------------------------------------------------------- + # Properties + # -------------------------------------------------------------------------------------- + @property + def name(self): + return self._name + + @name.setter + def name(self, value): + self._name = value + self.text.setPlainText(value) + + @property + def port_type(self): + return self._port_type + + @port_type.setter + def port_type(self, value): + if value not in ["in", "out"]: + raise ValueError(f"Invalid port type: {value}") + self._port_type = value + + @property + def width(self): + return self._width + + @width.setter + def width(self, width): + width = max(width, defaults["ports"]["size"]) + self._width = width + + @property + def height(self): + return self._height + + @height.setter + def height(self, height): + height = max(height, defaults["ports"]["size"]) + self._height = height + + @property + def color(self): + return self._color + + @color.setter + def color(self, color): + self._color = format_color(color) + self.update() + + @property + def border_color(self): + return self._border_color + + @border_color.setter + def border_color(self, color): + self._border_color = format_color(color) + self.update() + + @property + def active_color(self): + return self._active_color + + @active_color.setter + def active_color(self, color): + self._active_color = format_color(color) + self.update() + + @property + def active_border_color(self): + return self._active_border_color + + @active_border_color.setter + def active_border_color(self, color): + self._active_border_color = format_color(color) + self.update() + + @property + def hover_color(self): + return self._hover_color + + @hover_color.setter + def hover_color(self, color): + self._hover_color = format_color(color) + self.update() + + @property + def hover_border_color(self): + return self._hover_border_color + + @hover_border_color.setter + def hover_border_color(self, color): + self._hover_border_color = format_color(color) + self.update() + + @property + def text_color(self): + return self._text_color + + @text_color.setter + def text_color(self, color): + self._text_color = format_color(color) + self.text.setDefaultTextColor(QColor(*self._text_color)) + + @property + def hovered(self): + return self._hovered + + @hovered.setter + def hovered(self, hovered): + self._hovered = hovered + self.update() + + # -------------------------------------------------------------------------------------- + # Qt methods + # -------------------------------------------------------------------------------------- + def boundingRect(self): + # NodeViewer.port_position_scene() depends + # on the position of boundingRect to be (0, 0). + return QRectF( + 0.0, + 0.0, + self._width + defaults["ports"]["click_falloff"], + self._height, + ) + + def setPos(self, x, y): + super().setPos(x, y) + falloff = defaults["ports"]["click_falloff"] - 2 + if self.port_type == "in": + offset = self.boundingRect().width() - falloff + else: + offset = -self.text.boundingRect().width() + falloff + self.text.setPos(offset, -1.5) + + def paint(self, painter, option, widget=None): + """Draws the circular port. + + Args: + painter (QtGui.QPainter): painter used for drawing the item. + option (QtGui.QStyleOptionGraphicsItem): + used to describe the parameters needed to draw. + widget (QtWidgets.QWidget): not used. + """ + painter.save() + + rect_w = self._width / 1.8 + rect_h = self._height / 1.8 + rect_x = self.boundingRect().center().x() - (rect_w / 2) + rect_y = self.boundingRect().center().y() - (rect_h / 2) + port_rect = QRectF(rect_x, rect_y, rect_w, rect_h) + + if self._hovered: + color = QColor(*self.hover_color) + border_color = QColor(*self.hover_border_color) + elif len(self.connected_pipes) > 0: + color = QColor(*self.active_color) + border_color = QColor(*self.active_border_color) + else: + color = QColor(*self.color) + border_color = QColor(*self.border_color) + + pen = QPen(border_color, 1.8) + painter.setPen(pen) + painter.setBrush(color) + painter.drawEllipse(port_rect) + + if self.connected_pipes and not self._hovered: + painter.setBrush(border_color) + w = port_rect.width() / 2.5 + h = port_rect.height() / 2.5 + rect = QRectF( + port_rect.center().x() - w / 2, port_rect.center().y() - h / 2, w, h + ) + border_color = QColor(*self.border_color) + pen = QPen(border_color, 1.6) + painter.setPen(pen) + painter.setBrush(border_color) + painter.drawEllipse(rect) + elif self._hovered: + if self.multi_connection: + pen = QPen(border_color, 1.4) + painter.setPen(pen) + painter.setBrush(color) + w = port_rect.width() / 1.8 + h = port_rect.height() / 1.8 + else: + painter.setBrush(border_color) + w = port_rect.width() / 3.5 + h = port_rect.height() / 3.5 + rect = QRectF( + port_rect.center().x() - w / 2, port_rect.center().y() - h / 2, w, h + ) + painter.drawEllipse(rect) + painter.restore() + + def redraw_connected_pipes(self): + if len(self.connected_pipes) == 0: + return + for node_id, pipe in self.connected_pipes.items(): + if self.port_type == "in": + pipe.draw_path(self, pipe.output_port) + elif self.port_type == "out": + pipe.draw_path(pipe.input_port, self) + + def itemChange(self, change, value): + if change == self.GraphicsItemChange.ItemScenePositionHasChanged: + self.redraw_connected_pipes() + return super().itemChange(change, value) + + def hoverEnterEvent(self, event): + self._hovered = True + super().hoverEnterEvent(event) + + def hoverLeaveEvent(self, event): + self._hovered = False + super().hoverLeaveEvent(event) + + # -------------------------------------------------------------------------------------- + # Logic methods + # -------------------------------------------------------------------------------------- + def to_dict(self): + return { + "name": self.name, + "port_type": self.port_type, + "multi_connection": self.multi_connection, + "accepted_ports": self.accepted_ports, + "old_id": self.id, + } + + def add_accepted_ports(self, ports): + if isinstance(ports, list): + self.accepted_ports.extend(ports) + elif isinstance(ports, str): + self._accepted_ports.append(ports) + else: + raise ValueError("Invalid port type") + + def connected(self, target_port): + """Check if the specified port (port object, port name or port id) is connected + to this port.""" + if isinstance(target_port, str): + if target_port in [port.name for port in self.connected_ports]: + return True + elif isinstance(target_port, int): + if target_port in [port.id for port in self.connected_ports]: + return True + elif isinstance(target_port, Port): + if target_port in self.connected_ports: + return True + else: + logging.warning( + "Invalid port type for connection check " + "(only port object, port name or port id accepted)." + ) + return False + + def compatible(self, port, verbose=True): + """Check if the specified port is compatible with this port.""" + # check if the ports are the same. + if self is port: + if verbose: + logging.debug("Can't connect the same port.") + # check if the ports are from the same node. + elif self.node is port.node: + if verbose: + logging.debug("Can't connect ports from the same node.") + # check if the ports are from the same type (can't connect input to input). + elif self.port_type == port.port_type: + if verbose: + logging.debug("Can't connect the same port type.") + # check if the ports are already connected. + elif self.connected(port): + if verbose: + logging.debug("Ports are already connected.") + # check if the ports are compatible. + elif self.accepted_ports is not None and port.name not in self.accepted_ports: + if verbose: + logging.debug("Ports are not compatible.") + else: + if verbose: + logging.debug("Ports are compatible.") + return True + return False + + def connect_to(self, target_port=None): + """Create connection to the specified port and emits the + :attr:`NodeGraph.port_connected` signal from the parent node graph. + + Args: + target_port (Port): port object. + """ + if target_port is None: + for pipe in self.connected_pipes.values(): + pipe.delete() + logging.debug("No target port specified.") + return + + # validate accept connection. + if not self.compatible(target_port): + return + + # Remove existing connections from this port and the target port, + # if not multi-connection. + for port in [self, target_port]: + if not port.multi_connection and len(port.connected_ports) > 0: + for d_port in list(port.connected_ports): + port.disconnect_from(d_port) + + # Add to connected_ports + for port, trg_port in [(self, target_port), (target_port, self)]: + if trg_port not in port.connected_ports: + port.connected_ports.append(trg_port) + + if self.port_type == "in": + input_port = self + output_port = target_port + else: + input_port = target_port + output_port = self + # Draw pipe + if self.scene(): + pipe = Pipe(input_port, output_port) + input_port.connected_pipes[output_port.id] = pipe + output_port.connected_pipes[input_port.id] = pipe + self.scene().addItem(pipe) + pipe.draw_path(input_port, output_port) + if self.node.isSelected() or target_port.node.isSelected(): + pipe.highlight() + if not self.node.isVisible() or not target_port.node.isVisible(): + pipe.hide() + else: + logging.warning( + f"Scene not found, could not draw pipe from " + f"{self.name} to {target_port.name}." + ) + + # Emit Signal + self.node.viewer.PortConnected.emit(input_port, output_port) + + self.update() + target_port.update() + logging.debug( + f"Connected {self.node.name}/{self.name} to " + f"{target_port.node.name}/{target_port.name}" + ) + + def disconnect_from(self, target_port=None): + """Disconnect from the specified port and emits the + :attr:`NodeGraph.port_disconnected` signal from the parent node graph. + + Args: + target_port (NodeGrapchQt.Port): port object. + """ + if not target_port: + return + + # Remove ids from connected ports of this port and the target port. + for port, trg_port in [(self, target_port), (target_port, self)]: + port.connected_ports.remove(trg_port) + + # Remove the pipe connected to target_port + rm_pipe = self.connected_pipes.pop(target_port.id, None) + if rm_pipe is not None: + rm_pipe.delete() + + # emit signal + if self.port_type == "in": + self.node.viewer.PortDisconnected.emit(self, target_port) + else: + self.node.viewer.PortDisconnected.emit(target_port, self) + + self.update() + target_port.update() + logging.debug( + f"Disconnected {self.node.name}/{self.name} from " + f"{target_port.node.name}/{target_port.name}" + ) + + def clear_connections(self): + """Disconnect from all port connections and emit the + :attr:`NodeGraph.port_disconnected` signals from the node graph.""" + # Copy to avoid iteration failure + remove_ports = list(self.connected_ports) + for port in remove_ports: + self.disconnect_from(port) diff --git a/mne_pipeline_hd/gui/parameter_widgets.py b/mne_pipeline_hd/gui/parameter_widgets.py index bbe827b8..f1c78cf7 100644 --- a/mne_pipeline_hd/gui/parameter_widgets.py +++ b/mne_pipeline_hd/gui/parameter_widgets.py @@ -65,11 +65,10 @@ # ToDo: Unify None-select and more +# ToDo: potentially use docstring-inheritance to avoid repetition class Param(QWidget): - """ - Base-Class Parameter-GUIs, not to be called directly - Inherited Clases should have "Gui" in their name to get - identified correctly. + """Base-Class Parameter-GUIs, not to be called directly Inherited Clases should have + "Gui" in their name to get identified correctly. Attributes ---------- @@ -160,9 +159,11 @@ def __init__( _object_refs["parameter_widgets"][self.name] = self def init_ui(self, layout=None): - """Base layout initialization, which adds the given layout to a - group-box with the parameters name if groupbox_layout is enabled. - Else the layout will be horizontal with a QLabel for the name""" + """Base layout initialization, which adds the given layout to a group-box with + the parameters name if groupbox_layout is enabled. + + Else the layout will be horizontal with a QLabel for the name + """ main_layout = QHBoxLayout() @@ -215,11 +216,11 @@ def check_groupbox_state(self): self.save_param() def get_value(self): - """This should be implemented for each widget""" + """This should be implemented for each widget.""" pass def set_value(self, value): - """This should be implemented for each widget""" + """This should be implemented for each widget.""" pass def _get_param(self): @@ -288,7 +289,7 @@ def save_param(self): class IntGui(Param): - """A GUI for Integer-Parameters""" + """A GUI for Integer-Parameters.""" data_type = int @@ -336,7 +337,7 @@ def get_value(self): class FloatGui(Param): - """A GUI for Float-Parameters""" + """A GUI for Float-Parameters.""" data_type = float @@ -385,9 +386,7 @@ def get_value(self): class StringGui(Param): - """ - A GUI for String-Parameters - """ + """A GUI for String-Parameters.""" data_type = str @@ -431,7 +430,7 @@ def _eval_param(param_exp): class FuncGui(Param): - """A GUI for Parameters defined by small functions, e.g from numpy""" + """A GUI for Parameters defined by small functions, e.g from numpy.""" data_type = "multiple" @@ -520,7 +519,7 @@ def save_param(self): class BoolGui(Param): - """A GUI for Boolean-Parameters""" + """A GUI for Boolean-Parameters.""" data_type = bool @@ -561,7 +560,7 @@ def get_value(self): class TupleGui(Param): - """A GUI for Tuple-Parameters""" + """A GUI for Tuple-Parameters.""" data_type = tuple @@ -641,7 +640,7 @@ def get_value(self): # ToDo: make options replacable class ComboGui(Param): - """A GUI for a Parameter with limited options""" + """A GUI for a Parameter with limited options.""" data_type = "multiple" @@ -748,7 +747,7 @@ def closeEvent(self, event): class ListGui(Param): - """A GUI for as list""" + """A GUI for as list.""" data_type = list @@ -847,7 +846,7 @@ def closeEvent(self, event): # ToDo: make options replacable class CheckListGui(Param): - """A GUI to select items from a list of options""" + """A GUI to select items from a list of options.""" data_type = list @@ -947,7 +946,7 @@ def closeEvent(self, event): class DictGui(Param): - """A GUI for a dictionary""" + """A GUI for a dictionary.""" data_type = dict @@ -1024,7 +1023,7 @@ def get_value(self): class SliderGui(Param): - """A GUI to show a slider for Int/Float-Parameters""" + """A GUI to show a slider for Int/Float-Parameters.""" data_type = "multiple" @@ -1118,7 +1117,7 @@ def get_value(self): class MultiTypeGui(Param): - """A GUI which accepts multiple types of values in a single LineEdit""" + """A GUI which accepts multiple types of values in a single LineEdit.""" data_type = "multiple" diff --git a/mne_pipeline_hd/gui/plot_widgets.py b/mne_pipeline_hd/gui/plot_widgets.py index 75249624..b95f90de 100644 --- a/mne_pipeline_hd/gui/plot_widgets.py +++ b/mne_pipeline_hd/gui/plot_widgets.py @@ -138,9 +138,8 @@ def show_plot_manager(): class PlotViewSelection(QDialog): - """The user selects the plot-function and the objects - to show for this plot_function - """ + """The user selects the plot-function and the objects to show for this + plot_function.""" def __init__(self, main_win): super().__init__(main_win) @@ -239,8 +238,7 @@ def update_objects(self): self.obj_select.replace_data(self.objects) def func_selected(self, func): - """Get selected function and adjust contents - of Object-Selection to target""" + """Get selected function and adjust contents of Object-Selection to target.""" self.selected_func = func self.target = self.ct.pd_funcs.loc[func, "target"] self.update_objects() diff --git a/mne_pipeline_hd/gui/syntax_highlight.py b/mne_pipeline_hd/gui/syntax_highlight.py new file mode 100644 index 00000000..35296af6 --- /dev/null +++ b/mne_pipeline_hd/gui/syntax_highlight.py @@ -0,0 +1,245 @@ +# -*- coding: utf-8 -*- +""" +This is code from https://wiki.python.org/moin/PyQt/Python%20syntax%20highlighting +""" +from qtpy import QtCore, QtGui + + +def qformat(color, style=""): + """Return a QTextCharFormat with the given attributes.""" + _color = QtGui.QColor() + _color.setNamedColor(color) + + _format = QtGui.QTextCharFormat() + _format.setForeground(_color) + if "bold" in style: + _format.setFontWeight(QtGui.QFont.Bold) + if "italic" in style: + _format.setFontItalic(True) + + return _format + + +# Syntax styles that can be shared by all languages +STYLES = { + "keyword": qformat("blue"), + "operator": qformat("red"), + "brace": qformat("darkGray"), + "defclass": qformat("black", "bold"), + "string": qformat("magenta"), + "string2": qformat("darkMagenta"), + "comment": qformat("darkGreen", "italic"), + "self": qformat("black", "italic"), + "numbers": qformat("brown"), +} + + +class PythonHighlighter(QtGui.QSyntaxHighlighter): + """Syntax highlighter for the Python language.""" + + # Python keywords + keywords = [ + "and", + "assert", + "break", + "class", + "continue", + "def", + "del", + "elif", + "else", + "except", + "exec", + "finally", + "for", + "from", + "global", + "if", + "import", + "in", + "is", + "lambda", + "not", + "or", + "pass", + "print", + "raise", + "return", + "try", + "while", + "yield", + "None", + "True", + "False", + ] + + # Python operators + operators = [ + "=", + # Comparison + "==", + "!=", + "<", + "<=", + ">", + ">=", + # Arithmetic + "\+", + "-", + "\*", + "/", + "//", + "\%", + "\*\*", + # In-place + "\+=", + "-=", + "\*=", + "/=", + "\%=", + # Bitwise + "\^", + "\|", + "\&", + "\~", + ">>", + "<<", + ] + + # Python braces + braces = [ + "\{", + "\}", + "\(", + "\)", + "\[", + "\]", + ] + + def __init__(self, parent: QtGui.QTextDocument) -> None: + super().__init__(parent) + + # Multi-line strings (expression, flag, style) + self.tri_single = (QtCore.QRegExp("'''"), 1, STYLES["string2"]) + self.tri_double = (QtCore.QRegExp('"""'), 2, STYLES["string2"]) + + rules = [] + + # Keyword, operator, and brace rules + rules += [ + (r"\b%s\b" % w, 0, STYLES["keyword"]) for w in PythonHighlighter.keywords + ] + rules += [ + (r"%s" % o, 0, STYLES["operator"]) for o in PythonHighlighter.operators + ] + rules += [(r"%s" % b, 0, STYLES["brace"]) for b in PythonHighlighter.braces] + + # All other rules + rules += [ + # 'self' + (r"\bself\b", 0, STYLES["self"]), + # 'def' followed by an identifier + (r"\bdef\b\s*(\w+)", 1, STYLES["defclass"]), + # 'class' followed by an identifier + (r"\bclass\b\s*(\w+)", 1, STYLES["defclass"]), + # Numeric literals + (r"\b[+-]?[0-9]+[lL]?\b", 0, STYLES["numbers"]), + (r"\b[+-]?0[xX][0-9A-Fa-f]+[lL]?\b", 0, STYLES["numbers"]), + (r"\b[+-]?[0-9]+(?:\.[0-9]+)?(?:[eE][+-]?[0-9]+)?\b", 0, STYLES["numbers"]), + # Double-quoted string, possibly containing escape sequences + (r'"[^"\\]*(\\.[^"\\]*)*"', 0, STYLES["string"]), + # Single-quoted string, possibly containing escape sequences + (r"'[^'\\]*(\\.[^'\\]*)*'", 0, STYLES["string"]), + # From '#' until a newline + (r"#[^\n]*", 0, STYLES["comment"]), + ] + + # Build a QRegExp for each pattern + self.rules = [(QtCore.QRegExp(pat), index, fmt) for (pat, index, fmt) in rules] + + def highlightBlock(self, text): + """Apply syntax highlighting to the given block of text.""" + self.tripleQuoutesWithinStrings = [] + # Do other syntax formatting + for expression, nth, format in self.rules: + index = expression.indexIn(text, 0) + if index >= 0: + # if there is a string we check + # if there are some triple quotes within the string + # they will be ignored if they are matched again + if expression.pattern() in [ + r'"[^"\\]*(\\.[^"\\]*)*"', + r"'[^'\\]*(\\.[^'\\]*)*'", + ]: + innerIndex = self.tri_single[0].indexIn(text, index + 1) + if innerIndex == -1: + innerIndex = self.tri_double[0].indexIn(text, index + 1) + + if innerIndex != -1: + tripleQuoteIndexes = range(innerIndex, innerIndex + 3) + self.tripleQuoutesWithinStrings.extend(tripleQuoteIndexes) + + while index >= 0: + # skipping triple quotes within strings + if index in self.tripleQuoutesWithinStrings: + index += 1 + expression.indexIn(text, index) + continue + + # We actually want the index of the nth match + index = expression.pos(nth) + length = len(expression.cap(nth)) + self.setFormat(index, length, format) + index = expression.indexIn(text, index + length) + + self.setCurrentBlockState(0) + + # Do multi-line strings + in_multiline = self.match_multiline(text, *self.tri_single) + if not in_multiline: + in_multiline = self.match_multiline(text, *self.tri_double) + + def match_multiline(self, text, delimiter, in_state, style): + """Do highlighting of multi-line strings. + + ``delimiter`` should be a + ``QRegExp`` for triple-single-quotes or triple-double-quotes, and + ``in_state`` should be a unique integer to represent the corresponding + state changes when inside those strings. Returns True if we're still + inside a multi-line string when this function is finished. + """ + # If inside triple-single quotes, start at 0 + if self.previousBlockState() == in_state: + start = 0 + add = 0 + # Otherwise, look for the delimiter on this line + else: + start = delimiter.indexIn(text) + # skipping triple quotes within strings + if start in self.tripleQuoutesWithinStrings: + return False + # Move past this match + add = delimiter.matchedLength() + + # As long as there's a delimiter match on this line... + while start >= 0: + # Look for the ending delimiter + end = delimiter.indexIn(text, start + add) + # Ending delimiter on this line? + if end >= add: + length = end - start + add + delimiter.matchedLength() + self.setCurrentBlockState(0) + # No; multi-line string + else: + self.setCurrentBlockState(in_state) + length = len(text) - start + add + # Apply formatting + self.setFormat(start, length, style) + # Look for the next match + start = delimiter.indexIn(text, start + length) + + # Return True if still inside a multi-line string, False otherwise + if self.currentBlockState() == in_state: + return True + else: + return False diff --git a/mne_pipeline_hd/pipeline/controller.py b/mne_pipeline_hd/pipeline/controller.py index dbd09978..1747b7e3 100644 --- a/mne_pipeline_hd/pipeline/controller.py +++ b/mne_pipeline_hd/pipeline/controller.py @@ -12,6 +12,7 @@ import shutil import sys import traceback +from datetime import datetime from importlib import reload, resources, import_module from os import listdir from os.path import isdir, join @@ -23,7 +24,12 @@ from mne_pipeline_hd import functions, extra from mne_pipeline_hd.gui.gui_utils import get_user_input_string from mne_pipeline_hd.pipeline.legacy import transfer_file_params_to_single_subject -from mne_pipeline_hd.pipeline.pipeline_utils import QS, logger +from mne_pipeline_hd.pipeline.pipeline_utils import ( + QS, + logger, + type_json_hook, + TypedJSONEncoder, +) from mne_pipeline_hd.pipeline.project import Project home_dirs = ["custom_packages", "freesurfer", "projects"] @@ -309,9 +315,7 @@ def load_edu(self): ] def import_custom_modules(self): - """ - Load all modules in functions and custom_functions - """ + """Load all modules in functions and custom_functions.""" # Load basic-modules # Add functions to sys.path @@ -423,3 +427,99 @@ def reload_modules(self): # be caught by the UncaughtHook spec.loader.exec_module(module) sys.modules[module_name] = module + + +class NewController: + """New controller, that combines the former old controller and project class and + loads a controller for each "project". + + The home-path structure should no longer be as rigid as before, just specifying the + path to meeg- and fsmri-data. For each controller, there is a config-file stored, + where paths to the meeg-data, the freesurfer-dir and the custom-packages are stored. + """ + + def __init__(self, config_file=None): + self.config_file = config_file + self.config = self.load_config() + + def load_config(self): + if self.config_file is not None: + return json.load(self.config_file, object_hook=type_json_hook) + else: + return dict() + + def save_config(self): + if self.config_file is None: + logging.error("No config-file set!") + with open(self.config_file, "w") as file: + json.dump(self.config, file, indent=2, cls=TypedJSONEncoder) + + @property + def name(self): + name_default = f"Project_{datetime.now().strftime('%Y%m%d%H%M%S')}" + return self.config.get("name", name_default) + + # ToDo: Rename function (rename all files etc.) + def rename(self, new_name): + pass + + @property + def meeg_root(self): + if "meeg_root" not in self.config: + raise ValueError("The path to the MEEG data is not set!") + return self.config["meeg_root"] + + @meeg_root.setter + def meeg_root(self, value): + if not isdir(value): + raise ValueError(f"Path {value} does not exist!") + self.config["meeg_root"] = value + self.save_config() + + @property + def fsmri_root(self): + if "fsmri_root" not in self.config: + raise ValueError("The path to the FreeSurfer MRI data is not set!") + return self.config["fsmri_root"] + + @fsmri_root.setter + def fsmri_root(self, value): + if not isdir(value): + raise ValueError(f"Path {value} does not exist!") + self.config["fsmri_root"] = value + self.save_config() + + @property + def plots_path(self): + if "plots_path" not in self.config: + raise ValueError("The path for plots is not set!") + return self.config["plots_path"] + + @plots_path.setter + def plots_path(self, value): + if not isdir(value): + raise ValueError(f"Path {value} does not exist!") + self.config["plots_path"] = value + self.save_config() + + @property + def inputs(self): + """This holds all data inputs from MEEG, FSMRI, etc.""" + if "inputs" not in self.config: + self.config["inputs"] = { + "MEEG": list(), + "FSMRI": list(), + "EmptyRoom": list(), + } + return self.config["inputs"] + + @property + def selected_inputs(self): + """This holds all selected inputs.""" + if "selected_inputs" not in self.config: + self.config["selected_inputs"] = { + "MEEG": list(), + "FSMRI": list(), + "EmptyRoom": list(), + } + return self.config["selected_inputs"] diff --git a/mne_pipeline_hd/pipeline/function_utils.py b/mne_pipeline_hd/pipeline/function_utils.py index e84e8fbb..3366d1ce 100644 --- a/mne_pipeline_hd/pipeline/function_utils.py +++ b/mne_pipeline_hd/pipeline/function_utils.py @@ -141,6 +141,7 @@ def run_func(func, keywargs, pipe=None): return get_exception_tuple(is_mp=pipe is not None) +# Continue: Implement Run-Controller for Nodes-GUI class RunController: def __init__(self, controller): self.ct = controller diff --git a/mne_pipeline_hd/pipeline/legacy.py b/mne_pipeline_hd/pipeline/legacy.py index 37d66411..99d7c49b 100644 --- a/mne_pipeline_hd/pipeline/legacy.py +++ b/mne_pipeline_hd/pipeline/legacy.py @@ -68,10 +68,8 @@ def uninstall_package(package_name): def legacy_import_check(test_package=None): - """ - This function checks for recent package changes - and offers installation or manual installation instructions. - """ + """This function checks for recent package changes and offers installation or manual + installation instructions.""" # For testing purposes if test_package is not None: new_packages[test_package] = test_package diff --git a/mne_pipeline_hd/pipeline/loading.py b/mne_pipeline_hd/pipeline/loading.py index 17e91d7d..75e45182 100644 --- a/mne_pipeline_hd/pipeline/loading.py +++ b/mne_pipeline_hd/pipeline/loading.py @@ -33,6 +33,13 @@ logger, ) +# BIDS-Considerations: +# - BIDS might be a bit work at first (implement session, run etc.) +# - Might pay off since data and derivative might be used then by mne-bids-pipeline +# and hopefully other analysis tools complying to the bids-standard +# - For freesurfer data use recon-all --bids-out +# - bids-apps/freesurfer might be also interesting + def _get_data_type_from_func(self, func, method): # Get matching data-type from IO-Dict @@ -149,8 +156,8 @@ def save_wrapper(self, *args, **kwargs): # ToDo: Unify all objects to one loading-class # For example Group and MEEG can have the same load-method for Source-Estimates then class BaseLoading: - """Base-Class for Sub (The current File/MRI-File/Grand-Average-Group, - which is executed)""" + """Base-Class for Sub (The current File/MRI-File/Grand-Average-Group, which is + executed)""" def __init__(self, name, controller): # Basic Attributes (partly taking parameters or main-win-attributes @@ -184,7 +191,7 @@ def init_plot_files(self): self.plot_files = self.pr.plot_files[self.name][self.p_preset] def get_parameter(self, parameter_name): - """Get parameter from parameter-dictionary""" + """Get parameter from parameter-dictionary.""" if parameter_name in self.pa: return self.pa[parameter_name] @@ -192,13 +199,13 @@ def get_parameter(self, parameter_name): raise KeyError(f"Parameter {parameter_name} not found in parameters") def init_attributes(self): - """Initialization of additional attributes, should be overridden - in inherited classes""" + """Initialization of additional attributes, should be overridden in inherited + classes.""" pass def init_paths(self): - """Initialization of all paths and the io_dict, should be overridden - in inherited classes""" + """Initialization of all paths and the io_dict, should be overridden in + inherited classes.""" self.save_dir = "" self.io_dict = dict() self.deprecated_paths = dict() @@ -340,10 +347,8 @@ def plot_save( dpi=None, img_format=None, ): - """ - Save a plot with this method either by letting the figure be detected - by the backend (pyplot, mayavi) or by - supplying the figure directly. + """Save a plot with this method either by letting the figure be detected by the + backend (pyplot, mayavi) or by supplying the figure directly. Parameters ---------- @@ -516,8 +521,7 @@ def remove_json(self, file_name): logger().warning(f"{file_path} was removed") def get_existing_paths(self): - """Get existing paths and add the mapped File-Type - to existing_paths (set)""" + """Get existing paths and add the mapped File-Type to existing_paths (set)""" self.existing_paths.clear() for data_type in self.io_dict: paths = self._return_path_list(data_type) @@ -590,7 +594,7 @@ def remove_path(self, data_type): # In the future there should be only one, # favor io_dict (better than attribute since easier to set from config-files) class MEEG(BaseLoading): - """Class for File-Data in File-Loop""" + """Class for File-Data in File-Loop.""" def __init__(self, name, controller, fsmri=None, suppress_warnings=True): self.fsmri = fsmri @@ -601,7 +605,7 @@ def __init__(self, name, controller, fsmri=None, suppress_warnings=True): self.init_sample() def init_attributes(self): - """Initialize additional attributes for MEEG""" + """Initialize additional attributes for MEEG.""" # The assigned Empty-Room-Measurement if existing if self.name not in self.pr.meeg_to_erm: self.erm = None @@ -679,8 +683,7 @@ def init_attributes(self): self.ica_exclude = self.pr.meeg_ica_exclude[self.name] def init_paths(self): - """Load Paths as attributes - (depending on which Parameter-Preset is selected)""" + """Load Paths as attributes (depending on which Parameter-Preset is selected)""" # Main save directory self.save_dir = join(self.pr.data_path, self.name) @@ -1051,7 +1054,17 @@ def set_ica_exclude(self, ica_exclude): ########################################################################### def load_info(self): - return mne.io.read_info(self.raw_path) + if isfile(self.raw_path): + path = self.raw_path + elif isfile(self.epochs_path): + path = self.epochs_path + elif isfile(self.evokeds_path): + path = self.evokeds_path + else: + raise FileNotFoundError( + f"Could not find file to load info from for {self.name}" + ) + return mne.io.read_info(path) @load_decorator def load_raw(self): @@ -1110,7 +1123,7 @@ def save_epochs(self, epochs): epochs.save(self.epochs_path, overwrite=True) def get_trial_epochs(self): - """Return epochs for each trial in self.sel_trials""" + """Return epochs for each trial in self.sel_trials.""" epochs = self.load_epochs() for trial, meta_query in self.sel_trials.items(): epoch_trial = meta_query or trial @@ -1440,7 +1453,7 @@ def __init__(self, name, controller, load_labels=False): self.init_fsaverage() def init_attributes(self): - """Initialize additional attributes for FSMRI""" + """Initialize additional attributes for FSMRI.""" self.fs_path = QS().value("fs_path") self.mne_path = QS().value("mne_path") @@ -1643,7 +1656,7 @@ def __init__(self, name, controller, suppress_warnings=True): super().__init__(name, controller) def init_attributes(self): - """Initialize additional attributes for Group""" + """Initialize additional attributes for Group.""" if self.name not in self.pr.all_groups: self.group_list = [] if not self.suppress_warnings: @@ -1760,6 +1773,7 @@ def init_paths(self): ########################################################################### def load_items(self, obj_type="MEEG", data_type=None): """Returns a generator for group items.""" + # ToDO: Also with obj_type=None and only data_type for obj_name in self.group_list: if obj_type == "MEEG": obj = MEEG(obj_name, self.ct) diff --git a/mne_pipeline_hd/pipeline/pipeline_utils.py b/mne_pipeline_hd/pipeline/pipeline_utils.py index 8b3ca3ba..997c53a2 100644 --- a/mne_pipeline_hd/pipeline/pipeline_utils.py +++ b/mne_pipeline_hd/pipeline/pipeline_utils.py @@ -59,7 +59,7 @@ def logger(): def get_n_jobs(n_jobs): - """Get the number of jobs to use for parallel processing""" + """Get the number of jobs to use for parallel processing.""" if n_jobs == -1 or n_jobs in ["auto", "max"]: n_cores = multiprocessing.cpu_count() else: @@ -69,8 +69,10 @@ def get_n_jobs(n_jobs): def encode_tuples(input_dict): - """Encode tuples in a dictionary, because JSON does not recognize them - (CAVE: input_dict is changed in place)""" + """Encode tuples in a dictionary, because JSON does not recognize them (CAVE: + + input_dict is changed in place) + """ for key, value in input_dict.items(): if isinstance(value, dict): encode_tuples(value) @@ -80,19 +82,20 @@ def encode_tuples(input_dict): class TypedJSONEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, np.integer): - return int(obj) - elif isinstance(obj, np.floating): - return float(obj) - elif isinstance(obj, np.ndarray): - return {"numpy_array": obj.tolist()} - elif isinstance(obj, datetime): - return {"datetime": obj.strftime(datetime_format)} - elif isinstance(obj, set): - return {"set_type": list(obj)} + def default(self, o): + if isinstance(o, np.integer): + return int(o) + elif isinstance(o, np.floating): + return float(o) + # Only onedimensional arrays are supported + elif isinstance(o, np.ndarray): + return {"numpy_array": o.tolist()} + elif isinstance(o, datetime): + return {"datetime": o.strftime(datetime_format)} + elif isinstance(o, set): + return {"set_type": list(o)} else: - return json.JSONEncoder.default(self, obj) + return json.JSONEncoder.default(self, o) def type_json_hook(obj): @@ -100,6 +103,7 @@ def type_json_hook(obj): return obj["numpy_int"] elif "numpy_float" in obj.keys(): return obj["numpy_float"] + # Only onedimensional arrays are supported elif "numpy_array" in obj.keys(): return np.asarray(obj["numpy_array"]) elif "datetime" in obj.keys(): @@ -113,8 +117,8 @@ def type_json_hook(obj): def compare_filep(obj, path, target_parameters=None, verbose=True): - """Compare the parameters of the previous run to the current - parameters for the given path + """Compare the parameters of the previous run to the current parameters for the + given path. Parameters ---------- @@ -210,7 +214,7 @@ def check_kwargs(kwargs, function): def count_dict_keys(d, max_level=None): - """Count the number of keys of a nested dictionary""" + """Count the number of keys of a nested dictionary.""" keys = 0 for value in d.values(): if isinstance(value, dict): @@ -236,8 +240,7 @@ def shutdown(): def restart_program(): - """Restarts the current program, with file objects and descriptors - cleanup.""" + """Restarts the current program, with file objects and descriptors cleanup.""" logger().info("Restarting") try: p = psutil.Process(os.getpid()) diff --git a/mne_pipeline_hd/pipeline/project.py b/mne_pipeline_hd/pipeline/project.py index 8f4a9cd9..c588bf05 100644 --- a/mne_pipeline_hd/pipeline/project.py +++ b/mne_pipeline_hd/pipeline/project.py @@ -29,10 +29,8 @@ class Project: - """ - A class with attributes for all the paths, file-lists/dicts - and parameters of the selected project - """ + """A class with attributes for all the paths, file-lists/dicts and parameters of the + selected project.""" def __init__(self, controller, name): self.ct = controller @@ -385,11 +383,13 @@ def save(self, worker_signals=None): def add_meeg(self, name, file_path=None, is_erm=False): if is_erm: - # Organize Empty-Room-FIles - self.all_erm.append(name) + # Organize Empty-Room-Files + if name not in self.all_erm: + self.all_erm.append(name) else: # Organize other files - self.all_meeg.append(name) + if name not in self.all_meeg: + self.all_meeg.append(name) # Copy sub_files to destination (with MEEG-Class # to also include raw into file_parameters) diff --git a/mne_pipeline_hd/tests/test_nodes.py b/mne_pipeline_hd/tests/test_nodes.py new file mode 100644 index 00000000..8d1f2c38 --- /dev/null +++ b/mne_pipeline_hd/tests/test_nodes.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +from mne_pipeline_hd.gui.gui_utils import mouseDrag +from qtpy.QtCore import Qt, QPointF + + +def test_nodes_basic_interaction(nodeviewer): + node1 = nodeviewer.node(node_idx=0) + node2 = nodeviewer.node(node_idx=1) + + out1_pos = nodeviewer.port_position_view("out", 1, node_idx=0) + in2_pos = nodeviewer.port_position_view("in", 1, node_idx=1) + mouseDrag( + widget=nodeviewer.viewport(), + positions=[out1_pos, in2_pos], + button=Qt.MouseButton.LeftButton, + ) + # Check if new connection was created + assert node1.output(port_idx=1) in node2.input(port_idx=1).connected_ports + + # Slice both connections + start_slice_pos = nodeviewer.mapFromScene(QPointF(200, 180)) + end_slice_pos = nodeviewer.mapFromScene(QPointF(320, 0)) + + mouseDrag( + widget=nodeviewer.viewport(), + positions=[start_slice_pos, end_slice_pos], + button=Qt.MouseButton.LeftButton, + modifier=Qt.KeyboardModifier.AltModifier | Qt.KeyboardModifier.ShiftModifier, + ) + # Check if connection was sliced + assert len(node1.output(1).connected_ports) == 0 + + +def test_node_serialization(qtbot, nodeviewer): + viewer_dict = nodeviewer.to_dict() + qtbot.wait(1000) + nodeviewer.clear() + qtbot.wait(1000) + nodeviewer.from_dict(viewer_dict) + second_viewer_dict = nodeviewer.to_dict() + qtbot.wait(10000) + assert viewer_dict == second_viewer_dict diff --git a/mne_pipeline_hd/tests/test_parameter_widgets.py b/mne_pipeline_hd/tests/test_parameter_widgets.py index 473951ab..4536ad67 100644 --- a/mne_pipeline_hd/tests/test_parameter_widgets.py +++ b/mne_pipeline_hd/tests/test_parameter_widgets.py @@ -166,7 +166,7 @@ def test_basic_param_guis(qtbot, gui_name): def test_label_gui(qtbot, controller): - """Test opening label-gui without error""" + """Test opening label-gui without error.""" # Add fsaverage controller.pr.add_fsmri("fsaverage") diff --git a/pyproject.toml b/pyproject.toml index 815a40d6..39a8eb8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,3 +85,8 @@ PYQT5 = false PYSIDE2 = false PYQT6 = true PYSIDE6 = false + +[tool.docformatter] +black = true +in-place = true +recursive = true