From 2b1028aea0194bc5005e6fc84f11c275495c183a Mon Sep 17 00:00:00 2001 From: Greg Morenz Date: Fri, 8 Apr 2022 00:09:57 -0400 Subject: [PATCH] Many to many connections Currently graph connections are always one output to many inputs. This is appropriate for most "data type values", but not always correct. The motivating example for a place where it is incorrect is "control flow", where many nodes might "output" control flow to the same node's control flow input. This changes that so that it is configurable via data-type whether nodes are one-many, many-one, many-many, or even one-one with the concept of splittable data-types (data types that can be copied from one output to many inputs) and mergeable data-types (data types that can accept many outputs into a single input). --- egui_node_graph/src/editor_ui.rs | 56 ++++++++++++++----- egui_node_graph/src/graph.rs | 7 ++- egui_node_graph/src/graph_impls.rs | 89 +++++++++++++++++++++++++----- egui_node_graph/src/traits.rs | 10 ++++ egui_node_graph/src/ui_state.rs | 2 +- egui_node_graph_example/src/app.rs | 2 +- 6 files changed, 133 insertions(+), 33 deletions(-) diff --git a/egui_node_graph/src/editor_ui.rs b/egui_node_graph/src/editor_ui.rs index 59ccf1b..2807b66 100644 --- a/egui_node_graph/src/editor_ui.rs +++ b/egui_node_graph/src/editor_ui.rs @@ -201,7 +201,7 @@ where } NodeResponse::DisconnectEvent { input, output } => { let other_node = self.graph.get_input(input).node(); - self.graph.remove_connection(input); + self.graph.remove_connection(output, input); self.connection_in_progress = Some((other_node, AnyParameterId::Output(output))); } @@ -327,10 +327,10 @@ where for (param_name, param_id) in inputs { if self.graph[param_id].shown_inline { let height_before = ui.min_rect().bottom(); - if self.graph.connection(param_id).is_some() { - ui.label(param_name); - } else { + if self.graph.incoming(param_id).is_empty() { self.graph[param_id].value.value_widget(¶m_name, ui); + } else { + ui.label(param_name); } let height_after = ui.min_rect().bottom(); input_port_heights.push((height_before + height_after) / 2.0); @@ -370,7 +370,11 @@ where param_id: AnyParameterId, port_locations: &mut PortLocations, ongoing_drag: Option<(NodeId, AnyParameterId)>, - connected_to_output: Option, + // If the datatype of this node restricts it to connecting to + // at most one other node, and there is a connection, then this + // parameter should be Some(PortItIsConnectedTo), otherwise it + // should be None + unique_connection: Option, ) where DataType: DataTypeTrait, UserResponse: UserResponseTrait, @@ -395,14 +399,18 @@ where .circle(port_rect.center(), 5.0, port_color, Stroke::none()); if resp.drag_started() { - if let Some(output) = connected_to_output { - responses.push(NodeResponse::DisconnectEvent { - output, + let response = match unique_connection { + Some(AnyParameterId::Input(input)) => NodeResponse::DisconnectEvent { + input, + output: param_id.assume_output(), + }, + Some(AnyParameterId::Output(output)) => NodeResponse::DisconnectEvent { input: param_id.assume_input(), - }); - } else { - responses.push(NodeResponse::ConnectEventStarted(node_id, param_id)); - } + output, + }, + None => NodeResponse::ConnectEventStarted(node_id, param_id), + }; + responses.push(response); } if let Some((origin_node, origin_param)) = ongoing_drag { @@ -436,6 +444,16 @@ where InputParamKind::ConnectionOrConstant => true, }; + let unique_connection = if !self.graph.get_input(*param).typ.mergeable() { + self.graph + .incoming(*param) + .first() + .copied() + .map(AnyParameterId::Output) + } else { + None + }; + if should_draw { let pos_left = pos2(port_left, port_height); draw_port( @@ -447,7 +465,7 @@ where AnyParameterId::Input(*param), self.port_locations, self.ongoing_drag, - self.graph.connection(*param), + unique_connection, ); } } @@ -458,6 +476,16 @@ where .iter() .zip(output_port_heights.into_iter()) { + let unique_connection = if !self.graph.get_output(*param).typ.splittable() { + self.graph + .outgoing(*param) + .first() + .copied() + .map(AnyParameterId::Input) + } else { + None + }; + let pos_right = pos2(port_right, port_height); draw_port( ui, @@ -468,7 +496,7 @@ where AnyParameterId::Output(*param), self.port_locations, self.ongoing_drag, - None, + unique_connection, ); } diff --git a/egui_node_graph/src/graph.rs b/egui_node_graph/src/graph.rs index 32301d7..359c254 100644 --- a/egui_node_graph/src/graph.rs +++ b/egui_node_graph/src/graph.rs @@ -85,7 +85,10 @@ pub struct Graph { pub inputs: SlotMap>, /// The [`OutputParam`]s of the graph pub outputs: SlotMap>, - // Connects the input of a node, to the output of its predecessor that + // Connects the input of a node, to the output(s) of its predecessor(s) that // produces it - pub connections: SecondaryMap, + pub incoming: SecondaryMap>, + // Connects the outputs of a node, to the input(s) of its predecessor(s) that + // consumes it + pub outgoing: SecondaryMap>, } diff --git a/egui_node_graph/src/graph_impls.rs b/egui_node_graph/src/graph_impls.rs index 7f898a6..7eb5111 100644 --- a/egui_node_graph/src/graph_impls.rs +++ b/egui_node_graph/src/graph_impls.rs @@ -1,12 +1,16 @@ use super::*; -impl Graph { +impl Graph +where + DataType: DataTypeTrait, +{ pub fn new() -> Self { Self { nodes: SlotMap::default(), inputs: SlotMap::default(), outputs: SlotMap::default(), - connections: SecondaryMap::default(), + incoming: SecondaryMap::default(), + outgoing: SecondaryMap::default(), } } @@ -64,21 +68,38 @@ impl Graph { } pub fn remove_node(&mut self, node_id: NodeId) { - self.connections - .retain(|i, o| !(self.outputs[*o].node == node_id || self.inputs[i].node == node_id)); let inputs: SVec<_> = self[node_id].input_ids().collect(); for input in inputs { - self.inputs.remove(input); + self.remove_incoming_connections(input); } let outputs: SVec<_> = self[node_id].output_ids().collect(); for output in outputs { - self.outputs.remove(output); + self.remove_outgoing_connections(output); } self.nodes.remove(node_id); } - pub fn remove_connection(&mut self, input_id: InputId) -> Option { - self.connections.remove(input_id) + pub fn remove_connection(&mut self, output_id: OutputId, input_id: InputId) { + self.outgoing[output_id].retain(|&mut x| x != input_id); + self.incoming[input_id].retain(|&mut x| x != output_id); + } + + pub fn remove_incoming_connections(&mut self, input_id: InputId) { + if let Some(outputs) = self.incoming.get(input_id) { + for &output in outputs { + self.outgoing[output].retain(|&mut x| x != input_id); + } + } + self.incoming.remove(input_id); + } + + pub fn remove_outgoing_connections(&mut self, output_id: OutputId) { + if let Some(inputs) = self.outgoing.get(output_id) { + for &input in inputs { + self.incoming[input].retain(|&mut x| x != output_id); + } + } + self.outgoing.remove(output_id); } pub fn iter_nodes(&self) -> impl Iterator + '_ { @@ -86,15 +107,51 @@ impl Graph { } pub fn add_connection(&mut self, output: OutputId, input: InputId) { - self.connections.insert(input, output); + if self.get_input(input).typ.mergeable() { + self.incoming + .entry(input) + .expect("Old InputId") + .or_default() + .push(output); + } else { + self.remove_incoming_connections(input); + let mut v = SVec::new(); + v.push(output); + self.incoming.insert(input, v); + } + + if self.get_output(output).typ.splittable() { + self.outgoing + .entry(output) + .expect("Old OutputId") + .or_default() + .push(input); + } else { + self.remove_outgoing_connections(output); + let mut v = SVec::new(); + v.push(input); + self.outgoing.insert(output, v); + } } pub fn iter_connections(&self) -> impl Iterator + '_ { - self.connections.iter().map(|(o, i)| (o, *i)) + self.incoming + .iter() + .flat_map(|(o, inputs)| inputs.iter().map(move |&i| (o, i))) + } + + pub fn incoming(&self, input: InputId) -> &[OutputId] { + self.incoming + .get(input) + .map(|x| x.as_slice()) + .unwrap_or(&[]) } - pub fn connection(&self, input: InputId) -> Option { - self.connections.get(input).copied() + pub fn outgoing(&self, output: OutputId) -> &[InputId] { + self.outgoing + .get(output) + .map(|x| x.as_slice()) + .unwrap_or(&[]) } pub fn any_param_type(&self, param: AnyParameterId) -> Result<&DataType, EguiGraphError> { @@ -114,21 +171,23 @@ impl Graph { } } -impl Default for Graph { +impl Default + for Graph +{ fn default() -> Self { Self::new() } } impl Node { - pub fn inputs<'a, DataType, DataValue>( + pub fn inputs<'a, DataType: DataTypeTrait, DataValue>( &'a self, graph: &'a Graph, ) -> impl Iterator> + 'a { self.input_ids().map(|id| graph.get_input(id)) } - pub fn outputs<'a, DataType, DataValue>( + pub fn outputs<'a, DataType: DataTypeTrait, DataValue>( &'a self, graph: &'a Graph, ) -> impl Iterator> + 'a { diff --git a/egui_node_graph/src/traits.rs b/egui_node_graph/src/traits.rs index 5be8908..6cf5aea 100644 --- a/egui_node_graph/src/traits.rs +++ b/egui_node_graph/src/traits.rs @@ -16,6 +16,16 @@ pub trait DataTypeTrait: PartialEq + Eq { // The name of this datatype fn name(&self) -> &str; + + /// Whether an output of this datatype can be sent to multiple nodes + fn splittable(&self) -> bool { + true + } + + /// Whether an input of this datatype can be recieved from multiple nodes + fn mergeable(&self) -> bool { + false + } } /// This trait must be implemented for the `NodeData` generic parameter of the diff --git a/egui_node_graph/src/ui_state.rs b/egui_node_graph/src/ui_state.rs index a2ccea1..8b6e2fb 100644 --- a/egui_node_graph/src/ui_state.rs +++ b/egui_node_graph/src/ui_state.rs @@ -31,7 +31,7 @@ pub struct GraphEditorState +impl GraphEditorState { pub fn new(default_zoom: f32, user_state: UserState) -> Self { diff --git a/egui_node_graph_example/src/app.rs b/egui_node_graph_example/src/app.rs index be4ea3f..888174b 100644 --- a/egui_node_graph_example/src/app.rs +++ b/egui_node_graph_example/src/app.rs @@ -515,7 +515,7 @@ fn evaluate_input( let input_id = graph[node_id].get_input(param_name)?; // The output of another node is connected. - if let Some(other_output_id) = graph.connection(input_id) { + if let Some(&other_output_id) = graph.incoming(input_id).first() { // The value was already computed due to the evaluation of some other // node. We simply return value from the cache. if let Some(other_value) = outputs_cache.get(&other_output_id) {