diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index f1112a46a..000000000 Binary files a/.DS_Store and /dev/null differ diff --git a/mixnet/node/Cargo.toml b/mixnet/node/Cargo.toml index 37601022a..c3fc05279 100644 --- a/mixnet/node/Cargo.toml +++ b/mixnet/node/Cargo.toml @@ -4,9 +4,10 @@ version = "0.1.0" edition = "2021" [dependencies] +crossbeam-skiplist = "0.1" serde = { version = "1.0", features = ["derive"] } tracing = "0.1.37" -tokio = { version = "1.32", features = ["net", "time"] } +tokio = { version = "1.32", features = ["net", "time", "signal"] } thiserror = "1" sphinx-packet = "0.1.0" nym-sphinx = { package = "nym-sphinx", git = "https://github.com/nymtech/nym", tag = "v1.1.22" } diff --git a/mixnet/node/src/config.rs b/mixnet/node/src/config.rs index 6d325c9b8..50b1061f1 100644 --- a/mixnet/node/src/config.rs +++ b/mixnet/node/src/config.rs @@ -7,7 +7,7 @@ use nym_sphinx::{PrivateKey, PublicKey}; use serde::{Deserialize, Serialize}; use sphinx_packet::crypto::{PRIVATE_KEY_SIZE, PUBLIC_KEY_SIZE}; -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Copy, Clone, Debug)] pub struct MixnetNodeConfig { /// A listen address for receiving Sphinx packets pub listen_address: SocketAddr, diff --git a/mixnet/node/src/lib.rs b/mixnet/node/src/lib.rs index 613376898..adb8fc18f 100644 --- a/mixnet/node/src/lib.rs +++ b/mixnet/node/src/lib.rs @@ -1,16 +1,15 @@ mod client_notifier; pub mod config; -use std::{net::SocketAddr, time::Duration}; +use std::{collections::HashMap, net::SocketAddr, time::Duration}; use client_notifier::ClientNotifier; pub use config::MixnetNodeConfig; use mixnet_protocol::{Body, ProtocolError}; use mixnet_topology::MixnetNodeId; -use mixnet_util::ConnectionPool; use nym_sphinx::{ addressing::nodes::{NymNodeRoutingAddress, NymNodeRoutingAddressError}, - Delay, DestinationAddressBytes, NodeAddressBytes, Payload, PrivateKey, + Delay, DestinationAddressBytes, NodeAddressBytes, PrivateKey, }; pub use sphinx_packet::crypto::PRIVATE_KEY_SIZE; use sphinx_packet::{crypto::PUBLIC_KEY_SIZE, ProcessedPacket, SphinxPacket}; @@ -28,7 +27,9 @@ pub enum MixnetNodeError { #[error("invalid routing address: {0}")] InvalidRoutingAddress(#[from] NymNodeRoutingAddressError), #[error("send error: {0}")] - SendError(#[from] tokio::sync::mpsc::error::TrySendError), + MessageSendError(#[from] tokio::sync::mpsc::error::SendError), + #[error("send error: fail to send {0} to client")] + ClientSendError(#[from] tokio::sync::mpsc::error::TrySendError), #[error("client: {0}")] Client(ProtocolError), } @@ -36,13 +37,11 @@ pub enum MixnetNodeError { // A mix node that routes packets in the Mixnet. pub struct MixnetNode { config: MixnetNodeConfig, - pool: ConnectionPool, } impl MixnetNode { pub fn new(config: MixnetNodeConfig) -> Self { - let pool = ConnectionPool::new(config.connection_pool_size); - Self { config, pool } + Self { config } } pub fn id(&self) -> MixnetNodeId { @@ -77,205 +76,279 @@ impl MixnetNode { self.config.listen_address ); + let (tx, rx) = mpsc::unbounded_channel(); + + let message_handler = MessageHandler::new(tx.clone(), rx, self.config); + + tokio::spawn(async move { + message_handler.run().await; + }); + + let runner = MixnetNodeRunner { + config: self.config, + client_tx, + message_tx: tx, + }; + loop { - match listener.accept().await { - Ok((socket, remote_addr)) => { - tracing::debug!("Accepted incoming connection from {remote_addr:?}"); - - let client_tx = client_tx.clone(); - let private_key = self.config.private_key; - let pool = self.pool.clone(); - tokio::spawn(async move { - if let Err(e) = Self::handle_connection( - socket, - self.config.max_retries, - self.config.retry_delay, - pool, - private_key, - client_tx, - ) - .await - { - tracing::error!("failed to handle conn: {e}"); + tokio::select! { + res = listener.accept() => { + match res { + Ok((socket, remote_addr)) => { + tracing::debug!("Accepted incoming connection from {remote_addr:?}"); + + let runner = runner.clone(); + tokio::spawn(async move { + if let Err(e) = runner.handle_connection(socket).await { + tracing::error!("failed to handle conn: {e}"); + } + }); } - }); + Err(e) => tracing::warn!("Failed to accept incoming connection: {e}"), + } + } + _ = tokio::signal::ctrl_c() => { + tracing::info!("Shutting down..."); + return Ok(()); } - Err(e) => tracing::warn!("Failed to accept incoming connection: {e}"), } } } +} - async fn handle_connection( - mut socket: TcpStream, - max_retries: usize, - retry_delay: Duration, - pool: ConnectionPool, - private_key: [u8; PRIVATE_KEY_SIZE], - client_tx: mpsc::Sender, - ) -> Result<()> { +#[derive(Clone)] +struct MixnetNodeRunner { + config: MixnetNodeConfig, + client_tx: mpsc::Sender, + message_tx: mpsc::UnboundedSender, +} + +impl MixnetNodeRunner { + async fn handle_connection(&self, mut socket: TcpStream) -> Result<()> { loop { let body = Body::read(&mut socket).await?; - - let pool = pool.clone(); - let private_key = PrivateKey::from(private_key); - let client_tx = client_tx.clone(); - + let this = self.clone(); tokio::spawn(async move { - if let Err(e) = Self::handle_body( - max_retries, - retry_delay, - body, - &pool, - &private_key, - &client_tx, - ) - .await - { + if let Err(e) = this.handle_body(body).await { tracing::error!("failed to handle body: {e}"); } }); } } - // TODO: refactor this fn to make it receive less arguments - #[allow(clippy::too_many_arguments)] - async fn handle_body( - max_retries: usize, - retry_delay: Duration, - body: Body, - pool: &ConnectionPool, - private_key: &PrivateKey, - client_tx: &mpsc::Sender, - ) -> Result<()> { - match body { - Body::SphinxPacket(packet) => { - Self::handle_sphinx_packet(pool, max_retries, retry_delay, private_key, packet) - .await - } + async fn handle_body(&self, msg: Body) -> Result<()> { + match msg { + Body::SphinxPacket(packet) => self.handle_sphinx_packet(packet).await, Body::FinalPayload(payload) => { - Self::forward_body_to_client_notifier( - private_key, - client_tx, - Body::FinalPayload(payload), - ) - .await + self.forward_body_to_client_notifier(Body::FinalPayload(payload)) + .await } _ => unreachable!(), } } - async fn handle_sphinx_packet( - pool: &ConnectionPool, - max_retries: usize, - retry_delay: Duration, - private_key: &PrivateKey, - packet: Box, - ) -> Result<()> { + async fn handle_sphinx_packet(&self, packet: Box) -> Result<()> { match packet - .process(private_key) + .process(&PrivateKey::from(self.config.private_key)) .map_err(ProtocolError::InvalidSphinxPacket)? { ProcessedPacket::ForwardHop(packet, next_node_addr, delay) => { - Self::forward_packet_to_next_hop( - pool, - max_retries, - retry_delay, - packet, - next_node_addr, - delay, - ) - .await + self.forward_packet_to_next_hop(Body::SphinxPacket(packet), next_node_addr, delay) + .await } ProcessedPacket::FinalHop(destination_addr, _, payload) => { - Self::forward_payload_to_destination( - pool, - max_retries, - retry_delay, - payload, - destination_addr, - ) - .await + self.forward_payload_to_destination(Body::FinalPayload(payload), destination_addr) + .await } } } - async fn forward_body_to_client_notifier( - _private_key: &PrivateKey, - client_tx: &mpsc::Sender, - body: Body, - ) -> Result<()> { + async fn forward_body_to_client_notifier(&self, body: Body) -> Result<()> { // TODO: Decrypt the final payload using the private key, if it's encrypted // Do not wait when the channel is full or no receiver exists - client_tx.try_send(body)?; + self.client_tx.try_send(body)?; Ok(()) } async fn forward_packet_to_next_hop( - pool: &ConnectionPool, - max_retries: usize, - retry_delay: Duration, - packet: Box, + &self, + packet: Body, next_node_addr: NodeAddressBytes, delay: Delay, ) -> Result<()> { tracing::debug!("Delaying the packet for {delay:?}"); tokio::time::sleep(delay.to_duration()).await; - Self::forward( - pool, - max_retries, - retry_delay, - Body::new_sphinx(packet), - NymNodeRoutingAddress::try_from(next_node_addr)?, - ) - .await + self.forward(packet, NymNodeRoutingAddress::try_from(next_node_addr)?) + .await } async fn forward_payload_to_destination( - pool: &ConnectionPool, - max_retries: usize, - retry_delay: Duration, - payload: Payload, + &self, + payload: Body, destination_addr: DestinationAddressBytes, ) -> Result<()> { tracing::debug!("Forwarding final payload to destination mixnode"); - Self::forward( - pool, - max_retries, - retry_delay, - Body::new_final_payload(payload), + self.forward( + payload, NymNodeRoutingAddress::try_from_bytes(&destination_addr.as_bytes())?, ) .await } - async fn forward( - pool: &ConnectionPool, - max_retries: usize, - retry_delay: Duration, - body: Body, - to: NymNodeRoutingAddress, - ) -> Result<()> { + async fn forward(&self, msg: Body, to: NymNodeRoutingAddress) -> Result<()> { let addr = SocketAddr::from(to); - let arc_socket = pool.get_or_init(&addr).await?; - - if let Err(e) = { - let mut socket = arc_socket.lock().await; - body.write(&mut *socket).await - } { - tracing::error!("Failed to forward packet to {addr} with error: {e}. Retrying..."); - return mixnet_protocol::retry_backoff( - addr, - max_retries, - retry_delay, - body, - arc_socket, - ) - .await - .map_err(Into::into); - } + + self.message_tx.send(TargetedMessage::new(addr, msg))?; Ok(()) } } + +struct MessageHandler { + // TODO: remove this allow when we implement the retry logic + #[allow(dead_code)] + config: MixnetNodeConfig, + message_rx: mpsc::UnboundedReceiver, + message_tx: mpsc::UnboundedSender, + connections: HashMap, +} + +impl MessageHandler { + pub fn new( + message_tx: mpsc::UnboundedSender, + message_rx: mpsc::UnboundedReceiver, + config: MixnetNodeConfig, + ) -> Self { + Self { + message_tx, + message_rx, + connections: HashMap::with_capacity(config.connection_pool_size), + config, + } + } + + pub async fn run(mut self) { + loop { + tokio::select! { + msg = self.message_rx.recv() => { + if let Some(msg) = msg { + match self.send(msg).await { + Ok(Some(msg)) => { + let _ = self.message_tx.send(msg); + } + Ok(None) => {}, + Err(e) => { + tracing::error!("failed to send msg: {e}"); + }, + } + } else { + // Channel closed, we should shutdown the message handler thread + return; + } + }, + _ = tokio::signal::ctrl_c() => { + tracing::info!("Shutting down message handler thread..."); + return; + } + } + } + } + + /// Send a message to the remote node, + /// return Ok(Some((Duration, Message))) if the message is not sent and the error is retryable + /// return Ok(None) if the message is sent successfully + /// return Err(e) if the message is not sent and the error is not retryable + async fn send(&mut self, mut msg: TargetedMessage) -> Result> { + if msg.retry_count >= self.config.max_retries { + return Err(MixnetNodeError::Protocol(ProtocolError::ReachMaxRetries( + self.config.max_retries, + ))); + } + use std::io::ErrorKind; + + if let std::collections::hash_map::Entry::Vacant(e) = self.connections.entry(msg.target) { + match TcpStream::connect(msg.target).await { + Ok(tcp) => { + e.insert(tcp); + } + Err(e) => { + tracing::error!("failed to connect to {}: {e}", msg.target); + return Ok(Some(msg)); + } + } + } + + let tcp = self.connections.get_mut(&msg.target).unwrap(); + + if msg.retry_count > 0 { + let wait = Duration::from_millis( + (self.config.retry_delay.as_millis() as u64).pow(msg.retry_count as u32), + ); + tokio::time::sleep(wait).await; + } + + match msg.body.write(tcp).await { + Ok(_) => Ok(None), + // we should only retry io errors (exclude unsupported for now, may be more in future) + Err(ProtocolError::IO(e)) if e.kind() != ErrorKind::Unsupported => { + // Update the connection, I actully do not want to do it unless the connection is in broken + // situation, but rust does not provide a method to let us check if the connection is broken + // or not, so I just hard code the possible situations worth to refresh. + if matches!( + e.kind(), + ErrorKind::ConnectionAborted + | ErrorKind::ConnectionReset + | ErrorKind::ConnectionRefused + | ErrorKind::NotConnected + | ErrorKind::BrokenPipe + | ErrorKind::TimedOut + ) { + match TcpStream::connect(msg.target).await { + Ok(fresh_tcp) => { + *tcp = fresh_tcp; + } + Err(e) => { + tracing::error!("failed to update connection to {}, the local machine network is down: {e}", msg.target); + return Err(MixnetNodeError::Protocol(ProtocolError::IO(e))); + } + } + } + + if msg.retry_count < self.config.max_retries { + msg.retry_count += 1; + return Ok(Some(msg)); + } + + tracing::error!( + "failed to forward msg to {}: reach the maximum retries", + msg.target, + ); + Err(MixnetNodeError::Protocol(ProtocolError::ReachMaxRetries( + self.config.max_retries, + ))) + } + Err(e) => { + tracing::error!("failed to forward msg to {}: {e}", msg.target); + Err(MixnetNodeError::Protocol(e)) + } + } + } +} + +pub struct TargetedMessage { + target: SocketAddr, + body: Body, + retry_count: usize, +} + +impl TargetedMessage { + fn new(target: SocketAddr, body: Body) -> Self { + Self { + target, + body, + retry_count: 0, + } + } +} diff --git a/mixnet/protocol/Cargo.toml b/mixnet/protocol/Cargo.toml index e21c7ff48..bdbf5c726 100644 --- a/mixnet/protocol/Cargo.toml +++ b/mixnet/protocol/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -tokio = { version = "1.32", features = ["sync", "net"] } +tokio = { version = "1.32", features = ["sync", "net", "time"] } sphinx-packet = "0.1.0" futures = "0.3" tokio-util = { version = "0.7", features = ["io", "io-util"] } diff --git a/mixnet/protocol/src/lib.rs b/mixnet/protocol/src/lib.rs index 459eac7bc..af6fdf7c5 100644 --- a/mixnet/protocol/src/lib.rs +++ b/mixnet/protocol/src/lib.rs @@ -1,6 +1,7 @@ use sphinx_packet::{payload::Payload, SphinxPacket}; use std::{io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration}; + use tokio::{ io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, net::TcpStream, @@ -122,7 +123,6 @@ pub async fn retry_backoff( for idx in 0..max_retries { // backoff let wait = Duration::from_millis((retry_delay.as_millis() as u64).pow(idx as u32)); - tokio::time::sleep(wait).await; let mut socket = socket.lock().await; match body.write(&mut *socket).await { diff --git a/mixnet/util/Cargo.toml b/mixnet/util/Cargo.toml index 7ab677749..66b323f0b 100644 --- a/mixnet/util/Cargo.toml +++ b/mixnet/util/Cargo.toml @@ -5,5 +5,4 @@ edition = "2021" [dependencies] tokio = { version = "1.32", default-features = false, features = ["sync", "net"] } -parking_lot = { version = "0.12", features = ["send_guard"] } -mixnet-protocol = { path = "../protocol" } \ No newline at end of file +mixnet-protocol = { path = "../protocol" } diff --git a/mixnet/util/src/lib.rs b/mixnet/util/src/lib.rs index e00f3714b..a8431eb21 100644 --- a/mixnet/util/src/lib.rs +++ b/mixnet/util/src/lib.rs @@ -1,7 +1,6 @@ use std::{collections::HashMap, net::SocketAddr, sync::Arc}; -use tokio::net::TcpStream; -use tokio::sync::Mutex; +use tokio::{net::TcpStream, sync::Mutex}; #[derive(Clone)] pub struct ConnectionPool {