Skip to content

Commit

Permalink
Continue to/from_dict (in Progress)
Browse files Browse the repository at this point in the history
  • Loading branch information
marsipu committed Nov 30, 2024
1 parent 1e08ba4 commit b997020
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 129 deletions.
5 changes: 2 additions & 3 deletions mne_pipeline_hd/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pytest

from mne_pipeline_hd.gui.node.node_viewer import NodeViewer
from mne_pipeline_hd.gui.node.nodes import FunctionNode
from mne_pipeline_hd.gui.main_window import MainWindow
from mne_pipeline_hd.pipeline.controller import Controller, NewController
from mne_pipeline_hd.pipeline.pipeline_utils import _set_test_run
Expand Down Expand Up @@ -76,8 +75,8 @@ def nodeviewer(qtbot):
},
},
}
func_node1 = viewer.create_node(FunctionNode, **func_kwargs)
func_node2 = viewer.create_node(FunctionNode, **func_kwargs)
func_node1 = viewer.create_node("FunctionNode", **func_kwargs)
func_node2 = viewer.create_node("FunctionNode", **func_kwargs)
func_node1.set_output(0, func_node2.input(0))

func_node2.setPos(400, 100)
Expand Down
20 changes: 7 additions & 13 deletions mne_pipeline_hd/development/node_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,6 @@

from PyQt5.QtWidgets import QApplication
from mne_pipeline_hd.gui.node.node_viewer import NodeViewer
from mne_pipeline_hd.gui.node.nodes import (
FunctionNode,
MEEGInputNode,
AssignmentNode,
MRIInputNode,
)
from pipeline.controller import NewController


Expand All @@ -22,13 +16,13 @@ def run_graph_test():
viewer = NodeViewer(ct)
viewer.resize(1600, 1000)

meeg_node = viewer.create_node(MEEGInputNode)
mri_node = viewer.create_node(MRIInputNode)
ass_node = viewer.create_node(AssignmentNode)
func_node1 = viewer.create_node(FunctionNode)
func_node2 = viewer.create_node(FunctionNode)
func_node3 = viewer.create_node(FunctionNode)
func_node4 = viewer.create_node(FunctionNode)
meeg_node = viewer.create_node("MEEGInputNode")
mri_node = viewer.create_node("MRIInputNode")
ass_node = viewer.create_node("AssignmentNode")
func_node1 = viewer.create_node("FunctionNode")
func_node2 = viewer.create_node("FunctionNode")
func_node3 = viewer.create_node("FunctionNode")
func_node4 = viewer.create_node("FunctionNode")

# Wire up the nodes
meeg_node.set_output(0, func_node1.input(0))
Expand Down
14 changes: 4 additions & 10 deletions mne_pipeline_hd/gui/main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,6 @@
ExportDialog,
)
from mne_pipeline_hd.gui.node.node_viewer import NodeViewer
from mne_pipeline_hd.gui.node.nodes import (
MEEGInputNode,
MRIInputNode,
AssignmentNode,
FunctionNode,
)
from mne_pipeline_hd.gui.parameter_widgets import (
BoolGui,
IntGui,
Expand Down Expand Up @@ -670,17 +664,17 @@ def add_func_bts(self):
}

# Add some demo nodes
meeg_node = self.node_viewer.create_node(MEEGInputNode)
mri_node = self.node_viewer.create_node(MRIInputNode)
meeg_node = self.node_viewer.create_node("MEEGInputNode")
mri_node = self.node_viewer.create_node("MRIInputNode")
ass_node = self.node_viewer.create_node(
AssignmentNode,
"AssignmentNode",
inputs={"Evokeds": {}, "Fwd": {}},
outputs={"Evokeds": {}, "Fwd": {}},
)
fn = dict()
for func_name, func_kwargs in demo_dict.items():
fnode = self.node_viewer.create_node(
FunctionNode, function_name=func_name, **func_kwargs
"FunctionNode", function_name=func_name, **func_kwargs
)
fn[func_name] = fnode

