diff --git a/egui_node_graph/src/editor_ui.rs b/egui_node_graph/src/editor_ui.rs index 59ccf1b..be46695 100644 --- a/egui_node_graph/src/editor_ui.rs +++ b/egui_node_graph/src/editor_ui.rs @@ -201,7 +201,7 @@ where } NodeResponse::DisconnectEvent { input, output } => { let other_node = self.graph.get_input(input).node(); - self.graph.remove_connection(input); + self.graph.remove_connection(output, input); self.connection_in_progress = Some((other_node, AnyParameterId::Output(output))); } @@ -327,7 +327,7 @@ where for (param_name, param_id) in inputs { if self.graph[param_id].shown_inline { let height_before = ui.min_rect().bottom(); - if self.graph.connection(param_id).is_some() { + if self.graph.incoming(param_id).len() != 0 { ui.label(param_name); } else { self.graph[param_id].value.value_widget(¶m_name, ui); @@ -370,7 +370,11 @@ where param_id: AnyParameterId, port_locations: &mut PortLocations, ongoing_drag: Option<(NodeId, AnyParameterId)>, - connected_to_output: Option, + // If the datatype of this node restricts it to connecting to + // at most one other node, and there is a connection, then this + // parameter should be Some(PortItIsConnectedTo), otherwise it + // should be None + unique_connection: Option, ) where DataType: DataTypeTrait, UserResponse: UserResponseTrait, @@ -395,14 +399,18 @@ where .circle(port_rect.center(), 5.0, port_color, Stroke::none()); if resp.drag_started() { - if let Some(output) = connected_to_output { - responses.push(NodeResponse::DisconnectEvent { - output, + let response = match unique_connection { + Some(AnyParameterId::Input(input)) => NodeResponse::DisconnectEvent { + input, + output: param_id.assume_output(), + }, + Some(AnyParameterId::Output(output)) => NodeResponse::DisconnectEvent { input: param_id.assume_input(), - }); - } else { - responses.push(NodeResponse::ConnectEventStarted(node_id, param_id)); - } + output, + }, + None => NodeResponse::ConnectEventStarted(node_id, param_id), + }; + responses.push(response); } if let Some((origin_node, origin_param)) = ongoing_drag { @@ -436,6 +444,16 @@ where InputParamKind::ConnectionOrConstant => true, }; + let unique_connection = if !self.graph.get_input(*param).typ.mergeable() { + self.graph + .incoming(*param) + .first() + .copied() + .map(AnyParameterId::Output) + } else { + None + }; + if should_draw { let pos_left = pos2(port_left, port_height); draw_port( @@ -447,7 +465,7 @@ where AnyParameterId::Input(*param), self.port_locations, self.ongoing_drag, - self.graph.connection(*param), + unique_connection, ); } } @@ -458,6 +476,16 @@ where .iter() .zip(output_port_heights.into_iter()) { + let unique_connection = if !self.graph.get_output(*param).typ.splittable() { + self.graph + .outgoing(*param) + .first() + .copied() + .map(AnyParameterId::Input) + } else { + None + }; + let pos_right = pos2(port_right, port_height); draw_port( ui, @@ -468,7 +496,7 @@ where AnyParameterId::Output(*param), self.port_locations, self.ongoing_drag, - None, + unique_connection, ); } diff --git a/egui_node_graph/src/graph.rs b/egui_node_graph/src/graph.rs index 32301d7..359c254 100644 --- a/egui_node_graph/src/graph.rs +++ b/egui_node_graph/src/graph.rs @@ -85,7 +85,10 @@ pub struct Graph { pub inputs: SlotMap>, /// The [`OutputParam`]s of the graph pub outputs: SlotMap>, - // Connects the input of a node, to the output of its predecessor that + // Connects the input of a node, to the output(s) of its predecessor(s) that // produces it - pub connections: SecondaryMap, + pub incoming: SecondaryMap>, + // Connects the outputs of a node, to the input(s) of its predecessor(s) that + // consumes it + pub outgoing: SecondaryMap>, } diff --git a/egui_node_graph/src/graph_impls.rs b/egui_node_graph/src/graph_impls.rs index 7f898a6..7eb5111 100644 --- a/egui_node_graph/src/graph_impls.rs +++ b/egui_node_graph/src/graph_impls.rs @@ -1,12 +1,16 @@ use super::*; -impl Graph { +impl Graph +where + DataType: DataTypeTrait, +{ pub fn new() -> Self { Self { nodes: SlotMap::default(), inputs: SlotMap::default(), outputs: SlotMap::default(), - connections: SecondaryMap::default(), + incoming: SecondaryMap::default(), + outgoing: SecondaryMap::default(), } } @@ -64,21 +68,38 @@ impl Graph { } pub fn remove_node(&mut self, node_id: NodeId) { - self.connections - .retain(|i, o| !(self.outputs[*o].node == node_id || self.inputs[i].node == node_id)); let inputs: SVec<_> = self[node_id].input_ids().collect(); for input in inputs { - self.inputs.remove(input); + self.remove_incoming_connections(input); } let outputs: SVec<_> = self[node_id].output_ids().collect(); for output in outputs { - self.outputs.remove(output); + self.remove_outgoing_connections(output); } self.nodes.remove(node_id); } - pub fn remove_connection(&mut self, input_id: InputId) -> Option { - self.connections.remove(input_id) + pub fn remove_connection(&mut self, output_id: OutputId, input_id: InputId) { + self.outgoing[output_id].retain(|&mut x| x != input_id); + self.incoming[input_id].retain(|&mut x| x != output_id); + } + + pub fn remove_incoming_connections(&mut self, input_id: InputId) { + if let Some(outputs) = self.incoming.get(input_id) { + for &output in outputs { + self.outgoing[output].retain(|&mut x| x != input_id); + } + } + self.incoming.remove(input_id); + } + + pub fn remove_outgoing_connections(&mut self, output_id: OutputId) { + if let Some(inputs) = self.outgoing.get(output_id) { + for &input in inputs { + self.incoming[input].retain(|&mut x| x != output_id); + } + } + self.outgoing.remove(output_id); } pub fn iter_nodes(&self) -> impl Iterator + '_ { @@ -86,15 +107,51 @@ impl Graph { } pub fn add_connection(&mut self, output: OutputId, input: InputId) { - self.connections.insert(input, output); + if self.get_input(input).typ.mergeable() { + self.incoming + .entry(input) + .expect("Old InputId") + .or_default() + .push(output); + } else { + self.remove_incoming_connections(input); + let mut v = SVec::new(); + v.push(output); + self.incoming.insert(input, v); + } + + if self.get_output(output).typ.splittable() { + self.outgoing + .entry(output) + .expect("Old OutputId") + .or_default() + .push(input); + } else { + self.remove_outgoing_connections(output); + let mut v = SVec::new(); + v.push(input); + self.outgoing.insert(output, v); + } } pub fn iter_connections(&self) -> impl Iterator + '_ { - self.connections.iter().map(|(o, i)| (o, *i)) + self.incoming + .iter() + .flat_map(|(o, inputs)| inputs.iter().map(move |&i| (o, i))) + } + + pub fn incoming(&self, input: InputId) -> &[OutputId] { + self.incoming + .get(input) + .map(|x| x.as_slice()) + .unwrap_or(&[]) } - pub fn connection(&self, input: InputId) -> Option { - self.connections.get(input).copied() + pub fn outgoing(&self, output: OutputId) -> &[InputId] { + self.outgoing + .get(output) + .map(|x| x.as_slice()) + .unwrap_or(&[]) } pub fn any_param_type(&self, param: AnyParameterId) -> Result<&DataType, EguiGraphError> { @@ -114,21 +171,23 @@ impl Graph { } } -impl Default for Graph { +impl Default + for Graph +{ fn default() -> Self { Self::new() } } impl Node { - pub fn inputs<'a, DataType, DataValue>( + pub fn inputs<'a, DataType: DataTypeTrait, DataValue>( &'a self, graph: &'a Graph, ) -> impl Iterator> + 'a { self.input_ids().map(|id| graph.get_input(id)) } - pub fn outputs<'a, DataType, DataValue>( + pub fn outputs<'a, DataType: DataTypeTrait, DataValue>( &'a self, graph: &'a Graph, ) -> impl Iterator> + 'a { diff --git a/egui_node_graph/src/traits.rs b/egui_node_graph/src/traits.rs index 5be8908..6cf5aea 100644 --- a/egui_node_graph/src/traits.rs +++ b/egui_node_graph/src/traits.rs @@ -16,6 +16,16 @@ pub trait DataTypeTrait: PartialEq + Eq { // The name of this datatype fn name(&self) -> &str; + + /// Whether an output of this datatype can be sent to multiple nodes + fn splittable(&self) -> bool { + true + } + + /// Whether an input of this datatype can be recieved from multiple nodes + fn mergeable(&self) -> bool { + false + } } /// This trait must be implemented for the `NodeData` generic parameter of the diff --git a/egui_node_graph/src/ui_state.rs b/egui_node_graph/src/ui_state.rs index a2ccea1..8b6e2fb 100644 --- a/egui_node_graph/src/ui_state.rs +++ b/egui_node_graph/src/ui_state.rs @@ -31,7 +31,7 @@ pub struct GraphEditorState +impl GraphEditorState { pub fn new(default_zoom: f32, user_state: UserState) -> Self { diff --git a/egui_node_graph_example/src/app.rs b/egui_node_graph_example/src/app.rs index be4ea3f..888174b 100644 --- a/egui_node_graph_example/src/app.rs +++ b/egui_node_graph_example/src/app.rs @@ -515,7 +515,7 @@ fn evaluate_input( let input_id = graph[node_id].get_input(param_name)?; // The output of another node is connected. - if let Some(other_output_id) = graph.connection(input_id) { + if let Some(&other_output_id) = graph.incoming(input_id).first() { // The value was already computed due to the evaluation of some other // node. We simply return value from the cache. if let Some(other_value) = outputs_cache.get(&other_output_id) {