diff --git a/egui_node_graph2/src/editor_ui.rs b/egui_node_graph2/src/editor_ui.rs index 1feff64..28736e6 100644 --- a/egui_node_graph2/src/editor_ui.rs +++ b/egui_node_graph2/src/editor_ui.rs @@ -1,4 +1,5 @@ use std::collections::HashSet; +use std::num::NonZeroU32; use crate::color_hex_utils::*; use crate::utils::ColorUtils; @@ -7,7 +8,20 @@ use super::*; use egui::epaint::{CubicBezierShape, RectShape}; use egui::*; -pub type PortLocations = std::collections::HashMap; +/// Mapping from parameter id to positions of hooks it contains. +/// +/// Outputs and short inputs always only have one hook, so the value is +/// just `vec![port_position]`. Wide inputs may have multiple hooks. +pub type PortLocations = std::collections::HashMap>; + +/// Destination positions of connections made to a given input. +/// +/// This is not equivalent to [`PortLocations`] because connections may be moved +/// around (e.g. while an in-progress connection is hovered over a wide port), +/// while hooks within a port are strictly a function of the port. +pub type ConnLocations = std::collections::HashMap>; + +/// Rectangle containing each node. pub type NodeRects = std::collections::HashMap; const DISTANCE_TO_CONNECT: f32 = 10.0; @@ -21,6 +35,10 @@ pub enum NodeResponse ConnectEventEnded { output: OutputId, input: InputId, + /// Index of the connection in wide input ports. + /// + /// If the input isn't a wide port this is always 0 and may be ignored. + input_hook: usize, }, CreatedNode(NodeId), SelectNode(NodeId), @@ -76,6 +94,7 @@ pub struct GraphNodeWidget<'a, NodeData, DataType, ValueType> { pub position: &'a mut Pos2, pub graph: &'a mut Graph, pub port_locations: &'a mut PortLocations, + pub conn_locations: &'a mut ConnLocations, pub node_rects: &'a mut NodeRects, pub node_id: NodeId, pub ongoing_drag: Option<(NodeId, AnyParameterId)>, @@ -194,6 +213,9 @@ where let mut port_locations = PortLocations::new(); let mut node_rects = NodeRects::new(); + // actual dest location of each connection + let mut conn_locations = ConnLocations::default(); + // The responses returned from node drawing have side effects that are best // executed at the end of this function. let mut delayed_responses: Vec> = prepend_responses; @@ -229,6 +251,7 @@ where position: self.node_positions.get_mut(node_id).unwrap(), graph: &mut self.graph, port_locations: &mut port_locations, + conn_locations: &mut conn_locations, node_rects: &mut node_rects, node_id, ongoing_drag: self.connection_in_progress, @@ -282,11 +305,13 @@ where self.node_finder = None; } - /* Draw connections */ + // draw in-progress connections if let Some((_, ref locator)) = self.connection_in_progress { let port_type = self.graph.any_param_type(*locator).unwrap(); let connection_color = port_type.data_type_color(user_state); - let start_pos = port_locations[locator]; + + // outputs can't be wide yet so this is fine. + let start_pos = *port_locations[locator].last().unwrap(); // Find a port to connect to fn snap_to_ports< @@ -312,12 +337,19 @@ where .unwrap_or(false); if compatible_ports { - port_locations.get(&port_id.into()).and_then(|port_pos| { - if port_pos.distance(cursor_pos) < DISTANCE_TO_CONNECT { - Some(*port_pos) - } else { - None - } + port_locations.get(&port_id.into()).and_then(|hooks| { + hooks + .iter() + .min_by(|hook1, hook2| { + hook1 + .distance(cursor_pos) + .partial_cmp(&hook2.distance(cursor_pos)) + .unwrap() + }) + .filter(|nearest_hook| { + nearest_hook.distance(cursor_pos) < DISTANCE_TO_CONNECT + }) + .copied() }) } else { None @@ -357,21 +389,19 @@ where ); } - for (input, output) in self.graph.iter_connections() { - let port_type = self - .graph - .any_param_type(AnyParameterId::Output(output)) - .unwrap(); - let connection_color = port_type.data_type_color(user_state); - let src_pos = port_locations[&AnyParameterId::Output(output)]; - let dst_pos = port_locations[&AnyParameterId::Input(input)]; - draw_connection( - &self.pan_zoom, - ui.painter(), - src_pos, - dst_pos, - connection_color, - ); + // draw existing connections + for (input, outputs) in self.graph.iter_connection_groups() { + for (hook_n, &output) in outputs.iter().enumerate() { + let port_type = self + .graph + .any_param_type(AnyParameterId::Output(output)) + .unwrap(); + let connection_color = port_type.data_type_color(user_state); + // outputs can't be wide yet so this is fine. + let src_pos = port_locations[&AnyParameterId::Output(output)][0]; + let dst_pos = conn_locations[&input][hook_n]; + draw_connection(&self.pan_zoom, ui.painter(), src_pos, dst_pos, connection_color); + } } /* Handle responses from drawing nodes */ @@ -385,9 +415,11 @@ where NodeResponse::ConnectEventStarted(node_id, port) => { self.connection_in_progress = Some((*node_id, *port)); } - NodeResponse::ConnectEventEnded { input, output } => { - self.graph.add_connection(*output, *input) - } + NodeResponse::ConnectEventEnded { + output, + input, + input_hook, + } => self.graph.add_connection(*output, *input, *input_hook), NodeResponse::CreatedNode(_) => { //Convenience NodeResponse for users } @@ -396,6 +428,7 @@ where } NodeResponse::DeleteNodeUi(node_id) => { let (node, disc_events) = self.graph.remove_node(*node_id); + // Pass the disconnection responses first so user code can perform cleanup // before node removal response. extra_responses.extend( @@ -416,7 +449,7 @@ where } NodeResponse::DisconnectEvent { input, output } => { let other_node = self.graph.get_output(*output).node; - self.graph.remove_connection(*input); + self.graph.remove_connection(*input, *output); self.connection_in_progress = Some((other_node, AnyParameterId::Output(*output))); } @@ -673,38 +706,62 @@ where for (param_name, param_id) in inputs { if self.graph[param_id].shown_inline { let height_before = ui.min_rect().bottom(); - // NOTE: We want to pass the `user_data` to - // `value_widget`, but we can't since that would require - // borrowing the graph twice. Here, we make the - // assumption that the value is cheaply replaced, and - // use `std::mem::take` to temporarily replace it with a - // dummy value. This requires `ValueType` to implement - // Default, but results in a totally safe alternative. - let mut value = std::mem::take(&mut self.graph[param_id].value); - - if self.graph.connection(param_id).is_some() { - let node_responses = value.value_widget_connected( - ¶m_name, - self.node_id, - ui, - user_state, - &self.graph[self.node_id].user_data, - ); - - responses.extend(node_responses.into_iter().map(NodeResponse::User)); + + if self.graph[param_id].max_connections == NonZeroU32::new(1) { + // NOTE: We want to pass the `user_data` to + // `value_widget`, but we can't since that would require + // borrowing the graph twice. Here, we make the + // assumption that the value is cheaply replaced, and + // use `std::mem::take` to temporarily replace it with a + // dummy value. This requires `ValueType` to implement + // Default, but results in a totally safe alternative. + let mut value = std::mem::take(&mut self.graph[param_id].value); + + if !self.graph.connections(param_id).is_empty() { + let node_responses = value.value_widget_connected( + ¶m_name, + self.node_id, + ui, + user_state, + &self.graph[self.node_id].user_data, + ); + + responses.extend(node_responses.into_iter().map(NodeResponse::User)); + } else { + let node_responses = value.value_widget( + ¶m_name, + self.node_id, + ui, + user_state, + &self.graph[self.node_id].user_data, + ); + + responses.extend(node_responses.into_iter().map(NodeResponse::User)); + } + + self.graph[param_id].value = value; } else { - let node_responses = value.value_widget( - ¶m_name, - self.node_id, - ui, - user_state, - &self.graph[self.node_id].user_data, - ); - - responses.extend(node_responses.into_iter().map(NodeResponse::User)); + ui.label(param_name); } - self.graph[param_id].value = value; + let height_intermediate = ui.min_rect().bottom(); + + let max_connections = self.graph[param_id] + .max_connections + .map(NonZeroU32::get) + .unwrap_or(std::u32::MAX) + as usize; + let port_height = port_height( + max_connections != 1, + self.graph.connections(param_id).len(), + max_connections, + ); + let margin = 5.0; + let missing_space = + port_height - (height_intermediate - height_before) + margin; + if missing_space > 0.0 { + ui.add_space(missing_space); + } self.graph[self.node_id].user_data.separator( ui, @@ -715,6 +772,7 @@ where ); let height_after = ui.min_rect().bottom(); + input_port_heights.push((height_before + height_after) / 2.0); } } @@ -762,6 +820,17 @@ where .insert_temp(child_ui.id(), OuterRectMemory(outer_rect)) }); + fn port_height(wide_port: bool, connections: usize, max_connections: usize) -> f32 { + let port_full = connections == max_connections; + if wide_port { + let hooks = connections + if port_full { 0 } else { 1 }; + + 5.0 + (10.0 * hooks as f32).max(10.0) + } else { + 10.0 + } + } + #[allow(clippy::too_many_arguments)] fn draw_port( pan_zoom: &PanZoom, @@ -773,8 +842,11 @@ where responses: &mut Vec>, param_id: AnyParameterId, port_locations: &mut PortLocations, + conn_locations: &mut ConnLocations, ongoing_drag: Option<(NodeId, AnyParameterId)>, - is_connected_input: bool, + wide_port: bool, + connections: usize, + max_connections: usize, ) where DataType: DataTypeTrait, UserResponse: UserResponseTrait, @@ -782,8 +854,29 @@ where { let port_type = graph.any_param_type(param_id).unwrap(); - let port_rect = - Rect::from_center_size(port_pos, egui::vec2(10.0, 10.0) * pan_zoom.zoom); + let port_rect = Rect::from_center_size( + port_pos, + egui::vec2(10.0, port_height(wide_port, connections, max_connections)) * pan_zoom.zoom, + ); + + let port_full = connections == max_connections; + + let inner_ports = if wide_port { + connections + if port_full { 0 } else { 1 } + } else { + 1 + }; + + port_locations.insert( + param_id, + (0..inner_ports) + .map(|k| { + port_rect.center_top() + + Vec2::new(0.0, 5.0) + + Vec2::new(0.0, 10.0) * k as f32 + }) + .collect(), + ); let sense = if ongoing_drag.is_some() { Sense::hover() @@ -805,47 +898,90 @@ where } else { port_type.data_type_color(user_state) }; - ui.painter().circle( - port_rect.center(), - 5.0 * pan_zoom.zoom, - port_color, - Stroke::NONE, - ); + + if wide_port { + ui.painter().rect_filled(port_rect, 5.0 * pan_zoom.zoom, port_color); + } else { + ui.painter() + .circle(port_rect.center(), 5.0 * pan_zoom.zoom, port_color, Stroke::NONE); + } + + if connections > 0 { + if let AnyParameterId::Input(input) = param_id { + for (k, dst_pos) in port_locations[&AnyParameterId::Input(input)] + .iter() + .enumerate() + { + conn_locations.entry(input).or_default().insert(k, *dst_pos); + } + } + } + + let nearest_hook = ui + .input(|in_state| in_state.pointer.hover_pos()) + .and_then(|mouse_pos| match param_id { + AnyParameterId::Input(input) => Some((mouse_pos, input)), + AnyParameterId::Output(_) => None, + }) + .and_then(|(mouse_pos, input)| { + let hooks = 0..inner_ports; + hooks.min_by(|&hook1, &hook2| { + let out1_dist = conn_locations[&input][hook1].distance(mouse_pos); + let out2_dist = conn_locations[&input][hook2].distance(mouse_pos); + + out1_dist.partial_cmp(&out2_dist).unwrap() + }) + }); if resp.drag_started() { - if is_connected_input { - let input = param_id.assume_input(); - let corresp_output = graph - .connection(input) - .expect("Connection data should be valid"); - responses.push(NodeResponse::DisconnectEvent { - input: param_id.assume_input(), - output: corresp_output, - }); - } else { - responses.push(NodeResponse::ConnectEventStarted(node_id, param_id)); + match param_id { + AnyParameterId::Input(input) => { + match nearest_hook + .and_then(|hook| graph.connections(input).get(hook).copied()) + { + Some(output) => { + responses.push(NodeResponse::DisconnectEvent { input, output }); + } + None => { + responses + .push(NodeResponse::ConnectEventStarted(node_id, param_id)); + } + } + } + AnyParameterId::Output(_) => { + responses.push(NodeResponse::ConnectEventStarted(node_id, param_id)); + } } } if let Some((origin_node, origin_param)) = ongoing_drag { if origin_node != node_id { // Don't allow self-loops - if graph.any_param_type(origin_param).unwrap() == port_type - && close_enough - && ui.input(|i| i.pointer.any_released()) - { + if graph.any_param_type(origin_param).unwrap() == port_type && close_enough { match (param_id, origin_param) { (AnyParameterId::Input(input), AnyParameterId::Output(output)) | (AnyParameterId::Output(output), AnyParameterId::Input(input)) => { - responses.push(NodeResponse::ConnectEventEnded { input, output }); + let input_hook = + nearest_hook.unwrap_or(graph.connections(input).len()); + + if ui.input(|i| i.pointer.any_released()) { + responses.push(NodeResponse::ConnectEventEnded { + output, + input, + input_hook, + }); + } else if wide_port && !port_full { + // move connections below the in-progress one to a lower position + for k in input_hook..graph.connections(input).len() { + conn_locations.get_mut(&input).unwrap()[k].y += 7.5; + } + } } _ => { /* Ignore in-in or out-out connections */ } } } } } - - port_locations.insert(param_id, port_rect.center()); } // Input ports @@ -862,6 +998,10 @@ where if should_draw { let pos_left = pos2(port_left, port_height); + let max_connections = self.graph[*param] + .max_connections + .map(NonZeroU32::get) + .unwrap_or(std::u32::MAX) as usize; draw_port( pan_zoom, ui, @@ -872,8 +1012,11 @@ where &mut responses, AnyParameterId::Input(*param), self.port_locations, + self.conn_locations, self.ongoing_drag, - self.graph.connection(*param).is_some(), + max_connections > 1, + self.graph.connections(*param).len(), + max_connections, ); } } @@ -895,8 +1038,11 @@ where &mut responses, AnyParameterId::Output(*param), self.port_locations, + self.conn_locations, self.ongoing_drag, false, + 0, + 1, ); } diff --git a/egui_node_graph2/src/graph.rs b/egui_node_graph2/src/graph.rs index 32301d7..85c142e 100644 --- a/egui_node_graph2/src/graph.rs +++ b/egui_node_graph2/src/graph.rs @@ -1,3 +1,5 @@ +use std::num::NonZeroU32; + use super::*; #[cfg(feature = "persistence")] @@ -55,6 +57,8 @@ pub struct InputParam { pub kind: InputParamKind, /// Back-reference to the node containing this parameter. pub node: NodeId, + /// How many connections can be made with this input. `None` means no limit. + pub max_connections: Option, /// When true, the node is shown inline inside the node graph. #[cfg_attr(feature = "persistence", serde(default = "shown_inline_default"))] pub shown_inline: bool, @@ -87,5 +91,5 @@ pub struct Graph { pub outputs: SlotMap>, // Connects the input of a node, to the output of its predecessor that // produces it - pub connections: SecondaryMap, + pub connections: SecondaryMap>, } diff --git a/egui_node_graph2/src/graph_impls.rs b/egui_node_graph2/src/graph_impls.rs index 44b5a5d..a117b95 100644 --- a/egui_node_graph2/src/graph_impls.rs +++ b/egui_node_graph2/src/graph_impls.rs @@ -1,3 +1,5 @@ +use std::num::NonZeroU32; + use super::*; impl Graph { @@ -32,13 +34,15 @@ impl Graph { node_id } - pub fn add_input_param( + #[allow(clippy::too_many_arguments)] + pub fn add_wide_input_param( &mut self, node_id: NodeId, name: String, typ: DataType, value: ValueType, kind: InputParamKind, + max_connections: Option, shown_inline: bool, ) -> InputId { let input_id = self.inputs.insert_with_key(|input_id| InputParam { @@ -47,12 +51,33 @@ impl Graph { value, kind, node: node_id, + max_connections, shown_inline, }); self.nodes[node_id].inputs.push((name, input_id)); input_id } + pub fn add_input_param( + &mut self, + node_id: NodeId, + name: String, + typ: DataType, + value: ValueType, + kind: InputParamKind, + shown_inline: bool, + ) -> InputId { + self.add_wide_input_param( + node_id, + name, + typ, + value, + kind, + NonZeroU32::new(1), + shown_inline, + ) + } + pub fn remove_input_param(&mut self, param: InputId) { let node = self[param].node; self[node].inputs.retain(|(_, id)| *id != param); @@ -64,7 +89,9 @@ impl Graph { let node = self[param].node; self[node].outputs.retain(|(_, id)| *id != param); self.outputs.remove(param); - self.connections.retain(|_, o| *o != param); + for (_, conns) in &mut self.connections { + conns.retain(|o| *o != param); + } } pub fn add_output_param(&mut self, node_id: NodeId, name: String, typ: DataType) -> OutputId { @@ -87,14 +114,16 @@ impl Graph { pub fn remove_node(&mut self, node_id: NodeId) -> (Node, Vec<(InputId, OutputId)>) { let mut disconnect_events = vec![]; - self.connections.retain(|i, o| { - if self.outputs[*o].node == node_id || self.inputs[i].node == node_id { - disconnect_events.push((i, *o)); - false - } else { - true - } - }); + for (i, conns) in &mut self.connections { + conns.retain(|o| { + if self.outputs[*o].node == node_id || self.inputs[i].node == node_id { + disconnect_events.push((i, *o)); + false + } else { + true + } + }); + } // NOTE: Collect is needed because we can't borrow the input ids while // we remove them inside the loop. @@ -109,24 +138,72 @@ impl Graph { (removed_node, disconnect_events) } - pub fn remove_connection(&mut self, input_id: InputId) -> Option { - self.connections.remove(input_id) + pub fn remove_connection(&mut self, input_id: InputId, output_id: OutputId) -> bool { + self.connections + .get_mut(input_id) + .map(|conns| { + let old_size = conns.len(); + conns.retain(|id| id != &output_id); + + // connection removed if `conn` size changes + old_size != conns.len() + }) + .unwrap_or(false) } pub fn iter_nodes(&self) -> impl Iterator + '_ { self.nodes.iter().map(|(id, _)| id) } - pub fn add_connection(&mut self, output: OutputId, input: InputId) { - self.connections.insert(input, output); + pub fn add_connection(&mut self, output: OutputId, input: InputId, pos: usize) { + if !self.connections.contains_key(input) { + self.connections.insert(input, Vec::default()); + } + + let max_connections = self + .get_input(input) + .max_connections + .map(NonZeroU32::get) + .unwrap_or(std::u32::MAX) as usize; + let already_in = self.connections[input].contains(&output); + + // connecting twice to the same port is a no-op + // even for wide ports. + if already_in { + return; + } + + if self.connections[input].len() == max_connections { + // if full, replace the connected output + self.connections[input][pos] = output; + } else { + // otherwise, insert at a selected position + self.connections[input].insert(pos, output); + } + } + + pub fn iter_connection_groups(&self) -> impl Iterator)> + '_ { + self.connections.iter().map(|(i, conns)| (i, conns.clone())) } pub fn iter_connections(&self) -> impl Iterator + '_ { - self.connections.iter().map(|(o, i)| (o, *i)) + self.iter_connection_groups() + .flat_map(|(i, conns)| conns.into_iter().map(move |o| (i, o))) + } + + pub fn connections(&self, input: InputId) -> Vec { + self.connections.get(input).cloned().unwrap_or_default() } pub fn connection(&self, input: InputId) -> Option { - self.connections.get(input).copied() + let is_limit_1 = self.get_input(input).max_connections == NonZeroU32::new(1); + let connections = self.connections(input); + + if is_limit_1 && connections.len() == 1 { + connections.into_iter().next() + } else { + None + } } pub fn any_param_type(&self, param: AnyParameterId) -> Result<&DataType, EguiGraphError> {