Expand Down
69 changes: 36 additions & 33 deletions mne_pipeline_hd/gui/node/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,11 @@ def add_input(
port qgraphics item.
"""
# port names must be unique
if name in self._inputs:
if name in [p.name for p in self.inputs]:
logging.warning(f"Input port {name} already exists.")
return
port = Port(self, name, "in", multi_connection, accepted_ports)
self._inputs[port.name] = port
self._inputs[port.id] = port
if self.scene():
self.draw_node()

Expand Down Expand Up @@ -257,49 +257,61 @@ def add_output(
port qgraphics item.
"""
# port names must be unique
if name in self._outputs:
if name in [p.name for p in self.outputs]:
logging.warning(f"Output port {name} already exists.")
return
port = Port(self, name, "out", multi_connection, accepted_ports)
self._outputs[port.name] = port
self._outputs[port.id] = port
if self.scene():
self.draw_node()

return port

def input(self, port):
"""
Get input port by the name or index.
Get input port by the name, index or id.
Args:
port (str or int): port name or index.
port (str or int): port name, index or id.
Returns:
NodeGraphQt.Port: node port.
"""
if isinstance(port, int):
if port < len(self.inputs):
return self.inputs[port]
elif isinstance(port, str):
# Get input port by id
if port in self._inputs:
return self._inputs[port]
# Get input port by index (self.inputs returns a list)
elif port < len(self.inputs):
return self.inputs[port]
elif isinstance(port, str):
port_names = [p.name for p in self.inputs]
if port in port_names:
name_index = port_names.index(port)
return self.inputs[name_index]

def output(self, port):
"""
Get output port by the name or index.
Get output port by the name, index or id.
Args:
port (str or int): port name or index.
port (str or int): port name, index or id.
Returns:
NodeGraphQt.Port: node port.
"""
if isinstance(port, int):
if port < len(self.outputs):
return self.outputs[port]
elif isinstance(port, str):
# Get output port by id
if port in self._outputs:
return self._outputs[port]
# Get output port by index (self.outputs returns a list)
elif port < len(self.outputs):
return self.outputs[port]
elif isinstance(port, str):
port_names = [p.name for p in self.outputs]
if port in port_names:
name_index = port_names.index(port)
return self.outputs[name_index]

def port(self, port_type, port):
"""
Expand Down Expand Up @@ -385,36 +397,27 @@ def delete(self):
"""
Remove node from the scene.
"""
self.scene().removeItem(self)
if self.scene() is not None:
self.scene().removeItem(self)
del self

def to_dict(self):
node_dict = {
"name": self.name,
"class": self.__class__.__name__,
"pos": self.xy_pos,
"inputs": self.inputs,
"outputs": self.outputs,
"connections": {
"inputs": {
p.name: {
nid: [cp.name for cp in cpts]
for nid, cpts in p.connected_ports.items()
}
for p in self.inputs
},
"outputs": {
p.name: {
nid: [cp.name for cp in cpts]
for nid, cpts in p.connected_ports.items()
}
for p in self.outputs
},
},
"inputs": {p.id: p.to_dict() for p in self.inputs},
"outputs": {p.id: p.to_dict() for p in self.outputs},
}

return node_dict

@classmethod
def from_dict(cls, node_dict, ct):
node = cls(ct, name=node_dict["name"])
print(node) ## only for ruff
# Todo: Continue here

# ----------------------------------------------------------------------------------
# Qt methods
# ----------------------------------------------------------------------------------
Expand Down
49 changes: 29 additions & 20 deletions mne_pipeline_hd/gui/node/node_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from collections import OrderedDict

import qtpy

from gui.node import nodes
from mne_pipeline_hd.gui.gui_utils import invert_rgb_color
from mne_pipeline_hd.gui.node.base_node import BaseNode
from mne_pipeline_hd.gui.node.node_defaults import defaults
Expand Down Expand Up @@ -197,23 +199,35 @@ def remove_node(self, node):

node.delete()

def create_node(self, node_class, **kwargs):
def create_node(self, node_info, **kwargs):
"""
Create a node from the given class.
Parameters
----------
node_class
The node class to create.
kwargs
node_info: str or dict
Can be a string to speficy the node class or a dictionary
from node.to_dict().
kwargs: dict
Additional keyword arguments to pass into BaseNode.__init__()
(replacing the values from the dictionary if provided).
Returns
-------
node
The created node.
"""
node = node_class(self.ct, **kwargs)
if isinstance(node_info, dict):
node_class = getattr(nodes, node_info["class"])
for key in node_info:
if key in kwargs:
node_info[key] = kwargs[key]
node = node_class.from_dict(node_info, self.ct)
elif isinstance(node_info, str):
node_class = getattr(nodes, node_info)
node = node_class(self.ct, **kwargs)
else:
raise ValueError("node_info must be a string or a dictionary.")
self.add_node(node)

return node
Expand Down Expand Up @@ -253,23 +267,17 @@ def to_dict(self):
return viewer_dict

def from_dict(self, viewer_dict):
# ToDo: Implement this
for node_id, node_data in viewer_dict["nodes"].items():
node = self.add_node(node_data["type"])
node.from_dict(node_data)

for conn_id, conn_data in viewer_dict["connections"].items():
start_port = self.nodes[conn_data["start_node"]].outputs[
conn_data["start_port"]
]
end_port = self.nodes[conn_data["end_node"]].inputs[conn_data["end_port"]]
start_port.connect_to(end_port)
self.clear()
# Create nodes
for node_dict in viewer_dict.values():
self.create_node(node_class=node_dict)
# Initialize connections

def clear(self):
"""
Clear the node graph.
"""
for node in self.nodes.values():
for node in list(self.nodes.values()):
self.remove_node(node)

# ----------------------------------------------------------------------------------
Expand Down Expand Up @@ -726,7 +734,8 @@ def sceneMousePressEvent(self, event):

if port:
if not port.multi_connection and len(port.connected_ports) > 0:
self._detached_port = port.get_connected_ports(port_idx=0)
# ToDo: Might cause problems with multi-connections
self._detached_port = port.connected_ports[0]
self.start_live_connection(port)
if not port.multi_connection:
[p.delete() for p in port.connected_pipes.values()]
Expand Down Expand Up @@ -1212,8 +1221,8 @@ def _update_node_rank(node, nodes_rank, down_stream=True):
node_values = node.connected_input_nodes().values()

connected_nodes = set()
for nodes in node_values:
connected_nodes.update(nodes)
for nds in node_values:
connected_nodes.update(nds)

rank = nodes_rank[node] + 1
for n in connected_nodes:
Expand Down
Loading

0 comments on commit b997020

Please sign in to comment.