From aac93d5270c2670231687137f568ba764172209d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Mon, 2 Sep 2024 14:02:58 +0200 Subject: [PATCH 01/55] Remove unused dependency on quinn-proto --- net/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/net/Cargo.toml b/net/Cargo.toml index f56c764c4..800da2760 100644 --- a/net/Cargo.toml +++ b/net/Cargo.toml @@ -12,7 +12,6 @@ bytecodec = "0.4.15" bytes = "1.1.0" futures-util = { workspace = true } quinn = "0.10.2" -quinn-proto = "0.10.2" rand = { package = "ouisync-rand", path = "../rand" } rcgen = { workspace = true } rustls = { workspace = true, features = ["quic", "dangerous_configuration"] } From d1fb87df78ebfeea39b4d17b35b0b548cd177a20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Mon, 2 Sep 2024 18:25:00 +0200 Subject: [PATCH 02/55] Bump quinn to v0.11 and rustls to v0.23 --- Cargo.toml | 7 +- bridge/Cargo.toml | 4 +- bridge/src/transport/remote.rs | 35 ++-- bridge/src/transport/tls.rs | 21 +- cli/Cargo.toml | 5 +- cli/src/handler/remote.rs | 13 +- cli/src/metrics.rs | 171 +++++++++++---- cli/src/repository.rs | 1 + cli/src/state.rs | 1 + cli/tests/utils.rs | 6 +- lib/Cargo.toml | 2 +- net/Cargo.toml | 4 +- net/src/quic.rs | 337 +++++++++++++++++++----------- utils/stun-server-list/Cargo.toml | 2 +- 14 files changed, 402 insertions(+), 207 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e6864fe63..8c1982c32 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,10 +48,11 @@ metrics-exporter-prometheus = { version = "0.13.0", default-features = false } metrics-util = { version = "0.16.0", default-features = false } num_enum = { version = "0.7.0", default-features = false } once_cell = "1.18.0" +pin-project-lite = "0.2.13" rand = { package = "ouisync-rand", path = "rand" } -rcgen = { version = "0.11.3", default-features = false } +rcgen = "0.13" rmp-serde = "1.1.0" -rustls = { version = "0.21.0", default-features = false } +rustls = { version = "0.23.5", default-features = false } serde = { version = "1.0", features = ["derive", "rc"] } serde_bytes = "0.11.8" serde_json = "1.0.94" @@ -59,7 +60,7 @@ sqlx = { version = "0.7.4", default-features = false, features = ["runtime-tokio tempfile = "3.2" thiserror = "1.0.49" tokio = { version = "1.38.0", default-features = false } -tokio-rustls = "0.24.1" +tokio-rustls = { version = "0.26", default-features = false } tokio-stream = { version = "0.1.15", default-features = false } tokio-util = "0.7.11" tracing = { version = "0.1.38" } diff --git a/bridge/Cargo.toml b/bridge/Cargo.toml index b110e7db8..662db1048 100644 --- a/bridge/Cargo.toml +++ b/bridge/Cargo.toml @@ -29,11 +29,11 @@ serde_json = { workspace = true } state_monitor = { path = "../state_monitor" } thiserror = { workspace = true } tokio = { workspace = true } -tokio-tungstenite = { version = "0.20.0", features = ["rustls-tls-webpki-roots"] } +tokio-tungstenite = { version = "0.23.1", features = ["rustls-tls-webpki-roots"] } tokio-rustls = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter", "json"] } -webpki-roots = "0.22.6" +webpki-roots = "0.26.5" [target.'cfg(target_os = "android")'.dependencies] libc = "0.2.147" diff --git a/bridge/src/transport/remote.rs b/bridge/src/transport/remote.rs index dfb7fe27f..e6fd9b927 100644 --- a/bridge/src/transport/remote.rs +++ b/bridge/src/transport/remote.rs @@ -21,7 +21,11 @@ use tokio::{ task::JoinSet, }; use tokio_rustls::{ - rustls::{self, ConnectionCommon}, + rustls::{ + self, + pki_types::{CertificateDer, PrivateKeyDer}, + ConnectionCommon, + }, TlsAcceptor, }; use tokio_tungstenite::{ @@ -32,11 +36,10 @@ use tracing::Instrument; /// Shared config for `RemoteServer` pub fn make_server_config( - cert_chain: Vec, - key: rustls::PrivateKey, + cert_chain: Vec>, + key: PrivateKeyDer<'static>, ) -> io::Result> { let config = rustls::ServerConfig::builder() - .with_safe_defaults() .with_no_client_auth() .with_single_cert(cert_chain, key) .map_err(|error| io::Error::new(io::ErrorKind::InvalidInput, error))?; @@ -46,28 +49,25 @@ pub fn make_server_config( /// Shared config for `RemoteClient` pub fn make_client_config( - additional_root_certs: &[rustls::Certificate], + additional_root_certs: &[CertificateDer<'_>], ) -> io::Result> { let mut root_cert_store = rustls::RootCertStore::empty(); // Add default root certificates - root_cert_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { - rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); + root_cert_store.extend( + webpki_roots::TLS_SERVER_ROOTS + .iter() + .map(|ta| ta.to_owned()), + ); // Add custom root certificates (if any) for cert in additional_root_certs { root_cert_store - .add(cert) + .add(cert.clone()) .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; } let config = rustls::ClientConfig::builder() - .with_safe_defaults() .with_root_certificates(root_cert_store) .with_no_client_auth(); @@ -289,6 +289,7 @@ mod tests { sync::atomic::{AtomicUsize, Ordering}, }; use tokio::task; + use tokio_rustls::rustls::pki_types::PrivatePkcs8KeyDer; #[tokio::test] async fn basic() { @@ -350,10 +351,10 @@ mod tests { fn make_configs() -> (Arc, Arc) { let gen = rcgen::generate_simple_self_signed(["localhost".to_owned()]).unwrap(); - let cert = rustls::Certificate(gen.serialize_der().unwrap()); - let key = rustls::PrivateKey(gen.serialize_private_key_der()); + let cert = CertificateDer::from(gen.cert); + let key = PrivatePkcs8KeyDer::from(gen.key_pair.serialize_der()); - let server_config = make_server_config(vec![cert.clone()], key).unwrap(); + let server_config = make_server_config(vec![cert.clone()], key.into()).unwrap(); let client_config = make_client_config(&[cert]).unwrap(); (server_config, client_config) diff --git a/bridge/src/transport/tls.rs b/bridge/src/transport/tls.rs index 049b1597f..5c95e896a 100644 --- a/bridge/src/transport/tls.rs +++ b/bridge/src/transport/tls.rs @@ -2,10 +2,10 @@ use std::{io, path::Path}; use tokio::fs; -use tokio_rustls::rustls::{Certificate, PrivateKey}; +use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; /// Loads all certificates in the given directory (non-recursively). -pub async fn load_certificates_from_dir(dir: &Path) -> io::Result> { +pub async fn load_certificates_from_dir(dir: &Path) -> io::Result>> { let mut read_dir = match fs::read_dir(dir).await { Ok(read_dir) => read_dir, Err(error) if error.kind() == io::ErrorKind::NotFound => return Ok(Vec::new()), @@ -33,17 +33,22 @@ pub async fn load_certificates_from_dir(dir: &Path) -> io::Result) -> io::Result> { +pub async fn load_certificates_from_file( + path: impl AsRef, +) -> io::Result>> { load_pems(path.as_ref(), "CERTIFICATE") .await - .map(|pems| pems.map(Certificate).collect()) + .map(|pems| pems.map(|content| content.into()).collect()) } /// Loads private keys from the given file. -pub async fn load_keys_from_file(path: impl AsRef) -> io::Result> { - load_pems(path.as_ref(), "PRIVATE KEY") - .await - .map(|pems| pems.map(PrivateKey).collect()) +pub async fn load_keys_from_file( + path: impl AsRef, +) -> io::Result>> { + load_pems(path.as_ref(), "PRIVATE KEY").await.map(|pems| { + pems.map(|content| PrivatePkcs8KeyDer::from(content).into()) + .collect() + }) } async fn load_pems<'a>( diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 420167919..c287e3d7e 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -22,8 +22,7 @@ chrono = { workspace = true } clap = { workspace = true } dirs = "4.0.0" futures-util = { workspace = true } -hyper = { version = "0.14.27", features = ["server", "http1", "http2"] } -hyper-rustls = { version = "0.24.1", default-features = false, features = ["acceptor"] } +hyper = { version = "1.4.1", features = ["server", "http1"] } interprocess = { version = "1.2.1", features = ["tokio_support"] } maxminddb = "0.23.0" metrics = { workspace = true } @@ -32,12 +31,12 @@ ouisync-bridge = { path = "../bridge" } ouisync-lib = { package = "ouisync", path = "../lib" } ouisync-vfs = { path = "../vfs" } rand = { workspace = true } -rustls = { workspace = true } scoped_task = { path = "../scoped_task" } serde = { workspace = true } state_monitor = { path = "../state_monitor" } thiserror = { workspace = true } tokio = { workspace = true, features = ["signal", "io-std"] } +tokio-rustls = { workspace = true } tokio-stream = { workspace = true } tokio-util = { workspace = true, features = ["codec", "compat"] } tracing = { workspace = true } diff --git a/cli/src/handler/remote.rs b/cli/src/handler/remote.rs index c89991816..9e26444dd 100644 --- a/cli/src/handler/remote.rs +++ b/cli/src/handler/remote.rs @@ -205,11 +205,14 @@ mod tests { make_client_config, make_server_config, RemoteClient, RemoteServer, }; use ouisync_lib::{crypto::sign::Keypair, AccessMode, WriteSecrets}; - use rustls::{Certificate, ClientConfig, PrivateKey}; use state_monitor::StateMonitor; use std::net::Ipv4Addr; use tempfile::TempDir; use tokio::task; + use tokio_rustls::rustls::{ + pki_types::{CertificateDer, PrivatePkcs8KeyDer}, + ClientConfig, + }; #[test] fn insert_separators_test() { @@ -467,11 +470,11 @@ mod tests { mount_dir: temp_dir.path().join("mount"), }; - let certs = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap(); - let cert = Certificate(certs.serialize_der().unwrap()); - let private_key = PrivateKey(certs.serialize_private_key_der()); + let gen = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap(); + let cert = CertificateDer::from(gen.cert); + let private_key = PrivatePkcs8KeyDer::from(gen.key_pair.serialize_der()); - let server_config = make_server_config(vec![cert.clone()], private_key).unwrap(); + let server_config = make_server_config(vec![cert.clone()], private_key.into()).unwrap(); let client_config = make_client_config(&[cert]).unwrap(); let state = State::init(&dirs, StateMonitor::make_root()).await.unwrap(); diff --git a/cli/src/metrics.rs b/cli/src/metrics.rs index fccdde370..e103423d1 100644 --- a/cli/src/metrics.rs +++ b/cli/src/metrics.rs @@ -3,12 +3,7 @@ use crate::{ protocol::Error, state::State, }; -use hyper::{ - server::{conn::AddrIncoming, Server}, - service::{make_service_fn, service_fn}, - Body, Response, -}; -use hyper_rustls::TlsAcceptor; +use hyper::{server::conn::http1, service::service_fn, Response}; use metrics::{Gauge, Key, KeyName, Label, Level, Metadata, Recorder, Unit}; use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusRecorder}; use ouisync_bridge::config::{ConfigError, ConfigKey}; @@ -20,10 +15,17 @@ use std::{ io, net::SocketAddr, path::PathBuf, + pin::Pin, sync::Mutex, + task::{Context, Poll}, time::{Duration, Instant}, }; -use tokio::task; +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf}, + net::TcpListener, + task::{self, JoinSet}, +}; +use tokio_rustls::TlsAcceptor; const BIND_METRICS_KEY: ConfigKey = ConfigKey::new("bind_metrics", "Addresses to bind the metrics endpoint to"); @@ -88,34 +90,17 @@ async fn start(state: &State, addr: SocketAddr) -> Result(service_fn(move |_| { - let recorder_handle = recorder_handle.clone(); - let collect_requester = collect_requester.clone(); - - async move { - collect_requester.request().await; - tracing::trace!("Serving metrics"); - - let content = recorder_handle.render(); - let content = Body::from(content); + let tcp_listener = TcpListener::bind(&addr).await?; - Ok::<_, Infallible>(Response::new(content)) - } - })) - } - }); - - let incoming = - AddrIncoming::bind(&addr).map_err(|error| io::Error::new(io::ErrorKind::Other, error))?; - tracing::info!("Metrics server listening on {}", incoming.local_addr()); + match tcp_listener.local_addr() { + Ok(addr) => tracing::info!("Metrics server listening on {addr}"), + Err(error) => tracing::error!( + ?error, + "Metrics server failed to retrieve the listening address" + ), + } - let acceptor = TlsAcceptor::new(state.get_server_config().await?, incoming); - let server = Server::builder(acceptor); + let tls_acceptor = TlsAcceptor::from(state.get_server_config().await?); task::spawn(collect( collect_acceptor, @@ -125,8 +110,57 @@ async fn start(state: &State, addr: SocketAddr) -> Result conn, + Err(error) => { + tracing::error!(?error, "Metrics server failed to accept new connection"); + break; + } + }; + + let stream = match tls_acceptor.accept(stream).await { + Ok(stream) => stream, + Err(error) => { + tracing::warn!( + ?error, + %addr, + "Metrics server failed to perform TLS handshake" + ); + continue; + } + }; + + let recorder_handle = recorder_handle.clone(); + let collect_requester = collect_requester.clone(); + + tasks.spawn(async move { + let service = move |_req| { + let recorder_handle = recorder_handle.clone(); + let collect_requester = collect_requester.clone(); + + async move { + collect_requester.request().await; + tracing::trace!("Serving metrics"); + + let content = recorder_handle.render(); + + Ok::<_, Infallible>(Response::new(content)) + } + }; + + match http1::Builder::new() + .serve_connection(StreamCompat(stream), service_fn(service)) + .await + { + Ok(()) => (), + Err(error) => { + tracing::error!(?error, %addr, "Metrics server connection failed") + } + } + }); } }) .abort_handle() @@ -288,3 +322,70 @@ mod sync { } } } + +// hyper no longer accepts `tokio::AsyncRead` and `tokio::AsyncWrite` and uses its own traits +// instead so we now need this compatibility wrapper. +struct StreamCompat(T); + +impl hyper::rt::Read for StreamCompat +where + T: AsyncRead + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll> { + unsafe { + let mut buf = ReadBuf::uninit(buf.as_mut()); + + let n = match Pin::new(&mut self.0).poll_read(cx, &mut buf) { + Poll::Ready(Ok(())) => buf.filled().len(), + other => return other, + }; + + buf.advance(n); + } + + Poll::Ready(Ok(())) + } +} + +impl hyper::rt::Write for StreamCompat +where + T: AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } + + fn is_write_vectored(&self) -> bool { + self.0.is_write_vectored() + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.0).poll_write_vectored(cx, bufs) + } +} diff --git a/cli/src/repository.rs b/cli/src/repository.rs index 0a1aa0754..1864ef534 100644 --- a/cli/src/repository.rs +++ b/cli/src/repository.rs @@ -15,6 +15,7 @@ use std::{ }; use thiserror::Error; use tokio::{fs, runtime, task}; +use tokio_rustls::rustls; use tokio_stream::StreamExt; // Config keys diff --git a/cli/src/state.rs b/cli/src/state.rs index 037ec55bc..a957043e2 100644 --- a/cli/src/state.rs +++ b/cli/src/state.rs @@ -19,6 +19,7 @@ use std::{ time::Duration, }; use tokio::{sync::OnceCell, time}; +use tokio_rustls::rustls; pub(crate) struct State { pub config: ConfigStore, diff --git a/cli/tests/utils.rs b/cli/tests/utils.rs index 48684d8b2..2a96502b4 100644 --- a/cli/tests/utils.rs +++ b/cli/tests/utils.rs @@ -27,7 +27,7 @@ const CONFIG_DIR: &str = "config"; const API_SOCKET: &str = "api.sock"; const DEFAULT_REPO: &str = "test"; -static CERT: Lazy = +static CERT: Lazy = Lazy::new(|| rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap()); impl Bin { @@ -43,11 +43,11 @@ impl Bin { fs::create_dir_all(config_dir.join("root_certs")).unwrap(); // Install the certificate - let cert = CERT.serialize_pem().unwrap(); + let cert = CERT.cert.pem(); // For server: fs::write(config_dir.join("cert.pem"), &cert).unwrap(); - fs::write(config_dir.join("key.pem"), CERT.serialize_private_key_pem()).unwrap(); + fs::write(config_dir.join("key.pem"), CERT.key_pair.serialize_pem()).unwrap(); // For client: fs::write(config_dir.join("root_certs").join("localhost.pem"), &cert).unwrap(); diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 1bfe3cbfb..cc04c1259 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -56,7 +56,7 @@ noise-rust-crypto = { version = "0.6.1", default-features = false, features = [" num_enum = { workspace = true } once_cell = { workspace = true } parse-size = { version = "1.0.0", features = ["std"] } -pin-project-lite = "0.2.13" +pin-project-lite = { workspace = true } rand = { workspace = true } ref-cast = "1.0.14" rupnp = { version = "1.1.0", default-features = false, features = [] } diff --git a/net/Cargo.toml b/net/Cargo.toml index 800da2760..fb80d9daf 100644 --- a/net/Cargo.toml +++ b/net/Cargo.toml @@ -11,10 +11,10 @@ version.workspace = true bytecodec = "0.4.15" bytes = "1.1.0" futures-util = { workspace = true } -quinn = "0.10.2" +pin-project-lite = { workspace = true } +quinn = "0.11.4" rand = { package = "ouisync-rand", path = "../rand" } rcgen = { workspace = true } -rustls = { workspace = true, features = ["quic", "dangerous_configuration"] } socket2 = "0.5.7" # To be able to setsockopts before a socket is bound stun_codec = "0.3.4" thiserror = "1.0.31" diff --git a/net/src/quic.rs b/net/src/quic.rs index 20e23af31..62e926fbb 100644 --- a/net/src/quic.rs +++ b/net/src/quic.rs @@ -1,6 +1,19 @@ use crate::KEEP_ALIVE_INTERVAL; use bytes::BytesMut; +use pin_project_lite::pin_project; +use quinn::{ + crypto::rustls::QuicClientConfig, + rustls::{ + self, + client::danger::{HandshakeSignatureValid, ServerCertVerified}, + pki_types::{CertificateDer, PrivatePkcs8KeyDer, ServerName, UnixTime}, + DigitallySignedStruct, SignatureScheme, + }, + UdpPoller, +}; use std::{ + fmt, + future::Future, io, net::SocketAddr, pin::Pin, @@ -20,15 +33,13 @@ use tokio::{ const CERT_DOMAIN: &str = "ouisync.net"; -pub type Result = std::result::Result; - //------------------------------------------------------------------------------ pub struct Connector { endpoint: quinn::Endpoint, } impl Connector { - pub async fn connect(&self, remote_addr: SocketAddr) -> Result { + pub async fn connect(&self, remote_addr: SocketAddr) -> Result { let connection = self.endpoint.connect(remote_addr, CERT_DOMAIN)?.await?; let (tx, rx) = connection.open_bi().await?; Ok(Connection::new(rx, tx, connection.remote_address())) @@ -52,7 +63,7 @@ impl Acceptor { self.endpoint .accept() .await - .map(|connecting| Connecting { connecting }) + .map(|incoming| Connecting { incoming }) } pub fn local_addr(&self) -> &SocketAddr { @@ -61,12 +72,12 @@ impl Acceptor { } pub struct Connecting { - connecting: quinn::Connecting, + incoming: quinn::Incoming, } impl Connecting { - pub async fn finish(self) -> Result { - let connection = self.connecting.await?; + pub async fn finish(self) -> Result { + let connection = self.incoming.await?; let (tx, rx) = connection.accept_bi().await?; Ok(Connection::new(rx, tx, connection.remote_address())) } @@ -108,24 +119,6 @@ impl Connection { OwnedWriteHalf { tx, can_finish }, ) } - - /// Make sure all data is sent, no more data can be sent afterwards. - #[cfg(test)] - pub async fn finish(&mut self) -> Result<()> { - if !self.can_finish { - return Err(Error::Write(quinn::WriteError::UnknownStream)); - } - - self.can_finish = false; - - match self.tx.take() { - Some(mut tx) => { - tx.finish().await?; - Ok(()) - } - None => Err(Error::Write(quinn::WriteError::UnknownStream)), - } - } } impl AsyncRead for Connection { @@ -159,13 +152,9 @@ impl AsyncWrite for Connection { ) -> Poll> { let this = self.get_mut(); match &mut this.tx { - Some(tx) => { - let poll = Pin::new(tx).poll_write(cx, buf); - if let Poll::Ready(r) = &poll { - this.can_finish &= r.is_ok(); - } - poll - } + Some(tx) => Pin::new(tx) + .poll_write(cx, buf) + .map_err(|error| error.into()), None => Poll::Ready(Err(io::Error::new( io::ErrorKind::BrokenPipe, "already finished", @@ -174,15 +163,8 @@ impl AsyncWrite for Connection { } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let this = self.get_mut(); - match &mut this.tx { - Some(tx) => { - let poll = Pin::new(tx).poll_flush(cx); - if let Poll::Ready(r) = &poll { - this.can_finish &= r.is_ok(); - } - poll - } + match &mut self.get_mut().tx { + Some(tx) => Pin::new(tx).poll_flush(cx), None => Poll::Ready(Err(io::Error::new( io::ErrorKind::BrokenPipe, "already finished", @@ -191,12 +173,8 @@ impl AsyncWrite for Connection { } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let this = self.get_mut(); - match &mut this.tx { - Some(tx) => { - this.can_finish = false; - Pin::new(tx).poll_shutdown(cx) - } + match &mut self.get_mut().tx { + Some(tx) => Pin::new(tx).poll_shutdown(cx), None => Poll::Ready(Err(io::Error::new( io::ErrorKind::BrokenPipe, "already finished", @@ -212,7 +190,7 @@ impl Drop for Connection { } if let Some(mut tx) = self.tx.take() { - tokio::task::spawn(async move { tx.finish().await.unwrap_or(()) }); + tx.finish().ok(); } } } @@ -252,17 +230,14 @@ impl AsyncWrite for OwnedWriteHalf { ) -> Poll> { let this = self.get_mut(); match &mut this.tx { - Some(tx) => { - let poll = Pin::new(tx).poll_write(cx, buf); - - if let Poll::Ready(r) = &poll { - if r.is_err() { - this.can_finish.store(false, Ordering::SeqCst); - } + Some(tx) => match Pin::new(tx).poll_write(cx, buf) { + Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), + Poll::Ready(Err(error)) => { + this.can_finish.store(false, Ordering::SeqCst); + Poll::Ready(Err(error.into())) } - - poll - } + Poll::Pending => Poll::Pending, + }, None => Poll::Ready(Err(io::Error::new( io::ErrorKind::BrokenPipe, "already finished", @@ -313,16 +288,18 @@ impl Drop for OwnedWriteHalf { } if let Some(mut tx) = self.tx.take() { - tokio::task::spawn(async move { tx.finish().await.unwrap_or(()) }); + tx.finish().ok(); } } } //------------------------------------------------------------------------------ -pub async fn configure(bind_addr: SocketAddr) -> Result<(Connector, Acceptor, SideChannelMaker)> { +pub async fn configure( + bind_addr: SocketAddr, +) -> Result<(Connector, Acceptor, SideChannelMaker), Error> { let server_config = make_server_config()?; - let custom_socket = CustomUdpSocket::bind(bind_addr).await?; - let side_channel_maker = custom_socket.side_channel_maker(); + let custom_socket = Arc::new(CustomUdpSocket::bind(bind_addr).await?); + let side_channel_maker = custom_socket.clone().side_channel_maker(); let mut endpoint = quinn::Endpoint::new_with_abstract_socket( quinn::EndpointConfig::default(), @@ -369,29 +346,71 @@ pub enum Error { // Dummy certificate verifier that treats any certificate as valid. In our P2P system there are no // certification authorities. TODO: I think this still makes the TLS encryption provided by QUIC // usefull against passive MitM attacks (eavesdropping), but not against the active ones. -struct SkipServerVerification; +#[derive(Debug)] +struct SkipServerVerification(rustls::crypto::CryptoProvider); + +impl SkipServerVerification { + fn new() -> Self { + Self(rustls::crypto::ring::default_provider()) + } +} -impl rustls::client::ServerCertVerifier for SkipServerVerification { +impl rustls::client::danger::ServerCertVerifier for SkipServerVerification { fn verify_server_cert( &self, - _end_entity: &rustls::Certificate, - _intermediates: &[rustls::Certificate], - _server_name: &rustls::ServerName, - _scts: &mut dyn Iterator, - _ocsp_response: &[u8], - _now: std::time::SystemTime, - ) -> std::result::Result { - Ok(rustls::client::ServerCertVerified::assertion()) + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp: &[u8], + _now: UnixTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls12_signature( + message, + cert, + dss, + &self.0.signature_verification_algorithms, + ) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls13_signature( + message, + cert, + dss, + &self.0.signature_verification_algorithms, + ) + } + + fn supported_verify_schemes(&self) -> Vec { + self.0.signature_verification_algorithms.supported_schemes() } } fn make_client_config() -> quinn::ClientConfig { - let crypto = rustls::ClientConfig::builder() - .with_safe_defaults() - .with_custom_certificate_verifier(Arc::new(SkipServerVerification {})) + let crypto_config = rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(Arc::new(SkipServerVerification::new())) .with_no_client_auth(); - let mut client_config = quinn::ClientConfig::new(Arc::new(crypto)); + let mut client_config = quinn::ClientConfig::new(Arc::new( + // `expect` should be OK because we made sure we constructed the rustls::ClientConfig in a + // compliant way. + QuicClientConfig::try_from(crypto_config).expect("failed to create quic client config"), + )); let mut transport_config = quinn::TransportConfig::default(); @@ -408,24 +427,20 @@ fn make_client_config() -> quinn::ClientConfig { client_config } -fn make_server_config() -> Result { +fn make_server_config() -> Result { // Generate a self signed certificate. let cert = rcgen::generate_simple_self_signed(vec![CERT_DOMAIN.into()]).unwrap(); - let cert_der = cert.serialize_der().unwrap(); - let priv_key = cert.serialize_private_key_der(); - let priv_key = rustls::PrivateKey(priv_key); - let cert_chain = vec![rustls::Certificate(cert_der)]; - - let mut server_config = quinn::ServerConfig::with_single_cert(cert_chain, priv_key)?; + let cert_der = CertificateDer::from(cert.cert); + let priv_key = PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der()); - let mut transport_config = quinn::TransportConfig::default(); + let mut server_config = + quinn::ServerConfig::with_single_cert(vec![cert_der.clone()], priv_key.into())?; + let transport_config = Arc::get_mut(&mut server_config.transport).unwrap(); transport_config .max_concurrent_uni_streams(0_u8.into()) .max_idle_timeout((2 * KEEP_ALIVE_INTERVAL).try_into().ok()); - server_config.transport_config(Arc::new(transport_config)); - Ok(server_config) } @@ -451,8 +466,8 @@ struct Packet { #[derive(Debug)] struct CustomUdpSocket { - io: Arc, - quinn_socket_state: quinn::udp::UdpSocketState, + io: tokio::net::UdpSocket, + state: quinn::udp::UdpSocketState, side_channel_tx: broadcast::Sender, } @@ -461,40 +476,35 @@ impl CustomUdpSocket { let socket = crate::udp::UdpSocket::bind(addr).await?; let socket = socket.into_std()?; - quinn::udp::UdpSocketState::configure((&socket).into())?; + let state = quinn::udp::UdpSocketState::new((&socket).into())?; Ok(Self { - io: Arc::new(tokio::net::UdpSocket::from_std(socket)?), - quinn_socket_state: quinn::udp::UdpSocketState::new(), + io: tokio::net::UdpSocket::from_std(socket)?, + state, side_channel_tx: broadcast::channel(MAX_SIDE_CHANNEL_PENDING_PACKETS).0, }) } - fn side_channel_maker(&self) -> SideChannelMaker { + fn side_channel_maker(self: Arc) -> SideChannelMaker { SideChannelMaker { - io: self.io.clone(), + socket: self.clone(), packet_tx: self.side_channel_tx.clone(), } } } impl quinn::AsyncUdpSocket for CustomUdpSocket { - fn poll_send( - &self, - state: &quinn::udp::UdpState, - cx: &mut Context, - transmits: &[quinn::udp::Transmit], - ) -> Poll> { - let quinn_socket_state = &self.quinn_socket_state; - let io = &*self.io; - loop { - ready!(io.poll_send_ready(cx))?; - if let Ok(res) = io.try_io(Interest::WRITABLE, || { - quinn_socket_state.send(io.into(), state, transmits) - }) { - return Poll::Ready(Ok(res)); - } - } + fn create_io_poller(self: Arc) -> Pin> { + Box::pin(UdpPollHelper::new(move || { + let socket = self.clone(); + async move { socket.io.writable().await } + })) + } + + fn try_send(&self, transmit: &quinn::udp::Transmit) -> io::Result<()> { + self.io.try_io(Interest::WRITABLE, || { + self.state.send((&self.io).into(), transmit) + }) } fn poll_recv( @@ -506,9 +516,7 @@ impl quinn::AsyncUdpSocket for CustomUdpSocket { loop { ready!(self.io.poll_recv_ready(cx))?; if let Ok(res) = self.io.try_io(Interest::READABLE, || { - let res = self - .quinn_socket_state - .recv((&*self.io).into(), bufs, metas); + let res = self.state.recv((&self.io).into(), bufs, metas); if let Ok(msg_count) = res { send_to_side_channels(&self.side_channel_tx, bufs, metas, msg_count); @@ -524,6 +532,18 @@ impl quinn::AsyncUdpSocket for CustomUdpSocket { fn local_addr(&self) -> io::Result { self.io.local_addr() } + + fn may_fragment(&self) -> bool { + self.state.may_fragment() + } + + fn max_transmit_segments(&self) -> usize { + self.state.max_gso_segments() + } + + fn max_receive_segments(&self) -> usize { + self.state.gro_segments() + } } fn send_to_side_channels( @@ -552,39 +572,86 @@ fn send_to_side_channels( } } +// This is copied verbatim from [quinn] +// (https://github.com/quinn-rs/quinn/blob/main/quinn/src/runtime.rs) as it's unfortunatelly not +// exposed from there. +pin_project! { + struct UdpPollHelper { + make_fut: MakeFut, + #[pin] + fut: Option, + } +} + +impl UdpPollHelper { + fn new(make_fut: MakeFut) -> Self { + Self { + make_fut, + fut: None, + } + } +} + +impl UdpPoller for UdpPollHelper +where + MakeFut: Fn() -> Fut + Send + Sync + 'static, + Fut: Future> + Send + Sync + 'static, +{ + fn poll_writable(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let mut this = self.project(); + + if this.fut.is_none() { + this.fut.set(Some((this.make_fut)())); + } + + let result = this.fut.as_mut().as_pin_mut().unwrap().poll(cx); + + if result.is_ready() { + this.fut.set(None); + } + + result + } +} + +impl fmt::Debug for UdpPollHelper { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("UdpPollHelper").finish_non_exhaustive() + } +} //------------------------------------------------------------------------------ /// Makes new `SideChannel`s. pub struct SideChannelMaker { - io: Arc, + socket: Arc, packet_tx: broadcast::Sender, } impl SideChannelMaker { pub fn make(&self) -> SideChannel { SideChannel { - io: self.io.clone(), + socket: self.socket.clone(), packet_rx: AsyncMutex::new(self.packet_tx.subscribe()), } } } pub struct SideChannel { - io: Arc, + socket: Arc, packet_rx: AsyncMutex>, } impl SideChannel { pub fn sender(&self) -> SideChannelSender { SideChannelSender { - io: self.io.clone(), + socket: self.socket.clone(), } } } impl DatagramSocket for SideChannel { async fn send_to<'a>(&'a self, buf: &'a [u8], target: SocketAddr) -> io::Result { - self.io.send_to(buf, target).await + self.socket.io.send_to(buf, target).await } // Note: receiving on side channels will only work when quinn is calling `poll_recv`. This @@ -615,7 +682,7 @@ impl DatagramSocket for SideChannel { } fn local_addr(&self) -> io::Result { - self.io.local_addr() + self.socket.io.local_addr() } } @@ -623,12 +690,12 @@ impl DatagramSocket for SideChannel { // `broadcast::Receiver` that the `CustomUdpSocket` would need to pass messages to. #[derive(Clone)] pub struct SideChannelSender { - io: Arc, + socket: Arc, } impl SideChannelSender { pub async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> io::Result<()> { - self.io.send_to(buf, target).await.map(|_| ()) + self.socket.io.send_to(buf, target).await.map(|_| ()) } } @@ -650,19 +717,35 @@ mod tests { let addr = *acceptor.local_addr(); - let message = b"hello world"; - let h1 = task::spawn(async move { let mut conn = acceptor.accept().await.unwrap().finish().await.unwrap(); - let mut buf = [0; 32]; - let n = conn.read(&mut buf).await.unwrap(); - assert_eq!(message, &buf[..n]); + + let mut buf = [0; 4]; + conn.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"ping"); + + conn.write_all(b"pong").await.unwrap(); }); let h2 = task::spawn(async move { let mut conn = connector.connect(addr).await.unwrap(); - conn.write_all(message).await.unwrap(); - conn.finish().await.unwrap(); + conn.write_all(b"ping").await.unwrap(); + + let mut buf = [0; 4]; + match conn.read_exact(&mut buf).await { + Ok(_) => (), + Err(error) => match error.downcast::() { + Ok(error) => match error { + quinn::ReadError::ConnectionLost( + quinn::ConnectionError::ApplicationClosed(_), + ) => { + // connection graceflly closed by the peer, this is expected. + } + error => panic!("unexpected error: {:?}", error), + }, + Err(error) => panic!("unexpected error: {:?}", error), + }, + } }); h1.await.unwrap(); diff --git a/utils/stun-server-list/Cargo.toml b/utils/stun-server-list/Cargo.toml index 879a4921a..a68e1e60b 100644 --- a/utils/stun-server-list/Cargo.toml +++ b/utils/stun-server-list/Cargo.toml @@ -11,4 +11,4 @@ version.workspace = true [dependencies] tokio = { workspace = true, features = [ "macros", "rt-multi-thread" ] } -reqwest = { version = "0.11.23", default-features = false, features = ["rustls-tls"] } +reqwest = { version = "0.12.7", default-features = false, features = ["rustls-tls"] } From 18a0278982eb7a818ec2a77ab73c94a240341e15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Tue, 3 Sep 2024 09:27:38 +0200 Subject: [PATCH 03/55] Simplify how quic connections are accepted --- lib/src/network/gateway.rs | 77 ++++++++++++-------------------------- net/examples/peer.rs | 6 +-- net/src/quic.rs | 25 +++++++------ 3 files changed, 41 insertions(+), 67 deletions(-) diff --git a/lib/src/network/gateway.rs b/lib/src/network/gateway.rs index 453b41454..801d36116 100644 --- a/lib/src/network/gateway.rs +++ b/lib/src/network/gateway.rs @@ -6,15 +6,12 @@ use net::{ tcp::{TcpListener, TcpStream}, }; use scoped_task::ScopedJoinHandle; -use std::{ - collections::HashMap, - net::{IpAddr, SocketAddr}, - sync::{Arc, Mutex}, -}; +use std::net::{IpAddr, SocketAddr}; use thiserror::Error; use tokio::{ select, - sync::{mpsc, oneshot, watch}, + sync::{mpsc, watch}, + task::JoinSet, time::{self, Duration}, }; use tracing::{field, Instrument, Span}; @@ -563,59 +560,33 @@ async fn run_tcp_listener(listener: TcpListener, tx: mpsc::Sender<(raw::Stream, } } -async fn run_quic_listener( - mut listener: quic::Acceptor, - tx: mpsc::Sender<(raw::Stream, PeerAddr)>, -) { - // Using `futures_util::stream::FuturesUnordered` may have been a nicer solution but I'm not - // sure whether `quic::Acceptor::accept()` is cancel safe. - let connectings = Arc::new(Mutex::new(HashMap::new())); - let mut next_connecting_id = 0; +async fn run_quic_listener(listener: quic::Acceptor, tx: mpsc::Sender<(raw::Stream, PeerAddr)>) { + let mut tasks = JoinSet::new(); loop { - let result = select! { - result = listener.accept() => result, + let connecting = select! { + connecting = listener.accept() => connecting, _ = tx.closed() => break, }; - match result { - Some(connecting) => { - // Using this channel to ensure the task is not removed from `connectings` before - // it's inserted. - let (start_task_tx, start_task_rx) = oneshot::channel(); - - let connecting_id = next_connecting_id; - next_connecting_id += 1; - - // Spawn so we can start listening for the next connection ASAP. - let task = scoped_task::spawn({ - let tx = tx.clone(); - let connectings = connectings.clone(); - async move { - if start_task_rx.await.is_ok() { - match connecting.finish().await { - Ok(socket) => { - let addr = *socket.remote_address(); - tx.send((raw::Stream::Quic(socket), PeerAddr::Quic(addr))) - .await - .ok(); - } - Err(error) => { - tracing::error!(?error, "Failed to accept connection"); - } - }; - } - connectings.lock().unwrap().remove(&connecting_id); + if let Some(connecting) = connecting { + let tx = tx.clone(); + let addr = connecting.remote_addr(); + + // Spawn so we can start listening for the next connection ASAP. + tasks.spawn(async move { + match connecting.complete().await { + Ok(connection) => { + tx.send((raw::Stream::Quic(connection), PeerAddr::Quic(addr))) + .await + .ok(); } - }); - - connectings.lock().unwrap().insert(connecting_id, task); - start_task_tx.send(()).unwrap_or(()); - } - None => { - tracing::error!("Stopped accepting new connections"); - break; - } + Err(error) => tracing::error!(?error, %addr, "Failed to accept connection"), + } + }); + } else { + tracing::error!("Stopped accepting new connections"); + break; } } } diff --git a/net/examples/peer.rs b/net/examples/peer.rs index d401713b6..34c3d696b 100644 --- a/net/examples/peer.rs +++ b/net/examples/peer.rs @@ -106,7 +106,7 @@ async fn run_tcp_server(addr: SocketAddr) -> Result<()> { } async fn run_quic_server(addr: SocketAddr) -> Result<()> { - let (_, mut acceptor, _) = quic::configure(addr).await?; + let (_, acceptor, _) = quic::configure(addr).await?; println!("bound to {}", acceptor.local_addr()); loop { @@ -114,9 +114,9 @@ async fn run_quic_server(addr: SocketAddr) -> Result<()> { .accept() .await .context("failed to accept")? - .finish() + .complete() .await?; - let addr = *connection.remote_address(); + let addr = *connection.remote_addr(); task::spawn(run_server_connection(connection, addr)); } } diff --git a/net/src/quic.rs b/net/src/quic.rs index 62e926fbb..4b08fa1a5 100644 --- a/net/src/quic.rs +++ b/net/src/quic.rs @@ -59,7 +59,7 @@ pub struct Acceptor { } impl Acceptor { - pub async fn accept(&mut self) -> Option { + pub async fn accept(&self) -> Option { self.endpoint .accept() .await @@ -76,7 +76,11 @@ pub struct Connecting { } impl Connecting { - pub async fn finish(self) -> Result { + pub fn remote_addr(&self) -> SocketAddr { + self.incoming.remote_address() + } + + pub async fn complete(self) -> Result { let connection = self.incoming.await?; let (tx, rx) = connection.accept_bi().await?; Ok(Connection::new(rx, tx, connection.remote_address())) @@ -87,22 +91,22 @@ impl Connecting { pub struct Connection { rx: Option, tx: Option, - remote_address: SocketAddr, + remote_addr: SocketAddr, can_finish: bool, } impl Connection { - pub fn new(rx: quinn::RecvStream, tx: quinn::SendStream, remote_address: SocketAddr) -> Self { + pub fn new(rx: quinn::RecvStream, tx: quinn::SendStream, remote_addr: SocketAddr) -> Self { Self { rx: Some(rx), tx: Some(tx), - remote_address, + remote_addr, can_finish: true, } } - pub fn remote_address(&self) -> &SocketAddr { - &self.remote_address + pub fn remote_addr(&self) -> &SocketAddr { + &self.remote_addr } pub fn into_split(mut self) -> (OwnedReadHalf, OwnedWriteHalf) { @@ -712,13 +716,12 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn small_data_exchange() { - let (connector, mut acceptor, _) = - configure((Ipv4Addr::LOCALHOST, 0).into()).await.unwrap(); + let (connector, acceptor, _) = configure((Ipv4Addr::LOCALHOST, 0).into()).await.unwrap(); let addr = *acceptor.local_addr(); let h1 = task::spawn(async move { - let mut conn = acceptor.accept().await.unwrap().finish().await.unwrap(); + let mut conn = acceptor.accept().await.unwrap().complete().await.unwrap(); let mut buf = [0; 4]; conn.read_exact(&mut buf).await.unwrap(); @@ -754,7 +757,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn side_channel() { - let (_connector, mut acceptor, side_channel_maker) = + let (_connector, acceptor, side_channel_maker) = configure((Ipv4Addr::LOCALHOST, 0).into()).await.unwrap(); let addr = *acceptor.local_addr(); let side_channel = side_channel_maker.make(); From 661321e9a6434642370e1b3ac9b2fa9fb73457b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Tue, 3 Sep 2024 09:59:50 +0200 Subject: [PATCH 04/55] Move raw::Stream to the net crate and rename to Connection --- lib/src/network/gateway.rs | 29 +++++---- lib/src/network/message_broker.rs | 6 +- lib/src/network/message_dispatcher.rs | 65 ++++++++++--------- lib/src/network/mod.rs | 29 +++++---- lib/src/network/stats.rs | 11 +--- .../network/raw.rs => net/src/connection.rs | 56 ++++++++-------- net/src/lib.rs | 1 + net/src/quic.rs | 2 +- 8 files changed, 100 insertions(+), 99 deletions(-) rename lib/src/network/raw.rs => net/src/connection.rs (68%) diff --git a/lib/src/network/gateway.rs b/lib/src/network/gateway.rs index 801d36116..7705033ac 100644 --- a/lib/src/network/gateway.rs +++ b/lib/src/network/gateway.rs @@ -1,7 +1,8 @@ -use super::{ip, peer_addr::PeerAddr, peer_source::PeerSource, raw, seen_peers::SeenPeer}; +use super::{ip, peer_addr::PeerAddr, peer_source::PeerSource, seen_peers::SeenPeer}; use crate::sync::atomic_slot::AtomicSlot; use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use net::{ + connection::Connection, quic, tcp::{TcpListener, TcpStream}, }; @@ -19,7 +20,7 @@ use tracing::{field, Instrument, Span}; /// Established incoming and outgoing connections. pub(super) struct Gateway { stacks: AtomicSlot, - incoming_tx: mpsc::Sender<(raw::Stream, PeerAddr)>, + incoming_tx: mpsc::Sender<(Connection, PeerAddr)>, connectivity_tx: watch::Sender, } @@ -27,7 +28,7 @@ impl Gateway { /// Create a new `Gateway` that is initially disabled. /// /// `incoming_tx` is the sender for the incoming connections. - pub fn new(incoming_tx: mpsc::Sender<(raw::Stream, PeerAddr)>) -> Self { + pub fn new(incoming_tx: mpsc::Sender<(Connection, PeerAddr)>) -> Self { let stacks = Stacks::unbound(); let stacks = AtomicSlot::new(stacks); @@ -120,7 +121,7 @@ impl Gateway { &self, peer: &SeenPeer, source: PeerSource, - ) -> Option { + ) -> Option { if !ok_to_connect(peer.addr_if_seen()?.socket_addr(), source) { tracing::debug!("Invalid peer address - discarding"); return None; @@ -262,7 +263,7 @@ impl Stacks { async fn bind( bind: &StackAddresses, - incoming_tx: mpsc::Sender<(raw::Stream, PeerAddr)>, + incoming_tx: mpsc::Sender<(Connection, PeerAddr)>, ) -> ( Self, Option, @@ -337,11 +338,11 @@ impl Stacks { self.tcp_v6.as_ref().map(|stack| &stack.listener_local_addr) } - async fn connect(&self, addr: PeerAddr) -> Result { + async fn connect(&self, addr: PeerAddr) -> Result { match addr { PeerAddr::Tcp(addr) => TcpStream::connect(addr) .await - .map(raw::Stream::Tcp) + .map(Connection::Tcp) .map_err(ConnectError::Tcp), PeerAddr::Quic(addr) => { let stack = self @@ -352,7 +353,7 @@ impl Stacks { .connector .connect(addr) .await - .map(raw::Stream::Quic) + .map(Connection::Quic) .map_err(ConnectError::Quic) } } @@ -439,7 +440,7 @@ struct QuicStack { impl QuicStack { async fn new( bind_addr: SocketAddr, - incoming_tx: mpsc::Sender<(raw::Stream, PeerAddr)>, + incoming_tx: mpsc::Sender<(Connection, PeerAddr)>, ) -> Option<(Self, quic::SideChannelMaker)> { let span = tracing::info_span!("listener", addr = field::Empty); @@ -494,7 +495,7 @@ struct TcpStack { impl TcpStack { async fn new( bind_addr: SocketAddr, - incoming_tx: mpsc::Sender<(raw::Stream, PeerAddr)>, + incoming_tx: mpsc::Sender<(Connection, PeerAddr)>, ) -> Option { let span = tracing::info_span!("listener", addr = field::Empty); @@ -539,7 +540,7 @@ impl TcpStack { } } -async fn run_tcp_listener(listener: TcpListener, tx: mpsc::Sender<(raw::Stream, PeerAddr)>) { +async fn run_tcp_listener(listener: TcpListener, tx: mpsc::Sender<(Connection, PeerAddr)>) { loop { let result = select! { result = listener.accept() => result, @@ -548,7 +549,7 @@ async fn run_tcp_listener(listener: TcpListener, tx: mpsc::Sender<(raw::Stream, match result { Ok((stream, addr)) => { - tx.send((raw::Stream::Tcp(stream), PeerAddr::Tcp(addr))) + tx.send((Connection::Tcp(stream), PeerAddr::Tcp(addr))) .await .ok(); } @@ -560,7 +561,7 @@ async fn run_tcp_listener(listener: TcpListener, tx: mpsc::Sender<(raw::Stream, } } -async fn run_quic_listener(listener: quic::Acceptor, tx: mpsc::Sender<(raw::Stream, PeerAddr)>) { +async fn run_quic_listener(listener: quic::Acceptor, tx: mpsc::Sender<(Connection, PeerAddr)>) { let mut tasks = JoinSet::new(); loop { @@ -577,7 +578,7 @@ async fn run_quic_listener(listener: quic::Acceptor, tx: mpsc::Sender<(raw::Stre tasks.spawn(async move { match connecting.complete().await { Ok(connection) => { - tx.send((raw::Stream::Quic(connection), PeerAddr::Quic(addr))) + tx.send((Connection::Quic(connection), PeerAddr::Quic(addr))) .await .ok(); } diff --git a/lib/src/network/message_broker.rs b/lib/src/network/message_broker.rs index 02620c2ed..56bdbda14 100644 --- a/lib/src/network/message_broker.rs +++ b/lib/src/network/message_broker.rs @@ -6,7 +6,6 @@ use super::{ message::{Content, MessageChannelId, Request, Response}, message_dispatcher::{ContentSink, ContentStream, MessageDispatcher}, peer_exchange::{PexPeer, PexReceiver, PexRepository, PexSender}, - raw, runtime_id::PublicRuntimeId, server::Server, stats::{ByteCounters, Instrumented}, @@ -18,6 +17,7 @@ use crate::{ repository::Vault, }; use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; +use net::connection::Connection; use state_monitor::StateMonitor; use std::{future, sync::Arc}; use tokio::{ @@ -67,10 +67,10 @@ impl MessageBroker { } } - pub fn add_connection(&self, stream: Instrumented, permit: ConnectionPermit) { + pub fn add_connection(&self, connection: Instrumented, permit: ConnectionPermit) { self.pex_peer .handle_connection(permit.addr(), permit.source(), permit.released()); - self.dispatcher.bind(stream, permit) + self.dispatcher.bind(connection, permit) } /// Has this broker at least one live connection? diff --git a/lib/src/network/message_dispatcher.rs b/lib/src/network/message_dispatcher.rs index fffb29997..d7114a7b8 100644 --- a/lib/src/network/message_dispatcher.rs +++ b/lib/src/network/message_dispatcher.rs @@ -4,12 +4,12 @@ use super::{ connection::{ConnectionId, ConnectionPermit, ConnectionPermitHalf}, message::{Message, MessageChannelId}, message_io::{MessageSink, MessageStream, MESSAGE_OVERHEAD}, - raw, stats::Instrumented, }; use crate::{collections::HashMap, sync::AwaitDrop}; use async_trait::async_trait; use futures_util::{future, ready, stream::SelectAll, FutureExt, Sink, SinkExt, Stream, StreamExt}; +use net::connection::{Connection, OwnedReadHalf, OwnedWriteHalf}; use std::{ io, pin::Pin, @@ -56,8 +56,10 @@ impl MessageDispatcher { /// Bind this dispatcher to the given TCP of QUIC socket. Can be bound to multiple sockets and /// the failed ones are automatically removed. - pub fn bind(&self, socket: Instrumented, permit: ConnectionPermit) { - self.command_tx.send(Command::Bind { socket, permit }).ok(); + pub fn bind(&self, connection: Instrumented, permit: ConnectionPermit) { + self.command_tx + .send(Command::Bind { connection, permit }) + .ok(); } /// Is this dispatcher bound to at least one connection? @@ -250,7 +252,7 @@ pub(super) struct ChannelClosed; struct ConnectionStream { // The reader is doubly instrumented - first time to track per connection stats and second time // to track cumulative stats across all connections. - reader: MessageStream>>, + reader: MessageStream>>, permit: ConnectionPermitHalf, permit_released: AwaitDrop, connection_count: Arc, @@ -258,7 +260,7 @@ struct ConnectionStream { impl ConnectionStream { fn new( - reader: Instrumented, + reader: Instrumented, permit: ConnectionPermitHalf, connection_count: Arc, ) -> Self { @@ -305,13 +307,13 @@ impl Drop for ConnectionStream { struct ConnectionSink { // The writer is doubly instrumented - first time to track per connection stats and second time // to track cumulative stats across all connections. - writer: MessageSink>>, + writer: MessageSink>>, _permit: ConnectionPermitHalf, permit_released: AwaitDrop, } impl ConnectionSink { - fn new(writer: Instrumented, permit: ConnectionPermitHalf) -> Self { + fn new(writer: Instrumented, permit: ConnectionPermitHalf) -> Self { let permit_released = permit.released(); Self { @@ -407,8 +409,8 @@ impl Worker { Command::Close { channel } => { self.recv.channels.remove(&channel); } - Command::Bind { socket, permit } => { - let (reader, writer) = socket.into_split(); + Command::Bind { connection, permit } => { + let (reader, writer) = connection.into_split(); let (send_permit, recv_permit) = permit.into_split(); self.send @@ -453,7 +455,7 @@ enum Command { channel: MessageChannelId, }, Bind { - socket: Instrumented, + connection: Instrumented, permit: ConnectionPermit, }, Shutdown { @@ -548,7 +550,10 @@ mod tests { use super::{super::stats::ByteCounters, *}; use assert_matches::assert_matches; use futures_util::stream; - use net::tcp::{TcpListener, TcpStream}; + use net::{ + connection::Connection, + tcp::{TcpListener, TcpStream}, + }; use std::{collections::BTreeSet, net::Ipv4Addr, str::from_utf8, time::Duration}; #[tokio::test(flavor = "multi_thread")] @@ -559,9 +564,9 @@ mod tests { let server_dispatcher = MessageDispatcher::new(); let mut server_stream = server_dispatcher.open_recv(channel); - let (client_socket, server_socket) = create_connected_sockets().await; - let mut client_sink = MessageSink::new(client_socket); - server_dispatcher.bind(server_socket, ConnectionPermit::dummy()); + let (client, server) = create_connection_pair().await; + let mut client_sink = MessageSink::new(client); + server_dispatcher.bind(server, ConnectionPermit::dummy()); client_sink .send(Message { @@ -587,9 +592,9 @@ mod tests { let server_stream0 = server_dispatcher.open_recv(channel0); let server_stream1 = server_dispatcher.open_recv(channel1); - let (client_socket, server_socket) = create_connected_sockets().await; - let mut client_sink = MessageSink::new(client_socket); - server_dispatcher.bind(server_socket, ConnectionPermit::dummy()); + let (client, server) = create_connection_pair().await; + let mut client_sink = MessageSink::new(client); + server_dispatcher.bind(server, ConnectionPermit::dummy()); for (channel, content) in [(channel0, send_content0), (channel1, send_content1)] { client_sink @@ -625,9 +630,9 @@ mod tests { let server_stream0 = server_dispatcher.open_recv(channel0); let server_stream1 = server_dispatcher.open_recv(channel1); - let (client_socket, server_socket) = create_connected_sockets().await; - client_dispatcher.bind(client_socket, ConnectionPermit::dummy()); - server_dispatcher.bind(server_socket, ConnectionPermit::dummy()); + let (client, server) = create_connection_pair().await; + client_dispatcher.bind(client, ConnectionPermit::dummy()); + server_dispatcher.bind(server, ConnectionPermit::dummy()); let num_messages = 20; let mut send_tasks = vec![]; @@ -671,9 +676,9 @@ mod tests { let mut server_stream0 = server_dispatcher.open_recv(channel); let mut server_stream1 = server_dispatcher.open_recv(channel); - let (client_socket, server_socket) = create_connected_sockets().await; - let mut client_sink = MessageSink::new(client_socket); - server_dispatcher.bind(server_socket, ConnectionPermit::dummy()); + let (client, server) = create_connection_pair().await; + let mut client_sink = MessageSink::new(client); + server_dispatcher.bind(server, ConnectionPermit::dummy()); for content in [send_content0, send_content1] { client_sink @@ -703,8 +708,8 @@ mod tests { let server_dispatcher = MessageDispatcher::new(); let mut server_stream = server_dispatcher.open_recv(channel); - let (client_socket0, server_socket0) = create_connected_sockets().await; - let (client_socket1, server_socket1) = create_connected_sockets().await; + let (client_socket0, server_socket0) = create_connection_pair().await; + let (client_socket1, server_socket1) = create_connection_pair().await; let client_sink0 = MessageSink::new(client_socket0); let client_sink1 = MessageSink::new(client_socket1); @@ -754,8 +759,8 @@ mod tests { let server_dispatcher = MessageDispatcher::new(); let server_sink = server_dispatcher.open_send(channel); - let (client_socket0, server_socket0) = create_connected_sockets().await; - let (client_socket1, server_socket1) = create_connected_sockets().await; + let (client_socket0, server_socket0) = create_connection_pair().await; + let (client_socket1, server_socket1) = create_connection_pair().await; let client_stream0 = MessageStream::new(client_socket0); let client_stream1 = MessageStream::new(client_socket1); @@ -798,7 +803,7 @@ mod tests { assert_matches!(server_sink.send(vec![]).await, Err(ChannelClosed)); } - async fn create_connected_sockets() -> (Instrumented, Instrumented) { + async fn create_connection_pair() -> (Instrumented, Instrumented) { let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0u16)) .await .unwrap(); @@ -808,8 +813,8 @@ mod tests { let (server, _) = listener.accept().await.unwrap(); ( - Instrumented::new(raw::Stream::Tcp(client), Arc::new(ByteCounters::default())), - Instrumented::new(raw::Stream::Tcp(server), Arc::new(ByteCounters::default())), + Instrumented::new(Connection::Tcp(client), Arc::new(ByteCounters::default())), + Instrumented::new(Connection::Tcp(server), Arc::new(ByteCounters::default())), ) } } diff --git a/lib/src/network/mod.rs b/lib/src/network/mod.rs index 71dbc9f28..0db3ad3b4 100644 --- a/lib/src/network/mod.rs +++ b/lib/src/network/mod.rs @@ -20,7 +20,6 @@ mod peer_source; mod peer_state; mod pending; mod protocol; -mod raw; mod runtime_id; mod seen_peers; mod server; @@ -41,6 +40,7 @@ pub use self::{ runtime_id::{PublicRuntimeId, SecretRuntimeId}, stats::Stats, }; +use net::connection::Connection; pub use net::stun::NatBehavior; use self::{ @@ -718,9 +718,9 @@ impl Inner { async fn handle_incoming_connections( self: Arc, - mut rx: mpsc::Receiver<(raw::Stream, PeerAddr)>, + mut rx: mpsc::Receiver<(Connection, PeerAddr)>, ) { - while let Some((stream, addr)) = rx.recv().await { + while let Some((connection, addr)) = rx.recv().await { match self.connections.reserve(addr, PeerSource::Listener) { ReserveResult::Permit(permit) => { if self.is_shutdown() { @@ -739,7 +739,7 @@ impl Inner { monitor.mark_as_connecting(permit.id()); self.spawn(async move { - this.handle_connection(stream, permit, &monitor).await; + this.handle_connection(connection, permit, &monitor).await; }); } ReserveResult::Occupied(_, _their_source, permit_id) => { @@ -840,7 +840,7 @@ impl Inner { /// Return true iff the peer is suitable for reconnection. async fn handle_connection( &self, - mut stream: raw::Stream, + mut connection: Connection, permit: ConnectionPermit, monitor: &ConnectionMonitor, ) -> bool { @@ -849,7 +849,8 @@ impl Inner { permit.mark_as_handshaking(); monitor.mark_as_handshaking(); - let handshake_result = perform_handshake(&mut stream, VERSION, &self.this_runtime_id).await; + let handshake_result = + perform_handshake(&mut connection, VERSION, &self.this_runtime_id).await; if let Err(error) = &handshake_result { tracing::debug!(parent: monitor.span(), ?error, "Handshake failed"); @@ -915,8 +916,8 @@ impl Inner { broker }); - let stream = Instrumented::new(stream, self.stats_tracker.bytes.clone()); - broker.add_connection(stream, permit); + let connection = Instrumented::new(connection, self.stats_tracker.bytes.clone()); + broker.add_connection(connection, permit); } let _remover = MessageBrokerEntryGuard { @@ -962,28 +963,28 @@ impl Inner { // Exchange runtime ids with the peer. Returns their (verified) runtime id. async fn perform_handshake( - stream: &mut raw::Stream, + connection: &mut Connection, this_version: Version, this_runtime_id: &SecretRuntimeId, ) -> Result { let result = tokio::time::timeout(std::time::Duration::from_secs(5), async move { - stream.write_all(MAGIC).await?; + connection.write_all(MAGIC).await?; - this_version.write_into(stream).await?; + this_version.write_into(connection).await?; let mut that_magic = [0; MAGIC.len()]; - stream.read_exact(&mut that_magic).await?; + connection.read_exact(&mut that_magic).await?; if MAGIC != &that_magic { return Err(HandshakeError::BadMagic); } - let that_version = Version::read_from(stream).await?; + let that_version = Version::read_from(connection).await?; if that_version > this_version { return Err(HandshakeError::ProtocolVersionMismatch(that_version)); } - let that_runtime_id = runtime_id::exchange(this_runtime_id, stream).await?; + let that_runtime_id = runtime_id::exchange(this_runtime_id, connection).await?; Ok(that_runtime_id) }) diff --git a/lib/src/network/stats.rs b/lib/src/network/stats.rs index 74286586f..80441a7fb 100644 --- a/lib/src/network/stats.rs +++ b/lib/src/network/stats.rs @@ -1,4 +1,4 @@ -use super::raw; +use net::connection::{Connection, OwnedReadHalf, OwnedWriteHalf}; use pin_project_lite::pin_project; use serde::{Deserialize, Serialize}; use std::{ @@ -223,13 +223,8 @@ where } } -impl Instrumented { - pub fn into_split( - self, - ) -> ( - Instrumented, - Instrumented, - ) { +impl Instrumented { + pub fn into_split(self) -> (Instrumented, Instrumented) { let (reader, writer) = self.inner.into_split(); ( diff --git a/lib/src/network/raw.rs b/net/src/connection.rs similarity index 68% rename from lib/src/network/raw.rs rename to net/src/connection.rs index e742545e4..57d6b44d5 100644 --- a/lib/src/network/raw.rs +++ b/net/src/connection.rs @@ -1,7 +1,4 @@ -use net::{ - quic, - tcp::{self, TcpStream}, -}; +use crate::{quic, tcp}; use std::{ io, pin::Pin, @@ -9,48 +6,49 @@ use std::{ }; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -pub enum Stream { - Tcp(TcpStream), +/// Network connection that supports multiple protocols. +pub enum Connection { + Tcp(tcp::TcpStream), Quic(quic::Connection), } -impl Stream { +impl Connection { pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { match self { - Stream::Tcp(con) => { - let (rx, tx) = con.into_split(); - (OwnedReadHalf::Tcp(rx), OwnedWriteHalf::Tcp(tx)) + Self::Tcp(conn) => { + let (reader, writer) = conn.into_split(); + (OwnedReadHalf::Tcp(reader), OwnedWriteHalf::Tcp(writer)) } - Stream::Quic(con) => { - let (rx, tx) = con.into_split(); - (OwnedReadHalf::Quic(rx), OwnedWriteHalf::Quic(tx)) + Self::Quic(conn) => { + let (reader, writer) = conn.into_split(); + (OwnedReadHalf::Quic(reader), OwnedWriteHalf::Quic(writer)) } } } } -impl AsyncRead for Stream { +impl AsyncRead for Connection { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { match self.get_mut() { - Stream::Tcp(s) => Pin::new(s).poll_read(cx, buf), - Stream::Quic(s) => Pin::new(s).poll_read(cx, buf), + Self::Tcp(s) => Pin::new(s).poll_read(cx, buf), + Self::Quic(s) => Pin::new(s).poll_read(cx, buf), } } } -impl AsyncWrite for Stream { +impl AsyncWrite for Connection { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { match self.get_mut() { - Stream::Tcp(s) => Pin::new(s).poll_write(cx, buf), - Stream::Quic(s) => Pin::new(s).poll_write(cx, buf), + Self::Tcp(s) => Pin::new(s).poll_write(cx, buf), + Self::Quic(s) => Pin::new(s).poll_write(cx, buf), } } @@ -60,29 +58,29 @@ impl AsyncWrite for Stream { bufs: &[io::IoSlice<'_>], ) -> Poll> { match self.get_mut() { - Stream::Tcp(s) => Pin::new(s).poll_write_vectored(cx, bufs), - Stream::Quic(s) => Pin::new(s).poll_write_vectored(cx, bufs), + Self::Tcp(s) => Pin::new(s).poll_write_vectored(cx, bufs), + Self::Quic(s) => Pin::new(s).poll_write_vectored(cx, bufs), } } fn is_write_vectored(&self) -> bool { match self { - Stream::Tcp(s) => s.is_write_vectored(), - Stream::Quic(s) => s.is_write_vectored(), + Self::Tcp(s) => s.is_write_vectored(), + Self::Quic(s) => s.is_write_vectored(), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { - Stream::Tcp(s) => Pin::new(s).poll_flush(cx), - Stream::Quic(s) => Pin::new(s).poll_flush(cx), + Self::Tcp(s) => Pin::new(s).poll_flush(cx), + Self::Quic(s) => Pin::new(s).poll_flush(cx), } } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { - Stream::Tcp(s) => Pin::new(s).poll_shutdown(cx), - Stream::Quic(s) => Pin::new(s).poll_shutdown(cx), + Self::Tcp(s) => Pin::new(s).poll_shutdown(cx), + Self::Quic(s) => Pin::new(s).poll_shutdown(cx), } } } @@ -99,8 +97,8 @@ impl AsyncRead for OwnedReadHalf { buf: &mut ReadBuf<'_>, ) -> Poll> { match self.get_mut() { - OwnedReadHalf::Tcp(rx) => Pin::new(rx).poll_read(cx, buf), - OwnedReadHalf::Quic(rx) => Pin::new(rx).poll_read(cx, buf), + Self::Tcp(rx) => Pin::new(rx).poll_read(cx, buf), + Self::Quic(rx) => Pin::new(rx).poll_read(cx, buf), } } } diff --git a/net/src/lib.rs b/net/src/lib.rs index bb4b94659..bee42acfc 100644 --- a/net/src/lib.rs +++ b/net/src/lib.rs @@ -1,5 +1,6 @@ use std::time::Duration; +pub mod connection; pub mod quic; pub mod stun; pub mod tcp; diff --git a/net/src/quic.rs b/net/src/quic.rs index 4b08fa1a5..6d564d539 100644 --- a/net/src/quic.rs +++ b/net/src/quic.rs @@ -742,7 +742,7 @@ mod tests { quinn::ReadError::ConnectionLost( quinn::ConnectionError::ApplicationClosed(_), ) => { - // connection graceflly closed by the peer, this is expected. + // connection gracefully closed by the peer, this is expected. } error => panic!("unexpected error: {:?}", error), }, From 10f601194e8b432d7aa9733dc291897e4ea5929f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Tue, 3 Sep 2024 12:02:32 +0200 Subject: [PATCH 05/55] Simplify quic Connection/OwnedReadHalf/OwnedWriteHalf implementation --- net/src/quic.rs | 188 +++++++++--------------------------------------- 1 file changed, 34 insertions(+), 154 deletions(-) diff --git a/net/src/quic.rs b/net/src/quic.rs index 6d564d539..9c718f960 100644 --- a/net/src/quic.rs +++ b/net/src/quic.rs @@ -17,10 +17,7 @@ use std::{ io, net::SocketAddr, pin::Pin, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, + sync::Arc, task::{Context, Poll}, }; use tokio::{ @@ -89,19 +86,21 @@ impl Connecting { //------------------------------------------------------------------------------ pub struct Connection { - rx: Option, - tx: Option, + reader: quinn::RecvStream, + writer: quinn::SendStream, remote_addr: SocketAddr, - can_finish: bool, } impl Connection { - pub fn new(rx: quinn::RecvStream, tx: quinn::SendStream, remote_addr: SocketAddr) -> Self { + pub fn new( + reader: quinn::RecvStream, + writer: quinn::SendStream, + remote_addr: SocketAddr, + ) -> Self { Self { - rx: Some(rx), - tx: Some(tx), + reader, + writer, remote_addr, - can_finish: true, } } @@ -109,191 +108,76 @@ impl Connection { &self.remote_addr } - pub fn into_split(mut self) -> (OwnedReadHalf, OwnedWriteHalf) { - // Unwrap OK because `self` can't be split more than once and we're not `taking` from `rx` - // anywhere else. - let rx = self.rx.take().unwrap(); - let tx = self.tx.take(); - let can_finish = Arc::new(AtomicBool::new(self.can_finish)); + pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { ( - OwnedReadHalf { - rx, - can_finish: can_finish.clone(), - }, - OwnedWriteHalf { tx, can_finish }, + OwnedReadHalf { inner: self.reader }, + OwnedWriteHalf { inner: self.writer }, ) } } impl AsyncRead for Connection { fn poll_read( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - let this = self.get_mut(); - match &mut this.rx { - Some(rx) => { - let poll = Pin::new(rx).poll_read(cx, buf); - if let Poll::Ready(r) = &poll { - this.can_finish &= r.is_ok(); - }; - poll - } - None => Poll::Ready(Err(io::Error::new( - io::ErrorKind::BrokenPipe, - "connection was split", - ))), - } + AsyncRead::poll_read(Pin::new(&mut self.reader), cx, buf) } } impl AsyncWrite for Connection { fn poll_write( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - let this = self.get_mut(); - match &mut this.tx { - Some(tx) => Pin::new(tx) - .poll_write(cx, buf) - .map_err(|error| error.into()), - None => Poll::Ready(Err(io::Error::new( - io::ErrorKind::BrokenPipe, - "already finished", - ))), - } - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - match &mut self.get_mut().tx { - Some(tx) => Pin::new(tx).poll_flush(cx), - None => Poll::Ready(Err(io::Error::new( - io::ErrorKind::BrokenPipe, - "already finished", - ))), - } + AsyncWrite::poll_write(Pin::new(&mut self.writer), cx, buf) } - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - match &mut self.get_mut().tx { - Some(tx) => Pin::new(tx).poll_shutdown(cx), - None => Poll::Ready(Err(io::Error::new( - io::ErrorKind::BrokenPipe, - "already finished", - ))), - } + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + AsyncWrite::poll_flush(Pin::new(&mut self.writer), cx) } -} - -impl Drop for Connection { - fn drop(&mut self) { - if !self.can_finish { - return; - } - if let Some(mut tx) = self.tx.take() { - tx.finish().ok(); - } + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + AsyncWrite::poll_shutdown(Pin::new(&mut self.writer), cx) } } //------------------------------------------------------------------------------ pub struct OwnedReadHalf { - rx: quinn::RecvStream, - can_finish: Arc, + inner: quinn::RecvStream, } + pub struct OwnedWriteHalf { - tx: Option, - can_finish: Arc, + inner: quinn::SendStream, } impl AsyncRead for OwnedReadHalf { fn poll_read( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - let this = self.get_mut(); - let poll = Pin::new(&mut this.rx).poll_read(cx, buf); - if let Poll::Ready(r) = &poll { - if r.is_err() { - this.can_finish.store(false, Ordering::SeqCst); - } - } - poll + AsyncRead::poll_read(Pin::new(&mut self.inner), cx, buf) } } impl AsyncWrite for OwnedWriteHalf { fn poll_write( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - let this = self.get_mut(); - match &mut this.tx { - Some(tx) => match Pin::new(tx).poll_write(cx, buf) { - Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), - Poll::Ready(Err(error)) => { - this.can_finish.store(false, Ordering::SeqCst); - Poll::Ready(Err(error.into())) - } - Poll::Pending => Poll::Pending, - }, - None => Poll::Ready(Err(io::Error::new( - io::ErrorKind::BrokenPipe, - "already finished", - ))), - } - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let this = self.get_mut(); - match &mut this.tx { - Some(tx) => { - let poll = Pin::new(tx).poll_flush(cx); - - if let Poll::Ready(r) = &poll { - if r.is_err() { - this.can_finish.store(false, Ordering::SeqCst); - } - } - - poll - } - None => Poll::Ready(Err(io::Error::new( - io::ErrorKind::BrokenPipe, - "already finished", - ))), - } + AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, buf) } - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let this = self.get_mut(); - match &mut this.tx { - Some(tx) => { - this.can_finish.store(false, Ordering::SeqCst); - Pin::new(tx).poll_shutdown(cx) - } - None => Poll::Ready(Err(io::Error::new( - io::ErrorKind::BrokenPipe, - "already finished", - ))), - } + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + AsyncWrite::poll_flush(Pin::new(&mut self.inner), cx) } -} -impl Drop for OwnedWriteHalf { - fn drop(&mut self) { - if !self.can_finish.load(Ordering::SeqCst) { - return; - } - - if let Some(mut tx) = self.tx.take() { - tx.finish().ok(); - } + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + AsyncWrite::poll_shutdown(Pin::new(&mut self.inner), cx) } } @@ -490,10 +374,7 @@ impl CustomUdpSocket { } fn side_channel_maker(self: Arc) -> SideChannelMaker { - SideChannelMaker { - socket: self.clone(), - packet_tx: self.side_channel_tx.clone(), - } + SideChannelMaker { socket: self } } } @@ -628,14 +509,13 @@ impl fmt::Debug for UdpPollHelper { /// Makes new `SideChannel`s. pub struct SideChannelMaker { socket: Arc, - packet_tx: broadcast::Sender, } impl SideChannelMaker { pub fn make(&self) -> SideChannel { SideChannel { socket: self.socket.clone(), - packet_rx: AsyncMutex::new(self.packet_tx.subscribe()), + packet_rx: AsyncMutex::new(self.socket.side_channel_tx.subscribe()), } } } From 37193e6f7bce5d1243230fa538bc51e13dd3c009 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Tue, 3 Sep 2024 14:32:58 +0200 Subject: [PATCH 06/55] Make the TCP API more similar to the QUIC API --- lib/src/network/gateway.rs | 115 ++++++++++++-------------- lib/src/network/local_discovery.rs | 2 +- lib/src/network/message_dispatcher.rs | 18 ++-- lib/src/network/mod.rs | 2 +- net/Cargo.toml | 1 + net/examples/peer.rs | 10 +-- net/examples/stun.rs | 10 ++- net/src/quic.rs | 14 ++-- net/src/tcp.rs | 50 ++++++++++- net/src/udp.rs | 4 +- 10 files changed, 130 insertions(+), 96 deletions(-) diff --git a/lib/src/network/gateway.rs b/lib/src/network/gateway.rs index 7705033ac..39288a5e1 100644 --- a/lib/src/network/gateway.rs +++ b/lib/src/network/gateway.rs @@ -1,11 +1,7 @@ use super::{ip, peer_addr::PeerAddr, peer_source::PeerSource, seen_peers::SeenPeer}; use crate::sync::atomic_slot::AtomicSlot; use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; -use net::{ - connection::Connection, - quic, - tcp::{TcpListener, TcpStream}, -}; +use net::{connection::Connection, quic, tcp}; use scoped_task::ScopedJoinHandle; use std::net::{IpAddr, SocketAddr}; use thiserror::Error; @@ -65,7 +61,7 @@ impl Gateway { } /// Binds the gateway to the specified addresses. Rebinds if already bound. - pub async fn bind( + pub fn bind( &self, bind: &StackAddresses, ) -> ( @@ -73,7 +69,7 @@ impl Gateway { Option, ) { let (next, side_channel_maker_v4, side_channel_maker_v6) = - Stacks::bind(bind, self.incoming_tx.clone()).await; + Stacks::bind(bind, self.incoming_tx.clone()); let prev = self.stacks.swap(next); let next = self.stacks.read(); @@ -226,11 +222,11 @@ impl Gateway { #[derive(Debug, Error)] pub(super) enum ConnectError { #[error("TCP error")] - Tcp(std::io::Error), + Tcp(tcp::Error), #[error("QUIC error")] Quic(quic::Error), - #[error("No corresponding QUIC connector")] - NoSuitableQuicConnector, + #[error("No corresponding connector")] + NoSuitableConnector, } impl ConnectError { @@ -261,7 +257,7 @@ impl Stacks { } } - async fn bind( + fn bind( bind: &StackAddresses, incoming_tx: mpsc::Sender<(Connection, PeerAddr)>, ) -> ( @@ -271,7 +267,6 @@ impl Stacks { ) { let (quic_v4, side_channel_maker_v4) = if let Some(addr) = bind.quic_v4 { QuicStack::new(addr, incoming_tx.clone()) - .await .map(|(stack, side_channel)| (Some(stack), Some(side_channel))) .unwrap_or((None, None)) } else { @@ -280,7 +275,6 @@ impl Stacks { let (quic_v6, side_channel_maker_v6) = if let Some(addr) = bind.quic_v6 { QuicStack::new(addr, incoming_tx.clone()) - .await .map(|(stack, side_channel)| (Some(stack), Some(side_channel))) .unwrap_or((None, None)) } else { @@ -288,13 +282,13 @@ impl Stacks { }; let tcp_v4 = if let Some(addr) = bind.tcp_v4 { - TcpStack::new(addr, incoming_tx.clone()).await + TcpStack::new(addr, incoming_tx.clone()) } else { None }; let tcp_v6 = if let Some(addr) = bind.tcp_v6 { - TcpStack::new(addr, incoming_tx).await + TcpStack::new(addr, incoming_tx) } else { None }; @@ -340,22 +334,23 @@ impl Stacks { async fn connect(&self, addr: PeerAddr) -> Result { match addr { - PeerAddr::Tcp(addr) => TcpStream::connect(addr) + PeerAddr::Tcp(addr) => self + .tcp_stack_for(&addr.ip()) + .ok_or(ConnectError::NoSuitableConnector)? + .connector + .connect(addr) .await .map(Connection::Tcp) .map_err(ConnectError::Tcp), - PeerAddr::Quic(addr) => { - let stack = self - .quic_stack_for(&addr.ip()) - .ok_or(ConnectError::NoSuitableQuicConnector)?; - - stack - .connector - .connect(addr) - .await - .map(Connection::Quic) - .map_err(ConnectError::Quic) - } + + PeerAddr::Quic(addr) => self + .quic_stack_for(&addr.ip()) + .ok_or(ConnectError::NoSuitableConnector)? + .connector + .connect(addr) + .await + .map(Connection::Quic) + .map_err(ConnectError::Quic), } } @@ -412,6 +407,13 @@ impl Stacks { Some(task) } + fn tcp_stack_for(&self, ip: &IpAddr) -> Option<&TcpStack> { + match ip { + IpAddr::V4(_) => self.tcp_v4.as_ref(), + IpAddr::V6(_) => self.tcp_v6.as_ref(), + } + } + fn quic_stack_for(&self, ip: &IpAddr) -> Option<&QuicStack> { match ip { IpAddr::V4(_) => self.quic_v4.as_ref(), @@ -438,36 +440,36 @@ struct QuicStack { } impl QuicStack { - async fn new( + fn new( bind_addr: SocketAddr, incoming_tx: mpsc::Sender<(Connection, PeerAddr)>, ) -> Option<(Self, quic::SideChannelMaker)> { - let span = tracing::info_span!("listener", addr = field::Empty); + let span = tracing::info_span!("quic", addr = field::Empty); - let (connector, listener, side_channel_maker) = match quic::configure(bind_addr).await { - Ok((connector, listener, side_channel_maker)) => { + let (connector, acceptor, side_channel_maker) = match quic::configure(bind_addr) { + Ok((connector, acceptor, side_channel_maker)) => { span.record( "addr", - field::display(PeerAddr::Quic(*listener.local_addr())), + field::display(PeerAddr::Quic(*acceptor.local_addr())), ); - tracing::info!(parent: &span, "Listener started"); + tracing::info!(parent: &span, "Stack configured"); - (connector, listener, side_channel_maker) + (connector, acceptor, side_channel_maker) } Err(error) => { tracing::warn!( parent: &span, bind_addr = %PeerAddr::Quic(bind_addr), ?error, - "Failed to start listener" + "Failed to configure stack" ); return None; } }; - let listener_local_addr = *listener.local_addr(); + let listener_local_addr = *acceptor.local_addr(); let listener_task = - scoped_task::spawn(run_quic_listener(listener, incoming_tx).instrument(span)); + scoped_task::spawn(run_quic_listener(acceptor, incoming_tx).instrument(span)); let hole_puncher = side_channel_maker.make().sender(); @@ -490,60 +492,45 @@ impl QuicStack { struct TcpStack { listener_local_addr: SocketAddr, _listener_task: ScopedJoinHandle<()>, + connector: tcp::Connector, } impl TcpStack { - async fn new( + fn new( bind_addr: SocketAddr, incoming_tx: mpsc::Sender<(Connection, PeerAddr)>, ) -> Option { - let span = tracing::info_span!("listener", addr = field::Empty); - - let listener = match TcpListener::bind(bind_addr).await { - Ok(listener) => listener, - Err(error) => { - tracing::warn!( - parent: &span, - bind_addr = %PeerAddr::Tcp(bind_addr), - ?error, - "Failed to start listener", - ); - return None; - } - }; + let span = tracing::info_span!("tcp", addr = field::Empty); - let listener_local_addr = match listener.local_addr() { - Ok(addr) => { - span.record("addr", field::display(PeerAddr::Tcp(addr))); - tracing::info!(parent: &span, "Listener started"); - - addr - } + let (connector, acceptor) = match tcp::configure(bind_addr) { + Ok(stack) => stack, Err(error) => { tracing::warn!( parent: &span, bind_addr = %PeerAddr::Tcp(bind_addr), ?error, - "Failed to get listener local address", + "Failed to configure stack", ); return None; } }; + let listener_local_addr = *acceptor.local_addr(); let listener_task = - scoped_task::spawn(run_tcp_listener(listener, incoming_tx).instrument(span)); + scoped_task::spawn(run_tcp_listener(acceptor, incoming_tx).instrument(span)); Some(Self { listener_local_addr, _listener_task: listener_task, + connector, }) } } -async fn run_tcp_listener(listener: TcpListener, tx: mpsc::Sender<(Connection, PeerAddr)>) { +async fn run_tcp_listener(acceptor: tcp::Acceptor, tx: mpsc::Sender<(Connection, PeerAddr)>) { loop { let result = select! { - result = listener.accept() => result, + result = acceptor.accept() => result, _ = tx.closed() => break, }; diff --git a/lib/src/network/local_discovery.rs b/lib/src/network/local_discovery.rs index 3092f9703..a33ecba77 100644 --- a/lib/src/network/local_discovery.rs +++ b/lib/src/network/local_discovery.rs @@ -401,7 +401,7 @@ impl SocketProvider { let mut last_error: Option = None; let socket = loop { - match UdpSocket::bind_multicast(self.interface).await { + match UdpSocket::bind_multicast(self.interface) { Ok(socket) => break Arc::new(socket), Err(error) => { if last_error != Some(error.kind()) { diff --git a/lib/src/network/message_dispatcher.rs b/lib/src/network/message_dispatcher.rs index d7114a7b8..775836564 100644 --- a/lib/src/network/message_dispatcher.rs +++ b/lib/src/network/message_dispatcher.rs @@ -550,10 +550,7 @@ mod tests { use super::{super::stats::ByteCounters, *}; use assert_matches::assert_matches; use futures_util::stream; - use net::{ - connection::Connection, - tcp::{TcpListener, TcpStream}, - }; + use net::{connection::Connection, tcp}; use std::{collections::BTreeSet, net::Ipv4Addr, str::from_utf8, time::Duration}; #[tokio::test(flavor = "multi_thread")] @@ -804,13 +801,16 @@ mod tests { } async fn create_connection_pair() -> (Instrumented, Instrumented) { - let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0u16)) - .await - .unwrap(); - let client = TcpStream::connect(listener.local_addr().unwrap()) + let (_server_connector, server_acceptor) = + tcp::configure((Ipv4Addr::LOCALHOST, 0u16).into()).unwrap(); + let (client_connector, _client_acceptor) = + tcp::configure((Ipv4Addr::LOCALHOST, 0u16).into()).unwrap(); + + let client = client_connector + .connect(*server_acceptor.local_addr()) .await .unwrap(); - let (server, _) = listener.accept().await.unwrap(); + let (server, _) = server_acceptor.accept().await.unwrap(); ( Instrumented::new(Connection::Tcp(client), Arc::new(ByteCounters::default())), diff --git a/lib/src/network/mod.rs b/lib/src/network/mod.rs index 0db3ad3b4..b0e22b004 100644 --- a/lib/src/network/mod.rs +++ b/lib/src/network/mod.rs @@ -553,7 +553,7 @@ impl Inner { } // Gateway - let side_channel_makers = self.gateway.bind(&bind).instrument(self.span.clone()).await; + let side_channel_makers = self.span.in_scope(|| self.gateway.bind(&bind)); let conn = self.gateway.connectivity(); diff --git a/net/Cargo.toml b/net/Cargo.toml index fb80d9daf..e352a565a 100644 --- a/net/Cargo.toml +++ b/net/Cargo.toml @@ -20,6 +20,7 @@ stun_codec = "0.3.4" thiserror = "1.0.31" tokio = { workspace = true, features = ["io-util", "macros", "net", "rt-multi-thread", "sync"] } turmoil = { workspace = true, optional = true } +yamux = "0.13.3" [dev-dependencies] anyhow = { workspace = true } diff --git a/net/examples/peer.rs b/net/examples/peer.rs index 34c3d696b..4fefff8b7 100644 --- a/net/examples/peer.rs +++ b/net/examples/peer.rs @@ -2,7 +2,7 @@ use anyhow::{Context, Result}; use clap::{Parser, ValueEnum}; use ouisync_net::{ quic, - tcp::{TcpListener, TcpStream}, + tcp::{self, TcpStream}, }; use std::{ future, @@ -81,7 +81,7 @@ async fn run_tcp_client(addr: SocketAddr, count: Option) -> Result<()> { } async fn run_quic_client(addr: SocketAddr, count: Option) -> Result<()> { - let (connector, _, _) = quic::configure((Ipv4Addr::UNSPECIFIED, 0).into()).await?; + let (connector, _, _) = quic::configure((Ipv4Addr::UNSPECIFIED, 0).into())?; let connection = connector.connect(addr).await?; run_client_connection(connection, count).await } @@ -96,8 +96,8 @@ async fn run_server(options: &Options) -> Result<()> { } async fn run_tcp_server(addr: SocketAddr) -> Result<()> { - let acceptor = TcpListener::bind(addr).await?; - println!("bound to {}", acceptor.local_addr()?); + let (_, acceptor) = tcp::configure(addr)?; + println!("bound to {}", acceptor.local_addr()); loop { let (stream, addr) = acceptor.accept().await?; @@ -106,7 +106,7 @@ async fn run_tcp_server(addr: SocketAddr) -> Result<()> { } async fn run_quic_server(addr: SocketAddr) -> Result<()> { - let (_, acceptor, _) = quic::configure(addr).await?; + let (_, acceptor, _) = quic::configure(addr)?; println!("bound to {}", acceptor.local_addr()); loop { diff --git a/net/examples/stun.rs b/net/examples/stun.rs index b4e059076..a86ad6dd8 100644 --- a/net/examples/stun.rs +++ b/net/examples/stun.rs @@ -10,10 +10,12 @@ use tokio::net; async fn main() -> io::Result<()> { let options = Options::parse(); - let client_v4 = - StunClient::new(UdpSocket::bind((Ipv4Addr::UNSPECIFIED, options.port).into()).await?); - let client_v6 = - StunClient::new(UdpSocket::bind((Ipv6Addr::UNSPECIFIED, options.port).into()).await?); + let client_v4 = StunClient::new(UdpSocket::bind( + (Ipv4Addr::UNSPECIFIED, options.port).into(), + )?); + let client_v6 = StunClient::new(UdpSocket::bind( + (Ipv6Addr::UNSPECIFIED, options.port).into(), + )?); for server_name in options.servers { for server_addr in net::lookup_host(server_name).await? { diff --git a/net/src/quic.rs b/net/src/quic.rs index 9c718f960..4e7143ded 100644 --- a/net/src/quic.rs +++ b/net/src/quic.rs @@ -182,11 +182,9 @@ impl AsyncWrite for OwnedWriteHalf { } //------------------------------------------------------------------------------ -pub async fn configure( - bind_addr: SocketAddr, -) -> Result<(Connector, Acceptor, SideChannelMaker), Error> { +pub fn configure(bind_addr: SocketAddr) -> Result<(Connector, Acceptor, SideChannelMaker), Error> { let server_config = make_server_config()?; - let custom_socket = Arc::new(CustomUdpSocket::bind(bind_addr).await?); + let custom_socket = Arc::new(CustomUdpSocket::bind(bind_addr)?); let side_channel_maker = custom_socket.clone().side_channel_maker(); let mut endpoint = quinn::Endpoint::new_with_abstract_socket( @@ -360,8 +358,8 @@ struct CustomUdpSocket { } impl CustomUdpSocket { - async fn bind(addr: SocketAddr) -> io::Result { - let socket = crate::udp::UdpSocket::bind(addr).await?; + fn bind(addr: SocketAddr) -> io::Result { + let socket = crate::udp::UdpSocket::bind(addr)?; let socket = socket.into_std()?; let state = quinn::udp::UdpSocketState::new((&socket).into())?; @@ -596,7 +594,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn small_data_exchange() { - let (connector, acceptor, _) = configure((Ipv4Addr::LOCALHOST, 0).into()).await.unwrap(); + let (connector, acceptor, _) = configure((Ipv4Addr::LOCALHOST, 0).into()).unwrap(); let addr = *acceptor.local_addr(); @@ -638,7 +636,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn side_channel() { let (_connector, acceptor, side_channel_maker) = - configure((Ipv4Addr::LOCALHOST, 0).into()).await.unwrap(); + configure((Ipv4Addr::LOCALHOST, 0).into()).unwrap(); let addr = *acceptor.local_addr(); let side_channel = side_channel_maker.make(); diff --git a/net/src/tcp.rs b/net/src/tcp.rs index ac388aecb..315b4aa81 100644 --- a/net/src/tcp.rs +++ b/net/src/tcp.rs @@ -1,4 +1,50 @@ -pub use self::implementation::*; +pub use self::implementation::{OwnedReadHalf, OwnedWriteHalf, TcpStream}; + +use self::implementation::TcpListener; +use std::{io, net::SocketAddr}; + +/// Configure TCP endpoint +pub fn configure(bind_addr: SocketAddr) -> Result<(Connector, Acceptor), Error> { + let listener = TcpListener::bind(bind_addr)?; + let local_addr = listener.local_addr()?; + + Ok(( + Connector, + Acceptor { + listener, + local_addr, + }, + )) +} + +pub struct Connector; + +impl Connector { + pub async fn connect(&self, addr: SocketAddr) -> Result { + Ok(TcpStream::connect(addr).await?) + } +} + +pub struct Acceptor { + listener: TcpListener, + local_addr: SocketAddr, +} + +impl Acceptor { + pub async fn accept(&self) -> Result<(TcpStream, SocketAddr), Error> { + Ok(self.listener.accept().await?) + } + + pub fn local_addr(&self) -> &SocketAddr { + &self.local_addr + } +} + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("IO error")] + Io(#[from] io::Error), +} // Real #[cfg(not(feature = "simulation"))] @@ -20,7 +66,7 @@ mod implementation { impl TcpListener { /// Binds TCP socket to the given address. If the port is taken, uses a random one, - pub async fn bind(addr: impl Into) -> io::Result { + pub fn bind(addr: impl Into) -> io::Result { let addr = addr.into(); let socket = Socket::new(Domain::for_address(addr), Type::STREAM, None)?; diff --git a/net/src/udp.rs b/net/src/udp.rs index f84b91ab3..64950a032 100644 --- a/net/src/udp.rs +++ b/net/src/udp.rs @@ -37,7 +37,7 @@ mod implementation { impl UdpSocket { /// Binds UDP socket to the given address. If the port is taken, uses a random one, - pub async fn bind(addr: SocketAddr) -> io::Result { + pub fn bind(addr: SocketAddr) -> io::Result { let socket = Socket::new(Domain::for_address(addr), Type::DGRAM, None)?; socket.set_nonblocking(true)?; // Ignore errors - reuse address is nice to have but not required. @@ -47,7 +47,7 @@ mod implementation { Ok(Self(tokio::net::UdpSocket::from_std(socket.into())?)) } - pub async fn bind_multicast(interface: Ipv4Addr) -> io::Result { + pub fn bind_multicast(interface: Ipv4Addr) -> io::Result { let addr = SocketAddr::from((Ipv4Addr::UNSPECIFIED, MULTICAST_PORT)); let socket = Socket::new(Domain::for_address(addr), Type::DGRAM, None)?; From 8912396fcdb2d3141bf18860ef39ec874823540a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Tue, 3 Sep 2024 17:56:08 +0200 Subject: [PATCH 07/55] Create unified wrapper for TCP and QUIC --- net/Cargo.toml | 25 ++-- net/examples/peer.rs | 121 ++++++++---------- net/src/connection.rs | 285 +++++++++++++++++++++++++++++------------- net/src/quic.rs | 189 ++++++---------------------- net/src/tcp.rs | 183 ++++++++++++++++++++++++--- 5 files changed, 467 insertions(+), 336 deletions(-) diff --git a/net/Cargo.toml b/net/Cargo.toml index e352a565a..4d671967f 100644 --- a/net/Cargo.toml +++ b/net/Cargo.toml @@ -8,19 +8,20 @@ license.workspace = true version.workspace = true [dependencies] -bytecodec = "0.4.15" -bytes = "1.1.0" -futures-util = { workspace = true } +bytecodec = "0.4.15" +bytes = "1.1.0" +futures-util = { workspace = true } pin-project-lite = { workspace = true } -quinn = "0.11.4" -rand = { package = "ouisync-rand", path = "../rand" } -rcgen = { workspace = true } -socket2 = "0.5.7" # To be able to setsockopts before a socket is bound -stun_codec = "0.3.4" -thiserror = "1.0.31" -tokio = { workspace = true, features = ["io-util", "macros", "net", "rt-multi-thread", "sync"] } -turmoil = { workspace = true, optional = true } -yamux = "0.13.3" +quinn = "0.11.4" +rand = { package = "ouisync-rand", path = "../rand" } +rcgen = { workspace = true } +socket2 = "0.5.7" # To be able to setsockopts before a socket is bound +stun_codec = "0.3.4" +thiserror = "1.0.31" +tokio = { workspace = true, features = ["io-util", "macros", "net", "rt-multi-thread", "sync"] } +tokio-util = { workspace = true, features = ["compat"] } +turmoil = { workspace = true, optional = true } +yamux = "0.13.3" [dev-dependencies] anyhow = { workspace = true } diff --git a/net/examples/peer.rs b/net/examples/peer.rs index 4fefff8b7..6a7a24497 100644 --- a/net/examples/peer.rs +++ b/net/examples/peer.rs @@ -1,8 +1,8 @@ -use anyhow::{Context, Result}; +use anyhow::Result; use clap::{Parser, ValueEnum}; use ouisync_net::{ - quic, - tcp::{self, TcpStream}, + connection::{Acceptor, Connection, Connector, RecvStream, SendStream}, + quic, tcp, }; use std::{ future, @@ -11,7 +11,7 @@ use std::{ time::Duration, }; use tokio::{ - io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + io::{AsyncReadExt, AsyncWriteExt}, task, time, }; @@ -68,79 +68,36 @@ enum Proto { async fn run_client(options: &Options) -> Result<()> { let addr: SocketAddr = (options.addr.unwrap_or(DEFAULT_CONNECT_ADDR), options.port).into(); + let connector = match options.proto { + Proto::Tcp => { + let (connector, _) = tcp::configure((DEFAULT_BIND_ADDR, 0).into())?; + Connector::Tcp(connector) + } + Proto::Quic => { + let (connector, _, _) = quic::configure((DEFAULT_BIND_ADDR, 0).into())?; + Connector::Quic(connector) + } + }; - match options.proto { - Proto::Tcp => run_tcp_client(addr, options.count).await, - Proto::Quic => run_quic_client(addr, options.count).await, - } -} - -async fn run_tcp_client(addr: SocketAddr, count: Option) -> Result<()> { - let stream = TcpStream::connect(addr).await?; - run_client_connection(stream, count).await -} - -async fn run_quic_client(addr: SocketAddr, count: Option) -> Result<()> { - let (connector, _, _) = quic::configure((Ipv4Addr::UNSPECIFIED, 0).into())?; let connection = connector.connect(addr).await?; - run_client_connection(connection, count).await -} - -async fn run_server(options: &Options) -> Result<()> { - let bind_addr: SocketAddr = (options.addr.unwrap_or(DEFAULT_BIND_ADDR), options.port).into(); - - match options.proto { - Proto::Tcp => run_tcp_server(bind_addr).await, - Proto::Quic => run_quic_server(bind_addr).await, - } -} - -async fn run_tcp_server(addr: SocketAddr) -> Result<()> { - let (_, acceptor) = tcp::configure(addr)?; - println!("bound to {}", acceptor.local_addr()); - - loop { - let (stream, addr) = acceptor.accept().await?; - task::spawn(run_server_connection(stream, addr)); - } -} - -async fn run_quic_server(addr: SocketAddr) -> Result<()> { - let (_, acceptor, _) = quic::configure(addr)?; - println!("bound to {}", acceptor.local_addr()); - - loop { - let connection = acceptor - .accept() - .await - .context("failed to accept")? - .complete() - .await?; - let addr = *connection.remote_addr(); - task::spawn(run_server_connection(connection, addr)); - } -} + let (mut tx, mut rx) = connection.outgoing().await?; -async fn run_client_connection( - mut stream: T, - count: Option, -) -> Result<()> { println!("connected"); let message = "hello world"; let mut i = 0; loop { - if count.map(|count| i >= count).unwrap_or(false) { + if options.count.map(|count| i >= count).unwrap_or(false) { break; } i = i.saturating_add(1); println!("sending \"{message}\""); - write_message(&mut stream, message).await?; + write_message(&mut tx, message).await?; - let response = read_message(&mut stream).await?; + let response = read_message(&mut rx).await?; println!("received \"{response}\""); time::sleep(SEND_DELAY).await; @@ -149,11 +106,43 @@ async fn run_client_connection( future::pending().await } -async fn run_server_connection(mut stream: T, addr: SocketAddr) { +async fn run_server(options: &Options) -> Result<()> { + let bind_addr: SocketAddr = (options.addr.unwrap_or(DEFAULT_BIND_ADDR), options.port).into(); + + let acceptor = match options.proto { + Proto::Tcp => { + let (_, acceptor) = tcp::configure(bind_addr)?; + Acceptor::Tcp(acceptor) + } + Proto::Quic => { + let (_, acceptor, _) = quic::configure(bind_addr)?; + Acceptor::Quic(acceptor) + } + }; + + println!("bound to {}", acceptor.local_addr()); + + loop { + let connection = acceptor.accept().await?.await?; + task::spawn(run_server_connection(connection)); + } +} + +async fn run_server_connection(connection: Connection) { + let addr = connection.remote_addr(); + println!("[{}] accepted", addr); + let (mut tx, mut rx) = match connection.incoming().await { + Ok(stream) => stream, + Err(error) => { + println!("[{}] accept stream failed: {}", addr, error); + return; + } + }; + loop { - let message = match read_message(&mut stream).await { + let message = match read_message(&mut rx).await { Ok(message) => message, Err(error) => { println!("[{}] read failed: {}", addr, error); @@ -163,7 +152,7 @@ async fn run_server_connection(mut stream: T, println!("[{}] received \"{}\"", addr, message); - match write_message(&mut stream, "ok").await { + match write_message(&mut tx, "ok").await { Ok(_) => (), Err(error) => { println!("[{}] write failed: {}", addr, error); @@ -175,14 +164,14 @@ async fn run_server_connection(mut stream: T, println!("[{}] closed", addr); } -async fn read_message(reader: &mut T) -> Result { +async fn read_message(reader: &mut RecvStream) -> Result { let size = reader.read_u32().await? as usize; let mut buffer = vec![0; size]; reader.read_exact(&mut buffer).await?; Ok(String::from_utf8(buffer)?) } -async fn write_message(writer: &mut T, message: &str) -> Result<()> { +async fn write_message(writer: &mut SendStream, message: &str) -> Result<()> { writer.write_u32(message.len() as u32).await?; writer.write_all(message.as_bytes()).await?; Ok(()) diff --git a/net/src/connection.rs b/net/src/connection.rs index 57d6b44d5..a4f90aced 100644 --- a/net/src/connection.rs +++ b/net/src/connection.rs @@ -1,122 +1,168 @@ use crate::{quic, tcp}; use std::{ + future::Future, io, + net::SocketAddr, pin::Pin, task::{Context, Poll}, }; +use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -/// Network connection that supports multiple protocols. -pub enum Connection { - Tcp(tcp::TcpStream), - Quic(quic::Connection), +/// Unified connector +pub enum Connector { + Tcp(tcp::Connector), + Quic(quic::Connector), } -impl Connection { - pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { +impl Connector { + pub async fn connect(&self, addr: SocketAddr) -> Result { match self { - Self::Tcp(conn) => { - let (reader, writer) = conn.into_split(); - (OwnedReadHalf::Tcp(reader), OwnedWriteHalf::Tcp(writer)) - } - Self::Quic(conn) => { - let (reader, writer) = conn.into_split(); - (OwnedReadHalf::Quic(reader), OwnedWriteHalf::Quic(writer)) - } + Self::Tcp(inner) => inner + .connect(addr) + .await + .map(Connection::Tcp) + .map_err(Into::into), + Self::Quic(inner) => inner + .connect(addr) + .await + .map(Connection::Quic) + .map_err(Into::into), } } } -impl AsyncRead for Connection { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - match self.get_mut() { - Self::Tcp(s) => Pin::new(s).poll_read(cx, buf), - Self::Quic(s) => Pin::new(s).poll_read(cx, buf), +/// Unified acceptor +pub enum Acceptor { + Tcp(tcp::Acceptor), + Quic(quic::Acceptor), +} + +impl Acceptor { + pub fn local_addr(&self) -> &SocketAddr { + match self { + Self::Tcp(inner) => inner.local_addr(), + Self::Quic(inner) => inner.local_addr(), + } + } + + pub async fn accept(&self) -> Result { + match self { + Self::Tcp(inner) => Ok(Connecting::Tcp(Some(inner.accept().await?))), + Self::Quic(inner) => Ok(Connecting::Quic(inner.accept().await?)), } } } -impl AsyncWrite for Connection { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { +/// Incoming connection which being established. +pub enum Connecting { + // Note TCP doesn't support two phase accept so this is already a fully established + // connection. + Tcp(Option), + Quic(quic::Connecting), +} + +impl Future for Connecting { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.get_mut() { - Self::Tcp(s) => Pin::new(s).poll_write(cx, buf), - Self::Quic(s) => Pin::new(s).poll_write(cx, buf), + Self::Tcp(connection) => Poll::Ready(Ok(Connection::Tcp( + connection.take().expect("future polled after completion"), + ))), + Self::Quic(connecting) => Pin::new(connecting) + .poll(cx) + .map_ok(Connection::Quic) + .map_err(Into::into), } } +} - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[io::IoSlice<'_>], - ) -> Poll> { - match self.get_mut() { - Self::Tcp(s) => Pin::new(s).poll_write_vectored(cx, bufs), - Self::Quic(s) => Pin::new(s).poll_write_vectored(cx, bufs), +/// Unified connection. +pub enum Connection { + Tcp(tcp::Connection), + Quic(quic::Connection), +} + +impl Connection { + pub fn remote_addr(&self) -> SocketAddr { + match self { + Self::Tcp(inner) => inner.remote_addr(), + Self::Quic(inner) => inner.remote_addr(), } } - fn is_write_vectored(&self) -> bool { + /// Accept a new incoming stream + pub async fn incoming(&self) -> Result<(SendStream, RecvStream), Error> { match self { - Self::Tcp(s) => s.is_write_vectored(), - Self::Quic(s) => s.is_write_vectored(), + Self::Tcp(inner) => inner + .incoming() + .await + .map(|(send, recv)| (SendStream::Tcp(send), RecvStream::Tcp(recv))) + .map_err(Into::into), + Self::Quic(inner) => inner + .incoming() + .await + .map(|(send, recv)| (SendStream::Quic(send), RecvStream::Quic(recv))) + .map_err(Into::into), } } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.get_mut() { - Self::Tcp(s) => Pin::new(s).poll_flush(cx), - Self::Quic(s) => Pin::new(s).poll_flush(cx), + /// Open a new outgoing stream + pub async fn outgoing(&self) -> Result<(SendStream, RecvStream), Error> { + match self { + Self::Tcp(inner) => inner + .outgoing() + .await + .map(|(send, recv)| (SendStream::Tcp(send), RecvStream::Tcp(recv))) + .map_err(Into::into), + Self::Quic(inner) => inner + .outgoing() + .await + .map(|(send, recv)| (SendStream::Quic(send), RecvStream::Quic(recv))) + .map_err(Into::into), } } - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.get_mut() { - Self::Tcp(s) => Pin::new(s).poll_shutdown(cx), - Self::Quic(s) => Pin::new(s).poll_shutdown(cx), + /// Gracefully close the connection + pub async fn close(&self) -> Result<(), Error> { + match self { + Self::Tcp(inner) => inner.close().await?, + Self::Quic(inner) => inner.close(), } + + Ok(()) } } -pub enum OwnedReadHalf { - Tcp(tcp::OwnedReadHalf), - Quic(quic::OwnedReadHalf), +pub enum SendStream { + Tcp(tcp::SendStream), + Quic(quic::SendStream), } -impl AsyncRead for OwnedReadHalf { - fn poll_read( +impl AsyncWrite for SendStream { + fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { + buf: &[u8], + ) -> Poll> { match self.get_mut() { - Self::Tcp(rx) => Pin::new(rx).poll_read(cx, buf), - Self::Quic(rx) => Pin::new(rx).poll_read(cx, buf), + Self::Tcp(inner) => AsyncWrite::poll_write(Pin::new(inner), cx, buf), + Self::Quic(inner) => AsyncWrite::poll_write(Pin::new(inner), cx, buf), } } -} -pub enum OwnedWriteHalf { - Tcp(tcp::OwnedWriteHalf), - Quic(quic::OwnedWriteHalf), -} + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Tcp(inner) => AsyncWrite::poll_flush(Pin::new(inner), cx), + Self::Quic(inner) => AsyncWrite::poll_flush(Pin::new(inner), cx), + } + } -impl AsyncWrite for OwnedWriteHalf { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { - Self::Tcp(s) => Pin::new(s).poll_write(cx, buf), - Self::Quic(s) => Pin::new(s).poll_write(cx, buf), + Self::Tcp(inner) => AsyncWrite::poll_shutdown(Pin::new(inner), cx), + Self::Quic(inner) => AsyncWrite::poll_shutdown(Pin::new(inner), cx), } } @@ -124,31 +170,100 @@ impl AsyncWrite for OwnedWriteHalf { self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], - ) -> Poll> { + ) -> Poll> { match self.get_mut() { - Self::Tcp(s) => Pin::new(s).poll_write_vectored(cx, bufs), - Self::Quic(s) => Pin::new(s).poll_write_vectored(cx, bufs), + Self::Tcp(inner) => AsyncWrite::poll_write_vectored(Pin::new(inner), cx, bufs), + Self::Quic(inner) => AsyncWrite::poll_write_vectored(Pin::new(inner), cx, bufs), } } fn is_write_vectored(&self) -> bool { match self { - Self::Tcp(s) => s.is_write_vectored(), - Self::Quic(s) => s.is_write_vectored(), + Self::Tcp(inner) => inner.is_write_vectored(), + Self::Quic(inner) => inner.is_write_vectored(), } } +} + +pub enum RecvStream { + Tcp(tcp::RecvStream), + Quic(quic::RecvStream), +} - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { +impl AsyncRead for RecvStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { match self.get_mut() { - Self::Tcp(s) => Pin::new(s).poll_flush(cx), - Self::Quic(s) => Pin::new(s).poll_flush(cx), + Self::Tcp(inner) => AsyncRead::poll_read(Pin::new(inner), cx, buf), + Self::Quic(inner) => AsyncRead::poll_read(Pin::new(inner), cx, buf), } } +} - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.get_mut() { - Self::Tcp(s) => Pin::new(s).poll_shutdown(cx), - Self::Quic(s) => Pin::new(s).poll_shutdown(cx), - } +#[derive(Error, Debug)] +pub enum Error { + #[error("tcp")] + Tcp(#[from] tcp::Error), + #[error("quic")] + Quic(#[from] quic::Error), +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::Ipv4Addr; + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + task, + }; + + #[tokio::test] + async fn ping_tcp() { + let (client, _) = tcp::configure((Ipv4Addr::LOCALHOST, 0).into()).unwrap(); + let (_, server) = tcp::configure((Ipv4Addr::LOCALHOST, 0).into()).unwrap(); + + ping_case(Connector::Tcp(client), Acceptor::Tcp(server)).await + } + + #[tokio::test] + async fn ping_quic() { + let (client, _, _) = quic::configure((Ipv4Addr::LOCALHOST, 0).into()).unwrap(); + let (_, server, _) = quic::configure((Ipv4Addr::LOCALHOST, 0).into()).unwrap(); + + ping_case(Connector::Quic(client), Acceptor::Quic(server)).await + } + + async fn ping_case(client_connector: Connector, server_acceptor: Acceptor) { + let addr = *server_acceptor.local_addr(); + + let server = task::spawn(async move { + let conn = server_acceptor.accept().await.unwrap().await.unwrap(); + let (mut tx, mut rx) = conn.incoming().await.unwrap(); + + let mut buf = [0; 4]; + rx.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"ping"); + + tx.write_all(b"pong").await.unwrap(); + }); + + let client = task::spawn(async move { + let conn = client_connector.connect(addr).await.unwrap(); + let (mut tx, mut rx) = conn.outgoing().await.unwrap(); + + tx.write_all(b"ping").await.unwrap(); + + let mut buf = [0; 4]; + + // Ignore error as it likely means the connection was closed by the peer, which is + // expected. + rx.read_exact(&mut buf).await.ok(); + }); + + server.await.unwrap(); + client.await.unwrap(); } } diff --git a/net/src/quic.rs b/net/src/quic.rs index 4e7143ded..e19e5ac1b 100644 --- a/net/src/quic.rs +++ b/net/src/quic.rs @@ -13,19 +13,16 @@ use quinn::{ }; use std::{ fmt, - future::Future, + future::{Future, IntoFuture}, io, net::SocketAddr, pin::Pin, sync::Arc, task::{Context, Poll}, }; -use tokio::{ - io::{AsyncRead, AsyncWrite, ReadBuf}, - sync::{ - broadcast::{self, error::RecvError}, - Mutex as AsyncMutex, - }, +use tokio::sync::{ + broadcast::{self, error::RecvError}, + Mutex as AsyncMutex, }; const CERT_DOMAIN: &str = "ouisync.net"; @@ -37,9 +34,11 @@ pub struct Connector { impl Connector { pub async fn connect(&self, remote_addr: SocketAddr) -> Result { - let connection = self.endpoint.connect(remote_addr, CERT_DOMAIN)?.await?; - let (tx, rx) = connection.open_bi().await?; - Ok(Connection::new(rx, tx, connection.remote_address())) + self.endpoint + .connect(remote_addr, CERT_DOMAIN)? + .await + .map(|inner| Connection { inner }) + .map_err(Into::into) } // forcefully close all connections (any pending operation on any connection will immediatelly @@ -56,11 +55,14 @@ pub struct Acceptor { } impl Acceptor { - pub async fn accept(&self) -> Option { + pub async fn accept(&self) -> Result { self.endpoint .accept() .await - .map(|incoming| Connecting { incoming }) + .map(|inner| Connecting { + inner: inner.into_future(), + }) + .ok_or(Error::EndpointClosed) } pub fn local_addr(&self) -> &SocketAddr { @@ -69,117 +71,44 @@ impl Acceptor { } pub struct Connecting { - incoming: quinn::Incoming, + inner: quinn::IncomingFuture, } -impl Connecting { - pub fn remote_addr(&self) -> SocketAddr { - self.incoming.remote_address() - } +impl Future for Connecting { + type Output = Result; - pub async fn complete(self) -> Result { - let connection = self.incoming.await?; - let (tx, rx) = connection.accept_bi().await?; - Ok(Connection::new(rx, tx, connection.remote_address())) + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.inner) + .poll(cx) + .map_ok(|inner| Connection { inner }) + .map_err(Into::into) } } -//------------------------------------------------------------------------------ pub struct Connection { - reader: quinn::RecvStream, - writer: quinn::SendStream, - remote_addr: SocketAddr, + inner: quinn::Connection, } impl Connection { - pub fn new( - reader: quinn::RecvStream, - writer: quinn::SendStream, - remote_addr: SocketAddr, - ) -> Self { - Self { - reader, - writer, - remote_addr, - } - } - - pub fn remote_addr(&self) -> &SocketAddr { - &self.remote_addr - } - - pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { - ( - OwnedReadHalf { inner: self.reader }, - OwnedWriteHalf { inner: self.writer }, - ) - } -} - -impl AsyncRead for Connection { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - AsyncRead::poll_read(Pin::new(&mut self.reader), cx, buf) - } -} - -impl AsyncWrite for Connection { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - AsyncWrite::poll_write(Pin::new(&mut self.writer), cx, buf) + pub fn remote_addr(&self) -> SocketAddr { + self.inner.remote_address() } - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - AsyncWrite::poll_flush(Pin::new(&mut self.writer), cx) + pub async fn incoming(&self) -> Result<(SendStream, RecvStream), Error> { + self.inner.accept_bi().await.map_err(Into::into) } - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - AsyncWrite::poll_shutdown(Pin::new(&mut self.writer), cx) + pub async fn outgoing(&self) -> Result<(SendStream, RecvStream), Error> { + self.inner.open_bi().await.map_err(Into::into) } -} - -//------------------------------------------------------------------------------ -pub struct OwnedReadHalf { - inner: quinn::RecvStream, -} -pub struct OwnedWriteHalf { - inner: quinn::SendStream, -} - -impl AsyncRead for OwnedReadHalf { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - AsyncRead::poll_read(Pin::new(&mut self.inner), cx, buf) + pub fn close(&self) { + self.inner.close(0u8.into(), &[]); } } -impl AsyncWrite for OwnedWriteHalf { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - AsyncWrite::poll_flush(Pin::new(&mut self.inner), cx) - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - AsyncWrite::poll_shutdown(Pin::new(&mut self.inner), cx) - } -} +pub type SendStream = quinn::SendStream; +pub type RecvStream = quinn::RecvStream; //------------------------------------------------------------------------------ pub fn configure(bind_addr: SocketAddr) -> Result<(Connector, Acceptor, SideChannelMaker), Error> { @@ -218,10 +147,8 @@ pub enum Error { Connect(#[from] ConnectError), #[error("connection error")] Connection(#[from] ConnectionError), - #[error("write error")] - Write(#[from] WriteError), - #[error("done accepting error")] - DoneAccepting, + #[error("endpoint closed")] + EndpointClosed, #[error("IO error")] Io(#[from] std::io::Error), #[error("TLS error")] @@ -587,51 +514,7 @@ impl SideChannelSender { mod tests { use super::*; use std::net::Ipv4Addr; - use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - task, - }; - - #[tokio::test(flavor = "multi_thread")] - async fn small_data_exchange() { - let (connector, acceptor, _) = configure((Ipv4Addr::LOCALHOST, 0).into()).unwrap(); - - let addr = *acceptor.local_addr(); - - let h1 = task::spawn(async move { - let mut conn = acceptor.accept().await.unwrap().complete().await.unwrap(); - - let mut buf = [0; 4]; - conn.read_exact(&mut buf).await.unwrap(); - assert_eq!(&buf, b"ping"); - - conn.write_all(b"pong").await.unwrap(); - }); - - let h2 = task::spawn(async move { - let mut conn = connector.connect(addr).await.unwrap(); - conn.write_all(b"ping").await.unwrap(); - - let mut buf = [0; 4]; - match conn.read_exact(&mut buf).await { - Ok(_) => (), - Err(error) => match error.downcast::() { - Ok(error) => match error { - quinn::ReadError::ConnectionLost( - quinn::ConnectionError::ApplicationClosed(_), - ) => { - // connection gracefully closed by the peer, this is expected. - } - error => panic!("unexpected error: {:?}", error), - }, - Err(error) => panic!("unexpected error: {:?}", error), - }, - } - }); - - h1.await.unwrap(); - h2.await.unwrap(); - } + use tokio::task; #[tokio::test(flavor = "multi_thread")] async fn side_channel() { diff --git a/net/src/tcp.rs b/net/src/tcp.rs index 315b4aa81..d4b32ddeb 100644 --- a/net/src/tcp.rs +++ b/net/src/tcp.rs @@ -1,7 +1,12 @@ -pub use self::implementation::{OwnedReadHalf, OwnedWriteHalf, TcpStream}; - -use self::implementation::TcpListener; -use std::{io, net::SocketAddr}; +use self::implementation::{TcpListener, TcpStream}; +use std::{collections::VecDeque, future, io, net::SocketAddr}; +use tokio::{ + io::{ReadHalf, WriteHalf}, + select, + sync::{mpsc, oneshot}, + task, +}; +use tokio_util::compat::{Compat, FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; /// Configure TCP endpoint pub fn configure(bind_addr: SocketAddr) -> Result<(Connector, Acceptor), Error> { @@ -20,19 +25,24 @@ pub fn configure(bind_addr: SocketAddr) -> Result<(Connector, Acceptor), Error> pub struct Connector; impl Connector { - pub async fn connect(&self, addr: SocketAddr) -> Result { - Ok(TcpStream::connect(addr).await?) + pub async fn connect(&self, addr: SocketAddr) -> Result { + let stream = TcpStream::connect(addr).await?; + + Ok(Connection::new(stream, yamux::Mode::Client, addr)) } } +/// TCP acceptor pub struct Acceptor { listener: TcpListener, local_addr: SocketAddr, } impl Acceptor { - pub async fn accept(&self) -> Result<(TcpStream, SocketAddr), Error> { - Ok(self.listener.accept().await?) + pub async fn accept(&self) -> Result { + let (stream, addr) = self.listener.accept().await?; + + Ok(Connection::new(stream, yamux::Mode::Server, addr)) } pub fn local_addr(&self) -> &SocketAddr { @@ -40,17 +50,157 @@ impl Acceptor { } } +/// TCP connection +pub struct Connection { + remote_addr: SocketAddr, + command_tx: mpsc::Sender, +} + +impl Connection { + fn new(stream: TcpStream, mode: yamux::Mode, remote_addr: SocketAddr) -> Self { + let connection = yamux::Connection::new(stream.compat(), connection_config(), mode); + let (command_tx, command_rx) = mpsc::channel(1); + + task::spawn(drive_connection(connection, command_rx)); + + Self { + command_tx, + remote_addr, + } + } + + pub fn remote_addr(&self) -> SocketAddr { + self.remote_addr + } + + /// Accept the next incoming stream + pub async fn incoming(&self) -> Result<(SendStream, RecvStream), Error> { + let (reply_tx, reply_rx) = oneshot::channel(); + + self.command_tx + .send(Command::Incoming(reply_tx)) + .await + .map_err(|_| yamux::ConnectionError::Closed)?; + + let stream = reply_rx + .await + .map_err(|_| yamux::ConnectionError::Closed)??; + let (recv, send) = tokio::io::split(stream.compat()); + + Ok((send, recv)) + } + + /// Open a new outgoing stream + pub async fn outgoing(&self) -> Result<(SendStream, RecvStream), Error> { + let (reply_tx, reply_rx) = oneshot::channel(); + + self.command_tx + .send(Command::Outgoing(reply_tx)) + .await + .map_err(|_| yamux::ConnectionError::Closed)?; + + let stream = reply_rx + .await + .map_err(|_| yamux::ConnectionError::Closed)??; + let (recv, send) = tokio::io::split(stream.compat()); + + Ok((send, recv)) + } + + /// Gracefully close the connection + pub async fn close(&self) -> Result<(), Error> { + let (reply_tx, reply_rx) = oneshot::channel(); + + // If send or receive return an error it means the connection is already closed. Returning + // `Ok` in that case to make this function idempotent. + self.command_tx + .send(Command::Close(Some(reply_tx))) + .await + .ok(); + + reply_rx.await.unwrap_or(Ok(())).map_err(Into::into) + } +} + +pub type SendStream = WriteHalf>; +pub type RecvStream = ReadHalf>; + #[derive(Debug, thiserror::Error)] pub enum Error { #[error("IO error")] Io(#[from] io::Error), + #[error("connection error")] + Connection(#[from] yamux::ConnectionError), +} + +// Yamux requires that `poll_next_inbound` is called continuously in order to drive the connection +// forward. We spawn a task a do it there. +async fn drive_connection( + mut conn: yamux::Connection>, + mut command_rx: mpsc::Receiver, +) { + // Buffer for incoming streams. This buffer is unbounded but yamux itself has a limit on the + // total number of concurrently open streams which effectively puts a bound on this buffer as + // well. + let mut incoming = VecDeque::new(); + + loop { + let command = select! { + command = command_rx.recv() => command, + result = future::poll_fn(|cx| conn.poll_next_inbound(cx)) => { + match result { + Some(result) => { + incoming.push_front(result); + continue; + } + None => break, + } + } + }; + + match command.unwrap_or(Command::Close(None)) { + Command::Incoming(reply_tx) => { + if let Some(result) = incoming.pop_back() { + reply_tx.send(result).ok(); + continue; + } + + if let Some(result) = future::poll_fn(|cx| conn.poll_next_inbound(cx)).await { + reply_tx.send(result).ok(); + } else { + break; + } + } + Command::Outgoing(reply_tx) => { + let result = future::poll_fn(|cx| conn.poll_new_outbound(cx)).await; + reply_tx.send(result).ok(); + } + Command::Close(reply_tx) => { + let result = future::poll_fn(|cx| conn.poll_close(cx)).await; + + if let Some(reply_tx) = reply_tx { + reply_tx.send(result).ok(); + } + + break; + } + } + } +} + +enum Command { + Incoming(oneshot::Sender>), + Outgoing(oneshot::Sender>), + Close(Option>>), +} + +fn connection_config() -> yamux::Config { + yamux::Config::default() } // Real #[cfg(not(feature = "simulation"))] mod implementation { - pub use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; - use crate::{socket, KEEP_ALIVE_INTERVAL}; use socket2::{Domain, Socket, TcpKeepalive, Type}; use std::{ @@ -62,7 +212,7 @@ mod implementation { use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; /// TCP listener - pub struct TcpListener(tokio::net::TcpListener); + pub(super) struct TcpListener(tokio::net::TcpListener); impl TcpListener { /// Binds TCP socket to the given address. If the port is taken, uses a random one, @@ -98,7 +248,7 @@ mod implementation { } /// TCP stream - pub struct TcpStream(tokio::net::TcpStream); + pub(super) struct TcpStream(tokio::net::TcpStream); impl TcpStream { pub async fn connect(addr: SocketAddr) -> io::Result { @@ -112,10 +262,6 @@ mod implementation { .await?, )) } - - pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { - self.0.into_split() - } } fn set_keep_alive(socket: &Socket) -> io::Result<()> { @@ -179,8 +325,5 @@ mod implementation { // Simulation #[cfg(feature = "simulation")] mod implementation { - pub use turmoil::net::{ - tcp::{OwnedReadHalf, OwnedWriteHalf}, - TcpListener, TcpStream, - }; + pub(super) use turmoil::net::{TcpListener, TcpStream}; } From 264dc53548be7366094ce1c7344efb5deaa24c66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Wed, 4 Sep 2024 15:21:05 +0200 Subject: [PATCH 08/55] Fix cancellation safety in tcp::Connection --- Cargo.toml | 1 + lib/Cargo.toml | 2 +- net/Cargo.toml | 37 +++++---- net/src/connection.rs | 143 ++++++++++++++++++++++++++++++--- net/src/lib.rs | 1 + net/src/sync.rs | 179 ++++++++++++++++++++++++++++++++++++++++++ net/src/tcp.rs | 76 +++++++++++++----- 7 files changed, 391 insertions(+), 48 deletions(-) create mode 100644 net/src/sync.rs diff --git a/Cargo.toml b/Cargo.toml index 8c1982c32..6ab8ffedd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,7 @@ rustls = { version = "0.23.5", default-features = false } serde = { version = "1.0", features = ["derive", "rc"] } serde_bytes = "0.11.8" serde_json = "1.0.94" +similar-asserts = "1.5.0" sqlx = { version = "0.7.4", default-features = false, features = ["runtime-tokio", "sqlite"] } tempfile = "3.2" thiserror = "1.0.49" diff --git a/lib/Cargo.toml b/lib/Cargo.toml index cc04c1259..137e14f73 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -91,7 +91,7 @@ proptest = "1.0" rmp-serde = { workspace = true } serde_json = { workspace = true } serde_test = "1.0.176" -similar-asserts = "1.5.0" +similar-asserts = { workspace = true } tempfile = { workspace = true } test-strategy = "0.2.1" tokio = { workspace = true, features = ["process", "test-util"] } diff --git a/net/Cargo.toml b/net/Cargo.toml index 4d671967f..f7e538b08 100644 --- a/net/Cargo.toml +++ b/net/Cargo.toml @@ -8,25 +8,28 @@ license.workspace = true version.workspace = true [dependencies] -bytecodec = "0.4.15" -bytes = "1.1.0" -futures-util = { workspace = true } -pin-project-lite = { workspace = true } -quinn = "0.11.4" -rand = { package = "ouisync-rand", path = "../rand" } -rcgen = { workspace = true } -socket2 = "0.5.7" # To be able to setsockopts before a socket is bound -stun_codec = "0.3.4" -thiserror = "1.0.31" -tokio = { workspace = true, features = ["io-util", "macros", "net", "rt-multi-thread", "sync"] } -tokio-util = { workspace = true, features = ["compat"] } -turmoil = { workspace = true, optional = true } -yamux = "0.13.3" +bytecodec = "0.4.15" +bytes = "1.1.0" +futures-util = { workspace = true } +pin-project-lite = { workspace = true } +quinn = "0.11.4" +rand = { package = "ouisync-rand", path = "../rand" } +rcgen = { workspace = true } +socket2 = "0.5.7" # To be able to setsockopts before a socket is bound +stun_codec = "0.3.4" +thiserror = "1.0.31" +tokio = { workspace = true, features = ["io-util", "macros", "net", "rt-multi-thread", "sync"] } +tokio-util = { workspace = true, features = ["compat"] } +tracing = { workspace = true } +turmoil = { workspace = true, optional = true } +yamux = "0.13.3" [dev-dependencies] -anyhow = { workspace = true } -clap = { workspace = true } -tokio = { workspace = true } +anyhow = { workspace = true } +clap = { workspace = true } +similar-asserts = { workspace = true } +tokio = { workspace = true } +tracing-subscriber = { workspace = true } [features] simulation = ["turmoil"] diff --git a/net/src/connection.rs b/net/src/connection.rs index a4f90aced..27ba8d1b8 100644 --- a/net/src/connection.rs +++ b/net/src/connection.rs @@ -214,26 +214,25 @@ pub enum Error { #[cfg(test)] mod tests { use super::*; + use futures_util::{future, stream::FuturesUnordered, StreamExt}; + use rand::{distributions::Standard, Rng}; use std::net::Ipv4Addr; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, - task, + select, task, }; + use tracing::Instrument; #[tokio::test] async fn ping_tcp() { - let (client, _) = tcp::configure((Ipv4Addr::LOCALHOST, 0).into()).unwrap(); - let (_, server) = tcp::configure((Ipv4Addr::LOCALHOST, 0).into()).unwrap(); - - ping_case(Connector::Tcp(client), Acceptor::Tcp(server)).await + let (client, server) = setup_tcp_peers(); + ping_case(client, server).await } #[tokio::test] async fn ping_quic() { - let (client, _, _) = quic::configure((Ipv4Addr::LOCALHOST, 0).into()).unwrap(); - let (_, server, _) = quic::configure((Ipv4Addr::LOCALHOST, 0).into()).unwrap(); - - ping_case(Connector::Quic(client), Acceptor::Quic(server)).await + let (client, server) = setup_quic_peers(); + ping_case(client, server).await } async fn ping_case(client_connector: Connector, server_acceptor: Acceptor) { @@ -266,4 +265,130 @@ mod tests { server.await.unwrap(); client.await.unwrap(); } + + #[tokio::test] + async fn multi_streams_tcp() { + let (client, server) = setup_tcp_peers(); + multi_streams_case(client, server).await; + } + + #[tokio::test] + async fn multi_streams_quic() { + let (client, server) = setup_quic_peers(); + multi_streams_case(client, server).await; + } + + async fn multi_streams_case(client_connector: Connector, server_acceptor: Acceptor) { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::DEBUG) + .compact() + .init(); + + let num_messages = 32; + let min_message_size = 1; + let max_message_size = 256 * 1024; + + let mut rng = rand::thread_rng(); + let mut messages: Vec> = (0..num_messages) + .map(|_| { + let size = rng.gen_range(min_message_size..=max_message_size); + (&mut rng).sample_iter(Standard).take(size).collect() + }) + .collect(); + + let server_addr = *server_acceptor.local_addr(); + + let client = async { + let conn = client_connector.connect(server_addr).await.unwrap(); + let tasks = FuturesUnordered::new(); + + for message in &messages { + tasks.push(async { + let (mut tx, mut rx) = conn.outgoing().await.unwrap(); + + // Send message + tx.write_u32(message.len() as u32).await.unwrap(); + tx.write_all(message).await.unwrap(); + + // Receive response and close the stream + let mut buf = [0; 2]; + rx.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"ok"); + + tx.shutdown().await.unwrap(); + + message.clone() + }); + } + + let sent_messages: Vec<_> = tasks.collect().await; + + conn.close().await.unwrap(); + + sent_messages + } + .instrument(tracing::info_span!("client")); + + let server = async { + let conn = server_acceptor.accept().await.unwrap().await.unwrap(); + let mut tasks = FuturesUnordered::new(); + let mut received_messages = Vec::new(); + + loop { + let (mut tx, mut rx) = select! { + Ok(stream) = conn.incoming() => stream, + Some(message) = tasks.next() => { + received_messages.push(message); + continue; + } + else => break, + }; + + tasks.push(async move { + // Read message len + let len = rx.read_u32().await.unwrap() as usize; + + // Read message content + let mut message = vec![0; len]; + rx.read_exact(&mut message).await.unwrap(); + + // Send response and close the stream + tx.write_all(b"ok").await.unwrap(); + + tx.shutdown().await.unwrap(); + + message + }); + } + + received_messages + } + .instrument(tracing::info_span!("server")); + + let (mut sent_messages, mut received_messages) = future::join(client, server).await; + + assert_eq!(sent_messages.len(), messages.len()); + assert_eq!(received_messages.len(), messages.len()); + + sent_messages.sort(); + received_messages.sort(); + messages.sort(); + + similar_asserts::assert_eq!(sent_messages, messages); + similar_asserts::assert_eq!(received_messages, messages); + } + + fn setup_tcp_peers() -> (Connector, Acceptor) { + let (client, _) = tcp::configure((Ipv4Addr::LOCALHOST, 0).into()).unwrap(); + let (_, server) = tcp::configure((Ipv4Addr::LOCALHOST, 0).into()).unwrap(); + + (Connector::Tcp(client), Acceptor::Tcp(server)) + } + + fn setup_quic_peers() -> (Connector, Acceptor) { + let (client, _, _) = quic::configure((Ipv4Addr::LOCALHOST, 0).into()).unwrap(); + let (_, server, _) = quic::configure((Ipv4Addr::LOCALHOST, 0).into()).unwrap(); + + (Connector::Quic(client), Acceptor::Quic(server)) + } } diff --git a/net/src/lib.rs b/net/src/lib.rs index bee42acfc..773f79e6a 100644 --- a/net/src/lib.rs +++ b/net/src/lib.rs @@ -8,5 +8,6 @@ pub mod udp; #[cfg(not(feature = "simulation"))] mod socket; +mod sync; pub const KEEP_ALIVE_INTERVAL: Duration = Duration::from_secs(10); diff --git a/net/src/sync.rs b/net/src/sync.rs new file mode 100644 index 000000000..2fb5d972e --- /dev/null +++ b/net/src/sync.rs @@ -0,0 +1,179 @@ +pub(crate) mod rendezvous { + //! Single producer, single consumer, oneshot, rendezvous channel. + //! + //! Unlike `tokio::sync::oneshot`, this one guarantees that the message is never lost even when + //! the receiver is dropped before receiving the message. Because of this,[`Sender::send`] must + //! be `async`. + //! + //! # Cancel safety + //! + //! If `send` is cancelled before completion, the value is still guaranteed to be received by + //! the receiver. If `recv` is cancelled before completion, the value is returned back from + //! `send`. If both `send` and `recv` are cancelled, the value is lost. + + use std::{ + fmt, + sync::{Arc, Mutex}, + }; + use tokio::sync::Notify; + + /// Sends a value to the associated [`Receiver`]. + pub struct Sender { + shared: Arc>, + } + + impl Sender { + /// Sends the `value` to the [`Receiver`]. + pub async fn send(self, value: T) -> Result<(), T> { + self.shared.state.lock().unwrap().value = Some(value); + self.shared.tx_notify.notify_one(); + + loop { + self.shared.rx_notify.notified().await; + + let mut state = self.shared.state.lock().unwrap(); + + if state.value.is_none() { + return Ok(()); + } + + if state.rx_drop { + return Err(state.value.take().unwrap()); + } + } + } + } + + impl Drop for Sender { + fn drop(&mut self) { + self.shared.state.lock().unwrap().tx_drop = true; + self.shared.tx_notify.notify_one(); + } + } + + /// Receives a value from the associated [`Sender`]. + pub struct Receiver { + shared: Arc>, + } + + impl Receiver { + /// Receives the value from the [`Sender`], + /// + /// If the sender is dropped before calling `send`, returns `RecvError`. Otherwise this is + /// guaranteed to return the sent value. + pub async fn recv(self) -> Result { + loop { + self.shared.tx_notify.notified().await; + + let mut state = self.shared.state.lock().unwrap(); + + if let Some(value) = state.value.take() { + self.shared.rx_notify.notify_one(); + return Ok(value); + } + + if state.tx_drop { + return Err(RecvError); + } + } + } + } + + impl Drop for Receiver { + fn drop(&mut self) { + self.shared.state.lock().unwrap().rx_drop = true; + self.shared.rx_notify.notify_one(); + } + } + + /// Create a rendezvous channel for sending a single message of type `T`. + pub fn channel() -> (Sender, Receiver) { + let shared = Arc::new(Shared { + state: Mutex::new(State { + value: None, + tx_drop: false, + rx_drop: false, + }), + tx_notify: Notify::new(), + rx_notify: Notify::new(), + }); + + ( + Sender { + shared: shared.clone(), + }, + Receiver { shared }, + ) + } + + #[derive(Debug, Eq, PartialEq)] + pub struct RecvError; + + impl fmt::Display for RecvError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "channel closed") + } + } + + impl std::error::Error for RecvError {} + + struct Shared { + state: Mutex>, + tx_notify: Notify, + rx_notify: Notify, + } + + struct State { + value: Option, + tx_drop: bool, + rx_drop: bool, + } + + #[cfg(test)] + mod tests { + use super::*; + use futures_util::future; + use tokio::task; + + #[tokio::test] + async fn sanity_check() { + let (tx, rx) = channel(); + + let (tx_result, rx_result) = future::join(tx.send(1), rx.recv()).await; + + assert_eq!(tx_result, Ok(())); + assert_eq!(rx_result, Ok(1)); + } + + #[tokio::test] + async fn drop_tx_before_send() { + let (tx, rx) = channel::(); + + drop(tx); + + assert_eq!(rx.recv().await, Err(RecvError)); + } + + #[tokio::test] + async fn drop_rx_before_send() { + let (tx, rx) = channel::(); + + drop(rx); + + assert_eq!(tx.send(1).await, Err(1)); + } + + #[tokio::test] + async fn drop_rx_before_recv() { + let (tx, rx) = channel::(); + + let (tx_result, _) = future::join(tx.send(1), async move { + task::yield_now().await; + drop(rx) + }) + .await; + + assert_eq!(tx_result, Err(1)); + } + } +} diff --git a/net/src/tcp.rs b/net/src/tcp.rs index d4b32ddeb..927d08929 100644 --- a/net/src/tcp.rs +++ b/net/src/tcp.rs @@ -1,5 +1,7 @@ +use crate::sync::rendezvous; + use self::implementation::{TcpListener, TcpStream}; -use std::{collections::VecDeque, future, io, net::SocketAddr}; +use std::{future, io, net::SocketAddr}; use tokio::{ io::{ReadHalf, WriteHalf}, select, @@ -7,6 +9,7 @@ use tokio::{ task, }; use tokio_util::compat::{Compat, FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; +use tracing::{Instrument, Span}; /// Configure TCP endpoint pub fn configure(bind_addr: SocketAddr) -> Result<(Connector, Acceptor), Error> { @@ -61,7 +64,7 @@ impl Connection { let connection = yamux::Connection::new(stream.compat(), connection_config(), mode); let (command_tx, command_rx) = mpsc::channel(1); - task::spawn(drive_connection(connection, command_rx)); + task::spawn(drive_connection(connection, command_rx).instrument(Span::current())); Self { command_tx, @@ -73,9 +76,13 @@ impl Connection { self.remote_addr } - /// Accept the next incoming stream + /// Accept the next incoming stream. + /// + /// # Cancel safety + /// + /// In case this function is cancelled, no stream gets lost and the call can be safely retried. pub async fn incoming(&self) -> Result<(SendStream, RecvStream), Error> { - let (reply_tx, reply_rx) = oneshot::channel(); + let (reply_tx, reply_rx) = rendezvous::channel(); self.command_tx .send(Command::Incoming(reply_tx)) @@ -83,6 +90,7 @@ impl Connection { .map_err(|_| yamux::ConnectionError::Closed)?; let stream = reply_rx + .recv() .await .map_err(|_| yamux::ConnectionError::Closed)??; let (recv, send) = tokio::io::split(stream.compat()); @@ -91,8 +99,12 @@ impl Connection { } /// Open a new outgoing stream + /// + /// # Cancel safety + /// + /// In case this function is cancelled, no stream gets lost and the call can be safely retried. pub async fn outgoing(&self) -> Result<(SendStream, RecvStream), Error> { - let (reply_tx, reply_rx) = oneshot::channel(); + let (reply_tx, reply_rx) = rendezvous::channel(); self.command_tx .send(Command::Outgoing(reply_tx)) @@ -100,6 +112,7 @@ impl Connection { .map_err(|_| yamux::ConnectionError::Closed)?; let stream = reply_rx + .recv() .await .map_err(|_| yamux::ConnectionError::Closed)??; let (recv, send) = tokio::io::split(stream.compat()); @@ -108,6 +121,10 @@ impl Connection { } /// Gracefully close the connection + /// + /// # Cancel safety + /// + /// This function is idempotent even in the presence of cancellation. pub async fn close(&self) -> Result<(), Error> { let (reply_tx, reply_rx) = oneshot::channel(); @@ -139,10 +156,11 @@ async fn drive_connection( mut conn: yamux::Connection>, mut command_rx: mpsc::Receiver, ) { - // Buffer for incoming streams. This buffer is unbounded but yamux itself has a limit on the - // total number of concurrently open streams which effectively puts a bound on this buffer as - // well. - let mut incoming = VecDeque::new(); + // Buffers for incoming and outgoing streams. These guarantee that no streams are ever lost, + // even if `Connection::incoming` or `Connection::outgoing` are cancelled. Due to the limit on + // the number of streams per connection, these buffers are effectively bounded. + let mut incoming = Vec::new(); + let mut outgoing = Vec::new(); loop { let command = select! { @@ -150,7 +168,7 @@ async fn drive_connection( result = future::poll_fn(|cx| conn.poll_next_inbound(cx)) => { match result { Some(result) => { - incoming.push_front(result); + incoming.push(result); continue; } None => break, @@ -160,20 +178,33 @@ async fn drive_connection( match command.unwrap_or(Command::Close(None)) { Command::Incoming(reply_tx) => { - if let Some(result) = incoming.pop_back() { - reply_tx.send(result).ok(); - continue; - } - - if let Some(result) = future::poll_fn(|cx| conn.poll_next_inbound(cx)).await { - reply_tx.send(result).ok(); + let result = if let Some(result) = incoming.pop() { + result + } else if let Some(result) = future::poll_fn(|cx| conn.poll_next_inbound(cx)).await + { + result } else { break; + }; + + if let Err(result) = reply_tx.send(result).await { + // reply_rx dropped before receiving the result, save it for next time. + incoming.push(result); } + + continue; } Command::Outgoing(reply_tx) => { - let result = future::poll_fn(|cx| conn.poll_new_outbound(cx)).await; - reply_tx.send(result).ok(); + let result = if let Some(result) = outgoing.pop() { + result + } else { + future::poll_fn(|cx| conn.poll_new_outbound(cx)).await + }; + + if let Err(result) = reply_tx.send(result).await { + // reply_rx dropped before receiving the result, save it for next time. + outgoing.push(result); + } } Command::Close(reply_tx) => { let result = future::poll_fn(|cx| conn.poll_close(cx)).await; @@ -189,8 +220,11 @@ async fn drive_connection( } enum Command { - Incoming(oneshot::Sender>), - Outgoing(oneshot::Sender>), + // Using rendezvous to guarantee the reply is either received or we get it back if the receive + // got cancelled. + Incoming(rendezvous::Sender>), + Outgoing(rendezvous::Sender>), + // Using regular oneshot as we don't care about cancellation here: Close(Option>>), } From 179c1c32e1b7b7a781fceded39cce73eef3d89ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Wed, 4 Sep 2024 16:24:58 +0200 Subject: [PATCH 09/55] Updating the network API --- lib/src/network/connection.rs | 10 +- lib/src/network/gateway.rs | 75 +++---- lib/src/network/message_broker.rs | 9 +- lib/src/network/message_dispatcher.rs | 293 +++++++++++++++++--------- lib/src/network/mod.rs | 54 +++-- lib/src/network/runtime_id.rs | 20 +- lib/src/network/stats.rs | 16 +- net/src/connection.rs | 76 +++++-- net/src/quic.rs | 27 ++- net/src/sync.rs | 11 + net/src/tcp.rs | 47 +++-- 11 files changed, 418 insertions(+), 220 deletions(-) diff --git a/lib/src/network/connection.rs b/lib/src/network/connection.rs index 747409a5f..fe02604e4 100644 --- a/lib/src/network/connection.rs +++ b/lib/src/network/connection.rs @@ -255,18 +255,22 @@ impl ConnectionPermit { /// Dummy connection permit for tests. #[cfg(test)] - pub fn dummy() -> Self { + pub fn dummy(dir: ConnectionDirection) -> Self { use std::net::Ipv4Addr; let key = Key { addr: PeerAddr::Tcp((Ipv4Addr::UNSPECIFIED, 0).into()), - dir: ConnectionDirection::Incoming, + dir, }; let id = ConnectionId::next(); + let source = match dir { + ConnectionDirection::Incoming => PeerSource::Listener, + ConnectionDirection::Outgoing => PeerSource::UserProvided, + }; let data = Data { id, state: PeerState::Known, - source: PeerSource::UserProvided, + source, stats_tracker: StatsTracker::default(), on_release: DropAwaitable::new(), }; diff --git a/lib/src/network/gateway.rs b/lib/src/network/gateway.rs index 39288a5e1..d255be30e 100644 --- a/lib/src/network/gateway.rs +++ b/lib/src/network/gateway.rs @@ -1,7 +1,10 @@ use super::{ip, peer_addr::PeerAddr, peer_source::PeerSource, seen_peers::SeenPeer}; use crate::sync::atomic_slot::AtomicSlot; use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; -use net::{connection::Connection, quic, tcp}; +use net::{ + connection::{Acceptor, Connection}, + quic, tcp, +}; use scoped_task::ScopedJoinHandle; use std::net::{IpAddr, SocketAddr}; use thiserror::Error; @@ -468,8 +471,9 @@ impl QuicStack { }; let listener_local_addr = *acceptor.local_addr(); - let listener_task = - scoped_task::spawn(run_quic_listener(acceptor, incoming_tx).instrument(span)); + let listener_task = scoped_task::spawn( + run_listener(Acceptor::Quic(acceptor), incoming_tx).instrument(span), + ); let hole_puncher = side_channel_maker.make().sender(); @@ -517,7 +521,7 @@ impl TcpStack { let listener_local_addr = *acceptor.local_addr(); let listener_task = - scoped_task::spawn(run_tcp_listener(acceptor, incoming_tx).instrument(span)); + scoped_task::spawn(run_listener(Acceptor::Tcp(acceptor), incoming_tx).instrument(span)); Some(Self { listener_local_addr, @@ -527,28 +531,7 @@ impl TcpStack { } } -async fn run_tcp_listener(acceptor: tcp::Acceptor, tx: mpsc::Sender<(Connection, PeerAddr)>) { - loop { - let result = select! { - result = acceptor.accept() => result, - _ = tx.closed() => break, - }; - - match result { - Ok((stream, addr)) => { - tx.send((Connection::Tcp(stream), PeerAddr::Tcp(addr))) - .await - .ok(); - } - Err(error) => { - tracing::error!(?error, "Failed to accept connection"); - break; - } - } - } -} - -async fn run_quic_listener(listener: quic::Acceptor, tx: mpsc::Sender<(Connection, PeerAddr)>) { +async fn run_listener(listener: Acceptor, tx: mpsc::Sender<(Connection, PeerAddr)>) { let mut tasks = JoinSet::new(); loop { @@ -557,24 +540,30 @@ async fn run_quic_listener(listener: quic::Acceptor, tx: mpsc::Sender<(Connectio _ = tx.closed() => break, }; - if let Some(connecting) = connecting { - let tx = tx.clone(); - let addr = connecting.remote_addr(); - - // Spawn so we can start listening for the next connection ASAP. - tasks.spawn(async move { - match connecting.complete().await { - Ok(connection) => { - tx.send((Connection::Quic(connection), PeerAddr::Quic(addr))) - .await - .ok(); + match connecting { + Ok(connecting) => { + let tx = tx.clone(); + + let addr = connecting.remote_addr(); + let addr = match listener { + Acceptor::Tcp(_) => PeerAddr::Tcp(addr), + Acceptor::Quic(_) => PeerAddr::Quic(addr), + }; + + // Spawn so we can start listening for the next connection ASAP. + tasks.spawn(async move { + match connecting.await { + Ok(connection) => { + tx.send((connection, addr)).await.ok(); + } + Err(error) => tracing::error!(?error, %addr, "Failed to accept connection"), } - Err(error) => tracing::error!(?error, %addr, "Failed to accept connection"), - } - }); - } else { - tracing::error!("Stopped accepting new connections"); - break; + }); + } + Err(error) => { + tracing::error!(?error, "Stopped accepting new connections"); + break; + } } } } diff --git a/lib/src/network/message_broker.rs b/lib/src/network/message_broker.rs index 56bdbda14..ffd5f9b6e 100644 --- a/lib/src/network/message_broker.rs +++ b/lib/src/network/message_broker.rs @@ -67,10 +67,15 @@ impl MessageBroker { } } - pub fn add_connection(&self, connection: Instrumented, permit: ConnectionPermit) { + pub fn add_connection( + &self, + connection: Connection, + permit: ConnectionPermit, + byte_counters: Arc, + ) { self.pex_peer .handle_connection(permit.addr(), permit.source(), permit.released()); - self.dispatcher.bind(connection, permit) + self.dispatcher.bind(connection, permit, byte_counters) } /// Has this broker at least one live connection? diff --git a/lib/src/network/message_dispatcher.rs b/lib/src/network/message_dispatcher.rs index 775836564..fe809f165 100644 --- a/lib/src/network/message_dispatcher.rs +++ b/lib/src/network/message_dispatcher.rs @@ -1,15 +1,19 @@ //! Utilities for sending and receiving messages across the network. use super::{ - connection::{ConnectionId, ConnectionPermit, ConnectionPermitHalf}, + connection::{ConnectionDirection, ConnectionId, ConnectionPermit, ConnectionPermitHalf}, message::{Message, MessageChannelId}, message_io::{MessageSink, MessageStream, MESSAGE_OVERHEAD}, - stats::Instrumented, + stats::{ByteCounters, Instrumented}, }; use crate::{collections::HashMap, sync::AwaitDrop}; use async_trait::async_trait; -use futures_util::{future, ready, stream::SelectAll, FutureExt, Sink, SinkExt, Stream, StreamExt}; -use net::connection::{Connection, OwnedReadHalf, OwnedWriteHalf}; +use futures_util::{ + future, ready, + stream::{FuturesUnordered, SelectAll}, + FutureExt, Sink, SinkExt, Stream, StreamExt, +}; +use net::connection::{Connection, Error as ConnectionError, RecvStream, SendStream}; use std::{ io, pin::Pin, @@ -56,9 +60,18 @@ impl MessageDispatcher { /// Bind this dispatcher to the given TCP of QUIC socket. Can be bound to multiple sockets and /// the failed ones are automatically removed. - pub fn bind(&self, connection: Instrumented, permit: ConnectionPermit) { + pub fn bind( + &self, + connection: Connection, + permit: ConnectionPermit, + byte_counters: Arc, + ) { self.command_tx - .send(Command::Bind { connection, permit }) + .send(Command::Bind { + connection, + permit, + byte_counters, + }) .ok(); } @@ -252,7 +265,7 @@ pub(super) struct ChannelClosed; struct ConnectionStream { // The reader is doubly instrumented - first time to track per connection stats and second time // to track cumulative stats across all connections. - reader: MessageStream>>, + reader: MessageStream>>, permit: ConnectionPermitHalf, permit_released: AwaitDrop, connection_count: Arc, @@ -260,7 +273,7 @@ struct ConnectionStream { impl ConnectionStream { fn new( - reader: Instrumented, + reader: Instrumented, permit: ConnectionPermitHalf, connection_count: Arc, ) -> Self { @@ -307,13 +320,13 @@ impl Drop for ConnectionStream { struct ConnectionSink { // The writer is doubly instrumented - first time to track per connection stats and second time // to track cumulative stats across all connections. - writer: MessageSink>>, + writer: MessageSink>>, _permit: ConnectionPermitHalf, permit_released: AwaitDrop, } impl ConnectionSink { - fn new(writer: Instrumented, permit: ConnectionPermitHalf) -> Self { + fn new(writer: Instrumented, permit: ConnectionPermitHalf) -> Self { let permit_released = permit.released(); Self { @@ -384,59 +397,77 @@ impl Worker { } async fn run(mut self) { + let mut streams = FuturesUnordered::new(); + loop { - select! { - command = self.command_rx.recv() => { - if let Some(command) = command { - self.handle_command(command).await; - } else { - break; + let command = select! { + command = self.command_rx.recv() => command, + Some(result) = streams.next() => { + match result { + Ok((connection, tx, rx)) => { + self.send.sinks.push((connection, tx)); + self.recv.streams.push(rx); + } + Err(error) => { + tracing::debug!(?error, "Failed to establish a new connection stream"); + } } + + continue; } - _ = self.send.run()=> unreachable!(), - _ = self.recv.run()=> unreachable!(), - } - } + _ = self.send.run() => unreachable!(), + _ = self.recv.run() => unreachable!(), + }; - self.shutdown().await; - } + let Some(command) = command else { + break; + }; - async fn handle_command(&mut self, command: Command) { - match command { - Command::Open { channel, stream_tx } => { - self.recv.channels.insert(channel, stream_tx); - } - Command::Close { channel } => { - self.recv.channels.remove(&channel); - } - Command::Bind { connection, permit } => { - let (reader, writer) = connection.into_split(); - let (send_permit, recv_permit) = permit.into_split(); - - self.send - .sinks - .push(ConnectionSink::new(writer, send_permit)); - - self.recv.streams.push(ConnectionStream::new( - reader, - recv_permit, - self.connection_count.clone(), - )); - } - Command::Shutdown { tx } => { - self.shutdown().await; - tx.send(()).ok(); + match command { + Command::Open { channel, stream_tx } => { + self.recv.channels.insert(channel, stream_tx); + } + Command::Close { channel } => { + self.recv.channels.remove(&channel); + } + Command::Bind { + connection, + permit, + byte_counters, + } => { + let connection_count = self.connection_count.clone(); + + streams.push(async move { + let (tx, rx) = match ConnectionDirection::from_source(permit.source()) { + ConnectionDirection::Incoming => connection.incoming().await?, + ConnectionDirection::Outgoing => connection.outgoing().await?, + }; + + let (tx_permit, rx_permit) = permit.into_split(); + + let tx = Instrumented::new(tx, byte_counters.clone()); + let tx = ConnectionSink::new(tx, tx_permit); + + let rx = Instrumented::new(rx, byte_counters.clone()); + let rx = ConnectionStream::new(rx, rx_permit, connection_count); + + Ok::<_, ConnectionError>((connection, tx, rx)) + }); + } + Command::Shutdown { tx } => { + self.shutdown().await; + tx.send(()).ok(); + } } } + + self.shutdown().await; } async fn shutdown(&mut self) { - future::join_all( - self.send - .sinks - .drain(..) - .map(|mut sink| async move { sink.close().await.ok() }), - ) + future::join_all(self.send.sinks.drain(..).map(|(connection, _)| async move { + connection.close().await; + })) .await; self.send.sink_rx.close(); @@ -455,8 +486,9 @@ enum Command { channel: MessageChannelId, }, Bind { - connection: Instrumented, + connection: Connection, permit: ConnectionPermit, + byte_counters: Arc, }, Shutdown { tx: oneshot::Sender<()>, @@ -465,13 +497,15 @@ enum Command { struct SendState { sink_rx: mpsc::Receiver, - sinks: Vec, + // We need to keep the `Connection` around so the sink/stream stay opened. We can store it here + // or in the `RecvState` but storing it here is slightly simpler. + sinks: Vec<(Connection, ConnectionSink)>, } impl SendState { // Keep sending outgoing messages. This function never returns, but it's safe to cancel. async fn run(&mut self) { - while let Some(sink) = self.sinks.first_mut() { + while let Some((_, sink)) = self.sinks.first_mut() { // The order of operations here is important for cancel-safety: first wait for the sink // to become ready for sending, then receive the message to be sent and finally send // the message on the sink. This order ensures that if this function is cancelled at @@ -549,8 +583,8 @@ impl RecvState { mod tests { use super::{super::stats::ByteCounters, *}; use assert_matches::assert_matches; - use futures_util::stream; - use net::{connection::Connection, tcp}; + use futures_util::{future, stream}; + use net::connection::{Acceptor, Connection, Connector}; use std::{collections::BTreeSet, net::Ipv4Addr, str::from_utf8, time::Duration}; #[tokio::test(flavor = "multi_thread")] @@ -562,8 +596,15 @@ mod tests { let mut server_stream = server_dispatcher.open_recv(channel); let (client, server) = create_connection_pair().await; - let mut client_sink = MessageSink::new(client); - server_dispatcher.bind(server, ConnectionPermit::dummy()); + + let (client_tx, _client_rx) = client.outgoing().await.unwrap(); + let mut client_sink = MessageSink::new(client_tx); + + server_dispatcher.bind( + server, + ConnectionPermit::dummy(ConnectionDirection::Incoming), + Arc::new(ByteCounters::default()), + ); client_sink .send(Message { @@ -590,8 +631,14 @@ mod tests { let server_stream1 = server_dispatcher.open_recv(channel1); let (client, server) = create_connection_pair().await; - let mut client_sink = MessageSink::new(client); - server_dispatcher.bind(server, ConnectionPermit::dummy()); + + let mut client_sink = MessageSink::new(client.outgoing().await.unwrap().0); + + server_dispatcher.bind( + server, + ConnectionPermit::dummy(ConnectionDirection::Incoming), + Arc::new(ByteCounters::default()), + ); for (channel, content) in [(channel0, send_content0), (channel1, send_content1)] { client_sink @@ -628,8 +675,16 @@ mod tests { let server_stream1 = server_dispatcher.open_recv(channel1); let (client, server) = create_connection_pair().await; - client_dispatcher.bind(client, ConnectionPermit::dummy()); - server_dispatcher.bind(server, ConnectionPermit::dummy()); + client_dispatcher.bind( + client, + ConnectionPermit::dummy(ConnectionDirection::Outgoing), + Arc::new(ByteCounters::new()), + ); + server_dispatcher.bind( + server, + ConnectionPermit::dummy(ConnectionDirection::Incoming), + Arc::new(ByteCounters::new()), + ); let num_messages = 20; let mut send_tasks = vec![]; @@ -674,8 +729,14 @@ mod tests { let mut server_stream1 = server_dispatcher.open_recv(channel); let (client, server) = create_connection_pair().await; - let mut client_sink = MessageSink::new(client); - server_dispatcher.bind(server, ConnectionPermit::dummy()); + + let mut client_sink = MessageSink::new(client.outgoing().await.unwrap().0); + + server_dispatcher.bind( + server, + ConnectionPermit::dummy(ConnectionDirection::Incoming), + Arc::new(ByteCounters::new()), + ); for content in [send_content0, send_content1] { client_sink @@ -697,6 +758,8 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn multiple_connections_recv() { + crate::test_utils::init_log(); + let channel = MessageChannelId::random(); let send_content0 = b"one two three"; @@ -705,14 +768,22 @@ mod tests { let server_dispatcher = MessageDispatcher::new(); let mut server_stream = server_dispatcher.open_recv(channel); - let (client_socket0, server_socket0) = create_connection_pair().await; - let (client_socket1, server_socket1) = create_connection_pair().await; + let (client0, server0) = create_connection_pair().await; + let (client1, server1) = create_connection_pair().await; - let client_sink0 = MessageSink::new(client_socket0); - let client_sink1 = MessageSink::new(client_socket1); + let client_sink0 = MessageSink::new(client0.outgoing().await.unwrap().0); + let client_sink1 = MessageSink::new(client1.outgoing().await.unwrap().0); - server_dispatcher.bind(server_socket0, ConnectionPermit::dummy()); - server_dispatcher.bind(server_socket1, ConnectionPermit::dummy()); + server_dispatcher.bind( + server0, + ConnectionPermit::dummy(ConnectionDirection::Incoming), + Arc::new(ByteCounters::new()), + ); + server_dispatcher.bind( + server1, + ConnectionPermit::dummy(ConnectionDirection::Incoming), + Arc::new(ByteCounters::new()), + ); for (mut client_sink, content) in [(client_sink0, send_content0), (client_sink1, send_content1)] @@ -737,13 +808,13 @@ mod tests { // The messages may be received in any order assert_eq!( - [recv_content0.as_slice(), recv_content1.as_slice()] - .into_iter() - .collect::>(), - [send_content0.as_slice(), send_content1.as_slice()] - .into_iter() - .collect::>(), + BTreeSet::from([recv_content0.as_slice(), recv_content1.as_slice()]), + BTreeSet::from([send_content0.as_slice(), send_content1.as_slice()]), ); + + client0.close().await; + client1.close().await; + server_dispatcher.shutdown().await; } #[tokio::test(flavor = "multi_thread")] @@ -756,21 +827,46 @@ mod tests { let server_dispatcher = MessageDispatcher::new(); let server_sink = server_dispatcher.open_send(channel); - let (client_socket0, server_socket0) = create_connection_pair().await; - let (client_socket1, server_socket1) = create_connection_pair().await; + let (client0, server0) = create_connection_pair().await; + let (client1, server1) = create_connection_pair().await; - let client_stream0 = MessageStream::new(client_socket0); - let client_stream1 = MessageStream::new(client_socket1); + let (client0_tx, client0_rx) = client0.outgoing().await.unwrap(); + let (client1_tx, client1_rx) = client1.outgoing().await.unwrap(); - server_dispatcher.bind(server_socket0, ConnectionPermit::dummy()); - server_dispatcher.bind(server_socket1, ConnectionPermit::dummy()); + // The incoming streams are accepted only after something is sent on the corresponding + // outgoing streams first. + let mut client0_sink = MessageSink::new(client0_tx); + let mut client1_sink = MessageSink::new(client1_tx); + + for sink in [&mut client0_sink, &mut client1_sink] { + sink.send(Message { + channel, + content: Vec::new(), + }) + .await + .unwrap(); + } + + let client0_stream = MessageStream::new(client0_rx); + let client1_stream = MessageStream::new(client1_rx); + + server_dispatcher.bind( + server0, + ConnectionPermit::dummy(ConnectionDirection::Incoming), + Arc::new(ByteCounters::new()), + ); + server_dispatcher.bind( + server1, + ConnectionPermit::dummy(ConnectionDirection::Incoming), + Arc::new(ByteCounters::new()), + ); for content in [send_content0, send_content1] { server_sink.send(content.to_vec()).await.unwrap(); } // The messages may be received on any stream - let recv_contents: BTreeSet<_> = stream::select(client_stream0, client_stream1) + let recv_contents: BTreeSet<_> = stream::select(client0_stream, client1_stream) .map(|message| message.unwrap().content) .take(2) .collect() @@ -800,21 +896,20 @@ mod tests { assert_matches!(server_sink.send(vec![]).await, Err(ChannelClosed)); } - async fn create_connection_pair() -> (Instrumented, Instrumented) { - let (_server_connector, server_acceptor) = - tcp::configure((Ipv4Addr::LOCALHOST, 0u16).into()).unwrap(); - let (client_connector, _client_acceptor) = - tcp::configure((Ipv4Addr::LOCALHOST, 0u16).into()).unwrap(); + async fn create_connection_pair() -> (Connection, Connection) { + let client = net::tcp::configure((Ipv4Addr::LOCALHOST, 0).into()) + .unwrap() + .0; + let server = net::tcp::configure((Ipv4Addr::LOCALHOST, 0).into()) + .unwrap() + .1; - let client = client_connector - .connect(*server_acceptor.local_addr()) - .await - .unwrap(); - let (server, _) = server_acceptor.accept().await.unwrap(); + let client = Connector::from(client); + let server = Acceptor::from(server); + + let client = client.connect(*server.local_addr()); + let server = async { server.accept().await?.await }; - ( - Instrumented::new(Connection::Tcp(client), Arc::new(ByteCounters::default())), - Instrumented::new(Connection::Tcp(server), Arc::new(ByteCounters::default())), - ) + future::try_join(client, server).await.unwrap() } } diff --git a/lib/src/network/mod.rs b/lib/src/network/mod.rs index b0e22b004..14e5d6bfd 100644 --- a/lib/src/network/mod.rs +++ b/lib/src/network/mod.rs @@ -40,7 +40,6 @@ pub use self::{ runtime_id::{PublicRuntimeId, SecretRuntimeId}, stats::Stats, }; -use net::connection::Connection; pub use net::stun::NatBehavior; use self::{ @@ -60,7 +59,7 @@ use self::{ }; use crate::{ collections::{hash_map::Entry, HashMap, HashSet}, - network::stats::Instrumented, + network::connection::ConnectionDirection, protocol::RepositoryId, repository::{RepositoryHandle, Vault}, sync::uninitialized_watch, @@ -69,6 +68,7 @@ use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use btdht::{self, InfoHash, INFO_HASH_LEN}; use deadlock::BlockingMutex; use futures_util::future; +use net::connection::{Connection, Error as ConnectionError}; use scoped_task::ScopedAbortHandle; use slab::Slab; use state_monitor::StateMonitor; @@ -840,7 +840,7 @@ impl Inner { /// Return true iff the peer is suitable for reconnection. async fn handle_connection( &self, - mut connection: Connection, + connection: Connection, permit: ConnectionPermit, monitor: &ConnectionMonitor, ) -> bool { @@ -849,8 +849,13 @@ impl Inner { permit.mark_as_handshaking(); monitor.mark_as_handshaking(); - let handshake_result = - perform_handshake(&mut connection, VERSION, &self.this_runtime_id).await; + let handshake_result = perform_handshake( + &connection, + VERSION, + &self.this_runtime_id, + ConnectionDirection::from_source(permit.source()), + ) + .await; if let Err(error) = &handshake_result { tracing::debug!(parent: monitor.span(), ?error, "Handshake failed"); @@ -862,9 +867,12 @@ impl Inner { self.on_protocol_mismatch(their_version); return false; } - Err(HandshakeError::Timeout | HandshakeError::BadMagic | HandshakeError::Fatal(_)) => { - return false - } + Err( + HandshakeError::Timeout + | HandshakeError::BadMagic + | HandshakeError::Io(_) + | HandshakeError::Connection(_), + ) => return false, }; // prevent self-connections. @@ -916,8 +924,7 @@ impl Inner { broker }); - let connection = Instrumented::new(connection, self.stats_tracker.bytes.clone()); - broker.add_connection(connection, permit); + broker.add_connection(connection, permit, self.stats_tracker.bytes.clone()); } let _remover = MessageBrokerEntryGuard { @@ -963,28 +970,37 @@ impl Inner { // Exchange runtime ids with the peer. Returns their (verified) runtime id. async fn perform_handshake( - connection: &mut Connection, + connection: &Connection, this_version: Version, this_runtime_id: &SecretRuntimeId, + dir: ConnectionDirection, ) -> Result { let result = tokio::time::timeout(std::time::Duration::from_secs(5), async move { - connection.write_all(MAGIC).await?; + let (mut writer, mut reader) = match dir { + ConnectionDirection::Incoming => connection.incoming().await?, + ConnectionDirection::Outgoing => connection.outgoing().await?, + }; - this_version.write_into(connection).await?; + writer.write_all(MAGIC).await?; + + this_version.write_into(&mut writer).await?; let mut that_magic = [0; MAGIC.len()]; - connection.read_exact(&mut that_magic).await?; + reader.read_exact(&mut that_magic).await?; if MAGIC != &that_magic { return Err(HandshakeError::BadMagic); } - let that_version = Version::read_from(connection).await?; + let that_version = Version::read_from(&mut reader).await?; if that_version > this_version { return Err(HandshakeError::ProtocolVersionMismatch(that_version)); } - let that_runtime_id = runtime_id::exchange(this_runtime_id, connection).await?; + let that_runtime_id = + runtime_id::exchange(this_runtime_id, &mut writer, &mut reader).await?; + + writer.shutdown().await?; Ok(that_runtime_id) }) @@ -1004,8 +1020,10 @@ enum HandshakeError { BadMagic, #[error("timeout")] Timeout, - #[error("fatal error")] - Fatal(#[from] io::Error), + #[error("IO error")] + Io(#[from] io::Error), + #[error("connection error")] + Connection(#[from] ConnectionError), } // RAII guard which when dropped removes the broker from the network state if it has no connections. diff --git a/lib/src/network/runtime_id.rs b/lib/src/network/runtime_id.rs index 15c793cb5..5c00a8cfe 100644 --- a/lib/src/network/runtime_id.rs +++ b/lib/src/network/runtime_id.rs @@ -85,26 +85,28 @@ impl Hashable for PublicRuntimeId { } } -pub async fn exchange( +pub async fn exchange( our_runtime_id: &SecretRuntimeId, - io: &mut IO, + writer: &mut W, + reader: &mut R, ) -> io::Result where - IO: AsyncRead + AsyncWrite + Unpin, + W: AsyncWrite + Unpin, + R: AsyncRead + Unpin, { let our_challenge: [u8; 32] = OsRng.gen(); - io.write_all(&our_challenge).await?; - our_runtime_id.public().write_into(io).await?; + writer.write_all(&our_challenge).await?; + our_runtime_id.public().write_into(writer).await?; - let their_challenge = read_bytes::<32, IO>(io).await?; - let their_runtime_id = PublicRuntimeId::read_from(io).await?; + let their_challenge: [_; 32] = read_bytes(reader).await?; + let their_runtime_id = PublicRuntimeId::read_from(reader).await?; let our_signature = our_runtime_id.keypair.sign(&to_sign(&their_challenge)); - io.write_all(&our_signature.to_bytes()).await?; + writer.write_all(&our_signature.to_bytes()).await?; - let their_signature = read_bytes::<{ Signature::SIZE }, IO>(io).await?; + let their_signature: [_; Signature::SIZE] = read_bytes(reader).await?; let their_signature = Signature::from(&their_signature); if !their_runtime_id diff --git a/lib/src/network/stats.rs b/lib/src/network/stats.rs index 80441a7fb..878bfdbcf 100644 --- a/lib/src/network/stats.rs +++ b/lib/src/network/stats.rs @@ -1,4 +1,3 @@ -use net::connection::{Connection, OwnedReadHalf, OwnedWriteHalf}; use pin_project_lite::pin_project; use serde::{Deserialize, Serialize}; use std::{ @@ -65,6 +64,10 @@ pub(super) struct ByteCounters { } impl ByteCounters { + pub fn new() -> Self { + Self::default() + } + pub fn increment_tx(&self, by: u64) { self.tx.fetch_add(by, Ordering::Relaxed); } @@ -223,17 +226,6 @@ where } } -impl Instrumented { - pub fn into_split(self) -> (Instrumented, Instrumented) { - let (reader, writer) = self.inner.into_split(); - - ( - Instrumented::new(reader, self.counters.clone()), - Instrumented::new(writer, self.counters), - ) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/net/src/connection.rs b/net/src/connection.rs index 27ba8d1b8..319be649f 100644 --- a/net/src/connection.rs +++ b/net/src/connection.rs @@ -1,6 +1,6 @@ use crate::{quic, tcp}; use std::{ - future::Future, + future::{self, Future, IntoFuture, Ready}, io, net::SocketAddr, pin::Pin, @@ -32,6 +32,18 @@ impl Connector { } } +impl From for Connector { + fn from(inner: tcp::Connector) -> Self { + Self::Tcp(inner) + } +} + +impl From for Connector { + fn from(inner: quic::Connector) -> Self { + Self::Quic(inner) + } +} + /// Unified acceptor pub enum Acceptor { Tcp(tcp::Acceptor), @@ -48,29 +60,67 @@ impl Acceptor { pub async fn accept(&self) -> Result { match self { - Self::Tcp(inner) => Ok(Connecting::Tcp(Some(inner.accept().await?))), + Self::Tcp(inner) => Ok(Connecting::Tcp(inner.accept().await?)), Self::Quic(inner) => Ok(Connecting::Quic(inner.accept().await?)), } } } -/// Incoming connection which being established. +impl From for Acceptor { + fn from(inner: tcp::Acceptor) -> Self { + Self::Tcp(inner) + } +} + +impl From for Acceptor { + fn from(inner: quic::Acceptor) -> Self { + Self::Quic(inner) + } +} + +/// Incoming connection while being established. pub enum Connecting { // Note TCP doesn't support two phase accept so this is already a fully established // connection. - Tcp(Option), + Tcp(tcp::Connection), Quic(quic::Connecting), } -impl Future for Connecting { +impl Connecting { + pub fn remote_addr(&self) -> SocketAddr { + match self { + Self::Tcp(inner) => inner.remote_addr(), + Self::Quic(inner) => inner.remote_addr(), + } + } +} + +impl IntoFuture for Connecting { + type Output = Result; + type IntoFuture = ConnectingFuture; + + fn into_future(self) -> Self::IntoFuture { + match self { + Self::Tcp(inner) => ConnectingFuture::Tcp(future::ready(inner)), + Self::Quic(inner) => ConnectingFuture::Quic(inner.into_future()), + } + } +} + +pub enum ConnectingFuture { + Tcp(Ready), + Quic(quic::ConnectingFuture), +} + +impl Future for ConnectingFuture { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.get_mut() { - Self::Tcp(connection) => Poll::Ready(Ok(Connection::Tcp( - connection.take().expect("future polled after completion"), - ))), - Self::Quic(connecting) => Pin::new(connecting) + Self::Tcp(inner) => Pin::new(inner) + .poll(cx) + .map(|inner| Ok(Connection::Tcp(inner))), + Self::Quic(inner) => Pin::new(inner) .poll(cx) .map_ok(Connection::Quic) .map_err(Into::into), @@ -125,13 +175,11 @@ impl Connection { } /// Gracefully close the connection - pub async fn close(&self) -> Result<(), Error> { + pub async fn close(&self) { match self { - Self::Tcp(inner) => inner.close().await?, + Self::Tcp(inner) => inner.close().await, Self::Quic(inner) => inner.close(), } - - Ok(()) } } @@ -323,7 +371,7 @@ mod tests { let sent_messages: Vec<_> = tasks.collect().await; - conn.close().await.unwrap(); + conn.close().await; sent_messages } diff --git a/net/src/quic.rs b/net/src/quic.rs index e19e5ac1b..7ecd20dc6 100644 --- a/net/src/quic.rs +++ b/net/src/quic.rs @@ -59,9 +59,7 @@ impl Acceptor { self.endpoint .accept() .await - .map(|inner| Connecting { - inner: inner.into_future(), - }) + .map(|inner| Connecting { inner }) .ok_or(Error::EndpointClosed) } @@ -71,10 +69,31 @@ impl Acceptor { } pub struct Connecting { + inner: quinn::Incoming, +} + +impl Connecting { + pub fn remote_addr(&self) -> SocketAddr { + self.inner.remote_address() + } +} + +impl IntoFuture for Connecting { + type Output = Result; + type IntoFuture = ConnectingFuture; + + fn into_future(self) -> Self::IntoFuture { + ConnectingFuture { + inner: self.inner.into_future(), + } + } +} + +pub struct ConnectingFuture { inner: quinn::IncomingFuture, } -impl Future for Connecting { +impl Future for ConnectingFuture { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { diff --git a/net/src/sync.rs b/net/src/sync.rs index 2fb5d972e..6167e2da8 100644 --- a/net/src/sync.rs +++ b/net/src/sync.rs @@ -42,6 +42,17 @@ pub(crate) mod rendezvous { } } } + + /// Waits for the associated [`Receiver`] to close (that is, to be dropped). + pub async fn closed(&self) { + loop { + if self.shared.state.lock().unwrap().rx_drop { + break; + } + + self.shared.rx_notify.notified().await; + } + } } impl Drop for Sender { diff --git a/net/src/tcp.rs b/net/src/tcp.rs index 927d08929..73e5cc0ce 100644 --- a/net/src/tcp.rs +++ b/net/src/tcp.rs @@ -1,6 +1,5 @@ -use crate::sync::rendezvous; - use self::implementation::{TcpListener, TcpStream}; +use crate::sync::rendezvous; use std::{future, io, net::SocketAddr}; use tokio::{ io::{ReadHalf, WriteHalf}, @@ -125,17 +124,22 @@ impl Connection { /// # Cancel safety /// /// This function is idempotent even in the presence of cancellation. - pub async fn close(&self) -> Result<(), Error> { + pub async fn close(&self) { let (reply_tx, reply_rx) = oneshot::channel(); - // If send or receive return an error it means the connection is already closed. Returning - // `Ok` in that case to make this function idempotent. - self.command_tx + if self + .command_tx .send(Command::Close(Some(reply_tx))) .await - .ok(); + .is_err() + { + return; + } - reply_rx.await.unwrap_or(Ok(())).map_err(Into::into) + match reply_rx.await { + Ok(Ok(())) | Err(_) => (), + Ok(Err(error)) => tracing::debug!(?error, "failed to close connection"), + } } } @@ -180,19 +184,24 @@ async fn drive_connection( Command::Incoming(reply_tx) => { let result = if let Some(result) = incoming.pop() { result - } else if let Some(result) = future::poll_fn(|cx| conn.poll_next_inbound(cx)).await - { - result } else { - break; + select! { + result = future::poll_fn(|cx| conn.poll_next_inbound(cx)) => { + if let Some(result) = result { + result + } else { + // connection closed + break; + } + } + _ = reply_tx.closed() => continue, + } }; if let Err(result) = reply_tx.send(result).await { // reply_rx dropped before receiving the result, save it for next time. incoming.push(result); } - - continue; } Command::Outgoing(reply_tx) => { let result = if let Some(result) = outgoing.pop() { @@ -224,7 +233,7 @@ enum Command { // got cancelled. Incoming(rendezvous::Sender>), Outgoing(rendezvous::Sender>), - // Using regular oneshot as we don't care about cancellation here: + // Using regular oneshot as we don't care about cancellation here Close(Option>>), } @@ -238,7 +247,7 @@ mod implementation { use crate::{socket, KEEP_ALIVE_INTERVAL}; use socket2::{Domain, Socket, TcpKeepalive, Type}; use std::{ - io, + fmt, io, net::SocketAddr, pin::Pin, task::{Context, Poll}, @@ -298,6 +307,12 @@ mod implementation { } } + impl fmt::Debug for TcpStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self.0) + } + } + fn set_keep_alive(socket: &Socket) -> io::Result<()> { let options = TcpKeepalive::new() .with_time(KEEP_ALIVE_INTERVAL) From 502ad31a502455d16195d12c219920205deeaddd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Mon, 9 Sep 2024 08:22:25 +0200 Subject: [PATCH 10/55] net: Use atomics to track rendezvous state --- lib/src/network/message_dispatcher.rs | 4 +- net/src/sync.rs | 58 +++++++++++++++++---------- 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/lib/src/network/message_dispatcher.rs b/lib/src/network/message_dispatcher.rs index fe809f165..f3099cfd5 100644 --- a/lib/src/network/message_dispatcher.rs +++ b/lib/src/network/message_dispatcher.rs @@ -897,10 +897,10 @@ mod tests { } async fn create_connection_pair() -> (Connection, Connection) { - let client = net::tcp::configure((Ipv4Addr::LOCALHOST, 0).into()) + let client = net::quic::configure((Ipv4Addr::LOCALHOST, 0).into()) .unwrap() .0; - let server = net::tcp::configure((Ipv4Addr::LOCALHOST, 0).into()) + let server = net::quic::configure((Ipv4Addr::LOCALHOST, 0).into()) .unwrap() .1; diff --git a/net/src/sync.rs b/net/src/sync.rs index 6167e2da8..a2e3d5a4b 100644 --- a/net/src/sync.rs +++ b/net/src/sync.rs @@ -13,7 +13,10 @@ pub(crate) mod rendezvous { use std::{ fmt, - sync::{Arc, Mutex}, + sync::{ + atomic::{AtomicU8, Ordering}, + Arc, Mutex, + }, }; use tokio::sync::Notify; @@ -25,20 +28,20 @@ pub(crate) mod rendezvous { impl Sender { /// Sends the `value` to the [`Receiver`]. pub async fn send(self, value: T) -> Result<(), T> { - self.shared.state.lock().unwrap().value = Some(value); + *self.shared.value.lock().unwrap() = Some(value); self.shared.tx_notify.notify_one(); loop { self.shared.rx_notify.notified().await; - let mut state = self.shared.state.lock().unwrap(); + let mut value = self.shared.value.lock().unwrap(); - if state.value.is_none() { + if value.is_none() { return Ok(()); } - if state.rx_drop { - return Err(state.value.take().unwrap()); + if self.shared.state.get(RX_DROP) { + return Err(value.take().unwrap()); } } } @@ -46,7 +49,7 @@ pub(crate) mod rendezvous { /// Waits for the associated [`Receiver`] to close (that is, to be dropped). pub async fn closed(&self) { loop { - if self.shared.state.lock().unwrap().rx_drop { + if self.shared.state.get(RX_DROP) { break; } @@ -57,7 +60,7 @@ pub(crate) mod rendezvous { impl Drop for Sender { fn drop(&mut self) { - self.shared.state.lock().unwrap().tx_drop = true; + self.shared.state.set(TX_DROP); self.shared.tx_notify.notify_one(); } } @@ -76,14 +79,14 @@ pub(crate) mod rendezvous { loop { self.shared.tx_notify.notified().await; - let mut state = self.shared.state.lock().unwrap(); + let value = self.shared.value.lock().unwrap().take(); - if let Some(value) = state.value.take() { + if let Some(value) = value { self.shared.rx_notify.notify_one(); return Ok(value); } - if state.tx_drop { + if self.shared.state.get(TX_DROP) { return Err(RecvError); } } @@ -92,7 +95,7 @@ pub(crate) mod rendezvous { impl Drop for Receiver { fn drop(&mut self) { - self.shared.state.lock().unwrap().rx_drop = true; + self.shared.state.set(RX_DROP); self.shared.rx_notify.notify_one(); } } @@ -100,11 +103,8 @@ pub(crate) mod rendezvous { /// Create a rendezvous channel for sending a single message of type `T`. pub fn channel() -> (Sender, Receiver) { let shared = Arc::new(Shared { - state: Mutex::new(State { - value: None, - tx_drop: false, - rx_drop: false, - }), + value: Mutex::new(None), + state: State::new(), tx_notify: Notify::new(), rx_notify: Notify::new(), }); @@ -129,17 +129,31 @@ pub(crate) mod rendezvous { impl std::error::Error for RecvError {} struct Shared { - state: Mutex>, + value: Mutex>, + state: State, tx_notify: Notify, rx_notify: Notify, } - struct State { - value: Option, - tx_drop: bool, - rx_drop: bool, + struct State(AtomicU8); + + impl State { + fn new() -> Self { + Self(AtomicU8::new(0)) + } + + fn get(&self, flag: u8) -> bool { + self.0.load(Ordering::Acquire) & flag == flag + } + + fn set(&self, flag: u8) { + self.0.fetch_or(flag, Ordering::AcqRel); + } } + const TX_DROP: u8 = 1; + const RX_DROP: u8 = 2; + #[cfg(test)] mod tests { use super::*; From 9b08c83ccd80cf0ffc6bc50e2ef860a8b24419f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Mon, 9 Sep 2024 08:29:41 +0200 Subject: [PATCH 11/55] net: rename module connection -> unified --- lib/src/network/gateway.rs | 2 +- lib/src/network/message_broker.rs | 2 +- lib/src/network/message_dispatcher.rs | 4 ++-- lib/src/network/mod.rs | 2 +- net/examples/peer.rs | 2 +- net/src/lib.rs | 2 +- net/src/{connection.rs => unified.rs} | 16 +++++++++------- 7 files changed, 16 insertions(+), 14 deletions(-) rename net/src/{connection.rs => unified.rs} (97%) diff --git a/lib/src/network/gateway.rs b/lib/src/network/gateway.rs index d255be30e..2e2129331 100644 --- a/lib/src/network/gateway.rs +++ b/lib/src/network/gateway.rs @@ -2,8 +2,8 @@ use super::{ip, peer_addr::PeerAddr, peer_source::PeerSource, seen_peers::SeenPe use crate::sync::atomic_slot::AtomicSlot; use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use net::{ - connection::{Acceptor, Connection}, quic, tcp, + unified::{Acceptor, Connection}, }; use scoped_task::ScopedJoinHandle; use std::net::{IpAddr, SocketAddr}; diff --git a/lib/src/network/message_broker.rs b/lib/src/network/message_broker.rs index ffd5f9b6e..d45f6e8e4 100644 --- a/lib/src/network/message_broker.rs +++ b/lib/src/network/message_broker.rs @@ -17,7 +17,7 @@ use crate::{ repository::Vault, }; use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; -use net::connection::Connection; +use net::unified::Connection; use state_monitor::StateMonitor; use std::{future, sync::Arc}; use tokio::{ diff --git a/lib/src/network/message_dispatcher.rs b/lib/src/network/message_dispatcher.rs index f3099cfd5..b562775c9 100644 --- a/lib/src/network/message_dispatcher.rs +++ b/lib/src/network/message_dispatcher.rs @@ -13,7 +13,7 @@ use futures_util::{ stream::{FuturesUnordered, SelectAll}, FutureExt, Sink, SinkExt, Stream, StreamExt, }; -use net::connection::{Connection, Error as ConnectionError, RecvStream, SendStream}; +use net::unified::{Connection, ConnectionError, RecvStream, SendStream}; use std::{ io, pin::Pin, @@ -584,7 +584,7 @@ mod tests { use super::{super::stats::ByteCounters, *}; use assert_matches::assert_matches; use futures_util::{future, stream}; - use net::connection::{Acceptor, Connection, Connector}; + use net::unified::{Acceptor, Connection, Connector}; use std::{collections::BTreeSet, net::Ipv4Addr, str::from_utf8, time::Duration}; #[tokio::test(flavor = "multi_thread")] diff --git a/lib/src/network/mod.rs b/lib/src/network/mod.rs index 14e5d6bfd..38ac149ca 100644 --- a/lib/src/network/mod.rs +++ b/lib/src/network/mod.rs @@ -68,7 +68,7 @@ use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use btdht::{self, InfoHash, INFO_HASH_LEN}; use deadlock::BlockingMutex; use futures_util::future; -use net::connection::{Connection, Error as ConnectionError}; +use net::unified::{Connection, ConnectionError}; use scoped_task::ScopedAbortHandle; use slab::Slab; use state_monitor::StateMonitor; diff --git a/net/examples/peer.rs b/net/examples/peer.rs index 6a7a24497..76d12d4c6 100644 --- a/net/examples/peer.rs +++ b/net/examples/peer.rs @@ -1,8 +1,8 @@ use anyhow::Result; use clap::{Parser, ValueEnum}; use ouisync_net::{ - connection::{Acceptor, Connection, Connector, RecvStream, SendStream}, quic, tcp, + unified::{Acceptor, Connection, Connector, RecvStream, SendStream}, }; use std::{ future, diff --git a/net/src/lib.rs b/net/src/lib.rs index 773f79e6a..bb354280e 100644 --- a/net/src/lib.rs +++ b/net/src/lib.rs @@ -1,10 +1,10 @@ use std::time::Duration; -pub mod connection; pub mod quic; pub mod stun; pub mod tcp; pub mod udp; +pub mod unified; #[cfg(not(feature = "simulation"))] mod socket; diff --git a/net/src/connection.rs b/net/src/unified.rs similarity index 97% rename from net/src/connection.rs rename to net/src/unified.rs index 319be649f..26ce298d1 100644 --- a/net/src/connection.rs +++ b/net/src/unified.rs @@ -1,3 +1,5 @@ +//! Unified interface over different network protocols (currently TCP and QUIC). + use crate::{quic, tcp}; use std::{ future::{self, Future, IntoFuture, Ready}, @@ -16,7 +18,7 @@ pub enum Connector { } impl Connector { - pub async fn connect(&self, addr: SocketAddr) -> Result { + pub async fn connect(&self, addr: SocketAddr) -> Result { match self { Self::Tcp(inner) => inner .connect(addr) @@ -58,7 +60,7 @@ impl Acceptor { } } - pub async fn accept(&self) -> Result { + pub async fn accept(&self) -> Result { match self { Self::Tcp(inner) => Ok(Connecting::Tcp(inner.accept().await?)), Self::Quic(inner) => Ok(Connecting::Quic(inner.accept().await?)), @@ -96,7 +98,7 @@ impl Connecting { } impl IntoFuture for Connecting { - type Output = Result; + type Output = Result; type IntoFuture = ConnectingFuture; fn into_future(self) -> Self::IntoFuture { @@ -113,7 +115,7 @@ pub enum ConnectingFuture { } impl Future for ConnectingFuture { - type Output = Result; + type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.get_mut() { @@ -143,7 +145,7 @@ impl Connection { } /// Accept a new incoming stream - pub async fn incoming(&self) -> Result<(SendStream, RecvStream), Error> { + pub async fn incoming(&self) -> Result<(SendStream, RecvStream), ConnectionError> { match self { Self::Tcp(inner) => inner .incoming() @@ -159,7 +161,7 @@ impl Connection { } /// Open a new outgoing stream - pub async fn outgoing(&self) -> Result<(SendStream, RecvStream), Error> { + pub async fn outgoing(&self) -> Result<(SendStream, RecvStream), ConnectionError> { match self { Self::Tcp(inner) => inner .outgoing() @@ -252,7 +254,7 @@ impl AsyncRead for RecvStream { } #[derive(Error, Debug)] -pub enum Error { +pub enum ConnectionError { #[error("tcp")] Tcp(#[from] tcp::Error), #[error("quic")] From 44d187074ebf5a1ccc504cd748abddc5c5082c80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Mon, 9 Sep 2024 08:53:11 +0200 Subject: [PATCH 12/55] net: Allow configuring whether to enable reuse_addr In production it should be on, but in tests it's better to be off so that we prevent one test from accidentally stealing a binding from another test which could cause those tests to connect to each other which almost certainly causes them to fail. --- lib/src/network/gateway.rs | 65 ++++++++++++++------------- lib/src/network/message_dispatcher.rs | 23 +++++++--- net/examples/peer.rs | 11 +++-- net/src/lib.rs | 2 + net/src/quic.rs | 15 ++++--- net/src/socket.rs | 13 ++++++ net/src/tcp.rs | 31 ++++++++----- net/src/udp.rs | 20 ++++++--- net/src/unified.rs | 14 ++++-- 9 files changed, 127 insertions(+), 67 deletions(-) diff --git a/lib/src/network/gateway.rs b/lib/src/network/gateway.rs index 2e2129331..57ff72034 100644 --- a/lib/src/network/gateway.rs +++ b/lib/src/network/gateway.rs @@ -4,6 +4,7 @@ use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use net::{ quic, tcp, unified::{Acceptor, Connection}, + SocketOptions, }; use scoped_task::ScopedJoinHandle; use std::net::{IpAddr, SocketAddr}; @@ -449,26 +450,27 @@ impl QuicStack { ) -> Option<(Self, quic::SideChannelMaker)> { let span = tracing::info_span!("quic", addr = field::Empty); - let (connector, acceptor, side_channel_maker) = match quic::configure(bind_addr) { - Ok((connector, acceptor, side_channel_maker)) => { - span.record( - "addr", - field::display(PeerAddr::Quic(*acceptor.local_addr())), - ); - tracing::info!(parent: &span, "Stack configured"); + let (connector, acceptor, side_channel_maker) = + match quic::configure(bind_addr, SocketOptions::default().with_reuse_addr()) { + Ok((connector, acceptor, side_channel_maker)) => { + span.record( + "addr", + field::display(PeerAddr::Quic(*acceptor.local_addr())), + ); + tracing::info!(parent: &span, "Stack configured"); - (connector, acceptor, side_channel_maker) - } - Err(error) => { - tracing::warn!( - parent: &span, - bind_addr = %PeerAddr::Quic(bind_addr), - ?error, - "Failed to configure stack" - ); - return None; - } - }; + (connector, acceptor, side_channel_maker) + } + Err(error) => { + tracing::warn!( + parent: &span, + bind_addr = %PeerAddr::Quic(bind_addr), + ?error, + "Failed to configure stack" + ); + return None; + } + }; let listener_local_addr = *acceptor.local_addr(); let listener_task = scoped_task::spawn( @@ -506,18 +508,19 @@ impl TcpStack { ) -> Option { let span = tracing::info_span!("tcp", addr = field::Empty); - let (connector, acceptor) = match tcp::configure(bind_addr) { - Ok(stack) => stack, - Err(error) => { - tracing::warn!( - parent: &span, - bind_addr = %PeerAddr::Tcp(bind_addr), - ?error, - "Failed to configure stack", - ); - return None; - } - }; + let (connector, acceptor) = + match tcp::configure(bind_addr, SocketOptions::default().with_reuse_addr()) { + Ok(stack) => stack, + Err(error) => { + tracing::warn!( + parent: &span, + bind_addr = %PeerAddr::Tcp(bind_addr), + ?error, + "Failed to configure stack", + ); + return None; + } + }; let listener_local_addr = *acceptor.local_addr(); let listener_task = diff --git a/lib/src/network/message_dispatcher.rs b/lib/src/network/message_dispatcher.rs index b562775c9..31a3f7e15 100644 --- a/lib/src/network/message_dispatcher.rs +++ b/lib/src/network/message_dispatcher.rs @@ -584,7 +584,10 @@ mod tests { use super::{super::stats::ByteCounters, *}; use assert_matches::assert_matches; use futures_util::{future, stream}; - use net::unified::{Acceptor, Connection, Connector}; + use net::{ + unified::{Acceptor, Connection, Connector}, + SocketOptions, + }; use std::{collections::BTreeSet, net::Ipv4Addr, str::from_utf8, time::Duration}; #[tokio::test(flavor = "multi_thread")] @@ -897,12 +900,18 @@ mod tests { } async fn create_connection_pair() -> (Connection, Connection) { - let client = net::quic::configure((Ipv4Addr::LOCALHOST, 0).into()) - .unwrap() - .0; - let server = net::quic::configure((Ipv4Addr::LOCALHOST, 0).into()) - .unwrap() - .1; + // NOTE: Make sure to keep the `reuse_addr` option disabled here to avoid one test to + // accidentally connect to a different test. More details here: + // https://gavv.net/articles/ephemeral-port-reuse/. + + let client = + net::quic::configure((Ipv4Addr::LOCALHOST, 0).into(), SocketOptions::default()) + .unwrap() + .0; + let server = + net::quic::configure((Ipv4Addr::LOCALHOST, 0).into(), SocketOptions::default()) + .unwrap() + .1; let client = Connector::from(client); let server = Acceptor::from(server); diff --git a/net/examples/peer.rs b/net/examples/peer.rs index 76d12d4c6..173a2a4da 100644 --- a/net/examples/peer.rs +++ b/net/examples/peer.rs @@ -3,6 +3,7 @@ use clap::{Parser, ValueEnum}; use ouisync_net::{ quic, tcp, unified::{Acceptor, Connection, Connector, RecvStream, SendStream}, + SocketOptions, }; use std::{ future, @@ -70,11 +71,13 @@ async fn run_client(options: &Options) -> Result<()> { let addr: SocketAddr = (options.addr.unwrap_or(DEFAULT_CONNECT_ADDR), options.port).into(); let connector = match options.proto { Proto::Tcp => { - let (connector, _) = tcp::configure((DEFAULT_BIND_ADDR, 0).into())?; + let (connector, _) = + tcp::configure((DEFAULT_BIND_ADDR, 0).into(), SocketOptions::default())?; Connector::Tcp(connector) } Proto::Quic => { - let (connector, _, _) = quic::configure((DEFAULT_BIND_ADDR, 0).into())?; + let (connector, _, _) = + quic::configure((DEFAULT_BIND_ADDR, 0).into(), SocketOptions::default())?; Connector::Quic(connector) } }; @@ -111,11 +114,11 @@ async fn run_server(options: &Options) -> Result<()> { let acceptor = match options.proto { Proto::Tcp => { - let (_, acceptor) = tcp::configure(bind_addr)?; + let (_, acceptor) = tcp::configure(bind_addr, SocketOptions::default())?; Acceptor::Tcp(acceptor) } Proto::Quic => { - let (_, acceptor, _) = quic::configure(bind_addr)?; + let (_, acceptor, _) = quic::configure(bind_addr, SocketOptions::default())?; Acceptor::Quic(acceptor) } }; diff --git a/net/src/lib.rs b/net/src/lib.rs index bb354280e..255825252 100644 --- a/net/src/lib.rs +++ b/net/src/lib.rs @@ -10,4 +10,6 @@ pub mod unified; mod socket; mod sync; +pub use socket::SocketOptions; + pub const KEEP_ALIVE_INTERVAL: Duration = Duration::from_secs(10); diff --git a/net/src/quic.rs b/net/src/quic.rs index 7ecd20dc6..57bd24e10 100644 --- a/net/src/quic.rs +++ b/net/src/quic.rs @@ -1,4 +1,4 @@ -use crate::KEEP_ALIVE_INTERVAL; +use crate::{SocketOptions, KEEP_ALIVE_INTERVAL}; use bytes::BytesMut; use pin_project_lite::pin_project; use quinn::{ @@ -130,9 +130,12 @@ pub type SendStream = quinn::SendStream; pub type RecvStream = quinn::RecvStream; //------------------------------------------------------------------------------ -pub fn configure(bind_addr: SocketAddr) -> Result<(Connector, Acceptor, SideChannelMaker), Error> { +pub fn configure( + bind_addr: SocketAddr, + options: SocketOptions, +) -> Result<(Connector, Acceptor, SideChannelMaker), Error> { let server_config = make_server_config()?; - let custom_socket = Arc::new(CustomUdpSocket::bind(bind_addr)?); + let custom_socket = Arc::new(CustomUdpSocket::bind(bind_addr, options)?); let side_channel_maker = custom_socket.clone().side_channel_maker(); let mut endpoint = quinn::Endpoint::new_with_abstract_socket( @@ -304,8 +307,8 @@ struct CustomUdpSocket { } impl CustomUdpSocket { - fn bind(addr: SocketAddr) -> io::Result { - let socket = crate::udp::UdpSocket::bind(addr)?; + fn bind(addr: SocketAddr, options: SocketOptions) -> io::Result { + let socket = crate::udp::UdpSocket::bind_with_options(addr, options)?; let socket = socket.into_std()?; let state = quinn::udp::UdpSocketState::new((&socket).into())?; @@ -538,7 +541,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn side_channel() { let (_connector, acceptor, side_channel_maker) = - configure((Ipv4Addr::LOCALHOST, 0).into()).unwrap(); + configure((Ipv4Addr::LOCALHOST, 0).into(), Default::default()).unwrap(); let addr = *acceptor.local_addr(); let side_channel = side_channel_maker.make(); diff --git a/net/src/socket.rs b/net/src/socket.rs index 2cf137230..0d80edf0c 100644 --- a/net/src/socket.rs +++ b/net/src/socket.rs @@ -17,3 +17,16 @@ pub(crate) fn bind_with_fallback(socket: &Socket, mut addr: SocketAddr) -> io::R } } } + +/// Options for the underlying network socket. +#[derive(Clone, Copy, Default, Debug)] +pub struct SocketOptions { + pub(crate) reuse_addr: bool, +} + +impl SocketOptions { + /// Enables the `SO_REUSEADDR` option on the socket. + pub fn with_reuse_addr(self) -> Self { + Self { reuse_addr: true } + } +} diff --git a/net/src/tcp.rs b/net/src/tcp.rs index 73e5cc0ce..fb936e64d 100644 --- a/net/src/tcp.rs +++ b/net/src/tcp.rs @@ -1,5 +1,5 @@ use self::implementation::{TcpListener, TcpStream}; -use crate::sync::rendezvous; +use crate::{sync::rendezvous, SocketOptions}; use std::{future, io, net::SocketAddr}; use tokio::{ io::{ReadHalf, WriteHalf}, @@ -11,8 +11,11 @@ use tokio_util::compat::{Compat, FuturesAsyncReadCompatExt, TokioAsyncReadCompat use tracing::{Instrument, Span}; /// Configure TCP endpoint -pub fn configure(bind_addr: SocketAddr) -> Result<(Connector, Acceptor), Error> { - let listener = TcpListener::bind(bind_addr)?; +pub fn configure( + bind_addr: SocketAddr, + options: SocketOptions, +) -> Result<(Connector, Acceptor), Error> { + let listener = TcpListener::bind_with_options(bind_addr, options)?; let local_addr = listener.local_addr()?; Ok(( @@ -244,7 +247,7 @@ fn connection_config() -> yamux::Config { // Real #[cfg(not(feature = "simulation"))] mod implementation { - use crate::{socket, KEEP_ALIVE_INTERVAL}; + use crate::{socket, SocketOptions, KEEP_ALIVE_INTERVAL}; use socket2::{Domain, Socket, TcpKeepalive, Type}; use std::{ fmt, io, @@ -258,14 +261,17 @@ mod implementation { pub(super) struct TcpListener(tokio::net::TcpListener); impl TcpListener { - /// Binds TCP socket to the given address. If the port is taken, uses a random one, - pub fn bind(addr: impl Into) -> io::Result { - let addr = addr.into(); - + /// Configures a TCP socket with the given options and binds it to the given address. If the + /// port is taken, uses a random one, + pub fn bind_with_options(addr: SocketAddr, options: SocketOptions) -> io::Result { let socket = Socket::new(Domain::for_address(addr), Type::STREAM, None)?; socket.set_nonblocking(true)?; - // Ignore errors - reuse address is nice to have but not required. - socket.set_reuse_address(true).ok(); + + if options.reuse_addr { + // Ignore errors - reuse address is nice to have but not required. + socket.set_reuse_address(true).ok(); + } + set_keep_alive(&socket)?; socket::bind_with_fallback(&socket, addr)?; @@ -278,6 +284,11 @@ mod implementation { Ok(Self(tokio::net::TcpListener::from_std(socket.into())?)) } + /// Binds TCP socket to the given address. If the port is taken, uses a random one, + pub fn bind(addr: SocketAddr) -> io::Result { + Self::bind_with_options(addr, SocketOptions::default()) + } + pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> { self.0 .accept() diff --git a/net/src/udp.rs b/net/src/udp.rs index 64950a032..3381cda51 100644 --- a/net/src/udp.rs +++ b/net/src/udp.rs @@ -30,23 +30,33 @@ pub trait DatagramSocket { #[cfg(not(feature = "simulation"))] mod implementation { use super::*; - use crate::socket; + use crate::{socket, SocketOptions}; use socket2::{Domain, Socket, Type}; pub struct UdpSocket(tokio::net::UdpSocket); impl UdpSocket { - /// Binds UDP socket to the given address. If the port is taken, uses a random one, - pub fn bind(addr: SocketAddr) -> io::Result { + /// Configures a UDP the socket with the given options and binds it to the given address. If + /// the port is taken, uses a random one, + pub fn bind_with_options(addr: SocketAddr, options: SocketOptions) -> io::Result { let socket = Socket::new(Domain::for_address(addr), Type::DGRAM, None)?; socket.set_nonblocking(true)?; - // Ignore errors - reuse address is nice to have but not required. - socket.set_reuse_address(true).ok(); + + if options.reuse_addr { + // Ignore errors - reuse address is nice to have but not required. + socket.set_reuse_address(true).ok(); + } + socket::bind_with_fallback(&socket, addr)?; Ok(Self(tokio::net::UdpSocket::from_std(socket.into())?)) } + /// Binds UDP socket to the given address. If the port is taken, uses a random one, + pub fn bind(addr: SocketAddr) -> io::Result { + Self::bind_with_options(addr, SocketOptions::default()) + } + pub fn bind_multicast(interface: Ipv4Addr) -> io::Result { let addr = SocketAddr::from((Ipv4Addr::UNSPECIFIED, MULTICAST_PORT)); diff --git a/net/src/unified.rs b/net/src/unified.rs index 26ce298d1..935127c18 100644 --- a/net/src/unified.rs +++ b/net/src/unified.rs @@ -263,6 +263,8 @@ pub enum ConnectionError { #[cfg(test)] mod tests { + use crate::SocketOptions; + use super::*; use futures_util::{future, stream::FuturesUnordered, StreamExt}; use rand::{distributions::Standard, Rng}; @@ -429,15 +431,19 @@ mod tests { } fn setup_tcp_peers() -> (Connector, Acceptor) { - let (client, _) = tcp::configure((Ipv4Addr::LOCALHOST, 0).into()).unwrap(); - let (_, server) = tcp::configure((Ipv4Addr::LOCALHOST, 0).into()).unwrap(); + let (client, _) = + tcp::configure((Ipv4Addr::LOCALHOST, 0).into(), SocketOptions::default()).unwrap(); + let (_, server) = + tcp::configure((Ipv4Addr::LOCALHOST, 0).into(), SocketOptions::default()).unwrap(); (Connector::Tcp(client), Acceptor::Tcp(server)) } fn setup_quic_peers() -> (Connector, Acceptor) { - let (client, _, _) = quic::configure((Ipv4Addr::LOCALHOST, 0).into()).unwrap(); - let (_, server, _) = quic::configure((Ipv4Addr::LOCALHOST, 0).into()).unwrap(); + let (client, _, _) = + quic::configure((Ipv4Addr::LOCALHOST, 0).into(), SocketOptions::default()).unwrap(); + let (_, server, _) = + quic::configure((Ipv4Addr::LOCALHOST, 0).into(), SocketOptions::default()).unwrap(); (Connector::Quic(client), Acceptor::Quic(server)) } From 48458c4134d28132a41a334a4927d4e71d721c4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Mon, 9 Sep 2024 09:15:19 +0200 Subject: [PATCH 13/55] net: Better handle outgoing cancellation --- lib/src/network/message_dispatcher.rs | 13 +++++++++++++ net/src/tcp.rs | 5 ++++- net/src/unified.rs | 4 ++-- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/lib/src/network/message_dispatcher.rs b/lib/src/network/message_dispatcher.rs index 31a3f7e15..c057718d3 100644 --- a/lib/src/network/message_dispatcher.rs +++ b/lib/src/network/message_dispatcher.rs @@ -660,6 +660,9 @@ mod tests { let recv_content = server_stream.recv().await.unwrap(); assert_eq!(recv_content, send_content); } + + client.close().await; + server_dispatcher.shutdown().await; } #[tokio::test(flavor = "multi_thread")] @@ -718,6 +721,9 @@ mod tests { ); } } + + client_dispatcher.shutdown().await; + server_dispatcher.shutdown().await; } #[tokio::test(flavor = "multi_thread")] @@ -757,6 +763,9 @@ mod tests { ); assert_eq!(server_stream1.recv().await.unwrap(), send_content0); assert_eq!(server_stream1.recv().await.unwrap(), send_content1); + + client.close().await; + server_dispatcher.shutdown().await; } #[tokio::test(flavor = "multi_thread")] @@ -881,6 +890,10 @@ mod tests { .into_iter() .collect::>(), ); + + client0.close().await; + client1.close().await; + server_dispatcher.shutdown().await; } #[tokio::test(flavor = "multi_thread")] diff --git a/net/src/tcp.rs b/net/src/tcp.rs index fb936e64d..6b6eb82fe 100644 --- a/net/src/tcp.rs +++ b/net/src/tcp.rs @@ -210,7 +210,10 @@ async fn drive_connection( let result = if let Some(result) = outgoing.pop() { result } else { - future::poll_fn(|cx| conn.poll_new_outbound(cx)).await + select! { + result = future::poll_fn(|cx| conn.poll_new_outbound(cx)) => result, + _ = reply_tx.closed() => continue, + } }; if let Err(result) = reply_tx.send(result).await { diff --git a/net/src/unified.rs b/net/src/unified.rs index 935127c18..3c27630b6 100644 --- a/net/src/unified.rs +++ b/net/src/unified.rs @@ -436,7 +436,7 @@ mod tests { let (_, server) = tcp::configure((Ipv4Addr::LOCALHOST, 0).into(), SocketOptions::default()).unwrap(); - (Connector::Tcp(client), Acceptor::Tcp(server)) + (client.into(), server.into()) } fn setup_quic_peers() -> (Connector, Acceptor) { @@ -445,6 +445,6 @@ mod tests { let (_, server, _) = quic::configure((Ipv4Addr::LOCALHOST, 0).into(), SocketOptions::default()).unwrap(); - (Connector::Quic(client), Acceptor::Quic(server)) + (client.into(), server.into()) } } From 053f60afdb17f7b290d3ad060e067668612a95cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Mon, 9 Sep 2024 10:33:21 +0200 Subject: [PATCH 14/55] Replace the format module with the hex_fmt crate --- Cargo.toml | 1 + lib/Cargo.toml | 1 + lib/src/conflict.rs | 11 ++++++++++- lib/src/crypto/hash.rs | 5 ++--- lib/src/crypto/sign.rs | 9 +++------ lib/src/format.rs | 36 ------------------------------------ lib/src/lib.rs | 1 - lib/src/macros.rs | 16 ++++++++-------- lib/src/protocol/block.rs | 7 ++----- lib/src/protocol/summary.rs | 3 +-- 10 files changed, 28 insertions(+), 62 deletions(-) delete mode 100644 lib/src/format.rs diff --git a/Cargo.toml b/Cargo.toml index 6ab8ffedd..8d1a1e322 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ camino = "1.1.6" chrono = { version = "0.4.31", default-features = false, features = ["clock"] } clap = { version = "4.4.6", features = ["derive"] } futures-util = { version = "0.3.30", default-features = false } +hex_fmt = "0.3.0" metrics = "0.22.0" metrics-exporter-prometheus = { version = "0.13.0", default-features = false } metrics-util = { version = "0.16.0", default-features = false } diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 137e14f73..37da64267 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -44,6 +44,7 @@ either = { version = "1.6.1", default-features = false } futures-util = { workspace = true } generic-array = { version = "0.14.5", features = ["serde"] } hex = "0.4.3" +hex_fmt = { workspace = true } if-watch = { version = "3.2.0", features = ["tokio"] } include_dir = "0.7.3" indexmap = "1.9.3" diff --git a/lib/src/conflict.rs b/lib/src/conflict.rs index da6428fa4..7fc1ea284 100644 --- a/lib/src/conflict.rs +++ b/lib/src/conflict.rs @@ -6,7 +6,16 @@ const SUFFIX_SEPARATOR: &str = ".v"; /// Create non-ambiguous name for a file/directory with `name` by appending a suffix derived from /// `branch_id`. pub fn create_unique_name(name: &str, branch_id: &PublicKey) -> String { - format!("{}{}{:-3$x}", name, SUFFIX_SEPARATOR, branch_id, SUFFIX_LEN) + let mut output = format!( + "{}{}{:<3$x}", + name, + SUFFIX_SEPARATOR, + branch_id, + SUFFIX_LEN + 2 + ); + // Trim the trailing ellipsis + output.truncate(name.len() + SUFFIX_SEPARATOR.len() + SUFFIX_LEN); + output } /// Parse a name created with `create_unique_name` into the original name and the disambiguation diff --git a/lib/src/crypto/hash.rs b/lib/src/crypto/hash.rs index 775c00b8c..2a2b9077c 100644 --- a/lib/src/crypto/hash.rs +++ b/lib/src/crypto/hash.rs @@ -1,6 +1,5 @@ pub use blake3::traits::digest::Digest; -use crate::format; use generic_array::{typenum::U32, GenericArray}; use serde::{Deserialize, Serialize}; use std::{ @@ -81,13 +80,13 @@ impl fmt::Display for Hash { impl fmt::Debug for Hash { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{self:8x}") + write!(f, "{self:<8x}") } } impl fmt::LowerHex for Hash { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - format::hex(f, self.as_ref()) + write!(f, "{}", hex_fmt::HexFmt(self.as_ref())) } } diff --git a/lib/src/crypto/sign.rs b/lib/src/crypto/sign.rs index dfbd6c966..11cbbca8b 100644 --- a/lib/src/crypto/sign.rs +++ b/lib/src/crypto/sign.rs @@ -1,7 +1,4 @@ -use crate::{ - crypto::{Digest, Hashable}, - format, -}; +use crate::crypto::{Digest, Hashable}; use ed25519_dalek::{self as ext, Signer, Verifier}; use rand::{rngs::OsRng, CryptoRng, Rng}; use serde::{Deserialize, Serialize}; @@ -104,7 +101,7 @@ impl Ord for PublicKey { impl fmt::LowerHex for PublicKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - format::hex(f, self.0.as_bytes()) + hex_fmt::HexFmt(self.0.as_bytes()).fmt(f) } } @@ -142,7 +139,7 @@ impl fmt::Display for PublicKey { impl fmt::Debug for PublicKey { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{self:8x}") + write!(f, "{self:<8x}") } } diff --git a/lib/src/format.rs b/lib/src/format.rs deleted file mode 100644 index c467d87c9..000000000 --- a/lib/src/format.rs +++ /dev/null @@ -1,36 +0,0 @@ -use std::fmt; - -// Format the byte slice as hex with optional truncation. -// This is a helper for implementing the `LowerHex` trait. -pub(crate) fn hex(f: &mut fmt::Formatter, bytes: &[u8]) -> fmt::Result { - let len = f - .width() - .map(|w| w / 2) - .unwrap_or(bytes.len()) - .min(bytes.len()); - - let (len, ellipsis) = match (len, f.sign_minus()) { - (0, _) => (0, false), - (len, _) if len == bytes.len() => (len, false), - (len, true) => (len, false), - (len, false) => (len - 1, true), - }; - - for byte in &bytes[..len] { - write!(f, "{:02x}", byte)?; - } - - if ellipsis { - write!(f, "..")?; - } - - Ok(()) -} - -pub(crate) struct Hex<'a>(pub &'a [u8]); - -impl fmt::LowerHex for Hex<'_> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - hex(f, self.0) - } -} diff --git a/lib/src/lib.rs b/lib/src/lib.rs index f33c4cd38..7a6145863 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -21,7 +21,6 @@ mod directory; mod error; mod event; mod file; -mod format; mod future; mod iterator; mod joint_directory; diff --git a/lib/src/macros.rs b/lib/src/macros.rs index 9b5d46182..f8a1bd4e2 100644 --- a/lib/src/macros.rs +++ b/lib/src/macros.rs @@ -50,13 +50,13 @@ macro_rules! define_byte_array_wrapper { impl std::fmt::Debug for $name { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{:8x}", self) + write!(f, "{:<8x}", self) } } impl std::fmt::LowerHex for $name { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - crate::format::hex(f, &self.0) + hex_fmt::HexFmt(&self.0).fmt(f) } } }; @@ -144,12 +144,12 @@ mod tests { format!("{:x}", id), "0001020305070b0d1113171d1f25292b2f353b3d4347494f53596165676b6d71" ); - assert_eq!(format!("{:1x}", id), ""); - assert_eq!(format!("{:2x}", id), ".."); - assert_eq!(format!("{:3x}", id), ".."); - assert_eq!(format!("{:4x}", id), "00.."); - assert_eq!(format!("{:6x}", id), "0001.."); - assert_eq!(format!("{:8x}", id), "000102.."); + assert_eq!(format!("{:<1x}", id), "."); + assert_eq!(format!("{:<2x}", id), ".."); + assert_eq!(format!("{:<3x}", id), "0.."); + assert_eq!(format!("{:<4x}", id), "00.."); + assert_eq!(format!("{:<6x}", id), "0001.."); + assert_eq!(format!("{:<8x}", id), "000102.."); assert_eq!(format!("{:?}", id), "000102.."); assert_eq!( diff --git a/lib/src/protocol/block.rs b/lib/src/protocol/block.rs index 15630a53a..9693bb61b 100644 --- a/lib/src/protocol/block.rs +++ b/lib/src/protocol/block.rs @@ -1,7 +1,4 @@ -use crate::{ - crypto::{Digest, Hash, Hashable}, - format::Hex, -}; +use crate::crypto::{Digest, Hash, Hashable}; use rand::{distributions::Standard, prelude::Distribution, Rng}; use serde::{Deserialize, Serialize}; use std::{ @@ -174,6 +171,6 @@ impl Distribution for Standard { impl fmt::Debug for BlockContent { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{:6x}", Hex(&self[..])) + write!(f, "{:<8}", hex_fmt::HexFmt(&self[..])) } } diff --git a/lib/src/protocol/summary.rs b/lib/src/protocol/summary.rs index cafc29356..2fbd3c7f3 100644 --- a/lib/src/protocol/summary.rs +++ b/lib/src/protocol/summary.rs @@ -1,5 +1,4 @@ use super::{InnerNodes, LeafNodes}; -use crate::format::Hex; use serde::{Deserialize, Serialize}; use sqlx::{ encode::IsNull, @@ -300,7 +299,7 @@ impl fmt::Debug for MultiBlockPresence { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::None => write!(f, "None"), - Self::Some(checksum) => write!(f, "Some({:10x})", Hex(checksum)), + Self::Some(checksum) => write!(f, "Some({:<8})", hex_fmt::HexFmt(checksum)), Self::Full => write!(f, "Full"), } } From 915da7828e52b09caa9bc73e966ae17d824565c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Tue, 10 Sep 2024 14:51:42 +0200 Subject: [PATCH 15/55] net: Implement Bus --- Cargo.toml | 3 + lib/Cargo.toml | 6 +- net/Cargo.toml | 8 +- net/src/bus.rs | 260 ++++++++++++++++++++++++++++++++++++++ net/src/bus/dispatch.rs | 230 ++++++++++++++++++++++++++++++++++ net/src/bus/topic.rs | 267 ++++++++++++++++++++++++++++++++++++++++ net/src/bus/worker.rs | 219 ++++++++++++++++++++++++++++++++ net/src/lib.rs | 4 + net/src/tcp.rs | 93 ++++++++------ net/src/test_utils.rs | 53 ++++++++ net/src/unified.rs | 154 ++++++++++++++--------- 11 files changed, 1201 insertions(+), 96 deletions(-) create mode 100644 net/src/bus.rs create mode 100644 net/src/bus/dispatch.rs create mode 100644 net/src/bus/topic.rs create mode 100644 net/src/bus/worker.rs create mode 100644 net/src/test_utils.rs diff --git a/Cargo.toml b/Cargo.toml index 8d1a1e322..93f056c5c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,6 +50,7 @@ metrics-util = { version = "0.16.0", default-features = false } num_enum = { version = "0.7.0", default-features = false } once_cell = "1.18.0" pin-project-lite = "0.2.13" +proptest = "1.0" rand = { package = "ouisync-rand", path = "rand" } rcgen = "0.13" rmp-serde = "1.1.0" @@ -58,8 +59,10 @@ serde = { version = "1.0", features = ["derive", "rc"] } serde_bytes = "0.11.8" serde_json = "1.0.94" similar-asserts = "1.5.0" +slab = "0.4.9" sqlx = { version = "0.7.4", default-features = false, features = ["runtime-tokio", "sqlite"] } tempfile = "3.2" +test-strategy = "0.4.0" thiserror = "1.0.49" tokio = { version = "1.38.0", default-features = false } tokio-rustls = { version = "0.26", default-features = false } diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 37da64267..9e86f7cca 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -64,7 +64,7 @@ rupnp = { version = "1.1.0", default-features = false, features = [] } scoped_task = { path = "../scoped_task" } serde = { workspace = true } serde_bytes = { workspace = true } -slab = "0.4.6" +slab = { workspace = true } sqlx = { workspace = true } ssdp-client = "1.0" state_monitor = { path = "../state_monitor" } @@ -88,13 +88,13 @@ criterion = { version = "0.4", features = ["html_reports"] } hdrhistogram = { version = "7.5.4", default-features = false, features = ["sync"] } metrics_ext = { path = "../metrics_ext" } ouisync-tracing-fmt = { path = "../tracing_fmt" } -proptest = "1.0" +proptest = { workspace = true } rmp-serde = { workspace = true } serde_json = { workspace = true } serde_test = "1.0.176" similar-asserts = { workspace = true } tempfile = { workspace = true } -test-strategy = "0.2.1" +test-strategy = { workspace = true } tokio = { workspace = true, features = ["process", "test-util"] } [features] diff --git a/net/Cargo.toml b/net/Cargo.toml index f7e538b08..0b7f43b5a 100644 --- a/net/Cargo.toml +++ b/net/Cargo.toml @@ -11,10 +11,12 @@ version.workspace = true bytecodec = "0.4.15" bytes = "1.1.0" futures-util = { workspace = true } +hex_fmt = { workspace = true } pin-project-lite = { workspace = true } quinn = "0.11.4" rand = { package = "ouisync-rand", path = "../rand" } rcgen = { workspace = true } +slab = { workspace = true } socket2 = "0.5.7" # To be able to setsockopts before a socket is bound stun_codec = "0.3.4" thiserror = "1.0.31" @@ -26,10 +28,14 @@ yamux = "0.13.3" [dev-dependencies] anyhow = { workspace = true } +assert_matches = { workspace = true } clap = { workspace = true } +itertools = "0.13.0" +proptest = { workspace = true } similar-asserts = { workspace = true } +test-strategy = { workspace = true } tokio = { workspace = true } -tracing-subscriber = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter"] } [features] simulation = ["turmoil"] diff --git a/net/src/bus.rs b/net/src/bus.rs new file mode 100644 index 000000000..58a480ecc --- /dev/null +++ b/net/src/bus.rs @@ -0,0 +1,260 @@ +mod dispatch; +mod topic; +mod worker; + +pub use topic::TopicId; + +use crate::unified::{Connection, RecvStream, SendStream}; +use std::{ + future::Future, + io, + pin::Pin, + task::{ready, Context, Poll}, +}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{mpsc, oneshot}, + task, +}; +use worker::Command; + +/// Wrapper around connection that allow creating arbitrary (up to a limit determined by the +/// underlying connection) number of independent streams, each bound to a specific topic. When the +/// two peers create streams bound to the same topic, they can communicate on them with each +/// other. +pub struct Bus { + command_tx: mpsc::UnboundedSender, +} + +impl Bus { + pub fn new(connection: Connection) -> Self { + let (command_tx, command_rx) = mpsc::unbounded_channel(); + + task::spawn(worker::run(connection, command_rx)); + + Self { command_tx } + } + + /// Creates a pair of send and receive streams bound to the given topic. + /// + /// It doesn't matter when or in what order the peers create the streams - as long as both + /// eventually create them, they will connect. + /// + /// Note: topics don't have to be unique. If multiple streams are created with the same topic, + /// each one is still connected to exactly one corresponding streams of the remote peer, but + /// it's undefined which local stream gets connected to which remote stream. + pub fn create_topic(&self, topic_id: TopicId) -> (BusSendStream, BusRecvStream) { + let (send_stream_tx, send_stream_rx) = oneshot::channel(); + let (recv_stream_tx, recv_stream_rx) = oneshot::channel(); + + let send = BusSendStream { + inner: StreamInner::Pending(send_stream_rx), + }; + + let recv = BusRecvStream { + inner: StreamInner::Pending(recv_stream_rx), + }; + + self.command_tx + .send(Command::Create { + topic_id, + send_stream_tx, + recv_stream_tx, + }) + .expect("bus worker unexpectedly terminated"); + + (send, recv) + } + + /// Gracefully shuts down the underlying connection. + pub async fn shutdown(&self) { + let (reply_tx, reply_rx) = oneshot::channel(); + + if self + .command_tx + .send(Command::Shutdown { reply_tx }) + .is_err() + { + return; + } + + reply_rx.await.ok(); + } +} + +pub struct BusSendStream { + inner: StreamInner, +} + +impl AsyncWrite for BusSendStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + ready!(self.get_mut().inner.poll(cx))?.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.get_mut().inner.poll(cx))?.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.get_mut().inner.poll(cx))?.poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + ready!(self.get_mut().inner.poll(cx))?.poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + if let StreamInner::Active(stream) = &self.inner { + stream.is_write_vectored() + } else { + false + } + } +} + +pub struct BusRecvStream { + inner: StreamInner, +} + +impl AsyncRead for BusRecvStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + ready!(self.get_mut().inner.poll(cx))?.poll_read(cx, buf) + } +} + +enum StreamInner { + Pending(oneshot::Receiver>), + Active(T), +} + +impl StreamInner +where + T: Unpin, +{ + fn poll(&mut self, cx: &mut Context<'_>) -> Poll, io::Error>> { + loop { + return match self { + Self::Pending(rx) => match ready!(Pin::new(rx).poll(cx)) { + Ok(Ok(stream)) => { + *self = Self::Active(stream); + continue; + } + Ok(Err(error)) => Poll::Ready(Err(error)), + Err(_) => Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())), + }, + Self::Active(stream) => Poll::Ready(Ok(Pin::new(stream))), + }; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_utils::{ + create_connected_connections, create_connected_peers, init_log, Proto, + }; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + #[tokio::test] + async fn sanity_check_tcp() { + sanity_check_case(Proto::Tcp).await + } + + #[tokio::test] + async fn sanity_check_quic() { + sanity_check_case(Proto::Quic).await + } + + async fn sanity_check_case(proto: Proto) { + init_log(); + + let (client, server) = create_connected_peers(proto); + let (client, server) = create_connected_connections(&client, &server).await; + + let client = Bus::new(client); + let server = Bus::new(server); + + let topic_id = TopicId::random(); + + let (mut client_send_stream, mut client_recv_stream) = client.create_topic(topic_id); + let (mut server_send_stream, mut server_recv_stream) = server.create_topic(topic_id); + + let client_message = b"hello from client"; + let server_message = b"hello from server"; + + client_send_stream.write_all(client_message).await.unwrap(); + + let mut buffer = vec![0; client_message.len()]; + server_recv_stream.read_exact(&mut buffer).await.unwrap(); + assert_eq!(&buffer, client_message); + + server_send_stream.write_all(server_message).await.unwrap(); + + let mut buffer = vec![0; server_message.len()]; + client_recv_stream.read_exact(&mut buffer).await.unwrap(); + assert_eq!(&buffer, server_message); + } + + #[tokio::test] + async fn duplicate_topic_tcp() { + duplicate_topic_case(Proto::Tcp).await + } + + #[tokio::test] + async fn duplicate_topic_quic() { + duplicate_topic_case(Proto::Quic).await + } + + async fn duplicate_topic_case(proto: Proto) { + init_log(); + + let (client, server) = create_connected_peers(proto); + let (client, server) = create_connected_connections(&client, &server).await; + + let client = Bus::new(client); + let server = Bus::new(server); + + let topic_id = TopicId::random(); + + let (mut client_send_stream_0, _client_recv_stream_0) = client.create_topic(topic_id); + let (mut client_send_stream_1, _client_recv_stream_1) = client.create_topic(topic_id); + + let (_server_send_stream_0, mut server_recv_stream_0) = server.create_topic(topic_id); + let (_server_send_stream_1, mut server_recv_stream_1) = server.create_topic(topic_id); + + client_send_stream_0.write_all(b"ping 0").await.unwrap(); + client_send_stream_1.write_all(b"ping 1").await.unwrap(); + + let mut buffer_0 = [0; 6]; + server_recv_stream_0 + .read_exact(&mut buffer_0) + .await + .unwrap(); + + let mut buffer_1 = [0; 6]; + server_recv_stream_1 + .read_exact(&mut buffer_1) + .await + .unwrap(); + + // The streams can be connected in any order + match (&buffer_0, &buffer_1) { + (b"ping 0", b"ping 1") => (), + (b"ping 1", b"ping 0") => (), + _ => panic!("unexpected {:?}", (&buffer_0, &buffer_1)), + } + } +} diff --git a/net/src/bus/dispatch.rs b/net/src/bus/dispatch.rs new file mode 100644 index 000000000..ce4cbff02 --- /dev/null +++ b/net/src/bus/dispatch.rs @@ -0,0 +1,230 @@ +use super::TopicId; +use crate::unified::{Connection, RecvStream, SendStream}; +use slab::Slab; +use std::{io, sync::Mutex}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + select, + sync::oneshot, +}; +use tracing::instrument; + +/// Wrapper around `Connection` that allow to establish incoming and outgoing streams that are bound +/// to a given topic. When the two connected peers establish streams bound to the same topic +/// (one peer incoming and one peer outgoing), those streams will be connected and will allow +/// conmmunication between them. +/// +/// NOTE: incoming streams wait for the corresponding outgoing streams but not the other way around. +/// That is, if one peer calls `incoming`, it will wait until the other peer calls `outgoing` with +/// the same topic. However, if one peer calls `outgoing` without the other calling `incoming` +/// with the same topic, the returned stream will be immediatelly closed. +pub(super) struct Dispatcher { + connection: Connection, + registry: Mutex, +} + +impl Dispatcher { + pub fn new(connection: Connection) -> Self { + Self { + connection, + registry: Mutex::new(Registry::default()), + } + } + + /// Establish an incoming stream bound to the given topic. + #[instrument(skip_all)] + pub async fn incoming(&self, topic_id: TopicId) -> io::Result<(SendStream, RecvStream)> { + let (reply_tx, reply_rx) = oneshot::channel(); + + let key = self.registry.lock().unwrap().insert((topic_id, reply_tx)); + let guard = CancelGuard::arm(&self.registry, key); + + let accept = async { + loop { + let (send_stream, mut recv_stream) = match self.connection.incoming().await { + Ok(streams) => streams, + Err(error) => { + tracing::debug!(?error, "failed to accept incoming stream"); + return Err(io::Error::other(error)); + } + }; + + let mut buffer = [0; TopicId::SIZE]; + let topic_id = match recv_stream.read_exact(&mut buffer).await { + Ok(_) => TopicId::from(buffer), + Err(error) => { + tracing::debug!(?error, "failed to read topic id from incoming stream"); + continue; + } + }; + + let reply_tx = { + let mut registry = self.registry.lock().unwrap(); + let key = registry + .iter() + .find(|(_, (registry_topic_id, _))| *registry_topic_id == topic_id) + .map(|(key, _)| key); + + key.map(|key| registry.remove(key).1) + }; + + if let Some(reply_tx) = reply_tx { + reply_tx.send((send_stream, recv_stream)).ok(); + } else { + tracing::debug!(?topic_id, "unsolicited incoming stream"); + continue; + } + } + }; + + select! { + result = accept => result, + result = reply_rx => { + guard.disarm(); + + // unwrap is OK because the associated `reply_tx` is only dropped after we send on + // it. + Ok(result.unwrap()) + } + } + } + + /// Establish an outgoing stream bound to the given topic. + #[instrument(skip_all)] + pub async fn outgoing(&self, topic_id: TopicId) -> io::Result<(SendStream, RecvStream)> { + let (mut send_stream, recv_stream) = self + .connection + .outgoing() + .await + .inspect_err(|error| tracing::error!(?error)) + .map_err(io::Error::other)?; + + send_stream + .write_all(topic_id.as_bytes()) + .await + .inspect_err(|error| tracing::error!(?error))?; + + Ok((send_stream, recv_stream)) + } + + pub async fn close(&self) { + self.connection.close().await + } +} + +type Registry = Slab<(TopicId, oneshot::Sender<(SendStream, RecvStream)>)>; + +struct CancelGuard<'a> { + armed: bool, + key: usize, + registry: &'a Mutex, +} + +impl<'a> CancelGuard<'a> { + fn arm(registry: &'a Mutex, key: usize) -> Self { + Self { + armed: true, + registry, + key, + } + } + + fn disarm(mut self) { + self.armed = false; + } +} + +impl Drop for CancelGuard<'_> { + fn drop(&mut self) { + if !self.armed { + return; + } + + self.registry + .lock() + .unwrap_or_else(|error| error.into_inner()) + .try_remove(self.key); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_utils::{create_connected_connections, create_connected_peers, Proto}; + use assert_matches::assert_matches; + use futures_util::future; + + #[tokio::test] + async fn sanity_check_tcp() { + sanity_check_case(Proto::Tcp).await + } + + #[tokio::test] + async fn sanity_check_quic() { + sanity_check_case(Proto::Quic).await + } + + async fn sanity_check_case(proto: Proto) { + let (client, server) = create_connected_peers(proto); + let (client, server) = create_connected_connections(&client, &server).await; + + let client = Dispatcher::new(client); + let server = Dispatcher::new(server); + + let topic_id = TopicId::random(); + + let ( + (mut client_send_stream, mut client_recv_stream), + (mut server_send_stream, mut server_recv_stream), + ) = future::try_join(client.outgoing(topic_id), server.incoming(topic_id)) + .await + .unwrap(); + + let client_message = b"hello from client"; + let server_message = b"hello from server"; + + client_send_stream.write_all(client_message).await.unwrap(); + + let mut buffer = [0; 17]; + server_recv_stream.read_exact(&mut buffer).await.unwrap(); + assert_eq!(&buffer, client_message); + + server_send_stream.write_all(server_message).await.unwrap(); + + let mut buffer = [0; 17]; + client_recv_stream.read_exact(&mut buffer).await.unwrap(); + assert_eq!(&buffer, server_message); + } + + #[tokio::test] + async fn unsolicited_stream() { + let (client, server) = create_connected_peers(Proto::Quic); + let (client, server) = create_connected_connections(&client, &server).await; + + let client = Dispatcher::new(client); + let server = Dispatcher::new(server); + + let client_task = async { + let (mut send_stream, mut recv_stream) = + client.outgoing(TopicId::random()).await.unwrap(); + + send_stream.write_all(b"ping").await.unwrap(); + + let mut buffer = [0; 1]; + assert_matches!( + recv_stream.read_exact(&mut buffer).await, + Err(error) if error.kind() == io::ErrorKind::UnexpectedEof + ); + }; + + let server_task = async { + server.incoming(TopicId::random()).await.ok(); + unreachable!(); + }; + + select! { + _ = client_task => (), + _ = server_task => (), + }; + } +} diff --git a/net/src/bus/topic.rs b/net/src/bus/topic.rs new file mode 100644 index 000000000..a6188865d --- /dev/null +++ b/net/src/bus/topic.rs @@ -0,0 +1,267 @@ +use rand::Rng; +use std::fmt; + +#[derive(Clone, Copy, Eq, PartialEq, Hash)] +#[repr(transparent)] +pub struct TopicId([u8; Self::SIZE]); + +impl TopicId { + pub const SIZE: usize = 32; + + pub fn generate(rng: &mut R) -> Self { + Self(rng.gen()) + } + + pub fn random() -> Self { + Self::generate(&mut rand::thread_rng()) + } + + pub fn as_bytes(&self) -> &[u8] { + self.0.as_slice() + } +} + +impl From<[u8; Self::SIZE]> for TopicId { + fn from(array: [u8; Self::SIZE]) -> Self { + Self(array) + } +} + +impl fmt::Debug for TopicId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:<8}", hex_fmt::HexFmt(&self.0)) + } +} + +/// Random value to disambiguate incoming and outgoing streams bound to the same topic. +#[derive(Clone, Copy, Eq, PartialEq, Ord, PartialOrd)] +pub struct TopicNonce([u8; Self::SIZE]); + +impl TopicNonce { + pub const SIZE: usize = 16; + + pub fn generate(rng: &mut R) -> Self { + Self(rng.gen()) + } + + pub fn random() -> Self { + Self::generate(&mut rand::thread_rng()) + } + + pub fn as_bytes(&self) -> &[u8] { + self.0.as_slice() + } +} + +impl From<[u8; Self::SIZE]> for TopicNonce { + fn from(array: [u8; Self::SIZE]) -> Self { + Self(array) + } +} + +impl From for TopicNonce { + fn from(n: u128) -> Self { + Self(n.to_be_bytes()) + } +} + +impl From for TopicNonce { + fn from(n: u64) -> Self { + Self::from(u128::from(n)) + } +} + +impl fmt::Debug for TopicNonce { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:<8}", hex_fmt::HexFmt(&self.0)) + } +} + +pub(super) enum Input { + IncomingCreated(TopicNonce), + IncomingFailed, + OutgoingCreated, + OutgoingFailed, +} + +#[derive(Eq, PartialEq, Debug)] +pub(super) enum Output { + OutgoingCreate(TopicNonce), + OutgoingAccept, + IncomingCreate, + IncomingAccept, +} + +/// State machine for establishing topic streams +pub(super) struct TopicState { + pub nonce: TopicNonce, + incoming: IncomingState, + outgoing: OutgoingState, +} + +impl TopicState { + pub fn new(nonce: TopicNonce) -> Self { + Self { + nonce, + incoming: IncomingState::Init, + outgoing: OutgoingState::Init, + } + } + + pub fn handle(&mut self, input: Input) { + match input { + Input::IncomingCreated(nonce) => { + self.incoming = IncomingState::Created(nonce); + } + Input::IncomingFailed => { + self.incoming = IncomingState::Failed; + } + Input::OutgoingCreated => { + self.outgoing = OutgoingState::Created; + } + Input::OutgoingFailed => { + self.outgoing = OutgoingState::Failed; + } + } + } + + pub fn poll(&mut self) -> Option { + match (&self.incoming, &self.outgoing) { + (IncomingState::Init, OutgoingState::Init) + | (IncomingState::Init, OutgoingState::Creating) + | (IncomingState::Init, OutgoingState::Created) + | (IncomingState::Init, OutgoingState::Failed) + | (IncomingState::Failed, OutgoingState::Created) => { + // Create the incoming stream initially or try to create it again if it failed + // previously but we just got the outgoing stream. + self.incoming = IncomingState::Creating; + Some(Output::IncomingCreate) + } + (IncomingState::Creating, OutgoingState::Init) + | (IncomingState::Created(_), OutgoingState::Init) + | (IncomingState::Created(_), OutgoingState::Failed) + | (IncomingState::Failed, OutgoingState::Init) => { + // Create the outgoing stream initially or try to create it again if it failed + // previously but we got the incoming stream. + self.outgoing = OutgoingState::Creating; + Some(Output::OutgoingCreate(self.nonce)) + } + (IncomingState::Created(incoming_nonce), OutgoingState::Created) => { + // Both streams created, break the ties using the nonces. + if *incoming_nonce > self.nonce { + Some(Output::IncomingAccept) + } else { + Some(Output::OutgoingAccept) + } + } + (IncomingState::Creating, OutgoingState::Creating) + | (IncomingState::Creating, OutgoingState::Created) + | (IncomingState::Creating, OutgoingState::Failed) + | (IncomingState::Created(_), OutgoingState::Creating) + | (IncomingState::Failed, OutgoingState::Creating) + | (IncomingState::Failed, OutgoingState::Failed) => { + // Nothing to do, waiting for further inputs. + None + } + } + } +} + +enum IncomingState { + Init, + Creating, + Created(TopicNonce), + Failed, +} + +enum OutgoingState { + Init, + Creating, + Created, + Failed, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sanity_check() { + for (this_nonce, that_nonce, expected_final_output) in [ + ( + TopicNonce::from(1u64), + TopicNonce::from(2u64), + Some(Output::IncomingAccept), + ), + ( + TopicNonce::from(2u64), + TopicNonce::from(1u64), + Some(Output::OutgoingAccept), + ), + ] { + let mut state = TopicState::new(this_nonce); + + assert_eq!(state.poll(), Some(Output::IncomingCreate)); + assert_eq!(state.poll(), Some(Output::OutgoingCreate(this_nonce))); + assert_eq!(state.poll(), None); + + state.handle(Input::IncomingCreated(that_nonce)); + + assert_eq!(state.poll(), None); + + state.handle(Input::OutgoingCreated); + + assert_eq!(state.poll(), expected_final_output); + } + } + + #[test] + fn outgoing_fail() { + let this_nonce = TopicNonce::from(1u64); + let that_nonce = TopicNonce::from(2u64); + + let mut state = TopicState::new(this_nonce); + + assert_eq!(state.poll(), Some(Output::IncomingCreate)); + assert_eq!(state.poll(), Some(Output::OutgoingCreate(this_nonce))); + assert_eq!(state.poll(), None); + + state.handle(Input::OutgoingFailed); + + assert_eq!(state.poll(), None); + + state.handle(Input::IncomingCreated(that_nonce)); + + assert_eq!(state.poll(), Some(Output::OutgoingCreate(this_nonce))); + assert_eq!(state.poll(), None); + + state.handle(Input::OutgoingCreated); + + assert_eq!(state.poll(), Some(Output::IncomingAccept)); + } + + #[test] + fn incoming_fail() { + let this_nonce = TopicNonce::from(1u64); + let that_nonce = TopicNonce::from(2u64); + + let mut state = TopicState::new(this_nonce); + + assert_eq!(state.poll(), Some(Output::IncomingCreate)); + assert_eq!(state.poll(), Some(Output::OutgoingCreate(this_nonce))); + assert_eq!(state.poll(), None); + + state.handle(Input::IncomingFailed); + + assert_eq!(state.poll(), None); + + state.handle(Input::OutgoingCreated); + + assert_eq!(state.poll(), Some(Output::IncomingCreate)); + assert_eq!(state.poll(), None); + + state.handle(Input::IncomingCreated(that_nonce)); + + assert_eq!(state.poll(), Some(Output::IncomingAccept)); + } +} diff --git a/net/src/bus/worker.rs b/net/src/bus/worker.rs new file mode 100644 index 000000000..9ee5eadd2 --- /dev/null +++ b/net/src/bus/worker.rs @@ -0,0 +1,219 @@ +use super::{ + dispatch::Dispatcher, + topic::{Input, Output, TopicId, TopicNonce, TopicState}, +}; +use crate::unified::{Connection, RecvStream, SendStream}; +use futures_util::{stream::FuturesUnordered, StreamExt}; +use std::io; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + select, + sync::{mpsc, oneshot}, +}; +use tracing::instrument; + +pub(super) enum Command { + Create { + topic_id: TopicId, + send_stream_tx: oneshot::Sender>, + recv_stream_tx: oneshot::Sender>, + }, + Shutdown { + reply_tx: oneshot::Sender<()>, + }, +} + +#[instrument(name = "worker", skip_all, fields(addr = %connection.remote_addr()))] +pub(super) async fn run(connection: Connection, mut command_rx: mpsc::UnboundedReceiver) { + let dispatcher = Dispatcher::new(connection); + let mut topic_tasks = FuturesUnordered::new(); + + loop { + let command = select! { + Some(command) = command_rx.recv() => command, + Some(_) = topic_tasks.next() => continue, + }; + + match command { + Command::Create { + topic_id, + send_stream_tx, + recv_stream_tx, + } => { + topic_tasks.push(create_topic( + topic_id, + send_stream_tx, + recv_stream_tx, + &dispatcher, + )); + } + Command::Shutdown { reply_tx } => { + dispatcher.close().await; + reply_tx.send(()).ok(); + break; + } + } + } +} + +async fn create_topic( + topic_id: TopicId, + send_stream_tx: oneshot::Sender>, + recv_stream_tx: oneshot::Sender>, + dispatcher: &Dispatcher, +) { + // TODO: handle receiver cancellation + + match TopicHandler::new(topic_id, TopicNonce::random(), dispatcher) + .run() + .await + { + Ok((send_stream, recv_stream)) => { + send_stream_tx.send(Ok(send_stream)).ok(); + recv_stream_tx.send(Ok(recv_stream)).ok(); + } + Err(error) => { + send_stream_tx + .send(Err(io::ErrorKind::BrokenPipe.into())) + .ok(); + recv_stream_tx.send(Err(error)).ok(); + } + } +} + +struct TopicHandler<'a> { + topic_id: TopicId, + state: TopicState, + dispatcher: &'a Dispatcher, +} + +impl<'a> TopicHandler<'a> { + fn new(topic_id: TopicId, nonce: TopicNonce, dispatcher: &'a Dispatcher) -> Self { + let state = TopicState::new(nonce); + + Self { + topic_id, + state, + dispatcher, + } + } + + #[instrument(name = "topic", skip_all, fields(topic_id = ?self.topic_id, nonce = ?self.state.nonce))] + async fn run(mut self) -> io::Result<(SendStream, RecvStream)> { + let mut tasks = FuturesUnordered::new(); + let mut incoming = None; + let mut outgoing = None; + let mut last_error = None; + + loop { + while let Some(output) = self.state.poll() { + match output { + Output::OutgoingCreate(nonce) => { + tasks.push(create_stream( + self.dispatcher, + CreateInput::Outgoing(self.topic_id, nonce), + )); + } + Output::OutgoingAccept => match outgoing { + Some(streams) => return Ok(streams), + None => unreachable!(), + }, + Output::IncomingCreate => { + tasks.push(create_stream( + self.dispatcher, + CreateInput::Incoming(self.topic_id), + )); + } + Output::IncomingAccept => match incoming { + Some(streams) => return Ok(streams), + None => unreachable!(), + }, + } + } + + let Some(output) = tasks.next().await else { + break; + }; + + match output { + CreateOutput::Incoming(Ok((nonce, send_stream, recv_stream))) => { + incoming = Some((send_stream, recv_stream)); + self.state.handle(Input::IncomingCreated(nonce)); + } + CreateOutput::Incoming(Err(error)) => { + tracing::debug!(?error, "failed to create incoming stream"); + last_error = Some(error); + self.state.handle(Input::IncomingFailed); + } + CreateOutput::Outgoing(Ok((send_stream, recv_stream))) => { + outgoing = Some((send_stream, recv_stream)); + self.state.handle(Input::OutgoingCreated); + } + CreateOutput::Outgoing(Err(error)) => { + tracing::debug!(?error, "failed to create outgoing stream"); + last_error = Some(error); + self.state.handle(Input::OutgoingFailed); + } + } + } + + return Err(last_error.unwrap_or_else(|| io::ErrorKind::ConnectionAborted.into())); + } +} + +async fn create_stream(dispatcher: &Dispatcher, input: CreateInput) -> CreateOutput { + match input { + CreateInput::Incoming(topic_id) => { + CreateOutput::Incoming(create_incoming_stream(dispatcher, topic_id).await) + } + CreateInput::Outgoing(topic_id, nonce) => { + CreateOutput::Outgoing(create_outgoing_stream(dispatcher, topic_id, nonce).await) + } + } +} + +async fn create_incoming_stream( + dispatcher: &Dispatcher, + topic_id: TopicId, +) -> io::Result<(TopicNonce, SendStream, RecvStream)> { + let (mut send_stream, mut recv_stream) = dispatcher.incoming(topic_id).await?; + + let mut buffer = [0; TopicNonce::SIZE]; + recv_stream.read_exact(&mut buffer).await?; + let nonce = TopicNonce::from(buffer); + + send_stream.write_all(ACK).await?; + + Ok((nonce, send_stream, recv_stream)) +} + +async fn create_outgoing_stream( + dispatcher: &Dispatcher, + topic_id: TopicId, + nonce: TopicNonce, +) -> io::Result<(SendStream, RecvStream)> { + let (mut send_stream, mut recv_stream) = dispatcher.outgoing(topic_id).await?; + + send_stream.write_all(nonce.as_bytes()).await?; + + let mut buffer = [0; ACK.len()]; + recv_stream.read_exact(&mut buffer).await?; + + if buffer != ACK { + return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid ack")); + } + + Ok((send_stream, recv_stream)) +} + +const ACK: &[u8] = &[0xac]; + +enum CreateInput { + Incoming(TopicId), + Outgoing(TopicId, TopicNonce), +} + +enum CreateOutput { + Incoming(io::Result<(TopicNonce, SendStream, RecvStream)>), + Outgoing(io::Result<(SendStream, RecvStream)>), +} diff --git a/net/src/lib.rs b/net/src/lib.rs index 255825252..8aea8ffde 100644 --- a/net/src/lib.rs +++ b/net/src/lib.rs @@ -1,5 +1,6 @@ use std::time::Duration; +pub mod bus; pub mod quic; pub mod stun; pub mod tcp; @@ -10,6 +11,9 @@ pub mod unified; mod socket; mod sync; +#[cfg(test)] +mod test_utils; + pub use socket::SocketOptions; pub const KEEP_ALIVE_INTERVAL: Duration = Duration::from_secs(10); diff --git a/net/src/tcp.rs b/net/src/tcp.rs index 6b6eb82fe..bd2dda461 100644 --- a/net/src/tcp.rs +++ b/net/src/tcp.rs @@ -163,66 +163,69 @@ async fn drive_connection( mut conn: yamux::Connection>, mut command_rx: mpsc::Receiver, ) { - // Buffers for incoming and outgoing streams. These guarantee that no streams are ever lost, - // even if `Connection::incoming` or `Connection::outgoing` are cancelled. Due to the limit on - // the number of streams per connection, these buffers are effectively bounded. - let mut incoming = Vec::new(); - let mut outgoing = Vec::new(); + // Incoming streams are being polled continuously and placed here, then taken from here one by + // one on the next call to `Connection::incoming`. Note that yamux has a limit on the number of + // simultaneously open streams per connection which effectively puts a bound on this collection + // as well. + let mut incoming_results = Vec::>::new(); + let mut incoming_senders = Vec::>::new(); + + // If an `Connection::outgoing` calls gets cancelled after the outgoing stream's been already + // created, we store the stream here and use it next time `Connection::outgoing` is called. + // This ensures no outgoing stream gets lost and thus makes `outgoing` cancel safe. + let mut outgoing_result = None; loop { + while let Some(tx) = incoming_senders.pop() { + if let Some(result) = incoming_results.pop() { + if let Err(result) = tx.send(result).await { + incoming_results.push(result); + } + } else { + incoming_senders.push(tx); + break; + } + } + let command = select! { command = command_rx.recv() => command, - result = future::poll_fn(|cx| conn.poll_next_inbound(cx)) => { - match result { - Some(result) => { - incoming.push(result); - continue; + result = incoming(&mut conn) => { + if let Some(result) = result { + // Store at most one error + if result.is_err() { + incoming_results.retain(|result| result.is_ok()); } - None => break, + + incoming_results.push(result); + + continue; + } else { + // Connection closed by the peer + break; } } }; match command.unwrap_or(Command::Close(None)) { Command::Incoming(reply_tx) => { - let result = if let Some(result) = incoming.pop() { - result - } else { - select! { - result = future::poll_fn(|cx| conn.poll_next_inbound(cx)) => { - if let Some(result) = result { - result - } else { - // connection closed - break; - } - } - _ = reply_tx.closed() => continue, - } - }; - - if let Err(result) = reply_tx.send(result).await { - // reply_rx dropped before receiving the result, save it for next time. - incoming.push(result); - } + incoming_senders.push(reply_tx); } Command::Outgoing(reply_tx) => { - let result = if let Some(result) = outgoing.pop() { + let result = if let Some(result) = outgoing_result.take() { result } else { select! { - result = future::poll_fn(|cx| conn.poll_new_outbound(cx)) => result, + result = outgoing(&mut conn) => result, _ = reply_tx.closed() => continue, } }; if let Err(result) = reply_tx.send(result).await { - // reply_rx dropped before receiving the result, save it for next time. - outgoing.push(result); + outgoing_result = Some(result); } } Command::Close(reply_tx) => { - let result = future::poll_fn(|cx| conn.poll_close(cx)).await; + let result = close(&mut conn).await; if let Some(reply_tx) = reply_tx { reply_tx.send(result).ok(); @@ -234,6 +237,24 @@ async fn drive_connection( } } +async fn incoming( + conn: &mut yamux::Connection>, +) -> Option> { + future::poll_fn(|cx| conn.poll_next_inbound(cx)).await +} + +async fn outgoing( + conn: &mut yamux::Connection>, +) -> Result { + future::poll_fn(|cx| conn.poll_new_outbound(cx)).await +} + +async fn close( + conn: &mut yamux::Connection>, +) -> Result<(), yamux::ConnectionError> { + future::poll_fn(|cx| conn.poll_close(cx)).await +} + enum Command { // Using rendezvous to guarantee the reply is either received or we get it back if the receive // got cancelled. diff --git a/net/src/test_utils.rs b/net/src/test_utils.rs new file mode 100644 index 000000000..8e852f0d0 --- /dev/null +++ b/net/src/test_utils.rs @@ -0,0 +1,53 @@ +use crate::{ + quic, tcp, + unified::{Acceptor, Connection, Connector}, + SocketOptions, +}; +use futures_util::future; +use std::net::Ipv4Addr; +use test_strategy::Arbitrary; + +#[derive(Clone, Copy, Debug, Arbitrary)] +pub(crate) enum Proto { + Tcp, + Quic, +} + +pub(crate) fn create_connected_peers(proto: Proto) -> (Connector, Acceptor) { + let addr = (Ipv4Addr::LOCALHOST, 0).into(); + let options = SocketOptions::default(); + + match proto { + Proto::Tcp => { + let (client, _) = tcp::configure(addr, options).unwrap(); + let (_, server) = tcp::configure(addr, options).unwrap(); + + (client.into(), server.into()) + } + Proto::Quic => { + let (client, _, _) = quic::configure(addr, options).unwrap(); + let (_, server, _) = quic::configure(addr, options).unwrap(); + + (client.into(), server.into()) + } + } +} + +pub(crate) async fn create_connected_connections( + client: &Connector, + server: &Acceptor, +) -> (Connection, Connection) { + future::try_join(client.connect(*server.local_addr()), async { + server.accept().await?.await + }) + .await + .unwrap() +} + +pub(crate) fn init_log() { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .compact() + .try_init() + .ok(); +} diff --git a/net/src/unified.rs b/net/src/unified.rs index 3c27630b6..3f6abe2e4 100644 --- a/net/src/unified.rs +++ b/net/src/unified.rs @@ -263,12 +263,14 @@ pub enum ConnectionError { #[cfg(test)] mod tests { - use crate::SocketOptions; - - use super::*; + use super::Connection; + use crate::test_utils::{ + create_connected_connections, create_connected_peers, init_log, Proto, + }; use futures_util::{future, stream::FuturesUnordered, StreamExt}; - use rand::{distributions::Standard, Rng}; - use std::net::Ipv4Addr; + use itertools::Itertools; + use proptest::{arbitrary::any, collection::vec}; + use test_strategy::proptest; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, select, task, @@ -277,21 +279,21 @@ mod tests { #[tokio::test] async fn ping_tcp() { - let (client, server) = setup_tcp_peers(); - ping_case(client, server).await + ping_case(Proto::Tcp).await } #[tokio::test] async fn ping_quic() { - let (client, server) = setup_quic_peers(); - ping_case(client, server).await + ping_case(Proto::Quic).await } - async fn ping_case(client_connector: Connector, server_acceptor: Acceptor) { - let addr = *server_acceptor.local_addr(); + async fn ping_case(proto: Proto) { + let (client, server) = create_connected_peers(proto); + + let addr = *server.local_addr(); let server = task::spawn(async move { - let conn = server_acceptor.accept().await.unwrap().await.unwrap(); + let conn = server.accept().await.unwrap().await.unwrap(); let (mut tx, mut rx) = conn.incoming().await.unwrap(); let mut buf = [0; 4]; @@ -302,7 +304,7 @@ mod tests { }); let client = task::spawn(async move { - let conn = client_connector.connect(addr).await.unwrap(); + let conn = client.connect(addr).await.unwrap(); let (mut tx, mut rx) = conn.outgoing().await.unwrap(); tx.write_all(b"ping").await.unwrap(); @@ -318,40 +320,24 @@ mod tests { client.await.unwrap(); } - #[tokio::test] - async fn multi_streams_tcp() { - let (client, server) = setup_tcp_peers(); - multi_streams_case(client, server).await; - } - - #[tokio::test] - async fn multi_streams_quic() { - let (client, server) = setup_quic_peers(); - multi_streams_case(client, server).await; + #[proptest] + fn multi_streams( + proto: Proto, + #[strategy(vec(vec(any::(), 1..=1024), 1..32))] messages: Vec>, + ) { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() + .block_on(multi_streams_case(proto, messages)); } - async fn multi_streams_case(client_connector: Connector, server_acceptor: Acceptor) { - tracing_subscriber::fmt() - .with_max_level(tracing::Level::DEBUG) - .compact() - .init(); - - let num_messages = 32; - let min_message_size = 1; - let max_message_size = 256 * 1024; - - let mut rng = rand::thread_rng(); - let mut messages: Vec> = (0..num_messages) - .map(|_| { - let size = rng.gen_range(min_message_size..=max_message_size); - (&mut rng).sample_iter(Standard).take(size).collect() - }) - .collect(); - - let server_addr = *server_acceptor.local_addr(); + async fn multi_streams_case(proto: Proto, mut messages: Vec>) { + let (client, server) = create_connected_peers(proto); + let server_addr = *server.local_addr(); let client = async { - let conn = client_connector.connect(server_addr).await.unwrap(); + let conn = client.connect(server_addr).await.unwrap(); let tasks = FuturesUnordered::new(); for message in &messages { @@ -382,7 +368,7 @@ mod tests { .instrument(tracing::info_span!("client")); let server = async { - let conn = server_acceptor.accept().await.unwrap().await.unwrap(); + let conn = server.accept().await.unwrap().await.unwrap(); let mut tasks = FuturesUnordered::new(); let mut received_messages = Vec::new(); @@ -430,21 +416,77 @@ mod tests { similar_asserts::assert_eq!(received_messages, messages); } - fn setup_tcp_peers() -> (Connector, Acceptor) { - let (client, _) = - tcp::configure((Ipv4Addr::LOCALHOST, 0).into(), SocketOptions::default()).unwrap(); - let (_, server) = - tcp::configure((Ipv4Addr::LOCALHOST, 0).into(), SocketOptions::default()).unwrap(); + #[tokio::test] + async fn concurrent_streams_tcp() { + concurrent_streams_case(Proto::Tcp).await + } - (client.into(), server.into()) + #[tokio::test] + async fn concurrent_streams_quic() { + concurrent_streams_case(Proto::Quic).await } - fn setup_quic_peers() -> (Connector, Acceptor) { - let (client, _, _) = - quic::configure((Ipv4Addr::LOCALHOST, 0).into(), SocketOptions::default()).unwrap(); - let (_, server, _) = - quic::configure((Ipv4Addr::LOCALHOST, 0).into(), SocketOptions::default()).unwrap(); + // Test concurrent establishment of both incoming and outgoing streams + async fn concurrent_streams_case(proto: Proto) { + init_log(); + + let (client, server) = create_connected_peers(proto); + let (client, server) = create_connected_connections(&client, &server).await; + + // Exhaustively test all permutations of the operations + let ops = [ + "ping(client)", + "ping(server)", + "pong(client)", + "pong(server)", + ]; + + for order in ops.iter().permutations(ops.len()) { + async { + tracing::info!("init"); + + future::try_join_all(order.iter().map(|op| async { + match **op { + "ping(client)" => ping(&client).await, + "ping(server)" => ping(&server).await, + "pong(client)" => pong(&client).await, + "pong(server)" => pong(&server).await, + _ => unreachable!(), + } + })) + .await?; + + tracing::info!("done"); + + Ok::<_, anyhow::Error>(()) + } + .instrument(tracing::info_span!("order", message = ?order)) + .await + .unwrap() + } + } + + async fn ping(connection: &Connection) -> anyhow::Result<()> { + let (mut send_stream, mut recv_stream) = connection.outgoing().await?; + + send_stream.write_all(b"ping").await?; + + let mut buffer = [0; 4]; + recv_stream.read_exact(&mut buffer).await?; + assert_eq!(&buffer, b"pong"); + + Ok(()) + } + + async fn pong(connection: &Connection) -> anyhow::Result<()> { + let (mut send_stream, mut recv_stream) = connection.incoming().await?; + + let mut buffer = [0; 4]; + recv_stream.read_exact(&mut buffer).await?; + assert_eq!(&buffer, b"ping"); + + send_stream.write_all(b"pong").await?; - (client.into(), server.into()) + Ok(()) } } From 1600045b6f457e2bfe188f217f9d95ae98ce2104 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Wed, 11 Sep 2024 13:27:32 +0200 Subject: [PATCH 16/55] Minor renaming in the network module --- lib/src/network/mod.rs | 113 +++++++++++++++++++++-------------------- 1 file changed, 57 insertions(+), 56 deletions(-) diff --git a/lib/src/network/mod.rs b/lib/src/network/mod.rs index 38ac149ca..858755766 100644 --- a/lib/src/network/mod.rs +++ b/lib/src/network/mod.rs @@ -140,9 +140,9 @@ impl Network { span: Span::current(), gateway, this_runtime_id, - state: BlockingMutex::new(State { - message_brokers: Some(HashMap::default()), - registry: Slab::new(), + registry: BlockingMutex::new(Registry { + peers: Some(HashMap::default()), + repos: Slab::new(), }), port_forwarder, port_forwarder_state: BlockingMutex::new(ComponentState::disabled( @@ -360,16 +360,16 @@ impl Network { let response_limiter = Arc::new(Semaphore::new(MAX_UNCHOKED_COUNT)); let stats_tracker = StatsTracker::default(); - let mut network_state = self.inner.state.lock().unwrap(); + let mut registry = self.inner.registry.lock().unwrap(); - network_state.create_link( + registry.create_link( handle.vault.clone(), &pex, response_limiter.clone(), stats_tracker.bytes.clone(), ); - let key = network_state.registry.insert(RegistrationHolder { + let key = registry.repos.insert(RegistrationHolder { vault: handle.vault, dht, pex, @@ -391,12 +391,12 @@ impl Network { pub async fn shutdown(&self) { // TODO: Would be a nice-to-have to also wait for all the spawned tasks here (e.g. dicovery // mechanisms). - let Some(message_brokers) = self.inner.state.lock().unwrap().message_brokers.take() else { + let Some(peers) = self.inner.registry.lock().unwrap().peers.take() else { tracing::warn!("Network already shut down"); return; }; - shutdown_brokers(message_brokers).await; + shutdown_peers(peers).await; } } @@ -409,8 +409,8 @@ impl Registration { pub async fn set_dht_enabled(&self, enabled: bool) { set_metadata_bool(&self.inner, self.key, DHT_ENABLED, enabled).await; - let mut state = self.inner.state.lock().unwrap(); - let holder = &mut state.registry[self.key]; + let mut registry = self.inner.registry.lock().unwrap(); + let holder = &mut registry.repos[self.key]; if enabled { holder.dht = Some( @@ -427,7 +427,7 @@ impl Registration { /// difference is in that this function should return true even in case e.g. the whole network /// is disabled. pub fn is_dht_enabled(&self) -> bool { - self.inner.state.lock().unwrap().registry[self.key] + self.inner.registry.lock().unwrap().repos[self.key] .dht .is_some() } @@ -440,19 +440,19 @@ impl Registration { pub async fn set_pex_enabled(&self, enabled: bool) { set_metadata_bool(&self.inner, self.key, PEX_ENABLED, enabled).await; - let state = self.inner.state.lock().unwrap(); - state.registry[self.key].pex.set_enabled(enabled); + let registry = self.inner.registry.lock().unwrap(); + registry.repos[self.key].pex.set_enabled(enabled); } pub fn is_pex_enabled(&self) -> bool { - self.inner.state.lock().unwrap().registry[self.key] + self.inner.registry.lock().unwrap().repos[self.key] .pex .is_enabled() } /// Fetch per-repository network statistics. pub fn stats(&self) -> Stats { - self.inner.state.lock().unwrap().registry[self.key] + self.inner.registry.lock().unwrap().repos[self.key] .stats_tracker .read() } @@ -460,12 +460,16 @@ impl Registration { impl Drop for Registration { fn drop(&mut self) { - let mut state = self.inner.state.lock().unwrap(); + let mut registry = self + .inner + .registry + .lock() + .unwrap_or_else(|error| error.into_inner()); - if let Some(holder) = state.registry.try_remove(self.key) { - if let Some(brokers) = &mut state.message_brokers { - for broker in brokers.values_mut() { - broker.destroy_link(holder.vault.repository_id()); + if let Some(holder) = registry.repos.try_remove(self.key) { + if let Some(peers) = &mut registry.peers { + for peer in peers.values_mut() { + peer.destroy_link(holder.vault.repository_id()); } } } @@ -473,7 +477,7 @@ impl Drop for Registration { } async fn set_metadata_bool(inner: &Inner, key: usize, name: &str, value: bool) { - let metadata = inner.state.lock().unwrap().registry[key].vault.metadata(); + let metadata = inner.registry.lock().unwrap().repos[key].vault.metadata(); metadata.set(name, value).await.ok(); } @@ -492,7 +496,7 @@ struct Inner { span: Span, gateway: Gateway, this_runtime_id: SecretRuntimeId, - state: BlockingMutex, + registry: BlockingMutex, port_forwarder: upnp::PortForwarder, port_forwarder_state: BlockingMutex>, local_discovery_state: BlockingMutex>, @@ -512,13 +516,13 @@ struct Inner { stats_tracker: StatsTracker, } -struct State { +struct Registry { // This is None once the network calls shutdown. - message_brokers: Option>, - registry: Slab, + peers: Option>, + repos: Slab, } -impl State { +impl Registry { fn create_link( &mut self, repo: Vault, @@ -526,9 +530,9 @@ impl State { response_limiter: Arc, byte_counters: Arc, ) { - if let Some(brokers) = &mut self.message_brokers { - for broker in brokers.values_mut() { - broker.create_link( + if let Some(peers) = &mut self.peers { + for peer in peers.values_mut() { + peer.create_link( repo.clone(), pex, response_limiter.clone(), @@ -541,7 +545,7 @@ impl State { impl Inner { fn is_shutdown(&self) -> bool { - self.state.lock().unwrap().message_brokers.is_none() + self.registry.lock().unwrap().peers.is_none() } async fn bind(self: &Arc, bind: &[PeerAddr]) { @@ -616,14 +620,14 @@ impl Inner { // Disconnect from all currently connected peers, regardless of their source. async fn disconnect_all(&self) { - let Some(message_brokers) = mem::replace( - &mut self.state.lock().unwrap().message_brokers, + let Some(peers) = mem::replace( + &mut self.registry.lock().unwrap().peers, Some(HashMap::default()), ) else { return; }; - shutdown_brokers(message_brokers).await; + shutdown_peers(peers).await; } fn spawn_local_discovery(self: &Arc) -> Option { @@ -889,17 +893,16 @@ impl Inner { let released = permit.released(); { - let mut state = self.state.lock().unwrap(); - let state = &mut *state; + let mut registry = self.registry.lock().unwrap(); + let registry = &mut *registry; - let brokers = match &mut state.message_brokers { - Some(brokers) => brokers, + let Some(peers) = &mut registry.peers else { // Network has been shut down. - None => return false, + return false; }; - let broker = brokers.entry(that_runtime_id).or_insert_with(|| { - let mut broker = self.span.in_scope(|| { + let peer = peers.entry(that_runtime_id).or_insert_with(|| { + let mut peer = self.span.in_scope(|| { MessageBroker::new( self.this_runtime_id.public(), that_runtime_id, @@ -912,8 +915,8 @@ impl Inner { // TODO: for DHT connection we should only link the repository for which we did the // lookup but make sure we correctly handle edge cases, for example, when we have // more than one repository shared with the peer. - for (_, holder) in &state.registry { - broker.create_link( + for (_, holder) in ®istry.repos { + peer.create_link( holder.vault.clone(), &holder.pex, holder.response_limiter.clone(), @@ -921,14 +924,14 @@ impl Inner { ); } - broker + peer }); - broker.add_connection(connection, permit, self.stats_tracker.bytes.clone()); + peer.add_connection(connection, permit, self.stats_tracker.bytes.clone()); } let _remover = MessageBrokerEntryGuard { - state: &self.state, + registry: &self.registry, that_runtime_id, monitor, }; @@ -1028,7 +1031,7 @@ enum HandshakeError { // RAII guard which when dropped removes the broker from the network state if it has no connections. struct MessageBrokerEntryGuard<'a> { - state: &'a BlockingMutex, + registry: &'a BlockingMutex, that_runtime_id: PublicRuntimeId, monitor: &'a ConnectionMonitor, } @@ -1037,9 +1040,12 @@ impl Drop for MessageBrokerEntryGuard<'_> { fn drop(&mut self) { tracing::info!(parent: self.monitor.span(), "Disconnected"); - let mut state = self.state.lock().unwrap(); - if let Some(brokers) = &mut state.message_brokers { - if let Entry::Occupied(entry) = brokers.entry(self.that_runtime_id) { + let mut registry = self + .registry + .lock() + .unwrap_or_else(|error| error.into_inner()); + if let Some(peers) = &mut registry.peers { + if let Entry::Occupied(entry) = peers.entry(self.that_runtime_id) { if !entry.get().has_connections() { entry.remove(); } @@ -1151,11 +1157,6 @@ pub fn repository_info_hash(id: &RepositoryId) -> InfoHash { .unwrap() } -async fn shutdown_brokers(message_brokers: HashMap) { - future::join_all( - message_brokers - .into_values() - .map(|message_broker| message_broker.shutdown()), - ) - .await; +async fn shutdown_peers(peers: HashMap) { + future::join_all(peers.into_values().map(|peer| peer.shutdown())).await; } From 8991108d644e85f19d22e8d0236b2c3f3bd2da32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Wed, 11 Sep 2024 13:55:14 +0200 Subject: [PATCH 17/55] Create one MessageBroker per connection --- lib/src/network/message_broker.rs | 5 -- lib/src/network/message_dispatcher.rs | 43 ++---------- lib/src/network/mod.rs | 95 +++++++++++++-------------- 3 files changed, 49 insertions(+), 94 deletions(-) diff --git a/lib/src/network/message_broker.rs b/lib/src/network/message_broker.rs index d45f6e8e4..ba93427b4 100644 --- a/lib/src/network/message_broker.rs +++ b/lib/src/network/message_broker.rs @@ -78,11 +78,6 @@ impl MessageBroker { self.dispatcher.bind(connection, permit, byte_counters) } - /// Has this broker at least one live connection? - pub fn has_connections(&self) -> bool { - self.dispatcher.is_bound() - } - /// Try to establish a link between a local repository and a remote repository. The remote /// counterpart needs to call this too with matching repository id for the link to actually be /// created. diff --git a/lib/src/network/message_dispatcher.rs b/lib/src/network/message_dispatcher.rs index c057718d3..7d06aa504 100644 --- a/lib/src/network/message_dispatcher.rs +++ b/lib/src/network/message_dispatcher.rs @@ -17,10 +17,7 @@ use net::unified::{Connection, ConnectionError, RecvStream, SendStream}; use std::{ io, pin::Pin, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, + sync::Arc, task::{Context, Poll}, }; use tokio::{ @@ -39,22 +36,19 @@ const CONTENT_STREAM_BUFFER_SIZE: usize = 1024; pub(super) struct MessageDispatcher { command_tx: mpsc::UnboundedSender, sink_tx: mpsc::Sender, - connection_count: Arc, } impl MessageDispatcher { pub fn new() -> Self { let (command_tx, command_rx) = mpsc::unbounded_channel(); let (sink_tx, sink_rx) = mpsc::channel(1); - let connection_count = Arc::new(AtomicUsize::new(0)); - let worker = Worker::new(command_rx, sink_rx, connection_count.clone()); + let worker = Worker::new(command_rx, sink_rx); task::spawn(worker.run()); Self { command_tx, sink_tx, - connection_count, } } @@ -75,11 +69,6 @@ impl MessageDispatcher { .ok(); } - /// Is this dispatcher bound to at least one connection? - pub fn is_bound(&self) -> bool { - self.connection_count.load(Ordering::Acquire) > 0 - } - /// Opens a stream for receiving messages on the given channel. Any messages received on /// `channel` before the stream's been opened are discarded. When a stream is opened, all /// previously opened streams on the same channel (if any) get automatically closed. @@ -268,24 +257,16 @@ struct ConnectionStream { reader: MessageStream>>, permit: ConnectionPermitHalf, permit_released: AwaitDrop, - connection_count: Arc, } impl ConnectionStream { - fn new( - reader: Instrumented, - permit: ConnectionPermitHalf, - connection_count: Arc, - ) -> Self { - connection_count.fetch_add(1, Ordering::Release); - + fn new(reader: Instrumented, permit: ConnectionPermitHalf) -> Self { let permit_released = permit.released(); Self { reader: MessageStream::new(Instrumented::new(reader, permit.byte_counters())), permit, permit_released, - connection_count, } } } @@ -309,12 +290,6 @@ impl Stream for ConnectionStream { } } -impl Drop for ConnectionStream { - fn drop(&mut self) { - self.connection_count.fetch_sub(1, Ordering::Release); - } -} - // Sink for sending messages on a single connection. Contains a connection permit half which gets // released on drop. Automatically closes when the corresponding `ConnectionStream` is closed. struct ConnectionSink { @@ -370,20 +345,14 @@ impl Sink for ConnectionSink { struct Worker { command_rx: mpsc::UnboundedReceiver, - connection_count: Arc, send: SendState, recv: RecvState, } impl Worker { - fn new( - command_rx: mpsc::UnboundedReceiver, - sink_rx: mpsc::Receiver, - connection_count: Arc, - ) -> Self { + fn new(command_rx: mpsc::UnboundedReceiver, sink_rx: mpsc::Receiver) -> Self { Self { command_rx, - connection_count, send: SendState { sink_rx, sinks: Vec::new(), @@ -435,8 +404,6 @@ impl Worker { permit, byte_counters, } => { - let connection_count = self.connection_count.clone(); - streams.push(async move { let (tx, rx) = match ConnectionDirection::from_source(permit.source()) { ConnectionDirection::Incoming => connection.incoming().await?, @@ -449,7 +416,7 @@ impl Worker { let tx = ConnectionSink::new(tx, tx_permit); let rx = Instrumented::new(rx, byte_counters.clone()); - let rx = ConnectionStream::new(rx, rx_permit, connection_count); + let rx = ConnectionStream::new(rx, rx_permit); Ok::<_, ConnectionError>((connection, tx, rx)) }); diff --git a/lib/src/network/mod.rs b/lib/src/network/mod.rs index 858755766..804c59826 100644 --- a/lib/src/network/mod.rs +++ b/lib/src/network/mod.rs @@ -58,7 +58,7 @@ use self::{ stun::StunClients, }; use crate::{ - collections::{hash_map::Entry, HashMap, HashSet}, + collections::HashSet, network::connection::ConnectionDirection, protocol::RepositoryId, repository::{RepositoryHandle, Vault}, @@ -141,7 +141,7 @@ impl Network { gateway, this_runtime_id, registry: BlockingMutex::new(Registry { - peers: Some(HashMap::default()), + peers: Some(Slab::new()), repos: Slab::new(), }), port_forwarder, @@ -467,10 +467,8 @@ impl Drop for Registration { .unwrap_or_else(|error| error.into_inner()); if let Some(holder) = registry.repos.try_remove(self.key) { - if let Some(peers) = &mut registry.peers { - for peer in peers.values_mut() { - peer.destroy_link(holder.vault.repository_id()); - } + for (_, peer) in registry.peers.as_mut().into_iter().flatten() { + peer.destroy_link(holder.vault.repository_id()); } } } @@ -518,7 +516,7 @@ struct Inner { struct Registry { // This is None once the network calls shutdown. - peers: Option>, + peers: Option>, repos: Slab, } @@ -531,7 +529,7 @@ impl Registry { byte_counters: Arc, ) { if let Some(peers) = &mut self.peers { - for peer in peers.values_mut() { + for (_, peer) in peers { peer.create_link( repo.clone(), pex, @@ -622,7 +620,7 @@ impl Inner { async fn disconnect_all(&self) { let Some(peers) = mem::replace( &mut self.registry.lock().unwrap().peers, - Some(HashMap::default()), + Some(Slab::default()), ) else { return; }; @@ -892,7 +890,7 @@ impl Inner { let released = permit.released(); - { + let key = { let mut registry = self.registry.lock().unwrap(); let registry = &mut *registry; @@ -901,38 +899,36 @@ impl Inner { return false; }; - let peer = peers.entry(that_runtime_id).or_insert_with(|| { - let mut peer = self.span.in_scope(|| { - MessageBroker::new( - self.this_runtime_id.public(), - that_runtime_id, - self.pex_discovery.new_peer(), - self.peers_monitor - .make_child(format!("{:?}", that_runtime_id.as_public_key())), - ) - }); - - // TODO: for DHT connection we should only link the repository for which we did the - // lookup but make sure we correctly handle edge cases, for example, when we have - // more than one repository shared with the peer. - for (_, holder) in ®istry.repos { - peer.create_link( - holder.vault.clone(), - &holder.pex, - holder.response_limiter.clone(), - holder.stats_tracker.bytes.clone(), - ); - } - - peer + let mut peer = self.span.in_scope(|| { + MessageBroker::new( + self.this_runtime_id.public(), + that_runtime_id, + self.pex_discovery.new_peer(), + self.peers_monitor + .make_child(format!("{:?}", that_runtime_id.as_public_key())), + ) }); + // TODO: for DHT connection we should only link the repository for which we did the + // lookup but make sure we correctly handle edge cases, for example, when we have + // more than one repository shared with the peer. + for (_, holder) in ®istry.repos { + peer.create_link( + holder.vault.clone(), + &holder.pex, + holder.response_limiter.clone(), + holder.stats_tracker.bytes.clone(), + ); + } + peer.add_connection(connection, permit, self.stats_tracker.bytes.clone()); - } - let _remover = MessageBrokerEntryGuard { + peers.insert(peer) + }; + + let _guard = PeerGuard { registry: &self.registry, - that_runtime_id, + key, monitor, }; @@ -1029,27 +1025,24 @@ enum HandshakeError { Connection(#[from] ConnectionError), } -// RAII guard which when dropped removes the broker from the network state if it has no connections. -struct MessageBrokerEntryGuard<'a> { +// RAII guard which when dropped removes the peer from the registry. +struct PeerGuard<'a> { registry: &'a BlockingMutex, - that_runtime_id: PublicRuntimeId, + key: usize, monitor: &'a ConnectionMonitor, } -impl Drop for MessageBrokerEntryGuard<'_> { +impl Drop for PeerGuard<'_> { fn drop(&mut self) { tracing::info!(parent: self.monitor.span(), "Disconnected"); - let mut registry = self + if let Some(peers) = &mut self .registry .lock() - .unwrap_or_else(|error| error.into_inner()); - if let Some(peers) = &mut registry.peers { - if let Entry::Occupied(entry) = peers.entry(self.that_runtime_id) { - if !entry.get().has_connections() { - entry.remove(); - } - } + .unwrap_or_else(|error| error.into_inner()) + .peers + { + peers.try_remove(self.key); } } } @@ -1157,6 +1150,6 @@ pub fn repository_info_hash(id: &RepositoryId) -> InfoHash { .unwrap() } -async fn shutdown_peers(peers: HashMap) { - future::join_all(peers.into_values().map(|peer| peer.shutdown())).await; +async fn shutdown_peers(peers: Slab) { + future::join_all(peers.into_iter().map(|(_, peer)| peer.shutdown())).await; } From b0d1607838ae3b52797cef496a1315a67210593a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Wed, 11 Sep 2024 14:51:49 +0200 Subject: [PATCH 18/55] net: Add notification on connection close --- net/src/bus.rs | 10 ++--- net/src/bus/dispatch.rs | 4 ++ net/src/bus/worker.rs | 5 ++- net/src/quic.rs | 7 ++++ net/src/sync.rs | 22 +++++------ net/src/tcp.rs | 14 ++++++- net/src/unified.rs | 82 ++++++++++++++++++++++++++++++++++------- 7 files changed, 108 insertions(+), 36 deletions(-) diff --git a/net/src/bus.rs b/net/src/bus.rs index 58a480ecc..62642019a 100644 --- a/net/src/bus.rs +++ b/net/src/bus.rs @@ -66,15 +66,11 @@ impl Bus { (send, recv) } - /// Gracefully shuts down the underlying connection. - pub async fn shutdown(&self) { + /// Gracefully close the underlying connection. + pub async fn close(&self) { let (reply_tx, reply_rx) = oneshot::channel(); - if self - .command_tx - .send(Command::Shutdown { reply_tx }) - .is_err() - { + if self.command_tx.send(Command::Close { reply_tx }).is_err() { return; } diff --git a/net/src/bus/dispatch.rs b/net/src/bus/dispatch.rs index ce4cbff02..7a5411605 100644 --- a/net/src/bus/dispatch.rs +++ b/net/src/bus/dispatch.rs @@ -110,6 +110,10 @@ impl Dispatcher { pub async fn close(&self) { self.connection.close().await } + + pub async fn closed(&self) { + self.connection.closed().await + } } type Registry = Slab<(TopicId, oneshot::Sender<(SendStream, RecvStream)>)>; diff --git a/net/src/bus/worker.rs b/net/src/bus/worker.rs index 9ee5eadd2..d24450e12 100644 --- a/net/src/bus/worker.rs +++ b/net/src/bus/worker.rs @@ -18,7 +18,7 @@ pub(super) enum Command { send_stream_tx: oneshot::Sender>, recv_stream_tx: oneshot::Sender>, }, - Shutdown { + Close { reply_tx: oneshot::Sender<()>, }, } @@ -32,6 +32,7 @@ pub(super) async fn run(connection: Connection, mut command_rx: mpsc::UnboundedR let command = select! { Some(command) = command_rx.recv() => command, Some(_) = topic_tasks.next() => continue, + _ = dispatcher.closed() => break, }; match command { @@ -47,7 +48,7 @@ pub(super) async fn run(connection: Connection, mut command_rx: mpsc::UnboundedR &dispatcher, )); } - Command::Shutdown { reply_tx } => { + Command::Close { reply_tx } => { dispatcher.close().await; reply_tx.send(()).ok(); break; diff --git a/net/src/quic.rs b/net/src/quic.rs index 57bd24e10..6f16abba7 100644 --- a/net/src/quic.rs +++ b/net/src/quic.rs @@ -124,6 +124,13 @@ impl Connection { pub fn close(&self) { self.inner.close(0u8.into(), &[]); } + + pub fn closed(&self) -> impl Future + 'static { + let inner = self.inner.clone(); + async move { + inner.closed().await; + } + } } pub type SendStream = quinn::SendStream; diff --git a/net/src/sync.rs b/net/src/sync.rs index a2e3d5a4b..8506f3d3f 100644 --- a/net/src/sync.rs +++ b/net/src/sync.rs @@ -1,15 +1,15 @@ +/// Single producer, single consumer, oneshot, rendezvous channel. +/// +/// Unlike `tokio::sync::oneshot`, this one guarantees that the message is never lost even when the +/// receiver is dropped before receiving the message. Because of this, [`Sender::send`] must be +/// `async`. +/// +/// # Cancel safety +/// +/// If `send` is cancelled before completion, the value is still guaranteed to be received by the +/// receiver. If `recv` is cancelled before completion, the value is returned back from `send`. If +/// both `send` and `recv` are cancelled, the value is lost. pub(crate) mod rendezvous { - //! Single producer, single consumer, oneshot, rendezvous channel. - //! - //! Unlike `tokio::sync::oneshot`, this one guarantees that the message is never lost even when - //! the receiver is dropped before receiving the message. Because of this,[`Sender::send`] must - //! be `async`. - //! - //! # Cancel safety - //! - //! If `send` is cancelled before completion, the value is still guaranteed to be received by - //! the receiver. If `recv` is cancelled before completion, the value is returned back from - //! `send`. If both `send` and `recv` are cancelled, the value is lost. use std::{ fmt, diff --git a/net/src/tcp.rs b/net/src/tcp.rs index bd2dda461..08be110e1 100644 --- a/net/src/tcp.rs +++ b/net/src/tcp.rs @@ -1,6 +1,10 @@ use self::implementation::{TcpListener, TcpStream}; use crate::{sync::rendezvous, SocketOptions}; -use std::{future, io, net::SocketAddr}; +use std::{ + future::{self, Future}, + io, + net::SocketAddr, +}; use tokio::{ io::{ReadHalf, WriteHalf}, select, @@ -69,8 +73,8 @@ impl Connection { task::spawn(drive_connection(connection, command_rx).instrument(Span::current())); Self { - command_tx, remote_addr, + command_tx, } } @@ -144,6 +148,12 @@ impl Connection { Ok(Err(error)) => tracing::debug!(?error, "failed to close connection"), } } + + /// Waits for the connection to be closed. + pub fn closed(&self) -> impl Future + 'static { + let command_tx = self.command_tx.clone(); + async move { command_tx.closed().await } + } } pub type SendStream = WriteHalf>; diff --git a/net/src/unified.rs b/net/src/unified.rs index 3f6abe2e4..8f19de3ac 100644 --- a/net/src/unified.rs +++ b/net/src/unified.rs @@ -1,6 +1,7 @@ //! Unified interface over different network protocols (currently TCP and QUIC). use crate::{quic, tcp}; +use futures_util::future::Either; use std::{ future::{self, Future, IntoFuture, Ready}, io, @@ -183,6 +184,17 @@ impl Connection { Self::Quic(inner) => inner.close(), } } + + /// Wait for the connection to be closed for any reason (e.g., locally or by the remote peer) + /// + /// Note the returned future has a `'static` lifetime, so it can be moved to another task/thread + /// and awaited there. + pub fn closed(&self) -> impl Future + 'static { + match self { + Self::Tcp(inner) => Either::Left(inner.closed()), + Self::Quic(inner) => Either::Right(inner.closed()), + } + } } pub enum SendStream { @@ -464,29 +476,71 @@ mod tests { .await .unwrap() } + + async fn ping(connection: &Connection) -> anyhow::Result<()> { + let (mut send_stream, mut recv_stream) = connection.outgoing().await?; + + send_stream.write_all(b"ping").await?; + + let mut buffer = [0; 4]; + recv_stream.read_exact(&mut buffer).await?; + assert_eq!(&buffer, b"pong"); + + Ok(()) + } + + async fn pong(connection: &Connection) -> anyhow::Result<()> { + let (mut send_stream, mut recv_stream) = connection.incoming().await?; + + let mut buffer = [0; 4]; + recv_stream.read_exact(&mut buffer).await?; + assert_eq!(&buffer, b"ping"); + + send_stream.write_all(b"pong").await?; + + Ok(()) + } } - async fn ping(connection: &Connection) -> anyhow::Result<()> { - let (mut send_stream, mut recv_stream) = connection.outgoing().await?; + #[tokio::test] + async fn close_tcp() { + close_case(Proto::Tcp).await + } - send_stream.write_all(b"ping").await?; + #[tokio::test] + async fn close_quic() { + close_case(Proto::Quic).await + } - let mut buffer = [0; 4]; - recv_stream.read_exact(&mut buffer).await?; - assert_eq!(&buffer, b"pong"); + async fn close_case(proto: Proto) { + let (client, server) = create_connected_peers(proto); + let (client, server) = create_connected_connections(&client, &server).await; - Ok(()) + future::join(client.closed(), async { + task::yield_now().await; + server.close().await; + }) + .await; } - async fn pong(connection: &Connection) -> anyhow::Result<()> { - let (mut send_stream, mut recv_stream) = connection.incoming().await?; + #[tokio::test] + async fn drop_tcp() { + drop_case(Proto::Tcp).await + } - let mut buffer = [0; 4]; - recv_stream.read_exact(&mut buffer).await?; - assert_eq!(&buffer, b"ping"); + #[tokio::test] + async fn drop_quic() { + drop_case(Proto::Quic).await + } - send_stream.write_all(b"pong").await?; + async fn drop_case(proto: Proto) { + let (client, server) = create_connected_peers(proto); + let (client, server) = create_connected_connections(&client, &server).await; - Ok(()) + future::join(client.closed(), async { + task::yield_now().await; + drop(server); + }) + .await; } } From c88f5e86d26667cb14df121221874e86cd912519 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Wed, 11 Sep 2024 16:42:40 +0200 Subject: [PATCH 19/55] Refactor MessageDispatcher to use Bus under the hood --- Cargo.toml | 2 +- lib/Cargo.toml | 1 + lib/src/network/barrier.rs | 702 ------------------- lib/src/network/connection.rs | 34 - lib/src/network/crypto.rs | 206 +++--- lib/src/network/message.rs | 124 +--- lib/src/network/message_broker.rs | 126 ++-- lib/src/network/message_dispatcher.rs | 939 +++----------------------- lib/src/network/message_io.rs | 406 ----------- lib/src/network/mod.rs | 18 +- lib/src/network/stats.rs | 12 - 11 files changed, 317 insertions(+), 2253 deletions(-) delete mode 100644 lib/src/network/barrier.rs delete mode 100644 lib/src/network/message_io.rs diff --git a/Cargo.toml b/Cargo.toml index 93f056c5c..ce025eca7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,7 @@ anyhow = "1.0.86" assert_matches = "1.5" async-trait = "0.1.73" btdht = { git = "https://github.com/equalitie/btdht.git", rev = "e7ddf5607b20f0b82cbc3ea6259425c00bd8d16b" } -bytes = "1.5.0" +bytes = "1.7.1" camino = "1.1.6" chrono = { version = "0.4.31", default-features = false, features = ["clock"] } clap = { version = "4.4.6", features = ["derive"] } diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 9e86f7cca..06810e84d 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -34,6 +34,7 @@ base64 = "0.13.0" bincode = "1.3" blake3 = { version = "1.5.0", features = ["traits-preview"] } btdht = { workspace = true } +bytes = { workspace = true } camino = { workspace = true } chacha20 = "0.9.1" chrono = { workspace = true } diff --git a/lib/src/network/barrier.rs b/lib/src/network/barrier.rs deleted file mode 100644 index 648f14de3..000000000 --- a/lib/src/network/barrier.rs +++ /dev/null @@ -1,702 +0,0 @@ -use super::message_dispatcher::{ - ChannelClosed, ContentSinkTrait, ContentStreamError, ContentStreamTrait, -}; -use state_monitor::{MonitoredValue, StateMonitor}; -use std::{fmt, mem::size_of}; -use tokio::time::{self, Duration}; - -type BarrierId = u64; -type Round = u32; -#[derive(Eq, PartialEq, Copy, Clone)] -enum Step { - Zero, - One, -} -type Msg = (BarrierId, Round, Step); - -/// Ensures there are no more in-flight messages beween us and the peer. -/// -/// There are two aspects of this, first one is that we need to ignore all peer's message up until -/// they indicate to us that they are starting a new round of communication (this is necessary so -/// we can restart the crypto with a common state). -/// -/// The second aspect is that we also need to cover the edge case when one peer restart its link -/// without the other one noticing. Then it could happen that the first peer sends a barrier -/// message to the second peer and that gets lost, the second peer (after the repo reload) will -/// then send a barrier message to the first while the first one will think it's the response to -/// it's first message. The second peer will then not receive its response. -/// -/// The construction of the algorithm went as follows: the two peers need to ensure that the entire -/// barrier agreement happens within a single instance of Barrier on one side and a single instance -/// of a Barrier on the other side. To ensure this, both peers choose a random `barrier_id` which -/// they send to each other. This is then echoed from the other side and upon reception of the echo -/// each peer is able to check that the other side knows it's `barrier_id`. Let's call this process -/// "sync on barrier ID". -/// -/// To be able to do the above, we needed to perform two steps: -/// -/// Step #1: Send our barrier ID to the peer, and -/// Step #2: Receive our barrier ID from the peer. -/// -/// As such, it may happen that one peer is currently performing the step #1 while the other peer -/// is performing the step #2. Thus we need to ensure that they're both performing the two steps in -/// sync. Let's call this process "sync on step". -/// -/// Finally, because each "step" consists of sending and receiving (exchanging) a message, we must -/// ensure that the exchange does not happen across steps. Or in other words: it must not be the -/// case that a peer sends a message in one step, but receives a message from the other peer's -/// previous step. Let's call this process "sync on exchange". -/// -/// TODO: This is one of those algorithms where a formal correctness proof would be welcome. -pub(super) struct Barrier<'a> { - // Barrier ID is used to ensure that the other peer is communicating with this instance of - // Barrier by sending us the ID back. - barrier_id: BarrierId, - stream: &'a mut (dyn ContentStreamTrait + Send + Sync + 'a), - sink: &'a (dyn ContentSinkTrait + Send + Sync + 'a), - state: MonitoredValue<&'static str>, - #[cfg(test)] - marker: Option, -} - -impl<'a> Barrier<'a> { - pub fn new(stream: &'a mut Stream, sink: &'a Sink, monitor: &StateMonitor) -> Self - where - Stream: ContentStreamTrait + Send + Sync, - Sink: ContentSinkTrait + Send + Sync, - { - Self { - barrier_id: rand::random(), - stream, - sink, - state: monitor.make_value("barrier", "idle"), - #[cfg(test)] - marker: None, - } - } - - pub async fn run(&mut self) -> Result<(), BarrierError> { - use std::cmp::max; - - #[cfg(test)] - self.mark_step().await; - - #[cfg(test)] - println!("{:x} >> RST", self.barrier_id); - - *self.state.get() = "sending reset"; - - // I think we send this empty message in order to break the encryption on the other side and - // thus forcing it to start this barrier process again. - self.sink.send(vec![]).await?; - - let mut next_round: u32 = 0; - - loop { - let mut round = next_round; - - if round > 64 { - tracing::error!("Barrier algorithm failed"); - return Err(BarrierError::Failure); - } - - let (their_barrier_id, their_round, their_step) = - self.exchange0(self.barrier_id, round).await?; - - if their_step != Step::Zero { - next_round = max(round, their_round) + 1; - continue; - } - - // Just for info, this is ensured inside `exchange0`. - assert!(round <= their_round); - - if round < their_round { - *self.state.get() = "catching up on step 0"; - // They are ahead of us, but on the same step. So play along, bump our round to - // theirs and pretend we did the step zero with the same round. - round = their_round; - self.send(self.barrier_id, round, Step::Zero).await?; - } - - let (our_barrier_id, their_round, their_step) = - match self.exchange1(their_barrier_id, round).await? { - Some(msg) => msg, - None => { - next_round += 1; - continue; - } - }; - - if their_step != Step::One { - next_round = max(round, their_round) + 1; - continue; - } - - if our_barrier_id != self.barrier_id { - // Peer was communicating with our previous barrier, ignoring that. - next_round = max(round, their_round) + 1; - continue; - } - - // Ensure we end at the same time. - if round != their_round { - next_round = max(round, their_round) + 1; - continue; - } - - break; - } - - *self.state.get() = "done"; - - Ok(()) - } - - async fn exchange0( - &mut self, - barrier_id: BarrierId, - our_round: Round, - ) -> Result { - let our_step = Step::Zero; - - loop { - *self.state.get() = "step 0 sending"; - self.send(barrier_id, our_round, our_step).await?; - - loop { - *self.state.get() = "step 0 receiving"; - let (barrier_id, their_round, their_step) = match self - .recv( - #[cfg(test)] - our_round, - #[cfg(test)] - our_step, - ) - .await? - { - Some(msg) => msg, - None => continue, - }; - - if their_round < our_round { - // The peer is behind, so we resend our previous message. If they receive it - // more than once, they should ignore the duplicates. Note that we do need to - // resend it as opposed to just ignore and start receiving again because they - // could have dropped the previous message we sent them (e.g. because they did - // not have the repo that the two are about to sync). - break; - } - - return Ok((barrier_id, their_round, their_step)); - } - } - } - - async fn exchange1( - &mut self, - barrier_id: BarrierId, - our_round: Round, - ) -> Result, BarrierError> { - *self.state.get() = "step 1 sending"; - let our_step = Step::One; - self.send(barrier_id, our_round, our_step).await?; - - loop { - *self.state.get() = "step 1 receiving"; - let recv = self.recv( - #[cfg(test)] - our_round, - #[cfg(test)] - our_step, - ); - - // Timing out shouldn't be necessary, but it may still be useful if the peer is buggy. - let result = match time::timeout(Duration::from_secs(5), recv).await { - Ok(result) => result, - Err(_) => { - // timeout - return Ok(None); - } - }; - - match result? { - Some((barrier, round, step)) => { - if step == Step::Zero && round == our_round { - // They resent the same message from previous step, ignore it. - continue; - } - return Ok(Some((barrier, round, step))); - } - None => return Ok(None), - } - } - } - - async fn recv( - &mut self, - #[cfg(test)] round: Round, - #[cfg(test)] our_step: Step, - ) -> Result, BarrierError> { - #[cfg(test)] - self.mark_step().await; - - let msg = self.stream.recv().await?; - - match parse_message(&msg) { - Some((barrier_id, their_round, their_step)) => { - match their_step { - Step::Zero => { - #[cfg(test)] - println!( - "{:x} R{} S{:?} << their_barrier_id:{:x} their_round:{} their_step:{:?}", - self.barrier_id, round, our_step, barrier_id, their_round, their_step - ) - } - Step::One => { - #[cfg(test)] - println!( - "{:x} R{} S{:?} << our_barrier_id:{:x} their_round:{} their_step:{:?}", - self.barrier_id, round, our_step, barrier_id, their_round, their_step - ) - } - } - Ok(Some((barrier_id, their_round, their_step))) - } - // Ignore messages that belonged to whatever communication was going on prior us - // starting this barrier process. - None => Ok(None), - } - } - - async fn send( - &mut self, - barrier_id: BarrierId, - our_round: Round, - our_step: Step, - ) -> Result<(), ChannelClosed> { - #[cfg(test)] - self.mark_step().await; - - #[cfg(test)] - match our_step { - Step::Zero => { - assert_eq!(self.barrier_id, barrier_id); - println!( - "{:x} R{} S0 >> self.barrier_id:{:x}", - self.barrier_id, our_round, barrier_id - ); - } - Step::One => { - assert_ne!(self.barrier_id, barrier_id); - println!( - "{:x} R{} S1 >> their_barrier_id:{:x}", - self.barrier_id, our_round, barrier_id - ); - } - } - self.sink - .send(construct_message(barrier_id, our_round, our_step).to_vec()) - .await - } - - #[cfg(test)] - async fn mark_step(&mut self) { - if let Some(marker) = &mut self.marker { - marker.mark_step().await - } - } -} - -const MSG_STEP_SIZE: usize = size_of::(); -const MSG_ID_SIZE: usize = size_of::(); -const MSG_ROUND_SIZE: usize = size_of::(); -const MSG_PREFIX: &[u8; 13] = b"barrier-start"; -const MSG_SUFFIX: &[u8; 11] = b"barrier-end"; -const MSG_PREFIX_SIZE: usize = MSG_PREFIX.len(); -const MSG_SUFFIX_SIZE: usize = MSG_SUFFIX.len(); -const MSG_SIZE: usize = - MSG_PREFIX_SIZE + MSG_ID_SIZE + MSG_ROUND_SIZE + MSG_STEP_SIZE + MSG_SUFFIX_SIZE; - -type MsgData = [u8; MSG_SIZE]; - -fn construct_message(barrier_id: BarrierId, round: Round, step: Step) -> MsgData { - let mut msg = [0u8; size_of::()]; - let s = &mut msg[..]; - - s[..MSG_PREFIX_SIZE].clone_from_slice(MSG_PREFIX); - let s = &mut s[MSG_PREFIX_SIZE..]; - - s[..MSG_ID_SIZE].clone_from_slice(&barrier_id.to_le_bytes()); - let s = &mut s[MSG_ID_SIZE..]; - - s[..MSG_ROUND_SIZE].clone_from_slice(&round.to_le_bytes()); - let s = &mut s[MSG_ROUND_SIZE..]; - - match step { - Step::Zero => s[..MSG_STEP_SIZE].clone_from_slice(&0u8.to_le_bytes()), - Step::One => s[..MSG_STEP_SIZE].clone_from_slice(&1u8.to_le_bytes()), - } - let s = &mut s[MSG_STEP_SIZE..]; - - s[..MSG_SUFFIX_SIZE].clone_from_slice(MSG_SUFFIX); - - msg -} - -fn parse_message(data: &[u8]) -> Option { - if data.len() != MSG_SIZE { - return None; - } - - let (prefix, rest) = data.split_at(MSG_PREFIX_SIZE); - - if prefix != MSG_PREFIX { - return None; - } - - let (id_data, rest) = rest.split_at(MSG_ID_SIZE); - let (round_data, rest) = rest.split_at(MSG_ROUND_SIZE); - let (step_data, suffix) = rest.split_at(MSG_STEP_SIZE); - - if suffix != MSG_SUFFIX { - return None; - } - - let step_num = u8::from_le_bytes(step_data.try_into().unwrap()); - - let step = match step_num { - 0 => Step::Zero, - 1 => Step::One, - _ => return None, - }; - - // Unwraps OK because we know the sizes at compile time. - Some(( - BarrierId::from_le_bytes(id_data.try_into().unwrap()), - Round::from_le_bytes(round_data.try_into().unwrap()), - step, - )) -} - -#[derive(Debug, thiserror::Error)] -pub enum BarrierError { - #[error("Barrier algorithm failed")] - Failure, - #[error("Channel closed")] - ChannelClosed, - #[error("Network transport changed")] - TransportChanged, -} - -impl From for BarrierError { - fn from(_: ChannelClosed) -> Self { - Self::ChannelClosed - } -} - -impl From for BarrierError { - fn from(error: ContentStreamError) -> Self { - match error { - ContentStreamError::ChannelClosed => Self::ChannelClosed, - ContentStreamError::TransportChanged => Self::TransportChanged, - } - } -} - -impl std::fmt::Debug for Step { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Zero => write!(f, "0"), - Self::One => write!(f, "1"), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use async_trait::async_trait; - use scoped_task::ScopedJoinHandle; - use std::sync::Arc; - use tokio::{ - sync::{mpsc, Mutex}, - task, - time::{timeout, Duration}, - }; - - struct Stepper { - first: bool, - pause_rx: mpsc::Receiver<()>, - resume_tx: mpsc::Sender<()>, - barrier_task: Option>>, - } - - impl Stepper { - fn new(barrier_id: BarrierId, sink: Sink, mut stream: Stream) -> Stepper { - let (pause_tx, pause_rx) = mpsc::channel(1); - let (resume_tx, resume_rx) = mpsc::channel(1); - - let barrier_task = scoped_task::spawn(async move { - Barrier { - barrier_id, - stream: &mut stream, - sink: &sink, - state: dummy_value(), - marker: Some(StepMarker { - pause_tx, - resume_rx, - }), - } - .run() - .await - }); - - Self { - first: true, - pause_rx, - resume_tx, - barrier_task: Some(barrier_task), - } - } - - // When the `barrier_task` finishes, this returns `Some(result of the task)`, otherwise it - // returns None. - async fn step(&mut self) -> Option> { - if !self.first { - self.resume_tx.send(()).await.unwrap(); - } - self.first = false; - - if self.pause_rx.recv().await.is_some() { - None - } else { - let barrier_task = self.barrier_task.take(); - Some(barrier_task.unwrap().await.unwrap()) - } - } - - async fn run_to_completion(&mut self) -> Result<(), BarrierError> { - loop { - if let Some(result) = self.step().await { - break result; - } - } - } - } - - pub(super) struct StepMarker { - pause_tx: mpsc::Sender<()>, - resume_rx: mpsc::Receiver<()>, - } - - impl StepMarker { - pub(super) async fn mark_step(&mut self) { - self.pause_tx.send(()).await.unwrap(); - self.resume_rx.recv().await.unwrap(); - } - } - - // --- Sink --------------------------------------------------------------- - #[derive(Clone)] - struct Sink { - drop_count: Arc>, - tx: mpsc::Sender>, - } - - #[async_trait] - impl ContentSinkTrait for Sink { - async fn send(&self, message: Vec) -> Result<(), ChannelClosed> { - { - let mut drop_count = self.drop_count.lock().await; - if *drop_count > 0 { - *drop_count -= 1; - return Ok(()); - } - } - self.tx.send(message).await.map_err(|_| ChannelClosed) - } - } - - // --- Stream -------------------------------------------------------------- - #[derive(Clone)] - struct Stream { - rx: Arc>>>, - } - - #[async_trait] - impl ContentStreamTrait for Stream { - async fn recv(&mut self) -> Result, ContentStreamError> { - let mut guard = self.rx.lock().await; - let vec = guard.recv().await.unwrap(); - Ok(vec) - } - } - - // ------------------------------------------------------------------------- - fn new_test_channel(drop_count: u32) -> (Sink, Stream) { - // Exchanging messages would normally require only a mpsc channel of size one, but at the - // beginning of the Barrier algorithm we also send one "reset" message which increases the - // channel size requirement by one. - let (tx, rx) = mpsc::channel(2); - ( - Sink { - drop_count: Arc::new(Mutex::new(drop_count)), - tx, - }, - Stream { - rx: Arc::new(Mutex::new(rx)), - }, - ) - } - // ------------------------------------------------------------------------- - - #[derive(Debug)] - enum Task1Result { - CFinished, - AFinished(Result<(), BarrierError>), - } - - // When this returns true, it's no longer needed to test with higher `n`. - async fn test_restart_case(n: u32) -> bool { - println!( - ">>>>>>>>>>>>>>>>>>> TEST RESTART AFTER n:{} <<<<<<<<<<<<<<<<<<<<<<", - n - ); - - let (ac_to_b, b_from_ac) = new_test_channel(0 /* don't drop anything */); - let (b_to_ac, ac_from_b) = new_test_channel(0 /* don't drop anything */); - - let task_1 = task::spawn(async move { - let mut stepper_c = Stepper::new(0xc, ac_to_b.clone(), ac_from_b.clone()); - - for _ in 0..n { - if let Some(result) = stepper_c.step().await { - assert!(result.is_ok()); - return Task1Result::CFinished; - } - } - - drop(stepper_c); - - let mut stepper_a = Stepper::new(0xa, ac_to_b, ac_from_b); - Task1Result::AFinished(stepper_a.run_to_completion().await) - }); - - let task_2 = task::spawn(async move { - let mut stepper = Stepper::new(0xb, b_to_ac, b_from_ac); - stepper.run_to_completion().await.unwrap() - }); - - let task_c = task::spawn(async move { - let r1 = task_1.await.unwrap(); - task_2.await.unwrap(); - - match r1 { - Task1Result::CFinished => (), - Task1Result::AFinished(Ok(_)) => (), - // This is a pathological case where 0xb finished while communicating with 0xc, but - // 0xc has been interrupted right before it could finish. Then 0xa starts but 0xb - // already moved on. I believe due to the CAP theorem there's nothing that can be - // done in this case apart from 0xa restarting the process. - Task1Result::AFinished(Err(BarrierError::ChannelClosed)) => (), - result => panic!("Invalid result from task '0xa' {:?}", result), - } - - matches!(r1, Task1Result::CFinished) - }); - - match timeout(Duration::from_secs(5), task_c).await { - Err(_) => panic!("Test case n:{} timed out", n), - Ok(Err(err)) => panic!("Test case n:{} failed with {:?}", n, err), - Ok(Ok(is_done)) => is_done, - } - } - - #[tokio::test] - async fn test_restarts() { - let mut n = 0; - loop { - if test_restart_case(n).await { - break; - } - n += 1; - } - } - - async fn test_drop_from_start_case(a: u32, b: u32) { - println!( - ">>>>>>>>>>>>>>>>>>> TEST DROP a:{} b:{} <<<<<<<<<<<<<<<<<<<<<<", - a, b - ); - - let (a_to_b, mut b_from_a) = new_test_channel(a); - let (b_to_a, mut a_from_b) = new_test_channel(b); - - let task_a = task::spawn(async move { - Barrier { - barrier_id: 0xa, - stream: &mut a_from_b, - sink: &a_to_b, - state: dummy_value(), - marker: None, - } - .run() - .await - .unwrap() - }); - - let task_b = task::spawn(async move { - Barrier { - barrier_id: 0xb, - stream: &mut b_from_a, - sink: &b_to_a, - state: dummy_value(), - marker: None, - } - .run() - .await - .unwrap() - }); - - match timeout(Duration::from_secs(5), task_a).await { - Err(_) => panic!( - "Test case drop_from_start (task_a, a:{}, b:{}) timed out", - a, b - ), - Ok(Err(err)) => panic!( - "Test case drop_from_start (task_a, a:{}, b:{}) failed with {:?}", - a, b, err - ), - Ok(Ok(_)) => (), - } - - match timeout(Duration::from_secs(5), task_b).await { - Err(_) => panic!( - "Test case drop_from_start (task_b, a:{}, b:{}) timed out", - a, b - ), - Ok(Err(err)) => panic!( - "Test case drop_from_start (task_b, a:{}, b:{}) failed with {:?}", - a, b, err - ), - Ok(Ok(_)) => (), - } - } - - #[tokio::test] - async fn test_drop_from_start() { - // Drop first `a` packets from the barrier A and first `b` packets from the barrier B. - // After sending first two packets, both nodes start reading, so max 2 dropped packets - // make sense to consider. - // Note that the case (2,2) should not be possible because that would mean that they both - // sent a message while not receiving one. - for (a, b) in [(0, 0), (1, 0), (1, 1), (2, 0), (2, 1)] { - test_drop_from_start_case(a, b).await; - } - } - - fn dummy_value() -> MonitoredValue<&'static str> { - StateMonitor::make_root().make_value("dummy", "") - } -} diff --git a/lib/src/network/connection.rs b/lib/src/network/connection.rs index fe02604e4..a71ea0c03 100644 --- a/lib/src/network/connection.rs +++ b/lib/src/network/connection.rs @@ -182,22 +182,6 @@ pub(super) struct ConnectionPermit { } impl ConnectionPermit { - /// Split the permit into two halves where dropping any of them releases the whole permit. - /// This is useful when the connection needs to be split into a reader and a writer Then if any - /// of them closes, the whole connection closes. So both the reader and the writer should be - /// associated with one half of the permit so that when any of them closes, the permit is - /// released. - pub fn into_split(self) -> (ConnectionPermitHalf, ConnectionPermitHalf) { - ( - ConnectionPermitHalf(Self { - connections: self.connections.clone(), - key: self.key, - id: self.id, - }), - ConnectionPermitHalf(self), - ) - } - pub fn mark_as_connecting(&self) { self.set_state(PeerState::Connecting); } @@ -316,24 +300,6 @@ impl Drop for ConnectionPermit { } } -/// Half of a connection permit. Dropping it drops the whole permit. -/// See [`ConnectionPermit::split`] for more details. -pub(super) struct ConnectionPermitHalf(ConnectionPermit); - -impl ConnectionPermitHalf { - pub fn id(&self) -> ConnectionId { - self.0.id - } - - pub fn byte_counters(&self) -> Arc { - self.0.byte_counters() - } - - pub fn released(&self) -> AwaitDrop { - self.0.released() - } -} - #[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] struct Key { addr: PeerAddr, diff --git a/lib/src/network/crypto.rs b/lib/src/network/crypto.rs index 41620129c..b6fcea636 100644 --- a/lib/src/network/crypto.rs +++ b/lib/src/network/crypto.rs @@ -8,14 +8,19 @@ //! based on the identity of the replicas is needed. use super::{ - message_dispatcher::{ChannelClosed, ContentSink, ContentStream, ContentStreamError}, + message_dispatcher::{ContentSink, ContentStream}, runtime_id::PublicRuntimeId, - stats::Instrumented, }; use crate::protocol::RepositoryId; +use bytes::{Bytes, BytesMut}; +use futures_util::{Sink, SinkExt, Stream, StreamExt, TryStreamExt}; use noise_protocol::Cipher as _; use noise_rust_crypto::{Blake2s, ChaCha20Poly1305, X25519}; -use std::mem; +use std::{ + io, + pin::Pin, + task::{ready, Context, Poll}, +}; use thiserror::Error; type Cipher = ChaCha20Poly1305; @@ -62,54 +67,68 @@ const MAX_NONCE: u64 = u64::MAX - 1; /// Wrapper for [`ContentStream`] that decrypts incoming messages. pub(super) struct DecryptingStream<'a> { - inner: &'a mut Instrumented, + inner: &'a mut ContentStream, cipher: CipherState, - buffer: Vec, } -impl DecryptingStream<'_> { - pub async fn recv(&mut self) -> Result, RecvError> { +impl Stream for DecryptingStream<'_> { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.cipher.get_next_n() >= MAX_NONCE { - return Err(RecvError::Exhausted); + return Poll::Ready(Some(Err(RecvError::Exhausted))); } - let mut content = self.inner.recv().await?; + let mut item = match ready!(self.inner.poll_next_unpin(cx)) { + Some(Ok(item)) => item, + Some(Err(error)) => return Poll::Ready(Some(Err(error.into()))), + None => return Poll::Ready(None), + }; - let plain_len = content - .len() - .checked_sub(Cipher::tag_len()) - .ok_or(RecvError::Crypto)?; - self.buffer.resize(plain_len, 0); - self.cipher - .decrypt_ad(self.inner.channel().as_ref(), &content, &mut self.buffer) - .map_err(|_| RecvError::Crypto)?; + let ciphertext_len = item.len(); - mem::swap(&mut content, &mut self.buffer); - - Ok(content) + match self.cipher.decrypt_in_place(&mut item, ciphertext_len) { + Ok(n) => Poll::Ready(Some(Ok(item.split_to(n)))), + Err(_) => Poll::Ready(Some(Err(RecvError::Crypto))), + } } } /// Wrapper for [`ContentSink`] that encrypts outgoing messages. pub(super) struct EncryptingSink<'a> { - inner: &'a mut Instrumented, + inner: &'a mut ContentSink, cipher: CipherState, - buffer: Vec, } -impl EncryptingSink<'_> { - pub async fn send(&mut self, mut content: Vec) -> Result<(), SendError> { +impl Sink for EncryptingSink<'_> { + type Error = SendError; + + fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { if self.cipher.get_next_n() >= MAX_NONCE { return Err(SendError::Exhausted); } - self.buffer.resize(content.len() + Cipher::tag_len(), 0); - self.cipher - .encrypt_ad(self.inner.channel().as_ref(), &content, &mut self.buffer); + let plaintext_len = item.len(); + let mut item = BytesMut::from(item); + + item.resize(plaintext_len + Cipher::tag_len(), 0); + let n = self.cipher.encrypt_in_place(&mut item, plaintext_len); + + self.inner.start_send_unpin(item.split_to(n).freeze())?; + + Ok(()) + } + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready_unpin(cx).map_err(Into::into) + } - mem::swap(&mut content, &mut self.buffer); + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_flush_unpin(cx).map_err(Into::into) + } - Ok(self.inner.send(content).await?) + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_close_unpin(cx).map_err(Into::into) } } @@ -118,8 +137,8 @@ impl EncryptingSink<'_> { pub(super) async fn establish_channel<'a>( role: Role, repo_id: &RepositoryId, - stream: &'a mut Instrumented, - sink: &'a mut Instrumented, + stream: &'a mut ContentStream, + sink: &'a mut ContentSink, ) -> Result<(DecryptingStream<'a>, EncryptingSink<'a>), EstablishError> { let mut handshake_state = build_handshake_state(role, repo_id); @@ -146,13 +165,11 @@ pub(super) async fn establish_channel<'a>( let stream = DecryptingStream { inner: stream, cipher: recv_cipher, - buffer: vec![], }; let sink = EncryptingSink { inner: sink, cipher: send_cipher, - buffer: vec![], }; Ok((stream, sink)) @@ -160,47 +177,28 @@ pub(super) async fn establish_channel<'a>( #[derive(Debug, Error)] pub(super) enum SendError { - #[error("channel closed")] - Closed, #[error("nonce counter exhausted")] Exhausted, -} - -impl From for SendError { - fn from(_: ChannelClosed) -> Self { - Self::Closed - } + #[error("IO error")] + Io(#[from] io::Error), } #[derive(Debug, Error)] pub(super) enum RecvError { #[error("decryption failed")] Crypto, - #[error("channel closed")] - Closed, #[error("nonce counter exhausted")] Exhausted, - #[error("network transport changed")] - TransportChanged, -} - -impl From for RecvError { - fn from(error: ContentStreamError) -> Self { - match error { - ContentStreamError::ChannelClosed => Self::Closed, - ContentStreamError::TransportChanged => Self::TransportChanged, - } - } + #[error("IO error")] + Io(#[from] io::Error), } #[derive(Debug, Error)] pub(super) enum EstablishError { #[error("encryption / decryption failed")] Crypto, - #[error("channel closed")] - Closed, - #[error("network transport changed")] - TransportChanged, + #[error("IO error")] + Io(#[from] io::Error), } impl From for EstablishError { @@ -209,21 +207,6 @@ impl From for EstablishError { } } -impl From for EstablishError { - fn from(_: ChannelClosed) -> Self { - Self::Closed - } -} - -impl From for EstablishError { - fn from(error: ContentStreamError) -> Self { - match error { - ContentStreamError::ChannelClosed => Self::Closed, - ContentStreamError::TransportChanged => Self::TransportChanged, - } - } -} - fn build_handshake_state(role: Role, repo_id: &RepositoryId) -> HandshakeState { use noise_protocol::patterns; @@ -242,17 +225,86 @@ fn build_handshake_state(role: Role, repo_id: &RepositoryId) -> HandshakeState { async fn handshake_send( state: &mut HandshakeState, - sink: &mut Instrumented, + sink: &mut ContentSink, msg: &[u8], ) -> Result<(), EstablishError> { let content = state.write_message_vec(msg)?; - Ok(sink.send(content).await?) + sink.send(content.into()).await?; + Ok(()) } async fn handshake_recv( state: &mut HandshakeState, - stream: &mut Instrumented, + stream: &mut ContentStream, ) -> Result, EstablishError> { - let content = stream.recv().await?; + let content = stream + .try_next() + .await? + .ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))?; + Ok(state.read_message_vec(&content)?) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::network::{ + message::MessageChannelId, + message_dispatcher::{create_connection_pair, MessageDispatcher}, + runtime_id::SecretRuntimeId, + stats::ByteCounters, + }; + use futures_util::future; + use std::sync::Arc; + + #[tokio::test] + async fn sanity_check() { + let (client, server) = create_connection_pair().await; + + let client = MessageDispatcher::builder(client).build(); + let server = MessageDispatcher::builder(server).build(); + + let repo_id = RepositoryId::random(); + + let client_id = SecretRuntimeId::random().public(); + let server_id = SecretRuntimeId::random().public(); + + let channel_id = MessageChannelId::new(&repo_id, &client_id, &server_id); + + let (mut client_sink, mut client_stream) = + client.open(channel_id, Arc::new(ByteCounters::new())); + + let (mut server_sink, mut server_stream) = + server.open(channel_id, Arc::new(ByteCounters::new())); + + let ((mut client_stream, mut client_sink), (mut server_stream, mut server_sink)) = + future::try_join( + establish_channel( + Role::Initiator, + &repo_id, + &mut client_stream, + &mut client_sink, + ), + establish_channel( + Role::Responder, + &repo_id, + &mut server_stream, + &mut server_sink, + ), + ) + .await + .unwrap(); + + client_sink.send(Bytes::from_static(b"ping")).await.unwrap(); + assert_eq!( + server_stream.try_next().await.unwrap().unwrap().as_ref(), + b"ping" + ); + + server_sink.send(Bytes::from_static(b"pong")).await.unwrap(); + assert_eq!( + client_stream.try_next().await.unwrap().unwrap().as_ref(), + b"pong" + ); + } +} diff --git a/lib/src/network/message.rs b/lib/src/network/message.rs index f980d5227..bc7996f43 100644 --- a/lib/src/network/message.rs +++ b/lib/src/network/message.rs @@ -1,5 +1,4 @@ use super::{ - crypto::Role, debug_payload::{DebugRequest, DebugResponse}, peer_exchange::PexPayload, runtime_id::PublicRuntimeId, @@ -11,8 +10,8 @@ use crate::{ UntrustedProof, }, }; +use net::bus::TopicId; use serde::{Deserialize, Serialize}; -use std::{fmt, io::Write}; #[derive(Clone, PartialEq, Serialize, Deserialize, Debug)] pub(crate) enum Request { @@ -60,98 +59,6 @@ pub(crate) enum Response { BlockError(BlockId, DebugResponse), } -const LEGACY_TAG: u8 = 2; - -#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] -pub(crate) struct Header { - pub channel: MessageChannelId, -} - -impl Header { - pub(crate) const SIZE: usize = 1 + // One byte for the tag. - Hash::SIZE; // Channel - - pub(crate) fn serialize(&self) -> [u8; Self::SIZE] { - let mut hdr = [0; Self::SIZE]; - let mut w = ArrayWriter { array: &mut hdr }; - - w.write_u8(LEGACY_TAG); - w.write_channel(&self.channel); - - hdr - } - - pub(crate) fn deserialize(hdr: &[u8; Self::SIZE]) -> Option
{ - let mut r = ArrayReader { array: &hdr[..] }; - // Tag is no longer used but we still read it for backwards compatibility. - let _ = r.read_u8(); - let channel = r.read_channel(); - - Some(Header { channel }) - } -} - -#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)] -pub(crate) struct Message { - pub channel: MessageChannelId, - pub content: Vec, -} - -impl Message { - pub fn header(&self) -> Header { - Header { - channel: self.channel, - } - } -} - -impl fmt::Debug for Message { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "Message {{ channel: {:?}, content-hash: {:?} }}", - self.channel, - self.content.hash() - ) - } -} - -struct ArrayReader<'a> { - array: &'a [u8], -} - -impl ArrayReader<'_> { - // Unwraps are OK because all sizes are known at compile time. - - fn read_u8(&mut self) -> u8 { - let n = u8::from_le_bytes(self.array[..1].try_into().unwrap()); - self.array = &self.array[1..]; - n - } - - fn read_channel(&mut self) -> MessageChannelId { - let hash: [u8; Hash::SIZE] = self.array[..Hash::SIZE].try_into().unwrap(); - self.array = &self.array[Hash::SIZE..]; - hash.into() - } -} - -struct ArrayWriter<'a> { - array: &'a mut [u8], -} - -impl ArrayWriter<'_> { - // Unwraps are OK because all sizes are known at compile time. - - fn write_u8(&mut self, n: u8) { - self.array.write_all(&n.to_le_bytes()).unwrap(); - } - - fn write_channel(&mut self, channel: &MessageChannelId) { - self.array.write_all(channel.as_ref()).unwrap(); - } -} - #[derive(Serialize, Deserialize, Debug)] pub(crate) enum Content { Request(Request), @@ -193,14 +100,14 @@ define_byte_array_wrapper! { impl MessageChannelId { pub(super) fn new( - repo_id: &'_ RepositoryId, - this_runtime_id: &'_ PublicRuntimeId, - that_runtime_id: &'_ PublicRuntimeId, - role: Role, + repo_id: &RepositoryId, + this_runtime_id: &PublicRuntimeId, + that_runtime_id: &PublicRuntimeId, ) -> Self { - let (id1, id2) = match role { - Role::Initiator => (this_runtime_id, that_runtime_id), - Role::Responder => (that_runtime_id, this_runtime_id), + let (id1, id2) = if this_runtime_id > that_runtime_id { + (this_runtime_id, that_runtime_id) + } else { + (that_runtime_id, this_runtime_id) }; Self( @@ -222,17 +129,8 @@ impl Default for MessageChannelId { } } -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn header_serialization() { - let header = Header { - channel: MessageChannelId::random(), - }; - - let serialized = header.serialize(); - assert_eq!(Header::deserialize(&serialized), Some(header)); +impl From for TopicId { + fn from(id: MessageChannelId) -> Self { + TopicId::from(id.0) } } diff --git a/lib/src/network/message_broker.rs b/lib/src/network/message_broker.rs index ba93427b4..5a9069d04 100644 --- a/lib/src/network/message_broker.rs +++ b/lib/src/network/message_broker.rs @@ -1,14 +1,12 @@ use super::{ - barrier::{Barrier, BarrierError}, client::Client, - connection::ConnectionPermit, crypto::{self, DecryptingStream, EncryptingSink, EstablishError, RecvError, Role, SendError}, message::{Content, MessageChannelId, Request, Response}, message_dispatcher::{ContentSink, ContentStream, MessageDispatcher}, peer_exchange::{PexPeer, PexReceiver, PexRepository, PexSender}, runtime_id::PublicRuntimeId, server::Server, - stats::{ByteCounters, Instrumented}, + stats::ByteCounters, }; use crate::{ collections::{hash_map::Entry, HashMap}, @@ -17,26 +15,23 @@ use crate::{ repository::Vault, }; use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; +use bytes::{BufMut, BytesMut}; +use futures_util::{SinkExt, StreamExt}; use net::unified::Connection; use state_monitor::StateMonitor; use std::{future, sync::Arc}; use tokio::{ select, - sync::{mpsc, oneshot, Semaphore}, + sync::{ + mpsc::{self, error::TryRecvError}, + oneshot, Semaphore, + }, task, time::Duration, }; use tracing::{instrument::Instrument, Span}; -/// Maintains one or more connections to a single peer, listening on all of them at the same time. -/// Note that at the present all the connections are UDP/QUIC based and so dropping some of them -/// would make sense. However, in the future we may also have other transports (e.g. TCP, -/// Bluetooth) and thus keeping all may make sence because even if one is dropped, the others may -/// still function. -/// -/// Once a message is received, it is determined whether it is a request or a response. Based on -/// that it either goes to the ClientStream or ServerStream for processing by the Client and Server -/// structures respectively. +/// Handler for communication with one peer. pub(super) struct MessageBroker { this_runtime_id: PublicRuntimeId, that_runtime_id: PublicRuntimeId, @@ -48,18 +43,25 @@ pub(super) struct MessageBroker { } impl MessageBroker { + #[allow(clippy::too_many_arguments)] pub fn new( this_runtime_id: PublicRuntimeId, that_runtime_id: PublicRuntimeId, + connection: Connection, pex_peer: PexPeer, monitor: StateMonitor, + total_counters: Arc, + peer_counters: Arc, ) -> Self { let span = SpanGuard::new(&that_runtime_id); Self { this_runtime_id, that_runtime_id, - dispatcher: MessageDispatcher::new(), + dispatcher: MessageDispatcher::builder(connection) + .with_total_counters(total_counters) + .with_peer_counters(peer_counters) + .build(), links: HashMap::default(), pex_peer, monitor, @@ -67,17 +69,6 @@ impl MessageBroker { } } - pub fn add_connection( - &self, - connection: Connection, - permit: ConnectionPermit, - byte_counters: Arc, - ) { - self.pex_peer - .handle_connection(permit.addr(), permit.source(), permit.released()); - self.dispatcher.bind(connection, permit, byte_counters) - } - /// Try to establish a link between a local repository and a remote repository. The remote /// counterpart needs to call this too with matching repository id for the link to actually be /// created. @@ -86,7 +77,7 @@ impl MessageBroker { vault: Vault, pex_repo: &PexRepository, response_limiter: Arc, - byte_counters: Arc, + repo_counters: Arc, ) { let monitor = self.monitor.make_child(vault.monitor.name()); let span = tracing::info_span!( @@ -123,14 +114,11 @@ impl MessageBroker { vault.repository_id(), &self.this_runtime_id, &self.that_runtime_id, - role, ); - let (pex_tx, pex_rx) = self.pex_peer.new_link(pex_repo); + let (sink, stream) = self.dispatcher.open(channel_id, repo_counters); - let stream = - Instrumented::new(self.dispatcher.open_recv(channel_id), byte_counters.clone()); - let sink = Instrumented::new(self.dispatcher.open_send(channel_id), byte_counters); + let (pex_tx, pex_rx) = self.pex_peer.new_link(pex_repo); let mut link = Link { role, @@ -190,8 +178,8 @@ impl Drop for SpanGuard { struct Link { role: Role, - stream: Instrumented, - sink: Instrumented, + stream: ContentStream, + sink: ContentSink, vault: Vault, response_limiter: Arc, pex_tx: PexSender, @@ -205,7 +193,6 @@ impl Link { #[derive(Debug)] enum State { Sleeping(#[allow(dead_code)] Duration), - AwaitingBarrier, EstablishingChannel, Running, } @@ -217,7 +204,7 @@ impl Link { .build(); let mut next_sleep = None; - let state = self.monitor.make_value("state", State::AwaitingBarrier); + let state = self.monitor.make_value("state", State::EstablishingChannel); loop { if let Some(sleep) = next_sleep { @@ -227,18 +214,6 @@ impl Link { next_sleep = backoff.next_backoff(); - *state.get() = State::AwaitingBarrier; - - match Barrier::new(self.stream.as_mut(), self.sink.as_ref(), &self.monitor) - .run() - .await - { - Ok(()) => (), - Err(BarrierError::Failure) => continue, - Err(BarrierError::ChannelClosed) => break, - Err(BarrierError::TransportChanged) => continue, - } - *state.get() = State::EstablishingChannel; let (crypto_stream, crypto_sink) = @@ -247,8 +222,7 @@ impl Link { { Ok(io) => io, Err(EstablishError::Crypto) => continue, - Err(EstablishError::Closed) => break, - Err(EstablishError::TransportChanged) => continue, + Err(EstablishError::Io(_)) => break, }; *state.get() = State::Running; @@ -272,8 +246,8 @@ impl Link { async fn establish_channel<'a>( role: Role, - stream: &'a mut Instrumented, - sink: &'a mut Instrumented, + stream: &'a mut ContentStream, + sink: &'a mut ContentSink, vault: &Vault, ) -> Result<(DecryptingStream<'a>, EncryptingSink<'a>), EstablishError> { match crypto::establish_channel(role, vault.repository_id(), stream, sink).await { @@ -328,23 +302,23 @@ async fn recv_messages( pex_rx: &PexReceiver, ) -> ControlFlow { loop { - let content = match stream.recv().await { - Ok(content) => content, - Err(RecvError::Crypto) => { + let content = match stream.next().await { + Some(Ok(content)) => content, + Some(Err(RecvError::Crypto)) => { tracing::warn!("Failed to decrypt incoming message",); return ControlFlow::Continue; } - Err(RecvError::Exhausted) => { + Some(Err(RecvError::Exhausted)) => { tracing::debug!("Incoming message nonce counter exhausted",); return ControlFlow::Continue; } - Err(RecvError::Closed) => { - tracing::debug!("Message stream closed"); + Some(Err(RecvError::Io(error))) => { + tracing::warn!(?error, "Failed to receive incoming message"); return ControlFlow::Break; } - Err(RecvError::TransportChanged) => { - tracing::debug!("Transport has changed"); - return ControlFlow::Continue; + None => { + tracing::debug!("Message channel closed"); + return ControlFlow::Break; } }; @@ -352,7 +326,7 @@ async fn recv_messages( Ok(content) => content, Err(error) => { tracing::warn!(?error, "Failed to deserialize incoming message"); - continue; // TODO: should we return `ControlFlow::Continue` here as well? + continue; } }; @@ -369,25 +343,41 @@ async fn send_messages( mut content_rx: mpsc::UnboundedReceiver, mut sink: EncryptingSink<'_>, ) -> ControlFlow { + let mut writer = BytesMut::new().writer(); + loop { - let content = if let Some(content) = content_rx.recv().await { - content - } else { + let content = match content_rx.try_recv() { + Ok(content) => Some(content), + Err(TryRecvError::Empty) => { + match sink.flush().await { + Ok(()) => (), + Err(error) => { + tracing::warn!(?error, "Failed to flush outgoing messages"); + return ControlFlow::Break; + } + } + + content_rx.recv().await + } + Err(TryRecvError::Disconnected) => None, + }; + + let Some(content) = content else { forever().await }; // unwrap is OK because serialization into a vec should never fail unless we have a bug // somewhere. - let content = bincode::serialize(&content).unwrap(); + bincode::serialize_into(&mut writer, &content).unwrap(); - match sink.send(content).await { + match sink.feed(writer.get_mut().split().freeze()).await { Ok(()) => (), Err(SendError::Exhausted) => { tracing::debug!("Outgoing message nonce counter exhausted"); return ControlFlow::Continue; } - Err(SendError::Closed) => { - tracing::debug!("Message sink closed"); + Err(SendError::Io(error)) => { + tracing::warn!(?error, "Failed to send outgoing message"); return ControlFlow::Break; } } diff --git a/lib/src/network/message_dispatcher.rs b/lib/src/network/message_dispatcher.rs index 7d06aa504..51f6045c2 100644 --- a/lib/src/network/message_dispatcher.rs +++ b/lib/src/network/message_dispatcher.rs @@ -1,904 +1,177 @@ //! Utilities for sending and receiving messages across the network. use super::{ - connection::{ConnectionDirection, ConnectionId, ConnectionPermit, ConnectionPermitHalf}, - message::{Message, MessageChannelId}, - message_io::{MessageSink, MessageStream, MESSAGE_OVERHEAD}, + message::MessageChannelId, stats::{ByteCounters, Instrumented}, }; -use crate::{collections::HashMap, sync::AwaitDrop}; -use async_trait::async_trait; -use futures_util::{ - future, ready, - stream::{FuturesUnordered, SelectAll}, - FutureExt, Sink, SinkExt, Stream, StreamExt, +use net::{ + bus::{Bus, BusRecvStream, BusSendStream}, + unified::Connection, }; -use net::unified::{Connection, ConnectionError, RecvStream, SendStream}; -use std::{ - io, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; -use tokio::{ - select, - sync::{mpsc, oneshot}, - task, -}; - -const CONTENT_STREAM_BUFFER_SIZE: usize = 1024; +use std::sync::Arc; +use tokio_util::codec::{length_delimited, FramedRead, FramedWrite, LengthDelimitedCodec}; -/// Reads/writes messages from/to the underlying TCP or QUIC streams and dispatches them to +/// Reads/writes messages from/to the underlying TCP or QUIC connection and dispatches them to /// individual streams/sinks based on their channel ids (in the MessageDispatcher's and /// MessageBroker's contexts, there is a one-to-one relationship between the channel id and a /// repository id). -#[derive(Clone)] pub(super) struct MessageDispatcher { - command_tx: mpsc::UnboundedSender, - sink_tx: mpsc::Sender, + bus: net::bus::Bus, + total_counters: Arc, + peer_counters: Arc, } impl MessageDispatcher { - pub fn new() -> Self { - let (command_tx, command_rx) = mpsc::unbounded_channel(); - let (sink_tx, sink_rx) = mpsc::channel(1); - - let worker = Worker::new(command_rx, sink_rx); - task::spawn(worker.run()); - - Self { - command_tx, - sink_tx, + pub fn builder(connection: Connection) -> Builder { + Builder { + connection, + total_counters: None, + peer_counters: None, } } - /// Bind this dispatcher to the given TCP of QUIC socket. Can be bound to multiple sockets and - /// the failed ones are automatically removed. - pub fn bind( + /// Opens a sink and a stream for communication on the given channel. + pub fn open( &self, - connection: Connection, - permit: ConnectionPermit, - byte_counters: Arc, - ) { - self.command_tx - .send(Command::Bind { - connection, - permit, - byte_counters, - }) - .ok(); - } - - /// Opens a stream for receiving messages on the given channel. Any messages received on - /// `channel` before the stream's been opened are discarded. When a stream is opened, all - /// previously opened streams on the same channel (if any) get automatically closed. - pub fn open_recv(&self, channel: MessageChannelId) -> ContentStream { - let (stream_tx, stream_rx) = mpsc::channel(CONTENT_STREAM_BUFFER_SIZE); - - self.command_tx - .send(Command::Open { channel, stream_tx }) - .ok(); - - ContentStream { - channel, - command_tx: self.command_tx.clone(), - stream_rx, - last_transport_id: None, - parked_message: None, - } - } - - /// Opens a sink for sending messages on the given channel. - pub fn open_send(&self, channel: MessageChannelId) -> ContentSink { - ContentSink { - channel, - sink_tx: self.sink_tx.clone(), - } - } - - /// Gracefully shuts down this dispatcher. This closes all bound connections and all open - /// message streams and sinks. - /// - /// Note: the dispatcher also shutdowns automatically when it and all its message streams and - /// sinks have been dropped. Calling this function is still useful when one wants to force the - /// existing streams/sinks to close and/or to wait until the shutdown has been completed. - pub async fn shutdown(self) { - let (tx, rx) = oneshot::channel(); - self.command_tx.send(Command::Shutdown { tx }).ok(); - rx.await.ok(); - } -} - -pub(super) struct ContentStream { - channel: MessageChannelId, - command_tx: mpsc::UnboundedSender, - stream_rx: mpsc::Receiver<(ConnectionId, Vec)>, - last_transport_id: Option, - parked_message: Option>, -} - -impl ContentStream { - /// Receive the next message content. - pub async fn recv(&mut self) -> Result, ContentStreamError> { - if let Some(content) = self.parked_message.take() { - return Ok(content); - } - - let (connection_id, content) = self - .stream_rx - .recv() - .await - .ok_or(ContentStreamError::ChannelClosed)?; - - if let Some(last_transport_id) = self.last_transport_id { - if last_transport_id == connection_id { - Ok(content) - } else { - self.last_transport_id = Some(connection_id); - self.parked_message = Some(content); - Err(ContentStreamError::TransportChanged) - } - } else { - self.last_transport_id = Some(connection_id); - Ok(content) - } - } - - pub fn channel(&self) -> &MessageChannelId { - &self.channel - } -} - -impl Instrumented { - pub async fn recv(&mut self) -> Result, ContentStreamError> { - let content = self.as_mut().recv().await?; - self.counters() - .increment_rx(content.len() as u64 + MESSAGE_OVERHEAD as u64); - Ok(content) - } - - pub fn channel(&self) -> &MessageChannelId { - self.as_ref().channel() - } -} - -impl Drop for ContentStream { - fn drop(&mut self) { - self.command_tx - .send(Command::Close { - channel: self.channel, - }) - .ok(); - } -} - -#[derive(Eq, PartialEq, Debug)] -pub(super) enum ContentStreamError { - ChannelClosed, - TransportChanged, -} - -#[derive(Clone)] -pub(super) struct ContentSink { - channel: MessageChannelId, - sink_tx: mpsc::Sender, -} - -impl ContentSink { - pub fn channel(&self) -> &MessageChannelId { - &self.channel - } - - /// Returns whether the send succeeded. - pub async fn send(&self, content: Vec) -> Result<(), ChannelClosed> { - self.sink_tx - .send(Message { - channel: self.channel, - content, - }) - .await - .map_err(|_| ChannelClosed) - } -} - -impl Instrumented { - pub async fn send(&self, content: Vec) -> Result<(), ChannelClosed> { - let len = content.len(); - self.as_ref().send(content).await?; - self.counters() - .increment_tx(len as u64 + MESSAGE_OVERHEAD as u64); - Ok(()) - } + channel_id: MessageChannelId, + repo_counters: Arc, + ) -> (ContentSink, ContentStream) { + let (writer, reader) = self.bus.create_topic(channel_id.into()); - pub fn channel(&self) -> &MessageChannelId { - self.as_ref().channel() - } -} + let writer = Instrumented::new(writer, self.total_counters.clone()); + let writer = Instrumented::new(writer, self.peer_counters.clone()); + let writer = Instrumented::new(writer, repo_counters.clone()); -//------------------------------------------------------------------------ -// These traits are useful for testing. -// TODO: Move these traits and impls to barrier.rs as they are not used anywhere else. + let reader = Instrumented::new(reader, self.total_counters.clone()); + let reader = Instrumented::new(reader, self.peer_counters.clone()); + let reader = Instrumented::new(reader, repo_counters); -#[async_trait] -pub(super) trait ContentSinkTrait { - async fn send(&self, content: Vec) -> Result<(), ChannelClosed>; -} + let codec = make_codec(); -#[async_trait] -pub(super) trait ContentStreamTrait { - async fn recv(&mut self) -> Result, ContentStreamError>; -} + let sink = FramedWrite::new(writer, codec.clone()); + let stream = FramedRead::new(reader, codec); -#[async_trait] -impl ContentSinkTrait for ContentSink { - async fn send(&self, content: Vec) -> Result<(), ChannelClosed> { - self.send(content).await + (sink, stream) } -} -#[async_trait] -impl ContentStreamTrait for ContentStream { - async fn recv(&mut self) -> Result, ContentStreamError> { - self.recv().await + /// Gracefully shuts down this dispatcher. This closes the underlying connection and all open + /// message streams and sinks. + /// + /// Note: the dispatcher also shuts down automatically when it's been dropped. Calling this + /// function is still useful when one wants to force the existing streams/sinks to close and/or + /// to wait until the shutdown has been completed. + pub async fn shutdown(self) { + self.bus.close().await; } } -#[derive(Debug)] -pub(super) struct ChannelClosed; - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// Internal - -// Stream for receiving messages from a single connection. Contains a connection permit half which -// gets released on drop. Automatically closes when the corresponding `ConnectionSink` closes. -struct ConnectionStream { - // The reader is doubly instrumented - first time to track per connection stats and second time - // to track cumulative stats across all connections. - reader: MessageStream>>, - permit: ConnectionPermitHalf, - permit_released: AwaitDrop, +pub(super) struct Builder { + connection: Connection, + total_counters: Option>, + peer_counters: Option>, } -impl ConnectionStream { - fn new(reader: Instrumented, permit: ConnectionPermitHalf) -> Self { - let permit_released = permit.released(); - +impl Builder { + pub fn with_total_counters(self, counters: Arc) -> Self { Self { - reader: MessageStream::new(Instrumented::new(reader, permit.byte_counters())), - permit, - permit_released, + total_counters: Some(counters), + ..self } } -} - -impl Stream for ConnectionStream { - type Item = (ConnectionId, Message); - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // Check if our sink was closed. - match self.permit_released.poll_unpin(cx) { - Poll::Pending => (), - Poll::Ready(()) => { - return Poll::Ready(None); - } - } - - match ready!(self.reader.poll_next_unpin(cx)) { - Some(Ok(message)) => Poll::Ready(Some((self.permit.id(), message))), - Some(Err(_)) | None => Poll::Ready(None), - } - } -} - -// Sink for sending messages on a single connection. Contains a connection permit half which gets -// released on drop. Automatically closes when the corresponding `ConnectionStream` is closed. -struct ConnectionSink { - // The writer is doubly instrumented - first time to track per connection stats and second time - // to track cumulative stats across all connections. - writer: MessageSink>>, - _permit: ConnectionPermitHalf, - permit_released: AwaitDrop, -} - -impl ConnectionSink { - fn new(writer: Instrumented, permit: ConnectionPermitHalf) -> Self { - let permit_released = permit.released(); + pub fn with_peer_counters(self, counters: Arc) -> Self { Self { - writer: MessageSink::new(Instrumented::new(writer, permit.byte_counters())), - _permit: permit, - permit_released, + peer_counters: Some(counters), + ..self } } -} -impl Sink for ConnectionSink { - type Error = io::Error; - - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // Check if our stream was closed. - match self.permit_released.poll_unpin(cx) { - Poll::Pending => (), - Poll::Ready(()) => { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::Other, - "message channel closed", - ))); - } + pub fn build(self) -> MessageDispatcher { + MessageDispatcher { + bus: Bus::new(self.connection), + total_counters: self.total_counters.unwrap_or_default(), + peer_counters: self.peer_counters.unwrap_or_default(), } - - self.writer.poll_ready_unpin(cx) - } - - fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { - self.writer.start_send_unpin(item) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.writer.poll_flush_unpin(cx) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.writer.poll_close_unpin(cx) } } -struct Worker { - command_rx: mpsc::UnboundedReceiver, - send: SendState, - recv: RecvState, +fn make_codec() -> LengthDelimitedCodec { + length_delimited::Builder::new() + .big_endian() + .length_field_type::() + .new_codec() } -impl Worker { - fn new(command_rx: mpsc::UnboundedReceiver, sink_rx: mpsc::Receiver) -> Self { - Self { - command_rx, - send: SendState { - sink_rx, - sinks: Vec::new(), - }, - recv: RecvState { - streams: SelectAll::default(), - channels: HashMap::default(), - message: None, - }, - } - } - - async fn run(mut self) { - let mut streams = FuturesUnordered::new(); - - loop { - let command = select! { - command = self.command_rx.recv() => command, - Some(result) = streams.next() => { - match result { - Ok((connection, tx, rx)) => { - self.send.sinks.push((connection, tx)); - self.recv.streams.push(rx); - } - Err(error) => { - tracing::debug!(?error, "Failed to establish a new connection stream"); - } - } - - continue; - } - _ = self.send.run() => unreachable!(), - _ = self.recv.run() => unreachable!(), - }; - - let Some(command) = command else { - break; - }; - - match command { - Command::Open { channel, stream_tx } => { - self.recv.channels.insert(channel, stream_tx); - } - Command::Close { channel } => { - self.recv.channels.remove(&channel); - } - Command::Bind { - connection, - permit, - byte_counters, - } => { - streams.push(async move { - let (tx, rx) = match ConnectionDirection::from_source(permit.source()) { - ConnectionDirection::Incoming => connection.incoming().await?, - ConnectionDirection::Outgoing => connection.outgoing().await?, - }; - - let (tx_permit, rx_permit) = permit.into_split(); - - let tx = Instrumented::new(tx, byte_counters.clone()); - let tx = ConnectionSink::new(tx, tx_permit); - - let rx = Instrumented::new(rx, byte_counters.clone()); - let rx = ConnectionStream::new(rx, rx_permit); - - Ok::<_, ConnectionError>((connection, tx, rx)) - }); - } - Command::Shutdown { tx } => { - self.shutdown().await; - tx.send(()).ok(); - } - } - } - - self.shutdown().await; - } +// The streams/sinks are tripple-instrumented: once to collect the total cummulative traffic across +// all peers, once to collect the traffic per peer and once to collect the traffic per repo. +pub(super) type ContentStream = + FramedRead>>, LengthDelimitedCodec>; - async fn shutdown(&mut self) { - future::join_all(self.send.sinks.drain(..).map(|(connection, _)| async move { - connection.close().await; - })) - .await; - - self.send.sink_rx.close(); - - self.recv.streams.clear(); - self.recv.channels.clear(); - } -} - -enum Command { - Open { - channel: MessageChannelId, - stream_tx: mpsc::Sender<(ConnectionId, Vec)>, - }, - Close { - channel: MessageChannelId, - }, - Bind { - connection: Connection, - permit: ConnectionPermit, - byte_counters: Arc, - }, - Shutdown { - tx: oneshot::Sender<()>, - }, -} - -struct SendState { - sink_rx: mpsc::Receiver, - // We need to keep the `Connection` around so the sink/stream stay opened. We can store it here - // or in the `RecvState` but storing it here is slightly simpler. - sinks: Vec<(Connection, ConnectionSink)>, -} - -impl SendState { - // Keep sending outgoing messages. This function never returns, but it's safe to cancel. - async fn run(&mut self) { - while let Some((_, sink)) = self.sinks.first_mut() { - // The order of operations here is important for cancel-safety: first wait for the sink - // to become ready for sending, then receive the message to be sent and finally send - // the message on the sink. This order ensures that if this function is cancelled at - // any point, the message to be sent is never lost. - match future::poll_fn(|cx| sink.poll_ready_unpin(cx)).await { - Ok(()) => (), - Err(_) => { - self.sinks.swap_remove(0); - continue; - } - } - - let Some(message) = self.sink_rx.recv().await else { - break; - }; - - match sink.start_send_unpin(message) { - Ok(()) => (), - Err(_) => { - self.sinks.swap_remove(0); - continue; - } - } - } - - future::pending().await - } -} - -struct RecvState { - streams: SelectAll, - channels: HashMap)>>, - message: Option<(MessageChannelId, ConnectionId, Vec)>, -} - -impl RecvState { - // Keeps receiving incomming messages and dispatches them to their respective message channels. - // This function never returns but it's safe to cancel. - async fn run(&mut self) { - loop { - let (channel, connection_id, content) = match self.message.take() { - Some(message) => message, - None => match self.streams.next().await { - Some((connection_id, message)) => { - (message.channel, connection_id, message.content) - } - None => break, - }, - }; - - let Some(tx) = self.channels.get(&channel) else { - continue; - }; - - // Cancel safety: Remember the message while we are awaiting the send permit, so that if - // this function is cancelled here we can resume sending of the message on the next - // invocation. - self.message = Some((channel, connection_id, content)); - - let Ok(send_permit) = tx.reserve().await else { - continue; - }; - - // unwrap is ok because `self.message` is `Some` here. - let (_, connection_id, content) = self.message.take().unwrap(); - - send_permit.send((connection_id, content)); - } - - future::pending().await - } -} +pub(super) type ContentSink = + FramedWrite>>, LengthDelimitedCodec>; +/// Create pair of Connections connected to each other. For tests only. #[cfg(test)] -mod tests { - use super::{super::stats::ByteCounters, *}; - use assert_matches::assert_matches; - use futures_util::{future, stream}; +pub(super) async fn create_connection_pair() -> (Connection, Connection) { + use futures_util::future; use net::{ - unified::{Acceptor, Connection, Connector}, + unified::{Acceptor, Connector}, SocketOptions, }; - use std::{collections::BTreeSet, net::Ipv4Addr, str::from_utf8, time::Duration}; - - #[tokio::test(flavor = "multi_thread")] - async fn recv_on_stream() { - let channel = MessageChannelId::random(); - let send_content = b"hello world"; - - let server_dispatcher = MessageDispatcher::new(); - let mut server_stream = server_dispatcher.open_recv(channel); - - let (client, server) = create_connection_pair().await; + use std::net::Ipv4Addr; - let (client_tx, _client_rx) = client.outgoing().await.unwrap(); - let mut client_sink = MessageSink::new(client_tx); + // NOTE: Make sure to keep the `reuse_addr` option disabled here to avoid one test to + // accidentally connect to a different test (even from a different process). More details + // here: https://gavv.net/articles/ephemeral-port-reuse/. - server_dispatcher.bind( - server, - ConnectionPermit::dummy(ConnectionDirection::Incoming), - Arc::new(ByteCounters::default()), - ); + let client = net::quic::configure((Ipv4Addr::LOCALHOST, 0).into(), SocketOptions::default()) + .unwrap() + .0; + let server = net::quic::configure((Ipv4Addr::LOCALHOST, 0).into(), SocketOptions::default()) + .unwrap() + .1; - client_sink - .send(Message { - channel, - content: send_content.to_vec(), - }) - .await - .unwrap(); - - let recv_content = server_stream.recv().await.unwrap(); - assert_eq!(recv_content, send_content); - } - - #[tokio::test(flavor = "multi_thread")] - async fn recv_on_two_streams() { - let channel0 = MessageChannelId::random(); - let channel1 = MessageChannelId::random(); + let client = Connector::from(client); + let server = Acceptor::from(server); - let send_content0 = b"one two three"; - let send_content1 = b"four five six"; + let client = client.connect(*server.local_addr()); + let server = async { server.accept().await?.await }; - let server_dispatcher = MessageDispatcher::new(); - let server_stream0 = server_dispatcher.open_recv(channel0); - let server_stream1 = server_dispatcher.open_recv(channel1); - - let (client, server) = create_connection_pair().await; - - let mut client_sink = MessageSink::new(client.outgoing().await.unwrap().0); - - server_dispatcher.bind( - server, - ConnectionPermit::dummy(ConnectionDirection::Incoming), - Arc::new(ByteCounters::default()), - ); - - for (channel, content) in [(channel0, send_content0), (channel1, send_content1)] { - client_sink - .send(Message { - channel, - content: content.to_vec(), - }) - .await - .unwrap(); - } - - for (mut server_stream, send_content) in [ - (server_stream0, send_content0), - (server_stream1, send_content1), - ] { - let recv_content = server_stream.recv().await.unwrap(); - assert_eq!(recv_content, send_content); - } - - client.close().await; - server_dispatcher.shutdown().await; - } - - #[tokio::test(flavor = "multi_thread")] - async fn send_on_two_streams_parallel() { - use tokio::{task, time::timeout}; - - let channel0 = MessageChannelId::random(); - let channel1 = MessageChannelId::random(); - - let client_dispatcher = MessageDispatcher::new(); - let client_sink0 = client_dispatcher.open_send(channel0); - let client_sink1 = client_dispatcher.open_send(channel1); - - let server_dispatcher = MessageDispatcher::new(); - let server_stream0 = server_dispatcher.open_recv(channel0); - let server_stream1 = server_dispatcher.open_recv(channel1); - - let (client, server) = create_connection_pair().await; - client_dispatcher.bind( - client, - ConnectionPermit::dummy(ConnectionDirection::Outgoing), - Arc::new(ByteCounters::new()), - ); - server_dispatcher.bind( - server, - ConnectionPermit::dummy(ConnectionDirection::Incoming), - Arc::new(ByteCounters::new()), - ); - - let num_messages = 20; - let mut send_tasks = vec![]; - - let build_message = |channel, i| format!("{:?}:{}", channel, i).as_bytes().to_vec(); - - for sink in [client_sink0, client_sink1] { - send_tasks.push(task::spawn(async move { - for i in 0..num_messages { - sink.send(build_message(sink.channel, i)).await.unwrap(); - } - })); - } - - for task in send_tasks { - timeout(Duration::from_secs(3), task) - .await - .expect("Timed out") - .expect("Send failed"); - } - - for mut server_stream in [server_stream0, server_stream1] { - for i in 0..num_messages { - let recv_content = server_stream.recv().await.unwrap(); - assert_eq!( - from_utf8(&recv_content).unwrap(), - from_utf8(&build_message(server_stream.channel, i)).unwrap() - ); - } - } - - client_dispatcher.shutdown().await; - server_dispatcher.shutdown().await; - } - - #[tokio::test(flavor = "multi_thread")] - async fn duplicate_stream() { - let channel = MessageChannelId::random(); - - let send_content0 = b"one two three"; - let send_content1 = b"four five six"; - - let server_dispatcher = MessageDispatcher::new(); - let mut server_stream0 = server_dispatcher.open_recv(channel); - let mut server_stream1 = server_dispatcher.open_recv(channel); - - let (client, server) = create_connection_pair().await; - - let mut client_sink = MessageSink::new(client.outgoing().await.unwrap().0); - - server_dispatcher.bind( - server, - ConnectionPermit::dummy(ConnectionDirection::Incoming), - Arc::new(ByteCounters::new()), - ); - - for content in [send_content0, send_content1] { - client_sink - .send(Message { - channel, - content: content.to_vec(), - }) - .await - .unwrap(); - } - - assert_matches!( - server_stream0.recv().await, - Err(ContentStreamError::ChannelClosed) - ); - assert_eq!(server_stream1.recv().await.unwrap(), send_content0); - assert_eq!(server_stream1.recv().await.unwrap(), send_content1); + future::try_join(client, server).await.unwrap() +} - client.close().await; - server_dispatcher.shutdown().await; - } +#[cfg(test)] +mod tests { + use super::{super::stats::ByteCounters, *}; + use bytes::Bytes; + use futures_util::SinkExt; + use tokio_stream::StreamExt; - #[tokio::test(flavor = "multi_thread")] - async fn multiple_connections_recv() { + #[tokio::test] + async fn sanity_check() { crate::test_utils::init_log(); let channel = MessageChannelId::random(); + let send_content = b"hello world"; - let send_content0 = b"one two three"; - let send_content1 = b"four five six"; - - let server_dispatcher = MessageDispatcher::new(); - let mut server_stream = server_dispatcher.open_recv(channel); - - let (client0, server0) = create_connection_pair().await; - let (client1, server1) = create_connection_pair().await; - - let client_sink0 = MessageSink::new(client0.outgoing().await.unwrap().0); - let client_sink1 = MessageSink::new(client1.outgoing().await.unwrap().0); - - server_dispatcher.bind( - server0, - ConnectionPermit::dummy(ConnectionDirection::Incoming), - Arc::new(ByteCounters::new()), - ); - server_dispatcher.bind( - server1, - ConnectionPermit::dummy(ConnectionDirection::Incoming), - Arc::new(ByteCounters::new()), - ); - - for (mut client_sink, content) in - [(client_sink0, send_content0), (client_sink1, send_content1)] - { - client_sink - .send(Message { - channel, - content: content.to_vec(), - }) - .await - .unwrap(); - } - - let recv_content0 = server_stream.recv().await.unwrap(); - - assert_eq!( - server_stream.recv().await, - Err(ContentStreamError::TransportChanged) - ); - - let recv_content1 = server_stream.recv().await.unwrap(); - - // The messages may be received in any order - assert_eq!( - BTreeSet::from([recv_content0.as_slice(), recv_content1.as_slice()]), - BTreeSet::from([send_content0.as_slice(), send_content1.as_slice()]), - ); - - client0.close().await; - client1.close().await; - server_dispatcher.shutdown().await; - } - - #[tokio::test(flavor = "multi_thread")] - async fn multiple_connections_send() { - let channel = MessageChannelId::random(); - - let send_content0 = b"one two three"; - let send_content1 = b"four five six"; + let (client, server) = create_connection_pair().await; - let server_dispatcher = MessageDispatcher::new(); - let server_sink = server_dispatcher.open_send(channel); + let server_dispatcher = MessageDispatcher::builder(server).build(); - let (client0, server0) = create_connection_pair().await; - let (client1, server1) = create_connection_pair().await; + let (_server_sink, mut server_stream) = + server_dispatcher.open(channel, Arc::new(ByteCounters::new())); - let (client0_tx, client0_rx) = client0.outgoing().await.unwrap(); - let (client1_tx, client1_rx) = client1.outgoing().await.unwrap(); + let client_dispatcher = MessageDispatcher::builder(client).build(); - // The incoming streams are accepted only after something is sent on the corresponding - // outgoing streams first. - let mut client0_sink = MessageSink::new(client0_tx); - let mut client1_sink = MessageSink::new(client1_tx); + let (mut client_sink, _client_stream) = + client_dispatcher.open(channel, Arc::new(ByteCounters::new())); - for sink in [&mut client0_sink, &mut client1_sink] { - sink.send(Message { - channel, - content: Vec::new(), - }) + client_sink + .send(Bytes::from_static(send_content)) .await .unwrap(); - } - - let client0_stream = MessageStream::new(client0_rx); - let client1_stream = MessageStream::new(client1_rx); - - server_dispatcher.bind( - server0, - ConnectionPermit::dummy(ConnectionDirection::Incoming), - Arc::new(ByteCounters::new()), - ); - server_dispatcher.bind( - server1, - ConnectionPermit::dummy(ConnectionDirection::Incoming), - Arc::new(ByteCounters::new()), - ); - - for content in [send_content0, send_content1] { - server_sink.send(content.to_vec()).await.unwrap(); - } - - // The messages may be received on any stream - let recv_contents: BTreeSet<_> = stream::select(client0_stream, client1_stream) - .map(|message| message.unwrap().content) - .take(2) - .collect() - .await; - - assert_eq!( - recv_contents, - [send_content0.to_vec(), send_content1.to_vec()] - .into_iter() - .collect::>(), - ); - - client0.close().await; - client1.close().await; - server_dispatcher.shutdown().await; - } - - #[tokio::test(flavor = "multi_thread")] - async fn shutdown() { - let server_dispatcher = MessageDispatcher::new(); - let mut server_stream = server_dispatcher.open_recv(MessageChannelId::random()); - let server_sink = server_dispatcher.open_send(MessageChannelId::random()); - - server_dispatcher.shutdown().await; - - assert_matches!( - server_stream.recv().await, - Err(ContentStreamError::ChannelClosed) - ); - - assert_matches!(server_sink.send(vec![]).await, Err(ChannelClosed)); - } - - async fn create_connection_pair() -> (Connection, Connection) { - // NOTE: Make sure to keep the `reuse_addr` option disabled here to avoid one test to - // accidentally connect to a different test. More details here: - // https://gavv.net/articles/ephemeral-port-reuse/. - - let client = - net::quic::configure((Ipv4Addr::LOCALHOST, 0).into(), SocketOptions::default()) - .unwrap() - .0; - let server = - net::quic::configure((Ipv4Addr::LOCALHOST, 0).into(), SocketOptions::default()) - .unwrap() - .1; - - let client = Connector::from(client); - let server = Acceptor::from(server); - - let client = client.connect(*server.local_addr()); - let server = async { server.accept().await?.await }; - future::try_join(client, server).await.unwrap() + let recv_content = server_stream.try_next().await.unwrap().unwrap(); + assert_eq!(recv_content.as_ref(), send_content.as_ref()); } } diff --git a/lib/src/network/message_io.rs b/lib/src/network/message_io.rs deleted file mode 100644 index 19ac6e1b5..000000000 --- a/lib/src/network/message_io.rs +++ /dev/null @@ -1,406 +0,0 @@ -use super::message::{Header, Message}; -use futures_util::{ready, Sink, Stream}; -use std::{ - io, mem, - pin::Pin, - task::{Context, Poll}, -}; -use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; - -/// Max message size when serialized in bytes. -/// This is also the maximum allowed message size in the Noise Protocol Framework. -const MAX_MESSAGE_SIZE: u16 = u16::MAX - 1; - -/// Size overhead of a single message in addition to the length of its content. -pub(super) const MESSAGE_OVERHEAD: usize = Header::SIZE + 2; - -// Messages are encoded like this: -// -// [ header: `Header::SIZE` bytes ][ len: 2 bytes ][ content: `len` bytes ] -// - -/// Wrapper that turns a reader (`AsyncRead`) into a `Stream` of `Message`. -pub(crate) struct MessageStream { - read: R, - decoder: Decoder, -} - -impl MessageStream { - pub fn new(read: R) -> Self { - Self { - read, - decoder: Decoder::default(), - } - } -} - -impl Stream for MessageStream -where - R: AsyncRead + Unpin, -{ - type Item = io::Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let this = &mut *self; - let mut read = Pin::new(&mut this.read); - this.decoder.poll_next(read.as_mut(), cx).map(Some) - } -} - -/// Wrapper that turns a writer (`AsyncWrite`) into a `Sink` of `Message`. -pub(crate) struct MessageSink { - write: W, - encoder: Encoder, -} - -impl MessageSink { - pub fn new(write: W) -> Self { - Self { - write, - encoder: Encoder::default(), - } - } -} - -impl Sink for MessageSink -where - W: AsyncWrite + Unpin, -{ - type Error = io::Error; - - fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { - self.encoder.start(item) - } - - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let this = &mut *self; - let write = Pin::new(&mut this.write); - this.encoder.poll_ready(write, cx) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.as_mut().poll_ready(cx))?; - - match &self.encoder.state { - EncodeState::Idle => Poll::Ready(Ok(())), - EncodeState::Sending { - phase: SendingPhase::Done, - .. - } => { - let result = ready!(Pin::new(&mut self.write).poll_flush(cx)); - self.encoder.state = EncodeState::Idle; - Poll::Ready(result) - } - EncodeState::Sending { .. } => unreachable!(), - } - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - ready!(self.as_mut().poll_ready(cx))?; - - match &self.encoder.state { - EncodeState::Idle - | EncodeState::Sending { - phase: SendingPhase::Done, - .. - } => { - let result = ready!(Pin::new(&mut self.write).poll_shutdown(cx)); - self.encoder.state = EncodeState::Idle; - Poll::Ready(result) - } - EncodeState::Sending { .. } => unreachable!(), - } - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// Encoder - -struct Encoder { - state: EncodeState, - offset: usize, -} - -enum EncodeState { - Idle, - Sending { - message: Message, - phase: SendingPhase, - }, -} - -enum SendingPhase { - Header, - Len, - Content, - Done, -} - -impl Default for Encoder { - fn default() -> Self { - Self { - state: EncodeState::Idle, - offset: 0, - } - } -} - -impl Encoder { - fn is_sending(&self) -> bool { - match &self.state { - EncodeState::Idle => false, - EncodeState::Sending { - phase: SendingPhase::Done, - .. - } => false, - EncodeState::Sending { .. } => true, - } - } - - fn start(&mut self, message: Message) -> Result<(), io::Error> { - assert!( - !self.is_sending(), - "start_send called while already sending" - ); - - if message.content.len() > MAX_MESSAGE_SIZE as usize { - return Err(io::Error::new(io::ErrorKind::InvalidInput, LengthError)); - } - - self.state = EncodeState::Sending { - message, - phase: SendingPhase::Header, - }; - self.offset = 0; - - Ok(()) - } - - fn poll_ready( - &mut self, - mut io: Pin<&mut W>, - cx: &mut Context, - ) -> Poll> - where - W: AsyncWrite, - { - loop { - match &mut self.state { - EncodeState::Idle => return Poll::Ready(Ok(())), - EncodeState::Sending { message, phase } => match phase { - SendingPhase::Header => { - match ready!(poll_write_all( - io.as_mut(), - cx, - &message.header().serialize(), - &mut self.offset - )) { - Ok(true) => { - *phase = SendingPhase::Len; - self.offset = 0; - } - Ok(false) => (), - Err(error) => { - self.state = EncodeState::Idle; - return Poll::Ready(Err(error)); - } - } - } - SendingPhase::Len => { - let buffer = (message.content.len() as u16).to_be_bytes(); - - match ready!(poll_write_all(io.as_mut(), cx, &buffer, &mut self.offset)) { - Ok(true) => { - if message.content.is_empty() { - *phase = SendingPhase::Done; - } else { - *phase = SendingPhase::Content; - } - self.offset = 0; - } - Ok(false) => (), - Err(error) => { - self.state = EncodeState::Idle; - return Poll::Ready(Err(error)); - } - } - } - SendingPhase::Content => { - match ready!(poll_write_all( - io.as_mut(), - cx, - &message.content, - &mut self.offset - )) { - Ok(true) => { - *phase = SendingPhase::Done; - self.offset = 0; - } - Ok(false) => (), - Err(error) => { - self.state = EncodeState::Idle; - return Poll::Ready(Err(error)); - } - } - } - SendingPhase::Done => return Poll::Ready(Ok(())), - }, - } - } - } -} - -fn poll_write_all( - io: Pin<&mut W>, - cx: &mut Context, - buffer: &[u8], - offset: &mut usize, -) -> Poll> -where - W: AsyncWrite, -{ - let len = ready!(io.poll_write(cx, &buffer[*offset..]))?; - - if len == 0 { - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); - } - - *offset += len; - - Poll::Ready(Ok(*offset >= buffer.len())) -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// Decoder - -struct Decoder { - phase: DecodePhase, - buffer: Vec, - offset: usize, -} - -#[derive(Clone, Copy)] -enum DecodePhase { - Header, - Len { header: Header }, - Content { header: Header }, -} - -impl Default for Decoder { - fn default() -> Self { - Self { - phase: DecodePhase::Header, - buffer: vec![0; Header::SIZE], - offset: 0, - } - } -} - -impl Decoder { - fn poll_next(&mut self, mut io: Pin<&mut R>, cx: &mut Context) -> Poll> - where - R: AsyncRead, - { - loop { - ready!(self.poll_read_exact(io.as_mut(), cx))?; - - match self.phase { - DecodePhase::Header => { - let header: [u8; Header::SIZE] = self - .filled() - .try_into() - .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; - - let header = match Header::deserialize(&header) { - Some(header) => header, - None => { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidData, - BadHeader, - ))) - } - }; - - self.phase = DecodePhase::Len { header }; - self.buffer.resize(2, 0); - self.offset = 0; - } - DecodePhase::Len { header } => { - let len = u16::from_be_bytes( - self.filled() - .try_into() - .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?, - ); - - if len > MAX_MESSAGE_SIZE { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidData, - LengthError, - ))); - } - - if len > 0 { - self.phase = DecodePhase::Content { header }; - self.buffer.resize(len as usize, 0); - self.offset = 0; - } else { - self.phase = DecodePhase::Header; - self.buffer.resize(Header::SIZE, 0); - self.offset = 0; - - return Poll::Ready(Ok(Message { - channel: header.channel, - content: Vec::new(), - })); - } - } - DecodePhase::Content { header } => { - let content = mem::take(&mut self.buffer); - - self.phase = DecodePhase::Header; - self.buffer.resize(Header::SIZE, 0); - self.offset = 0; - - return Poll::Ready(Ok(Message { - channel: header.channel, - content, - })); - } - } - } - } - - fn poll_read_exact(&mut self, mut io: Pin<&mut R>, cx: &mut Context) -> Poll> - where - R: AsyncRead, - { - loop { - let mut buf = ReadBuf::new(&mut self.buffer[self.offset..]); - - match ready!(io.as_mut().poll_read(cx, &mut buf)) { - Ok(()) if !buf.filled().is_empty() => { - self.offset += buf.filled().len(); - - if self.offset >= self.buffer.len() { - return Poll::Ready(Ok(())); - } - } - Ok(()) => return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())), - Err(error) => return Poll::Ready(Err(error)), - } - } - } - - fn filled(&self) -> &[u8] { - &self.buffer[..self.offset] - } -} - -#[derive(Debug, Error)] -#[error("message too big")] -struct LengthError; - -#[derive(Debug, Error)] -#[error("bad header")] -struct BadHeader; diff --git a/lib/src/network/mod.rs b/lib/src/network/mod.rs index 804c59826..af6cb346f 100644 --- a/lib/src/network/mod.rs +++ b/lib/src/network/mod.rs @@ -1,4 +1,3 @@ -mod barrier; mod client; mod connection; mod connection_monitor; @@ -12,7 +11,6 @@ mod local_discovery; mod message; mod message_broker; mod message_dispatcher; -mod message_io; mod peer_addr; mod peer_exchange; mod peer_info; @@ -888,7 +886,7 @@ impl Inner { monitor.mark_as_active(that_runtime_id); tracing::info!(parent: monitor.span(), "Connected"); - let released = permit.released(); + let closed = connection.closed(); let key = { let mut registry = self.registry.lock().unwrap(); @@ -899,13 +897,19 @@ impl Inner { return false; }; + let pex_peer = self.pex_discovery.new_peer(); + pex_peer.handle_connection(permit.addr(), permit.source(), permit.released()); + let mut peer = self.span.in_scope(|| { MessageBroker::new( self.this_runtime_id.public(), that_runtime_id, - self.pex_discovery.new_peer(), + connection, + pex_peer, self.peers_monitor .make_child(format!("{:?}", that_runtime_id.as_public_key())), + self.stats_tracker.bytes.clone(), + permit.byte_counters(), ) }); @@ -921,18 +925,18 @@ impl Inner { ); } - peer.add_connection(connection, permit, self.stats_tracker.bytes.clone()); - peers.insert(peer) }; + // Wait until the connection gets closed, then remove the `MessageBroker` instance. Using a + // RAII to also remove it in case this function gets cancelled. let _guard = PeerGuard { registry: &self.registry, key, monitor, }; - released.await; + closed.await; true } diff --git a/lib/src/network/stats.rs b/lib/src/network/stats.rs index 878bfdbcf..71ce5f54d 100644 --- a/lib/src/network/stats.rs +++ b/lib/src/network/stats.rs @@ -142,18 +142,6 @@ impl Instrumented { pub fn new(inner: T, counters: Arc) -> Self { Self { inner, counters } } - - pub fn as_ref(&self) -> &T { - &self.inner - } - - pub fn as_mut(&mut self) -> &mut T { - &mut self.inner - } - - pub fn counters(&self) -> &ByteCounters { - &self.counters - } } impl AsyncRead for Instrumented From 5b3ca89bf7359f732ed736ae9fd23f4ff78e0fb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Thu, 12 Sep 2024 13:27:40 +0200 Subject: [PATCH 20/55] Remove MessageChannelId, use TopicId instead --- lib/src/network/crypto.rs | 13 +++---- lib/src/network/message.rs | 50 ++------------------------- lib/src/network/message_broker.rs | 27 ++++++++++++--- lib/src/network/message_dispatcher.rs | 23 ++++++------ 4 files changed, 39 insertions(+), 74 deletions(-) diff --git a/lib/src/network/crypto.rs b/lib/src/network/crypto.rs index b6fcea636..7bee8572b 100644 --- a/lib/src/network/crypto.rs +++ b/lib/src/network/crypto.rs @@ -249,12 +249,11 @@ async fn handshake_recv( mod tests { use super::*; use crate::network::{ - message::MessageChannelId, message_dispatcher::{create_connection_pair, MessageDispatcher}, - runtime_id::SecretRuntimeId, stats::ByteCounters, }; use futures_util::future; + use net::bus::TopicId; use std::sync::Arc; #[tokio::test] @@ -265,17 +264,13 @@ mod tests { let server = MessageDispatcher::builder(server).build(); let repo_id = RepositoryId::random(); - - let client_id = SecretRuntimeId::random().public(); - let server_id = SecretRuntimeId::random().public(); - - let channel_id = MessageChannelId::new(&repo_id, &client_id, &server_id); + let topic_id = TopicId::random(); let (mut client_sink, mut client_stream) = - client.open(channel_id, Arc::new(ByteCounters::new())); + client.open(topic_id, Arc::new(ByteCounters::new())); let (mut server_sink, mut server_stream) = - server.open(channel_id, Arc::new(ByteCounters::new())); + server.open(topic_id, Arc::new(ByteCounters::new())); let ((mut client_stream, mut client_sink), (mut server_stream, mut server_sink)) = future::try_join( diff --git a/lib/src/network/message.rs b/lib/src/network/message.rs index bc7996f43..f8a2024c5 100644 --- a/lib/src/network/message.rs +++ b/lib/src/network/message.rs @@ -1,16 +1,14 @@ use super::{ debug_payload::{DebugRequest, DebugResponse}, peer_exchange::PexPayload, - runtime_id::PublicRuntimeId, }; use crate::{ - crypto::{sign::PublicKey, Hash, Hashable}, + crypto::{sign::PublicKey, Hash}, protocol::{ - BlockContent, BlockId, BlockNonce, InnerNodes, LeafNodes, MultiBlockPresence, RepositoryId, + BlockContent, BlockId, BlockNonce, InnerNodes, LeafNodes, MultiBlockPresence, UntrustedProof, }, }; -use net::bus::TopicId; use serde::{Deserialize, Serialize}; #[derive(Clone, PartialEq, Serialize, Deserialize, Debug)] @@ -90,47 +88,3 @@ impl From for Response { } } } - -define_byte_array_wrapper! { - // TODO: consider lower size (truncate the hash) which should still be enough to be unique - // while reducing the message size. - #[derive(Serialize, Deserialize)] - pub(crate) struct MessageChannelId([u8; Hash::SIZE]); -} - -impl MessageChannelId { - pub(super) fn new( - repo_id: &RepositoryId, - this_runtime_id: &PublicRuntimeId, - that_runtime_id: &PublicRuntimeId, - ) -> Self { - let (id1, id2) = if this_runtime_id > that_runtime_id { - (this_runtime_id, that_runtime_id) - } else { - (that_runtime_id, this_runtime_id) - }; - - Self( - (repo_id, id1, id2, b"ouisync message channel id") - .hash() - .into(), - ) - } - - #[cfg(test)] - pub(crate) fn random() -> Self { - Self(rand::random()) - } -} - -impl Default for MessageChannelId { - fn default() -> Self { - Self([0; Self::SIZE]) - } -} - -impl From for TopicId { - fn from(id: MessageChannelId) -> Self { - TopicId::from(id.0) - } -} diff --git a/lib/src/network/message_broker.rs b/lib/src/network/message_broker.rs index 5a9069d04..cad15bfb3 100644 --- a/lib/src/network/message_broker.rs +++ b/lib/src/network/message_broker.rs @@ -1,7 +1,7 @@ use super::{ client::Client, crypto::{self, DecryptingStream, EncryptingSink, EstablishError, RecvError, Role, SendError}, - message::{Content, MessageChannelId, Request, Response}, + message::{Content, Request, Response}, message_dispatcher::{ContentSink, ContentStream, MessageDispatcher}, peer_exchange::{PexPeer, PexReceiver, PexRepository, PexSender}, runtime_id::PublicRuntimeId, @@ -10,6 +10,7 @@ use super::{ }; use crate::{ collections::{hash_map::Entry, HashMap}, + crypto::Hashable, network::constants::{REQUEST_BUFFER_SIZE, RESPONSE_BUFFER_SIZE}, protocol::RepositoryId, repository::Vault, @@ -17,7 +18,7 @@ use crate::{ use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use bytes::{BufMut, BytesMut}; use futures_util::{SinkExt, StreamExt}; -use net::unified::Connection; +use net::{bus::TopicId, unified::Connection}; use state_monitor::StateMonitor; use std::{future, sync::Arc}; use tokio::{ @@ -110,13 +111,13 @@ impl MessageBroker { &self.that_runtime_id, ); - let channel_id = MessageChannelId::new( + let topic_id = make_topic_id( vault.repository_id(), &self.this_runtime_id, &self.that_runtime_id, ); - let (sink, stream) = self.dispatcher.open(channel_id, repo_counters); + let (sink, stream) = self.dispatcher.open(topic_id, repo_counters); let (pex_tx, pex_rx) = self.pex_peer.new_link(pex_repo); @@ -176,6 +177,24 @@ impl Drop for SpanGuard { } } +fn make_topic_id( + repo_id: &RepositoryId, + this_runtime_id: &PublicRuntimeId, + that_runtime_id: &PublicRuntimeId, +) -> TopicId { + let (id1, id2) = if this_runtime_id > that_runtime_id { + (this_runtime_id, that_runtime_id) + } else { + (that_runtime_id, this_runtime_id) + }; + + let bytes: [_; TopicId::SIZE] = (repo_id, id1, id2, b"ouisync message topic id") + .hash() + .into(); + + TopicId::from(bytes) +} + struct Link { role: Role, stream: ContentStream, diff --git a/lib/src/network/message_dispatcher.rs b/lib/src/network/message_dispatcher.rs index 51f6045c2..99cb636a3 100644 --- a/lib/src/network/message_dispatcher.rs +++ b/lib/src/network/message_dispatcher.rs @@ -1,19 +1,16 @@ //! Utilities for sending and receiving messages across the network. -use super::{ - message::MessageChannelId, - stats::{ByteCounters, Instrumented}, -}; +use super::stats::{ByteCounters, Instrumented}; use net::{ - bus::{Bus, BusRecvStream, BusSendStream}, + bus::{Bus, BusRecvStream, BusSendStream, TopicId}, unified::Connection, }; use std::sync::Arc; use tokio_util::codec::{length_delimited, FramedRead, FramedWrite, LengthDelimitedCodec}; /// Reads/writes messages from/to the underlying TCP or QUIC connection and dispatches them to -/// individual streams/sinks based on their channel ids (in the MessageDispatcher's and -/// MessageBroker's contexts, there is a one-to-one relationship between the channel id and a +/// individual streams/sinks based on their topic ids (in the MessageDispatcher's and +/// MessageBroker's contexts, there is a one-to-one relationship between the topic id and a /// repository id). pub(super) struct MessageDispatcher { bus: net::bus::Bus, @@ -30,13 +27,13 @@ impl MessageDispatcher { } } - /// Opens a sink and a stream for communication on the given channel. + /// Opens a sink and a stream for communication on the given topic. pub fn open( &self, - channel_id: MessageChannelId, + topic_id: TopicId, repo_counters: Arc, ) -> (ContentSink, ContentStream) { - let (writer, reader) = self.bus.create_topic(channel_id.into()); + let (writer, reader) = self.bus.create_topic(topic_id); let writer = Instrumented::new(writer, self.total_counters.clone()); let writer = Instrumented::new(writer, self.peer_counters.clone()); @@ -151,7 +148,7 @@ mod tests { async fn sanity_check() { crate::test_utils::init_log(); - let channel = MessageChannelId::random(); + let topic_id = TopicId::random(); let send_content = b"hello world"; let (client, server) = create_connection_pair().await; @@ -159,12 +156,12 @@ mod tests { let server_dispatcher = MessageDispatcher::builder(server).build(); let (_server_sink, mut server_stream) = - server_dispatcher.open(channel, Arc::new(ByteCounters::new())); + server_dispatcher.open(topic_id, Arc::new(ByteCounters::new())); let client_dispatcher = MessageDispatcher::builder(client).build(); let (mut client_sink, _client_stream) = - client_dispatcher.open(channel, Arc::new(ByteCounters::new())); + client_dispatcher.open(topic_id, Arc::new(ByteCounters::new())); client_sink .send(Bytes::from_static(send_content)) From e86daa2aa38aa8e4e380e39b01bf4ce71a142ff7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Thu, 12 Sep 2024 13:51:59 +0200 Subject: [PATCH 21/55] Rename Content -> Message --- lib/src/network/client.rs | 16 ++++---- lib/src/network/crypto.rs | 18 ++++----- lib/src/network/message.rs | 26 ++++++------ lib/src/network/message_broker.rs | 58 +++++++++++++-------------- lib/src/network/message_dispatcher.rs | 6 +-- lib/src/network/peer_exchange.rs | 6 +-- lib/src/network/server.rs | 10 ++--- lib/src/network/tests.rs | 10 ++--- 8 files changed, 75 insertions(+), 75 deletions(-) diff --git a/lib/src/network/client.rs b/lib/src/network/client.rs index c04338d99..a222b51af 100644 --- a/lib/src/network/client.rs +++ b/lib/src/network/client.rs @@ -1,7 +1,7 @@ use super::{ constants::RESPONSE_BATCH_SIZE, debug_payload::{DebugResponse, PendingDebugRequest}, - message::{Content, Response, ResponseDisambiguator}, + message::{Message, Response, ResponseDisambiguator}, pending::{ EphemeralResponse, PendingRequest, PendingRequests, PersistableResponse, PreparedResponse, }, @@ -35,7 +35,7 @@ pub(super) struct Client { impl Client { pub fn new( vault: Vault, - content_tx: mpsc::UnboundedSender, + message_tx: mpsc::UnboundedSender, response_rx: mpsc::Receiver, ) -> Self { let pending_requests = PendingRequests::new(vault.monitor.clone()); @@ -45,7 +45,7 @@ impl Client { vault, pending_requests, block_tracker, - content_tx, + message_tx, }; Self { inner, response_rx } @@ -64,7 +64,7 @@ struct Inner { vault: Vault, pending_requests: PendingRequests, block_tracker: TrackerClient, - content_tx: mpsc::UnboundedSender, + message_tx: mpsc::UnboundedSender, } impl Inner { @@ -78,8 +78,8 @@ impl Inner { fn send_request(&self, request: PendingRequest) { if let Some(request) = self.pending_requests.insert(request) { - self.content_tx - .send(Content::Request(request)) + self.message_tx + .send(Message::Request(request)) .unwrap_or(()); } } @@ -554,13 +554,13 @@ mod tests { let pending_requests = PendingRequests::new(vault.monitor.clone()); let block_tracker = vault.block_tracker.client(); - let (content_tx, _content_rx) = mpsc::unbounded_channel(); + let (message_tx, _message_rx) = mpsc::unbounded_channel(); let inner = Inner { vault, pending_requests, block_tracker, - content_tx, + message_tx, }; (base_dir, inner, secrets) diff --git a/lib/src/network/crypto.rs b/lib/src/network/crypto.rs index 7bee8572b..0ce94e89c 100644 --- a/lib/src/network/crypto.rs +++ b/lib/src/network/crypto.rs @@ -8,7 +8,7 @@ //! based on the identity of the replicas is needed. use super::{ - message_dispatcher::{ContentSink, ContentStream}, + message_dispatcher::{MessageSink, MessageStream}, runtime_id::PublicRuntimeId, }; use crate::protocol::RepositoryId; @@ -65,9 +65,9 @@ impl Role { // session. const MAX_NONCE: u64 = u64::MAX - 1; -/// Wrapper for [`ContentStream`] that decrypts incoming messages. +/// Wrapper for [`MessageStream`] that decrypts incoming messages. pub(super) struct DecryptingStream<'a> { - inner: &'a mut ContentStream, + inner: &'a mut MessageStream, cipher: CipherState, } @@ -94,9 +94,9 @@ impl Stream for DecryptingStream<'_> { } } -/// Wrapper for [`ContentSink`] that encrypts outgoing messages. +/// Wrapper for [`MessageSink`] that encrypts outgoing messages. pub(super) struct EncryptingSink<'a> { - inner: &'a mut ContentSink, + inner: &'a mut MessageSink, cipher: CipherState, } @@ -137,8 +137,8 @@ impl Sink for EncryptingSink<'_> { pub(super) async fn establish_channel<'a>( role: Role, repo_id: &RepositoryId, - stream: &'a mut ContentStream, - sink: &'a mut ContentSink, + stream: &'a mut MessageStream, + sink: &'a mut MessageSink, ) -> Result<(DecryptingStream<'a>, EncryptingSink<'a>), EstablishError> { let mut handshake_state = build_handshake_state(role, repo_id); @@ -225,7 +225,7 @@ fn build_handshake_state(role: Role, repo_id: &RepositoryId) -> HandshakeState { async fn handshake_send( state: &mut HandshakeState, - sink: &mut ContentSink, + sink: &mut MessageSink, msg: &[u8], ) -> Result<(), EstablishError> { let content = state.write_message_vec(msg)?; @@ -235,7 +235,7 @@ async fn handshake_send( async fn handshake_recv( state: &mut HandshakeState, - stream: &mut ContentStream, + stream: &mut MessageStream, ) -> Result, EstablishError> { let content = stream .try_next() diff --git a/lib/src/network/message.rs b/lib/src/network/message.rs index f8a2024c5..478cac2b5 100644 --- a/lib/src/network/message.rs +++ b/lib/src/network/message.rs @@ -58,7 +58,7 @@ pub(crate) enum Response { } #[derive(Serialize, Deserialize, Debug)] -pub(crate) enum Content { +pub(crate) enum Message { Request(Request), Response(Response), // Peer exchange @@ -66,24 +66,24 @@ pub(crate) enum Content { } #[cfg(test)] -impl From for Request { - fn from(content: Content) -> Self { - match content { - Content::Request(request) => request, - Content::Response(_) | Content::Pex(_) => { - panic!("not a request: {:?}", content) +impl From for Request { + fn from(message: Message) -> Self { + match message { + Message::Request(request) => request, + Message::Response(_) | Message::Pex(_) => { + panic!("not a request: {:?}", message) } } } } #[cfg(test)] -impl From for Response { - fn from(content: Content) -> Self { - match content { - Content::Response(response) => response, - Content::Request(_) | Content::Pex(_) => { - panic!("not a response: {:?}", content) +impl From for Response { + fn from(message: Message) -> Self { + match message { + Message::Response(response) => response, + Message::Request(_) | Message::Pex(_) => { + panic!("not a response: {:?}", message) } } } diff --git a/lib/src/network/message_broker.rs b/lib/src/network/message_broker.rs index cad15bfb3..032e2533e 100644 --- a/lib/src/network/message_broker.rs +++ b/lib/src/network/message_broker.rs @@ -1,8 +1,8 @@ use super::{ client::Client, crypto::{self, DecryptingStream, EncryptingSink, EstablishError, RecvError, Role, SendError}, - message::{Content, Request, Response}, - message_dispatcher::{ContentSink, ContentStream, MessageDispatcher}, + message::{Message, Request, Response}, + message_dispatcher::{MessageDispatcher, MessageSink, MessageStream}, peer_exchange::{PexPeer, PexReceiver, PexRepository, PexSender}, runtime_id::PublicRuntimeId, server::Server, @@ -197,8 +197,8 @@ fn make_topic_id( struct Link { role: Role, - stream: ContentStream, - sink: ContentSink, + stream: MessageStream, + sink: MessageSink, vault: Vault, response_limiter: Arc, pex_tx: PexSender, @@ -265,8 +265,8 @@ impl Link { async fn establish_channel<'a>( role: Role, - stream: &'a mut ContentStream, - sink: &'a mut ContentSink, + stream: &'a mut MessageStream, + sink: &'a mut MessageSink, vault: &Vault, ) -> Result<(DecryptingStream<'a>, EncryptingSink<'a>), EstablishError> { match crypto::establish_channel(role, vault.repository_id(), stream, sink).await { @@ -295,17 +295,17 @@ async fn run_link( let (request_tx, request_rx) = mpsc::channel(REQUEST_BUFFER_SIZE); let (response_tx, response_rx) = mpsc::channel(RESPONSE_BUFFER_SIZE); // Outgoing message channel is unbounded because we fully control how much stuff goes into it. - let (content_tx, content_rx) = mpsc::unbounded_channel(); + let (message_tx, message_rx) = mpsc::unbounded_channel(); tracing::info!("Link opened"); // Run everything in parallel: let flow = select! { - flow = run_client(repo.clone(), content_tx.clone(), response_rx) => flow, - flow = run_server(repo.clone(), content_tx.clone(), request_rx, response_limiter) => flow, + flow = run_client(repo.clone(), message_tx.clone(), response_rx) => flow, + flow = run_server(repo.clone(), message_tx.clone(), request_rx, response_limiter) => flow, flow = recv_messages(stream, request_tx, response_tx, pex_rx) => flow, - flow = send_messages(content_rx, sink) => flow, - _ = pex_tx.run(content_tx) => ControlFlow::Continue, + flow = send_messages(message_rx, sink) => flow, + _ = pex_tx.run(message_tx) => ControlFlow::Continue, }; tracing::info!("Link closed"); @@ -321,8 +321,8 @@ async fn recv_messages( pex_rx: &PexReceiver, ) -> ControlFlow { loop { - let content = match stream.next().await { - Some(Ok(content)) => content, + let message = match stream.next().await { + Some(Ok(message)) => message, Some(Err(RecvError::Crypto)) => { tracing::warn!("Failed to decrypt incoming message",); return ControlFlow::Continue; @@ -341,32 +341,32 @@ async fn recv_messages( } }; - let content: Content = match bincode::deserialize(&content) { - Ok(content) => content, + let message: Message = match bincode::deserialize(&message) { + Ok(message) => message, Err(error) => { tracing::warn!(?error, "Failed to deserialize incoming message"); continue; } }; - match content { - Content::Request(request) => request_tx.send(request).await.unwrap_or(()), - Content::Response(response) => response_tx.send(response).await.unwrap_or(()), - Content::Pex(payload) => pex_rx.handle_message(payload).await, + match message { + Message::Request(request) => request_tx.send(request).await.unwrap_or(()), + Message::Response(response) => response_tx.send(response).await.unwrap_or(()), + Message::Pex(payload) => pex_rx.handle_message(payload).await, } } } // Handle outgoing messages async fn send_messages( - mut content_rx: mpsc::UnboundedReceiver, + mut message_rx: mpsc::UnboundedReceiver, mut sink: EncryptingSink<'_>, ) -> ControlFlow { let mut writer = BytesMut::new().writer(); loop { - let content = match content_rx.try_recv() { - Ok(content) => Some(content), + let message = match message_rx.try_recv() { + Ok(message) => Some(message), Err(TryRecvError::Empty) => { match sink.flush().await { Ok(()) => (), @@ -376,18 +376,18 @@ async fn send_messages( } } - content_rx.recv().await + message_rx.recv().await } Err(TryRecvError::Disconnected) => None, }; - let Some(content) = content else { + let Some(message) = message else { forever().await }; // unwrap is OK because serialization into a vec should never fail unless we have a bug // somewhere. - bincode::serialize_into(&mut writer, &content).unwrap(); + bincode::serialize_into(&mut writer, &message).unwrap(); match sink.feed(writer.get_mut().split().freeze()).await { Ok(()) => (), @@ -406,10 +406,10 @@ async fn send_messages( // Create and run client. Returns only on error. async fn run_client( repo: Vault, - content_tx: mpsc::UnboundedSender, + message_tx: mpsc::UnboundedSender, response_rx: mpsc::Receiver, ) -> ControlFlow { - let mut client = Client::new(repo, content_tx, response_rx); + let mut client = Client::new(repo, message_tx, response_rx); let result = client.run().await; tracing::debug!("Client stopped running with result {:?}", result); @@ -423,11 +423,11 @@ async fn run_client( // Create and run server. Returns only on error. async fn run_server( repo: Vault, - content_tx: mpsc::UnboundedSender, + message_tx: mpsc::UnboundedSender, request_rx: mpsc::Receiver, response_limiter: Arc, ) -> ControlFlow { - let mut server = Server::new(repo, content_tx, request_rx, response_limiter); + let mut server = Server::new(repo, message_tx, request_rx, response_limiter); let result = server.run().await; diff --git a/lib/src/network/message_dispatcher.rs b/lib/src/network/message_dispatcher.rs index 99cb636a3..72e710634 100644 --- a/lib/src/network/message_dispatcher.rs +++ b/lib/src/network/message_dispatcher.rs @@ -32,7 +32,7 @@ impl MessageDispatcher { &self, topic_id: TopicId, repo_counters: Arc, - ) -> (ContentSink, ContentStream) { + ) -> (MessageSink, MessageStream) { let (writer, reader) = self.bus.create_topic(topic_id); let writer = Instrumented::new(writer, self.total_counters.clone()); @@ -101,10 +101,10 @@ fn make_codec() -> LengthDelimitedCodec { // The streams/sinks are tripple-instrumented: once to collect the total cummulative traffic across // all peers, once to collect the traffic per peer and once to collect the traffic per repo. -pub(super) type ContentStream = +pub(super) type MessageStream = FramedRead>>, LengthDelimitedCodec>; -pub(super) type ContentSink = +pub(super) type MessageSink = FramedWrite>>, LengthDelimitedCodec>; /// Create pair of Connections connected to each other. For tests only. diff --git a/lib/src/network/peer_exchange.rs b/lib/src/network/peer_exchange.rs index dfefb5612..e3f5a6f0a 100644 --- a/lib/src/network/peer_exchange.rs +++ b/lib/src/network/peer_exchange.rs @@ -4,7 +4,7 @@ use super::{ connection::ConnectionDirection, ip, - message::Content, + message::Message, peer_addr::PeerAddr, seen_peers::{SeenPeer, SeenPeers}, PeerSource, @@ -232,7 +232,7 @@ pub(crate) struct PexSender { impl PexSender { /// While this method is running, it periodically sends contacts of other peers that share the /// same repo to this peer and makes the contacts of this peer aailable to them. - pub async fn run(&mut self, content_tx: mpsc::UnboundedSender) { + pub async fn run(&mut self, message_tx: mpsc::UnboundedSender) { let Some(collector) = self.enable() else { // Another collector for this link already exists. return; @@ -253,7 +253,7 @@ impl PexSender { }; if !addrs.is_empty() { - content_tx.send(Content::Pex(PexPayload(addrs))).ok(); + message_tx.send(Message::Pex(PexPayload(addrs))).ok(); } let interval = rand::thread_rng().gen_range(SEND_INTERVAL_RANGE); diff --git a/lib/src/network/server.rs b/lib/src/network/server.rs index faeea9ab4..a3a533762 100644 --- a/lib/src/network/server.rs +++ b/lib/src/network/server.rs @@ -1,7 +1,7 @@ use super::{ constants::{INTEREST_TIMEOUT, MAX_UNCHOKED_DURATION}, debug_payload::{DebugRequest, DebugResponse}, - message::{Content, Request, Response, ResponseDisambiguator}, + message::{Message, Request, Response, ResponseDisambiguator}, }; use crate::{ crypto::{sign::PublicKey, Hash}, @@ -32,7 +32,7 @@ pub(crate) struct Server { impl Server { pub fn new( vault: Vault, - content_tx: mpsc::UnboundedSender, + message_tx: mpsc::UnboundedSender, request_rx: mpsc::Receiver, response_limiter: Arc, ) -> Self { @@ -42,7 +42,7 @@ impl Server { inner: Inner { vault, response_tx, - content_tx, + message_tx, response_limiter, }, request_rx, @@ -64,7 +64,7 @@ impl Server { struct Inner { vault: Vault, response_tx: mpsc::Sender, - content_tx: mpsc::UnboundedSender, + message_tx: mpsc::UnboundedSender, response_limiter: Arc, } @@ -361,7 +361,7 @@ impl Inner { } fn send_response(&self, response: Response) { - if self.content_tx.send(Content::Response(response)).is_ok() { + if self.message_tx.send(Message::Response(response)).is_ok() { self.vault.monitor.responses_sent.increment(1); } } diff --git a/lib/src/network/tests.rs b/lib/src/network/tests.rs index 16c512330..6cb92dbba 100644 --- a/lib/src/network/tests.rs +++ b/lib/src/network/tests.rs @@ -1,7 +1,7 @@ use super::{ client::Client, constants::MAX_UNCHOKED_COUNT, - message::{Content, Request, Response}, + message::{Message, Request, Response}, server::Server, }; use crate::{ @@ -620,12 +620,12 @@ where type ServerData = ( Server, - mpsc::UnboundedReceiver, + mpsc::UnboundedReceiver, mpsc::Sender, ); type ClientData = ( Client, - mpsc::UnboundedReceiver, + mpsc::UnboundedReceiver, mpsc::Sender, ); @@ -647,13 +647,13 @@ fn create_client(repo: Vault) -> ClientData { // Simulated connection between a server and a client. struct Connection<'a, T> { - send_rx: &'a mut mpsc::UnboundedReceiver, + send_rx: &'a mut mpsc::UnboundedReceiver, recv_tx: &'a mut mpsc::Sender, } impl Connection<'_, T> where - T: From + fmt::Debug, + T: From + fmt::Debug, { async fn run(&mut self) { while let Some(content) = self.send_rx.recv().await { From 85acb3a665425fbd55fd4f208b619f09778ceb4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Thu, 12 Sep 2024 15:08:42 +0200 Subject: [PATCH 22/55] Fix sync failure after relinking repo --- lib/src/crypto/hash.rs | 2 +- lib/src/network/message_broker.rs | 141 +++++++++++--------------- lib/src/network/message_dispatcher.rs | 1 + net/src/bus.rs | 66 ++++++++++++ 4 files changed, 130 insertions(+), 80 deletions(-) diff --git a/lib/src/crypto/hash.rs b/lib/src/crypto/hash.rs index 2a2b9077c..a38e11fd2 100644 --- a/lib/src/crypto/hash.rs +++ b/lib/src/crypto/hash.rs @@ -86,7 +86,7 @@ impl fmt::Debug for Hash { impl fmt::LowerHex for Hash { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", hex_fmt::HexFmt(self.as_ref())) + hex_fmt::HexFmt(self.as_ref()).fmt(f) } } diff --git a/lib/src/network/message_broker.rs b/lib/src/network/message_broker.rs index 032e2533e..5daa3e906 100644 --- a/lib/src/network/message_broker.rs +++ b/lib/src/network/message_broker.rs @@ -1,6 +1,6 @@ use super::{ client::Client, - crypto::{self, DecryptingStream, EncryptingSink, EstablishError, RecvError, Role, SendError}, + crypto::{self, DecryptingStream, EncryptingSink, EstablishError, Role}, message::{Message, Request, Response}, message_dispatcher::{MessageDispatcher, MessageSink, MessageStream}, peer_exchange::{PexPeer, PexReceiver, PexRepository, PexSender}, @@ -20,7 +20,7 @@ use bytes::{BufMut, BytesMut}; use futures_util::{SinkExt, StreamExt}; use net::{bus::TopicId, unified::Connection}; use state_monitor::StateMonitor; -use std::{future, sync::Arc}; +use std::{sync::Arc, time::Instant}; use tokio::{ select, sync::{ @@ -117,19 +117,18 @@ impl MessageBroker { &self.that_runtime_id, ); - let (sink, stream) = self.dispatcher.open(topic_id, repo_counters); - let (pex_tx, pex_rx) = self.pex_peer.new_link(pex_repo); let mut link = Link { role, - stream, - sink, + topic_id, + dispatcher: self.dispatcher.clone(), vault, response_limiter, pex_tx, pex_rx, monitor, + repo_counters, }; drop(span_enter); @@ -161,7 +160,7 @@ struct SpanGuard(Span); impl SpanGuard { fn new(that_runtime_id: &PublicRuntimeId) -> Self { let span = tracing::info_span!( - "message_broker", + "peer", message = ?that_runtime_id.as_public_key(), ); @@ -197,13 +196,14 @@ fn make_topic_id( struct Link { role: Role, - stream: MessageStream, - sink: MessageSink, + topic_id: TopicId, + dispatcher: MessageDispatcher, vault: Vault, response_limiter: Arc, pex_tx: PexSender, pex_rx: PexReceiver, monitor: StateMonitor, + repo_counters: Arc, } impl Link { @@ -216,9 +216,12 @@ impl Link { Running, } + let min_backoff = Duration::from_millis(100); + let max_backoff = Duration::from_secs(5); + let mut backoff = ExponentialBackoffBuilder::new() - .with_initial_interval(Duration::from_millis(100)) - .with_max_interval(Duration::from_secs(5)) + .with_initial_interval(min_backoff) + .with_max_interval(max_backoff) .with_max_elapsed_time(None) .build(); @@ -235,18 +238,21 @@ impl Link { *state.get() = State::EstablishingChannel; - let (crypto_stream, crypto_sink) = - match establish_channel(self.role, &mut self.stream, &mut self.sink, &self.vault) - .await - { - Ok(io) => io, - Err(EstablishError::Crypto) => continue, - Err(EstablishError::Io(_)) => break, - }; + let (mut sink, mut stream) = self + .dispatcher + .open(self.topic_id, self.repo_counters.clone()); + + let Ok((crypto_stream, crypto_sink)) = + establish_channel(self.role, &mut stream, &mut sink, &self.vault).await + else { + continue; + }; *state.get() = State::Running; - match run_link( + let start = Instant::now(); + + run_link( crypto_stream, crypto_sink, &self.vault, @@ -254,10 +260,10 @@ impl Link { &mut self.pex_tx, &mut self.pex_rx, ) - .await - { - ControlFlow::Continue => continue, - ControlFlow::Break => break, + .await; + + if start.elapsed() > max_backoff { + backoff.reset(); } } } @@ -276,7 +282,6 @@ async fn establish_channel<'a>( } Err(error) => { tracing::warn!(?error, "Failed to establish encrypted channel"); - Err(error) } } @@ -289,7 +294,7 @@ async fn run_link( response_limiter: Arc, pex_tx: &mut PexSender, pex_rx: &mut PexReceiver, -) -> ControlFlow { +) { // Incoming message channels are bounded to prevent malicious peers from sending us too many // messages and exhausting our memory. let (request_tx, request_rx) = mpsc::channel(REQUEST_BUFFER_SIZE); @@ -297,20 +302,30 @@ async fn run_link( // Outgoing message channel is unbounded because we fully control how much stuff goes into it. let (message_tx, message_rx) = mpsc::unbounded_channel(); - tracing::info!("Link opened"); + let _guard = LinkGuard::new(); - // Run everything in parallel: - let flow = select! { - flow = run_client(repo.clone(), message_tx.clone(), response_rx) => flow, - flow = run_server(repo.clone(), message_tx.clone(), request_rx, response_limiter) => flow, - flow = recv_messages(stream, request_tx, response_tx, pex_rx) => flow, - flow = send_messages(message_rx, sink) => flow, - _ = pex_tx.run(message_tx) => ControlFlow::Continue, + select! { + _ = run_client(repo.clone(), message_tx.clone(), response_rx) => (), + _ = run_server(repo.clone(), message_tx.clone(), request_rx, response_limiter) => (), + _ = recv_messages(stream, request_tx, response_tx, pex_rx) => (), + _ = send_messages(message_rx, sink) => (), + _ = pex_tx.run(message_tx) => (), }; +} + +struct LinkGuard; - tracing::info!("Link closed"); +impl LinkGuard { + fn new() -> Self { + tracing::info!("Link opened"); + Self + } +} - flow +impl Drop for LinkGuard { + fn drop(&mut self) { + tracing::info!("Link closed"); + } } // Handle incoming messages @@ -319,25 +334,17 @@ async fn recv_messages( request_tx: mpsc::Sender, response_tx: mpsc::Sender, pex_rx: &PexReceiver, -) -> ControlFlow { +) { loop { let message = match stream.next().await { Some(Ok(message)) => message, - Some(Err(RecvError::Crypto)) => { - tracing::warn!("Failed to decrypt incoming message",); - return ControlFlow::Continue; - } - Some(Err(RecvError::Exhausted)) => { - tracing::debug!("Incoming message nonce counter exhausted",); - return ControlFlow::Continue; - } - Some(Err(RecvError::Io(error))) => { + Some(Err(error)) => { tracing::warn!(?error, "Failed to receive incoming message"); - return ControlFlow::Break; + break; } None => { tracing::debug!("Message channel closed"); - return ControlFlow::Break; + break; } }; @@ -361,7 +368,7 @@ async fn recv_messages( async fn send_messages( mut message_rx: mpsc::UnboundedReceiver, mut sink: EncryptingSink<'_>, -) -> ControlFlow { +) { let mut writer = BytesMut::new().writer(); loop { @@ -372,7 +379,7 @@ async fn send_messages( Ok(()) => (), Err(error) => { tracing::warn!(?error, "Failed to flush outgoing messages"); - return ControlFlow::Break; + break; } } @@ -382,7 +389,7 @@ async fn send_messages( }; let Some(message) = message else { - forever().await + return; }; // unwrap is OK because serialization into a vec should never fail unless we have a bug @@ -391,13 +398,9 @@ async fn send_messages( match sink.feed(writer.get_mut().split().freeze()).await { Ok(()) => (), - Err(SendError::Exhausted) => { - tracing::debug!("Outgoing message nonce counter exhausted"); - return ControlFlow::Continue; - } - Err(SendError::Io(error)) => { + Err(error) => { tracing::warn!(?error, "Failed to send outgoing message"); - return ControlFlow::Break; + break; } } } @@ -408,16 +411,11 @@ async fn run_client( repo: Vault, message_tx: mpsc::UnboundedSender, response_rx: mpsc::Receiver, -) -> ControlFlow { +) { let mut client = Client::new(repo, message_tx, response_rx); let result = client.run().await; tracing::debug!("Client stopped running with result {:?}", result); - - match result { - Ok(()) => forever().await, - Err(_) => ControlFlow::Continue, - } } // Create and run server. Returns only on error. @@ -426,25 +424,10 @@ async fn run_server( message_tx: mpsc::UnboundedSender, request_rx: mpsc::Receiver, response_limiter: Arc, -) -> ControlFlow { +) { let mut server = Server::new(repo, message_tx, request_rx, response_limiter); let result = server.run().await; tracing::debug!("Server stopped running with result {:?}", result); - - match result { - Ok(()) => forever().await, - Err(_) => ControlFlow::Continue, - } -} - -async fn forever() -> ! { - future::pending::<()>().await; - unreachable!() -} - -enum ControlFlow { - Continue, - Break, } diff --git a/lib/src/network/message_dispatcher.rs b/lib/src/network/message_dispatcher.rs index 72e710634..3afac93e3 100644 --- a/lib/src/network/message_dispatcher.rs +++ b/lib/src/network/message_dispatcher.rs @@ -12,6 +12,7 @@ use tokio_util::codec::{length_delimited, FramedRead, FramedWrite, LengthDelimit /// individual streams/sinks based on their topic ids (in the MessageDispatcher's and /// MessageBroker's contexts, there is a one-to-one relationship between the topic id and a /// repository id). +#[derive(Clone)] pub(super) struct MessageDispatcher { bus: net::bus::Bus, total_counters: Arc, diff --git a/net/src/bus.rs b/net/src/bus.rs index 62642019a..5e6bce7e4 100644 --- a/net/src/bus.rs +++ b/net/src/bus.rs @@ -22,6 +22,7 @@ use worker::Command; /// underlying connection) number of independent streams, each bound to a specific topic. When the /// two peers create streams bound to the same topic, they can communicate on them with each /// other. +#[derive(Clone)] pub struct Bus { command_tx: mpsc::UnboundedSender, } @@ -162,6 +163,8 @@ mod tests { use crate::test_utils::{ create_connected_connections, create_connected_peers, init_log, Proto, }; + use assert_matches::assert_matches; + use futures_util::future; use tokio::io::{AsyncReadExt, AsyncWriteExt}; #[tokio::test] @@ -253,4 +256,67 @@ mod tests { _ => panic!("unexpected {:?}", (&buffer_0, &buffer_1)), } } + + #[tokio::test] + async fn recreate_topic_tcp() { + recreate_topic_case(Proto::Tcp).await + } + + #[tokio::test] + async fn recreate_topic_quic() { + recreate_topic_case(Proto::Quic).await + } + + async fn recreate_topic_case(proto: Proto) { + init_log(); + + let (client, server) = create_connected_peers(proto); + let (client, server) = create_connected_connections(&client, &server).await; + + let client = Bus::new(client); + let server = Bus::new(server); + + let topic_id = TopicId::random(); + + let (mut client_send_stream, client_recv_stream) = client.create_topic(topic_id); + let (_server_send_stream, mut server_recv_stream) = server.create_topic(topic_id); + + future::join( + async { + client_send_stream.write_all(b"ping 0").await.unwrap(); + }, + async { + let mut buffer = [0; 6]; + server_recv_stream.read_exact(&mut buffer).await.unwrap(); + assert_eq!(&buffer, b"ping 0"); + }, + ) + .await; + + // Close the client streams and create them again + drop(client_send_stream); + drop(client_recv_stream); + + let (mut client_send_stream, _client_recv_stream) = client.create_topic(topic_id); + + future::join( + async { + client_send_stream.write_all(b"ping 1").await.unwrap(); + }, + async { + // Reading from the stream fails because the client stream has been closed. + let mut buffer = [0; 6]; + assert_matches!( + server_recv_stream.read_exact(&mut buffer).await, + Err(error) if error.kind() == io::ErrorKind::UnexpectedEof + ); + + // Recreate the stream and try again. This time is succeeds. + let (_server_send_stream, mut server_recv_stream) = server.create_topic(topic_id); + server_recv_stream.read_exact(&mut buffer).await.unwrap(); + assert_eq!(&buffer, b"ping 1"); + }, + ) + .await; + } } From ea3ec26af049194d0f25ccc00333fff786f884c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Thu, 12 Sep 2024 16:03:22 +0200 Subject: [PATCH 23/55] Fix new warnings in rust 1.81 --- cli/tests/utils.rs | 4 ++-- lib/src/network/connection.rs | 29 --------------------------- lib/src/network/crypto.rs | 4 ++-- lib/src/network/message_dispatcher.rs | 4 ++-- lib/src/network/stats.rs | 4 ---- lib/src/versioned.rs | 1 + net/src/tcp.rs | 9 ++------- 7 files changed, 9 insertions(+), 46 deletions(-) diff --git a/cli/tests/utils.rs b/cli/tests/utils.rs index 2a96502b4..b76ac1cea 100644 --- a/cli/tests/utils.rs +++ b/cli/tests/utils.rs @@ -121,7 +121,7 @@ impl Bin { "", self.client_command() .arg("add-peers") - .arg(&format!("tcp/{}:{peer_port}", Ipv4Addr::LOCALHOST)) + .arg(format!("tcp/{}:{peer_port}", Ipv4Addr::LOCALHOST)) .output() .unwrap(), ) @@ -183,7 +183,7 @@ impl Bin { &self .client_command() .arg("bind-rpc") - .arg(&format!("{}:0", Ipv4Addr::LOCALHOST)) + .arg(format!("{}:0", Ipv4Addr::LOCALHOST)) .output() .unwrap() .stdout, diff --git a/lib/src/network/connection.rs b/lib/src/network/connection.rs index a71ea0c03..e718f1afe 100644 --- a/lib/src/network/connection.rs +++ b/lib/src/network/connection.rs @@ -237,35 +237,6 @@ impl ConnectionPermit { .unwrap_or_default() } - /// Dummy connection permit for tests. - #[cfg(test)] - pub fn dummy(dir: ConnectionDirection) -> Self { - use std::net::Ipv4Addr; - - let key = Key { - addr: PeerAddr::Tcp((Ipv4Addr::UNSPECIFIED, 0).into()), - dir, - }; - let id = ConnectionId::next(); - let source = match dir { - ConnectionDirection::Incoming => PeerSource::Listener, - ConnectionDirection::Outgoing => PeerSource::UserProvided, - }; - let data = Data { - id, - state: PeerState::Known, - source, - stats_tracker: StatsTracker::default(), - on_release: DropAwaitable::new(), - }; - - Self { - connections: watch::Sender::new([(key, data)].into()), - key, - id, - } - } - fn with(&self, f: F) -> Option where F: FnOnce(&Data) -> R, diff --git a/lib/src/network/crypto.rs b/lib/src/network/crypto.rs index 0ce94e89c..8d522d075 100644 --- a/lib/src/network/crypto.rs +++ b/lib/src/network/crypto.rs @@ -267,10 +267,10 @@ mod tests { let topic_id = TopicId::random(); let (mut client_sink, mut client_stream) = - client.open(topic_id, Arc::new(ByteCounters::new())); + client.open(topic_id, Arc::new(ByteCounters::default())); let (mut server_sink, mut server_stream) = - server.open(topic_id, Arc::new(ByteCounters::new())); + server.open(topic_id, Arc::new(ByteCounters::default())); let ((mut client_stream, mut client_sink), (mut server_stream, mut server_sink)) = future::try_join( diff --git a/lib/src/network/message_dispatcher.rs b/lib/src/network/message_dispatcher.rs index 3afac93e3..da3cbbae2 100644 --- a/lib/src/network/message_dispatcher.rs +++ b/lib/src/network/message_dispatcher.rs @@ -157,12 +157,12 @@ mod tests { let server_dispatcher = MessageDispatcher::builder(server).build(); let (_server_sink, mut server_stream) = - server_dispatcher.open(topic_id, Arc::new(ByteCounters::new())); + server_dispatcher.open(topic_id, Arc::new(ByteCounters::default())); let client_dispatcher = MessageDispatcher::builder(client).build(); let (mut client_sink, _client_stream) = - client_dispatcher.open(topic_id, Arc::new(ByteCounters::new())); + client_dispatcher.open(topic_id, Arc::new(ByteCounters::default())); client_sink .send(Bytes::from_static(send_content)) diff --git a/lib/src/network/stats.rs b/lib/src/network/stats.rs index 71ce5f54d..443f74fa6 100644 --- a/lib/src/network/stats.rs +++ b/lib/src/network/stats.rs @@ -64,10 +64,6 @@ pub(super) struct ByteCounters { } impl ByteCounters { - pub fn new() -> Self { - Self::default() - } - pub fn increment_tx(&self, by: u64) { self.tx.fetch_add(by, Ordering::Relaxed); } diff --git a/lib/src/versioned.rs b/lib/src/versioned.rs index d92c41982..d9f1381eb 100644 --- a/lib/src/versioned.rs +++ b/lib/src/versioned.rs @@ -60,6 +60,7 @@ impl Container for Vec { // but wouldn't compile without `Default` because `Container` requires it. #[allow(dead_code)] #[derive(Default)] +#[expect(dead_code)] // False positive? This is used in `keep_maximal`. pub(crate) struct Discard; impl Container for Discard { diff --git a/net/src/tcp.rs b/net/src/tcp.rs index 08be110e1..ed05dd00e 100644 --- a/net/src/tcp.rs +++ b/net/src/tcp.rs @@ -19,7 +19,7 @@ pub fn configure( bind_addr: SocketAddr, options: SocketOptions, ) -> Result<(Connector, Acceptor), Error> { - let listener = TcpListener::bind_with_options(bind_addr, options)?; + let listener = TcpListener::bind(bind_addr, options)?; let local_addr = listener.local_addr()?; Ok(( @@ -297,7 +297,7 @@ mod implementation { impl TcpListener { /// Configures a TCP socket with the given options and binds it to the given address. If the /// port is taken, uses a random one, - pub fn bind_with_options(addr: SocketAddr, options: SocketOptions) -> io::Result { + pub fn bind(addr: SocketAddr, options: SocketOptions) -> io::Result { let socket = Socket::new(Domain::for_address(addr), Type::STREAM, None)?; socket.set_nonblocking(true)?; @@ -318,11 +318,6 @@ mod implementation { Ok(Self(tokio::net::TcpListener::from_std(socket.into())?)) } - /// Binds TCP socket to the given address. If the port is taken, uses a random one, - pub fn bind(addr: SocketAddr) -> io::Result { - Self::bind_with_options(addr, SocketOptions::default()) - } - pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> { self.0 .accept() From bacec66df1cdaa8ae0f4dbdb44ba43955619de89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Mon, 16 Sep 2024 10:29:33 +0200 Subject: [PATCH 24/55] Remove PendingDebugRequest,PendingDebugResponse (not needed) --- lib/src/network/client.rs | 12 ++--- lib/src/network/debug_payload.rs | 70 ++++++------------------------ lib/src/network/pending.rs | 14 +++--- lib/src/network/request_tracker.rs | 20 +++++++++ lib/src/network/server.rs | 29 ++++++------- lib/src/versioned.rs | 5 +-- 6 files changed, 60 insertions(+), 90 deletions(-) create mode 100644 lib/src/network/request_tracker.rs diff --git a/lib/src/network/client.rs b/lib/src/network/client.rs index a222b51af..9619072eb 100644 --- a/lib/src/network/client.rs +++ b/lib/src/network/client.rs @@ -1,6 +1,6 @@ use super::{ constants::RESPONSE_BATCH_SIZE, - debug_payload::{DebugResponse, PendingDebugRequest}, + debug_payload::{DebugRequest, DebugResponse}, message::{Message, Response, ResponseDisambiguator}, pending::{ EphemeralResponse, PendingRequest, PendingRequests, PersistableResponse, PreparedResponse, @@ -256,7 +256,7 @@ impl Inner { self.send_request(PendingRequest::ChildNodes( node.hash, ResponseDisambiguator::new(node.summary.block_presence), - debug_payload.clone().follow_up(), + debug_payload.follow_up(), )); } @@ -358,8 +358,7 @@ impl Inner { loop { let block_offer = block_offers.next().await; - let debug = PendingDebugRequest::start(); - self.send_request(PendingRequest::Block(block_offer, debug)); + self.send_request(PendingRequest::Block(block_offer, DebugRequest::start())); } } @@ -392,10 +391,7 @@ impl Inner { // requested as soon as possible. fn refresh_branches(&self, branches: impl IntoIterator) { for branch_id in branches { - self.send_request(PendingRequest::RootNode( - branch_id, - PendingDebugRequest::start(), - )); + self.send_request(PendingRequest::RootNode(branch_id, DebugRequest::start())); } } diff --git a/lib/src/network/debug_payload.rs b/lib/src/network/debug_payload.rs index db242e7b7..b172911e1 100644 --- a/lib/src/network/debug_payload.rs +++ b/lib/src/network/debug_payload.rs @@ -11,51 +11,25 @@ mod meaningful_data { static NEXT_ID: AtomicU64 = AtomicU64::new(0); - #[derive(Clone, Eq, PartialEq, Hash, Debug)] - pub(crate) struct PendingDebugRequest { - exchange_id: u64, - } - - impl PendingDebugRequest { - pub(crate) fn start() -> Self { - let exchange_id = NEXT_ID.fetch_add(1, Ordering::Relaxed); - Self { exchange_id } - } - - pub(crate) fn send(&self) -> DebugRequest { - DebugRequest { - exchange_id: self.exchange_id, - } - } - } - #[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)] pub(crate) struct DebugRequest { exchange_id: u64, } impl DebugRequest { - pub(crate) fn begin_reply(self) -> PendingDebugResponse { - PendingDebugResponse { - exchange_id: self.exchange_id, - } + pub(crate) fn start() -> Self { + let exchange_id = NEXT_ID.fetch_add(1, Ordering::Relaxed); + Self { exchange_id } } - } - - #[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)] - pub(crate) struct PendingDebugResponse { - exchange_id: u64, - } - impl PendingDebugResponse { - pub(crate) fn send(self) -> DebugResponse { + pub(crate) fn reply(&self) -> DebugResponse { DebugResponse { exchange_id: self.exchange_id, } } } - #[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)] + #[derive(Eq, PartialEq, Serialize, Deserialize, Debug)] pub(crate) struct DebugResponse { exchange_id: u64, } @@ -66,8 +40,8 @@ mod meaningful_data { Self { exchange_id } } - pub(crate) fn follow_up(self) -> PendingDebugRequest { - PendingDebugRequest { + pub(crate) fn follow_up(&self) -> DebugRequest { + DebugRequest { exchange_id: self.exchange_id, } } @@ -78,38 +52,20 @@ mod meaningful_data { mod dummy_data { use serde::{Deserialize, Serialize}; - #[derive(Clone, Eq, PartialEq, Hash, Debug)] - pub(crate) struct PendingDebugRequest {} - - impl PendingDebugRequest { - pub(crate) fn start() -> Self { - Self {} - } - - pub(crate) fn send(&self) -> DebugRequest { - DebugRequest {} - } - } - #[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)] pub(crate) struct DebugRequest {} impl DebugRequest { - pub(crate) fn begin_reply(self) -> PendingDebugResponse { - PendingDebugResponse {} + pub(crate) fn start() -> Self { + Self {} } - } - - #[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)] - pub(crate) struct PendingDebugResponse {} - impl PendingDebugResponse { - pub(crate) fn send(self) -> DebugResponse { + pub(crate) fn reply(&self) -> DebugResponse { DebugResponse {} } } - #[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)] + #[derive(Eq, PartialEq, Serialize, Deserialize, Debug)] pub(crate) struct DebugResponse {} impl DebugResponse { @@ -117,8 +73,8 @@ mod dummy_data { Self {} } - pub(crate) fn follow_up(self) -> PendingDebugRequest { - PendingDebugRequest {} + pub(crate) fn follow_up(&self) -> DebugRequest { + DebugRequest {} } } } diff --git a/lib/src/network/pending.rs b/lib/src/network/pending.rs index 7d9912d42..ce62c39ec 100644 --- a/lib/src/network/pending.rs +++ b/lib/src/network/pending.rs @@ -1,6 +1,6 @@ use super::{ constants::REQUEST_TIMEOUT, - debug_payload::{DebugResponse, PendingDebugRequest}, + debug_payload::{DebugRequest, DebugResponse}, message::{Request, Response, ResponseDisambiguator}, }; use crate::{ @@ -18,9 +18,9 @@ use std::{task::Poll, time::Instant}; use tokio::sync::Notify; pub(crate) enum PendingRequest { - RootNode(PublicKey, PendingDebugRequest), - ChildNodes(Hash, ResponseDisambiguator, PendingDebugRequest), - Block(BlockOffer, PendingDebugRequest), + RootNode(PublicKey, DebugRequest), + ChildNodes(Hash, ResponseDisambiguator, DebugRequest), + Block(BlockOffer, DebugRequest), } /// Response that's been prepared for processing. @@ -124,17 +124,17 @@ impl PendingRequests { PendingRequest::RootNode(writer_id, debug) => self .index .try_insert(IndexKey::RootNode(writer_id)) - .then(|| Request::RootNode(writer_id, debug.send()))?, + .then_some(Request::RootNode(writer_id, debug))?, PendingRequest::ChildNodes(hash, disambiguator, debug) => self .index .try_insert(IndexKey::ChildNodes(hash, disambiguator)) - .then(|| Request::ChildNodes(hash, disambiguator, debug.send()))?, + .then_some(Request::ChildNodes(hash, disambiguator, debug))?, PendingRequest::Block(block_offer, debug) => { let block_promise = block_offer.accept()?; let block_id = *block_promise.block_id(); self.block .try_insert(block_promise) - .then(|| Request::Block(block_id, debug.send()))? + .then_some(Request::Block(block_id, debug))? } }; diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs new file mode 100644 index 000000000..a41bf0ca1 --- /dev/null +++ b/lib/src/network/request_tracker.rs @@ -0,0 +1,20 @@ +/* +pub(super) struct RequestTracker {} + +impl RequestTracker { + pub fn new() -> Self { + // ... + } +} + +pub(super) struct RequestTrackerClient { + +} + +pub(super) type TrackedRequests = ReceiverStream< + +} + + +struct State {} +*/ diff --git a/lib/src/network/server.rs b/lib/src/network/server.rs index a3a533762..630f0055f 100644 --- a/lib/src/network/server.rs +++ b/lib/src/network/server.rs @@ -105,8 +105,6 @@ impl Inner { #[instrument(skip(self, debug), err(Debug))] async fn handle_root_node(&self, writer_id: PublicKey, debug: DebugRequest) -> Result<()> { - let debug = debug.begin_reply(); - let root_node = self .vault .store() @@ -122,7 +120,7 @@ impl Inner { let response = Response::RootNode( node.proof.into(), node.summary.block_presence, - debug.send(), + debug.reply(), ); self.enqueue_response(response).await; @@ -130,12 +128,12 @@ impl Inner { } Err(store::Error::BranchNotFound) => { tracing::trace!("root node not found"); - self.enqueue_response(Response::RootNodeError(writer_id, debug.send())) + self.enqueue_response(Response::RootNodeError(writer_id, debug.reply())) .await; Ok(()) } Err(error) => { - self.enqueue_response(Response::RootNodeError(writer_id, debug.send())) + self.enqueue_response(Response::RootNodeError(writer_id, debug.reply())) .await; Err(error.into()) } @@ -149,8 +147,6 @@ impl Inner { disambiguator: ResponseDisambiguator, debug: DebugRequest, ) -> Result<()> { - let debug = debug.begin_reply(); - let mut reader = self.vault.store().acquire_read().await?; // At most one of these will be non-empty. @@ -165,22 +161,26 @@ impl Inner { self.enqueue_response(Response::InnerNodes( inner_nodes, disambiguator, - debug.clone().send(), + debug.reply(), )) .await; } if !leaf_nodes.is_empty() { tracing::trace!("leaf nodes found"); - self.enqueue_response(Response::LeafNodes(leaf_nodes, disambiguator, debug.send())) - .await; + self.enqueue_response(Response::LeafNodes( + leaf_nodes, + disambiguator, + debug.reply(), + )) + .await; } } else { tracing::trace!("child nodes not found"); self.enqueue_response(Response::ChildNodesError( parent_hash, disambiguator, - debug.send(), + debug.reply(), )) .await; } @@ -190,7 +190,6 @@ impl Inner { #[instrument(skip(self, debug), err(Debug))] async fn handle_block(&self, block_id: BlockId, debug: DebugRequest) -> Result<()> { - let debug = debug.begin_reply(); let mut content = BlockContent::new(); let result = self .vault @@ -203,18 +202,18 @@ impl Inner { match result { Ok(nonce) => { tracing::trace!("block found"); - self.enqueue_response(Response::Block(content, nonce, debug.send())) + self.enqueue_response(Response::Block(content, nonce, debug.reply())) .await; Ok(()) } Err(store::Error::BlockNotFound) => { tracing::trace!("block not found"); - self.enqueue_response(Response::BlockError(block_id, debug.send())) + self.enqueue_response(Response::BlockError(block_id, debug.reply())) .await; Ok(()) } Err(error) => { - self.enqueue_response(Response::BlockError(block_id, debug.send())) + self.enqueue_response(Response::BlockError(block_id, debug.reply())) .await; Err(error.into()) } diff --git a/lib/src/versioned.rs b/lib/src/versioned.rs index d9f1381eb..b9bc65e4b 100644 --- a/lib/src/versioned.rs +++ b/lib/src/versioned.rs @@ -56,11 +56,10 @@ impl Container for Vec { } /// Container for outdated items that discards them. +#[derive(Default)] // TODO: Clippy complains that `Discard` is never constructed in rust 1.81.0 // but wouldn't compile without `Default` because `Container` requires it. -#[allow(dead_code)] -#[derive(Default)] -#[expect(dead_code)] // False positive? This is used in `keep_maximal`. +#[expect(dead_code)] pub(crate) struct Discard; impl Container for Discard { From a9b2a16f45c980c84a76ab02031c8c093f913c54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Tue, 17 Sep 2024 14:53:48 +0200 Subject: [PATCH 25/55] Add initial implementation of RequestTracker --- lib/src/network/message.rs | 2 +- lib/src/network/mod.rs | 1 + lib/src/network/request_tracker.rs | 763 ++++++++++++++++++++++++++++- lib/src/protocol/test_utils.rs | 10 +- 4 files changed, 768 insertions(+), 8 deletions(-) diff --git a/lib/src/network/message.rs b/lib/src/network/message.rs index 478cac2b5..061b9bab0 100644 --- a/lib/src/network/message.rs +++ b/lib/src/network/message.rs @@ -11,7 +11,7 @@ use crate::{ }; use serde::{Deserialize, Serialize}; -#[derive(Clone, PartialEq, Serialize, Deserialize, Debug)] +#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)] pub(crate) enum Request { /// Request the latest root node of the given writer. RootNode(PublicKey, DebugRequest), diff --git a/lib/src/network/mod.rs b/lib/src/network/mod.rs index af6cb346f..4c0d124fd 100644 --- a/lib/src/network/mod.rs +++ b/lib/src/network/mod.rs @@ -18,6 +18,7 @@ mod peer_source; mod peer_state; mod pending; mod protocol; +mod request_tracker; mod runtime_id; mod seen_peers; mod server; diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index a41bf0ca1..4e88b5946 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -1,20 +1,771 @@ -/* -pub(super) struct RequestTracker {} +use super::{constants::REQUEST_TIMEOUT, message::Request}; +use crate::{ + collections::{HashMap, HashSet}, + crypto::{sign::PublicKey, Hash}, + protocol::BlockId, +}; +use std::{ + collections::hash_map::Entry, + iter, + sync::atomic::{AtomicUsize, Ordering}, +}; +use tokio::{ + select, + sync::{mpsc, oneshot}, + task, +}; +use tokio_stream::StreamExt; +use tokio_util::time::{delay_queue, DelayQueue}; +use tracing::instrument; + +/// Keeps track of in-flight requests. Falls back on another peer in case the request failed (due to +/// error response, timeout or disconnection). Evenly distributes the requests between the peers +/// and ensures every request is only sent to one peer at a time. +pub(super) struct RequestTracker { + command_tx: mpsc::UnboundedSender, +} impl RequestTracker { + #[cfg_attr(not(test), expect(dead_code))] pub fn new() -> Self { - // ... + let (command_tx, command_rx) = mpsc::unbounded_channel(); + let state = State::new(); + + task::spawn(run(state, command_rx)); + + Self { command_tx } + } + + #[cfg_attr(not(test), expect(dead_code))] + pub fn new_client(&self) -> (RequestTrackerClient, mpsc::UnboundedReceiver) { + let client_id = ClientId::next(); + let (request_tx, request_rx) = mpsc::unbounded_channel(); + + self.command_tx + .send(Command::InsertClient { + client_id, + request_tx, + }) + .ok(); + + ( + RequestTrackerClient { + client_id, + command_tx: self.command_tx.clone(), + }, + request_rx, + ) + } + + // Wait until all previously invoked operations complete and returns the number of request that + // are still being tracked. This is useful to ensure that if any of the operations caused a + // request to be scheduled for sending, that request has already been queued, thus calling + // `try_recv` on the corresponding receiver is guaranteed to return it. This is mostly useful + // for testing. + #[cfg_attr(not(test), expect(dead_code))] + pub async fn flush(&self) -> usize { + let (reply_tx, reply_rx) = oneshot::channel(); + self.command_tx.send(Command::Flush { reply_tx }).ok(); + reply_rx.await.ok().unwrap_or(0) } } pub(super) struct RequestTrackerClient { + client_id: ClientId, + command_tx: mpsc::UnboundedSender, +} + +impl RequestTrackerClient { + /// Handle sending a request that does not follow from any previously received response. + #[expect(dead_code)] + pub fn initial(&self, request: Request) { + self.command_tx + .send(Command::HandleInitial { + client_id: self.client_id, + request, + }) + .ok(); + } + + /// Handle sending requests that follow from a received success response. + #[cfg_attr(not(test), expect(dead_code))] + pub fn success(&self, response_key: MessageKey, requests: Vec) { + self.command_tx + .send(Command::HandleSuccess { + client_id: self.client_id, + response_key, + requests, + }) + .ok(); + } + + /// Handle failure response. + #[cfg_attr(not(test), expect(dead_code))] + pub fn failure(&self, response_key: MessageKey) { + self.command_tx + .send(Command::HandleFailure { + client_id: self.client_id, + response_key, + }) + .ok(); + } +} +impl Drop for RequestTrackerClient { + fn drop(&mut self) { + self.command_tx + .send(Command::RemoveClient { + client_id: self.client_id, + }) + .ok(); + } } -pub(super) type TrackedRequests = ReceiverStream< +/// Key identifying a request and its corresponding response. +#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] +pub(super) enum MessageKey { + RootNode(PublicKey), + ChildNodes(Hash), + Block(BlockId), +} +impl<'a> From<&'a Request> for MessageKey { + fn from(request: &'a Request) -> Self { + match request { + Request::RootNode(writer_id, _) => MessageKey::RootNode(*writer_id), + Request::ChildNodes(hash, _, _) => MessageKey::ChildNodes(*hash), + Request::Block(block_id, _) => MessageKey::Block(*block_id), + } + } } +async fn run(mut state: State, mut command_rx: mpsc::UnboundedReceiver) { + loop { + let command = select! { + command = command_rx.recv() => command, + true = state.wait_for_timeout() => continue, + }; + + match command { + Some(Command::InsertClient { + client_id, + request_tx, + }) => { + state.insert_client(client_id, request_tx); + } + Some(Command::RemoveClient { client_id }) => { + state.remove_client(client_id); + } + Some(Command::HandleInitial { client_id, request }) => { + state.handle_initial(client_id, request); + } + Some(Command::HandleSuccess { + client_id, + response_key, + requests, + }) => { + state.handle_success(client_id, response_key, requests); + } + Some(Command::HandleFailure { + client_id, + response_key, + }) => { + state.handle_failure(client_id, response_key); + } + Some(Command::Flush { reply_tx }) => { + reply_tx.send(state.request_count()).ok(); + } + None => break, + } + } +} + +#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] +struct ClientId(usize); + +impl ClientId { + fn next() -> Self { + static NEXT: AtomicUsize = AtomicUsize::new(0); + Self(NEXT.fetch_add(1, Ordering::Relaxed)) + } +} + +enum Command { + InsertClient { + client_id: ClientId, + request_tx: mpsc::UnboundedSender, + }, + RemoveClient { + client_id: ClientId, + }, + HandleInitial { + client_id: ClientId, + request: Request, + }, + HandleSuccess { + client_id: ClientId, + response_key: MessageKey, + requests: Vec, + }, + HandleFailure { + client_id: ClientId, + response_key: MessageKey, + }, + Flush { + reply_tx: oneshot::Sender, + }, +} + +struct State { + clients: HashMap, + requests: HashMap, + timeouts: DelayQueue<(ClientId, MessageKey)>, +} + +impl State { + fn new() -> Self { + Self { + clients: HashMap::default(), + requests: HashMap::default(), + timeouts: DelayQueue::new(), + } + } + + #[instrument(skip(self, request_tx))] + fn insert_client(&mut self, client_id: ClientId, request_tx: mpsc::UnboundedSender) { + self.clients.insert( + client_id, + ClientState { + request_tx, + requests: HashSet::default(), + }, + ); + } + + #[instrument(skip(self))] + fn remove_client(&mut self, client_id: ClientId) { + let Some(client_state) = self.clients.remove(&client_id) else { + return; + }; + + for message_key in client_state.requests { + let Entry::Occupied(mut entry) = self.requests.entry(message_key) else { + continue; + }; + + let Some(client_request_state) = entry.get_mut().clients.remove(&client_id) else { + continue; + }; + + if let Some(timeout_key) = client_request_state.timeout_key { + self.timeouts.try_remove(&timeout_key); + } + + if entry.get().clients.is_empty() { + entry.remove(); + continue; + } + + // TODO: fallback to another client, if any + } + } + + #[instrument(skip(self))] + fn handle_initial(&mut self, client_id: ClientId, request: Request) { + let Some(client_state) = self.clients.get_mut(&client_id) else { + // client not inserted + return; + }; + + let request_key = MessageKey::from(&request); + + client_state.requests.insert(request_key); + + let request_state = self + .requests + .entry(request_key) + .or_insert_with(|| RequestState { + request, + clients: HashMap::default(), + }); + + let timeout_key = if request_state + .clients + .values() + .all(|state| state.timeout_key.is_none()) + { + client_state + .request_tx + .send(request_state.request.clone()) + .ok(); + + Some( + self.timeouts + .insert((client_id, request_key), REQUEST_TIMEOUT), + ) + } else { + None + }; + + request_state + .clients + .insert(client_id, RequestClientState { timeout_key }); + } + + #[instrument(skip(self))] + fn handle_success( + &mut self, + client_id: ClientId, + response_key: MessageKey, + requests: Vec, + ) { + if let Some(state) = self.clients.get_mut(&client_id) { + state.requests.remove(&response_key); + } + + let mut followup_client_ids = if requests.is_empty() { + vec![] + } else { + vec![client_id] + }; + + if let Entry::Occupied(mut entry) = self.requests.entry(response_key) { + entry.get_mut().clients.retain(|other_client_id, state| { + // TODO: remove only those with the same or worse block presence. + + if let Some(timeout_key) = state.timeout_key { + self.timeouts.try_remove(&timeout_key); + } + + if !requests.is_empty() && *other_client_id != client_id { + followup_client_ids.push(*other_client_id); + } + + false + }); + + if entry.get().clients.is_empty() { + entry.remove(); + } + } + + for request in requests { + for (client_id, request) in followup_client_ids + .iter() + .copied() + // TODO: use `repeat_n` once it gets stabilized. + .zip(iter::repeat(request)) + { + self.handle_initial(client_id, request); + } + + // round-robin the requests across the clients + followup_client_ids.rotate_right(1); + } + } + + #[instrument(skip(self))] + fn handle_failure(&mut self, client_id: ClientId, response_key: MessageKey) { + if let Some(state) = self.clients.get_mut(&client_id) { + state.requests.remove(&response_key); + } + + let Entry::Occupied(mut entry) = self.requests.entry(response_key) else { + return; + }; + + if let Some(state) = entry.get_mut().clients.remove(&client_id) { + if let Some(timeout_key) = state.timeout_key { + self.timeouts.try_remove(&timeout_key); + } + } + + // TODO: prefer one with the same or better block presence as `client_id`. + if let Some((fallback_client_id, state)) = entry + .get_mut() + .clients + .iter_mut() + .find(|(_, state)| state.timeout_key.is_none()) + { + state.timeout_key = Some( + self.timeouts + .insert((*fallback_client_id, response_key), REQUEST_TIMEOUT), + ); + + // TODO: send the request + } + + if entry.get().clients.is_empty() { + entry.remove(); + } + } -struct State {} -*/ + /// Wait for the next timeout. Returns `true` if a requested timeouted, `false` if there are no + /// tracked requests. + async fn wait_for_timeout(&mut self) -> bool { + if let Some(expired) = self.timeouts.next().await { + let (client_id, request_key) = expired.into_inner(); + self.handle_failure(client_id, request_key); + + true + } else { + false + } + } + + fn request_count(&self) -> usize { + self.requests.len() + } +} + +struct ClientState { + request_tx: mpsc::UnboundedSender, + requests: HashSet, +} + +struct RequestState { + request: Request, + clients: HashMap, +} + +struct RequestClientState { + // disambiguator: ResponseDisambiguator, + timeout_key: Option, +} + +#[cfg(test)] +mod tests { + use super::{ + super::{debug_payload::DebugResponse, message::Response}, + *, + }; + use crate::{ + crypto::{sign::Keypair, Hashable}, + network::message::ResponseDisambiguator, + protocol::{test_utils::Snapshot, Block, MultiBlockPresence, Proof, UntrustedProof}, + test_utils, + version_vector::VersionVector, + }; + use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; + use std::collections::VecDeque; + use test_strategy::proptest; + + #[proptest] + fn sanity_check( + #[strategy(test_utils::rng_seed_strategy())] seed: u64, + #[strategy(1usize..=32)] num_blocks: usize, + #[strategy(1usize..=3)] num_peers: usize, + ) { + test_utils::run(sanity_check_case(seed, num_blocks, num_peers)); + } + + async fn sanity_check_case(seed: u64, num_blocks: usize, num_peers: usize) { + test_utils::init_log(); + + let mut rng = StdRng::seed_from_u64(seed); + + let snapshot = Snapshot::generate(&mut rng, num_blocks); + let mut summary = Summary::default(); + + let tracker = RequestTracker::new(); + + let mut peers: Vec<_> = (0..num_peers) + .map(|_| { + let (tracker_client, tracker_request_rx) = tracker.new_client(); + let client = TestClient::new(tracker_client, tracker_request_rx); + + let writer_id = PublicKey::generate(&mut rng); + let write_keys = Keypair::generate(&mut rng); + let server = TestServer::new(writer_id, write_keys, &snapshot); + + (client, server) + }) + .collect(); + + while poll(&mut rng, &mut peers, &snapshot, &mut summary) { + tracker.flush().await; + } + + summary.verify(peers.len(), &snapshot); + + assert_eq!(tracker.flush().await, 0); + } + + #[derive(Default)] + struct Summary { + nodes: HashMap, + blocks: HashMap, + } + + impl Summary { + fn receive_node(&mut self, hash: Hash) { + *self.nodes.entry(hash).or_default() += 1; + } + + fn receive_block(&mut self, block_id: BlockId) { + *self.blocks.entry(block_id).or_default() += 1; + } + + fn verify(&mut self, num_peers: usize, snapshot: &Snapshot) { + assert_eq!( + self.nodes.remove(snapshot.root_hash()).unwrap_or(0), + num_peers, + "root node not received exactly {num_peers} times: {:?}", + snapshot.root_hash() + ); + + for hash in snapshot + .inner_nodes() + .map(|node| &node.hash) + .chain(snapshot.leaf_nodes().map(|node| &node.locator)) + { + assert_eq!( + self.nodes.remove(hash).unwrap_or(0), + 1, + "child node not received exactly once: {hash:?}" + ); + } + + for block_id in snapshot.blocks().keys() { + assert_eq!( + self.blocks.remove(block_id).unwrap_or(0), + 1, + "block not received exactly once: {block_id:?}" + ); + } + + // Verify we received only the expected nodes and blocks + assert!( + self.nodes.is_empty(), + "unexpected nodes received: {:?}", + self.nodes + ); + assert!( + self.blocks.is_empty(), + "unexpected blocks received: {:?}", + self.blocks + ); + } + } + + struct TestClient { + tracker_client: RequestTrackerClient, + tracker_request_rx: mpsc::UnboundedReceiver, + } + + impl TestClient { + fn new( + tracker_client: RequestTrackerClient, + tracker_request_rx: mpsc::UnboundedReceiver, + ) -> Self { + Self { + tracker_client, + tracker_request_rx, + } + } + + fn handle_response(&mut self, response: Response, summary: &mut Summary) { + match response { + Response::RootNode(proof, block_presence, debug_payload) => { + summary.receive_node(proof.hash); + + let requests = vec![Request::ChildNodes( + proof.hash, + ResponseDisambiguator::new(block_presence), + debug_payload.follow_up(), + )]; + + self.tracker_client + .success(MessageKey::RootNode(proof.writer_id), requests); + } + Response::InnerNodes(nodes, _disambiguator, debug_payload) => { + let parent_hash = nodes.hash(); + let requests: Vec<_> = nodes + .into_iter() + .map(|(_, node)| { + summary.receive_node(node.hash); + + Request::ChildNodes( + node.hash, + ResponseDisambiguator::new(node.summary.block_presence), + debug_payload.follow_up(), + ) + }) + .collect(); + + self.tracker_client + .success(MessageKey::ChildNodes(parent_hash), requests); + } + Response::LeafNodes(nodes, _disambiguator, debug_payload) => { + let parent_hash = nodes.hash(); + let requests = nodes + .into_iter() + .map(|node| { + summary.receive_node(node.locator); + + Request::Block(node.block_id, debug_payload.follow_up()) + }) + .collect(); + + self.tracker_client + .success(MessageKey::ChildNodes(parent_hash), requests); + } + Response::Block(content, nonce, _debug_payload) => { + let block = Block::new(content, nonce); + + summary.receive_block(block.id); + + self.tracker_client + .success(MessageKey::Block(block.id), vec![]); + } + Response::RootNodeError(writer_id, _debug_payload) => { + self.tracker_client.failure(MessageKey::RootNode(writer_id)); + } + Response::ChildNodesError(hash, _disambiguator, _debug_payload) => { + self.tracker_client.failure(MessageKey::ChildNodes(hash)); + } + Response::BlockError(block_id, _debug_payload) => { + self.tracker_client.failure(MessageKey::Block(block_id)); + } + Response::BlockOffer(_block_id, _debug_payload) => unimplemented!(), + }; + } + + fn poll_request(&mut self) -> Option { + self.tracker_request_rx.try_recv().ok() + } + } + + struct TestServer { + writer_id: PublicKey, + write_keys: Keypair, + outbox: VecDeque, + } + + impl TestServer { + fn new(writer_id: PublicKey, write_keys: Keypair, snapshot: &Snapshot) -> Self { + let proof = UntrustedProof::from(Proof::new( + writer_id, + VersionVector::first(writer_id), + *snapshot.root_hash(), + &write_keys, + )); + + let outbox = [Response::RootNode( + proof.clone(), + MultiBlockPresence::Full, + DebugResponse::unsolicited(), + )] + .into(); + + Self { + writer_id, + write_keys, + outbox, + } + } + + fn handle_request(&mut self, request: Request, snapshot: &Snapshot) { + match request { + Request::RootNode(writer_id, debug_payload) => { + if writer_id == self.writer_id { + let proof = Proof::new( + writer_id, + VersionVector::first(writer_id), + *snapshot.root_hash(), + &self.write_keys, + ); + + self.outbox.push_back(Response::RootNode( + proof.into(), + MultiBlockPresence::Full, + debug_payload.reply(), + )); + } else { + self.outbox + .push_back(Response::RootNodeError(writer_id, debug_payload.reply())); + } + } + Request::ChildNodes(hash, disambiguator, debug_payload) => { + if let Some(nodes) = snapshot + .inner_layers() + .flat_map(|layer| layer.inner_maps()) + .find_map(|(parent_hash, nodes)| (*parent_hash == hash).then_some(nodes)) + { + self.outbox.push_back(Response::InnerNodes( + nodes.clone(), + disambiguator, + debug_payload.reply(), + )); + } + + if let Some(nodes) = snapshot + .leaf_sets() + .find_map(|(parent_hash, nodes)| (*parent_hash == hash).then_some(nodes)) + { + self.outbox.push_back(Response::LeafNodes( + nodes.clone(), + disambiguator, + debug_payload.reply(), + )); + } + } + Request::Block(block_id, debug_payload) => { + if let Some(block) = snapshot.blocks().get(&block_id) { + self.outbox.push_back(Response::Block( + block.content.clone(), + block.nonce, + debug_payload.reply(), + )); + } + } + } + } + + fn poll_response(&mut self) -> Option { + self.outbox.pop_front() + } + } + + // Polls every client and server once, in random order + fn poll( + rng: &mut R, + peers: &mut [(TestClient, TestServer)], + snapshot: &Snapshot, + summary: &mut Summary, + ) -> bool { + enum Side { + Client, + Server, + } + + let mut order: Vec<_> = (0..peers.len()) + .flat_map(|index| [(Side::Client, index), (Side::Server, index)]) + .collect(); + + order.shuffle(rng); + + let mut changed = false; + + for (side, index) in order { + let (client, server) = &mut peers[index]; + + match side { + Side::Client => { + if let Some(request) = client.poll_request() { + server.handle_request(request, snapshot); + changed = true; + } + } + Side::Server => { + if let Some(response) = server.poll_response() { + client.handle_response(response, summary); + changed = true; + } + } + } + } + + changed + } +} diff --git a/lib/src/protocol/test_utils.rs b/lib/src/protocol/test_utils.rs index 03a7deefe..63c7dcf2c 100644 --- a/lib/src/protocol/test_utils.rs +++ b/lib/src/protocol/test_utils.rs @@ -103,6 +103,14 @@ impl Snapshot { (0..self.inners.len()).map(move |inner_layer| InnerLayer(self, inner_layer)) } + pub fn inner_nodes(&self) -> impl Iterator { + self.inners + .iter() + .flat_map(|layer| layer.values()) + .flatten() + .map(|(_, node)| node) + } + pub fn blocks(&self) -> &HashMap { &self.blocks } @@ -126,7 +134,7 @@ impl Snapshot { pub(crate) struct InnerLayer<'a>(&'a Snapshot, usize); impl<'a> InnerLayer<'a> { - pub fn inner_maps(&self) -> impl Iterator { + pub fn inner_maps(self) -> impl Iterator { self.0.inners[self.1].iter().map(move |(path, nodes)| { let parent_hash = self.0.parent_hash(self.1, path); (parent_hash, nodes) From 5b3538c5ba7d08ccb523ce52f0cbeeaf269e3e29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Tue, 17 Sep 2024 16:00:15 +0200 Subject: [PATCH 26/55] Refactor RequestTracker for better testability --- lib/src/network/request_tracker.rs | 237 ++++++++++++++--------------- 1 file changed, 117 insertions(+), 120 deletions(-) diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index 4e88b5946..c7a180a6f 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -9,11 +9,7 @@ use std::{ iter, sync::atomic::{AtomicUsize, Ordering}, }; -use tokio::{ - select, - sync::{mpsc, oneshot}, - task, -}; +use tokio::{select, sync::mpsc, task}; use tokio_stream::StreamExt; use tokio_util::time::{delay_queue, DelayQueue}; use tracing::instrument; @@ -26,14 +22,11 @@ pub(super) struct RequestTracker { } impl RequestTracker { - #[cfg_attr(not(test), expect(dead_code))] + #[expect(dead_code)] pub fn new() -> Self { - let (command_tx, command_rx) = mpsc::unbounded_channel(); - let state = State::new(); - - task::spawn(run(state, command_rx)); - - Self { command_tx } + let (this, worker) = build(); + task::spawn(worker.run()); + this } #[cfg_attr(not(test), expect(dead_code))] @@ -56,18 +49,6 @@ impl RequestTracker { request_rx, ) } - - // Wait until all previously invoked operations complete and returns the number of request that - // are still being tracked. This is useful to ensure that if any of the operations caused a - // request to be scheduled for sending, that request has already been queued, thus calling - // `try_recv` on the corresponding receiver is guaranteed to return it. This is mostly useful - // for testing. - #[cfg_attr(not(test), expect(dead_code))] - pub async fn flush(&self) -> usize { - let (reply_tx, reply_rx) = oneshot::channel(); - self.command_tx.send(Command::Flush { reply_tx }).ok(); - reply_rx.await.ok().unwrap_or(0) - } } pub(super) struct RequestTrackerClient { @@ -139,99 +120,92 @@ impl<'a> From<&'a Request> for MessageKey { } } -async fn run(mut state: State, mut command_rx: mpsc::UnboundedReceiver) { - loop { - let command = select! { - command = command_rx.recv() => command, - true = state.wait_for_timeout() => continue, - }; +fn build() -> (RequestTracker, Worker) { + let (command_tx, command_rx) = mpsc::unbounded_channel(); + (RequestTracker { command_tx }, Worker::new(command_rx)) +} + +struct Worker { + clients: HashMap, + requests: HashMap, + timeouts: DelayQueue<(ClientId, MessageKey)>, + command_rx: mpsc::UnboundedReceiver, +} + +impl Worker { + fn new(command_rx: mpsc::UnboundedReceiver) -> Self { + Self { + clients: HashMap::default(), + requests: HashMap::default(), + timeouts: DelayQueue::new(), + command_rx, + } + } + + pub async fn run(mut self) { + loop { + select! { + command = self.command_rx.recv() => { + if let Some(command) = command { + self.handle_command(command); + } else { + break; + } + } + Some(expired) = self.timeouts.next() => { + let (client_id, request_key) = expired.into_inner(); + self.handle_failure(client_id, request_key); + continue; + } + } + } + } + /// Process all currently queued commands. + #[cfg(test)] + pub fn step(&mut self) { + while let Ok(command) = self.command_rx.try_recv() { + self.handle_command(command); + } + + // TODO: Check timeouts + } + + #[cfg(test)] + pub fn request_count(&self) -> usize { + self.requests.len() + } + + fn handle_command(&mut self, command: Command) { match command { - Some(Command::InsertClient { + Command::InsertClient { client_id, request_tx, - }) => { - state.insert_client(client_id, request_tx); + } => { + self.insert_client(client_id, request_tx); } - Some(Command::RemoveClient { client_id }) => { - state.remove_client(client_id); + Command::RemoveClient { client_id } => { + self.remove_client(client_id); } - Some(Command::HandleInitial { client_id, request }) => { - state.handle_initial(client_id, request); + Command::HandleInitial { client_id, request } => { + self.handle_initial(client_id, request); } - Some(Command::HandleSuccess { + Command::HandleSuccess { client_id, response_key, requests, - }) => { - state.handle_success(client_id, response_key, requests); + } => { + self.handle_success(client_id, response_key, requests); } - Some(Command::HandleFailure { + Command::HandleFailure { client_id, response_key, - }) => { - state.handle_failure(client_id, response_key); + } => { + self.handle_failure(client_id, response_key); } - Some(Command::Flush { reply_tx }) => { - reply_tx.send(state.request_count()).ok(); - } - None => break, - } - } -} - -#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] -struct ClientId(usize); - -impl ClientId { - fn next() -> Self { - static NEXT: AtomicUsize = AtomicUsize::new(0); - Self(NEXT.fetch_add(1, Ordering::Relaxed)) - } -} - -enum Command { - InsertClient { - client_id: ClientId, - request_tx: mpsc::UnboundedSender, - }, - RemoveClient { - client_id: ClientId, - }, - HandleInitial { - client_id: ClientId, - request: Request, - }, - HandleSuccess { - client_id: ClientId, - response_key: MessageKey, - requests: Vec, - }, - HandleFailure { - client_id: ClientId, - response_key: MessageKey, - }, - Flush { - reply_tx: oneshot::Sender, - }, -} - -struct State { - clients: HashMap, - requests: HashMap, - timeouts: DelayQueue<(ClientId, MessageKey)>, -} - -impl State { - fn new() -> Self { - Self { - clients: HashMap::default(), - requests: HashMap::default(), - timeouts: DelayQueue::new(), } } - #[instrument(skip(self, request_tx))] fn insert_client(&mut self, client_id: ClientId, request_tx: mpsc::UnboundedSender) { self.clients.insert( client_id, @@ -399,23 +373,39 @@ impl State { entry.remove(); } } +} - /// Wait for the next timeout. Returns `true` if a requested timeouted, `false` if there are no - /// tracked requests. - async fn wait_for_timeout(&mut self) -> bool { - if let Some(expired) = self.timeouts.next().await { - let (client_id, request_key) = expired.into_inner(); - self.handle_failure(client_id, request_key); +#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] +struct ClientId(usize); - true - } else { - false - } +impl ClientId { + fn next() -> Self { + static NEXT: AtomicUsize = AtomicUsize::new(0); + Self(NEXT.fetch_add(1, Ordering::Relaxed)) } +} - fn request_count(&self) -> usize { - self.requests.len() - } +enum Command { + InsertClient { + client_id: ClientId, + request_tx: mpsc::UnboundedSender, + }, + RemoveClient { + client_id: ClientId, + }, + HandleInitial { + client_id: ClientId, + request: Request, + }, + HandleSuccess { + client_id: ClientId, + response_key: MessageKey, + requests: Vec, + }, + HandleFailure { + client_id: ClientId, + response_key: MessageKey, + }, } struct ClientState { @@ -456,18 +446,25 @@ mod tests { #[strategy(1usize..=32)] num_blocks: usize, #[strategy(1usize..=3)] num_peers: usize, ) { - test_utils::run(sanity_check_case(seed, num_blocks, num_peers)); + sanity_check_case(seed, num_blocks, num_peers); } - async fn sanity_check_case(seed: u64, num_blocks: usize, num_peers: usize) { + fn sanity_check_case(seed: u64, num_blocks: usize, num_peers: usize) { test_utils::init_log(); + // Tokio runtime needed for `DelayQueue`. + let _runtime_guard = tokio::runtime::Builder::new_current_thread() + .enable_time() + .build() + .unwrap() + .enter(); + let mut rng = StdRng::seed_from_u64(seed); let snapshot = Snapshot::generate(&mut rng, num_blocks); let mut summary = Summary::default(); - let tracker = RequestTracker::new(); + let (tracker, mut tracker_worker) = build(); let mut peers: Vec<_> = (0..num_peers) .map(|_| { @@ -482,13 +479,13 @@ mod tests { }) .collect(); - while poll(&mut rng, &mut peers, &snapshot, &mut summary) { - tracker.flush().await; + while poll_peers(&mut rng, &mut peers, &snapshot, &mut summary) { + tracker_worker.step(); } summary.verify(peers.len(), &snapshot); - assert_eq!(tracker.flush().await, 0); + assert_eq!(tracker_worker.request_count(), 0); } #[derive(Default)] @@ -728,7 +725,7 @@ mod tests { } // Polls every client and server once, in random order - fn poll( + fn poll_peers( rng: &mut R, peers: &mut [(TestClient, TestServer)], snapshot: &Snapshot, From 369b7ba5132c58865ac83c12ad66515608f5c83b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Wed, 18 Sep 2024 08:13:06 +0200 Subject: [PATCH 27/55] Proptests now supports async directly --- lib/src/blob/tests.rs | 138 +++++++++++++---------------- lib/src/network/request_tracker.rs | 12 +-- lib/src/network/tests.rs | 18 ++-- lib/src/repository/vault/tests.rs | 6 +- lib/src/store/index.rs | 12 +-- lib/src/store/root_node.rs | 6 +- lib/src/store/tests.rs | 12 +-- lib/src/test_utils.rs | 12 --- 8 files changed, 89 insertions(+), 127 deletions(-) diff --git a/lib/src/blob/tests.rs b/lib/src/blob/tests.rs index 935ba5ba1..26ad1bb6c 100644 --- a/lib/src/blob/tests.rs +++ b/lib/src/blob/tests.rs @@ -40,17 +40,15 @@ async fn empty_blob() { store.close().await.unwrap(); } -#[proptest] -fn write_and_read( +#[proptest(async = "tokio")] +async fn write_and_read( is_root: bool, #[strategy(1..3 * BLOCK_SIZE)] blob_len: usize, #[strategy(1..=#blob_len)] write_len: usize, #[strategy(1..=#blob_len + 1)] read_len: usize, #[strategy(test_utils::rng_seed_strategy())] rng_seed: u64, ) { - test_utils::run(write_and_read_case( - is_root, blob_len, write_len, read_len, rng_seed, - )) + write_and_read_case(is_root, blob_len, write_len, read_len, rng_seed).await } async fn write_and_read_case( @@ -106,66 +104,60 @@ async fn write_and_read_case( store.close().await.unwrap(); } -#[proptest] -fn len( +#[proptest(async = "tokio")] +async fn len( #[strategy(0..3 * BLOCK_SIZE)] content_len: usize, #[strategy(test_utils::rng_seed_strategy())] rng_seed: u64, ) { - test_utils::run(async { - let (rng, _base_dir, store, [branch]) = setup(rng_seed).await; - let mut tx = store.begin_write().await.unwrap(); - let mut changeset = Changeset::new(); + let (rng, _base_dir, store, [branch]) = setup(rng_seed).await; + let mut tx = store.begin_write().await.unwrap(); + let mut changeset = Changeset::new(); - let content = random_bytes(rng, content_len); + let content = random_bytes(rng, content_len); - let mut blob = Blob::create(branch.clone(), BlobId::ROOT); - blob.write_all(&mut tx, &mut changeset, &content[..]) - .await - .unwrap(); - assert_eq!(blob.len(), content_len as u64); + let mut blob = Blob::create(branch.clone(), BlobId::ROOT); + blob.write_all(&mut tx, &mut changeset, &content[..]) + .await + .unwrap(); + assert_eq!(blob.len(), content_len as u64); - blob.flush(&mut tx, &mut changeset).await.unwrap(); - changeset - .apply(&mut tx, branch.id(), branch.keys().write().unwrap()) - .await - .unwrap(); + blob.flush(&mut tx, &mut changeset).await.unwrap(); + changeset + .apply(&mut tx, branch.id(), branch.keys().write().unwrap()) + .await + .unwrap(); - assert_eq!(blob.len(), content_len as u64); + assert_eq!(blob.len(), content_len as u64); - let blob = Blob::open(&mut tx, branch, BlobId::ROOT).await.unwrap(); - assert_eq!(blob.len(), content_len as u64); + let blob = Blob::open(&mut tx, branch, BlobId::ROOT).await.unwrap(); + assert_eq!(blob.len(), content_len as u64); - drop(tx); - store.close().await.unwrap(); - }) + drop(tx); + store.close().await.unwrap(); } -#[proptest] -fn seek_from_start( +#[proptest(async = "tokio")] +async fn seek_from_start( #[strategy(0..2 * BLOCK_SIZE)] content_len: usize, #[strategy(0..=#content_len)] pos: usize, #[strategy(test_utils::rng_seed_strategy())] rng_seed: u64, ) { - test_utils::run(seek_from( - content_len, - SeekFrom::Start(pos as u64), - pos, - rng_seed, - )) + seek_from(content_len, SeekFrom::Start(pos as u64), pos, rng_seed).await } -#[proptest] -fn seek_from_end( +#[proptest(async = "tokio")] +async fn seek_from_end( #[strategy(0..2 * BLOCK_SIZE)] content_len: usize, #[strategy(0..=#content_len)] pos: usize, #[strategy(test_utils::rng_seed_strategy())] rng_seed: u64, ) { - test_utils::run(seek_from( + seek_from( content_len, SeekFrom::End(-((content_len - pos) as i64)), pos, rng_seed, - )) + ) + .await } async fn seek_from(content_len: usize, seek_from: SeekFrom, expected_pos: usize, rng_seed: u64) { @@ -196,44 +188,42 @@ async fn seek_from(content_len: usize, seek_from: SeekFrom, expected_pos: usize, store.close().await.unwrap(); } -#[proptest] -fn seek_from_current( +#[proptest(async = "tokio")] +async fn seek_from_current( #[strategy(1..2 * BLOCK_SIZE)] content_len: usize, #[strategy(vec(0..#content_len, 1..10))] positions: Vec, #[strategy(test_utils::rng_seed_strategy())] rng_seed: u64, ) { - test_utils::run(async { - let (rng, _base_dir, store, [branch]) = setup(rng_seed).await; - let mut tx = store.begin_write().await.unwrap(); - let mut changeset = Changeset::new(); + let (rng, _base_dir, store, [branch]) = setup(rng_seed).await; + let mut tx = store.begin_write().await.unwrap(); + let mut changeset = Changeset::new(); - let content = random_bytes(rng, content_len); + let content = random_bytes(rng, content_len); - let mut blob = Blob::create(branch.clone(), BlobId::ROOT); - blob.write_all(&mut tx, &mut changeset, &content[..]) - .await - .unwrap(); - blob.flush(&mut tx, &mut changeset).await.unwrap(); - changeset - .apply(&mut tx, branch.id(), branch.keys().write().unwrap()) - .await - .unwrap(); + let mut blob = Blob::create(branch.clone(), BlobId::ROOT); + blob.write_all(&mut tx, &mut changeset, &content[..]) + .await + .unwrap(); + blob.flush(&mut tx, &mut changeset).await.unwrap(); + changeset + .apply(&mut tx, branch.id(), branch.keys().write().unwrap()) + .await + .unwrap(); - blob.seek(SeekFrom::Start(0)); + blob.seek(SeekFrom::Start(0)); - let mut prev_pos = 0; - for pos in positions { - blob.seek(SeekFrom::Current(pos as i64 - prev_pos as i64)); - prev_pos = pos; - } + let mut prev_pos = 0; + for pos in positions { + blob.seek(SeekFrom::Current(pos as i64 - prev_pos as i64)); + prev_pos = pos; + } - let mut read_buffer = vec![0; content.len()]; - let len = blob.read_all(&mut tx, &mut read_buffer[..]).await.unwrap(); - assert_eq!(read_buffer[..len], content[prev_pos..]); + let mut read_buffer = vec![0; content.len()]; + let len = blob.read_all(&mut tx, &mut read_buffer[..]).await.unwrap(); + assert_eq!(read_buffer[..len], content[prev_pos..]); - drop(tx); - store.close().await.unwrap(); - }) + drop(tx); + store.close().await.unwrap(); } #[tokio::test(flavor = "multi_thread")] @@ -524,21 +514,15 @@ async fn write_reopen_and_read() { store.close().await.unwrap(); } -#[proptest] -fn fork_and_write( +#[proptest(async = "tokio")] +async fn fork_and_write( #[strategy(0..2 * BLOCK_SIZE)] src_len: usize, #[strategy(0..=#src_len)] seek_pos: usize, #[strategy(1..BLOCK_SIZE)] write_len: usize, src_locator_is_root: bool, #[strategy(test_utils::rng_seed_strategy())] rng_seed: u64, ) { - test_utils::run(fork_and_write_case( - src_len, - seek_pos, - write_len, - src_locator_is_root, - rng_seed, - )) + fork_and_write_case(src_len, seek_pos, write_len, src_locator_is_root, rng_seed).await } async fn fork_and_write_case( diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index c7a180a6f..3490ba57e 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -155,7 +155,6 @@ impl Worker { Some(expired) = self.timeouts.next() => { let (client_id, request_key) = expired.into_inner(); self.handle_failure(client_id, request_key); - continue; } } } @@ -440,8 +439,8 @@ mod tests { use std::collections::VecDeque; use test_strategy::proptest; - #[proptest] - fn sanity_check( + #[proptest(async = "tokio")] + async fn sanity_check( #[strategy(test_utils::rng_seed_strategy())] seed: u64, #[strategy(1usize..=32)] num_blocks: usize, #[strategy(1usize..=3)] num_peers: usize, @@ -452,13 +451,6 @@ mod tests { fn sanity_check_case(seed: u64, num_blocks: usize, num_peers: usize) { test_utils::init_log(); - // Tokio runtime needed for `DelayQueue`. - let _runtime_guard = tokio::runtime::Builder::new_current_thread() - .enable_time() - .build() - .unwrap() - .enter(); - let mut rng = StdRng::seed_from_u64(seed); let snapshot = Snapshot::generate(&mut rng, num_blocks); diff --git a/lib/src/network/tests.rs b/lib/src/network/tests.rs index 6cb92dbba..534a59357 100644 --- a/lib/src/network/tests.rs +++ b/lib/src/network/tests.rs @@ -42,19 +42,20 @@ const TIMEOUT: Duration = Duration::from_secs(60); // // NOTE: Reducing the number of cases otherwise this test is too slow. // TODO: Make it faster and increase the cases. -#[proptest(cases = 8)] -fn transfer_snapshot_between_two_replicas( +#[proptest(async = "tokio", cases = 8)] +async fn transfer_snapshot_between_two_replicas( #[strategy(0usize..32)] leaf_count: usize, #[strategy(0usize..2)] changeset_count: usize, #[strategy(1usize..4)] changeset_size: usize, #[strategy(test_utils::rng_seed_strategy())] rng_seed: u64, ) { - test_utils::run(transfer_snapshot_between_two_replicas_case( + transfer_snapshot_between_two_replicas_case( leaf_count, changeset_count, changeset_size, rng_seed, - )) + ) + .await } async fn transfer_snapshot_between_two_replicas_case( @@ -106,15 +107,12 @@ async fn transfer_snapshot_between_two_replicas_case( // NOTE: Reducing the number of cases otherwise this test is too slow. // TODO: Make it faster and increase the cases. -#[proptest(cases = 8)] -fn transfer_blocks_between_two_replicas( +#[proptest(async = "tokio", cases = 8)] +async fn transfer_blocks_between_two_replicas( #[strategy(1usize..32)] block_count: usize, #[strategy(test_utils::rng_seed_strategy())] rng_seed: u64, ) { - test_utils::run(transfer_blocks_between_two_replicas_case( - block_count, - rng_seed, - )) + transfer_blocks_between_two_replicas_case(block_count, rng_seed).await } // #[tokio::test] diff --git a/lib/src/repository/vault/tests.rs b/lib/src/repository/vault/tests.rs index 8c7ad26b2..6aa4b602c 100644 --- a/lib/src/repository/vault/tests.rs +++ b/lib/src/repository/vault/tests.rs @@ -474,13 +474,13 @@ async fn block_ids_pagination() { assert!(actual.is_empty()); } -#[proptest] -fn sync_progress( +#[proptest(async = "tokio")] +async fn sync_progress( #[strategy(1usize..16)] block_count: usize, #[strategy(1usize..5)] branch_count: usize, #[strategy(test_utils::rng_seed_strategy())] rng_seed: u64, ) { - test_utils::run(sync_progress_case(block_count, branch_count, rng_seed)) + sync_progress_case(block_count, branch_count, rng_seed).await } async fn sync_progress_case(block_count: usize, branch_count: usize, rng_seed: u64) { diff --git a/lib/src/store/index.rs b/lib/src/store/index.rs index daed667cf..072852c35 100644 --- a/lib/src/store/index.rs +++ b/lib/src/store/index.rs @@ -281,12 +281,12 @@ mod tests { assert_eq!(summary.block_presence, MultiBlockPresence::Full); } - #[proptest] - fn check_complete( + #[proptest(async = "tokio")] + async fn check_complete( #[strategy(0usize..=32)] leaf_count: usize, #[strategy(test_utils::rng_seed_strategy())] rng_seed: u64, ) { - test_utils::run(check_complete_case(leaf_count, rng_seed)) + check_complete_case(leaf_count, rng_seed).await } async fn check_complete_case(leaf_count: usize, rng_seed: u64) { @@ -368,12 +368,12 @@ mod tests { } } - #[proptest] - fn summary( + #[proptest(async = "tokio")] + async fn summary( #[strategy(0usize..=32)] leaf_count: usize, #[strategy(test_utils::rng_seed_strategy())] rng_seed: u64, ) { - test_utils::run(summary_case(leaf_count, rng_seed)) + summary_case(leaf_count, rng_seed).await } async fn summary_case(leaf_count: usize, rng_seed: u64) { diff --git a/lib/src/store/root_node.rs b/lib/src/store/root_node.rs index 8547524cc..10bb9434e 100644 --- a/lib/src/store/root_node.rs +++ b/lib/src/store/root_node.rs @@ -827,8 +827,8 @@ mod tests { use proptest::{arbitrary::any, collection::vec, sample::select, strategy::Strategy}; use test_strategy::proptest; - #[proptest] - fn proptest( + #[proptest(async = "tokio")] + async fn proptest( write_keys: Keypair, #[strategy(root_node_params_strategy())] input: Vec<( SnapshotId, @@ -837,7 +837,7 @@ mod tests { NodeState, )>, ) { - crate::test_utils::run(case(write_keys, input)) + case(write_keys, input).await } async fn case(write_keys: Keypair, input: Vec<(SnapshotId, PublicKey, Hash, NodeState)>) { diff --git a/lib/src/store/tests.rs b/lib/src/store/tests.rs index c6710c693..d8dff649d 100644 --- a/lib/src/store/tests.rs +++ b/lib/src/store/tests.rs @@ -304,12 +304,12 @@ async fn fallback() { .is_none()); } -#[proptest] -fn empty_nodes_are_not_stored( +#[proptest(async = "tokio")] +async fn empty_nodes_are_not_stored( #[strategy(1usize..32)] leaf_count: usize, #[strategy(test_utils::rng_seed_strategy())] rng_seed: u64, ) { - test_utils::run(empty_nodes_are_not_stored_case(leaf_count, rng_seed)) + empty_nodes_are_not_stored_case(leaf_count, rng_seed).await } async fn empty_nodes_are_not_stored_case(leaf_count: usize, rng_seed: u64) { @@ -354,12 +354,12 @@ async fn empty_nodes_are_not_stored_case(leaf_count: usize, rng_seed: u64) { } } -#[proptest] -fn prune( +#[proptest(async = "tokio")] +async fn prune( #[strategy(vec(any::(), 1..32))] ops: Vec, #[strategy(test_utils::rng_seed_strategy())] rng_seed: u64, ) { - test_utils::run(prune_case(ops, rng_seed)) + prune_case(ops, rng_seed).await } #[derive(Arbitrary, Debug)] diff --git a/lib/src/test_utils.rs b/lib/src/test_utils.rs index cc6c41fa0..0e600345b 100644 --- a/lib/src/test_utils.rs +++ b/lib/src/test_utils.rs @@ -1,16 +1,4 @@ use proptest::prelude::*; -use std::future::Future; - -// proptest doesn't work with the `#[tokio::test]` macro yet -// (see https://github.com/AltSysrq/proptest/issues/179). As a workaround, create the runtime -// manually. -pub(crate) fn run(future: F) -> F::Output { - tokio::runtime::Builder::new_current_thread() - .enable_time() - .build() - .unwrap() - .block_on(future) -} pub(crate) fn rng_seed_strategy() -> impl Strategy { any::().no_shrink() From ce803dcc27c038de18ee0a27be174c786575ad1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Wed, 18 Sep 2024 13:15:07 +0200 Subject: [PATCH 28/55] Expand tests for RequestTracker --- Cargo.toml | 2 +- lib/src/network/request_tracker.rs | 474 +++++------------------ lib/src/network/request_tracker/tests.rs | 388 +++++++++++++++++++ lib/src/protocol/block.rs | 42 ++ lib/src/protocol/test_utils.rs | 12 +- 5 files changed, 546 insertions(+), 372 deletions(-) create mode 100644 lib/src/network/request_tracker/tests.rs diff --git a/Cargo.toml b/Cargo.toml index ce025eca7..4c41dfa0a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,7 +50,7 @@ metrics-util = { version = "0.16.0", default-features = false } num_enum = { version = "0.7.0", default-features = false } once_cell = "1.18.0" pin-project-lite = "0.2.13" -proptest = "1.0" +proptest = "1.5" rand = { package = "ouisync-rand", path = "rand" } rcgen = "0.13" rmp-serde = "1.1.0" diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index 3490ba57e..d1ba9014c 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -1,13 +1,21 @@ +#[cfg(test)] +mod tests; + use super::{constants::REQUEST_TIMEOUT, message::Request}; use crate::{ collections::{HashMap, HashSet}, crypto::{sign::PublicKey, Hash}, protocol::BlockId, }; +use futures_util::Stream; use std::{ collections::hash_map::Entry, iter, + ops::{Deref, DerefMut}, + pin::Pin, sync::atomic::{AtomicUsize, Ordering}, + task::{Context, Poll}, + time::Duration, }; use tokio::{select, sync::mpsc, task}; use tokio_stream::StreamExt; @@ -24,7 +32,7 @@ pub(super) struct RequestTracker { impl RequestTracker { #[expect(dead_code)] pub fn new() -> Self { - let (this, worker) = build(); + let (this, worker) = build(DelayQueue::new()); task::spawn(worker.run()); this } @@ -120,24 +128,27 @@ impl<'a> From<&'a Request> for MessageKey { } } -fn build() -> (RequestTracker, Worker) { +fn build(timer: T) -> (RequestTracker, Worker) { let (command_tx, command_rx) = mpsc::unbounded_channel(); - (RequestTracker { command_tx }, Worker::new(command_rx)) + ( + RequestTracker { command_tx }, + Worker::new(timer, command_rx), + ) } -struct Worker { +struct Worker { clients: HashMap, - requests: HashMap, - timeouts: DelayQueue<(ClientId, MessageKey)>, + requests: HashMap>, + timer: TimerStream, command_rx: mpsc::UnboundedReceiver, } -impl Worker { - fn new(command_rx: mpsc::UnboundedReceiver) -> Self { +impl Worker { + fn new(timer: T, command_rx: mpsc::UnboundedReceiver) -> Self { Self { clients: HashMap::default(), requests: HashMap::default(), - timeouts: DelayQueue::new(), + timer: TimerStream(timer), command_rx, } } @@ -152,8 +163,7 @@ impl Worker { break; } } - Some(expired) = self.timeouts.next() => { - let (client_id, request_key) = expired.into_inner(); + Some((client_id, request_key)) = self.timer.next() => { self.handle_failure(client_id, request_key); } } @@ -166,8 +176,6 @@ impl Worker { while let Ok(command) = self.command_rx.try_recv() { self.handle_command(command); } - - // TODO: Check timeouts } #[cfg(test)] @@ -226,15 +234,15 @@ impl Worker { continue; }; - let Some(client_request_state) = entry.get_mut().clients.remove(&client_id) else { + let Some(interest) = entry.get_mut().interests.remove(&client_id) else { continue; }; - if let Some(timeout_key) = client_request_state.timeout_key { - self.timeouts.try_remove(&timeout_key); + if let Some(timeout_key) = interest.timer_key { + self.timer.remove(&timeout_key); } - if entry.get().clients.is_empty() { + if entry.get().interests.is_empty() { entry.remove(); continue; } @@ -259,30 +267,27 @@ impl Worker { .entry(request_key) .or_insert_with(|| RequestState { request, - clients: HashMap::default(), + interests: HashMap::default(), }); - let timeout_key = if request_state - .clients + let timer_key = if request_state + .interests .values() - .all(|state| state.timeout_key.is_none()) + .all(|state| state.timer_key.is_none()) { client_state .request_tx .send(request_state.request.clone()) .ok(); - Some( - self.timeouts - .insert((client_id, request_key), REQUEST_TIMEOUT), - ) + Some(self.timer.insert(client_id, request_key, REQUEST_TIMEOUT)) } else { None }; request_state - .clients - .insert(client_id, RequestClientState { timeout_key }); + .interests + .insert(client_id, Interest { timer_key }); } #[instrument(skip(self))] @@ -303,21 +308,24 @@ impl Worker { }; if let Entry::Occupied(mut entry) = self.requests.entry(response_key) { - entry.get_mut().clients.retain(|other_client_id, state| { - // TODO: remove only those with the same or worse block presence. - - if let Some(timeout_key) = state.timeout_key { - self.timeouts.try_remove(&timeout_key); - } + entry + .get_mut() + .interests + .retain(|other_client_id, interest| { + // TODO: remove only those with the same or worse block presence. + + if let Some(timer_key) = interest.timer_key { + self.timer.remove(&timer_key); + } - if !requests.is_empty() && *other_client_id != client_id { - followup_client_ids.push(*other_client_id); - } + if !requests.is_empty() && *other_client_id != client_id { + followup_client_ids.push(*other_client_id); + } - false - }); + false + }); - if entry.get().clients.is_empty() { + if entry.get().interests.is_empty() { entry.remove(); } } @@ -347,28 +355,29 @@ impl Worker { return; }; - if let Some(state) = entry.get_mut().clients.remove(&client_id) { - if let Some(timeout_key) = state.timeout_key { - self.timeouts.try_remove(&timeout_key); + if let Some(interest) = entry.get_mut().interests.remove(&client_id) { + if let Some(timer_key) = interest.timer_key { + self.timer.remove(&timer_key); } } // TODO: prefer one with the same or better block presence as `client_id`. - if let Some((fallback_client_id, state)) = entry + if let Some((fallback_client_id, interest)) = entry .get_mut() - .clients + .interests .iter_mut() - .find(|(_, state)| state.timeout_key.is_none()) + .find(|(_, interest)| interest.timer_key.is_none()) { - state.timeout_key = Some( - self.timeouts - .insert((*fallback_client_id, response_key), REQUEST_TIMEOUT), - ); + interest.timer_key = Some(self.timer.insert( + *fallback_client_id, + response_key, + REQUEST_TIMEOUT, + )); // TODO: send the request } - if entry.get().clients.is_empty() { + if entry.get().interests.is_empty() { entry.remove(); } } @@ -412,349 +421,74 @@ struct ClientState { requests: HashSet, } -struct RequestState { +struct RequestState { request: Request, - clients: HashMap, + interests: HashMap>, } -struct RequestClientState { +struct Interest { // disambiguator: ResponseDisambiguator, - timeout_key: Option, + timer_key: Option, } -#[cfg(test)] -mod tests { - use super::{ - super::{debug_payload::DebugResponse, message::Response}, - *, - }; - use crate::{ - crypto::{sign::Keypair, Hashable}, - network::message::ResponseDisambiguator, - protocol::{test_utils::Snapshot, Block, MultiBlockPresence, Proof, UntrustedProof}, - test_utils, - version_vector::VersionVector, - }; - use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; - use std::collections::VecDeque; - use test_strategy::proptest; - - #[proptest(async = "tokio")] - async fn sanity_check( - #[strategy(test_utils::rng_seed_strategy())] seed: u64, - #[strategy(1usize..=32)] num_blocks: usize, - #[strategy(1usize..=3)] num_peers: usize, - ) { - sanity_check_case(seed, num_blocks, num_peers); - } - - fn sanity_check_case(seed: u64, num_blocks: usize, num_peers: usize) { - test_utils::init_log(); - - let mut rng = StdRng::seed_from_u64(seed); - - let snapshot = Snapshot::generate(&mut rng, num_blocks); - let mut summary = Summary::default(); - - let (tracker, mut tracker_worker) = build(); +/// Trait for timer to to track request timeouts. +trait Timer: Unpin { + type Key: Copy; - let mut peers: Vec<_> = (0..num_peers) - .map(|_| { - let (tracker_client, tracker_request_rx) = tracker.new_client(); - let client = TestClient::new(tracker_client, tracker_request_rx); + fn insert( + &mut self, + client_id: ClientId, + message_key: MessageKey, + timeout: Duration, + ) -> Self::Key; - let writer_id = PublicKey::generate(&mut rng); - let write_keys = Keypair::generate(&mut rng); - let server = TestServer::new(writer_id, write_keys, &snapshot); + fn remove(&mut self, key: &Self::Key); - (client, server) - }) - .collect(); + fn poll_expired(&mut self, cx: &mut Context<'_>) -> Poll>; +} - while poll_peers(&mut rng, &mut peers, &snapshot, &mut summary) { - tracker_worker.step(); - } +struct TimerStream(T); - summary.verify(peers.len(), &snapshot); +impl Deref for TimerStream { + type Target = T; - assert_eq!(tracker_worker.request_count(), 0); + fn deref(&self) -> &Self::Target { + &self.0 } +} - #[derive(Default)] - struct Summary { - nodes: HashMap, - blocks: HashMap, +impl DerefMut for TimerStream { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 } +} - impl Summary { - fn receive_node(&mut self, hash: Hash) { - *self.nodes.entry(hash).or_default() += 1; - } - - fn receive_block(&mut self, block_id: BlockId) { - *self.blocks.entry(block_id).or_default() += 1; - } - - fn verify(&mut self, num_peers: usize, snapshot: &Snapshot) { - assert_eq!( - self.nodes.remove(snapshot.root_hash()).unwrap_or(0), - num_peers, - "root node not received exactly {num_peers} times: {:?}", - snapshot.root_hash() - ); - - for hash in snapshot - .inner_nodes() - .map(|node| &node.hash) - .chain(snapshot.leaf_nodes().map(|node| &node.locator)) - { - assert_eq!( - self.nodes.remove(hash).unwrap_or(0), - 1, - "child node not received exactly once: {hash:?}" - ); - } - - for block_id in snapshot.blocks().keys() { - assert_eq!( - self.blocks.remove(block_id).unwrap_or(0), - 1, - "block not received exactly once: {block_id:?}" - ); - } - - // Verify we received only the expected nodes and blocks - assert!( - self.nodes.is_empty(), - "unexpected nodes received: {:?}", - self.nodes - ); - assert!( - self.blocks.is_empty(), - "unexpected blocks received: {:?}", - self.blocks - ); - } - } +impl Stream for TimerStream { + type Item = (ClientId, MessageKey); - struct TestClient { - tracker_client: RequestTrackerClient, - tracker_request_rx: mpsc::UnboundedReceiver, + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().0.poll_expired(cx) } +} - impl TestClient { - fn new( - tracker_client: RequestTrackerClient, - tracker_request_rx: mpsc::UnboundedReceiver, - ) -> Self { - Self { - tracker_client, - tracker_request_rx, - } - } - - fn handle_response(&mut self, response: Response, summary: &mut Summary) { - match response { - Response::RootNode(proof, block_presence, debug_payload) => { - summary.receive_node(proof.hash); - - let requests = vec![Request::ChildNodes( - proof.hash, - ResponseDisambiguator::new(block_presence), - debug_payload.follow_up(), - )]; - - self.tracker_client - .success(MessageKey::RootNode(proof.writer_id), requests); - } - Response::InnerNodes(nodes, _disambiguator, debug_payload) => { - let parent_hash = nodes.hash(); - let requests: Vec<_> = nodes - .into_iter() - .map(|(_, node)| { - summary.receive_node(node.hash); - - Request::ChildNodes( - node.hash, - ResponseDisambiguator::new(node.summary.block_presence), - debug_payload.follow_up(), - ) - }) - .collect(); - - self.tracker_client - .success(MessageKey::ChildNodes(parent_hash), requests); - } - Response::LeafNodes(nodes, _disambiguator, debug_payload) => { - let parent_hash = nodes.hash(); - let requests = nodes - .into_iter() - .map(|node| { - summary.receive_node(node.locator); - - Request::Block(node.block_id, debug_payload.follow_up()) - }) - .collect(); - - self.tracker_client - .success(MessageKey::ChildNodes(parent_hash), requests); - } - Response::Block(content, nonce, _debug_payload) => { - let block = Block::new(content, nonce); - - summary.receive_block(block.id); - - self.tracker_client - .success(MessageKey::Block(block.id), vec![]); - } - Response::RootNodeError(writer_id, _debug_payload) => { - self.tracker_client.failure(MessageKey::RootNode(writer_id)); - } - Response::ChildNodesError(hash, _disambiguator, _debug_payload) => { - self.tracker_client.failure(MessageKey::ChildNodes(hash)); - } - Response::BlockError(block_id, _debug_payload) => { - self.tracker_client.failure(MessageKey::Block(block_id)); - } - Response::BlockOffer(_block_id, _debug_payload) => unimplemented!(), - }; - } - - fn poll_request(&mut self) -> Option { - self.tracker_request_rx.try_recv().ok() - } - } +impl Timer for DelayQueue<(ClientId, MessageKey)> { + type Key = delay_queue::Key; - struct TestServer { - writer_id: PublicKey, - write_keys: Keypair, - outbox: VecDeque, + fn insert( + &mut self, + client_id: ClientId, + message_key: MessageKey, + timeout: Duration, + ) -> Self::Key { + DelayQueue::insert(self, (client_id, message_key), timeout) } - impl TestServer { - fn new(writer_id: PublicKey, write_keys: Keypair, snapshot: &Snapshot) -> Self { - let proof = UntrustedProof::from(Proof::new( - writer_id, - VersionVector::first(writer_id), - *snapshot.root_hash(), - &write_keys, - )); - - let outbox = [Response::RootNode( - proof.clone(), - MultiBlockPresence::Full, - DebugResponse::unsolicited(), - )] - .into(); - - Self { - writer_id, - write_keys, - outbox, - } - } - - fn handle_request(&mut self, request: Request, snapshot: &Snapshot) { - match request { - Request::RootNode(writer_id, debug_payload) => { - if writer_id == self.writer_id { - let proof = Proof::new( - writer_id, - VersionVector::first(writer_id), - *snapshot.root_hash(), - &self.write_keys, - ); - - self.outbox.push_back(Response::RootNode( - proof.into(), - MultiBlockPresence::Full, - debug_payload.reply(), - )); - } else { - self.outbox - .push_back(Response::RootNodeError(writer_id, debug_payload.reply())); - } - } - Request::ChildNodes(hash, disambiguator, debug_payload) => { - if let Some(nodes) = snapshot - .inner_layers() - .flat_map(|layer| layer.inner_maps()) - .find_map(|(parent_hash, nodes)| (*parent_hash == hash).then_some(nodes)) - { - self.outbox.push_back(Response::InnerNodes( - nodes.clone(), - disambiguator, - debug_payload.reply(), - )); - } - - if let Some(nodes) = snapshot - .leaf_sets() - .find_map(|(parent_hash, nodes)| (*parent_hash == hash).then_some(nodes)) - { - self.outbox.push_back(Response::LeafNodes( - nodes.clone(), - disambiguator, - debug_payload.reply(), - )); - } - } - Request::Block(block_id, debug_payload) => { - if let Some(block) = snapshot.blocks().get(&block_id) { - self.outbox.push_back(Response::Block( - block.content.clone(), - block.nonce, - debug_payload.reply(), - )); - } - } - } - } - - fn poll_response(&mut self) -> Option { - self.outbox.pop_front() - } + fn remove(&mut self, key: &Self::Key) { + DelayQueue::try_remove(self, key); } - // Polls every client and server once, in random order - fn poll_peers( - rng: &mut R, - peers: &mut [(TestClient, TestServer)], - snapshot: &Snapshot, - summary: &mut Summary, - ) -> bool { - enum Side { - Client, - Server, - } - - let mut order: Vec<_> = (0..peers.len()) - .flat_map(|index| [(Side::Client, index), (Side::Server, index)]) - .collect(); - - order.shuffle(rng); - - let mut changed = false; - - for (side, index) in order { - let (client, server) = &mut peers[index]; - - match side { - Side::Client => { - if let Some(request) = client.poll_request() { - server.handle_request(request, snapshot); - changed = true; - } - } - Side::Server => { - if let Some(response) = server.poll_response() { - client.handle_response(response, summary); - changed = true; - } - } - } - } - - changed + fn poll_expired(&mut self, cx: &mut Context<'_>) -> Poll> { + self.poll_expired(cx) + .map(|expired| expired.map(|expired| expired.into_inner())) } } diff --git a/lib/src/network/request_tracker/tests.rs b/lib/src/network/request_tracker/tests.rs new file mode 100644 index 000000000..c43ac4cc0 --- /dev/null +++ b/lib/src/network/request_tracker/tests.rs @@ -0,0 +1,388 @@ +use super::{ + super::{debug_payload::DebugResponse, message::Response}, + *, +}; +use crate::{ + crypto::{sign::Keypair, Hashable}, + network::message::ResponseDisambiguator, + protocol::{test_utils::Snapshot, Block, MultiBlockPresence, Proof, UntrustedProof}, + test_utils, + version_vector::VersionVector, +}; +use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; +use std::collections::VecDeque; + +#[test] +fn simulation() { + simulation_case(6045920800462135606, 1, 2, (1, 10), (1, 10)); + // let seed = rand::random(); + // simulation_case(seed, 32, 4, (1, 10), (1, 10)); +} + +fn simulation_case( + seed: u64, + max_blocks: usize, + max_peers: usize, + peer_insert_ratio: (u32, u32), + peer_remove_ratio: (u32, u32), +) { + test_utils::init_log(); + + tracing::info!( + seed, + max_blocks, + max_peers, + peer_insert_ratio = ?peer_insert_ratio, + peer_remove_ratio = ?peer_remove_ratio, + ); + + let mut rng = StdRng::seed_from_u64(seed); + + let (tracker, mut tracker_worker) = build(FakeTimer); + let mut summary = Summary::default(); + + let block_count = rng.gen_range(1..=max_blocks); + let snapshot = Snapshot::generate(&mut rng, block_count); + + tracing::info!(?snapshot); + + let mut peers = Vec::new(); + let mut total_peer_count = 0; + + loop { + let peers_len_before = peers.len(); + + if peers.is_empty() + || (peers.len() < max_peers && rng.gen_ratio(peer_insert_ratio.0, peer_insert_ratio.1)) + { + tracing::info!("insert peer"); + + let (tracker_client, tracker_request_rx) = tracker.new_client(); + let client = TestClient::new(tracker_client, tracker_request_rx); + + let writer_id = PublicKey::generate(&mut rng); + let write_keys = Keypair::generate(&mut rng); + let server = TestServer::new(writer_id, write_keys, &snapshot); + + peers.push((client, server)); + } + + if peers.len() > 1 && rng.gen_ratio(peer_remove_ratio.0, peer_remove_ratio.1) { + tracing::info!("remove peer"); + + let index = rng.gen_range(0..peers.len()); + peers.remove(index); + } + + // Note some peers might be inserted and removed in the same tick. Such peers are discounted + // from the total because they would not send/receive any messages. + total_peer_count += peers.len().saturating_sub(peers_len_before); + + if poll_peers(&mut rng, &mut peers, &snapshot, &mut summary) { + tracker_worker.step(); + } else { + break; + } + } + + summary.verify(total_peer_count, &snapshot); + assert_eq!(tracker_worker.request_count(), 0); +} + +#[derive(Default)] +struct Summary { + nodes: HashMap, + blocks: HashMap, +} + +impl Summary { + fn receive_node(&mut self, hash: Hash) { + *self.nodes.entry(hash).or_default() += 1; + } + + fn receive_block(&mut self, block_id: BlockId) { + *self.blocks.entry(block_id).or_default() += 1; + } + + fn verify(&mut self, num_peers: usize, snapshot: &Snapshot) { + assert_eq!( + self.nodes.remove(snapshot.root_hash()).unwrap_or(0), + num_peers, + "root node not received exactly {num_peers} times: {:?}", + snapshot.root_hash() + ); + + for hash in snapshot + .inner_nodes() + .map(|node| &node.hash) + .chain(snapshot.leaf_nodes().map(|node| &node.locator)) + { + assert_eq!( + self.nodes.remove(hash).unwrap_or(0), + 1, + "child node not received exactly once: {hash:?}" + ); + } + + for block_id in snapshot.blocks().keys() { + assert_eq!( + self.blocks.remove(block_id).unwrap_or(0), + 1, + "block not received exactly once: {block_id:?}" + ); + } + + // Verify we received only the expected nodes and blocks + assert!( + self.nodes.is_empty(), + "unexpected nodes received: {:?}", + self.nodes + ); + assert!( + self.blocks.is_empty(), + "unexpected blocks received: {:?}", + self.blocks + ); + } +} + +struct TestClient { + tracker_client: RequestTrackerClient, + tracker_request_rx: mpsc::UnboundedReceiver, +} + +impl TestClient { + fn new( + tracker_client: RequestTrackerClient, + tracker_request_rx: mpsc::UnboundedReceiver, + ) -> Self { + Self { + tracker_client, + tracker_request_rx, + } + } + + fn handle_response(&mut self, response: Response, summary: &mut Summary) { + match response { + Response::RootNode(proof, block_presence, debug_payload) => { + summary.receive_node(proof.hash); + + let requests = vec![Request::ChildNodes( + proof.hash, + ResponseDisambiguator::new(block_presence), + debug_payload.follow_up(), + )]; + + self.tracker_client + .success(MessageKey::RootNode(proof.writer_id), requests); + } + Response::InnerNodes(nodes, _disambiguator, debug_payload) => { + let parent_hash = nodes.hash(); + let requests: Vec<_> = nodes + .into_iter() + .map(|(_, node)| { + summary.receive_node(node.hash); + + Request::ChildNodes( + node.hash, + ResponseDisambiguator::new(node.summary.block_presence), + debug_payload.follow_up(), + ) + }) + .collect(); + + self.tracker_client + .success(MessageKey::ChildNodes(parent_hash), requests); + } + Response::LeafNodes(nodes, _disambiguator, debug_payload) => { + let parent_hash = nodes.hash(); + let requests = nodes + .into_iter() + .map(|node| { + summary.receive_node(node.locator); + + Request::Block(node.block_id, debug_payload.follow_up()) + }) + .collect(); + + self.tracker_client + .success(MessageKey::ChildNodes(parent_hash), requests); + } + Response::Block(content, nonce, _debug_payload) => { + let block = Block::new(content, nonce); + + summary.receive_block(block.id); + + self.tracker_client + .success(MessageKey::Block(block.id), vec![]); + } + Response::RootNodeError(writer_id, _debug_payload) => { + self.tracker_client.failure(MessageKey::RootNode(writer_id)); + } + Response::ChildNodesError(hash, _disambiguator, _debug_payload) => { + self.tracker_client.failure(MessageKey::ChildNodes(hash)); + } + Response::BlockError(block_id, _debug_payload) => { + self.tracker_client.failure(MessageKey::Block(block_id)); + } + Response::BlockOffer(_block_id, _debug_payload) => unimplemented!(), + }; + } + + fn poll_request(&mut self) -> Option { + self.tracker_request_rx.try_recv().ok() + } +} + +struct TestServer { + writer_id: PublicKey, + write_keys: Keypair, + outbox: VecDeque, +} + +impl TestServer { + fn new(writer_id: PublicKey, write_keys: Keypair, snapshot: &Snapshot) -> Self { + let proof = UntrustedProof::from(Proof::new( + writer_id, + VersionVector::first(writer_id), + *snapshot.root_hash(), + &write_keys, + )); + + let outbox = [Response::RootNode( + proof.clone(), + MultiBlockPresence::Full, + DebugResponse::unsolicited(), + )] + .into(); + + Self { + writer_id, + write_keys, + outbox, + } + } + + fn handle_request(&mut self, request: Request, snapshot: &Snapshot) { + match request { + Request::RootNode(writer_id, debug_payload) => { + if writer_id == self.writer_id { + let proof = Proof::new( + writer_id, + VersionVector::first(writer_id), + *snapshot.root_hash(), + &self.write_keys, + ); + + self.outbox.push_back(Response::RootNode( + proof.into(), + MultiBlockPresence::Full, + debug_payload.reply(), + )); + } else { + self.outbox + .push_back(Response::RootNodeError(writer_id, debug_payload.reply())); + } + } + Request::ChildNodes(hash, disambiguator, debug_payload) => { + if let Some(nodes) = snapshot + .inner_layers() + .flat_map(|layer| layer.inner_maps()) + .find_map(|(parent_hash, nodes)| (*parent_hash == hash).then_some(nodes)) + { + self.outbox.push_back(Response::InnerNodes( + nodes.clone(), + disambiguator, + debug_payload.reply(), + )); + } + + if let Some(nodes) = snapshot + .leaf_sets() + .find_map(|(parent_hash, nodes)| (*parent_hash == hash).then_some(nodes)) + { + self.outbox.push_back(Response::LeafNodes( + nodes.clone(), + disambiguator, + debug_payload.reply(), + )); + } + } + Request::Block(block_id, debug_payload) => { + if let Some(block) = snapshot.blocks().get(&block_id) { + self.outbox.push_back(Response::Block( + block.content.clone(), + block.nonce, + debug_payload.reply(), + )); + } + } + } + } + + fn poll_response(&mut self) -> Option { + self.outbox.pop_front() + } +} + +// Polls every client and server once, in random order +fn poll_peers( + rng: &mut R, + peers: &mut [(TestClient, TestServer)], + snapshot: &Snapshot, + summary: &mut Summary, +) -> bool { + enum Side { + Client, + Server, + } + + let mut order: Vec<_> = (0..peers.len()) + .flat_map(|index| [(Side::Client, index), (Side::Server, index)]) + .collect(); + + order.shuffle(rng); + + let mut changed = false; + + for (side, index) in order { + let (client, server) = &mut peers[index]; + + match side { + Side::Client => { + if let Some(request) = client.poll_request() { + server.handle_request(request, snapshot); + changed = true; + } + } + Side::Server => { + if let Some(response) = server.poll_response() { + client.handle_response(response, summary); + changed = true; + } + } + } + } + + changed +} + +struct FakeTimer; + +impl Timer for FakeTimer { + type Key = (); + + fn insert( + &mut self, + _client_id: ClientId, + _message_key: MessageKey, + _timeout: Duration, + ) -> Self::Key { + } + + fn remove(&mut self, _key: &Self::Key) {} + + fn poll_expired(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(None) + } +} diff --git a/lib/src/protocol/block.rs b/lib/src/protocol/block.rs index 9693bb61b..f03f0d83d 100644 --- a/lib/src/protocol/block.rs +++ b/lib/src/protocol/block.rs @@ -89,6 +89,14 @@ impl Distribution for Standard { } } +impl fmt::Debug for Block { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Block") + .field("id", &self.id) + .finish_non_exhaustive() + } +} + #[derive(Clone, Serialize, Deserialize)] pub(crate) struct BlockContent(Box<[u8]>); @@ -174,3 +182,37 @@ impl fmt::Debug for BlockContent { write!(f, "{:<8}", hex_fmt::HexFmt(&self[..])) } } + +#[cfg(test)] +mod test_utils { + use super::{Block, BlockContent, BlockNonce, BLOCK_SIZE}; + use proptest::{ + arbitrary::{any, Arbitrary, StrategyFor}, + collection::{vec, VecStrategy}, + strategy::{Map, NoShrink, Strategy}, + }; + + impl Arbitrary for BlockContent { + type Parameters = (); + type Strategy = Map>>, fn(Vec) -> Self>; + + fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { + vec(any::(), BLOCK_SIZE) + .no_shrink() + .prop_map(|bytes| Self(bytes.into_boxed_slice())) + } + } + + impl Arbitrary for Block { + type Parameters = (); + type Strategy = Map< + (StrategyFor, NoShrink>), + fn((BlockContent, BlockNonce)) -> Self, + >; + + fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { + (any::(), any::().no_shrink()) + .prop_map(|(content, nonce)| Self::new(content, nonce)) + } + } +} diff --git a/lib/src/protocol/test_utils.rs b/lib/src/protocol/test_utils.rs index 63c7dcf2c..f9ac23580 100644 --- a/lib/src/protocol/test_utils.rs +++ b/lib/src/protocol/test_utils.rs @@ -7,9 +7,10 @@ use crate::{ }, }; use rand::{distributions::Standard, Rng}; -use std::mem; +use std::{fmt, mem}; // In-memory snapshot for testing purposes. +#[derive(Clone)] pub(crate) struct Snapshot { root_hash: Hash, inners: [HashMap; INNER_LAYER_COUNT], @@ -131,6 +132,15 @@ impl Snapshot { } } +impl fmt::Debug for Snapshot { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Snapshot") + .field("root_hash", &self.root_hash) + .field("num_blocks", &self.blocks.len()) + .finish_non_exhaustive() + } +} + pub(crate) struct InnerLayer<'a>(&'a Snapshot, usize); impl<'a> InnerLayer<'a> { From 11e9f19e26be77a564230e9b2615fde34fadcf0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Thu, 19 Sep 2024 16:22:27 +0200 Subject: [PATCH 29/55] Track request parent-child relationships in RequestTracker --- lib/src/network/request_tracker.rs | 487 +++++++++++++---------- lib/src/network/request_tracker/graph.rs | 101 +++++ lib/src/network/request_tracker/tests.rs | 84 ++-- 3 files changed, 424 insertions(+), 248 deletions(-) create mode 100644 lib/src/network/request_tracker/graph.rs diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index d1ba9014c..2c54de5be 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -1,21 +1,18 @@ +mod graph; #[cfg(test)] mod tests; +use self::graph::{Entry as GraphEntry, Graph, Key as GraphKey}; use super::{constants::REQUEST_TIMEOUT, message::Request}; use crate::{ collections::{HashMap, HashSet}, crypto::{sign::PublicKey, Hash}, - protocol::BlockId, + protocol::{BlockId, MultiBlockPresence}, }; -use futures_util::Stream; use std::{ - collections::hash_map::Entry, - iter, - ops::{Deref, DerefMut}, - pin::Pin, + collections::VecDeque, + iter, mem, sync::atomic::{AtomicUsize, Ordering}, - task::{Context, Poll}, - time::Duration, }; use tokio::{select, sync::mpsc, task}; use tokio_stream::StreamExt; @@ -32,7 +29,7 @@ pub(super) struct RequestTracker { impl RequestTracker { #[expect(dead_code)] pub fn new() -> Self { - let (this, worker) = build(DelayQueue::new()); + let (this, worker) = build(); task::spawn(worker.run()); this } @@ -67,18 +64,19 @@ pub(super) struct RequestTrackerClient { impl RequestTrackerClient { /// Handle sending a request that does not follow from any previously received response. #[expect(dead_code)] - pub fn initial(&self, request: Request) { + pub fn initial(&self, request: Request, block_presence: MultiBlockPresence) { self.command_tx .send(Command::HandleInitial { client_id: self.client_id, request, + block_presence, }) .ok(); } /// Handle sending requests that follow from a received success response. #[cfg_attr(not(test), expect(dead_code))] - pub fn success(&self, response_key: MessageKey, requests: Vec) { + pub fn success(&self, response_key: MessageKey, requests: Vec<(Request, MultiBlockPresence)>) { self.command_tx .send(Command::HandleSuccess { client_id: self.client_id, @@ -128,28 +126,25 @@ impl<'a> From<&'a Request> for MessageKey { } } -fn build(timer: T) -> (RequestTracker, Worker) { +fn build() -> (RequestTracker, Worker) { let (command_tx, command_rx) = mpsc::unbounded_channel(); - ( - RequestTracker { command_tx }, - Worker::new(timer, command_rx), - ) + (RequestTracker { command_tx }, Worker::new(command_rx)) } -struct Worker { - clients: HashMap, - requests: HashMap>, - timer: TimerStream, +struct Worker { command_rx: mpsc::UnboundedReceiver, + clients: HashMap, + requests: Graph, + timer: DelayQueue<(ClientId, MessageKey)>, } -impl Worker { - fn new(timer: T, command_rx: mpsc::UnboundedReceiver) -> Self { +impl Worker { + fn new(command_rx: mpsc::UnboundedReceiver) -> Self { Self { - clients: HashMap::default(), - requests: HashMap::default(), - timer: TimerStream(timer), command_rx, + clients: HashMap::default(), + requests: Graph::new(), + timer: DelayQueue::new(), } } @@ -163,7 +158,8 @@ impl Worker { break; } } - Some((client_id, request_key)) = self.timer.next() => { + Some(expired) = self.timer.next() => { + let (client_id, request_key) = expired.into_inner(); self.handle_failure(client_id, request_key); } } @@ -194,8 +190,12 @@ impl Worker { Command::RemoveClient { client_id } => { self.remove_client(client_id); } - Command::HandleInitial { client_id, request } => { - self.handle_initial(client_id, request); + Command::HandleInitial { + client_id, + request, + block_presence, + } => { + self.handle_initial(client_id, request, block_presence); } Command::HandleSuccess { client_id, @@ -213,14 +213,9 @@ impl Worker { } } + #[instrument(skip(self, request_tx))] fn insert_client(&mut self, client_id: ClientId, request_tx: mpsc::UnboundedSender) { - self.clients.insert( - client_id, - ClientState { - request_tx, - requests: HashSet::default(), - }, - ); + self.clients.insert(client_id, ClientState::new(request_tx)); } #[instrument(skip(self))] @@ -229,65 +224,24 @@ impl Worker { return; }; - for message_key in client_state.requests { - let Entry::Occupied(mut entry) = self.requests.entry(message_key) else { - continue; - }; - - let Some(interest) = entry.get_mut().interests.remove(&client_id) else { - continue; - }; - - if let Some(timeout_key) = interest.timer_key { - self.timer.remove(&timeout_key); - } - - if entry.get().interests.is_empty() { - entry.remove(); - continue; - } - - // TODO: fallback to another client, if any + for (request_key, block_presence) in client_state.requests { + self.cancel_request(client_id, GraphKey(request_key, block_presence)); } } #[instrument(skip(self))] - fn handle_initial(&mut self, client_id: ClientId, request: Request) { - let Some(client_state) = self.clients.get_mut(&client_id) else { - // client not inserted - return; - }; - - let request_key = MessageKey::from(&request); - - client_state.requests.insert(request_key); - - let request_state = self - .requests - .entry(request_key) - .or_insert_with(|| RequestState { - request, - interests: HashMap::default(), - }); - - let timer_key = if request_state - .interests - .values() - .all(|state| state.timer_key.is_none()) - { - client_state - .request_tx - .send(request_state.request.clone()) - .ok(); - - Some(self.timer.insert(client_id, request_key, REQUEST_TIMEOUT)) - } else { - None - }; - - request_state - .interests - .insert(client_id, Interest { timer_key }); + fn handle_initial( + &mut self, + client_id: ClientId, + request: Request, + block_presence: MultiBlockPresence, + ) { + self.insert_request( + client_id, + GraphKey(MessageKey::from(&request), block_presence), + Some(request), + None, + ) } #[instrument(skip(self))] @@ -295,90 +249,240 @@ impl Worker { &mut self, client_id: ClientId, response_key: MessageKey, - requests: Vec, + requests: Vec<(Request, MultiBlockPresence)>, ) { - if let Some(state) = self.clients.get_mut(&client_id) { - state.requests.remove(&response_key); - } - - let mut followup_client_ids = if requests.is_empty() { - vec![] - } else { - vec![client_id] + let Some(block_presence) = self + .clients + .get_mut(&client_id) + .and_then(|client_state| client_state.requests.remove(&response_key)) + else { + return; }; - if let Entry::Occupied(mut entry) = self.requests.entry(response_key) { - entry - .get_mut() - .interests - .retain(|other_client_id, interest| { - // TODO: remove only those with the same or worse block presence. + let request_key = GraphKey(response_key, block_presence); - if let Some(timer_key) = interest.timer_key { - self.timer.remove(&timer_key); - } + let (parent_keys, mut client_ids) = match self.requests.entry(request_key) { + GraphEntry::Occupied(mut entry) => match entry.get_mut() { + RequestState::InFlight { + sender_client_id, + sender_timer_key, + waiting, + } if *sender_client_id == client_id => { + self.timer.try_remove(sender_timer_key); - if !requests.is_empty() && *other_client_id != client_id { - followup_client_ids.push(*other_client_id); + let waiting = mem::take(waiting); + + // Add child requests to this request. + for (request, block_presence) in &requests { + entry.insert_child(GraphKey(MessageKey::from(request), *block_presence)); } - false - }); + // If this request has children, mark it as complete, otherwise remove it. + let parent_key = if !entry.children().is_empty() { + *entry.get_mut() = RequestState::Complete; + HashSet::default() + } else { + entry.remove().parents + }; - if entry.get().interests.is_empty() { - entry.remove(); - } - } + (parent_key, waiting) + } + RequestState::InFlight { .. } + | RequestState::Complete + | RequestState::Cancelled => return, + }, + GraphEntry::Vacant(_) => return, + }; + + // Register the followup (child) requests with this client but also with all the clients that + // were waiting for the original request. + client_ids.push_front(client_id); - for request in requests { - for (client_id, request) in followup_client_ids - .iter() - .copied() + for (child_request, child_block_presence) in requests { + for (client_id, child_request) in // TODO: use `repeat_n` once it gets stabilized. - .zip(iter::repeat(request)) + client_ids.iter().copied().zip(iter::repeat(child_request)) { - self.handle_initial(client_id, request); + self.insert_request( + client_id, + GraphKey(MessageKey::from(&child_request), child_block_presence), + Some(child_request.clone()), + Some(request_key), + ); } - // round-robin the requests across the clients - followup_client_ids.rotate_right(1); + // Round-robin the requests among the clients. + client_ids.rotate_left(1); + } + + for parent_key in parent_keys { + self.request_removed(request_key, parent_key); } } #[instrument(skip(self))] fn handle_failure(&mut self, client_id: ClientId, response_key: MessageKey) { - if let Some(state) = self.clients.get_mut(&client_id) { - state.requests.remove(&response_key); + let Some(client_state) = self.clients.get_mut(&client_id) else { + return; + }; + + let Some(block_presence) = client_state.requests.remove(&response_key) else { + return; + }; + + self.cancel_request(client_id, GraphKey(response_key, block_presence)); + } + + fn insert_request( + &mut self, + client_id: ClientId, + request_key: GraphKey, + request: Option, + parent_key: Option, + ) { + let Some(client_state) = self.clients.get_mut(&client_id) else { + return; + }; + + let (children, request, request_state) = match self.requests.entry(request_key) { + GraphEntry::Occupied(mut entry) => { + if let Some(parent_key) = parent_key { + entry.insert_parent(parent_key); + } + + ( + entry.children().iter().copied().collect(), + entry.request(), + entry.into_mut(), + ) + } + GraphEntry::Vacant(entry) => { + if let Some(request) = request { + let (request, request_state) = + entry.insert(request, parent_key, RequestState::Cancelled); + (Vec::new(), request, request_state) + } else { + return; + } + } + }; + + match request_state { + RequestState::InFlight { waiting, .. } => { + waiting.push_back(client_id); + client_state.requests.insert(request_key.0, request_key.1); + } + RequestState::Complete => (), + RequestState::Cancelled => { + let timer_key = self + .timer + .insert((client_id, request_key.0), REQUEST_TIMEOUT); + + *request_state = RequestState::InFlight { + sender_client_id: client_id, + sender_timer_key: timer_key, + waiting: VecDeque::new(), + }; + + client_state.requests.insert(request_key.0, request_key.1); + client_state.request_tx.send(request.clone()).ok(); + } } - let Entry::Occupied(mut entry) = self.requests.entry(response_key) else { + // NOTE: we are using recursion, but the graph is only a few layers deep (currently 5) so + // there is no danger of stack overflow. + for child_key in children { + self.insert_request(client_id, child_key, None, Some(request_key)); + } + } + + fn cancel_request(&mut self, client_id: ClientId, request_key: GraphKey) { + let GraphEntry::Occupied(mut entry) = self.requests.entry(request_key) else { return; }; - if let Some(interest) = entry.get_mut().interests.remove(&client_id) { - if let Some(timer_key) = interest.timer_key { - self.timer.remove(&timer_key); + let parent_keys = match entry.get_mut() { + RequestState::InFlight { + sender_client_id, + sender_timer_key, + waiting, + } => { + if *sender_client_id == client_id { + // The removed client is the current sender of this request. + + // Remove the timeout for the previous sender + self.timer.try_remove(sender_timer_key); + + // Find a waiting client + let next_client = iter::from_fn(|| waiting.pop_front()).find_map(|client_id| { + self.clients + .get(&client_id) + .map(|client_state| (client_id, client_state)) + }); + + if let Some((next_client_id, next_client_state)) = next_client { + // Next waiting client found. Promote it to a sender. + + *sender_client_id = next_client_id; + *sender_timer_key = self + .timer + .insert((next_client_id, request_key.0), REQUEST_TIMEOUT); + + // Send the request to the new sender. + next_client_state + .request_tx + .send(entry.request().clone()) + .ok(); + + return; + } else { + // No waiting client found. If this request has no children, we can remove + // it, otherwise we mark it as cancelled. + if !entry.children().is_empty() { + *entry.get_mut() = RequestState::Cancelled; + return; + } else { + entry.remove().parents + } + } + } else { + // The removed client is one of the waiting clients - remove it from the + // waiting queue. + if let Some(index) = waiting + .iter() + .position(|next_client_id| *next_client_id == client_id) + { + waiting.remove(index); + } + + return; + } } + RequestState::Complete | RequestState::Cancelled => return, + }; + + for parent_key in parent_keys { + self.request_removed(request_key, parent_key); } + } - // TODO: prefer one with the same or better block presence as `client_id`. - if let Some((fallback_client_id, interest)) = entry - .get_mut() - .interests - .iter_mut() - .find(|(_, interest)| interest.timer_key.is_none()) - { - interest.timer_key = Some(self.timer.insert( - *fallback_client_id, - response_key, - REQUEST_TIMEOUT, - )); + // Remove the request from its parent request and if it was the last child of the parent, remove + // the parent as well, recursively. + fn request_removed(&mut self, request_key: GraphKey, parent_key: GraphKey) { + let mut entry = match self.requests.entry(parent_key) { + GraphEntry::Occupied(entry) => entry, + GraphEntry::Vacant(_) => return, + }; - // TODO: send the request + entry.remove_child(&request_key); + + if !entry.children().is_empty() { + return; } - if entry.get().interests.is_empty() { - entry.remove(); + let grandparent_keys: Vec<_> = entry.parents().iter().copied().collect(); + for grandparent_key in grandparent_keys { + self.request_removed(parent_key, grandparent_key); } } } @@ -404,11 +508,12 @@ enum Command { HandleInitial { client_id: ClientId, request: Request, + block_presence: MultiBlockPresence, }, HandleSuccess { client_id: ClientId, response_key: MessageKey, - requests: Vec, + requests: Vec<(Request, MultiBlockPresence)>, }, HandleFailure { client_id: ClientId, @@ -418,77 +523,31 @@ enum Command { struct ClientState { request_tx: mpsc::UnboundedSender, - requests: HashSet, + requests: HashMap, } -struct RequestState { - request: Request, - interests: HashMap>, -} - -struct Interest { - // disambiguator: ResponseDisambiguator, - timer_key: Option, -} - -/// Trait for timer to to track request timeouts. -trait Timer: Unpin { - type Key: Copy; - - fn insert( - &mut self, - client_id: ClientId, - message_key: MessageKey, - timeout: Duration, - ) -> Self::Key; - - fn remove(&mut self, key: &Self::Key); - - fn poll_expired(&mut self, cx: &mut Context<'_>) -> Poll>; -} - -struct TimerStream(T); - -impl Deref for TimerStream { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for TimerStream { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl Stream for TimerStream { - type Item = (ClientId, MessageKey); - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.get_mut().0.poll_expired(cx) +impl ClientState { + fn new(request_tx: mpsc::UnboundedSender) -> Self { + Self { + request_tx, + requests: HashMap::default(), + } } } -impl Timer for DelayQueue<(ClientId, MessageKey)> { - type Key = delay_queue::Key; - - fn insert( - &mut self, - client_id: ClientId, - message_key: MessageKey, - timeout: Duration, - ) -> Self::Key { - DelayQueue::insert(self, (client_id, message_key), timeout) - } - - fn remove(&mut self, key: &Self::Key) { - DelayQueue::try_remove(self, key); - } - - fn poll_expired(&mut self, cx: &mut Context<'_>) -> Poll> { - self.poll_expired(cx) - .map(|expired| expired.map(|expired| expired.into_inner())) - } +enum RequestState { + /// This request is currently in flight + InFlight { + /// Client who's sending this request + sender_client_id: ClientId, + /// Timeout key for the request + sender_timer_key: delay_queue::Key, + /// Other clients interested in sending this request. If the current client fails or + /// timeouts, a new one will be picked from this list. + waiting: VecDeque, + }, + /// The response to this request has already been received. + Complete, + /// The response for the current client failed and there are no more clients waiting. + Cancelled, } diff --git a/lib/src/network/request_tracker/graph.rs b/lib/src/network/request_tracker/graph.rs new file mode 100644 index 000000000..59af83f88 --- /dev/null +++ b/lib/src/network/request_tracker/graph.rs @@ -0,0 +1,101 @@ +use super::MessageKey; +use crate::{collections::HashSet, network::message::Request, protocol::MultiBlockPresence}; +use std::marker::PhantomData; + +/// DAG for storing data for the request tracker. +pub(super) struct Graph { + _todo: PhantomData, +} + +impl Graph { + pub fn new() -> Self { + todo!() + } + + pub fn entry(&mut self, _key: Key) -> Entry<'_, T> { + todo!() + } + + #[cfg_attr(not(test), expect(dead_code))] + pub fn len(&self) -> usize { + todo!() + } +} + +#[derive(Clone, Copy, Eq, PartialEq, Hash)] +pub(super) struct Key(pub MessageKey, pub MultiBlockPresence); + +pub(super) enum Entry<'a, T> { + #[expect(dead_code)] + Occupied(OccupiedEntry<'a, T>), + #[expect(dead_code)] + Vacant(VacantEntry<'a, T>), +} + +pub(super) struct OccupiedEntry<'a, T> { + _todo: PhantomData<&'a mut T>, +} + +impl<'a, T> OccupiedEntry<'a, T> { + #[expect(dead_code)] + pub fn get(&self) -> &T { + todo!() + } + + pub fn get_mut(&mut self) -> &mut T { + todo!() + } + + pub fn into_mut(self) -> &'a mut T { + todo!() + } + + pub fn request(&self) -> &'a Request { + todo!() + } + + pub fn parents(&self) -> &HashSet { + todo!() + } + + pub fn insert_parent(&mut self, _parent: Key) { + todo!() + } + + pub fn children(&self) -> &HashSet { + todo!(); + } + + pub fn insert_child(&mut self, _child: Key) { + todo!() + } + + pub fn remove_child(&mut self, _child: &Key) { + todo!() + } + + pub fn remove(self) -> RemovedEntry { + todo!() + } +} + +pub(super) struct VacantEntry<'a, T> { + _todo: PhantomData<&'a mut T>, +} + +impl<'a, T> VacantEntry<'a, T> { + pub fn insert( + self, + _request: Request, + _parent: Option, + _value: T, + ) -> (&'a Request, &'a mut T) { + todo!() + } +} + +pub(super) struct RemovedEntry { + #[expect(dead_code)] + pub value: T, + pub parents: HashSet, +} diff --git a/lib/src/network/request_tracker/tests.rs b/lib/src/network/request_tracker/tests.rs index c43ac4cc0..85f1c07f0 100644 --- a/lib/src/network/request_tracker/tests.rs +++ b/lib/src/network/request_tracker/tests.rs @@ -12,8 +12,10 @@ use crate::{ use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; use std::collections::VecDeque; -#[test] -fn simulation() { +// Note: We need `tokio::test` here because the `RequestTracker` uses `DelayQueue` internaly which +// needs a tokio runtime. +#[tokio::test] +async fn simulation() { simulation_case(6045920800462135606, 1, 2, (1, 10), (1, 10)); // let seed = rand::random(); // simulation_case(seed, 32, 4, (1, 10), (1, 10)); @@ -38,8 +40,7 @@ fn simulation_case( let mut rng = StdRng::seed_from_u64(seed); - let (tracker, mut tracker_worker) = build(FakeTimer); - let mut summary = Summary::default(); + let (tracker, mut tracker_worker) = build(); let block_count = rng.gen_range(1..=max_blocks); let snapshot = Snapshot::generate(&mut rng, block_count); @@ -48,6 +49,7 @@ fn simulation_case( let mut peers = Vec::new(); let mut total_peer_count = 0; + let mut summary = Summary::new(snapshot.blocks().len()); loop { let peers_len_before = peers.len(); @@ -89,18 +91,34 @@ fn simulation_case( assert_eq!(tracker_worker.request_count(), 0); } -#[derive(Default)] struct Summary { + expected_blocks: usize, nodes: HashMap, blocks: HashMap, } impl Summary { + fn new(expected_blocks: usize) -> Self { + Self { + expected_blocks, + nodes: HashMap::default(), + blocks: HashMap::default(), + } + } + fn receive_node(&mut self, hash: Hash) { + if self.blocks.len() >= self.expected_blocks { + return; + } + *self.nodes.entry(hash).or_default() += 1; } fn receive_block(&mut self, block_id: BlockId) { + if self.blocks.len() >= self.expected_blocks { + return; + } + *self.blocks.entry(block_id).or_default() += 1; } @@ -167,10 +185,13 @@ impl TestClient { Response::RootNode(proof, block_presence, debug_payload) => { summary.receive_node(proof.hash); - let requests = vec![Request::ChildNodes( - proof.hash, - ResponseDisambiguator::new(block_presence), - debug_payload.follow_up(), + let requests = vec![( + Request::ChildNodes( + proof.hash, + ResponseDisambiguator::new(block_presence), + debug_payload.follow_up(), + ), + block_presence, )]; self.tracker_client @@ -183,10 +204,13 @@ impl TestClient { .map(|(_, node)| { summary.receive_node(node.hash); - Request::ChildNodes( - node.hash, - ResponseDisambiguator::new(node.summary.block_presence), - debug_payload.follow_up(), + ( + Request::ChildNodes( + node.hash, + ResponseDisambiguator::new(node.summary.block_presence), + debug_payload.follow_up(), + ), + node.summary.block_presence, ) }) .collect(); @@ -201,7 +225,10 @@ impl TestClient { .map(|node| { summary.receive_node(node.locator); - Request::Block(node.block_id, debug_payload.follow_up()) + ( + Request::Block(node.block_id, debug_payload.follow_up()), + MultiBlockPresence::None, + ) }) .collect(); @@ -306,6 +333,12 @@ impl TestServer { disambiguator, debug_payload.reply(), )); + } else { + self.outbox.push_back(Response::ChildNodesError( + hash, + disambiguator, + debug_payload.reply(), + )); } } Request::Block(block_id, debug_payload) => { @@ -315,6 +348,9 @@ impl TestServer { block.nonce, debug_payload.reply(), )); + } else { + self.outbox + .push_back(Response::BlockError(block_id, debug_payload.reply())); } } } @@ -366,23 +402,3 @@ fn poll_peers( changed } - -struct FakeTimer; - -impl Timer for FakeTimer { - type Key = (); - - fn insert( - &mut self, - _client_id: ClientId, - _message_key: MessageKey, - _timeout: Duration, - ) -> Self::Key { - } - - fn remove(&mut self, _key: &Self::Key) {} - - fn poll_expired(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(None) - } -} From 98865191b9dacf4d8b2f35859bd0ed35d741c4bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Mon, 23 Sep 2024 16:27:07 +0200 Subject: [PATCH 30/55] Implement Graph for RequestTracker + fix test failures --- lib/src/network/request_tracker.rs | 268 ++++++++++++----------- lib/src/network/request_tracker/graph.rs | 233 ++++++++++++++------ lib/src/network/request_tracker/tests.rs | 156 ++++++------- lib/src/protocol/test_utils.rs | 8 + 4 files changed, 397 insertions(+), 268 deletions(-) diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index 2c54de5be..9148c3b9e 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -2,10 +2,10 @@ mod graph; #[cfg(test)] mod tests; -use self::graph::{Entry as GraphEntry, Graph, Key as GraphKey}; +use self::graph::{Graph, Key as GraphKey}; use super::{constants::REQUEST_TIMEOUT, message::Request}; use crate::{ - collections::{HashMap, HashSet}, + collections::HashMap, crypto::{sign::PublicKey, Hash}, protocol::{BlockId, MultiBlockPresence}, }; @@ -76,11 +76,11 @@ impl RequestTrackerClient { /// Handle sending requests that follow from a received success response. #[cfg_attr(not(test), expect(dead_code))] - pub fn success(&self, response_key: MessageKey, requests: Vec<(Request, MultiBlockPresence)>) { + pub fn success(&self, request_key: MessageKey, requests: Vec<(Request, MultiBlockPresence)>) { self.command_tx .send(Command::HandleSuccess { client_id: self.client_id, - response_key, + request_key, requests, }) .ok(); @@ -88,11 +88,11 @@ impl RequestTrackerClient { /// Handle failure response. #[cfg_attr(not(test), expect(dead_code))] - pub fn failure(&self, response_key: MessageKey) { + pub fn failure(&self, request_key: MessageKey) { self.command_tx .send(Command::HandleFailure { client_id: self.client_id, - response_key, + request_key, }) .ok(); } @@ -160,7 +160,7 @@ impl Worker { } Some(expired) = self.timer.next() => { let (client_id, request_key) = expired.into_inner(); - self.handle_failure(client_id, request_key); + self.handle_failure(client_id, request_key, FailureReason::Timeout); } } } @@ -199,33 +199,36 @@ impl Worker { } Command::HandleSuccess { client_id, - response_key, + request_key, requests, } => { - self.handle_success(client_id, response_key, requests); + self.handle_success(client_id, request_key, requests); } Command::HandleFailure { client_id, - response_key, + request_key, } => { - self.handle_failure(client_id, response_key); + self.handle_failure(client_id, request_key, FailureReason::Response); } } } #[instrument(skip(self, request_tx))] fn insert_client(&mut self, client_id: ClientId, request_tx: mpsc::UnboundedSender) { + // tracing::debug!("insert_client"); self.clients.insert(client_id, ClientState::new(request_tx)); } #[instrument(skip(self))] fn remove_client(&mut self, client_id: ClientId) { + // tracing::debug!("remove_client"); + let Some(client_state) = self.clients.remove(&client_id) else { return; }; - for (request_key, block_presence) in client_state.requests { - self.cancel_request(client_id, GraphKey(request_key, block_presence)); + for (_, node_key) in client_state.requests { + self.cancel_request(client_id, node_key); } } @@ -236,62 +239,61 @@ impl Worker { request: Request, block_presence: MultiBlockPresence, ) { - self.insert_request( - client_id, - GraphKey(MessageKey::from(&request), block_presence), - Some(request), - None, - ) + // tracing::debug!("handle_initial"); + self.insert_request(client_id, request, block_presence, None) } #[instrument(skip(self))] fn handle_success( &mut self, client_id: ClientId, - response_key: MessageKey, + request_key: MessageKey, requests: Vec<(Request, MultiBlockPresence)>, ) { - let Some(block_presence) = self + // tracing::debug!("handle_success"); + + let node_key = self .clients .get_mut(&client_id) - .and_then(|client_state| client_state.requests.remove(&response_key)) - else { - return; - }; - - let request_key = GraphKey(response_key, block_presence); - - let (parent_keys, mut client_ids) = match self.requests.entry(request_key) { - GraphEntry::Occupied(mut entry) => match entry.get_mut() { - RequestState::InFlight { - sender_client_id, - sender_timer_key, - waiting, - } if *sender_client_id == client_id => { - self.timer.try_remove(sender_timer_key); - - let waiting = mem::take(waiting); + .and_then(|client_state| client_state.requests.remove(&request_key)); + + let mut client_ids = if let Some(node_key) = node_key { + let (client_ids, remove) = match self.requests.get_mut(node_key) { + Some(node) => match node.value_mut() { + RequestState::InFlight { + sender_client_id, + sender_timer_key, + waiting, + } if *sender_client_id == client_id => { + self.timer.try_remove(sender_timer_key); + + let waiting = mem::take(waiting); + + // If this request has, or will have, children, mark it as complete, otherwise + // remove it. + let remove = if node.children().len() > 0 || !requests.is_empty() { + *node.value_mut() = RequestState::Complete; + false + } else { + true + }; - // Add child requests to this request. - for (request, block_presence) in &requests { - entry.insert_child(GraphKey(MessageKey::from(request), *block_presence)); + (waiting, remove) } + RequestState::InFlight { .. } + | RequestState::Complete + | RequestState::Cancelled => return, + }, + None => return, + }; + + if remove { + self.remove_request(node_key); + } - // If this request has children, mark it as complete, otherwise remove it. - let parent_key = if !entry.children().is_empty() { - *entry.get_mut() = RequestState::Complete; - HashSet::default() - } else { - entry.remove().parents - }; - - (parent_key, waiting) - } - RequestState::InFlight { .. } - | RequestState::Complete - | RequestState::Cancelled => return, - }, - GraphEntry::Vacant(_) => return, + client_ids + } else { + Default::default() }; // Register the followup (child) requests with this client but also with all the clients that @@ -305,103 +307,102 @@ impl Worker { { self.insert_request( client_id, - GraphKey(MessageKey::from(&child_request), child_block_presence), - Some(child_request.clone()), - Some(request_key), + child_request.clone(), + child_block_presence, + node_key, ); } // Round-robin the requests among the clients. client_ids.rotate_left(1); } - - for parent_key in parent_keys { - self.request_removed(request_key, parent_key); - } } #[instrument(skip(self))] - fn handle_failure(&mut self, client_id: ClientId, response_key: MessageKey) { + fn handle_failure( + &mut self, + client_id: ClientId, + request_key: MessageKey, + reason: FailureReason, + ) { + // tracing::debug!("handle_failure"); + let Some(client_state) = self.clients.get_mut(&client_id) else { return; }; - let Some(block_presence) = client_state.requests.remove(&response_key) else { + let Some(node_key) = client_state.requests.remove(&request_key) else { return; }; - self.cancel_request(client_id, GraphKey(response_key, block_presence)); + self.cancel_request(client_id, node_key); } fn insert_request( &mut self, client_id: ClientId, - request_key: GraphKey, - request: Option, + request: Request, + block_presence: MultiBlockPresence, parent_key: Option, ) { - let Some(client_state) = self.clients.get_mut(&client_id) else { + let node_key = self.requests.get_or_insert( + request, + block_presence, + parent_key, + RequestState::Cancelled, + ); + + self.update_request(client_id, node_key); + } + + fn update_request(&mut self, client_id: ClientId, node_key: GraphKey) { + let Some(node) = self.requests.get_mut(node_key) else { return; }; - let (children, request, request_state) = match self.requests.entry(request_key) { - GraphEntry::Occupied(mut entry) => { - if let Some(parent_key) = parent_key { - entry.insert_parent(parent_key); - } - - ( - entry.children().iter().copied().collect(), - entry.request(), - entry.into_mut(), - ) - } - GraphEntry::Vacant(entry) => { - if let Some(request) = request { - let (request, request_state) = - entry.insert(request, parent_key, RequestState::Cancelled); - (Vec::new(), request, request_state) - } else { - return; - } - } + let Some(client_state) = self.clients.get_mut(&client_id) else { + return; }; - match request_state { + let request_key = MessageKey::from(node.request()); + + match node.value_mut() { RequestState::InFlight { waiting, .. } => { waiting.push_back(client_id); - client_state.requests.insert(request_key.0, request_key.1); + client_state.requests.insert(request_key, node_key); } RequestState::Complete => (), RequestState::Cancelled => { - let timer_key = self - .timer - .insert((client_id, request_key.0), REQUEST_TIMEOUT); + let timer_key = self.timer.insert((client_id, request_key), REQUEST_TIMEOUT); - *request_state = RequestState::InFlight { + *node.value_mut() = RequestState::InFlight { sender_client_id: client_id, sender_timer_key: timer_key, waiting: VecDeque::new(), }; - client_state.requests.insert(request_key.0, request_key.1); - client_state.request_tx.send(request.clone()).ok(); + client_state.requests.insert(request_key, node_key); + client_state.request_tx.send(node.request().clone()).ok(); } } - // NOTE: we are using recursion, but the graph is only a few layers deep (currently 5) so + // Note: we are using recursion, but the graph is only a few layers deep (currently 5) so // there is no danger of stack overflow. + let children: Vec<_> = node.children().collect(); + for child_key in children { - self.insert_request(client_id, child_key, None, Some(request_key)); + self.update_request(client_id, child_key); } } - fn cancel_request(&mut self, client_id: ClientId, request_key: GraphKey) { - let GraphEntry::Occupied(mut entry) = self.requests.entry(request_key) else { + fn cancel_request(&mut self, client_id: ClientId, node_key: GraphKey) { + let Some(node) = self.requests.get_mut(node_key) else { return; }; - let parent_keys = match entry.get_mut() { + let (request, state) = node.request_and_value_mut(); + + let remove = match state { RequestState::InFlight { sender_client_id, sender_timer_key, @@ -426,23 +427,20 @@ impl Worker { *sender_client_id = next_client_id; *sender_timer_key = self .timer - .insert((next_client_id, request_key.0), REQUEST_TIMEOUT); + .insert((next_client_id, MessageKey::from(request)), REQUEST_TIMEOUT); // Send the request to the new sender. - next_client_state - .request_tx - .send(entry.request().clone()) - .ok(); + next_client_state.request_tx.send(request.clone()).ok(); - return; + false } else { // No waiting client found. If this request has no children, we can remove // it, otherwise we mark it as cancelled. - if !entry.children().is_empty() { - *entry.get_mut() = RequestState::Cancelled; - return; + if node.children().len() > 0 { + *node.value_mut() = RequestState::Cancelled; + false } else { - entry.remove().parents + true } } } else { @@ -455,34 +453,32 @@ impl Worker { waiting.remove(index); } - return; + false } } - RequestState::Complete | RequestState::Cancelled => return, + RequestState::Complete | RequestState::Cancelled => false, }; - for parent_key in parent_keys { - self.request_removed(request_key, parent_key); + if remove { + self.remove_request(node_key); } } - // Remove the request from its parent request and if it was the last child of the parent, remove - // the parent as well, recursively. - fn request_removed(&mut self, request_key: GraphKey, parent_key: GraphKey) { - let mut entry = match self.requests.entry(parent_key) { - GraphEntry::Occupied(entry) => entry, - GraphEntry::Vacant(_) => return, + fn remove_request(&mut self, node_key: GraphKey) { + let Some(node) = self.requests.remove(node_key) else { + return; }; - entry.remove_child(&request_key); + for parent_key in node.parents() { + let Some(parent_node) = self.requests.get(parent_key) else { + continue; + }; - if !entry.children().is_empty() { - return; - } + if parent_node.children().len() > 0 { + continue; + } - let grandparent_keys: Vec<_> = entry.parents().iter().copied().collect(); - for grandparent_key in grandparent_keys { - self.request_removed(parent_key, grandparent_key); + self.remove_request(parent_key); } } } @@ -512,18 +508,18 @@ enum Command { }, HandleSuccess { client_id: ClientId, - response_key: MessageKey, + request_key: MessageKey, requests: Vec<(Request, MultiBlockPresence)>, }, HandleFailure { client_id: ClientId, - response_key: MessageKey, + request_key: MessageKey, }, } struct ClientState { request_tx: mpsc::UnboundedSender, - requests: HashMap, + requests: HashMap, } impl ClientState { @@ -551,3 +547,9 @@ enum RequestState { /// The response for the current client failed and there are no more clients waiting. Cancelled, } + +#[derive(Debug)] +enum FailureReason { + Response, + Timeout, +} diff --git a/lib/src/network/request_tracker/graph.rs b/lib/src/network/request_tracker/graph.rs index 59af83f88..5ef3a8d64 100644 --- a/lib/src/network/request_tracker/graph.rs +++ b/lib/src/network/request_tracker/graph.rs @@ -1,101 +1,212 @@ use super::MessageKey; -use crate::{collections::HashSet, network::message::Request, protocol::MultiBlockPresence}; -use std::marker::PhantomData; +use crate::{ + collections::{HashMap, HashSet}, + network::message::Request, + protocol::MultiBlockPresence, +}; +use slab::Slab; +use std::collections::hash_map::Entry; /// DAG for storing data for the request tracker. pub(super) struct Graph { - _todo: PhantomData, + index: HashMap<(MessageKey, MultiBlockPresence), Key>, + nodes: Slab>, } impl Graph { pub fn new() -> Self { - todo!() + Self { + index: HashMap::default(), + nodes: Slab::new(), + } } - pub fn entry(&mut self, _key: Key) -> Entry<'_, T> { - todo!() + pub fn get_or_insert( + &mut self, + request: Request, + block_presence: MultiBlockPresence, + parent_key: Option, + value: T, + ) -> Key { + let entry = match self + .index + .entry((MessageKey::from(&request), block_presence)) + { + Entry::Occupied(entry) => return *entry.get(), + Entry::Vacant(entry) => entry, + }; + + let node_key = self.nodes.insert(Node { + request, + block_presence, + parents: parent_key.into_iter().collect(), + children: HashSet::default(), + value, + }); + let node_key = Key(node_key); + + entry.insert(node_key); + + if let Some(parent_key) = parent_key { + if let Some(parent_node) = self.nodes.get_mut(parent_key.0) { + parent_node.children.insert(node_key); + } + } + + node_key } - #[cfg_attr(not(test), expect(dead_code))] - pub fn len(&self) -> usize { - todo!() + pub fn get(&self, key: Key) -> Option<&Node> { + self.nodes.get(key.0) } -} -#[derive(Clone, Copy, Eq, PartialEq, Hash)] -pub(super) struct Key(pub MessageKey, pub MultiBlockPresence); + pub fn get_mut(&mut self, key: Key) -> Option<&mut Node> { + self.nodes.get_mut(key.0) + } -pub(super) enum Entry<'a, T> { - #[expect(dead_code)] - Occupied(OccupiedEntry<'a, T>), - #[expect(dead_code)] - Vacant(VacantEntry<'a, T>), -} + pub fn remove(&mut self, key: Key) -> Option> { + let node = self.nodes.try_remove(key.0)?; -pub(super) struct OccupiedEntry<'a, T> { - _todo: PhantomData<&'a mut T>, -} + self.index + .remove(&(MessageKey::from(&node.request), node.block_presence)); -impl<'a, T> OccupiedEntry<'a, T> { - #[expect(dead_code)] - pub fn get(&self) -> &T { - todo!() - } + for parent_key in &node.parents { + let Some(parent_node) = self.nodes.get_mut(parent_key.0) else { + continue; + }; - pub fn get_mut(&mut self) -> &mut T { - todo!() - } + parent_node.children.remove(&key); + } - pub fn into_mut(self) -> &'a mut T { - todo!() + Some(node) } - pub fn request(&self) -> &'a Request { - todo!() + #[cfg_attr(not(test), expect(dead_code))] + pub fn len(&self) -> usize { + self.nodes.len() } +} + +#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] +pub(super) struct Key(usize); + +pub(super) struct Node { + request: Request, + block_presence: MultiBlockPresence, + parents: HashSet, + children: HashSet, + value: T, +} - pub fn parents(&self) -> &HashSet { - todo!() +impl Node { + #[cfg_attr(not(test), expect(dead_code))] + pub fn value(&self) -> &T { + &self.value } - pub fn insert_parent(&mut self, _parent: Key) { - todo!() + pub fn value_mut(&mut self) -> &mut T { + &mut self.value } - pub fn children(&self) -> &HashSet { - todo!(); + pub fn request(&self) -> &Request { + &self.request } - pub fn insert_child(&mut self, _child: Key) { - todo!() + pub fn request_and_value_mut(&mut self) -> (&Request, &mut T) { + (&self.request, &mut self.value) } - pub fn remove_child(&mut self, _child: &Key) { - todo!() + pub fn parents(&self) -> impl ExactSizeIterator + '_ { + self.parents.iter().copied() } - pub fn remove(self) -> RemovedEntry { - todo!() + pub fn children(&self) -> impl ExactSizeIterator + '_ { + self.children.iter().copied() } } -pub(super) struct VacantEntry<'a, T> { - _todo: PhantomData<&'a mut T>, -} +#[cfg(test)] +mod tests { + use super::*; + use crate::network::{debug_payload::DebugRequest, message::ResponseDisambiguator}; + use rand::Rng; + + #[test] + fn child_request() { + let mut rng = rand::thread_rng(); + let mut graph = Graph::new(); + + assert_eq!(graph.len(), 0); -impl<'a, T> VacantEntry<'a, T> { - pub fn insert( - self, - _request: Request, - _parent: Option, - _value: T, - ) -> (&'a Request, &'a mut T) { - todo!() + let request0 = Request::ChildNodes( + rng.gen(), + ResponseDisambiguator::new(MultiBlockPresence::Full), + DebugRequest::start(), + ); + + let node_key0 = graph.get_or_insert(request0.clone(), MultiBlockPresence::Full, None, 1); + + assert_eq!(graph.len(), 1); + + let Some(node) = graph.get(node_key0) else { + unreachable!() + }; + + assert_eq!(*node.value(), 1); + assert_eq!(node.children().len(), 0); + assert_eq!(node.request(), &request0); + + let request1 = Request::ChildNodes( + rng.gen(), + ResponseDisambiguator::new(MultiBlockPresence::Full), + DebugRequest::start(), + ); + + let node_key1 = graph.get_or_insert( + request1.clone(), + MultiBlockPresence::Full, + Some(node_key0), + 2, + ); + + assert_eq!(graph.len(), 2); + + let Some(node) = graph.get(node_key1) else { + unreachable!() + }; + + assert_eq!(*node.value(), 2); + assert_eq!(node.children().len(), 0); + assert_eq!(node.request(), &request1); + + assert_eq!( + graph.get(node_key0).unwrap().children().collect::>(), + [node_key1] + ); + + graph.remove(node_key1); + + assert_eq!(graph.get(node_key0).unwrap().children().len(), 0); } -} -pub(super) struct RemovedEntry { - #[expect(dead_code)] - pub value: T, - pub parents: HashSet, + #[test] + fn duplicate_request() { + let mut rng = rand::thread_rng(); + let mut graph = Graph::new(); + + assert_eq!(graph.len(), 0); + + let request = Request::ChildNodes( + rng.gen(), + ResponseDisambiguator::new(MultiBlockPresence::Full), + DebugRequest::start(), + ); + + let node_key0 = graph.get_or_insert(request.clone(), MultiBlockPresence::Full, None, 1); + assert_eq!(graph.len(), 1); + + let node_key1 = graph.get_or_insert(request, MultiBlockPresence::Full, None, 1); + assert_eq!(graph.len(), 1); + assert_eq!(node_key0, node_key1); + } } diff --git a/lib/src/network/request_tracker/tests.rs b/lib/src/network/request_tracker/tests.rs index 85f1c07f0..8e1da5df5 100644 --- a/lib/src/network/request_tracker/tests.rs +++ b/lib/src/network/request_tracker/tests.rs @@ -16,27 +16,14 @@ use std::collections::VecDeque; // needs a tokio runtime. #[tokio::test] async fn simulation() { - simulation_case(6045920800462135606, 1, 2, (1, 10), (1, 10)); - // let seed = rand::random(); - // simulation_case(seed, 32, 4, (1, 10), (1, 10)); + let seed = rand::random(); + simulation_case(seed, 64, 4); } -fn simulation_case( - seed: u64, - max_blocks: usize, - max_peers: usize, - peer_insert_ratio: (u32, u32), - peer_remove_ratio: (u32, u32), -) { +fn simulation_case(seed: u64, max_blocks: usize, expected_peer_changes: usize) { test_utils::init_log(); - tracing::info!( - seed, - max_blocks, - max_peers, - peer_insert_ratio = ?peer_insert_ratio, - peer_remove_ratio = ?peer_remove_ratio, - ); + tracing::info!(seed, max_blocks, expected_peer_changes); let mut rng = StdRng::seed_from_u64(seed); @@ -44,43 +31,74 @@ fn simulation_case( let block_count = rng.gen_range(1..=max_blocks); let snapshot = Snapshot::generate(&mut rng, block_count); + let mut summary = Summary::new(snapshot.blocks().len()); tracing::info!(?snapshot); let mut peers = Vec::new(); - let mut total_peer_count = 0; - let mut summary = Summary::new(snapshot.blocks().len()); - - loop { - let peers_len_before = peers.len(); - - if peers.is_empty() - || (peers.len() < max_peers && rng.gen_ratio(peer_insert_ratio.0, peer_insert_ratio.1)) - { - tracing::info!("insert peer"); + let mut total_peer_count = 0; // total number of peers that participated in the simulation + + // Action to perform on the set of peers. + #[derive(Debug)] + enum Action { + // Insert a new peer + Insert, + // Remove a random peer + Remove, + // Keep the peer set intact + Keep, + } - let (tracker_client, tracker_request_rx) = tracker.new_client(); - let client = TestClient::new(tracker_client, tracker_request_rx); + // Total number of simulation steps is the number of index nodes plus the number of blocks in + // the snapshot. This is used to calculate the probability of the next action. + let steps = 1 + snapshot.inner_count() + snapshot.leaf_count() + snapshot.blocks().len(); + + for tick in 0.. { + let _enter = tracing::info_span!("tick", message = tick); + + // Generate the next action. The probability of `Insert` or `Remove` is chosen such that the + // expected number of such actions in the simulation is equal to `expected_peer_changes`. + // Both `Insert` and `Remove` have currently the same probability. + let action = if rng.gen_range(0..steps) < expected_peer_changes { + if rng.gen() { + Action::Insert + } else { + Action::Remove + } + } else { + Action::Keep + }; - let writer_id = PublicKey::generate(&mut rng); - let write_keys = Keypair::generate(&mut rng); - let server = TestServer::new(writer_id, write_keys, &snapshot); + match action { + Action::Insert => { + let (tracker_client, tracker_request_rx) = tracker.new_client(); + let client = TestClient::new(tracker_client, tracker_request_rx); - peers.push((client, server)); - } + let writer_id = PublicKey::generate(&mut rng); + let write_keys = Keypair::generate(&mut rng); + let server = TestServer::new(writer_id, write_keys, &snapshot); - if peers.len() > 1 && rng.gen_ratio(peer_remove_ratio.0, peer_remove_ratio.1) { - tracing::info!("remove peer"); + peers.push((client, server)); + total_peer_count += 1; + } + Action::Remove => { + if peers.len() < 2 { + continue; + } - let index = rng.gen_range(0..peers.len()); - peers.remove(index); + let index = rng.gen_range(0..peers.len()); + peers.remove(index); + } + Action::Keep => { + if peers.is_empty() { + continue; + } + } } - // Note some peers might be inserted and removed in the same tick. Such peers are discounted - // from the total because they would not send/receive any messages. - total_peer_count += peers.len().saturating_sub(peers_len_before); + let polled = poll_peers(&mut rng, &mut peers, &snapshot, &mut summary); - if poll_peers(&mut rng, &mut peers, &snapshot, &mut summary) { + if polled || matches!(action, Action::Remove) { tracker_worker.step(); } else { break; @@ -106,19 +124,12 @@ impl Summary { } } - fn receive_node(&mut self, hash: Hash) { - if self.blocks.len() >= self.expected_blocks { - return; - } - + fn receive_node(&mut self, hash: Hash) -> bool { *self.nodes.entry(hash).or_default() += 1; + self.blocks.len() < self.expected_blocks } fn receive_block(&mut self, block_id: BlockId) { - if self.blocks.len() >= self.expected_blocks { - return; - } - *self.blocks.entry(block_id).or_default() += 1; } @@ -183,16 +194,18 @@ impl TestClient { fn handle_response(&mut self, response: Response, summary: &mut Summary) { match response { Response::RootNode(proof, block_presence, debug_payload) => { - summary.receive_node(proof.hash); - - let requests = vec![( - Request::ChildNodes( - proof.hash, - ResponseDisambiguator::new(block_presence), - debug_payload.follow_up(), - ), - block_presence, - )]; + let requests = summary + .receive_node(proof.hash) + .then_some(( + Request::ChildNodes( + proof.hash, + ResponseDisambiguator::new(block_presence), + debug_payload.follow_up(), + ), + block_presence, + )) + .into_iter() + .collect(); self.tracker_client .success(MessageKey::RootNode(proof.writer_id), requests); @@ -201,17 +214,15 @@ impl TestClient { let parent_hash = nodes.hash(); let requests: Vec<_> = nodes .into_iter() - .map(|(_, node)| { - summary.receive_node(node.hash); - - ( + .filter_map(|(_, node)| { + summary.receive_node(node.hash).then_some(( Request::ChildNodes( node.hash, ResponseDisambiguator::new(node.summary.block_presence), debug_payload.follow_up(), ), node.summary.block_presence, - ) + )) }) .collect(); @@ -222,13 +233,11 @@ impl TestClient { let parent_hash = nodes.hash(); let requests = nodes .into_iter() - .map(|node| { - summary.receive_node(node.locator); - - ( + .filter_map(|node| { + summary.receive_node(node.locator).then_some(( Request::Block(node.block_id, debug_payload.follow_up()), MultiBlockPresence::None, - ) + )) }) .collect(); @@ -322,9 +331,7 @@ impl TestServer { disambiguator, debug_payload.reply(), )); - } - - if let Some(nodes) = snapshot + } else if let Some(nodes) = snapshot .leaf_sets() .find_map(|(parent_hash, nodes)| (*parent_hash == hash).then_some(nodes)) { @@ -368,6 +375,7 @@ fn poll_peers( snapshot: &Snapshot, summary: &mut Summary, ) -> bool { + #[derive(Debug)] enum Side { Client, Server, diff --git a/lib/src/protocol/test_utils.rs b/lib/src/protocol/test_utils.rs index f9ac23580..ab638dbfd 100644 --- a/lib/src/protocol/test_utils.rs +++ b/lib/src/protocol/test_utils.rs @@ -112,6 +112,14 @@ impl Snapshot { .map(|(_, node)| node) } + pub fn inner_count(&self) -> usize { + self.inners + .iter() + .flat_map(|layer| layer.values()) + .map(|nodes| nodes.len()) + .sum() + } + pub fn blocks(&self) -> &HashMap { &self.blocks } From f2d5d492ee3cff5b678d2c346ac55e2ad3e9d581 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Tue, 24 Sep 2024 08:45:38 +0200 Subject: [PATCH 31/55] Support missing blocks in test_utils::Snapshot --- lib/src/network/request_tracker/tests.rs | 3 + lib/src/protocol/test_utils.rs | 87 +++++++++++++++++++----- lib/src/repository/vault/tests.rs | 30 ++++---- lib/src/store/client.rs | 2 +- 4 files changed, 90 insertions(+), 32 deletions(-) diff --git a/lib/src/network/request_tracker/tests.rs b/lib/src/network/request_tracker/tests.rs index 8e1da5df5..393f62d0c 100644 --- a/lib/src/network/request_tracker/tests.rs +++ b/lib/src/network/request_tracker/tests.rs @@ -109,6 +109,9 @@ fn simulation_case(seed: u64, max_blocks: usize, expected_peer_changes: usize) { assert_eq!(tracker_worker.request_count(), 0); } +// TODO: test multiple peers with different block summaries +// TODO: test failure/timeout + struct Summary { expected_blocks: usize, nodes: HashMap, diff --git a/lib/src/protocol/test_utils.rs b/lib/src/protocol/test_utils.rs index ab638dbfd..2e8cf7683 100644 --- a/lib/src/protocol/test_utils.rs +++ b/lib/src/protocol/test_utils.rs @@ -7,12 +7,13 @@ use crate::{ }, }; use rand::{distributions::Standard, Rng}; -use std::{fmt, mem}; +use std::{borrow::Cow, fmt, mem}; // In-memory snapshot for testing purposes. #[derive(Clone)] pub(crate) struct Snapshot { root_hash: Hash, + root_block_presence: MultiBlockPresence, inners: [HashMap; INNER_LAYER_COUNT], leaves: HashMap, blocks: HashMap, @@ -21,52 +22,78 @@ pub(crate) struct Snapshot { impl Snapshot { // Generate a random snapshot with the given number of blocks. pub fn generate(rng: &mut R, block_count: usize) -> Self { - Self::new(rng.sample_iter(Standard).take(block_count)) + Self::from_present_blocks(rng.sample_iter(Standard).take(block_count)) + } + + pub fn from_present_blocks( + locators_and_blocks: impl IntoIterator, + ) -> Self { + Self::from_blocks( + locators_and_blocks + .into_iter() + .map(|(locator, block)| (locator, BlockState::Present(block))), + ) } // Create snapshot given an iterator of blocks where each block is associated to its encoded // locator. - pub fn new(locators_and_blocks: impl IntoIterator) -> Self { + pub fn from_blocks(locators_and_blocks: impl IntoIterator) -> Self { let mut blocks = HashMap::default(); let mut leaves = HashMap::default(); for (locator, block) in locators_and_blocks { - let id = block.id; - blocks.insert(id, block); + let block_id = *block.id(); + let block_presence = block.presence(); + + match block { + BlockState::Present(block) => { + blocks.insert(block_id, block); + } + BlockState::Missing(_) => (), + } - let node = LeafNode::present(locator, id); leaves - .entry(BucketPath::new(&node.locator, INNER_LAYER_COUNT - 1)) + .entry(BucketPath::new(&locator, INNER_LAYER_COUNT - 1)) .or_insert_with(LeafNodes::default) - .insert(node.locator, node.block_id, SingleBlockPresence::Present); + .insert(locator, block_id, block_presence); } let mut inners: [HashMap<_, InnerNodes>; INNER_LAYER_COUNT] = Default::default(); - for (path, set) in &leaves { + for (path, nodes) in &leaves { add_inner_node( INNER_LAYER_COUNT - 1, &mut inners[INNER_LAYER_COUNT - 1], path, - set.hash(), + nodes.hash(), + Summary::from_leaves(nodes).block_presence, ); } for layer in (0..INNER_LAYER_COUNT - 1).rev() { let (lo, hi) = inners.split_at_mut(layer + 1); - for (path, map) in &hi[0] { - add_inner_node(layer, lo.last_mut().unwrap(), path, map.hash()); + for (path, nodes) in &hi[0] { + add_inner_node( + layer, + lo.last_mut().unwrap(), + path, + nodes.hash(), + Summary::from_inners(nodes).block_presence, + ); } } - let root_hash = inners[0] + let nodes = inners[0] .get(&BucketPath::default()) - .unwrap_or(&InnerNodes::default()) - .hash(); + .map(Cow::Borrowed) + .unwrap_or(Cow::Owned(InnerNodes::default())); + let root_hash = nodes.hash(); + let root_block_presence = Summary::from_inners(&nodes).block_presence; Self { root_hash, + root_block_presence, inners, leaves, blocks, @@ -77,6 +104,11 @@ impl Snapshot { &self.root_hash } + #[expect(dead_code)] + pub fn root_block_presence(&self) -> &MultiBlockPresence { + &self.root_block_presence + } + pub fn leaf_sets(&self) -> impl Iterator { self.leaves.iter().map(move |(path, nodes)| { let parent_hash = self.parent_hash(INNER_LAYER_COUNT, path); @@ -160,11 +192,34 @@ impl<'a> InnerLayer<'a> { } } +pub(crate) enum BlockState { + Present(Block), + #[expect(dead_code)] + Missing(BlockId), +} + +impl BlockState { + pub fn id(&self) -> &BlockId { + match self { + Self::Present(block) => &block.id, + Self::Missing(block_id) => block_id, + } + } + + pub fn presence(&self) -> SingleBlockPresence { + match self { + Self::Present(_) => SingleBlockPresence::Present, + Self::Missing(_) => SingleBlockPresence::Missing, + } + } +} + fn add_inner_node( inner_layer: usize, maps: &mut HashMap, path: &BucketPath, hash: Hash, + block_presence: MultiBlockPresence, ) { let (bucket, parent_path) = path.pop(inner_layer); maps.entry(parent_path).or_default().insert( @@ -173,7 +228,7 @@ fn add_inner_node( hash, Summary { state: NodeState::Complete, - block_presence: MultiBlockPresence::Full, + block_presence, }, ), ); diff --git a/lib/src/repository/vault/tests.rs b/lib/src/repository/vault/tests.rs index 6aa4b602c..8a4291bd1 100644 --- a/lib/src/repository/vault/tests.rs +++ b/lib/src/repository/vault/tests.rs @@ -87,14 +87,14 @@ async fn prune_snapshots_insert_present() { // snapshot 1 let mut blocks = vec![rng.gen()]; - let snapshot = Snapshot::new(blocks.clone()); + let snapshot = Snapshot::from_present_blocks(blocks.clone()); receive_snapshot(&vault, remote_id, &snapshot, &secrets.write_keys).await; receive_block(&vault, &blocks[0].1).await; // snapshot 2 (insert new block) blocks.push(rng.gen()); - let snapshot = Snapshot::new(blocks.clone()); + let snapshot = Snapshot::from_present_blocks(blocks.clone()); receive_snapshot(&vault, remote_id, &snapshot, &secrets.write_keys).await; receive_block(&vault, &blocks[1].1).await; @@ -115,14 +115,14 @@ async fn prune_snapshots_insert_missing() { // snapshot 1 let mut blocks = vec![rng.gen()]; - let snapshot = Snapshot::new(blocks.clone()); + let snapshot = Snapshot::from_present_blocks(blocks.clone()); receive_snapshot(&index, remote_id, &snapshot, &secrets.write_keys).await; receive_block(&index, &blocks[0].1).await; // snapshot 2 (insert new block) blocks.push(rng.gen()); - let snapshot = Snapshot::new(blocks.clone()); + let snapshot = Snapshot::from_present_blocks(blocks.clone()); receive_snapshot(&index, remote_id, &snapshot, &secrets.write_keys).await; // don't receive the new block @@ -145,14 +145,14 @@ async fn prune_snapshots_update_from_present_to_present() { // snapshot 1 let mut blocks = [rng.gen()]; - let snapshot = Snapshot::new(blocks.clone()); + let snapshot = Snapshot::from_present_blocks(blocks.clone()); receive_snapshot(&index, remote_id, &snapshot, &secrets.write_keys).await; receive_block(&index, &blocks[0].1).await; // snapshot 2 (update the first block) blocks[0].1 = rng.gen(); - let snapshot = Snapshot::new(blocks.clone()); + let snapshot = Snapshot::from_present_blocks(blocks.clone()); receive_snapshot(&index, remote_id, &snapshot, &secrets.write_keys).await; receive_block(&index, &blocks[0].1).await; @@ -173,14 +173,14 @@ async fn prune_snapshots_update_from_present_to_missing() { // snapshot 1 let mut blocks = [rng.gen()]; - let snapshot = Snapshot::new(blocks.clone()); + let snapshot = Snapshot::from_present_blocks(blocks.clone()); receive_snapshot(&index, remote_id, &snapshot, &secrets.write_keys).await; receive_block(&index, &blocks[0].1).await; // snapshot 2 (update the first block) blocks[0].1 = rng.gen(); - let snapshot = Snapshot::new(blocks); + let snapshot = Snapshot::from_present_blocks(blocks); receive_snapshot(&index, remote_id, &snapshot, &secrets.write_keys).await; // don't receive the new block @@ -204,14 +204,14 @@ async fn prune_snapshots_update_from_missing_to_missing() { // snapshot 1 let mut blocks = [rng.gen()]; - let snapshot = Snapshot::new(blocks.clone()); + let snapshot = Snapshot::from_present_blocks(blocks.clone()); receive_snapshot(&index, remote_id, &snapshot, &secrets.write_keys).await; // don't receive the block // snapshot 2 (update the first block) blocks[0].1 = rng.gen(); - let snapshot = Snapshot::new(blocks); + let snapshot = Snapshot::from_present_blocks(blocks); receive_snapshot(&index, remote_id, &snapshot, &secrets.write_keys).await; // don't receive the new block @@ -235,14 +235,14 @@ async fn prune_snapshots_keep_missing_and_insert_missing() { // snapshot 1 let mut blocks = vec![rng.gen()]; - let snapshot = Snapshot::new(blocks.clone()); + let snapshot = Snapshot::from_present_blocks(blocks.clone()); receive_snapshot(&index, remote_id, &snapshot, &secrets.write_keys).await; // don't receive the block // snapshot 2 (insert new block) blocks.push(rng.gen()); - let snapshot = Snapshot::new(blocks); + let snapshot = Snapshot::from_present_blocks(blocks); receive_snapshot(&index, remote_id, &snapshot, &secrets.write_keys).await; // don't receive the new block @@ -388,8 +388,8 @@ async fn block_ids_multiple_branches() { let blocks_0 = &all_blocks[..2]; let blocks_1 = &all_blocks[1..]; - let snapshot_0 = Snapshot::new(blocks_0.iter().cloned()); - let snapshot_1 = Snapshot::new(blocks_1.iter().cloned()); + let snapshot_0 = Snapshot::from_present_blocks(blocks_0.iter().cloned()); + let snapshot_1 = Snapshot::from_present_blocks(blocks_1.iter().cloned()); SnapshotWriter::begin(vault.store(), &snapshot_0) .await @@ -494,7 +494,7 @@ async fn sync_progress_case(block_count: usize, branch_count: usize, rng_seed: u .map(|_| { let block_count = rng.gen_range(0..block_count); let blocks = all_blocks.choose_multiple(&mut rng, block_count).cloned(); - let snapshot = Snapshot::new(blocks); + let snapshot = Snapshot::from_present_blocks(blocks); let branch_id = PublicKey::generate(&mut rng); (branch_id, snapshot) diff --git a/lib/src/store/client.rs b/lib/src/store/client.rs index ae1ee3856..6f1246568 100644 --- a/lib/src/store/client.rs +++ b/lib/src/store/client.rs @@ -1025,7 +1025,7 @@ mod tests { assert_eq!(node.summary.state, NodeState::Incomplete); // Remove the block - let snapshot = Snapshot::new( + let snapshot = Snapshot::from_present_blocks( snapshot .locators_and_blocks() .filter(|(_, block)| block.id != block_to_remove) From 9737d5886d51e5437ed163df475f8d0cef1a73f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Tue, 24 Sep 2024 18:06:02 +0200 Subject: [PATCH 32/55] Add failing tests for handling missing blocks in RequestTracker --- lib/src/network/request_tracker.rs | 132 +++---- lib/src/network/request_tracker/graph.rs | 262 ++++++++++++-- lib/src/network/request_tracker/tests.rs | 418 ++++++++++++++++------- lib/src/protocol/test_utils.rs | 4 +- 4 files changed, 585 insertions(+), 231 deletions(-) diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index 9148c3b9e..2f05c724d 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -109,7 +109,7 @@ impl Drop for RequestTrackerClient { } /// Key identifying a request and its corresponding response. -#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] +#[derive(Clone, Copy, Eq, PartialEq, Hash, Ord, PartialOrd, Debug)] pub(super) enum MessageKey { RootNode(PublicKey), ChildNodes(Hash), @@ -175,8 +175,8 @@ impl Worker { } #[cfg(test)] - pub fn request_count(&self) -> usize { - self.requests.len() + pub fn requests(&self) -> impl ExactSizeIterator { + self.requests.requests() } fn handle_command(&mut self, command: Command) { @@ -215,13 +215,13 @@ impl Worker { #[instrument(skip(self, request_tx))] fn insert_client(&mut self, client_id: ClientId, request_tx: mpsc::UnboundedSender) { - // tracing::debug!("insert_client"); + tracing::debug!("insert_client"); self.clients.insert(client_id, ClientState::new(request_tx)); } #[instrument(skip(self))] fn remove_client(&mut self, client_id: ClientId) { - // tracing::debug!("remove_client"); + tracing::debug!("remove_client"); let Some(client_state) = self.clients.remove(&client_id) else { return; @@ -250,52 +250,58 @@ impl Worker { request_key: MessageKey, requests: Vec<(Request, MultiBlockPresence)>, ) { - // tracing::debug!("handle_success"); + tracing::debug!("handle_success"); let node_key = self .clients .get_mut(&client_id) - .and_then(|client_state| client_state.requests.remove(&request_key)); - - let mut client_ids = if let Some(node_key) = node_key { - let (client_ids, remove) = match self.requests.get_mut(node_key) { - Some(node) => match node.value_mut() { - RequestState::InFlight { - sender_client_id, - sender_timer_key, - waiting, - } if *sender_client_id == client_id => { - self.timer.try_remove(sender_timer_key); - - let waiting = mem::take(waiting); - - // If this request has, or will have, children, mark it as complete, otherwise - // remove it. - let remove = if node.children().len() > 0 || !requests.is_empty() { - *node.value_mut() = RequestState::Complete; - false - } else { - true - }; - - (waiting, remove) - } - RequestState::InFlight { .. } - | RequestState::Complete - | RequestState::Cancelled => return, - }, - None => return, + .and_then(|state| state.requests.remove(&request_key)); + + let (mut client_ids, remove_key) = if let Some(node_key) = node_key { + let Some(node) = self.requests.get_mut(node_key) else { + return; }; - if remove { - self.remove_request(node_key); - } + match node.value_mut() { + RequestState::InFlight { + sender_client_id, + sender_timer_key, + waiters, + } if *sender_client_id == client_id => { + self.timer.try_remove(sender_timer_key); + + let waiters = mem::take(waiters); + let remove_key = if node.children().len() > 0 || !requests.is_empty() { + *node.value_mut() = RequestState::Complete; + None + } else { + Some(node_key) + }; - client_ids + (waiters, remove_key) + } + RequestState::InFlight { waiters, .. } => { + remove_from_queue(waiters, &client_id); + return; + } + RequestState::Complete | RequestState::Cancelled => return, + } } else { - Default::default() + (Default::default(), None) }; + // Remove the node from the other waiting clients, if any. + for client_id in &client_ids { + if let Some(state) = self.clients.get_mut(client_id) { + state.requests.remove(&request_key); + } + } + + // If the node has no children, remove it. + if let Some(node_key) = remove_key { + self.remove_request(node_key); + } + // Register the followup (child) requests with this client but also with all the clients that // were waiting for the original request. client_ids.push_front(client_id); @@ -325,7 +331,7 @@ impl Worker { request_key: MessageKey, reason: FailureReason, ) { - // tracing::debug!("handle_failure"); + tracing::debug!("handle_failure"); let Some(client_state) = self.clients.get_mut(&client_id) else { return; @@ -367,8 +373,8 @@ impl Worker { let request_key = MessageKey::from(node.request()); match node.value_mut() { - RequestState::InFlight { waiting, .. } => { - waiting.push_back(client_id); + RequestState::InFlight { waiters, .. } => { + waiters.push_back(client_id); client_state.requests.insert(request_key, node_key); } RequestState::Complete => (), @@ -378,7 +384,7 @@ impl Worker { *node.value_mut() = RequestState::InFlight { sender_client_id: client_id, sender_timer_key: timer_key, - waiting: VecDeque::new(), + waiters: VecDeque::new(), }; client_state.requests.insert(request_key, node_key); @@ -402,11 +408,11 @@ impl Worker { let (request, state) = node.request_and_value_mut(); - let remove = match state { + match state { RequestState::InFlight { sender_client_id, sender_timer_key, - waiting, + waiters, } => { if *sender_client_id == client_id { // The removed client is the current sender of this request. @@ -415,7 +421,7 @@ impl Worker { self.timer.try_remove(sender_timer_key); // Find a waiting client - let next_client = iter::from_fn(|| waiting.pop_front()).find_map(|client_id| { + let next_client = iter::from_fn(|| waiters.pop_front()).find_map(|client_id| { self.clients .get(&client_id) .map(|client_state| (client_id, client_state)) @@ -432,36 +438,26 @@ impl Worker { // Send the request to the new sender. next_client_state.request_tx.send(request.clone()).ok(); - false + return; } else { // No waiting client found. If this request has no children, we can remove // it, otherwise we mark it as cancelled. if node.children().len() > 0 { *node.value_mut() = RequestState::Cancelled; - false - } else { - true + return; } } } else { // The removed client is one of the waiting clients - remove it from the // waiting queue. - if let Some(index) = waiting - .iter() - .position(|next_client_id| *next_client_id == client_id) - { - waiting.remove(index); - } - - false + remove_from_queue(waiters, &client_id); + return; } } - RequestState::Complete | RequestState::Cancelled => false, + RequestState::Complete | RequestState::Cancelled => return, }; - if remove { - self.remove_request(node_key); - } + self.remove_request(node_key); } fn remove_request(&mut self, node_key: GraphKey) { @@ -540,7 +536,7 @@ enum RequestState { sender_timer_key: delay_queue::Key, /// Other clients interested in sending this request. If the current client fails or /// timeouts, a new one will be picked from this list. - waiting: VecDeque, + waiters: VecDeque, }, /// The response to this request has already been received. Complete, @@ -553,3 +549,9 @@ enum FailureReason { Response, Timeout, } + +fn remove_from_queue(queue: &mut VecDeque, item: &T) { + if let Some(index) = queue.iter().position(|other| other == item) { + queue.remove(index); + } +} diff --git a/lib/src/network/request_tracker/graph.rs b/lib/src/network/request_tracker/graph.rs index 5ef3a8d64..b09a3cc9c 100644 --- a/lib/src/network/request_tracker/graph.rs +++ b/lib/src/network/request_tracker/graph.rs @@ -28,25 +28,35 @@ impl Graph { parent_key: Option, value: T, ) -> Key { - let entry = match self + let node_key = match self .index .entry((MessageKey::from(&request), block_presence)) { - Entry::Occupied(entry) => return *entry.get(), - Entry::Vacant(entry) => entry, + Entry::Occupied(entry) => { + self.nodes + .get_mut(entry.get().0) + .expect("dangling index entry") + .parents + .extend(parent_key); + + *entry.get() + } + Entry::Vacant(entry) => { + let node_key = self.nodes.insert(Node { + request, + block_presence, + parents: parent_key.into_iter().collect(), + children: HashSet::default(), + value, + }); + let node_key = Key(node_key); + + entry.insert(node_key); + + node_key + } }; - let node_key = self.nodes.insert(Node { - request, - block_presence, - parents: parent_key.into_iter().collect(), - children: HashSet::default(), - value, - }); - let node_key = Key(node_key); - - entry.insert(node_key); - if let Some(parent_key) = parent_key { if let Some(parent_node) = self.nodes.get_mut(parent_key.0) { parent_node.children.insert(node_key); @@ -78,12 +88,20 @@ impl Graph { parent_node.children.remove(&key); } + for child_key in &node.children { + let Some(child_node) = self.nodes.get_mut(child_key.0) else { + continue; + }; + + child_node.parents.remove(&key); + } + Some(node) } #[cfg_attr(not(test), expect(dead_code))] - pub fn len(&self) -> usize { - self.nodes.len() + pub fn requests(&self) -> impl ExactSizeIterator { + self.nodes.iter().map(|(_, node)| &node.request) } } @@ -136,57 +154,62 @@ mod tests { let mut rng = rand::thread_rng(); let mut graph = Graph::new(); - assert_eq!(graph.len(), 0); + assert_eq!(graph.requests().len(), 0); - let request0 = Request::ChildNodes( + let parent_request = Request::ChildNodes( rng.gen(), ResponseDisambiguator::new(MultiBlockPresence::Full), DebugRequest::start(), ); - let node_key0 = graph.get_or_insert(request0.clone(), MultiBlockPresence::Full, None, 1); + let parent_node_key = + graph.get_or_insert(parent_request.clone(), MultiBlockPresence::Full, None, 1); - assert_eq!(graph.len(), 1); + assert_eq!(graph.requests().len(), 1); - let Some(node) = graph.get(node_key0) else { + let Some(node) = graph.get(parent_node_key) else { unreachable!() }; assert_eq!(*node.value(), 1); assert_eq!(node.children().len(), 0); - assert_eq!(node.request(), &request0); + assert_eq!(node.request(), &parent_request); - let request1 = Request::ChildNodes( + let child_request = Request::ChildNodes( rng.gen(), ResponseDisambiguator::new(MultiBlockPresence::Full), DebugRequest::start(), ); - let node_key1 = graph.get_or_insert( - request1.clone(), + let child_node_key = graph.get_or_insert( + child_request.clone(), MultiBlockPresence::Full, - Some(node_key0), + Some(parent_node_key), 2, ); - assert_eq!(graph.len(), 2); + assert_eq!(graph.requests().len(), 2); - let Some(node) = graph.get(node_key1) else { + let Some(node) = graph.get(child_node_key) else { unreachable!() }; assert_eq!(*node.value(), 2); assert_eq!(node.children().len(), 0); - assert_eq!(node.request(), &request1); + assert_eq!(node.request(), &child_request); assert_eq!( - graph.get(node_key0).unwrap().children().collect::>(), - [node_key1] + graph + .get(parent_node_key) + .unwrap() + .children() + .collect::>(), + [child_node_key] ); - graph.remove(node_key1); + graph.remove(child_node_key); - assert_eq!(graph.get(node_key0).unwrap().children().len(), 0); + assert_eq!(graph.get(parent_node_key).unwrap().children().len(), 0); } #[test] @@ -194,7 +217,7 @@ mod tests { let mut rng = rand::thread_rng(); let mut graph = Graph::new(); - assert_eq!(graph.len(), 0); + assert_eq!(graph.requests().len(), 0); let request = Request::ChildNodes( rng.gen(), @@ -203,10 +226,177 @@ mod tests { ); let node_key0 = graph.get_or_insert(request.clone(), MultiBlockPresence::Full, None, 1); - assert_eq!(graph.len(), 1); + assert_eq!(graph.requests().len(), 1); let node_key1 = graph.get_or_insert(request, MultiBlockPresence::Full, None, 1); - assert_eq!(graph.len(), 1); + assert_eq!(graph.requests().len(), 1); assert_eq!(node_key0, node_key1); } + + #[test] + fn multiple_parents() { + let mut rng = rand::thread_rng(); + let mut graph = Graph::new(); + + let hash = rng.gen(); + + let parent_block_presence_0 = MultiBlockPresence::None; + let parent_request_0 = Request::ChildNodes( + hash, + ResponseDisambiguator::new(parent_block_presence_0), + DebugRequest::start(), + ); + + let parent_block_presence_1 = MultiBlockPresence::Full; + let parent_request_1 = Request::ChildNodes( + hash, + ResponseDisambiguator::new(parent_block_presence_1), + DebugRequest::start(), + ); + + let child_request = Request::Block(rng.gen(), DebugRequest::start()); + + let parent_key_0 = graph.get_or_insert(parent_request_0, parent_block_presence_0, None, 0); + let parent_key_1 = graph.get_or_insert(parent_request_1, parent_block_presence_1, None, 1); + + let child_key_0 = graph.get_or_insert( + child_request.clone(), + MultiBlockPresence::None, + Some(parent_key_0), + 2, + ); + + let child_key_1 = graph.get_or_insert( + child_request, + MultiBlockPresence::None, + Some(parent_key_1), + 2, + ); + + assert_eq!(child_key_0, child_key_1); + + for parent_key in [parent_key_0, parent_key_1] { + assert_eq!( + graph + .get(parent_key) + .unwrap() + .children() + .collect::>(), + HashSet::from([child_key_0]) + ); + } + + assert_eq!( + graph + .get(child_key_0) + .unwrap() + .parents() + .collect::>(), + HashSet::from([parent_key_0, parent_key_1]) + ); + + graph.remove(parent_key_0); + + assert_eq!( + graph + .get(child_key_0) + .unwrap() + .parents() + .collect::>(), + HashSet::from([parent_key_1]) + ); + + graph.remove(parent_key_1); + + assert_eq!( + graph + .get(child_key_0) + .unwrap() + .parents() + .collect::>(), + HashSet::default(), + ); + } + + #[test] + fn multiple_children() { + let mut rng = rand::thread_rng(); + let mut graph = Graph::new(); + + let parent_request = Request::ChildNodes( + rng.gen(), + ResponseDisambiguator::new(MultiBlockPresence::Full), + DebugRequest::start(), + ); + + let child_request_0 = Request::ChildNodes( + rng.gen(), + ResponseDisambiguator::new(MultiBlockPresence::Full), + DebugRequest::start(), + ); + + let child_request_1 = Request::ChildNodes( + rng.gen(), + ResponseDisambiguator::new(MultiBlockPresence::Full), + DebugRequest::start(), + ); + + let parent_key = graph.get_or_insert(parent_request, MultiBlockPresence::Full, None, 0); + + let child_key_0 = graph.get_or_insert( + child_request_0, + MultiBlockPresence::Full, + Some(parent_key), + 1, + ); + + let child_key_1 = graph.get_or_insert( + child_request_1, + MultiBlockPresence::Full, + Some(parent_key), + 2, + ); + + assert_eq!( + graph + .get(parent_key) + .unwrap() + .children() + .collect::>(), + HashSet::from([child_key_0, child_key_1]) + ); + + for child_key in [child_key_0, child_key_1] { + assert_eq!( + graph + .get(child_key) + .unwrap() + .parents() + .collect::>(), + HashSet::from([parent_key]) + ); + } + + graph.remove(child_key_0); + + assert_eq!( + graph + .get(parent_key) + .unwrap() + .children() + .collect::>(), + HashSet::from([child_key_1]) + ); + + graph.remove(child_key_1); + + assert_eq!( + graph + .get(parent_key) + .unwrap() + .children() + .collect::>(), + HashSet::default() + ); + } } diff --git a/lib/src/network/request_tracker/tests.rs b/lib/src/network/request_tracker/tests.rs index 393f62d0c..16cf2f3ca 100644 --- a/lib/src/network/request_tracker/tests.rs +++ b/lib/src/network/request_tracker/tests.rs @@ -3,145 +3,214 @@ use super::{ *, }; use crate::{ + collections::HashSet, crypto::{sign::Keypair, Hashable}, network::message::ResponseDisambiguator, - protocol::{test_utils::Snapshot, Block, MultiBlockPresence, Proof, UntrustedProof}, - test_utils, + protocol::{ + test_utils::{BlockState, Snapshot}, + Block, MultiBlockPresence, Proof, SingleBlockPresence, UntrustedProof, + }, version_vector::VersionVector, }; -use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; -use std::collections::VecDeque; +use rand::{ + distributions::{Bernoulli, Distribution, Standard}, + rngs::StdRng, + seq::SliceRandom, + CryptoRng, Rng, SeedableRng, +}; +use std::collections::{BTreeMap, VecDeque}; +// Test syncing while peers keep joining and leaving the swarm. +// // Note: We need `tokio::test` here because the `RequestTracker` uses `DelayQueue` internaly which // needs a tokio runtime. +#[ignore = "fails due to problems with the test setup"] #[tokio::test] -async fn simulation() { +async fn dynamic_swarm() { let seed = rand::random(); - simulation_case(seed, 64, 4); -} + case(seed, 64, 4); -fn simulation_case(seed: u64, max_blocks: usize, expected_peer_changes: usize) { - test_utils::init_log(); + fn case(seed: u64, max_blocks: usize, expected_peer_changes: usize) { + let mut rng = StdRng::seed_from_u64(seed); + let num_blocks = rng.gen_range(1..=max_blocks); + let snapshot = Snapshot::generate(&mut rng, num_blocks); + let mut summary = Summary::new(snapshot.blocks().len()); - tracing::info!(seed, max_blocks, expected_peer_changes); - - let mut rng = StdRng::seed_from_u64(seed); + println!( + "seed = {seed}, blocks = {}/{max_blocks}, expected_peer_changes = {expected_peer_changes}", + snapshot.blocks().len() + ); - let (tracker, mut tracker_worker) = build(); + let (tracker, mut tracker_worker) = build(); + let mut peers = Vec::new(); + + // Action to perform on the set of peers. + #[derive(Debug)] + enum Action { + // Insert a new peer + Insert, + // Remove a random peer + Remove, + // Keep the peer set intact + Keep, + } - let block_count = rng.gen_range(1..=max_blocks); - let snapshot = Snapshot::generate(&mut rng, block_count); - let mut summary = Summary::new(snapshot.blocks().len()); + // Total number of simulation steps is the number of index nodes plus the number of blocks + // in the snapshot. This is used to calculate the probability of the next action. + let steps = 1 + snapshot.inner_count() + snapshot.leaf_count() + snapshot.blocks().len(); - tracing::info!(?snapshot); + for tick in 0.. { + let _enter = tracing::info_span!("tick", message = tick).entered(); - let mut peers = Vec::new(); - let mut total_peer_count = 0; // total number of peers that participated in the simulation + // Generate the next action. The probability of `Insert` or `Remove` is chosen such that + // the expected number of such actions in the simulation is equal to + // `expected_peer_changes`. Both `Insert` and `Remove` have currently the same + // probability. + let action = if rng.gen_range(0..steps) < expected_peer_changes { + if rng.gen() { + Action::Insert + } else { + Action::Remove + } + } else { + Action::Keep + }; - // Action to perform on the set of peers. - #[derive(Debug)] - enum Action { - // Insert a new peer - Insert, - // Remove a random peer - Remove, - // Keep the peer set intact - Keep, - } + match action { + Action::Insert => { + peers.push(make_peer(&mut rng, &tracker, snapshot.clone())); + } + Action::Remove => { + if peers.len() < 2 { + continue; + } - // Total number of simulation steps is the number of index nodes plus the number of blocks in - // the snapshot. This is used to calculate the probability of the next action. - let steps = 1 + snapshot.inner_count() + snapshot.leaf_count() + snapshot.blocks().len(); + let index = rng.gen_range(0..peers.len()); + peers.remove(index); + } + Action::Keep => { + if peers.is_empty() { + continue; + } + } + } - for tick in 0.. { - let _enter = tracing::info_span!("tick", message = tick); + let polled = poll_peers(&mut rng, &mut peers, &mut summary); - // Generate the next action. The probability of `Insert` or `Remove` is chosen such that the - // expected number of such actions in the simulation is equal to `expected_peer_changes`. - // Both `Insert` and `Remove` have currently the same probability. - let action = if rng.gen_range(0..steps) < expected_peer_changes { - if rng.gen() { - Action::Insert + if polled || matches!(action, Action::Remove) { + tracker_worker.step(); } else { - Action::Remove + break; } - } else { - Action::Keep - }; + } - match action { - Action::Insert => { - let (tracker_client, tracker_request_rx) = tracker.new_client(); - let client = TestClient::new(tracker_client, tracker_request_rx); + summary.verify(&snapshot); + assert_eq!(tracker_worker.requests().len(), 0); + } +} - let writer_id = PublicKey::generate(&mut rng); - let write_keys = Keypair::generate(&mut rng); - let server = TestServer::new(writer_id, write_keys, &snapshot); +// Test syncing with multiple peers where no peer has all the blocks but every block is present in +// at least one peer. +#[ignore = "fails due to problems with the test setup"] +#[tokio::test] +async fn missing_blocks() { + // let seed = rand::random(); + let seed = 830380000365750606; + case(seed, 8, 2); + + fn case(seed: u64, max_blocks: usize, max_peers: usize) { + crate::test_utils::init_log(); + + let mut rng = StdRng::seed_from_u64(seed); + let num_blocks = rng.gen_range(2..=max_blocks); + let num_peers = rng.gen_range(2..=max_peers); + let (master_snapshot, peer_snapshots) = + generate_snapshots_with_missing_blocks(&mut rng, num_peers, num_blocks); + let mut summary = Summary::new(master_snapshot.blocks().len()); + + println!( + "seed = {seed}, blocks = {num_blocks}/{max_blocks}, peers = {num_peers}/{max_peers}" + ); - peers.push((client, server)); - total_peer_count += 1; - } - Action::Remove => { - if peers.len() < 2 { - continue; - } + let (tracker, mut tracker_worker) = build(); + let mut peers: Vec<_> = peer_snapshots + .into_iter() + .map(|snapshot| make_peer(&mut rng, &tracker, snapshot)) + .collect(); - let index = rng.gen_range(0..peers.len()); - peers.remove(index); - } - Action::Keep => { - if peers.is_empty() { - continue; - } + for tick in 0.. { + let _enter = tracing::info_span!("tick", message = tick).entered(); + + if poll_peers(&mut rng, &mut peers, &mut summary) { + tracker_worker.step(); + } else { + break; } } - let polled = poll_peers(&mut rng, &mut peers, &snapshot, &mut summary); - - if polled || matches!(action, Action::Remove) { - tracker_worker.step(); - } else { - break; - } + summary.verify(&master_snapshot); + assert_eq!(tracker_worker.requests().cloned().collect::>(), []); } - - summary.verify(total_peer_count, &snapshot); - assert_eq!(tracker_worker.request_count(), 0); } -// TODO: test multiple peers with different block summaries // TODO: test failure/timeout struct Summary { - expected_blocks: usize, - nodes: HashMap, - blocks: HashMap, + expected_block_count: usize, + + // Using `BTreeMap` so any potential failures are printed in the same order in different test + // runs. + requests: BTreeMap, + + nodes: HashMap>, + blocks: HashSet, + + node_failures: HashMap, + block_failures: HashMap, } impl Summary { - fn new(expected_blocks: usize) -> Self { + fn new(expected_block_count: usize) -> Self { Self { - expected_blocks, + expected_block_count, + requests: BTreeMap::default(), nodes: HashMap::default(), - blocks: HashMap::default(), + blocks: HashSet::default(), + node_failures: HashMap::default(), + block_failures: HashMap::default(), } } - fn receive_node(&mut self, hash: Hash) -> bool { - *self.nodes.entry(hash).or_default() += 1; - self.blocks.len() < self.expected_blocks + fn send_request(&mut self, request: &Request) { + *self.requests.entry(MessageKey::from(request)).or_default() += 1; + } + + fn receive_node(&mut self, hash: Hash, block_presence: MultiBlockPresence) -> bool { + self.nodes.entry(hash).or_default().insert(block_presence); + self.blocks.len() < self.expected_block_count + } + + fn receive_node_failure(&mut self, hash: Hash) { + *self.node_failures.entry(hash).or_default() += 1; } fn receive_block(&mut self, block_id: BlockId) { - *self.blocks.entry(block_id).or_default() += 1; + self.blocks.insert(block_id); } - fn verify(&mut self, num_peers: usize, snapshot: &Snapshot) { - assert_eq!( - self.nodes.remove(snapshot.root_hash()).unwrap_or(0), - num_peers, - "root node not received exactly {num_peers} times: {:?}", - snapshot.root_hash() + fn receive_block_failure(&mut self, block_id: BlockId) { + *self.block_failures.entry(block_id).or_default() += 1; + } + + fn verify(self, snapshot: &Snapshot) { + assert!( + self.nodes + .get(snapshot.root_hash()) + .into_iter() + .flatten() + .count() + > 0, + "root node not received" ); for hash in snapshot @@ -149,32 +218,40 @@ impl Summary { .map(|node| &node.hash) .chain(snapshot.leaf_nodes().map(|node| &node.locator)) { - assert_eq!( - self.nodes.remove(hash).unwrap_or(0), - 1, - "child node not received exactly once: {hash:?}" + assert!( + self.nodes.get(hash).into_iter().flatten().count() > 0, + "child node not received: {hash:?}" ); } for block_id in snapshot.blocks().keys() { - assert_eq!( - self.blocks.remove(block_id).unwrap_or(0), - 1, - "block not received exactly once: {block_id:?}" + assert!( + self.blocks.contains(block_id), + "block not received: {block_id:?}" ); } - // Verify we received only the expected nodes and blocks - assert!( - self.nodes.is_empty(), - "unexpected nodes received: {:?}", - self.nodes - ); - assert!( - self.blocks.is_empty(), - "unexpected blocks received: {:?}", - self.blocks - ); + for (request, &actual_count) in &self.requests { + let expected_max = match request { + MessageKey::RootNode(_) => 0, + MessageKey::ChildNodes(hash) => { + self.nodes.get(hash).map(HashSet::len).unwrap_or(0) + + self.node_failures.get(hash).copied().unwrap_or(0) + } + MessageKey::Block(block_id) => { + (if self.blocks.contains(block_id) { 1 } else { 0 }) + + self.block_failures.get(block_id).copied().unwrap_or(0) + } + }; + + assert!( + actual_count <= expected_max, + "request sent too many times ({} instead of {}): {:?}", + actual_count, + expected_max, + request + ); + } } } @@ -198,7 +275,7 @@ impl TestClient { match response { Response::RootNode(proof, block_presence, debug_payload) => { let requests = summary - .receive_node(proof.hash) + .receive_node(proof.hash, block_presence) .then_some(( Request::ChildNodes( proof.hash, @@ -217,15 +294,18 @@ impl TestClient { let parent_hash = nodes.hash(); let requests: Vec<_> = nodes .into_iter() - .filter_map(|(_, node)| { - summary.receive_node(node.hash).then_some(( + .filter(|(_, node)| { + summary.receive_node(node.hash, node.summary.block_presence) + }) + .map(|(_, node)| { + ( Request::ChildNodes( node.hash, ResponseDisambiguator::new(node.summary.block_presence), debug_payload.follow_up(), ), node.summary.block_presence, - )) + ) }) .collect(); @@ -236,11 +316,21 @@ impl TestClient { let parent_hash = nodes.hash(); let requests = nodes .into_iter() - .filter_map(|node| { - summary.receive_node(node.locator).then_some(( + .filter(|node| { + summary.receive_node( + node.locator, + match node.block_presence { + SingleBlockPresence::Present => MultiBlockPresence::Full, + SingleBlockPresence::Missing => MultiBlockPresence::None, + SingleBlockPresence::Expired => unimplemented!(), + }, + ) + }) + .map(|node| { + ( Request::Block(node.block_id, debug_payload.follow_up()), MultiBlockPresence::None, - )) + ) }) .collect(); @@ -259,9 +349,11 @@ impl TestClient { self.tracker_client.failure(MessageKey::RootNode(writer_id)); } Response::ChildNodesError(hash, _disambiguator, _debug_payload) => { + summary.receive_node_failure(hash); self.tracker_client.failure(MessageKey::ChildNodes(hash)); } Response::BlockError(block_id, _debug_payload) => { + summary.receive_block_failure(block_id); self.tracker_client.failure(MessageKey::Block(block_id)); } Response::BlockOffer(_block_id, _debug_payload) => unimplemented!(), @@ -276,11 +368,12 @@ impl TestClient { struct TestServer { writer_id: PublicKey, write_keys: Keypair, + snapshot: Snapshot, outbox: VecDeque, } impl TestServer { - fn new(writer_id: PublicKey, write_keys: Keypair, snapshot: &Snapshot) -> Self { + fn new(writer_id: PublicKey, write_keys: Keypair, snapshot: Snapshot) -> Self { let proof = UntrustedProof::from(Proof::new( writer_id, VersionVector::first(writer_id), @@ -290,7 +383,7 @@ impl TestServer { let outbox = [Response::RootNode( proof.clone(), - MultiBlockPresence::Full, + *snapshot.root_block_presence(), DebugResponse::unsolicited(), )] .into(); @@ -298,24 +391,27 @@ impl TestServer { Self { writer_id, write_keys, + snapshot, outbox, } } - fn handle_request(&mut self, request: Request, snapshot: &Snapshot) { + fn handle_request(&mut self, request: Request, summary: &mut Summary) { + summary.send_request(&request); + match request { Request::RootNode(writer_id, debug_payload) => { if writer_id == self.writer_id { let proof = Proof::new( writer_id, VersionVector::first(writer_id), - *snapshot.root_hash(), + *self.snapshot.root_hash(), &self.write_keys, ); self.outbox.push_back(Response::RootNode( proof.into(), - MultiBlockPresence::Full, + *self.snapshot.root_block_presence(), debug_payload.reply(), )); } else { @@ -324,7 +420,8 @@ impl TestServer { } } Request::ChildNodes(hash, disambiguator, debug_payload) => { - if let Some(nodes) = snapshot + if let Some(nodes) = self + .snapshot .inner_layers() .flat_map(|layer| layer.inner_maps()) .find_map(|(parent_hash, nodes)| (*parent_hash == hash).then_some(nodes)) @@ -334,7 +431,8 @@ impl TestServer { disambiguator, debug_payload.reply(), )); - } else if let Some(nodes) = snapshot + } else if let Some(nodes) = self + .snapshot .leaf_sets() .find_map(|(parent_hash, nodes)| (*parent_hash == hash).then_some(nodes)) { @@ -352,7 +450,7 @@ impl TestServer { } } Request::Block(block_id, debug_payload) => { - if let Some(block) = snapshot.blocks().get(&block_id) { + if let Some(block) = self.snapshot.blocks().get(&block_id) { self.outbox.push_back(Response::Block( block.content.clone(), block.nonce, @@ -371,11 +469,25 @@ impl TestServer { } } +fn make_peer( + rng: &mut R, + tracker: &RequestTracker, + snapshot: Snapshot, +) -> (TestClient, TestServer) { + let (tracker_client, tracker_request_rx) = tracker.new_client(); + let client = TestClient::new(tracker_client, tracker_request_rx); + + let writer_id = PublicKey::generate(rng); + let write_keys = Keypair::generate(rng); + let server = TestServer::new(writer_id, write_keys, snapshot); + + (client, server) +} + // Polls every client and server once, in random order fn poll_peers( rng: &mut R, peers: &mut [(TestClient, TestServer)], - snapshot: &Snapshot, summary: &mut Summary, ) -> bool { #[derive(Debug)] @@ -398,7 +510,7 @@ fn poll_peers( match side { Side::Client => { if let Some(request) = client.poll_request() { - server.handle_request(request, snapshot); + server.handle_request(request, summary); changed = true; } } @@ -413,3 +525,53 @@ fn poll_peers( changed } + +/// Generate `count + 1` copies of the same snapshot. The first one will have all the blocks +/// present (the "master copy"). The remaining ones will have some blocks missing but in such a +/// way that every block is present in at least one of the snapshots. +fn generate_snapshots_with_missing_blocks( + mut rng: &mut impl Rng, + count: usize, + num_blocks: usize, +) -> (Snapshot, Vec) { + let all_blocks: Vec<(Hash, Block)> = rng.sample_iter(Standard).take(num_blocks).collect(); + + let mut partial_block_sets = Vec::with_capacity(count); + partial_block_sets.resize_with(count, || Vec::with_capacity(num_blocks)); + + // Every block is present in one snapshot and has a 50% (1:2) chance of being present in any of + // the other shapshots respectively. + let bernoulli = Bernoulli::from_ratio(1, 2).unwrap(); + + let mut batch = Vec::with_capacity(count); + + for (locator, block) in &all_blocks { + // Poor man's Binomial distribution + let num_present = 1 + (1..count).filter(|_| bernoulli.sample(&mut rng)).count(); + let num_missing = count - num_present; + + batch.extend( + iter::repeat(block.clone()) + .map(BlockState::Present) + .take(num_present) + .chain( + iter::repeat(block.id) + .map(BlockState::Missing) + .take(num_missing), + ), + ); + batch.shuffle(&mut rng); + + for (index, block) in batch.drain(..).enumerate() { + partial_block_sets[index].push((*locator, block)); + } + } + + ( + Snapshot::from_present_blocks(all_blocks), + partial_block_sets + .into_iter() + .map(Snapshot::from_blocks) + .collect(), + ) +} diff --git a/lib/src/protocol/test_utils.rs b/lib/src/protocol/test_utils.rs index 2e8cf7683..02fcb0b37 100644 --- a/lib/src/protocol/test_utils.rs +++ b/lib/src/protocol/test_utils.rs @@ -25,6 +25,7 @@ impl Snapshot { Self::from_present_blocks(rng.sample_iter(Standard).take(block_count)) } + // Convenience alternative to `from_blocks` when all blocks are present. pub fn from_present_blocks( locators_and_blocks: impl IntoIterator, ) -> Self { @@ -104,7 +105,6 @@ impl Snapshot { &self.root_hash } - #[expect(dead_code)] pub fn root_block_presence(&self) -> &MultiBlockPresence { &self.root_block_presence } @@ -192,9 +192,9 @@ impl<'a> InnerLayer<'a> { } } +#[derive(Debug)] pub(crate) enum BlockState { Present(Block), - #[expect(dead_code)] Missing(BlockId), } From 6f1f375174dd9d99ee12fcb32026aa3b1d3ae692 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Wed, 25 Sep 2024 12:22:08 +0200 Subject: [PATCH 33/55] Support building test_utils::Snapshot incrementally --- lib/src/network/request_tracker/tests.rs | 17 +- lib/src/protocol/inner_node.rs | 2 +- lib/src/protocol/leaf_node.rs | 11 +- lib/src/protocol/test_utils.rs | 542 ++++++++++++++++++----- lib/src/repository/vault/tests.rs | 6 +- lib/src/store/client.rs | 68 ++- lib/src/store/index.rs | 26 +- lib/src/store/test_utils.rs | 12 +- 8 files changed, 493 insertions(+), 191 deletions(-) diff --git a/lib/src/network/request_tracker/tests.rs b/lib/src/network/request_tracker/tests.rs index 16cf2f3ca..0d5b97cbd 100644 --- a/lib/src/network/request_tracker/tests.rs +++ b/lib/src/network/request_tracker/tests.rs @@ -383,7 +383,7 @@ impl TestServer { let outbox = [Response::RootNode( proof.clone(), - *snapshot.root_block_presence(), + snapshot.root_summary().block_presence, DebugResponse::unsolicited(), )] .into(); @@ -411,7 +411,7 @@ impl TestServer { self.outbox.push_back(Response::RootNode( proof.into(), - *self.snapshot.root_block_presence(), + self.snapshot.root_summary().block_presence, debug_payload.reply(), )); } else { @@ -420,22 +420,13 @@ impl TestServer { } } Request::ChildNodes(hash, disambiguator, debug_payload) => { - if let Some(nodes) = self - .snapshot - .inner_layers() - .flat_map(|layer| layer.inner_maps()) - .find_map(|(parent_hash, nodes)| (*parent_hash == hash).then_some(nodes)) - { + if let Some(nodes) = self.snapshot.get_inner_set(&hash) { self.outbox.push_back(Response::InnerNodes( nodes.clone(), disambiguator, debug_payload.reply(), )); - } else if let Some(nodes) = self - .snapshot - .leaf_sets() - .find_map(|(parent_hash, nodes)| (*parent_hash == hash).then_some(nodes)) - { + } else if let Some(nodes) = self.snapshot.get_leaf_set(&hash) { self.outbox.push_back(Response::LeafNodes( nodes.clone(), disambiguator, diff --git a/lib/src/protocol/inner_node.rs b/lib/src/protocol/inner_node.rs index 48f09f29a..4270ce7f2 100644 --- a/lib/src/protocol/inner_node.rs +++ b/lib/src/protocol/inner_node.rs @@ -37,7 +37,7 @@ impl Hashable for InnerNode { } } -#[derive(Default, Clone, Debug, Serialize, Deserialize)] +#[derive(Default, Clone, Eq, PartialEq, Debug, Serialize, Deserialize)] pub struct InnerNodes(BTreeMap); impl InnerNodes { diff --git a/lib/src/protocol/leaf_node.rs b/lib/src/protocol/leaf_node.rs index fc2912f9e..59171a999 100644 --- a/lib/src/protocol/leaf_node.rs +++ b/lib/src/protocol/leaf_node.rs @@ -53,7 +53,7 @@ impl Hashable for LeafNode { } /// Collection that acts as a ordered set of `LeafNode`s -#[derive(Default, Clone, Debug, Serialize, Deserialize)] +#[derive(Default, Clone, Eq, PartialEq, Debug, Serialize, Deserialize)] pub struct LeafNodes(Vec); impl LeafNodes { @@ -161,6 +161,15 @@ impl<'a> IntoIterator for &'a LeafNodes { } } +impl<'a> IntoIterator for &'a mut LeafNodes { + type Item = &'a mut LeafNode; + type IntoIter = slice::IterMut<'a, LeafNode>; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter_mut() + } +} + impl IntoIterator for LeafNodes { type Item = LeafNode; type IntoIter = vec::IntoIter; diff --git a/lib/src/protocol/test_utils.rs b/lib/src/protocol/test_utils.rs index 02fcb0b37..3fb8256bd 100644 --- a/lib/src/protocol/test_utils.rs +++ b/lib/src/protocol/test_utils.rs @@ -1,22 +1,24 @@ use super::{MultiBlockPresence, NodeState, SingleBlockPresence, Summary}; use crate::{ - collections::HashMap, crypto::{Hash, Hashable}, - protocol::{ - get_bucket, Block, BlockId, InnerNode, InnerNodes, LeafNode, LeafNodes, INNER_LAYER_COUNT, - }, + protocol::{Block, BlockId, InnerNode, InnerNodes, LeafNode, LeafNodes, INNER_LAYER_COUNT}, }; use rand::{distributions::Standard, Rng}; -use std::{borrow::Cow, fmt, mem}; +use std::{ + borrow::Cow, + collections::{btree_map::Entry, BTreeMap, VecDeque}, + fmt, mem, +}; // In-memory snapshot for testing purposes. #[derive(Clone)] pub(crate) struct Snapshot { root_hash: Hash, - root_block_presence: MultiBlockPresence, - inners: [HashMap; INNER_LAYER_COUNT], - leaves: HashMap, - blocks: HashMap, + root_summary: Summary, + // Using BTreeMap instead of HashMap for deterministic iteration order for repeatable tests. + inners: BTreeMap, + leaves: BTreeMap, + blocks: BTreeMap, } impl Snapshot { @@ -39,64 +41,79 @@ impl Snapshot { // Create snapshot given an iterator of blocks where each block is associated to its encoded // locator. pub fn from_blocks(locators_and_blocks: impl IntoIterator) -> Self { - let mut blocks = HashMap::default(); - let mut leaves = HashMap::default(); + let mut blocks = BTreeMap::default(); + let mut leaves = BTreeMap::default(); + let mut inners = Vec::new(); for (locator, block) in locators_and_blocks { let block_id = *block.id(); - let block_presence = block.presence(); - - match block { + let block_presence = match block { BlockState::Present(block) => { blocks.insert(block_id, block); + SingleBlockPresence::Present } - BlockState::Missing(_) => (), - } + BlockState::Missing(_) => SingleBlockPresence::Missing, + }; + + let path = BucketPath::leaf(&locator); leaves - .entry(BucketPath::new(&locator, INNER_LAYER_COUNT - 1)) + .entry(path) .or_insert_with(LeafNodes::default) .insert(locator, block_id, block_presence); } - let mut inners: [HashMap<_, InnerNodes>; INNER_LAYER_COUNT] = Default::default(); + let mut layer = BTreeMap::default(); - for (path, nodes) in &leaves { - add_inner_node( - INNER_LAYER_COUNT - 1, - &mut inners[INNER_LAYER_COUNT - 1], - path, - nodes.hash(), - Summary::from_leaves(nodes).block_presence, - ); + for (&path, child_nodes) in &leaves { + let parent_node = InnerNode::new(child_nodes.hash(), Summary::from_leaves(child_nodes)); + let (bucket, path) = path.pop(INNER_LAYER_COUNT - 1); + + layer + .entry(path) + .or_insert_with(InnerNodes::default) + .insert(bucket, parent_node); } - for layer in (0..INNER_LAYER_COUNT - 1).rev() { - let (lo, hi) = inners.split_at_mut(layer + 1); + inners.push(layer); + + for layer_index in (0..INNER_LAYER_COUNT - 1).rev() { + let mut layer = BTreeMap::default(); + + for (&path, child_nodes) in inners.last().unwrap() { + let parent_node = + InnerNode::new(child_nodes.hash(), Summary::from_inners(child_nodes)); + let (bucket, path) = path.pop(layer_index); - for (path, nodes) in &hi[0] { - add_inner_node( - layer, - lo.last_mut().unwrap(), - path, - nodes.hash(), - Summary::from_inners(nodes).block_presence, - ); + layer + .entry(path) + .or_insert_with(InnerNodes::default) + .insert(bucket, parent_node); } + + inners.push(layer); } - let nodes = inners[0] - .get(&BucketPath::default()) + let nodes = inners + .last() + .and_then(|layer| layer.values().next()) .map(Cow::Borrowed) .unwrap_or(Cow::Owned(InnerNodes::default())); let root_hash = nodes.hash(); - let root_block_presence = Summary::from_inners(&nodes).block_presence; + let root_summary = Summary::from_inners(&nodes); Self { root_hash, - root_block_presence, - inners, - leaves, + root_summary, + inners: inners + .into_iter() + .flat_map(|layer| layer.into_values()) + .map(|nodes| (nodes.hash(), nodes)) + .collect(), + leaves: leaves + .into_values() + .map(|nodes| (nodes.hash(), nodes)) + .collect(), blocks, } } @@ -105,15 +122,12 @@ impl Snapshot { &self.root_hash } - pub fn root_block_presence(&self) -> &MultiBlockPresence { - &self.root_block_presence + pub fn root_summary(&self) -> &Summary { + &self.root_summary } pub fn leaf_sets(&self) -> impl Iterator { - self.leaves.iter().map(move |(path, nodes)| { - let parent_hash = self.parent_hash(INNER_LAYER_COUNT, path); - (parent_hash, nodes) - }) + self.leaves.iter() } pub fn leaf_nodes(&self) -> impl Iterator { @@ -132,46 +146,176 @@ impl Snapshot { self.leaves.values().map(|nodes| nodes.len()).sum() } - pub fn inner_layers(&self) -> impl Iterator { - (0..self.inners.len()).map(move |inner_layer| InnerLayer(self, inner_layer)) + pub fn get_leaf_set(&self, parent_hash: &Hash) -> Option<&LeafNodes> { + self.leaves.get(parent_hash) + } + + // Iterates the inner sets in topological order (parents before children) + pub fn inner_sets(&self) -> impl Iterator { + InnerSets { + inners: &self.inners, + queue: VecDeque::from([&self.root_hash]), + } } pub fn inner_nodes(&self) -> impl Iterator { - self.inners - .iter() - .flat_map(|layer| layer.values()) - .flatten() + self.inner_sets() + .flat_map(|(_, nodes)| nodes) .map(|(_, node)| node) } pub fn inner_count(&self) -> usize { - self.inners - .iter() - .flat_map(|layer| layer.values()) - .map(|nodes| nodes.len()) - .sum() + self.inners.values().map(|nodes| nodes.len()).sum() + } + + pub fn get_inner_set(&self, parent_hash: &Hash) -> Option<&InnerNodes> { + self.inners.get(parent_hash) } - pub fn blocks(&self) -> &HashMap { + pub fn blocks(&self) -> &BTreeMap { &self.blocks } - // Returns the parent hash of inner nodes at `inner_layer` with the specified bucket path. - fn parent_hash(&self, inner_layer: usize, path: &BucketPath) -> &Hash { - if inner_layer == 0 { - &self.root_hash + pub fn insert_root(&mut self, hash: Hash, block_presence: MultiBlockPresence) -> bool { + if self.root_hash == hash { + match self.root_summary.state { + NodeState::Incomplete => true, + NodeState::Complete | NodeState::Approved => self + .root_summary + .block_presence + .is_outdated(&block_presence), + NodeState::Rejected => unimplemented!(), + } } else { - let (bucket, parent_path) = path.pop(inner_layer - 1); - &self.inners[inner_layer - 1] - .get(&parent_path) - .unwrap() - .get(bucket) - .unwrap() - .hash + self.root_hash = hash; + // FIXME: if hash == EMPTY_INNER_HASH we should set this to `Complete`: + self.root_summary = Summary::INCOMPLETE; + + self.inners.clear(); + self.leaves.clear(); + self.blocks.clear(); + + true + } + } + + pub fn insert_inners(&mut self, nodes: InnerNodes) -> InnerNodes { + match self.inners.entry(nodes.hash()) { + Entry::Occupied(entry) => entry + .get() + .iter() + .zip(nodes) + .filter(|((_, old), (_, new))| old.summary.is_outdated(&new.summary)) + .map(|(_, new)| new) + .collect(), + Entry::Vacant(entry) => { + entry.insert(nodes.clone().into_incomplete()); + nodes + } + } + } + + pub fn insert_leaves(&mut self, nodes: LeafNodes) -> LeafNodes { + let (nodes, update) = match self.leaves.entry(nodes.hash()) { + Entry::Occupied(entry) => ( + entry + .get() + .iter() + .zip(nodes) + .filter( + |(old, new)| match (old.block_presence, new.block_presence) { + (SingleBlockPresence::Missing, SingleBlockPresence::Present) => true, + (SingleBlockPresence::Present, SingleBlockPresence::Present) + | (SingleBlockPresence::Present, SingleBlockPresence::Missing) + | (SingleBlockPresence::Missing, SingleBlockPresence::Missing) => false, + (SingleBlockPresence::Expired, _) + | (_, SingleBlockPresence::Expired) => unimplemented!(), + }, + ) + .map(|(_, new)| new) + .collect(), + false, + ), + Entry::Vacant(entry) => { + entry.insert(nodes.clone().into_missing()); + (nodes, true) + } + }; + + if update { + self.update_root_summary(); + } + + nodes + } + + pub fn insert_block(&mut self, block: Block) -> bool { + if self.blocks.insert(block.id, block).is_none() { + self.update_root_summary(); + true + } else { + false + } + } + + fn update_root_summary(&mut self) { + self.root_summary = self.update_summary(self.root_hash); + } + + fn update_summary(&mut self, hash: Hash) -> Summary { + // XXX: This is not very efficient but probably fine for tests. + + // Dancing around the borrow checker ... + let inner_inputs = self.inners.get(&hash).map(|nodes| { + nodes + .iter() + .filter(|(_, node)| { + node.summary.state == NodeState::Incomplete + || node.summary.block_presence != MultiBlockPresence::Full + }) + .map(|(bucket, node)| (bucket, node.hash)) + .collect::>() + }); + + if let Some(inputs) = inner_inputs { + let outputs: Vec<_> = inputs + .into_iter() + .map(|(bucket, hash)| (bucket, self.update_summary(hash))) + .collect(); + + let nodes = self.inners.get_mut(&hash).unwrap(); + + for (bucket, summary) in outputs { + nodes.get_mut(bucket).unwrap().summary = summary + } + + Summary::from_inners(nodes) + } else if let Some(nodes) = self.leaves.get_mut(&hash) { + for node in &mut *nodes { + match node.block_presence { + SingleBlockPresence::Present => continue, + SingleBlockPresence::Missing => { + if self.blocks.contains_key(&node.block_id) { + node.block_presence = SingleBlockPresence::Present; + } + } + SingleBlockPresence::Expired => unimplemented!(), + } + } + + Summary::from_leaves(nodes) + } else { + Summary::INCOMPLETE } } } +impl Default for Snapshot { + fn default() -> Self { + Self::from_blocks([]) + } +} + impl fmt::Debug for Snapshot { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Snapshot") @@ -181,14 +325,61 @@ impl fmt::Debug for Snapshot { } } -pub(crate) struct InnerLayer<'a>(&'a Snapshot, usize); +// Iterator that yields the inner sets in topological order. +struct InnerSets<'a> { + inners: &'a BTreeMap, + queue: VecDeque<&'a Hash>, +} -impl<'a> InnerLayer<'a> { - pub fn inner_maps(self) -> impl Iterator { - self.0.inners[self.1].iter().map(move |(path, nodes)| { - let parent_hash = self.0.parent_hash(self.1, path); - (parent_hash, nodes) - }) +impl<'a> Iterator for InnerSets<'a> { + type Item = (&'a Hash, &'a InnerNodes); + + fn next(&mut self) -> Option { + let parent_hash = self.queue.pop_front()?; + let nodes = self.inners.get(parent_hash)?; + + self.queue.extend(nodes.iter().map(|(_, node)| &node.hash)); + + Some((parent_hash, nodes)) + } +} + +#[track_caller] +pub(crate) fn assert_snapshots_equal(lhs: &Snapshot, rhs: &Snapshot) { + assert_eq!( + (lhs.root_hash(), lhs.root_summary()), + (rhs.root_hash(), rhs.root_summary()), + "root node mismatch" + ); + + assert_eq!( + lhs.inner_count(), + rhs.inner_count(), + "inner node count mismatch" + ); + + for (lhs, rhs) in lhs.inner_nodes().zip(rhs.inner_nodes()) { + assert_eq!(lhs, rhs, "inner node mismatch"); + } + + assert_eq!( + lhs.leaf_count(), + rhs.leaf_count(), + "leaf node count mismatch" + ); + + for (lhs, rhs) in lhs.leaf_nodes().zip(rhs.leaf_nodes()) { + assert_eq!(lhs, rhs, "leaf node mismatch"); + } + + assert_eq!( + lhs.blocks().len(), + rhs.blocks().len(), + "present block count mismatch" + ); + + for (lhs, rhs) in lhs.blocks().keys().zip(rhs.blocks().keys()) { + assert_eq!(lhs, rhs, "block mismatch"); } } @@ -205,50 +396,173 @@ impl BlockState { Self::Missing(block_id) => block_id, } } +} - pub fn presence(&self) -> SingleBlockPresence { - match self { - Self::Present(_) => SingleBlockPresence::Present, - Self::Missing(_) => SingleBlockPresence::Missing, - } +#[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)] +struct BucketPath([u8; INNER_LAYER_COUNT]); + +impl BucketPath { + fn leaf(locator: &Hash) -> Self { + let mut path = [0; INNER_LAYER_COUNT]; + path.copy_from_slice(&locator.as_ref()[..INNER_LAYER_COUNT]); + + Self(path) } -} -fn add_inner_node( - inner_layer: usize, - maps: &mut HashMap, - path: &BucketPath, - hash: Hash, - block_presence: MultiBlockPresence, -) { - let (bucket, parent_path) = path.pop(inner_layer); - maps.entry(parent_path).or_default().insert( - bucket, - InnerNode::new( - hash, - Summary { - state: NodeState::Complete, - block_presence, - }, - ), - ); + fn pop(mut self, layer: usize) -> (u8, Self) { + let bucket = mem::replace(&mut self.0[layer], 0); + (bucket, self) + } } -#[derive(Default, Clone, Copy, Eq, PartialEq, Hash, Debug)] -struct BucketPath([u8; INNER_LAYER_COUNT]); +#[cfg(test)] +mod tests { + use super::*; + use crate::{collections::HashSet, protocol::EMPTY_INNER_HASH}; + use rand::{rngs::StdRng, SeedableRng}; + + #[test] + fn empty_snapshot() { + let s = Snapshot::default(); + + assert_eq!(*s.root_hash(), *EMPTY_INNER_HASH); + assert_eq!(s.root_summary().state, NodeState::Complete); + assert_eq!(s.root_summary().block_presence, MultiBlockPresence::None); + assert_eq!(s.inner_count(), 0); + assert_eq!(s.leaf_count(), 0); + assert!(s.blocks().is_empty()); + } -impl BucketPath { - fn new(locator: &Hash, inner_layer: usize) -> Self { - let mut path = Self(Default::default()); - for (layer, bucket) in path.0.iter_mut().enumerate().take(inner_layer + 1) { - *bucket = get_bucket(locator, layer) + #[test] + fn full_snapshot() { + case(rand::random(), 8); + + fn case(seed: u64, max_blocks: usize) { + println!("seed = {seed}, max_blocks = {max_blocks}"); + + let mut rng = StdRng::seed_from_u64(seed); + let num_blocks = rng.gen_range(1..=max_blocks); + let s = Snapshot::generate(&mut rng, num_blocks); + + assert_ne!(*s.root_hash(), *EMPTY_INNER_HASH); + assert_eq!(s.root_summary().state, NodeState::Complete); + assert_eq!(s.root_summary().block_presence, MultiBlockPresence::Full); + assert_eq!(s.leaf_count(), num_blocks); + assert_eq!(s.blocks().len(), num_blocks); + + for node in s.inner_nodes() { + assert_eq!(node.summary.state, NodeState::Complete); + assert_eq!(node.summary.block_presence, MultiBlockPresence::Full); + } + + for node in s.leaf_nodes() { + assert_eq!(node.block_presence, SingleBlockPresence::Present); + } + + // Verify that traversing the whole tree yields all the block ids. + let mut hashes = vec![s.root_hash()]; + let mut expected_block_ids = HashSet::default(); + + while let Some(hash) = hashes.pop() { + if let Some(nodes) = s.get_inner_set(hash) { + assert_eq!(nodes.hash(), *hash); + hashes.extend(nodes.iter().map(|(_, node)| &node.hash)); + } else if let Some(nodes) = s.get_leaf_set(hash) { + assert_eq!(nodes.hash(), *hash); + expected_block_ids.extend(nodes.iter().map(|node| &node.block_id)); + } else { + panic!("nodes with parent hash {hash:?} not found"); + } + } + + let actual_block_ids: HashSet<_> = s.blocks().keys().collect(); + assert_eq!(actual_block_ids, expected_block_ids); } - path } - fn pop(&self, inner_layer: usize) -> (u8, Self) { - let mut popped = *self; - let bucket = mem::replace(&mut popped.0[inner_layer], 0); - (bucket, popped) + #[test] + fn snapshot_inner_sets_topological_order() { + case(rand::random(), 8); + + fn case(seed: u64, max_blocks: usize) { + println!("seed = {seed}, max_blocks = {max_blocks}"); + + let mut rng = StdRng::seed_from_u64(seed); + let num_blocks = rng.gen_range(1..=max_blocks); + let s = Snapshot::generate(&mut rng, num_blocks); + + let mut visited = HashSet::from([s.root_hash()]); + + for (parent_hash, nodes) in s.inner_sets() { + assert!(visited.contains(parent_hash)); + visited.extend(nodes.iter().map(|(_, node)| &node.hash)); + } + } + } + + #[test] + fn snapshot_sync() { + case(rand::random(), 8); + + fn case(seed: u64, max_blocks: usize) { + println!("seed = {seed}, max_blocks = {max_blocks}"); + + let mut rng = StdRng::seed_from_u64(seed); + let block_count = rng.gen_range(0..=max_blocks); + + let src = Snapshot::generate(&mut rng, block_count); + let mut dst = Snapshot::default(); + + let result = dst.insert_root(*src.root_hash(), src.root_summary().block_presence); + + let expected_initial_root_state = if block_count == 0 { + NodeState::Complete + } else { + NodeState::Incomplete + }; + + assert_eq!(result, block_count != 0); + assert_eq!( + *dst.root_summary(), + Summary { + state: expected_initial_root_state, + block_presence: MultiBlockPresence::None + } + ); + + for (_, nodes) in src.inner_sets() { + assert_eq!(dst.insert_inners(nodes.clone()), *nodes); + } + + assert_eq!(dst.root_summary().state, expected_initial_root_state); + assert_eq!(dst.root_summary().block_presence, MultiBlockPresence::None); + + for node in dst.inner_nodes() { + assert_eq!(node.summary.state, expected_initial_root_state); + assert_eq!(node.summary.block_presence, MultiBlockPresence::None); + } + + for (_, nodes) in src.leaf_sets() { + assert_eq!(dst.insert_leaves(nodes.clone()), *nodes); + } + + assert_eq!(dst.root_summary().state, NodeState::Complete); + assert_eq!(dst.root_summary().block_presence, MultiBlockPresence::None); + + for node in dst.inner_nodes() { + assert_eq!(node.summary.state, NodeState::Complete); + assert_eq!(node.summary.block_presence, MultiBlockPresence::None); + } + + for node in dst.leaf_nodes() { + assert_eq!(node.block_presence, SingleBlockPresence::Missing); + } + + for block in src.blocks().values() { + assert!(dst.insert_block(block.clone())); + } + + assert_snapshots_equal(&src, &dst); + } } } diff --git a/lib/src/repository/vault/tests.rs b/lib/src/repository/vault/tests.rs index 8a4291bd1..b18aef312 100644 --- a/lib/src/repository/vault/tests.rs +++ b/lib/src/repository/vault/tests.rs @@ -360,10 +360,8 @@ async fn block_ids_excludes_blocks_from_incomplete_snapshots() { .await .unwrap(); - for layer in snapshot.inner_layers() { - for (_, nodes) in layer.inner_maps() { - writer.save_inner_nodes(nodes.clone().into()).await.unwrap(); - } + for (_, nodes) in snapshot.inner_sets() { + writer.save_inner_nodes(nodes.clone().into()).await.unwrap(); } for (_, nodes) in snapshot.leaf_sets().take(1) { diff --git a/lib/src/store/client.rs b/lib/src/store/client.rs index 6f1246568..69515d650 100644 --- a/lib/src/store/client.rs +++ b/lib/src/store/client.rs @@ -834,24 +834,22 @@ mod tests { .unwrap(); writer.commit().await.unwrap(); - for layer in snapshot.inner_layers() { - for (hash, inner_nodes) in layer.inner_maps() { - let mut writer = store.begin_client_write().await.unwrap(); - writer - .save_inner_nodes(inner_nodes.clone().into()) - .await - .unwrap(); - writer.commit().await.unwrap(); + for (hash, inner_nodes) in snapshot.inner_sets() { + let mut writer = store.begin_client_write().await.unwrap(); + writer + .save_inner_nodes(inner_nodes.clone().into()) + .await + .unwrap(); + writer.commit().await.unwrap(); - assert!(!store - .acquire_read() - .await - .unwrap() - .load_inner_nodes(hash) - .await - .unwrap() - .is_empty()); - } + assert!(!store + .acquire_read() + .await + .unwrap() + .load_inner_nodes(hash) + .await + .unwrap() + .is_empty()); } for (hash, leaf_nodes) in snapshot.leaf_sets() { @@ -880,26 +878,24 @@ mod tests { let snapshot = Snapshot::generate(&mut rand::thread_rng(), 1); // Try to save the inner nodes - for layer in snapshot.inner_layers() { - let (hash, inner_nodes) = layer.inner_maps().next().unwrap(); - let mut writer = store.begin_client_write().await.unwrap(); - let status = writer - .save_inner_nodes(inner_nodes.clone().into()) - .await - .unwrap(); - assert!(status.new_children.is_empty()); - writer.commit().await.unwrap(); + let (hash, inner_nodes) = snapshot.inner_sets().next().unwrap(); + let mut writer = store.begin_client_write().await.unwrap(); + let status = writer + .save_inner_nodes(inner_nodes.clone().into()) + .await + .unwrap(); + assert!(status.new_children.is_empty()); + writer.commit().await.unwrap(); - // The orphaned inner nodes were not written to the db. - let inner_nodes = store - .acquire_read() - .await - .unwrap() - .load_inner_nodes(hash) - .await - .unwrap(); - assert!(inner_nodes.is_empty()); - } + // The orphaned inner nodes were not written to the db. + let inner_nodes = store + .acquire_read() + .await + .unwrap() + .load_inner_nodes(hash) + .await + .unwrap(); + assert!(inner_nodes.is_empty()); // Try to save the leaf nodes let (hash, leaf_nodes) = snapshot.leaf_sets().next().unwrap(); diff --git a/lib/src/store/index.rs b/lib/src/store/index.rs index 072852c35..03f02b23c 100644 --- a/lib/src/store/index.rs +++ b/lib/src/store/index.rs @@ -320,17 +320,15 @@ mod tests { // TODO: consider randomizing the order the nodes are saved so it's not always // breadth-first. - for layer in snapshot.inner_layers() { - for (parent_hash, nodes) in layer.inner_maps() { - inner_node::save_all(&mut tx, &nodes.clone().into_incomplete(), parent_hash) - .await - .unwrap(); + for (parent_hash, nodes) in snapshot.inner_sets() { + inner_node::save_all(&mut tx, &nodes.clone().into_incomplete(), parent_hash) + .await + .unwrap(); - update_summaries_and_approve(&mut tx, *parent_hash).await; + update_summaries_and_approve(&mut tx, *parent_hash).await; - reload_root_node(&mut tx, &mut root_node).await.unwrap(); - assert!(!root_node.summary.state.is_approved()); - } + reload_root_node(&mut tx, &mut root_node).await.unwrap(); + assert!(!root_node.summary.state.is_approved()); } let mut unsaved_leaves = snapshot.leaf_count(); @@ -406,12 +404,10 @@ mod tests { .unwrap(); } - for layer in snapshot.inner_layers() { - for (parent_hash, nodes) in layer.inner_maps() { - inner_node::save_all(&mut tx, &nodes.clone().into_incomplete(), parent_hash) - .await - .unwrap(); - } + for (parent_hash, nodes) in snapshot.inner_sets() { + inner_node::save_all(&mut tx, &nodes.clone().into_incomplete(), parent_hash) + .await + .unwrap(); } for (parent_hash, nodes) in snapshot.leaf_sets() { diff --git a/lib/src/store/test_utils.rs b/lib/src/store/test_utils.rs index 1547d254d..a0e1bca9b 100644 --- a/lib/src/store/test_utils.rs +++ b/lib/src/store/test_utils.rs @@ -51,13 +51,11 @@ impl<'a> SnapshotWriter<'a> { } pub async fn save_inner_nodes(mut self) -> Self { - for layer in self.snapshot.inner_layers() { - for (_, nodes) in layer.inner_maps() { - self.writer - .save_inner_nodes(nodes.clone().into()) - .await - .unwrap(); - } + for (_, nodes) in self.snapshot.inner_sets() { + self.writer + .save_inner_nodes(nodes.clone().into()) + .await + .unwrap(); } self From 4bc41cbaf6cdaf08ba5707dcbc2f93a920df3304 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Thu, 26 Sep 2024 09:36:56 +0200 Subject: [PATCH 34/55] Refactor RequestTracker test setup, fix one test --- lib/src/network/request_tracker/tests.rs | 293 +++++++++-------------- 1 file changed, 113 insertions(+), 180 deletions(-) diff --git a/lib/src/network/request_tracker/tests.rs b/lib/src/network/request_tracker/tests.rs index 0d5b97cbd..0a2b61d76 100644 --- a/lib/src/network/request_tracker/tests.rs +++ b/lib/src/network/request_tracker/tests.rs @@ -7,8 +7,8 @@ use crate::{ crypto::{sign::Keypair, Hashable}, network::message::ResponseDisambiguator, protocol::{ - test_utils::{BlockState, Snapshot}, - Block, MultiBlockPresence, Proof, SingleBlockPresence, UntrustedProof, + test_utils::{assert_snapshots_equal, BlockState, Snapshot}, + Block, MultiBlockPresence, Proof, UntrustedProof, }, version_vector::VersionVector, }; @@ -18,13 +18,12 @@ use rand::{ seq::SliceRandom, CryptoRng, Rng, SeedableRng, }; -use std::collections::{BTreeMap, VecDeque}; +use std::collections::VecDeque; // Test syncing while peers keep joining and leaving the swarm. // // Note: We need `tokio::test` here because the `RequestTracker` uses `DelayQueue` internaly which // needs a tokio runtime. -#[ignore = "fails due to problems with the test setup"] #[tokio::test] async fn dynamic_swarm() { let seed = rand::random(); @@ -32,9 +31,10 @@ async fn dynamic_swarm() { fn case(seed: u64, max_blocks: usize, expected_peer_changes: usize) { let mut rng = StdRng::seed_from_u64(seed); + let mut sim = Simulation::new(); + let num_blocks = rng.gen_range(1..=max_blocks); let snapshot = Snapshot::generate(&mut rng, num_blocks); - let mut summary = Summary::new(snapshot.blocks().len()); println!( "seed = {seed}, blocks = {}/{max_blocks}, expected_peer_changes = {expected_peer_changes}", @@ -42,7 +42,6 @@ async fn dynamic_swarm() { ); let (tracker, mut tracker_worker) = build(); - let mut peers = Vec::new(); // Action to perform on the set of peers. #[derive(Debug)] @@ -78,24 +77,23 @@ async fn dynamic_swarm() { match action { Action::Insert => { - peers.push(make_peer(&mut rng, &tracker, snapshot.clone())); + sim.insert_peer(&mut rng, &tracker, snapshot.clone()); } Action::Remove => { - if peers.len() < 2 { + if sim.peer_count() < 2 { continue; } - let index = rng.gen_range(0..peers.len()); - peers.remove(index); + sim.remove_peer(&mut rng); } Action::Keep => { - if peers.is_empty() { + if sim.peer_count() == 0 { continue; } } } - let polled = poll_peers(&mut rng, &mut peers, &mut summary); + let polled = sim.poll(&mut rng); if polled || matches!(action, Action::Remove) { tracker_worker.step(); @@ -104,155 +102,160 @@ async fn dynamic_swarm() { } } - summary.verify(&snapshot); + sim.verify(&snapshot); assert_eq!(tracker_worker.requests().len(), 0); } } // Test syncing with multiple peers where no peer has all the blocks but every block is present in // at least one peer. -#[ignore = "fails due to problems with the test setup"] +#[ignore = "duplicate request send checking is too strict"] #[tokio::test] async fn missing_blocks() { + crate::test_utils::init_log(); + // let seed = rand::random(); let seed = 830380000365750606; case(seed, 8, 2); fn case(seed: u64, max_blocks: usize, max_peers: usize) { - crate::test_utils::init_log(); - let mut rng = StdRng::seed_from_u64(seed); + let mut sim = Simulation::new(); + let num_blocks = rng.gen_range(2..=max_blocks); let num_peers = rng.gen_range(2..=max_peers); let (master_snapshot, peer_snapshots) = generate_snapshots_with_missing_blocks(&mut rng, num_peers, num_blocks); - let mut summary = Summary::new(master_snapshot.blocks().len()); println!( "seed = {seed}, blocks = {num_blocks}/{max_blocks}, peers = {num_peers}/{max_peers}" ); let (tracker, mut tracker_worker) = build(); - let mut peers: Vec<_> = peer_snapshots - .into_iter() - .map(|snapshot| make_peer(&mut rng, &tracker, snapshot)) - .collect(); + for snapshot in peer_snapshots { + sim.insert_peer(&mut rng, &tracker, snapshot); + } for tick in 0.. { let _enter = tracing::info_span!("tick", message = tick).entered(); - if poll_peers(&mut rng, &mut peers, &mut summary) { + if sim.poll(&mut rng) { tracker_worker.step(); } else { break; } } - summary.verify(&master_snapshot); + sim.verify(&master_snapshot); assert_eq!(tracker_worker.requests().cloned().collect::>(), []); } } // TODO: test failure/timeout -struct Summary { - expected_block_count: usize, - - // Using `BTreeMap` so any potential failures are printed in the same order in different test - // runs. - requests: BTreeMap, - - nodes: HashMap>, - blocks: HashSet, - - node_failures: HashMap, - block_failures: HashMap, +struct Simulation { + peers: Vec, + requests: HashSet, + snapshot: Snapshot, } -impl Summary { - fn new(expected_block_count: usize) -> Self { +impl Simulation { + fn new() -> Self { Self { - expected_block_count, - requests: BTreeMap::default(), - nodes: HashMap::default(), - blocks: HashSet::default(), - node_failures: HashMap::default(), - block_failures: HashMap::default(), + peers: Vec::new(), + requests: HashSet::default(), + snapshot: Snapshot::default(), } } - fn send_request(&mut self, request: &Request) { - *self.requests.entry(MessageKey::from(request)).or_default() += 1; + fn peer_count(&self) -> usize { + self.peers.len() } - fn receive_node(&mut self, hash: Hash, block_presence: MultiBlockPresence) -> bool { - self.nodes.entry(hash).or_default().insert(block_presence); - self.blocks.len() < self.expected_block_count + fn insert_peer( + &mut self, + rng: &mut R, + tracker: &RequestTracker, + snapshot: Snapshot, + ) { + let (tracker_client, tracker_request_rx) = tracker.new_client(); + let client = TestClient::new(tracker_client, tracker_request_rx); + + let writer_id = PublicKey::generate(rng); + let write_keys = Keypair::generate(rng); + let server = TestServer::new(writer_id, write_keys, snapshot); + + self.peers.push(TestPeer { + client, + server, + requests: Vec::new(), + }); } - fn receive_node_failure(&mut self, hash: Hash) { - *self.node_failures.entry(hash).or_default() += 1; - } + fn remove_peer(&mut self, rng: &mut R) { + let index = rng.gen_range(0..self.peers.len()); + let peer = self.peers.remove(index); - fn receive_block(&mut self, block_id: BlockId) { - self.blocks.insert(block_id); + for key in peer.requests { + self.requests.remove(&key); + } } - fn receive_block_failure(&mut self, block_id: BlockId) { - *self.block_failures.entry(block_id).or_default() += 1; - } + // Polls random client or server once + #[track_caller] + fn poll(&mut self, rng: &mut R) -> bool { + enum Side { + Client, + Server, + } - fn verify(self, snapshot: &Snapshot) { - assert!( - self.nodes - .get(snapshot.root_hash()) - .into_iter() - .flatten() - .count() - > 0, - "root node not received" - ); + let mut order: Vec<_> = (0..self.peers.len()) + .flat_map(|index| [(Side::Client, index), (Side::Server, index)]) + .collect(); - for hash in snapshot - .inner_nodes() - .map(|node| &node.hash) - .chain(snapshot.leaf_nodes().map(|node| &node.locator)) - { - assert!( - self.nodes.get(hash).into_iter().flatten().count() > 0, - "child node not received: {hash:?}" - ); - } + order.shuffle(rng); - for block_id in snapshot.blocks().keys() { - assert!( - self.blocks.contains(block_id), - "block not received: {block_id:?}" - ); - } + for (side, index) in order { + let peer = &mut self.peers[index]; + + match side { + Side::Client => { + if let Some(request) = peer.client.poll_request() { + let key = MessageKey::from(&request); + + assert!( + self.requests.insert(key), + "request sent more than once: {request:?}" + ); + + peer.requests.push(key); + peer.server.handle_request(request); - for (request, &actual_count) in &self.requests { - let expected_max = match request { - MessageKey::RootNode(_) => 0, - MessageKey::ChildNodes(hash) => { - self.nodes.get(hash).map(HashSet::len).unwrap_or(0) - + self.node_failures.get(hash).copied().unwrap_or(0) + return true; + } } - MessageKey::Block(block_id) => { - (if self.blocks.contains(block_id) { 1 } else { 0 }) - + self.block_failures.get(block_id).copied().unwrap_or(0) + Side::Server => { + if let Some(response) = peer.server.poll_response() { + peer.client.handle_response(response, &mut self.snapshot); + return true; + } } - }; - - assert!( - actual_count <= expected_max, - "request sent too many times ({} instead of {}): {:?}", - actual_count, - expected_max, - request - ); + } } + + false } + + #[track_caller] + fn verify(&self, expected_snapshot: &Snapshot) { + assert_snapshots_equal(&self.snapshot, expected_snapshot) + } +} + +struct TestPeer { + client: TestClient, + server: TestServer, + requests: Vec, } struct TestClient { @@ -271,11 +274,11 @@ impl TestClient { } } - fn handle_response(&mut self, response: Response, summary: &mut Summary) { + fn handle_response(&mut self, response: Response, snapshot: &mut Snapshot) { match response { Response::RootNode(proof, block_presence, debug_payload) => { - let requests = summary - .receive_node(proof.hash, block_presence) + let requests = snapshot + .insert_root(proof.hash, block_presence) .then_some(( Request::ChildNodes( proof.hash, @@ -292,11 +295,10 @@ impl TestClient { } Response::InnerNodes(nodes, _disambiguator, debug_payload) => { let parent_hash = nodes.hash(); + let nodes = snapshot.insert_inners(nodes); + let requests: Vec<_> = nodes .into_iter() - .filter(|(_, node)| { - summary.receive_node(node.hash, node.summary.block_presence) - }) .map(|(_, node)| { ( Request::ChildNodes( @@ -314,18 +316,9 @@ impl TestClient { } Response::LeafNodes(nodes, _disambiguator, debug_payload) => { let parent_hash = nodes.hash(); + let nodes = snapshot.insert_leaves(nodes); let requests = nodes .into_iter() - .filter(|node| { - summary.receive_node( - node.locator, - match node.block_presence { - SingleBlockPresence::Present => MultiBlockPresence::Full, - SingleBlockPresence::Missing => MultiBlockPresence::None, - SingleBlockPresence::Expired => unimplemented!(), - }, - ) - }) .map(|node| { ( Request::Block(node.block_id, debug_payload.follow_up()), @@ -339,21 +332,20 @@ impl TestClient { } Response::Block(content, nonce, _debug_payload) => { let block = Block::new(content, nonce); + let block_id = block.id; - summary.receive_block(block.id); + snapshot.insert_block(block); self.tracker_client - .success(MessageKey::Block(block.id), vec![]); + .success(MessageKey::Block(block_id), vec![]); } Response::RootNodeError(writer_id, _debug_payload) => { self.tracker_client.failure(MessageKey::RootNode(writer_id)); } Response::ChildNodesError(hash, _disambiguator, _debug_payload) => { - summary.receive_node_failure(hash); self.tracker_client.failure(MessageKey::ChildNodes(hash)); } Response::BlockError(block_id, _debug_payload) => { - summary.receive_block_failure(block_id); self.tracker_client.failure(MessageKey::Block(block_id)); } Response::BlockOffer(_block_id, _debug_payload) => unimplemented!(), @@ -396,9 +388,7 @@ impl TestServer { } } - fn handle_request(&mut self, request: Request, summary: &mut Summary) { - summary.send_request(&request); - + fn handle_request(&mut self, request: Request) { match request { Request::RootNode(writer_id, debug_payload) => { if writer_id == self.writer_id { @@ -460,63 +450,6 @@ impl TestServer { } } -fn make_peer( - rng: &mut R, - tracker: &RequestTracker, - snapshot: Snapshot, -) -> (TestClient, TestServer) { - let (tracker_client, tracker_request_rx) = tracker.new_client(); - let client = TestClient::new(tracker_client, tracker_request_rx); - - let writer_id = PublicKey::generate(rng); - let write_keys = Keypair::generate(rng); - let server = TestServer::new(writer_id, write_keys, snapshot); - - (client, server) -} - -// Polls every client and server once, in random order -fn poll_peers( - rng: &mut R, - peers: &mut [(TestClient, TestServer)], - summary: &mut Summary, -) -> bool { - #[derive(Debug)] - enum Side { - Client, - Server, - } - - let mut order: Vec<_> = (0..peers.len()) - .flat_map(|index| [(Side::Client, index), (Side::Server, index)]) - .collect(); - - order.shuffle(rng); - - let mut changed = false; - - for (side, index) in order { - let (client, server) = &mut peers[index]; - - match side { - Side::Client => { - if let Some(request) = client.poll_request() { - server.handle_request(request, summary); - changed = true; - } - } - Side::Server => { - if let Some(response) = server.poll_response() { - client.handle_response(response, summary); - changed = true; - } - } - } - } - - changed -} - /// Generate `count + 1` copies of the same snapshot. The first one will have all the blocks /// present (the "master copy"). The remaining ones will have some blocks missing but in such a /// way that every block is present in at least one of the snapshots. From 383a34570eec675041a40a551fbb551a7f6c4fb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Thu, 26 Sep 2024 13:28:37 +0200 Subject: [PATCH 35/55] Fix both RequestTracker tests --- lib/src/network/request_tracker.rs | 54 +++++++++++++--- lib/src/network/request_tracker/graph.rs | 8 ++- lib/src/network/request_tracker/tests.rs | 82 ++++++++++++++++++------ lib/src/protocol/summary.rs | 10 +-- lib/src/protocol/test_utils.rs | 12 +++- 5 files changed, 127 insertions(+), 39 deletions(-) diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index 2f05c724d..dd231f7a1 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -35,7 +35,7 @@ impl RequestTracker { } #[cfg_attr(not(test), expect(dead_code))] - pub fn new_client(&self) -> (RequestTrackerClient, mpsc::UnboundedReceiver) { + pub fn new_client(&self) -> (RequestTrackerClient, mpsc::UnboundedReceiver) { let client_id = ClientId::next(); let (request_tx, request_rx) = mpsc::unbounded_channel(); @@ -108,6 +108,17 @@ impl Drop for RequestTrackerClient { } } +/// Permit to send the specified request. Contains also the block presence as reported by the peer +/// who sent the response that triggered this request. That is mostly useful for diagnostics and +/// testing. +#[derive(Debug)] +pub(super) struct SendPermit { + #[cfg_attr(not(test), expect(dead_code))] + pub request: Request, + #[cfg_attr(not(test), expect(dead_code))] + pub block_presence: MultiBlockPresence, +} + /// Key identifying a request and its corresponding response. #[derive(Clone, Copy, Eq, PartialEq, Hash, Ord, PartialOrd, Debug)] pub(super) enum MessageKey { @@ -214,13 +225,20 @@ impl Worker { } #[instrument(skip(self, request_tx))] - fn insert_client(&mut self, client_id: ClientId, request_tx: mpsc::UnboundedSender) { + fn insert_client( + &mut self, + client_id: ClientId, + request_tx: mpsc::UnboundedSender, + ) { + #[cfg(test)] tracing::debug!("insert_client"); + self.clients.insert(client_id, ClientState::new(request_tx)); } #[instrument(skip(self))] fn remove_client(&mut self, client_id: ClientId) { + #[cfg(test)] tracing::debug!("remove_client"); let Some(client_state) = self.clients.remove(&client_id) else { @@ -239,7 +257,9 @@ impl Worker { request: Request, block_presence: MultiBlockPresence, ) { - // tracing::debug!("handle_initial"); + #[cfg(test)] + tracing::debug!("handle_initial"); + self.insert_request(client_id, request, block_presence, None) } @@ -250,6 +270,7 @@ impl Worker { request_key: MessageKey, requests: Vec<(Request, MultiBlockPresence)>, ) { + #[cfg(test)] tracing::debug!("handle_success"); let node_key = self @@ -331,6 +352,7 @@ impl Worker { request_key: MessageKey, reason: FailureReason, ) { + #[cfg(test)] tracing::debug!("handle_failure"); let Some(client_state) = self.clients.get_mut(&client_id) else { @@ -388,7 +410,13 @@ impl Worker { }; client_state.requests.insert(request_key, node_key); - client_state.request_tx.send(node.request().clone()).ok(); + client_state + .request_tx + .send(SendPermit { + request: node.request().clone(), + block_presence: *node.block_presence(), + }) + .ok(); } } @@ -406,7 +434,7 @@ impl Worker { return; }; - let (request, state) = node.request_and_value_mut(); + let (request, &block_presence, state) = node.parts_mut(); match state { RequestState::InFlight { @@ -435,8 +463,14 @@ impl Worker { .timer .insert((next_client_id, MessageKey::from(request)), REQUEST_TIMEOUT); - // Send the request to the new sender. - next_client_state.request_tx.send(request.clone()).ok(); + // Send the permit to the new sender. + next_client_state + .request_tx + .send(SendPermit { + request: request.clone(), + block_presence, + }) + .ok(); return; } else { @@ -492,7 +526,7 @@ impl ClientId { enum Command { InsertClient { client_id: ClientId, - request_tx: mpsc::UnboundedSender, + request_tx: mpsc::UnboundedSender, }, RemoveClient { client_id: ClientId, @@ -514,12 +548,12 @@ enum Command { } struct ClientState { - request_tx: mpsc::UnboundedSender, + request_tx: mpsc::UnboundedSender, requests: HashMap, } impl ClientState { - fn new(request_tx: mpsc::UnboundedSender) -> Self { + fn new(request_tx: mpsc::UnboundedSender) -> Self { Self { request_tx, requests: HashMap::default(), diff --git a/lib/src/network/request_tracker/graph.rs b/lib/src/network/request_tracker/graph.rs index b09a3cc9c..a8eb1cc61 100644 --- a/lib/src/network/request_tracker/graph.rs +++ b/lib/src/network/request_tracker/graph.rs @@ -130,8 +130,12 @@ impl Node { &self.request } - pub fn request_and_value_mut(&mut self) -> (&Request, &mut T) { - (&self.request, &mut self.value) + pub fn block_presence(&self) -> &MultiBlockPresence { + &self.block_presence + } + + pub fn parts_mut(&mut self) -> (&Request, &MultiBlockPresence, &mut T) { + (&self.request, &self.block_presence, &mut self.value) } pub fn parents(&self) -> impl ExactSizeIterator + '_ { diff --git a/lib/src/network/request_tracker/tests.rs b/lib/src/network/request_tracker/tests.rs index 0a2b61d76..562e10e26 100644 --- a/lib/src/network/request_tracker/tests.rs +++ b/lib/src/network/request_tracker/tests.rs @@ -18,7 +18,7 @@ use rand::{ seq::SliceRandom, CryptoRng, Rng, SeedableRng, }; -use std::collections::VecDeque; +use std::collections::{hash_map::Entry, VecDeque}; // Test syncing while peers keep joining and leaving the swarm. // @@ -109,14 +109,10 @@ async fn dynamic_swarm() { // Test syncing with multiple peers where no peer has all the blocks but every block is present in // at least one peer. -#[ignore = "duplicate request send checking is too strict"] #[tokio::test] async fn missing_blocks() { - crate::test_utils::init_log(); - - // let seed = rand::random(); - let seed = 830380000365750606; - case(seed, 8, 2); + let seed = rand::random(); + case(seed, 32, 4); fn case(seed: u64, max_blocks: usize, max_peers: usize) { let mut rng = StdRng::seed_from_u64(seed); @@ -155,7 +151,11 @@ async fn missing_blocks() { struct Simulation { peers: Vec, - requests: HashSet, + // All requests sent by live peers. This is used to verify that every request is sent only once + // unless the peer that sent it died or the request failed. In those cases the request may be + // sent by another peer. It's also allowed to sent the same request more than once as long as + // each one has a different block presence. + requests: HashMap>, snapshot: Snapshot, } @@ -163,7 +163,7 @@ impl Simulation { fn new() -> Self { Self { peers: Vec::new(), - requests: HashSet::default(), + requests: HashMap::default(), snapshot: Snapshot::default(), } } @@ -188,7 +188,7 @@ impl Simulation { self.peers.push(TestPeer { client, server, - requests: Vec::new(), + requests: HashMap::default(), }); } @@ -196,8 +196,8 @@ impl Simulation { let index = rng.gen_range(0..self.peers.len()); let peer = self.peers.remove(index); - for key in peer.requests { - self.requests.remove(&key); + for (key, block_presence) in peer.requests { + cancel_request(&mut self.requests, key, block_presence); } } @@ -220,15 +220,19 @@ impl Simulation { match side { Side::Client => { - if let Some(request) = peer.client.poll_request() { + if let Some(SendPermit { + request, + block_presence, + }) = peer.client.poll_request() + { let key = MessageKey::from(&request); assert!( - self.requests.insert(key), - "request sent more than once: {request:?}" + self.requests.entry(key).or_default().insert(block_presence), + "request sent more than once: {request:?} ({block_presence:?})" ); - peer.requests.push(key); + peer.requests.insert(key, block_presence); peer.server.handle_request(request); return true; @@ -236,6 +240,29 @@ impl Simulation { } Side::Server => { if let Some(response) = peer.server.poll_response() { + // In case of failure, cancel the request so it can be retried without it + // triggering assertion failure. + let key = match response { + Response::RootNodeError(writer_id, _) => { + Some(MessageKey::RootNode(writer_id)) + } + Response::ChildNodesError(hash, _, _) => { + Some(MessageKey::ChildNodes(hash)) + } + Response::BlockError(block_id, _) => Some(MessageKey::Block(block_id)), + Response::RootNode(..) + | Response::InnerNodes(..) + | Response::LeafNodes(..) + | Response::Block(..) + | Response::BlockOffer(..) => None, + }; + + if let Some(key) = key { + if let Some(block_presence) = peer.requests.get(&key) { + cancel_request(&mut self.requests, key, *block_presence); + } + } + peer.client.handle_response(response, &mut self.snapshot); return true; } @@ -252,21 +279,36 @@ impl Simulation { } } +fn cancel_request( + requests: &mut HashMap>, + key: MessageKey, + block_presence: MultiBlockPresence, +) { + if let Entry::Occupied(mut entry) = requests.entry(key) { + entry.get_mut().remove(&block_presence); + + if entry.get().is_empty() { + entry.remove(); + } + } +} + struct TestPeer { client: TestClient, server: TestServer, - requests: Vec, + // All requests sent by this peer. + requests: HashMap, } struct TestClient { tracker_client: RequestTrackerClient, - tracker_request_rx: mpsc::UnboundedReceiver, + tracker_request_rx: mpsc::UnboundedReceiver, } impl TestClient { fn new( tracker_client: RequestTrackerClient, - tracker_request_rx: mpsc::UnboundedReceiver, + tracker_request_rx: mpsc::UnboundedReceiver, ) -> Self { Self { tracker_client, @@ -352,7 +394,7 @@ impl TestClient { }; } - fn poll_request(&mut self) -> Option { + fn poll_request(&mut self) -> Option { self.tracker_request_rx.try_recv().ok() } } diff --git a/lib/src/protocol/summary.rs b/lib/src/protocol/summary.rs index 2fbd3c7f3..7ece94379 100644 --- a/lib/src/protocol/summary.rs +++ b/lib/src/protocol/summary.rs @@ -77,14 +77,16 @@ impl Summary { } } - /// Checks whether the subtree at `self` is outdated compared to the subtree at `other` in - /// terms of present blocks. That is, whether `other` has some blocks present that `self` is - /// missing. + /// Checks whether the subtree at `self` is outdated compared to the subtree at `other` in terms + /// of completeness and block presence. That is, `self` is considered outdated if it's + /// incomplete (regardless of what `other` is) or if `other` has some blocks present that + /// `self` is missing. /// /// NOTE: This function is NOT antisymetric, that is, `is_outdated(A, B)` does not imply /// !is_outdated(B, A)` (and vice-versa). pub fn is_outdated(&self, other: &Self) -> bool { - self.block_presence.is_outdated(&other.block_presence) + self.state == NodeState::Incomplete + || self.block_presence.is_outdated(&other.block_presence) } pub fn with_state(self, state: NodeState) -> Self { diff --git a/lib/src/protocol/test_utils.rs b/lib/src/protocol/test_utils.rs index 3fb8256bd..5f6172745 100644 --- a/lib/src/protocol/test_utils.rs +++ b/lib/src/protocol/test_utils.rs @@ -1,4 +1,4 @@ -use super::{MultiBlockPresence, NodeState, SingleBlockPresence, Summary}; +use super::{MultiBlockPresence, NodeState, SingleBlockPresence, Summary, EMPTY_INNER_HASH}; use crate::{ crypto::{Hash, Hashable}, protocol::{Block, BlockId, InnerNode, InnerNodes, LeafNode, LeafNodes, INNER_LAYER_COUNT}, @@ -188,8 +188,14 @@ impl Snapshot { } } else { self.root_hash = hash; - // FIXME: if hash == EMPTY_INNER_HASH we should set this to `Complete`: - self.root_summary = Summary::INCOMPLETE; + self.root_summary = if hash == *EMPTY_INNER_HASH { + Summary { + state: NodeState::Complete, + block_presence: MultiBlockPresence::None, + } + } else { + Summary::INCOMPLETE + }; self.inners.clear(); self.leaves.clear(); From feb39e744b45051f2b351abaf76a529e1b3b8ca3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Thu, 26 Sep 2024 14:07:34 +0200 Subject: [PATCH 36/55] Extract Simulation into separate file --- lib/src/network/request_tracker.rs | 3 + lib/src/network/request_tracker/simulation.rs | 364 ++++++++++++++++++ lib/src/network/request_tracker/tests.rs | 363 +---------------- 3 files changed, 372 insertions(+), 358 deletions(-) create mode 100644 lib/src/network/request_tracker/simulation.rs diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index dd231f7a1..1dbf820d9 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -1,4 +1,7 @@ mod graph; + +#[cfg(test)] +mod simulation; #[cfg(test)] mod tests; diff --git a/lib/src/network/request_tracker/simulation.rs b/lib/src/network/request_tracker/simulation.rs new file mode 100644 index 000000000..ee5c8a63c --- /dev/null +++ b/lib/src/network/request_tracker/simulation.rs @@ -0,0 +1,364 @@ +use super::{ + super::message::{Request, Response, ResponseDisambiguator}, + MessageKey, RequestTracker, RequestTrackerClient, SendPermit, +}; +use crate::{ + collections::{HashMap, HashSet}, + crypto::{ + sign::{Keypair, PublicKey}, + Hashable, + }, + network::debug_payload::DebugResponse, + protocol::{ + test_utils::{assert_snapshots_equal, Snapshot}, + Block, MultiBlockPresence, Proof, UntrustedProof, + }, + version_vector::VersionVector, +}; +use rand::{seq::SliceRandom, CryptoRng, Rng}; +use std::collections::{hash_map::Entry, VecDeque}; +use tokio::sync::mpsc; + +/// Simple network simulation for testing `RequestTracker`. +pub(super) struct Simulation { + peers: Vec, + // All requests sent by live peers. This is used to verify that every request is sent only once + // unless the peer that sent it died or the request failed. In those cases the request may be + // sent by another peer. It's also allowed to sent the same request more than once as long as + // each one has a different block presence. + requests: HashMap>, + snapshot: Snapshot, +} + +impl Simulation { + pub fn new() -> Self { + Self { + peers: Vec::new(), + requests: HashMap::default(), + snapshot: Snapshot::default(), + } + } + + pub fn peer_count(&self) -> usize { + self.peers.len() + } + + pub fn insert_peer( + &mut self, + rng: &mut R, + tracker: &RequestTracker, + snapshot: Snapshot, + ) { + let (tracker_client, tracker_request_rx) = tracker.new_client(); + let client = TestClient::new(tracker_client, tracker_request_rx); + + let writer_id = PublicKey::generate(rng); + let write_keys = Keypair::generate(rng); + let server = TestServer::new(writer_id, write_keys, snapshot); + + self.peers.push(TestPeer { + client, + server, + requests: HashMap::default(), + }); + } + + pub fn remove_peer(&mut self, rng: &mut R) { + let index = rng.gen_range(0..self.peers.len()); + let peer = self.peers.remove(index); + + for (key, block_presence) in peer.requests { + cancel_request(&mut self.requests, key, block_presence); + } + } + + // Polls random client or server once + #[track_caller] + pub fn poll(&mut self, rng: &mut R) -> bool { + enum Side { + Client, + Server, + } + + let mut order: Vec<_> = (0..self.peers.len()) + .flat_map(|index| [(Side::Client, index), (Side::Server, index)]) + .collect(); + + order.shuffle(rng); + + for (side, index) in order { + let peer = &mut self.peers[index]; + + match side { + Side::Client => { + if let Some(SendPermit { + request, + block_presence, + }) = peer.client.poll_request() + { + let key = MessageKey::from(&request); + + assert!( + self.requests.entry(key).or_default().insert(block_presence), + "request sent more than once: {request:?} ({block_presence:?})" + ); + + peer.requests.insert(key, block_presence); + peer.server.handle_request(request); + + return true; + } + } + Side::Server => { + if let Some(response) = peer.server.poll_response() { + // In case of failure, cancel the request so it can be retried without it + // triggering assertion failure. + let key = match response { + Response::RootNodeError(writer_id, _) => { + Some(MessageKey::RootNode(writer_id)) + } + Response::ChildNodesError(hash, _, _) => { + Some(MessageKey::ChildNodes(hash)) + } + Response::BlockError(block_id, _) => Some(MessageKey::Block(block_id)), + Response::RootNode(..) + | Response::InnerNodes(..) + | Response::LeafNodes(..) + | Response::Block(..) + | Response::BlockOffer(..) => None, + }; + + if let Some(key) = key { + if let Some(block_presence) = peer.requests.get(&key) { + cancel_request(&mut self.requests, key, *block_presence); + } + } + + peer.client.handle_response(response, &mut self.snapshot); + return true; + } + } + } + } + + false + } + + #[track_caller] + pub fn verify(&self, expected_snapshot: &Snapshot) { + assert_snapshots_equal(&self.snapshot, expected_snapshot) + } +} + +fn cancel_request( + requests: &mut HashMap>, + key: MessageKey, + block_presence: MultiBlockPresence, +) { + if let Entry::Occupied(mut entry) = requests.entry(key) { + entry.get_mut().remove(&block_presence); + + if entry.get().is_empty() { + entry.remove(); + } + } +} + +struct TestPeer { + client: TestClient, + server: TestServer, + // All requests sent by this peer. + requests: HashMap, +} + +struct TestClient { + tracker_client: RequestTrackerClient, + tracker_request_rx: mpsc::UnboundedReceiver, +} + +impl TestClient { + fn new( + tracker_client: RequestTrackerClient, + tracker_request_rx: mpsc::UnboundedReceiver, + ) -> Self { + Self { + tracker_client, + tracker_request_rx, + } + } + + fn handle_response(&mut self, response: Response, snapshot: &mut Snapshot) { + match response { + Response::RootNode(proof, block_presence, debug_payload) => { + let requests = snapshot + .insert_root(proof.hash, block_presence) + .then_some(( + Request::ChildNodes( + proof.hash, + ResponseDisambiguator::new(block_presence), + debug_payload.follow_up(), + ), + block_presence, + )) + .into_iter() + .collect(); + + self.tracker_client + .success(MessageKey::RootNode(proof.writer_id), requests); + } + Response::InnerNodes(nodes, _disambiguator, debug_payload) => { + let parent_hash = nodes.hash(); + let nodes = snapshot.insert_inners(nodes); + + let requests: Vec<_> = nodes + .into_iter() + .map(|(_, node)| { + ( + Request::ChildNodes( + node.hash, + ResponseDisambiguator::new(node.summary.block_presence), + debug_payload.follow_up(), + ), + node.summary.block_presence, + ) + }) + .collect(); + + self.tracker_client + .success(MessageKey::ChildNodes(parent_hash), requests); + } + Response::LeafNodes(nodes, _disambiguator, debug_payload) => { + let parent_hash = nodes.hash(); + let nodes = snapshot.insert_leaves(nodes); + let requests = nodes + .into_iter() + .map(|node| { + ( + Request::Block(node.block_id, debug_payload.follow_up()), + MultiBlockPresence::None, + ) + }) + .collect(); + + self.tracker_client + .success(MessageKey::ChildNodes(parent_hash), requests); + } + Response::Block(content, nonce, _debug_payload) => { + let block = Block::new(content, nonce); + let block_id = block.id; + + snapshot.insert_block(block); + + self.tracker_client + .success(MessageKey::Block(block_id), vec![]); + } + Response::RootNodeError(writer_id, _debug_payload) => { + self.tracker_client.failure(MessageKey::RootNode(writer_id)); + } + Response::ChildNodesError(hash, _disambiguator, _debug_payload) => { + self.tracker_client.failure(MessageKey::ChildNodes(hash)); + } + Response::BlockError(block_id, _debug_payload) => { + self.tracker_client.failure(MessageKey::Block(block_id)); + } + Response::BlockOffer(_block_id, _debug_payload) => unimplemented!(), + }; + } + + fn poll_request(&mut self) -> Option { + self.tracker_request_rx.try_recv().ok() + } +} + +struct TestServer { + writer_id: PublicKey, + write_keys: Keypair, + snapshot: Snapshot, + outbox: VecDeque, +} + +impl TestServer { + fn new(writer_id: PublicKey, write_keys: Keypair, snapshot: Snapshot) -> Self { + let proof = UntrustedProof::from(Proof::new( + writer_id, + VersionVector::first(writer_id), + *snapshot.root_hash(), + &write_keys, + )); + + let outbox = [Response::RootNode( + proof.clone(), + snapshot.root_summary().block_presence, + DebugResponse::unsolicited(), + )] + .into(); + + Self { + writer_id, + write_keys, + snapshot, + outbox, + } + } + + fn handle_request(&mut self, request: Request) { + match request { + Request::RootNode(writer_id, debug_payload) => { + if writer_id == self.writer_id { + let proof = Proof::new( + writer_id, + VersionVector::first(writer_id), + *self.snapshot.root_hash(), + &self.write_keys, + ); + + self.outbox.push_back(Response::RootNode( + proof.into(), + self.snapshot.root_summary().block_presence, + debug_payload.reply(), + )); + } else { + self.outbox + .push_back(Response::RootNodeError(writer_id, debug_payload.reply())); + } + } + Request::ChildNodes(hash, disambiguator, debug_payload) => { + if let Some(nodes) = self.snapshot.get_inner_set(&hash) { + self.outbox.push_back(Response::InnerNodes( + nodes.clone(), + disambiguator, + debug_payload.reply(), + )); + } else if let Some(nodes) = self.snapshot.get_leaf_set(&hash) { + self.outbox.push_back(Response::LeafNodes( + nodes.clone(), + disambiguator, + debug_payload.reply(), + )); + } else { + self.outbox.push_back(Response::ChildNodesError( + hash, + disambiguator, + debug_payload.reply(), + )); + } + } + Request::Block(block_id, debug_payload) => { + if let Some(block) = self.snapshot.blocks().get(&block_id) { + self.outbox.push_back(Response::Block( + block.content.clone(), + block.nonce, + debug_payload.reply(), + )); + } else { + self.outbox + .push_back(Response::BlockError(block_id, debug_payload.reply())); + } + } + } + } + + fn poll_response(&mut self) -> Option { + self.outbox.pop_front() + } +} diff --git a/lib/src/network/request_tracker/tests.rs b/lib/src/network/request_tracker/tests.rs index 562e10e26..7f39899ea 100644 --- a/lib/src/network/request_tracker/tests.rs +++ b/lib/src/network/request_tracker/tests.rs @@ -1,24 +1,14 @@ -use super::{ - super::{debug_payload::DebugResponse, message::Response}, - *, -}; -use crate::{ - collections::HashSet, - crypto::{sign::Keypair, Hashable}, - network::message::ResponseDisambiguator, - protocol::{ - test_utils::{assert_snapshots_equal, BlockState, Snapshot}, - Block, MultiBlockPresence, Proof, UntrustedProof, - }, - version_vector::VersionVector, +use super::{simulation::Simulation, *}; +use crate::protocol::{ + test_utils::{BlockState, Snapshot}, + Block, }; use rand::{ distributions::{Bernoulli, Distribution, Standard}, rngs::StdRng, seq::SliceRandom, - CryptoRng, Rng, SeedableRng, + Rng, SeedableRng, }; -use std::collections::{hash_map::Entry, VecDeque}; // Test syncing while peers keep joining and leaving the swarm. // @@ -149,349 +139,6 @@ async fn missing_blocks() { // TODO: test failure/timeout -struct Simulation { - peers: Vec, - // All requests sent by live peers. This is used to verify that every request is sent only once - // unless the peer that sent it died or the request failed. In those cases the request may be - // sent by another peer. It's also allowed to sent the same request more than once as long as - // each one has a different block presence. - requests: HashMap>, - snapshot: Snapshot, -} - -impl Simulation { - fn new() -> Self { - Self { - peers: Vec::new(), - requests: HashMap::default(), - snapshot: Snapshot::default(), - } - } - - fn peer_count(&self) -> usize { - self.peers.len() - } - - fn insert_peer( - &mut self, - rng: &mut R, - tracker: &RequestTracker, - snapshot: Snapshot, - ) { - let (tracker_client, tracker_request_rx) = tracker.new_client(); - let client = TestClient::new(tracker_client, tracker_request_rx); - - let writer_id = PublicKey::generate(rng); - let write_keys = Keypair::generate(rng); - let server = TestServer::new(writer_id, write_keys, snapshot); - - self.peers.push(TestPeer { - client, - server, - requests: HashMap::default(), - }); - } - - fn remove_peer(&mut self, rng: &mut R) { - let index = rng.gen_range(0..self.peers.len()); - let peer = self.peers.remove(index); - - for (key, block_presence) in peer.requests { - cancel_request(&mut self.requests, key, block_presence); - } - } - - // Polls random client or server once - #[track_caller] - fn poll(&mut self, rng: &mut R) -> bool { - enum Side { - Client, - Server, - } - - let mut order: Vec<_> = (0..self.peers.len()) - .flat_map(|index| [(Side::Client, index), (Side::Server, index)]) - .collect(); - - order.shuffle(rng); - - for (side, index) in order { - let peer = &mut self.peers[index]; - - match side { - Side::Client => { - if let Some(SendPermit { - request, - block_presence, - }) = peer.client.poll_request() - { - let key = MessageKey::from(&request); - - assert!( - self.requests.entry(key).or_default().insert(block_presence), - "request sent more than once: {request:?} ({block_presence:?})" - ); - - peer.requests.insert(key, block_presence); - peer.server.handle_request(request); - - return true; - } - } - Side::Server => { - if let Some(response) = peer.server.poll_response() { - // In case of failure, cancel the request so it can be retried without it - // triggering assertion failure. - let key = match response { - Response::RootNodeError(writer_id, _) => { - Some(MessageKey::RootNode(writer_id)) - } - Response::ChildNodesError(hash, _, _) => { - Some(MessageKey::ChildNodes(hash)) - } - Response::BlockError(block_id, _) => Some(MessageKey::Block(block_id)), - Response::RootNode(..) - | Response::InnerNodes(..) - | Response::LeafNodes(..) - | Response::Block(..) - | Response::BlockOffer(..) => None, - }; - - if let Some(key) = key { - if let Some(block_presence) = peer.requests.get(&key) { - cancel_request(&mut self.requests, key, *block_presence); - } - } - - peer.client.handle_response(response, &mut self.snapshot); - return true; - } - } - } - } - - false - } - - #[track_caller] - fn verify(&self, expected_snapshot: &Snapshot) { - assert_snapshots_equal(&self.snapshot, expected_snapshot) - } -} - -fn cancel_request( - requests: &mut HashMap>, - key: MessageKey, - block_presence: MultiBlockPresence, -) { - if let Entry::Occupied(mut entry) = requests.entry(key) { - entry.get_mut().remove(&block_presence); - - if entry.get().is_empty() { - entry.remove(); - } - } -} - -struct TestPeer { - client: TestClient, - server: TestServer, - // All requests sent by this peer. - requests: HashMap, -} - -struct TestClient { - tracker_client: RequestTrackerClient, - tracker_request_rx: mpsc::UnboundedReceiver, -} - -impl TestClient { - fn new( - tracker_client: RequestTrackerClient, - tracker_request_rx: mpsc::UnboundedReceiver, - ) -> Self { - Self { - tracker_client, - tracker_request_rx, - } - } - - fn handle_response(&mut self, response: Response, snapshot: &mut Snapshot) { - match response { - Response::RootNode(proof, block_presence, debug_payload) => { - let requests = snapshot - .insert_root(proof.hash, block_presence) - .then_some(( - Request::ChildNodes( - proof.hash, - ResponseDisambiguator::new(block_presence), - debug_payload.follow_up(), - ), - block_presence, - )) - .into_iter() - .collect(); - - self.tracker_client - .success(MessageKey::RootNode(proof.writer_id), requests); - } - Response::InnerNodes(nodes, _disambiguator, debug_payload) => { - let parent_hash = nodes.hash(); - let nodes = snapshot.insert_inners(nodes); - - let requests: Vec<_> = nodes - .into_iter() - .map(|(_, node)| { - ( - Request::ChildNodes( - node.hash, - ResponseDisambiguator::new(node.summary.block_presence), - debug_payload.follow_up(), - ), - node.summary.block_presence, - ) - }) - .collect(); - - self.tracker_client - .success(MessageKey::ChildNodes(parent_hash), requests); - } - Response::LeafNodes(nodes, _disambiguator, debug_payload) => { - let parent_hash = nodes.hash(); - let nodes = snapshot.insert_leaves(nodes); - let requests = nodes - .into_iter() - .map(|node| { - ( - Request::Block(node.block_id, debug_payload.follow_up()), - MultiBlockPresence::None, - ) - }) - .collect(); - - self.tracker_client - .success(MessageKey::ChildNodes(parent_hash), requests); - } - Response::Block(content, nonce, _debug_payload) => { - let block = Block::new(content, nonce); - let block_id = block.id; - - snapshot.insert_block(block); - - self.tracker_client - .success(MessageKey::Block(block_id), vec![]); - } - Response::RootNodeError(writer_id, _debug_payload) => { - self.tracker_client.failure(MessageKey::RootNode(writer_id)); - } - Response::ChildNodesError(hash, _disambiguator, _debug_payload) => { - self.tracker_client.failure(MessageKey::ChildNodes(hash)); - } - Response::BlockError(block_id, _debug_payload) => { - self.tracker_client.failure(MessageKey::Block(block_id)); - } - Response::BlockOffer(_block_id, _debug_payload) => unimplemented!(), - }; - } - - fn poll_request(&mut self) -> Option { - self.tracker_request_rx.try_recv().ok() - } -} - -struct TestServer { - writer_id: PublicKey, - write_keys: Keypair, - snapshot: Snapshot, - outbox: VecDeque, -} - -impl TestServer { - fn new(writer_id: PublicKey, write_keys: Keypair, snapshot: Snapshot) -> Self { - let proof = UntrustedProof::from(Proof::new( - writer_id, - VersionVector::first(writer_id), - *snapshot.root_hash(), - &write_keys, - )); - - let outbox = [Response::RootNode( - proof.clone(), - snapshot.root_summary().block_presence, - DebugResponse::unsolicited(), - )] - .into(); - - Self { - writer_id, - write_keys, - snapshot, - outbox, - } - } - - fn handle_request(&mut self, request: Request) { - match request { - Request::RootNode(writer_id, debug_payload) => { - if writer_id == self.writer_id { - let proof = Proof::new( - writer_id, - VersionVector::first(writer_id), - *self.snapshot.root_hash(), - &self.write_keys, - ); - - self.outbox.push_back(Response::RootNode( - proof.into(), - self.snapshot.root_summary().block_presence, - debug_payload.reply(), - )); - } else { - self.outbox - .push_back(Response::RootNodeError(writer_id, debug_payload.reply())); - } - } - Request::ChildNodes(hash, disambiguator, debug_payload) => { - if let Some(nodes) = self.snapshot.get_inner_set(&hash) { - self.outbox.push_back(Response::InnerNodes( - nodes.clone(), - disambiguator, - debug_payload.reply(), - )); - } else if let Some(nodes) = self.snapshot.get_leaf_set(&hash) { - self.outbox.push_back(Response::LeafNodes( - nodes.clone(), - disambiguator, - debug_payload.reply(), - )); - } else { - self.outbox.push_back(Response::ChildNodesError( - hash, - disambiguator, - debug_payload.reply(), - )); - } - } - Request::Block(block_id, debug_payload) => { - if let Some(block) = self.snapshot.blocks().get(&block_id) { - self.outbox.push_back(Response::Block( - block.content.clone(), - block.nonce, - debug_payload.reply(), - )); - } else { - self.outbox - .push_back(Response::BlockError(block_id, debug_payload.reply())); - } - } - } - } - - fn poll_response(&mut self) -> Option { - self.outbox.pop_front() - } -} - /// Generate `count + 1` copies of the same snapshot. The first one will have all the blocks /// present (the "master copy"). The remaining ones will have some blocks missing but in such a /// way that every block is present in at least one of the snapshots. From a2074637299068cc96625e889863b980cca9b282 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Thu, 26 Sep 2024 14:29:20 +0200 Subject: [PATCH 37/55] Add test for RequestTracker timeout --- lib/src/network/request_tracker/tests.rs | 66 ++++++++++++++++++++++-- 1 file changed, 62 insertions(+), 4 deletions(-) diff --git a/lib/src/network/request_tracker/tests.rs b/lib/src/network/request_tracker/tests.rs index 7f39899ea..03b024da1 100644 --- a/lib/src/network/request_tracker/tests.rs +++ b/lib/src/network/request_tracker/tests.rs @@ -1,7 +1,10 @@ use super::{simulation::Simulation, *}; -use crate::protocol::{ - test_utils::{BlockState, Snapshot}, - Block, +use crate::{ + network::{debug_payload::DebugRequest, message::ResponseDisambiguator}, + protocol::{ + test_utils::{BlockState, Snapshot}, + Block, + }, }; use rand::{ distributions::{Bernoulli, Distribution, Standard}, @@ -9,6 +12,8 @@ use rand::{ seq::SliceRandom, Rng, SeedableRng, }; +use std::{pin::pin, time::Duration}; +use tokio::{sync::mpsc, time}; // Test syncing while peers keep joining and leaving the swarm. // @@ -137,7 +142,60 @@ async fn missing_blocks() { } } -// TODO: test failure/timeout +#[tokio::test(start_paused = true)] +async fn timeout() { + let mut rng = StdRng::seed_from_u64(0); + let (tracker, tracker_worker) = build(); + + let mut work = pin!(tracker_worker.run()); + + let (client_a, mut request_rx_a) = tracker.new_client(); + let (client_b, mut request_rx_b) = tracker.new_client(); + + let preceding_request_key = MessageKey::RootNode(PublicKey::generate(&mut rng)); + let request = Request::ChildNodes( + rng.gen(), + ResponseDisambiguator::new(MultiBlockPresence::Full), + DebugRequest::start(), + ); + + // Register the request with both clients. + client_a.success( + preceding_request_key, + vec![(request.clone(), MultiBlockPresence::Full)], + ); + + client_b.success( + preceding_request_key, + vec![(request.clone(), MultiBlockPresence::Full)], + ); + + time::timeout(Duration::from_millis(1), &mut work) + .await + .ok(); + + // Only the first client gets the send permit. + assert_eq!( + request_rx_a.try_recv().map(|permit| permit.request), + Ok(request.clone()) + ); + + assert_eq!( + request_rx_b.try_recv().map(|permit| permit.request), + Err(mpsc::error::TryRecvError::Empty), + ); + + // Wait until the request timeout passes + time::timeout(REQUEST_TIMEOUT + Duration::from_millis(1), &mut work) + .await + .ok(); + + // The first client timeouted so the second client now gets the permit. + assert_eq!( + request_rx_b.try_recv().map(|permit| permit.request), + Ok(request.clone()) + ); +} /// Generate `count + 1` copies of the same snapshot. The first one will have all the blocks /// present (the "master copy"). The remaining ones will have some blocks missing but in such a From 430276f05acdcc5a3be8da0aa78c4831ce88933c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Thu, 26 Sep 2024 16:18:15 +0200 Subject: [PATCH 38/55] Integrate RequestTracker --- lib/src/blob/lock.rs | 4 +- lib/src/blob/mod.rs | 4 +- lib/src/block_tracker.rs | 1202 +++++++---------- lib/src/collections.rs | 1 - lib/src/network/client.rs | 212 ++- lib/src/network/connection.rs | 3 +- lib/src/network/local_discovery.rs | 2 +- lib/src/network/message_broker.rs | 26 +- lib/src/network/mod.rs | 13 +- lib/src/network/pending.rs | 303 ----- lib/src/network/request_tracker.rs | 160 ++- lib/src/network/request_tracker/graph.rs | 168 ++- lib/src/network/request_tracker/simulation.rs | 40 +- lib/src/network/request_tracker/tests.rs | 10 +- lib/src/network/seen_peers.rs | 4 +- lib/src/network/tests.rs | 95 +- lib/src/network/upnp.rs | 15 +- lib/src/repository/mod.rs | 12 +- lib/src/repository/vault/tests.rs | 6 +- lib/src/repository/worker.rs | 34 +- lib/src/store/block_expiration_tracker.rs | 10 +- lib/src/store/block_id_cache.rs | 9 +- lib/src/store/client.rs | 92 +- lib/src/store/test_utils.rs | 2 +- lib/src/sync.rs | 108 -- 25 files changed, 937 insertions(+), 1598 deletions(-) delete mode 100644 lib/src/network/pending.rs diff --git a/lib/src/blob/lock.rs b/lib/src/blob/lock.rs index 906a1f80c..2bb3fb8cb 100644 --- a/lib/src/blob/lock.rs +++ b/lib/src/blob/lock.rs @@ -2,12 +2,12 @@ use crate::{ blob::BlobId, - collections::{hash_map::Entry, HashMap}, + collections::HashMap, crypto::sign::PublicKey, sync::{AwaitDrop, DropAwaitable}, }; use deadlock::BlockingMutex; -use std::sync::Arc; +use std::{collections::hash_map::Entry, sync::Arc}; /// Container for blob locks in all branches. #[derive(Default, Clone)] diff --git a/lib/src/blob/mod.rs b/lib/src/blob/mod.rs index bf5473998..e45f4bf35 100644 --- a/lib/src/blob/mod.rs +++ b/lib/src/blob/mod.rs @@ -12,7 +12,7 @@ pub(crate) use self::{block_ids::BlockIds, id::BlobId}; use self::position::Position; use crate::{ branch::Branch, - collections::{hash_map::Entry, HashMap}, + collections::HashMap, crypto::{ cipher::{self, Nonce, SecretKey}, sign::{Keypair, PublicKey}, @@ -25,7 +25,7 @@ use crate::{ }, store::{self, Changeset, ReadTransaction}, }; -use std::{io::SeekFrom, iter, mem}; +use std::{collections::hash_map::Entry, io::SeekFrom, iter, mem}; use thiserror::Error; /// Size of the blob header in bytes. diff --git a/lib/src/block_tracker.rs b/lib/src/block_tracker.rs index 17461bae1..433fe7065 100644 --- a/lib/src/block_tracker.rs +++ b/lib/src/block_tracker.rs @@ -2,31 +2,29 @@ use crate::{ collections::{HashMap, HashSet}, protocol::BlockId, }; -use deadlock::BlockingMutex; -use std::{collections::hash_map::Entry, sync::Arc}; -use tokio::sync::watch; - -/// Helper for tracking required missing blocks. +use std::{ + collections::hash_map::Entry, + mem, + sync::atomic::{AtomicUsize, Ordering}, +}; +use tokio::{sync::mpsc, task}; + +/// Tracks blocks that are offered for requesting from some peers and blocks that are required +/// locally. If a block is both, a notification is triggered to prompt us to request the block from +/// the offering peers. +/// +/// Note the above applies only to read and write replicas. For blind replicas which blocks are +/// required is not known and so blocks are notified immediately once they've been offered. #[derive(Clone)] pub(crate) struct BlockTracker { - shared: Arc, + command_tx: mpsc::UnboundedSender, } impl BlockTracker { pub fn new() -> Self { - let (notify_tx, _) = watch::channel(()); - - Self { - shared: Arc::new(Shared { - inner: BlockingMutex::new(Inner { - missing_blocks: HashMap::default(), - clients: HashMap::default(), - next_client_id: 0, - request_mode: RequestMode::Greedy, - }), - notify_tx, - }), - } + let (this, worker) = build(); + task::spawn(worker.run()); + this } /// Set the mode in which blocks are requested: @@ -34,914 +32,634 @@ impl BlockTracker { /// - Lazy: block is requested when it's both offered and required /// - Greedy: block is requested as soon as it's offered. /// - /// Note: In `Greedy` mode calling `require` (or `require_batch`) is unnecessary. - pub fn set_request_mode(&self, mode: RequestMode) { - self.shared.inner.lock().unwrap().request_mode = mode; + /// Note: In `Greedy` mode calling `require` is unnecessary. + pub fn set_request_mode(&self, mode: BlockRequestMode) { + self.command_tx.send(Command::SetRequestMode { mode }).ok(); } - /// Marks the block with the given id as required. - pub fn require(&self, block_id: BlockId) { - if self.shared.inner.lock().unwrap().require(block_id) { - self.shared.notify() - } - } + pub fn new_client(&self) -> (BlockTrackerClient, mpsc::UnboundedReceiver) { + let (block_tx, block_rx) = mpsc::unbounded_channel(); + let client_id = ClientId::next(); - /// Marks multiple blocks as required. - pub fn require_batch(&self) -> RequireBatch<'_> { - RequireBatch { - shared: &self.shared, - notify: false, - } + self.command_tx + .send(Command::InsertClient { + client_id, + block_tx, + }) + .ok(); + + ( + BlockTrackerClient { + client_id, + command_tx: self.command_tx.clone(), + }, + block_rx, + ) } - /// Approve the block request if offered. This is called when `quota` is not `None`, otherwise - /// blocks are pre-approved from `TrackerClient::register(block_id, OfferState::Approved)`. - pub fn approve(&self, block_id: BlockId) { - let mut inner = self.shared.inner.lock().unwrap(); + pub fn require(&self, block_id: BlockId) { + self.command_tx + .send(Command::InsertRequired { block_id }) + .ok(); + } - let Some(missing_block) = inner.missing_blocks.get_mut(&block_id) else { - return; - }; + pub fn clear_required(&self) { + self.command_tx.send(Command::ClearRequired).ok(); + } +} - let required = match &mut missing_block.state { - State::Idle { approved: true, .. } | State::Accepted(_) => return, - State::Idle { approved, required } => { - *approved = true; - *required - } - }; +pub(crate) struct BlockTrackerClient { + client_id: ClientId, + command_tx: mpsc::UnboundedSender, +} - // If required and offered, notify the waiting acceptors. - if required && !missing_block.offers.is_empty() { - self.shared.notify(); - } +impl BlockTrackerClient { + /// Marks block as offered by this peer. + pub fn offer(&self, block_id: BlockId, state: BlockOfferState) { + self.command_tx + .send(Command::InsertOffer { + client_id: self.client_id, + block_id, + state, + }) + .ok(); } - pub fn client(&self) -> TrackerClient { - let client_id = self.shared.inner.lock().unwrap().insert_client(); - let notify_rx = self.shared.notify_tx.subscribe(); - - TrackerClient { - shared: self.shared.clone(), - client_id, - notify_rx, - } + /// Marks a block offer as approved. Do this for a block offered (see [Self::offer]) previously + /// with a `Complete` or `Incomplete` root node state after that root node has become + /// `Approved`. + /// + /// Note: this marks all offers for the give block as approved, not just the ones from this + /// peer. + pub fn approve(&self, block_id: BlockId) { + self.command_tx + .send(Command::ApproveOffers { block_id }) + .ok(); } } -#[derive(Clone, Copy, Debug)] -pub(crate) enum OfferState { - Pending, - Approved, +impl Drop for BlockTrackerClient { + fn drop(&mut self) { + self.command_tx + .send(Command::RemoveClient { + client_id: self.client_id, + }) + .ok(); + } } #[derive(Clone, Copy)] -pub(crate) enum RequestMode { +pub(crate) enum BlockRequestMode { // Request only required blocks Lazy, // Request all blocks Greedy, } -pub(crate) struct RequireBatch<'a> { - shared: &'a Shared, - notify: bool, +#[derive(Clone, Copy)] +pub(crate) enum BlockOfferState { + Pending, + Approved, +} + +fn build() -> (BlockTracker, Worker) { + let (command_tx, command_rx) = mpsc::unbounded_channel(); + (BlockTracker { command_tx }, Worker::new(command_rx)) +} + +struct Worker { + clients: HashMap, + required_blocks: HashSet, + request_mode: BlockRequestMode, + command_rx: mpsc::UnboundedReceiver, } -impl RequireBatch<'_> { - pub fn add(&mut self, block_id: BlockId) { - if self.shared.inner.lock().unwrap().require(block_id) { - self.notify = true; +impl Worker { + fn new(command_rx: mpsc::UnboundedReceiver) -> Self { + Self { + clients: HashMap::default(), + required_blocks: HashSet::default(), + request_mode: BlockRequestMode::Greedy, + command_rx, } } - pub fn commit(&mut self) { - if self.notify { - self.shared.notify(); - self.notify = false; + async fn run(mut self) { + while let Some(command) = self.command_rx.recv().await { + self.handle_command(command); } } -} -impl Drop for RequireBatch<'_> { - fn drop(&mut self) { - self.commit() + /// Process all currently queued commands. + #[cfg(test)] + pub fn step(&mut self) { + while let Ok(command) = self.command_rx.try_recv() { + self.handle_command(command); + } } -} - -pub(crate) struct TrackerClient { - shared: Arc, - client_id: ClientId, - notify_rx: watch::Receiver<()>, -} -impl TrackerClient { - /// Returns a stream of offers for required blocks. - pub fn offers(&self) -> BlockOffers { - BlockOffers { - shared: self.shared.clone(), - client_id: self.client_id, - notify_rx: self.notify_rx.clone(), + fn handle_command(&mut self, command: Command) { + match command { + Command::InsertClient { + client_id, + block_tx, + } => self.handle_insert_client(client_id, block_tx), + Command::RemoveClient { client_id } => self.handle_remove_client(client_id), + Command::InsertOffer { + client_id, + block_id, + state, + } => self.handle_insert_offer(client_id, block_id, state), + Command::ApproveOffers { block_id } => self.handle_approve_offers(block_id), + Command::SetRequestMode { mode } => self.handle_set_request_mode(mode), + Command::InsertRequired { block_id } => self.handle_insert_required(block_id), + Command::ClearRequired => self.handle_clear_required(), } } - /// Registers an offer for a block with the given id. - /// Returns `true` if this block was offered for the first time (by any client) or `false` if - /// it's already been offered but not yet accepted or cancelled. - pub fn register(&self, block_id: BlockId, state: OfferState) -> bool { - let mut inner = self.shared.inner.lock().unwrap(); - - // unwrap is OK because if `self` exists the `inner.clients` entry must exists as well. - if !inner - .clients - .get_mut(&self.client_id) - .unwrap() - .insert(block_id) - { - // Already offered - return false; - } + fn handle_insert_client( + &mut self, + client_id: ClientId, + block_tx: mpsc::UnboundedSender, + ) { + self.clients.insert(client_id, ClientState::new(block_tx)); + } - let missing_block = inner - .missing_blocks - .entry(block_id) - .or_insert_with(|| MissingBlock { - offers: HashMap::default(), - state: State::Idle { - required: false, - approved: false, - }, - }); - - missing_block - .offers - .insert(self.client_id, Offer::Available); - - let mut notify = false; - - match &mut missing_block.state { - State::Idle { approved, .. } => { - match state { - OfferState::Approved => { - *approved = true; - } - OfferState::Pending => (), - } + fn handle_remove_client(&mut self, client_id: ClientId) { + self.clients.remove(&client_id); + } - if *approved { - notify = true; + fn handle_set_request_mode(&mut self, mode: BlockRequestMode) { + self.request_mode = mode; + + match mode { + BlockRequestMode::Greedy => { + for client_state in self.clients.values_mut() { + client_state.accept_all_approved(); } + + self.required_blocks.clear(); } - State::Accepted(_) => (), + BlockRequestMode::Lazy => (), } + } - match inner.request_mode { - RequestMode::Lazy => (), - RequestMode::Greedy => { - if inner.require(block_id) { - notify = true; - } - } + fn handle_insert_required(&mut self, block_id: BlockId) { + if !self.required_blocks.insert(block_id) { + // already required + return; } - if notify { - self.shared.notify(); + match self.request_mode { + BlockRequestMode::Greedy => { + for client_state in self.clients.values_mut() { + client_state.accept(block_id); + } + } + BlockRequestMode::Lazy => { + for client_state in self.clients.values_mut() { + client_state.accept_if_approved(block_id); + } + } } + } - true + fn handle_clear_required(&mut self) { + self.required_blocks.clear(); } -} -impl Drop for TrackerClient { - fn drop(&mut self) { - if self - .shared - .inner - .lock() - .unwrap() - .remove_client(self.client_id) - { - self.shared.notify(); + fn handle_insert_offer( + &mut self, + client_id: ClientId, + block_id: BlockId, + new_state: BlockOfferState, + ) { + let Some(client_state) = self.clients.get_mut(&client_id) else { + return; + }; + + let old_state = mem::replace( + client_state + .offers + .entry(block_id) + .or_insert(BlockOfferState::Pending), + new_state, + ); + + if matches!(new_state, BlockOfferState::Pending) { + return; } - } -} -/// Stream of offers for required blocks. -pub(crate) struct BlockOffers { - shared: Arc, - client_id: ClientId, - notify_rx: watch::Receiver<()>, -} + if matches!(old_state, BlockOfferState::Approved) { + return; + } -impl BlockOffers { - /// Returns the next offer, waiting for one to appear if necessary. - pub async fn next(&mut self) -> BlockOffer { - loop { - if let Some(offer) = self.try_next() { - return offer; + match self.request_mode { + BlockRequestMode::Greedy => { + client_state.block_tx.send(block_id).ok(); } - - // unwrap is ok because the sender exists in self.shared. - self.notify_rx.changed().await.unwrap(); + BlockRequestMode::Lazy if self.required_blocks.contains(&block_id) => { + client_state.block_tx.send(block_id).ok(); + } + BlockRequestMode::Lazy => (), } } - /// Returns the next offer or `None` if none exists currently. - pub fn try_next(&self) -> Option { - let block_id = self - .shared - .inner - .lock() - .unwrap() - .propose_offer(self.client_id)?; + fn handle_approve_offers(&mut self, block_id: BlockId) { + let accept = matches!(self.request_mode, BlockRequestMode::Greedy) + || self.required_blocks.contains(&block_id); - Some(BlockOffer { - shared: self.shared.clone(), - client_id: self.client_id, - block_id, - complete: false, - }) + for client_state in self.clients.values_mut() { + if accept { + client_state.accept(block_id); + } else { + client_state.approve(block_id); + } + } } } -/// Offer for a required block. -pub(crate) struct BlockOffer { - shared: Arc, - client_id: ClientId, - block_id: BlockId, - complete: bool, +struct ClientState { + offers: HashMap, + block_tx: mpsc::UnboundedSender, } -impl BlockOffer { - #[cfg(test)] - pub fn block_id(&self) -> &BlockId { - &self.block_id - } - - /// Accepts the offer. There can be multiple offers for the same block (each from a different - /// peer) but only one returns `Some` here. The returned `BlockPromise` is a commitment to send - /// the block request through this client. - pub fn accept(self) -> Option { - if self - .shared - .inner - .lock() - .unwrap() - .accept_offer(&self.block_id, self.client_id) - { - Some(BlockPromise(self)) - } else { - None +impl ClientState { + fn new(block_tx: mpsc::UnboundedSender) -> Self { + Self { + offers: HashMap::default(), + block_tx, } } -} -impl Drop for BlockOffer { - fn drop(&mut self) { - if self.complete { - return; + fn accept(&mut self, block_id: BlockId) { + if self.offers.remove(&block_id).is_some() { + self.block_tx.send(block_id).ok(); } + } + + fn accept_if_approved(&mut self, block_id: BlockId) { + let Entry::Occupied(entry) = self.offers.entry(block_id) else { + return; + }; - if self - .shared - .inner - .lock() - .unwrap() - .cancel_offer(&self.block_id, self.client_id) - { - self.shared.notify(); + match entry.get() { + BlockOfferState::Approved => { + entry.remove(); + self.block_tx.send(block_id).ok(); + } + BlockOfferState::Pending => (), } } -} - -/// Accepted block offer. -pub(crate) struct BlockPromise(BlockOffer); -impl BlockPromise { - pub(crate) fn block_id(&self) -> &BlockId { - &self.0.block_id + fn accept_all_approved(&mut self) { + self.offers.retain(|block_id, state| match state { + BlockOfferState::Approved => { + self.block_tx.send(*block_id).ok(); + false + } + BlockOfferState::Pending => true, + }); } - /// Mark the block request as successfully completed. - pub fn complete(mut self) { - self.0.complete = true; - self.0 - .shared - .inner - .lock() - .unwrap() - .complete(&self.0.block_id); + fn approve(&mut self, block_id: BlockId) { + if let Some(state) = self.offers.get_mut(&block_id) { + *state = BlockOfferState::Approved; + } } } -struct Shared { - inner: BlockingMutex, - notify_tx: watch::Sender<()>, -} +#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] +struct ClientId(usize); -impl Shared { - fn notify(&self) { - self.notify_tx.send(()).unwrap_or(()) +impl ClientId { + fn next() -> Self { + static NEXT: AtomicUsize = AtomicUsize::new(0); + Self(NEXT.fetch_add(1, Ordering::Relaxed)) } } -// Invariant: for all `block_id` and `client_id` such that -// -// missing_blocks[block_id].offers.contains_key(client_id) -// -// it must hold that -// -// clients[client_id].contains(block_id) -// -// and vice-versa. -struct Inner { - missing_blocks: HashMap, - clients: HashMap>, - next_client_id: ClientId, - request_mode: RequestMode, +enum Command { + InsertClient { + client_id: ClientId, + block_tx: mpsc::UnboundedSender, + }, + RemoveClient { + client_id: ClientId, + }, + InsertOffer { + client_id: ClientId, + block_id: BlockId, + state: BlockOfferState, + }, + ApproveOffers { + block_id: BlockId, + }, + SetRequestMode { + mode: BlockRequestMode, + }, + InsertRequired { + block_id: BlockId, + }, + ClearRequired, } -impl Inner { - fn insert_client(&mut self) -> ClientId { - let client_id = self.next_client_id; - self.next_client_id = self - .next_client_id - .checked_add(1) - .expect("too many clients"); - self.clients.insert(client_id, HashSet::new()); - client_id - } +#[cfg(test)] +mod tests { + use super::*; + use tokio::sync::mpsc::error::TryRecvError; - fn remove_client(&mut self, client_id: ClientId) -> bool { - // unwrap is ok because if `self` exists the `clients` entry must exists as well. - let block_ids = self.clients.remove(&client_id).unwrap(); - let mut notify = false; + #[test] + fn greedy_mode_offer_pending_then_approve() { + let (tracker, mut worker) = build(); + let block_id = BlockId::try_from([0; BlockId::SIZE].as_ref()).unwrap(); - for block_id in block_ids { - // unwrap is ok because of the invariant in `Inner` - let missing_block = self.missing_blocks.get_mut(&block_id).unwrap(); + // Note `Greedy` is the default mode - missing_block.offers.remove(&client_id); + let (client, mut block_rx) = tracker.new_client(); + worker.step(); - if missing_block.unaccept_by(client_id) { - notify = true; - } + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); - // TODO: if the block hasn't other offers and isn't required, remove it - } + client.offer(block_id, BlockOfferState::Pending); + worker.step(); - notify - } - - /// Mark the block with the given id as required. Returns true if the block wasn't already - /// required and if it has at least one offer. Otherwise returns false. - fn require(&mut self, block_id: BlockId) -> bool { - let missing_block = self - .missing_blocks - .entry(block_id) - .or_insert_with(|| MissingBlock { - offers: HashMap::default(), - state: State::Idle { - required: false, - approved: false, - }, - }); - - match &mut missing_block.state { - State::Idle { required: true, .. } | State::Accepted(_) => false, - State::Idle { required, .. } => { - *required = true; - !missing_block.offers.is_empty() - } - } - } + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); - fn complete(&mut self, block_id: &BlockId) { - let Some(missing_block) = self.missing_blocks.remove(block_id) else { - return; - }; + client.approve(block_id); + worker.step(); - for (client_id, _) in missing_block.offers { - if let Some(block_ids) = self.clients.get_mut(&client_id) { - block_ids.remove(block_id); - } - } + assert_eq!(block_rx.try_recv(), Ok(block_id)); + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); } - fn propose_offer(&mut self, client_id: ClientId) -> Option { - // TODO: OPTIMIZE (but profile first) this linear lookup - for block_id in self.clients.get(&client_id).into_iter().flatten() { - // unwrap is ok because of the invariant in `Inner` - let missing_block = self.missing_blocks.get_mut(block_id).unwrap(); + #[test] + fn greedy_mode_offer_approved() { + let (tracker, mut worker) = build(); - match missing_block.state { - State::Idle { - required: true, - approved: true, - } => (), - State::Idle { .. } | State::Accepted(_) => continue, - } + let block_id = BlockId::try_from([0; BlockId::SIZE].as_ref()).unwrap(); - // unwrap is ok because of the invariant. - let offer = missing_block.offers.get_mut(&client_id).unwrap(); - match offer { - Offer::Available => { - *offer = Offer::Proposed; - } - Offer::Proposed | Offer::Accepted => continue, - } + let (client, mut block_rx) = tracker.new_client(); + worker.step(); - return Some(*block_id); - } + client.offer(block_id, BlockOfferState::Approved); + worker.step(); - None + assert_eq!(block_rx.try_recv(), Ok(block_id)); } - fn accept_offer(&mut self, block_id: &BlockId, client_id: ClientId) -> bool { - let Some(missing_block) = self.missing_blocks.get_mut(block_id) else { - return false; - }; + #[test] + fn greedy_mode_approve_non_existing_offer() { + let (tracker, mut worker) = build(); + let block_id = BlockId::try_from([0; BlockId::SIZE].as_ref()).unwrap(); - match missing_block.state { - State::Idle { - required: true, - approved: true, - } => (), - State::Idle { .. } | State::Accepted(_) => return false, - } + let (client, mut block_rx) = tracker.new_client(); + worker.step(); - missing_block.state = State::Accepted(client_id); - missing_block.offers.insert(client_id, Offer::Accepted); + client.approve(block_id); + worker.step(); - true + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); } - fn cancel_offer(&mut self, block_id: &BlockId, client_id: ClientId) -> bool { - let Some(missing_block) = self.missing_blocks.get_mut(block_id) else { - return false; - }; + #[test] + fn lazy_mode_offer_pending_then_approve_then_require() { + let (tracker, mut worker) = build(); + let block_id = BlockId::try_from([0; BlockId::SIZE].as_ref()).unwrap(); - let Entry::Occupied(mut entry) = missing_block.offers.entry(client_id) else { - return false; - }; + tracker.set_request_mode(BlockRequestMode::Lazy); + worker.step(); - match entry.get() { - Offer::Proposed => { - entry.insert(Offer::Available); - } - Offer::Accepted => { - // Cancelling an accepted offer means the request either failed or timeouted so it's - // safe to remove it. If the peer sends us another leaf node response with the same - // block id, we register the offer again. - entry.remove(); - // unwrap is ok because if the client has been already destroyed then - // `missing_block.offers[&self.client_id]` would not exists and this function would - // have exited earlier. - self.clients.get_mut(&client_id).unwrap().remove(block_id); - } - Offer::Available => unreachable!(), - } + let (client, mut block_rx) = tracker.new_client(); + worker.step(); - missing_block.unaccept_by(client_id) - } -} + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); -#[derive(Debug)] -struct MissingBlock { - // Clients that offered this block. - offers: HashMap, - state: State, -} + client.offer(block_id, BlockOfferState::Pending); + worker.step(); -impl MissingBlock { - fn unaccept_by(&mut self, client_id: ClientId) -> bool { - match self.state { - State::Accepted(other_client_id) if other_client_id == client_id => { - self.state = State::Idle { - required: true, - approved: true, - }; - true - } - State::Accepted(_) | State::Idle { .. } => false, - } - } -} + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); -#[derive(Debug)] -enum State { - Idle { required: bool, approved: bool }, - Accepted(ClientId), -} + client.approve(block_id); + worker.step(); -#[derive(Debug)] -enum Offer { - Available, - Proposed, - Accepted, -} + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); -type ClientId = usize; + tracker.require(block_id); + worker.step(); -#[cfg(test)] -mod tests { - use super::*; - use crate::{collections::HashSet, protocol::Block, test_utils}; - use futures_util::future; - use rand::{distributions::Standard, rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; - use std::{pin::pin, time::Duration}; - use test_strategy::proptest; - use tokio::{select, sync::mpsc, sync::Barrier, task, time}; - - #[test] - fn lazy() { - let tracker = BlockTracker::new(); - tracker.set_request_mode(RequestMode::Lazy); - - let client = tracker.client(); - - // Initially no blocks are returned - assert!(client.offers().try_next().is_none()); - - // Offered but not required blocks are not returned - let block0: Block = rand::random(); - client.register(block0.id, OfferState::Approved); - assert!(client.offers().try_next().is_none()); - - // Required but not offered blocks are not returned - let block1: Block = rand::random(); - tracker.require(block1.id); - assert!(client.offers().try_next().is_none()); - - // Required + offered blocks are returned... - tracker.require(block0.id); - assert_eq!( - client - .offers() - .try_next() - .and_then(BlockOffer::accept) - .as_ref() - .map(BlockPromise::block_id), - Some(&block0.id) - ); - - // ...but only once. - assert!(client.offers().try_next().is_none()); + assert_eq!(block_rx.try_recv(), Ok(block_id)); + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); } #[test] - fn greedy() { - let tracker = BlockTracker::new(); - tracker.set_request_mode(RequestMode::Greedy); // greedy is the default, but let's be explicit here - - let client = tracker.client(); - - // Initially no blocks are returned - assert!(client.offers().try_next().is_none()); - - // Offered blocks are immediately returned - let block0: Block = rand::random(); - client.register(block0.id, OfferState::Approved); - assert_eq!( - client - .offers() - .try_next() - .and_then(BlockOffer::accept) - .as_ref() - .map(BlockPromise::block_id), - Some(&block0.id) - ); + fn lazy_mode_offer_pending_then_require_then_approve() { + let (tracker, mut worker) = build(); + let block_id = BlockId::try_from([0; BlockId::SIZE].as_ref()).unwrap(); - // ...but only once. - assert!(client.offers().try_next().is_none()); - } - - #[tokio::test(flavor = "multi_thread")] - async fn concurrent() { - let tracker = BlockTracker::new(); - tracker.set_request_mode(RequestMode::Lazy); - - let block: Block = rand::random(); - let client = tracker.client(); - let mut offers = client.offers(); + tracker.set_request_mode(BlockRequestMode::Lazy); + worker.step(); - tracker.require(block.id); + let (client, mut block_rx) = tracker.new_client(); + worker.step(); - let (tx, mut rx) = mpsc::channel(1); + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); - let handle = tokio::task::spawn(async move { - let mut next = pin!(offers.next()); + client.offer(block_id, BlockOfferState::Pending); + worker.step(); - loop { - select! { - block_offer = &mut next => { - return *block_offer.block_id(); - }, - _ = tx.send(()) => {} - } - } - }); + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); - // Make sure the task started. - rx.recv().await.unwrap(); + tracker.require(block_id); + worker.step(); - client.register(block.id, OfferState::Approved); + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); - let offered_block_id = time::timeout(Duration::from_secs(5), handle) - .await - .expect("timeout") - .unwrap(); + client.approve(block_id); + worker.step(); - assert_eq!(block.id, offered_block_id); + assert_eq!(block_rx.try_recv(), Ok(block_id)); + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); } #[test] - fn fallback_on_cancel_after_accept() { - let tracker = BlockTracker::new(); - tracker.set_request_mode(RequestMode::Lazy); + fn lazy_mode_require_then_offer_pending_then_approve() { + let (tracker, mut worker) = build(); + let block_id = BlockId::try_from([0; BlockId::SIZE].as_ref()).unwrap(); - let client0 = tracker.client(); - let client1 = tracker.client(); + tracker.set_request_mode(BlockRequestMode::Lazy); + worker.step(); - let block: Block = rand::random(); + let (client, mut block_rx) = tracker.new_client(); + worker.step(); - tracker.require(block.id); - client0.register(block.id, OfferState::Approved); - client1.register(block.id, OfferState::Approved); + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); - let block_promise = client0.offers().try_next().and_then(BlockOffer::accept); - assert_eq!( - block_promise.as_ref().map(BlockPromise::block_id), - Some(&block.id) - ); - assert!(client1.offers().try_next().is_none()); - - drop(block_promise); - - assert!(client0.offers().try_next().is_none()); - assert_eq!( - client1 - .offers() - .try_next() - .and_then(BlockOffer::accept) - .as_ref() - .map(BlockPromise::block_id), - Some(&block.id) - ); - } - - #[test] - fn fallback_on_client_drop_after_require_before_accept() { - let tracker = BlockTracker::new(); - tracker.set_request_mode(RequestMode::Lazy); - - let client0 = tracker.client(); - let client1 = tracker.client(); + tracker.require(block_id); + worker.step(); - let block: Block = rand::random(); + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); - client0.register(block.id, OfferState::Approved); - client1.register(block.id, OfferState::Approved); + client.offer(block_id, BlockOfferState::Pending); + worker.step(); - tracker.require(block.id); + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); - drop(client0); + client.approve(block_id); + worker.step(); - assert_eq!( - client1 - .offers() - .try_next() - .and_then(BlockOffer::accept) - .as_ref() - .map(BlockPromise::block_id), - Some(&block.id) - ); + assert_eq!(block_rx.try_recv(), Ok(block_id)); + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); } #[test] - fn fallback_on_client_drop_after_require_after_accept() { - let tracker = BlockTracker::new(); - tracker.set_request_mode(RequestMode::Lazy); + fn lazy_mode_offer_approved_then_require() { + let (tracker, mut worker) = build(); + let block_id = BlockId::try_from([0; BlockId::SIZE].as_ref()).unwrap(); - let client0 = tracker.client(); - let client1 = tracker.client(); + tracker.set_request_mode(BlockRequestMode::Lazy); + worker.step(); - let block: Block = rand::random(); + let (client, mut block_rx) = tracker.new_client(); + worker.step(); - client0.register(block.id, OfferState::Approved); - client1.register(block.id, OfferState::Approved); + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); - tracker.require(block.id); + client.offer(block_id, BlockOfferState::Approved); + worker.step(); - let block_promise = client0.offers().try_next().and_then(BlockOffer::accept); + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); - assert_eq!( - block_promise.as_ref().map(BlockPromise::block_id), - Some(&block.id) - ); - assert!(client1.offers().try_next().is_none()); - - drop(client0); - - assert_eq!( - client1 - .offers() - .try_next() - .and_then(BlockOffer::accept) - .as_ref() - .map(BlockPromise::block_id), - Some(&block.id) - ); + tracker.require(block_id); + worker.step(); + + assert_eq!(block_rx.try_recv(), Ok(block_id)); + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); } #[test] - fn fallback_on_client_drop_before_require() { - let tracker = BlockTracker::new(); - tracker.set_request_mode(RequestMode::Lazy); + fn lazy_mode_require_then_offer_approved() { + let (tracker, mut worker) = build(); + let block_id = BlockId::try_from([0; BlockId::SIZE].as_ref()).unwrap(); - let client0 = tracker.client(); - let client1 = tracker.client(); + tracker.set_request_mode(BlockRequestMode::Lazy); + worker.step(); - let block: Block = rand::random(); + let (client, mut block_rx) = tracker.new_client(); + worker.step(); - client0.register(block.id, OfferState::Approved); - client1.register(block.id, OfferState::Approved); + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); - drop(client0); + tracker.require(block_id); + worker.step(); - tracker.require(block.id); + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); - assert_eq!( - client1 - .offers() - .try_next() - .and_then(BlockOffer::accept) - .as_ref() - .map(BlockPromise::block_id), - Some(&block.id) - ); - } + client.offer(block_id, BlockOfferState::Approved); + worker.step(); - #[test] - fn approve() { - let tracker = BlockTracker::new(); - let client = tracker.client(); - - let block: Block = rand::random(); - client.register(block.id, OfferState::Pending); - assert!(client.offers().try_next().is_none()); - - tracker.approve(block.id); - assert_eq!( - client - .offers() - .try_next() - .and_then(BlockOffer::accept) - .as_ref() - .map(BlockPromise::block_id), - Some(&block.id) - ); + assert_eq!(block_rx.try_recv(), Ok(block_id)); + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); } #[test] - fn multiple_offers_from_different_clients() { - let tracker = BlockTracker::new(); - - let client0 = tracker.client(); - let client1 = tracker.client(); + fn lazy_mode_approve_non_existing_offer() { + let (tracker, mut worker) = build(); + let block_id = BlockId::try_from([0; BlockId::SIZE].as_ref()).unwrap(); - let block: Block = rand::random(); + tracker.set_request_mode(BlockRequestMode::Lazy); + worker.step(); - client0.register(block.id, OfferState::Approved); - client1.register(block.id, OfferState::Approved); + let (client, mut block_rx) = tracker.new_client(); + worker.step(); - let offer0 = client0.offers().try_next().unwrap(); - let offer1 = client1.offers().try_next().unwrap(); + tracker.require(block_id); + client.approve(block_id); + worker.step(); - assert_eq!(offer0.block_id(), offer1.block_id()); - - let promise0 = offer0.accept(); - let promise1 = offer1.accept(); - - assert!(promise0.is_some()); - assert!(promise1.is_none()); + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); } #[test] - fn multiple_offers_from_same_client() { - let tracker = BlockTracker::new(); + fn lazy_mode_offer_then_drop_client_then_require() { + let (tracker, mut worker) = build(); + let block_id = BlockId::try_from([0; BlockId::SIZE].as_ref()).unwrap(); - let client = tracker.client(); + tracker.set_request_mode(BlockRequestMode::Lazy); + worker.step(); - let block0: Block = rand::random(); - client.register(block0.id, OfferState::Approved); + let (client, mut block_rx) = tracker.new_client(); + worker.step(); - let block1: Block = rand::random(); - client.register(block1.id, OfferState::Approved); + client.offer(block_id, BlockOfferState::Approved); + worker.step(); - let offer0 = client.offers().try_next().unwrap(); - let offer1 = client.offers().try_next().unwrap(); - let offer2 = client.offers().try_next(); + drop(client); + tracker.require(block_id); + worker.step(); - assert_ne!(offer0.block_id(), offer1.block_id()); - assert!(offer2.is_none()); + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Disconnected)); } - #[tokio::test(flavor = "multi_thread")] - async fn race() { - let num_clients = 10; - - let tracker = BlockTracker::new(); - let clients: Vec<_> = (0..num_clients).map(|_| tracker.client()).collect(); + #[test] + fn switch_lazy_mode_to_greedy_mode() { + let (tracker, mut worker) = build(); + let block_id_0 = BlockId::try_from([0; BlockId::SIZE].as_ref()).unwrap(); + let block_id_1 = BlockId::try_from([1; BlockId::SIZE].as_ref()).unwrap(); - let block: Block = rand::random(); + tracker.set_request_mode(BlockRequestMode::Lazy); + worker.step(); - for client in &clients { - client.register(block.id, OfferState::Approved); - } + let (client, mut block_rx) = tracker.new_client(); + client.offer(block_id_0, BlockOfferState::Pending); + client.offer(block_id_1, BlockOfferState::Approved); + worker.step(); - // Make sure all clients stay alive until we are done so that any accepted requests are not - // released prematurely. - let barrier = Arc::new(Barrier::new(clients.len())); - - // Run the clients in parallel - let handles = clients.into_iter().map(|client| { - task::spawn({ - let barrier = barrier.clone(); - async move { - let block_promise = client.offers().try_next().and_then(BlockOffer::accept); - let result = block_promise.as_ref().map(BlockPromise::block_id).cloned(); - barrier.wait().await; - result - } - }) - }); + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); - let block_ids = future::try_join_all(handles).await.unwrap(); + tracker.set_request_mode(BlockRequestMode::Greedy); + worker.step(); - // Exactly one client gets the block id - let mut block_ids = block_ids.into_iter().flatten(); - assert_eq!(block_ids.next(), Some(block.id)); - assert_eq!(block_ids.next(), None); + assert_eq!(block_rx.try_recv(), Ok(block_id_1)); + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); } - #[proptest] - fn stress( - #[strategy(1usize..100)] num_blocks: usize, - #[strategy(test_utils::rng_seed_strategy())] rng_seed: u64, - ) { - stress_case(num_blocks, rng_seed) - } + #[test] + fn approve_affects_all_clients() { + let (tracker, mut worker) = build(); + let block_id = BlockId::try_from([0; BlockId::SIZE].as_ref()).unwrap(); - fn stress_case(num_blocks: usize, rng_seed: u64) { - let mut rng = StdRng::seed_from_u64(rng_seed); + let (client_a, mut block_rx_a) = tracker.new_client(); + let (client_b, mut block_rx_b) = tracker.new_client(); - let tracker = BlockTracker::new(); - tracker.set_request_mode(RequestMode::Lazy); + client_a.offer(block_id, BlockOfferState::Pending); + client_b.offer(block_id, BlockOfferState::Pending); - let client = tracker.client(); + worker.step(); - let block_ids: Vec = (&mut rng).sample_iter(Standard).take(num_blocks).collect(); + assert_eq!(block_rx_a.try_recv(), Err(TryRecvError::Empty)); + assert_eq!(block_rx_b.try_recv(), Err(TryRecvError::Empty)); - enum Op { - Require, - Register, - } + client_a.approve(block_id); + worker.step(); - let mut ops: Vec<_> = block_ids - .iter() - .map(|block_id| (Op::Require, *block_id)) - .chain(block_ids.iter().map(|block_id| (Op::Register, *block_id))) - .collect(); - ops.shuffle(&mut rng); - - for (op, block_id) in ops { - match op { - Op::Require => { - tracker.require(block_id); - } - Op::Register => { - client.register(block_id, OfferState::Approved); - } - } - } + assert_eq!(block_rx_a.try_recv(), Ok(block_id)); + assert_eq!(block_rx_b.try_recv(), Ok(block_id)); + } - let mut block_promise = HashSet::with_capacity(block_ids.len()); + #[test] + fn clear_required() { + let (tracker, mut worker) = build(); + let block_id = BlockId::try_from([0; BlockId::SIZE].as_ref()).unwrap(); - while let Some(block_id) = client - .offers() - .try_next() - .and_then(BlockOffer::accept) - .as_ref() - .map(BlockPromise::block_id) - { - block_promise.insert(*block_id); - } + tracker.set_request_mode(BlockRequestMode::Lazy); + tracker.require(block_id); + tracker.clear_required(); + worker.step(); - assert_eq!(block_promise.len(), block_ids.len()); + let (client, mut block_rx) = tracker.new_client(); + client.offer(block_id, BlockOfferState::Approved); - for block_id in &block_ids { - assert!(block_promise.contains(block_id)); - } + assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); } } diff --git a/lib/src/collections.rs b/lib/src/collections.rs index a35e0c094..0dbcb161a 100644 --- a/lib/src/collections.rs +++ b/lib/src/collections.rs @@ -6,7 +6,6 @@ pub(crate) use self::{hash_map::HashMap, hash_set::HashSet}; pub(crate) mod hash_map { pub use rand::RandomState; - pub use std::collections::hash_map::{Entry, /* OccupiedEntry,*/ VacantEntry}; pub type HashMap = std::collections::HashMap; } diff --git a/lib/src/network/client.rs b/lib/src/network/client.rs index 9619072eb..7fc6cff81 100644 --- a/lib/src/network/client.rs +++ b/lib/src/network/client.rs @@ -2,18 +2,17 @@ use super::{ constants::RESPONSE_BATCH_SIZE, debug_payload::{DebugRequest, DebugResponse}, message::{Message, Response, ResponseDisambiguator}, - pending::{ - EphemeralResponse, PendingRequest, PendingRequests, PersistableResponse, PreparedResponse, - }, + request_tracker::{PendingRequest, RequestTracker, RequestTrackerClient}, }; use crate::{ - block_tracker::{BlockPromise, TrackerClient}, + block_tracker::{BlockOfferState, BlockTrackerClient}, crypto::{sign::PublicKey, CacheHash, Hashable}, error::Result, event::Payload, + network::{message::Request, request_tracker::MessageKey}, protocol::{ - Block, BlockId, InnerNodes, LeafNodes, MultiBlockPresence, ProofError, RootNodeFilter, - UntrustedProof, + Block, BlockId, InnerNodes, LeafNodes, MultiBlockPresence, NodeState, ProofError, + RootNodeFilter, UntrustedProof, }, repository::Vault, store::{ClientReader, ClientWriter}, @@ -29,7 +28,9 @@ mod future { pub(super) struct Client { inner: Inner, + request_rx: mpsc::UnboundedReceiver, response_rx: mpsc::Receiver, + block_rx: mpsc::UnboundedReceiver, } impl Client { @@ -37,50 +38,65 @@ impl Client { vault: Vault, message_tx: mpsc::UnboundedSender, response_rx: mpsc::Receiver, + request_tracker: &RequestTracker, ) -> Self { - let pending_requests = PendingRequests::new(vault.monitor.clone()); - let block_tracker = vault.block_tracker.client(); + let (request_tracker, request_rx) = request_tracker.new_client(); + let (block_tracker, block_rx) = vault.block_tracker.new_client(); let inner = Inner { vault, - pending_requests, + request_tracker, block_tracker, message_tx, }; - Self { inner, response_rx } + Self { + inner, + request_rx, + response_rx, + block_rx, + } } } impl Client { pub async fn run(&mut self) -> Result<()> { - let Self { inner, response_rx } = self; - - inner.run(response_rx).await + let Self { + inner, + request_rx, + response_rx, + block_rx, + } = self; + + inner.run(request_rx, response_rx, block_rx).await } } struct Inner { vault: Vault, - pending_requests: PendingRequests, - block_tracker: TrackerClient, + request_tracker: RequestTrackerClient, + block_tracker: BlockTrackerClient, message_tx: mpsc::UnboundedSender, } impl Inner { - async fn run(&mut self, response_rx: &mut mpsc::Receiver) -> Result<()> { + async fn run( + &mut self, + request_rx: &mut mpsc::UnboundedReceiver, + response_rx: &mut mpsc::Receiver, + block_rx: &mut mpsc::UnboundedReceiver, + ) -> Result<()> { select! { result = self.handle_responses(response_rx) => result, - _ = self.handle_available_block_offers() => Ok(()), + _ = self.send_requests(request_rx) => Ok(()), + _ = self.request_blocks(block_rx) => Ok(()), _ = self.handle_reload_index() => Ok(()), } } - fn send_request(&self, request: PendingRequest) { - if let Some(request) = self.pending_requests.insert(request) { - self.message_tx - .send(Message::Request(request)) - .unwrap_or(()); + async fn send_requests(&self, request_rx: &mut mpsc::UnboundedReceiver) { + while let Some(PendingRequest { request, .. }) = request_rx.recv().await { + self.message_tx.send(Message::Request(request)).ok(); } } @@ -92,31 +108,39 @@ impl Inner { for response in recv_iter(rx).await { self.vault.monitor.responses_received.increment(1); - let response = self.pending_requests.remove(response); - match response { - PreparedResponse::RootNode(proof, block_presence, debug) => { + Response::RootNode(proof, block_presence, debug) => { persistable.push(PersistableResponse::RootNode( proof, block_presence, debug, )); } - PreparedResponse::InnerNodes(nodes, _, debug) => { - persistable.push(PersistableResponse::InnerNodes(nodes, debug)); + Response::InnerNodes(nodes, _, debug) => { + persistable.push(PersistableResponse::InnerNodes(nodes.into(), debug)); } - PreparedResponse::LeafNodes(nodes, _, debug) => { - persistable.push(PersistableResponse::LeafNodes(nodes, debug)); + Response::LeafNodes(nodes, _, debug) => { + persistable.push(PersistableResponse::LeafNodes(nodes.into(), debug)); } - PreparedResponse::Block(block, block_promise, debug) => { - persistable.push(PersistableResponse::Block(block, block_promise, debug)); + Response::Block(block_content, block_nonce, debug) => { + persistable.push(PersistableResponse::Block( + Block::new(block_content, block_nonce), + debug, + )); } - PreparedResponse::BlockOffer(block_id, debug) => { + Response::BlockOffer(block_id, debug) => { ephemeral.push(EphemeralResponse::BlockOffer(block_id, debug)); } - PreparedResponse::RootNodeError(..) - | PreparedResponse::ChildNodesError(..) - | PreparedResponse::BlockError(..) => (), + Response::RootNodeError(writer_id, _) => { + self.request_tracker + .failure(MessageKey::RootNode(writer_id)); + } + Response::ChildNodesError(hash, _, _) => { + self.request_tracker.failure(MessageKey::ChildNodes(hash)); + } + Response::BlockError(block_id, _) => { + self.request_tracker.failure(MessageKey::Block(block_id)); + } } if ephemeral.len() >= RESPONSE_BATCH_SIZE { @@ -158,9 +182,8 @@ impl Inner { PersistableResponse::LeafNodes(nodes, debug) => { self.handle_leaf_nodes(&mut writer, nodes, debug).await?; } - PersistableResponse::Block(block, block_promise, debug) => { - self.handle_block(&mut writer, block, block_promise, debug) - .await?; + PersistableResponse::Block(block, debug) => { + self.handle_block(&mut writer, block, debug).await?; } } } @@ -221,16 +244,23 @@ impl Inner { } let hash = proof.hash; + let writer_id = proof.writer_id; let status = writer.save_root_node(proof, &block_presence).await?; tracing::debug!("Received root node - {status}"); if status.request_children() { - self.send_request(PendingRequest::ChildNodes( - hash, - ResponseDisambiguator::new(block_presence), - debug_payload.follow_up(), - )); + self.request_tracker.success( + MessageKey::RootNode(writer_id), + vec![PendingRequest { + request: Request::ChildNodes( + hash, + ResponseDisambiguator::new(block_presence), + debug_payload.follow_up(), + ), + block_presence, + }], + ); } Ok(()) @@ -243,6 +273,7 @@ impl Inner { nodes: CacheHash, debug_payload: DebugResponse, ) -> Result<()> { + let hash = nodes.hash(); let total = nodes.len(); let status = writer.save_inner_nodes(nodes).await?; @@ -252,13 +283,21 @@ impl Inner { total ); - for node in status.new_children { - self.send_request(PendingRequest::ChildNodes( - node.hash, - ResponseDisambiguator::new(node.summary.block_presence), - debug_payload.follow_up(), - )); - } + self.request_tracker.success( + MessageKey::ChildNodes(hash), + status + .new_children + .into_iter() + .map(|node| PendingRequest { + request: Request::ChildNodes( + node.hash, + ResponseDisambiguator::new(node.summary.block_presence), + debug_payload.follow_up(), + ), + block_presence: node.summary.block_presence, + }) + .collect(), + ); Ok(()) } @@ -270,6 +309,7 @@ impl Inner { nodes: CacheHash, debug_payload: DebugResponse, ) -> Result<()> { + let hash = nodes.hash(); let total = nodes.len(); let status = writer.save_leaf_nodes(nodes).await?; @@ -280,9 +320,14 @@ impl Inner { ); for (block_id, state) in status.new_block_offers { - self.block_tracker.register(block_id, state); + if let Some(state) = block_offer_state(state) { + self.block_tracker.offer(block_id, state); + } } + self.request_tracker + .success(MessageKey::ChildNodes(hash), Vec::new()); + Ok(()) } @@ -293,13 +338,13 @@ impl Inner { block_id: BlockId, debug_payload: DebugResponse, ) -> Result<()> { - let Some(offer_state) = reader.load_block_offer_state(&block_id).await? else { - return Ok(()); - }; - - tracing::trace!(?offer_state, "Received block offer"); + let root_node_state = reader + .load_effective_root_node_state_for_block(&block_id) + .await?; - self.block_tracker.register(block_id, offer_state); + if let Some(offer_state) = block_offer_state(root_node_state) { + self.block_tracker.offer(block_id, offer_state); + } Ok(()) } @@ -309,13 +354,15 @@ impl Inner { &self, writer: &mut ClientWriter, block: Block, - block_promise: Option, debug_payload: DebugResponse, ) -> Result<()> { - writer.save_block(&block, block_promise).await?; + writer.save_block(&block).await?; tracing::trace!("Received block"); + self.request_tracker + .success(MessageKey::Block(block.id), vec![]); + Ok(()) } @@ -344,7 +391,7 @@ impl Inner { // Approve pending block offers referenced from the newly approved snapshots. for block_id in status.approved_missing_blocks { - self.vault.block_tracker.approve(block_id); + self.block_tracker.approve(block_id); } self.refresh_branches(status.approved_branches.iter().copied()); @@ -353,12 +400,10 @@ impl Inner { Ok(()) } - async fn handle_available_block_offers(&self) { - let mut block_offers = self.block_tracker.offers(); - - loop { - let block_offer = block_offers.next().await; - self.send_request(PendingRequest::Block(block_offer, DebugRequest::start())); + async fn request_blocks(&self, block_rx: &mut mpsc::UnboundedReceiver) { + while let Some(block_id) = block_rx.recv().await { + self.request_tracker + .initial(Request::Block(block_id, DebugRequest::start())); } } @@ -391,7 +436,8 @@ impl Inner { // requested as soon as possible. fn refresh_branches(&self, branches: impl IntoIterator) { for branch_id in branches { - self.send_request(PendingRequest::RootNode(branch_id, DebugRequest::start())); + self.request_tracker + .initial(Request::RootNode(branch_id, DebugRequest::start())); } } @@ -436,6 +482,19 @@ impl Inner { } } +/// Response whose processing requires only read access to the store or no access at all. +enum EphemeralResponse { + BlockOffer(BlockId, DebugResponse), +} + +/// Response whose processing requires write access to the store. +enum PersistableResponse { + RootNode(UntrustedProof, MultiBlockPresence, DebugResponse), + InnerNodes(CacheHash, DebugResponse), + LeafNodes(CacheHash, DebugResponse), + Block(Block, DebugResponse), +} + /// Waits for at least one item to become available (or the chanel getting closed) and then yields /// all the buffered items from the channel. async fn recv_iter(rx: &mut mpsc::Receiver) -> impl Iterator + '_ { @@ -445,12 +504,20 @@ async fn recv_iter(rx: &mut mpsc::Receiver) -> impl Iterator + ' .chain(iter::from_fn(|| rx.try_recv().ok())) } +fn block_offer_state(root_node_state: NodeState) -> Option { + match root_node_state { + NodeState::Approved => Some(BlockOfferState::Approved), + NodeState::Complete | NodeState::Incomplete => Some(BlockOfferState::Pending), + NodeState::Rejected => None, + } +} + #[cfg(test)] mod tests { use super::*; use crate::{ access_control::WriteSecrets, - block_tracker::RequestMode, + block_tracker::BlockRequestMode, crypto::sign::Keypair, db, event::EventSender, @@ -545,16 +612,17 @@ mod tests { RepositoryMonitor::new(StateMonitor::make_root(), &NoopRecorder), ); - vault.block_tracker.set_request_mode(RequestMode::Lazy); + vault.block_tracker.set_request_mode(BlockRequestMode::Lazy); - let pending_requests = PendingRequests::new(vault.monitor.clone()); - let block_tracker = vault.block_tracker.client(); + let request_tracker = RequestTracker::new(); + let (request_tracker, _request_rx) = request_tracker.new_client(); + let (block_tracker, _block_rx) = vault.block_tracker.new_client(); let (message_tx, _message_rx) = mpsc::unbounded_channel(); let inner = Inner { vault, - pending_requests, + request_tracker, block_tracker, message_tx, }; diff --git a/lib/src/network/connection.rs b/lib/src/network/connection.rs index e718f1afe..700e6c4d6 100644 --- a/lib/src/network/connection.rs +++ b/lib/src/network/connection.rs @@ -7,11 +7,12 @@ use super::{ stats::{ByteCounters, StatsTracker}, }; use crate::{ - collections::{hash_map::Entry, HashMap}, + collections::HashMap, sync::{AwaitDrop, DropAwaitable, WatchSenderExt}, }; use serde::Serialize; use std::{ + collections::hash_map::Entry, fmt, sync::{ atomic::{AtomicU64, Ordering}, diff --git a/lib/src/network/local_discovery.rs b/lib/src/network/local_discovery.rs index a33ecba77..96cdf14c3 100644 --- a/lib/src/network/local_discovery.rs +++ b/lib/src/network/local_discovery.rs @@ -108,7 +108,7 @@ struct LocalDiscoveryInner { impl LocalDiscoveryInner { fn add(&mut self, interface: IpAddr, parent_monitor: &StateMonitor) { - use crate::collections::hash_map::Entry; + use std::collections::hash_map::Entry; if interface.is_loopback() { return; diff --git a/lib/src/network/message_broker.rs b/lib/src/network/message_broker.rs index 5daa3e906..bfc2a9b09 100644 --- a/lib/src/network/message_broker.rs +++ b/lib/src/network/message_broker.rs @@ -4,12 +4,13 @@ use super::{ message::{Message, Request, Response}, message_dispatcher::{MessageDispatcher, MessageSink, MessageStream}, peer_exchange::{PexPeer, PexReceiver, PexRepository, PexSender}, + request_tracker::RequestTracker, runtime_id::PublicRuntimeId, server::Server, stats::ByteCounters, }; use crate::{ - collections::{hash_map::Entry, HashMap}, + collections::HashMap, crypto::Hashable, network::constants::{REQUEST_BUFFER_SIZE, RESPONSE_BUFFER_SIZE}, protocol::RepositoryId, @@ -20,7 +21,7 @@ use bytes::{BufMut, BytesMut}; use futures_util::{SinkExt, StreamExt}; use net::{bus::TopicId, unified::Connection}; use state_monitor::StateMonitor; -use std::{sync::Arc, time::Instant}; +use std::{collections::hash_map::Entry, sync::Arc, time::Instant}; use tokio::{ select, sync::{ @@ -44,7 +45,6 @@ pub(super) struct MessageBroker { } impl MessageBroker { - #[allow(clippy::too_many_arguments)] pub fn new( this_runtime_id: PublicRuntimeId, that_runtime_id: PublicRuntimeId, @@ -77,6 +77,7 @@ impl MessageBroker { &mut self, vault: Vault, pex_repo: &PexRepository, + request_tracker: RequestTracker, response_limiter: Arc, repo_counters: Arc, ) { @@ -124,6 +125,7 @@ impl MessageBroker { topic_id, dispatcher: self.dispatcher.clone(), vault, + request_tracker, response_limiter, pex_tx, pex_rx, @@ -199,6 +201,7 @@ struct Link { topic_id: TopicId, dispatcher: MessageDispatcher, vault: Vault, + request_tracker: RequestTracker, response_limiter: Arc, pex_tx: PexSender, pex_rx: PexReceiver, @@ -256,6 +259,7 @@ impl Link { crypto_stream, crypto_sink, &self.vault, + &self.request_tracker, self.response_limiter.clone(), &mut self.pex_tx, &mut self.pex_rx, @@ -290,7 +294,8 @@ async fn establish_channel<'a>( async fn run_link( stream: DecryptingStream<'_>, sink: EncryptingSink<'_>, - repo: &Vault, + vault: &Vault, + request_tracker: &RequestTracker, response_limiter: Arc, pex_tx: &mut PexSender, pex_rx: &mut PexReceiver, @@ -305,8 +310,8 @@ async fn run_link( let _guard = LinkGuard::new(); select! { - _ = run_client(repo.clone(), message_tx.clone(), response_rx) => (), - _ = run_server(repo.clone(), message_tx.clone(), request_rx, response_limiter) => (), + _ = run_client(vault.clone(), message_tx.clone(), response_rx, request_tracker) => (), + _ = run_server(vault.clone(), message_tx.clone(), request_rx, response_limiter) => (), _ = recv_messages(stream, request_tx, response_tx, pex_rx) => (), _ = send_messages(message_rx, sink) => (), _ = pex_tx.run(message_tx) => (), @@ -408,11 +413,12 @@ async fn send_messages( // Create and run client. Returns only on error. async fn run_client( - repo: Vault, + vault: Vault, message_tx: mpsc::UnboundedSender, response_rx: mpsc::Receiver, + request_tracker: &RequestTracker, ) { - let mut client = Client::new(repo, message_tx, response_rx); + let mut client = Client::new(vault, message_tx, response_rx, request_tracker); let result = client.run().await; tracing::debug!("Client stopped running with result {:?}", result); @@ -420,12 +426,12 @@ async fn run_client( // Create and run server. Returns only on error. async fn run_server( - repo: Vault, + vault: Vault, message_tx: mpsc::UnboundedSender, request_rx: mpsc::Receiver, response_limiter: Arc, ) { - let mut server = Server::new(repo, message_tx, request_rx, response_limiter); + let mut server = Server::new(vault, message_tx, request_rx, response_limiter); let result = server.run().await; diff --git a/lib/src/network/mod.rs b/lib/src/network/mod.rs index 4c0d124fd..92e956ab2 100644 --- a/lib/src/network/mod.rs +++ b/lib/src/network/mod.rs @@ -16,7 +16,6 @@ mod peer_exchange; mod peer_info; mod peer_source; mod peer_state; -mod pending; mod protocol; mod request_tracker; mod runtime_id; @@ -40,6 +39,7 @@ pub use self::{ stats::Stats, }; pub use net::stun::NatBehavior; +use request_tracker::RequestTracker; use self::{ connection::{ConnectionPermit, ConnectionSet, ReserveResult}, @@ -355,6 +355,7 @@ impl Network { let pex = self.inner.pex_discovery.new_repository(); pex.set_enabled(pex_enabled); + let request_tracker = RequestTracker::new(); // TODO: This should be global, not per repo let response_limiter = Arc::new(Semaphore::new(MAX_UNCHOKED_COUNT)); let stats_tracker = StatsTracker::default(); @@ -364,7 +365,8 @@ impl Network { registry.create_link( handle.vault.clone(), &pex, - response_limiter.clone(), + &request_tracker, + &response_limiter, stats_tracker.bytes.clone(), ); @@ -372,6 +374,7 @@ impl Network { vault: handle.vault, dht, pex, + request_tracker, response_limiter, stats_tracker, }); @@ -482,6 +485,7 @@ struct RegistrationHolder { vault: Vault, dht: Option, pex: PexRepository, + request_tracker: RequestTracker, response_limiter: Arc, stats_tracker: StatsTracker, } @@ -524,7 +528,8 @@ impl Registry { &mut self, repo: Vault, pex: &PexRepository, - response_limiter: Arc, + request_tracker: &RequestTracker, + response_limiter: &Arc, byte_counters: Arc, ) { if let Some(peers) = &mut self.peers { @@ -532,6 +537,7 @@ impl Registry { peer.create_link( repo.clone(), pex, + request_tracker.clone(), response_limiter.clone(), byte_counters.clone(), ) @@ -921,6 +927,7 @@ impl Inner { peer.create_link( holder.vault.clone(), &holder.pex, + holder.request_tracker.clone(), holder.response_limiter.clone(), holder.stats_tracker.bytes.clone(), ); diff --git a/lib/src/network/pending.rs b/lib/src/network/pending.rs deleted file mode 100644 index ce62c39ec..000000000 --- a/lib/src/network/pending.rs +++ /dev/null @@ -1,303 +0,0 @@ -use super::{ - constants::REQUEST_TIMEOUT, - debug_payload::{DebugRequest, DebugResponse}, - message::{Request, Response, ResponseDisambiguator}, -}; -use crate::{ - block_tracker::{BlockOffer, BlockPromise}, - collections::HashMap, - crypto::{sign::PublicKey, CacheHash, Hash, Hashable}, - protocol::{Block, BlockId, InnerNodes, LeafNodes, MultiBlockPresence, UntrustedProof}, - repository::RepositoryMonitor, - sync::delay_map::DelayMap, -}; -use deadlock::BlockingMutex; -use scoped_task::ScopedJoinHandle; -use std::{collections::hash_map::Entry, future, sync::Arc, task::ready}; -use std::{task::Poll, time::Instant}; -use tokio::sync::Notify; - -pub(crate) enum PendingRequest { - RootNode(PublicKey, DebugRequest), - ChildNodes(Hash, ResponseDisambiguator, DebugRequest), - Block(BlockOffer, DebugRequest), -} - -/// Response that's been prepared for processing. -pub(super) enum PreparedResponse { - RootNode(UntrustedProof, MultiBlockPresence, DebugResponse), - InnerNodes(CacheHash, ResponseDisambiguator, DebugResponse), - LeafNodes(CacheHash, ResponseDisambiguator, DebugResponse), - BlockOffer(BlockId, DebugResponse), - // The `BlockPromise` is `None` if the request timeouted but we still received the response - // afterwards. - Block(Block, Option, DebugResponse), - RootNodeError(PublicKey, DebugResponse), - ChildNodesError(Hash, ResponseDisambiguator, DebugResponse), - BlockError(BlockId, DebugResponse), -} - -impl From for PreparedResponse { - fn from(response: Response) -> Self { - match response { - Response::RootNode(proof, block_presence, debug) => { - Self::RootNode(proof, block_presence, debug) - } - Response::InnerNodes(nodes, disambiguator, debug) => { - Self::InnerNodes(nodes.into(), disambiguator, debug) - } - Response::LeafNodes(nodes, disambiguator, debug) => { - Self::LeafNodes(nodes.into(), disambiguator, debug) - } - Response::BlockOffer(block_id, debug) => Self::BlockOffer(block_id, debug), - Response::Block(content, nonce, debug) => { - Self::Block(Block::new(content, nonce), None, debug) - } - Response::RootNodeError(writer_id, debug) => Self::RootNodeError(writer_id, debug), - Response::ChildNodesError(hash, disambiguator, debug) => { - Self::ChildNodesError(hash, disambiguator, debug) - } - Response::BlockError(block_id, debug) => Self::BlockError(block_id, debug), - } - } -} - -/// Response whose processing requires only read access to the store or no access at all. -pub(super) enum EphemeralResponse { - BlockOffer(BlockId, DebugResponse), -} - -/// Response whose processing requires write access to the store. -pub(super) enum PersistableResponse { - RootNode(UntrustedProof, MultiBlockPresence, DebugResponse), - InnerNodes(CacheHash, DebugResponse), - LeafNodes(CacheHash, DebugResponse), - Block(Block, Option, DebugResponse), -} - -#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] -pub(crate) enum IndexKey { - RootNode(PublicKey), - ChildNodes(Hash, ResponseDisambiguator), -} - -/// Tracks sent requests whose responses have not been received yet. -/// -/// This has multiple purposes: -/// -/// - To prevent sending duplicate requests. -/// - To know how many requests are in flight which in turn is used to indicate activity. -/// - To track round trip time / latency. -/// - To timeout block requests so that we can send them to other peers instead. -/// -/// Note that only block requests are currently timeouted. This is because we currently send block -/// request to only one peer at a time. So if this peer was faulty, without the timeout it could -/// prevent us from receiving that block indefinitely. Index requests, on the other hand, are -/// currently sent to all peers so even with some of the peers being faulty, the responses from the -/// non-faulty ones should eventually be received. -pub(super) struct PendingRequests { - monitor: Arc, - index: PendingIndexRequests, - block: Arc, - // This is to ensure the `run_expiration_tracker` task is destroyed with PendingRequests (as - // opposed to the task being destroyed "sometime after"). This is important because the task - // holds an Arc to the RepositoryMonitor which must be destroyed prior to reimporting its - // corresponding repository if the user decides to do so. - _expiration_tracker_task: ScopedJoinHandle<()>, -} - -impl PendingRequests { - pub fn new(monitor: Arc) -> Self { - let index = PendingIndexRequests::default(); - let block = Arc::new(PendingBlockRequests::default()); - - Self { - monitor: monitor.clone(), - index, - block: block.clone(), - _expiration_tracker_task: scoped_task::spawn(run_expiration_tracker(monitor, block)), - } - } - - pub fn insert(&self, pending_request: PendingRequest) -> Option { - let request = match pending_request { - PendingRequest::RootNode(writer_id, debug) => self - .index - .try_insert(IndexKey::RootNode(writer_id)) - .then_some(Request::RootNode(writer_id, debug))?, - PendingRequest::ChildNodes(hash, disambiguator, debug) => self - .index - .try_insert(IndexKey::ChildNodes(hash, disambiguator)) - .then_some(Request::ChildNodes(hash, disambiguator, debug))?, - PendingRequest::Block(block_offer, debug) => { - let block_promise = block_offer.accept()?; - let block_id = *block_promise.block_id(); - self.block - .try_insert(block_promise) - .then_some(Request::Block(block_id, debug))? - } - }; - - match request { - Request::RootNode(..) | Request::ChildNodes(..) => { - self.monitor.index_requests_sent.increment(1); - self.monitor.index_requests_inflight.increment(1.0); - } - Request::Block(..) => { - self.monitor.block_requests_sent.increment(1); - self.monitor.block_requests_inflight.increment(1.0); - } - } - - Some(request) - } - - pub fn remove(&self, response: Response) -> PreparedResponse { - let mut response = PreparedResponse::from(response); - - enum ResponseKind { - Index, - Block, - } - - let status = match &mut response { - PreparedResponse::RootNode(proof, ..) => self - .index - .remove(&IndexKey::RootNode(proof.writer_id)) - .map(|timestamp| (timestamp, ResponseKind::Index)), - PreparedResponse::RootNodeError(writer_id, ..) => self - .index - .remove(&IndexKey::RootNode(*writer_id)) - .map(|timestamp| (timestamp, ResponseKind::Index)), - PreparedResponse::InnerNodes(nodes, disambiguator, ..) => self - .index - .remove(&IndexKey::ChildNodes(nodes.hash(), *disambiguator)) - .map(|timestamp| (timestamp, ResponseKind::Index)), - PreparedResponse::LeafNodes(nodes, disambiguator, ..) => self - .index - .remove(&IndexKey::ChildNodes(nodes.hash(), *disambiguator)) - .map(|timestamp| (timestamp, ResponseKind::Index)), - PreparedResponse::ChildNodesError(hash, disambiguator, ..) => self - .index - .remove(&IndexKey::ChildNodes(*hash, *disambiguator)) - .map(|timestamp| (timestamp, ResponseKind::Index)), - PreparedResponse::Block(block, block_promise, ..) => { - self.block - .remove(&block.id) - .map(|(timestamp, new_block_promise)| { - *block_promise = Some(new_block_promise); - (timestamp, ResponseKind::Block) - }) - } - PreparedResponse::BlockError(block_id, ..) => self - .block - .remove(block_id) - .map(|(timestamp, _)| (timestamp, ResponseKind::Block)), - PreparedResponse::BlockOffer(..) => None, - }; - - if let Some((timestamp, kind)) = status { - self.monitor.request_latency.record(timestamp.elapsed()); - - match kind { - ResponseKind::Index => self.monitor.index_requests_inflight.decrement(1.0), - ResponseKind::Block => self.monitor.block_requests_inflight.decrement(1.0), - } - } - - response - } -} - -impl Drop for PendingRequests { - fn drop(&mut self) { - self.monitor - .index_requests_inflight - .decrement(self.index.map.lock().unwrap().len() as f64); - self.monitor - .block_requests_inflight - .decrement(self.block.map.lock().unwrap().len() as f64); - } -} - -#[derive(Default)] -struct PendingIndexRequests { - map: BlockingMutex>, -} - -impl PendingIndexRequests { - fn try_insert(&self, key: IndexKey) -> bool { - match self.map.lock().unwrap().entry(key) { - Entry::Vacant(entry) => { - entry.insert(Instant::now()); - true - } - Entry::Occupied(_) => false, - } - } - - fn remove(&self, key: &IndexKey) -> Option { - self.map.lock().unwrap().remove(key) - } -} - -#[derive(Default)] -struct PendingBlockRequests { - map: BlockingMutex>, - // Notify when item is inserted into previously empty map. This restarts the expiration tracker - // task. - notify: Notify, -} - -impl PendingBlockRequests { - fn try_insert(&self, block_promise: BlockPromise) -> bool { - let mut map = self.map.lock().unwrap(); - - if let Some(entry) = map.try_insert(*block_promise.block_id()) { - entry.insert((Instant::now(), block_promise), REQUEST_TIMEOUT); - - if map.len() == 1 { - drop(map); - self.notify.notify_waiters(); - } - - true - } else { - false - } - } - - fn remove(&self, block_id: &BlockId) -> Option<(Instant, BlockPromise)> { - self.map.lock().unwrap().remove(block_id) - } -} - -async fn run_expiration_tracker( - monitor: Arc, - pending: Arc, -) { - // NOTE: The `expired` fn does not always complete when the last item is removed from the - // DelayMap. There is an issue in the DelayQueue used by DelayMap, reported here: - // https://github.com/tokio-rs/tokio/issues/6751 - - loop { - let notified = pending.notify.notified(); - - while expired(&pending.map).await { - monitor.request_timeouts.increment(1); - monitor.block_requests_inflight.decrement(1.0); - } - - // Last item removed from the map. Wait until new item added. - notified.await; - } -} - -// Wait for the next expired request. This does not block the map so it can be inserted / removed -// from while this is being awaited. -// Returns `true` if a request expired and `false` if there are no more pending requests. -async fn expired(map: &BlockingMutex>) -> bool { - future::poll_fn(|cx| Poll::Ready(ready!(map.lock().unwrap().poll_expired(cx)))) - .await - .is_some() -} diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index 1dbf820d9..c298dd5ee 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -25,20 +25,24 @@ use tracing::instrument; /// Keeps track of in-flight requests. Falls back on another peer in case the request failed (due to /// error response, timeout or disconnection). Evenly distributes the requests between the peers /// and ensures every request is only sent to one peer at a time. +#[derive(Clone)] pub(super) struct RequestTracker { command_tx: mpsc::UnboundedSender, } impl RequestTracker { - #[expect(dead_code)] pub fn new() -> Self { let (this, worker) = build(); task::spawn(worker.run()); this } - #[cfg_attr(not(test), expect(dead_code))] - pub fn new_client(&self) -> (RequestTrackerClient, mpsc::UnboundedReceiver) { + pub fn new_client( + &self, + ) -> ( + RequestTrackerClient, + mpsc::UnboundedReceiver, + ) { let client_id = ClientId::next(); let (request_tx, request_rx) = mpsc::unbounded_channel(); @@ -66,20 +70,17 @@ pub(super) struct RequestTrackerClient { impl RequestTrackerClient { /// Handle sending a request that does not follow from any previously received response. - #[expect(dead_code)] - pub fn initial(&self, request: Request, block_presence: MultiBlockPresence) { + pub fn initial(&self, request: Request) { self.command_tx .send(Command::HandleInitial { client_id: self.client_id, request, - block_presence, }) .ok(); } /// Handle sending requests that follow from a received success response. - #[cfg_attr(not(test), expect(dead_code))] - pub fn success(&self, request_key: MessageKey, requests: Vec<(Request, MultiBlockPresence)>) { + pub fn success(&self, request_key: MessageKey, requests: Vec) { self.command_tx .send(Command::HandleSuccess { client_id: self.client_id, @@ -90,7 +91,6 @@ impl RequestTrackerClient { } /// Handle failure response. - #[cfg_attr(not(test), expect(dead_code))] pub fn failure(&self, request_key: MessageKey) { self.command_tx .send(Command::HandleFailure { @@ -99,6 +99,18 @@ impl RequestTrackerClient { }) .ok(); } + + /// Commit all successfully completed requests. + /// + /// If this client is dropped before this is called, all requests successfully completed by this + /// client will be considered failed and will be made available for retry by other clients. + pub fn commit(&self) { + self.command_tx + .send(Command::Commit { + client_id: self.client_id, + }) + .ok(); + } } impl Drop for RequestTrackerClient { @@ -111,14 +123,13 @@ impl Drop for RequestTrackerClient { } } -/// Permit to send the specified request. Contains also the block presence as reported by the peer -/// who sent the response that triggered this request. That is mostly useful for diagnostics and -/// testing. -#[derive(Debug)] -pub(super) struct SendPermit { - #[cfg_attr(not(test), expect(dead_code))] +/// Request to be sent to the peer. +/// +/// It also contains the block presence from the response that triggered this request. This is +/// mostly useful for diagnostics and testing. +#[derive(Clone, Debug)] +pub(super) struct PendingRequest { pub request: Request, - #[cfg_attr(not(test), expect(dead_code))] pub block_presence: MultiBlockPresence, } @@ -204,12 +215,8 @@ impl Worker { Command::RemoveClient { client_id } => { self.remove_client(client_id); } - Command::HandleInitial { - client_id, - request, - block_presence, - } => { - self.handle_initial(client_id, request, block_presence); + Command::HandleInitial { client_id, request } => { + self.handle_initial(client_id, request); } Command::HandleSuccess { client_id, @@ -224,6 +231,9 @@ impl Worker { } => { self.handle_failure(client_id, request_key, FailureReason::Response); } + Command::Commit { client_id } => { + self.commit(client_id); + } } } @@ -231,18 +241,16 @@ impl Worker { fn insert_client( &mut self, client_id: ClientId, - request_tx: mpsc::UnboundedSender, + request_tx: mpsc::UnboundedSender, ) { - #[cfg(test)] - tracing::debug!("insert_client"); + tracing::trace!("insert_client"); self.clients.insert(client_id, ClientState::new(request_tx)); } #[instrument(skip(self))] fn remove_client(&mut self, client_id: ClientId) { - #[cfg(test)] - tracing::debug!("remove_client"); + tracing::trace!("remove_client"); let Some(client_state) = self.clients.remove(&client_id) else { return; @@ -254,16 +262,17 @@ impl Worker { } #[instrument(skip(self))] - fn handle_initial( - &mut self, - client_id: ClientId, - request: Request, - block_presence: MultiBlockPresence, - ) { - #[cfg(test)] - tracing::debug!("handle_initial"); + fn handle_initial(&mut self, client_id: ClientId, request: Request) { + tracing::trace!("handle_initial"); - self.insert_request(client_id, request, block_presence, None) + self.insert_request( + client_id, + PendingRequest { + request, + block_presence: MultiBlockPresence::None, + }, + None, + ) } #[instrument(skip(self))] @@ -271,10 +280,9 @@ impl Worker { &mut self, client_id: ClientId, request_key: MessageKey, - requests: Vec<(Request, MultiBlockPresence)>, + requests: Vec, ) { - #[cfg(test)] - tracing::debug!("handle_success"); + tracing::trace!("handle_success"); let node_key = self .clients @@ -330,17 +338,12 @@ impl Worker { // were waiting for the original request. client_ids.push_front(client_id); - for (child_request, child_block_presence) in requests { + for child_request in requests { for (client_id, child_request) in // TODO: use `repeat_n` once it gets stabilized. client_ids.iter().copied().zip(iter::repeat(child_request)) { - self.insert_request( - client_id, - child_request.clone(), - child_block_presence, - node_key, - ); + self.insert_request(client_id, child_request, node_key); } // Round-robin the requests among the clients. @@ -355,8 +358,7 @@ impl Worker { request_key: MessageKey, reason: FailureReason, ) { - #[cfg(test)] - tracing::debug!("handle_failure"); + tracing::trace!("handle_failure"); let Some(client_state) = self.clients.get_mut(&client_id) else { return; @@ -369,19 +371,22 @@ impl Worker { self.cancel_request(client_id, node_key); } + #[instrument(skip(self))] + fn commit(&mut self, client_id: ClientId) { + tracing::trace!("commit_failure"); + + todo!() + } + fn insert_request( &mut self, client_id: ClientId, - request: Request, - block_presence: MultiBlockPresence, + request: PendingRequest, parent_key: Option, ) { - let node_key = self.requests.get_or_insert( - request, - block_presence, - parent_key, - RequestState::Cancelled, - ); + let node_key = self + .requests + .get_or_insert(request, parent_key, RequestState::Cancelled); self.update_request(client_id, node_key); } @@ -395,7 +400,7 @@ impl Worker { return; }; - let request_key = MessageKey::from(node.request()); + let request_key = MessageKey::from(&node.request().request); match node.value_mut() { RequestState::InFlight { waiters, .. } => { @@ -413,13 +418,7 @@ impl Worker { }; client_state.requests.insert(request_key, node_key); - client_state - .request_tx - .send(SendPermit { - request: node.request().clone(), - block_presence: *node.block_presence(), - }) - .ok(); + client_state.request_tx.send(node.request().clone()).ok(); } } @@ -437,7 +436,7 @@ impl Worker { return; }; - let (request, &block_presence, state) = node.parts_mut(); + let (request, state) = node.parts_mut(); match state { RequestState::InFlight { @@ -462,18 +461,13 @@ impl Worker { // Next waiting client found. Promote it to a sender. *sender_client_id = next_client_id; - *sender_timer_key = self - .timer - .insert((next_client_id, MessageKey::from(request)), REQUEST_TIMEOUT); - - // Send the permit to the new sender. - next_client_state - .request_tx - .send(SendPermit { - request: request.clone(), - block_presence, - }) - .ok(); + *sender_timer_key = self.timer.insert( + (next_client_id, MessageKey::from(&request.request)), + REQUEST_TIMEOUT, + ); + + // Send the request to the new sender. + next_client_state.request_tx.send(request.clone()).ok(); return; } else { @@ -529,7 +523,7 @@ impl ClientId { enum Command { InsertClient { client_id: ClientId, - request_tx: mpsc::UnboundedSender, + request_tx: mpsc::UnboundedSender, }, RemoveClient { client_id: ClientId, @@ -537,26 +531,28 @@ enum Command { HandleInitial { client_id: ClientId, request: Request, - block_presence: MultiBlockPresence, }, HandleSuccess { client_id: ClientId, request_key: MessageKey, - requests: Vec<(Request, MultiBlockPresence)>, + requests: Vec, }, HandleFailure { client_id: ClientId, request_key: MessageKey, }, + Commit { + client_id: ClientId, + }, } struct ClientState { - request_tx: mpsc::UnboundedSender, + request_tx: mpsc::UnboundedSender, requests: HashMap, } impl ClientState { - fn new(request_tx: mpsc::UnboundedSender) -> Self { + fn new(request_tx: mpsc::UnboundedSender) -> Self { Self { request_tx, requests: HashMap::default(), diff --git a/lib/src/network/request_tracker/graph.rs b/lib/src/network/request_tracker/graph.rs index a8eb1cc61..c94661901 100644 --- a/lib/src/network/request_tracker/graph.rs +++ b/lib/src/network/request_tracker/graph.rs @@ -1,4 +1,4 @@ -use super::MessageKey; +use super::{MessageKey, PendingRequest}; use crate::{ collections::{HashMap, HashSet}, network::message::Request, @@ -23,14 +23,13 @@ impl Graph { pub fn get_or_insert( &mut self, - request: Request, - block_presence: MultiBlockPresence, + request: PendingRequest, parent_key: Option, value: T, ) -> Key { let node_key = match self .index - .entry((MessageKey::from(&request), block_presence)) + .entry((MessageKey::from(&request.request), request.block_presence)) { Entry::Occupied(entry) => { self.nodes @@ -44,7 +43,6 @@ impl Graph { Entry::Vacant(entry) => { let node_key = self.nodes.insert(Node { request, - block_presence, parents: parent_key.into_iter().collect(), children: HashSet::default(), value, @@ -77,8 +75,10 @@ impl Graph { pub fn remove(&mut self, key: Key) -> Option> { let node = self.nodes.try_remove(key.0)?; - self.index - .remove(&(MessageKey::from(&node.request), node.block_presence)); + self.index.remove(&( + MessageKey::from(&node.request.request), + node.request.block_presence, + )); for parent_key in &node.parents { let Some(parent_node) = self.nodes.get_mut(parent_key.0) else { @@ -101,7 +101,7 @@ impl Graph { #[cfg_attr(not(test), expect(dead_code))] pub fn requests(&self) -> impl ExactSizeIterator { - self.nodes.iter().map(|(_, node)| &node.request) + self.nodes.iter().map(|(_, node)| &node.request.request) } } @@ -109,8 +109,7 @@ impl Graph { pub(super) struct Key(usize); pub(super) struct Node { - request: Request, - block_presence: MultiBlockPresence, + request: PendingRequest, parents: HashSet, children: HashSet, value: T, @@ -126,16 +125,12 @@ impl Node { &mut self.value } - pub fn request(&self) -> &Request { + pub fn request(&self) -> &PendingRequest { &self.request } - pub fn block_presence(&self) -> &MultiBlockPresence { - &self.block_presence - } - - pub fn parts_mut(&mut self) -> (&Request, &MultiBlockPresence, &mut T) { - (&self.request, &self.block_presence, &mut self.value) + pub fn parts_mut(&mut self) -> (&PendingRequest, &mut T) { + (&self.request, &mut self.value) } pub fn parents(&self) -> impl ExactSizeIterator + '_ { @@ -166,8 +161,14 @@ mod tests { DebugRequest::start(), ); - let parent_node_key = - graph.get_or_insert(parent_request.clone(), MultiBlockPresence::Full, None, 1); + let parent_node_key = graph.get_or_insert( + PendingRequest { + request: parent_request.clone(), + block_presence: MultiBlockPresence::Full, + }, + None, + 1, + ); assert_eq!(graph.requests().len(), 1); @@ -177,7 +178,7 @@ mod tests { assert_eq!(*node.value(), 1); assert_eq!(node.children().len(), 0); - assert_eq!(node.request(), &parent_request); + assert_eq!(node.request().request, parent_request); let child_request = Request::ChildNodes( rng.gen(), @@ -186,8 +187,10 @@ mod tests { ); let child_node_key = graph.get_or_insert( - child_request.clone(), - MultiBlockPresence::Full, + PendingRequest { + request: child_request.clone(), + block_presence: MultiBlockPresence::Full, + }, Some(parent_node_key), 2, ); @@ -200,7 +203,7 @@ mod tests { assert_eq!(*node.value(), 2); assert_eq!(node.children().len(), 0); - assert_eq!(node.request(), &child_request); + assert_eq!(node.request().request, child_request); assert_eq!( graph @@ -229,10 +232,24 @@ mod tests { DebugRequest::start(), ); - let node_key0 = graph.get_or_insert(request.clone(), MultiBlockPresence::Full, None, 1); + let node_key0 = graph.get_or_insert( + PendingRequest { + request: request.clone(), + block_presence: MultiBlockPresence::Full, + }, + None, + 1, + ); assert_eq!(graph.requests().len(), 1); - let node_key1 = graph.get_or_insert(request, MultiBlockPresence::Full, None, 1); + let node_key1 = graph.get_or_insert( + PendingRequest { + request, + block_presence: MultiBlockPresence::Full, + }, + None, + 1, + ); assert_eq!(graph.requests().len(), 1); assert_eq!(node_key0, node_key1); } @@ -244,35 +261,43 @@ mod tests { let hash = rng.gen(); - let parent_block_presence_0 = MultiBlockPresence::None; - let parent_request_0 = Request::ChildNodes( - hash, - ResponseDisambiguator::new(parent_block_presence_0), - DebugRequest::start(), - ); + let parent_request_0 = PendingRequest { + request: Request::ChildNodes( + hash, + ResponseDisambiguator::new(MultiBlockPresence::None), + DebugRequest::start(), + ), + block_presence: MultiBlockPresence::None, + }; - let parent_block_presence_1 = MultiBlockPresence::Full; - let parent_request_1 = Request::ChildNodes( - hash, - ResponseDisambiguator::new(parent_block_presence_1), - DebugRequest::start(), - ); + let parent_request_1 = PendingRequest { + request: Request::ChildNodes( + hash, + ResponseDisambiguator::new(MultiBlockPresence::Full), + DebugRequest::start(), + ), + block_presence: MultiBlockPresence::Full, + }; let child_request = Request::Block(rng.gen(), DebugRequest::start()); - let parent_key_0 = graph.get_or_insert(parent_request_0, parent_block_presence_0, None, 0); - let parent_key_1 = graph.get_or_insert(parent_request_1, parent_block_presence_1, None, 1); + let parent_key_0 = graph.get_or_insert(parent_request_0, None, 0); + let parent_key_1 = graph.get_or_insert(parent_request_1, None, 1); let child_key_0 = graph.get_or_insert( - child_request.clone(), - MultiBlockPresence::None, + PendingRequest { + request: child_request.clone(), + block_presence: MultiBlockPresence::None, + }, Some(parent_key_0), 2, ); let child_key_1 = graph.get_or_insert( - child_request, - MultiBlockPresence::None, + PendingRequest { + request: child_request, + block_presence: MultiBlockPresence::None, + }, Some(parent_key_1), 2, ); @@ -327,39 +352,36 @@ mod tests { let mut rng = rand::thread_rng(); let mut graph = Graph::new(); - let parent_request = Request::ChildNodes( - rng.gen(), - ResponseDisambiguator::new(MultiBlockPresence::Full), - DebugRequest::start(), - ); - - let child_request_0 = Request::ChildNodes( - rng.gen(), - ResponseDisambiguator::new(MultiBlockPresence::Full), - DebugRequest::start(), - ); - - let child_request_1 = Request::ChildNodes( - rng.gen(), - ResponseDisambiguator::new(MultiBlockPresence::Full), - DebugRequest::start(), - ); + let parent_request = PendingRequest { + request: Request::ChildNodes( + rng.gen(), + ResponseDisambiguator::new(MultiBlockPresence::Full), + DebugRequest::start(), + ), + block_presence: MultiBlockPresence::Full, + }; - let parent_key = graph.get_or_insert(parent_request, MultiBlockPresence::Full, None, 0); + let child_request_0 = PendingRequest { + request: Request::ChildNodes( + rng.gen(), + ResponseDisambiguator::new(MultiBlockPresence::Full), + DebugRequest::start(), + ), + block_presence: MultiBlockPresence::Full, + }; - let child_key_0 = graph.get_or_insert( - child_request_0, - MultiBlockPresence::Full, - Some(parent_key), - 1, - ); + let child_request_1 = PendingRequest { + request: Request::ChildNodes( + rng.gen(), + ResponseDisambiguator::new(MultiBlockPresence::Full), + DebugRequest::start(), + ), + block_presence: MultiBlockPresence::Full, + }; - let child_key_1 = graph.get_or_insert( - child_request_1, - MultiBlockPresence::Full, - Some(parent_key), - 2, - ); + let parent_key = graph.get_or_insert(parent_request, None, 0); + let child_key_0 = graph.get_or_insert(child_request_0, Some(parent_key), 1); + let child_key_1 = graph.get_or_insert(child_request_1, Some(parent_key), 2); assert_eq!( graph diff --git a/lib/src/network/request_tracker/simulation.rs b/lib/src/network/request_tracker/simulation.rs index ee5c8a63c..d1d1cffcd 100644 --- a/lib/src/network/request_tracker/simulation.rs +++ b/lib/src/network/request_tracker/simulation.rs @@ -1,6 +1,6 @@ use super::{ super::message::{Request, Response, ResponseDisambiguator}, - MessageKey, RequestTracker, RequestTrackerClient, SendPermit, + MessageKey, PendingRequest, RequestTracker, RequestTrackerClient, }; use crate::{ collections::{HashMap, HashSet}, @@ -91,7 +91,7 @@ impl Simulation { match side { Side::Client => { - if let Some(SendPermit { + if let Some(PendingRequest { request, block_presence, }) = peer.client.poll_request() @@ -173,13 +173,13 @@ struct TestPeer { struct TestClient { tracker_client: RequestTrackerClient, - tracker_request_rx: mpsc::UnboundedReceiver, + tracker_request_rx: mpsc::UnboundedReceiver, } impl TestClient { fn new( tracker_client: RequestTrackerClient, - tracker_request_rx: mpsc::UnboundedReceiver, + tracker_request_rx: mpsc::UnboundedReceiver, ) -> Self { Self { tracker_client, @@ -192,14 +192,14 @@ impl TestClient { Response::RootNode(proof, block_presence, debug_payload) => { let requests = snapshot .insert_root(proof.hash, block_presence) - .then_some(( - Request::ChildNodes( + .then_some(PendingRequest { + request: Request::ChildNodes( proof.hash, ResponseDisambiguator::new(block_presence), debug_payload.follow_up(), ), block_presence, - )) + }) .into_iter() .collect(); @@ -212,15 +212,13 @@ impl TestClient { let requests: Vec<_> = nodes .into_iter() - .map(|(_, node)| { - ( - Request::ChildNodes( - node.hash, - ResponseDisambiguator::new(node.summary.block_presence), - debug_payload.follow_up(), - ), - node.summary.block_presence, - ) + .map(|(_, node)| PendingRequest { + request: Request::ChildNodes( + node.hash, + ResponseDisambiguator::new(node.summary.block_presence), + debug_payload.follow_up(), + ), + block_presence: node.summary.block_presence, }) .collect(); @@ -232,11 +230,9 @@ impl TestClient { let nodes = snapshot.insert_leaves(nodes); let requests = nodes .into_iter() - .map(|node| { - ( - Request::Block(node.block_id, debug_payload.follow_up()), - MultiBlockPresence::None, - ) + .map(|node| PendingRequest { + request: Request::Block(node.block_id, debug_payload.follow_up()), + block_presence: MultiBlockPresence::None, }) .collect(); @@ -265,7 +261,7 @@ impl TestClient { }; } - fn poll_request(&mut self) -> Option { + fn poll_request(&mut self) -> Option { self.tracker_request_rx.try_recv().ok() } } diff --git a/lib/src/network/request_tracker/tests.rs b/lib/src/network/request_tracker/tests.rs index 03b024da1..2fadf53dd 100644 --- a/lib/src/network/request_tracker/tests.rs +++ b/lib/src/network/request_tracker/tests.rs @@ -162,12 +162,18 @@ async fn timeout() { // Register the request with both clients. client_a.success( preceding_request_key, - vec![(request.clone(), MultiBlockPresence::Full)], + vec![PendingRequest { + request: request.clone(), + block_presence: MultiBlockPresence::Full, + }], ); client_b.success( preceding_request_key, - vec![(request.clone(), MultiBlockPresence::Full)], + vec![PendingRequest { + request: request.clone(), + block_presence: MultiBlockPresence::Full, + }], ); time::timeout(Duration::from_millis(1), &mut work) diff --git a/lib/src/network/seen_peers.rs b/lib/src/network/seen_peers.rs index c0c4941d7..5632a36b1 100644 --- a/lib/src/network/seen_peers.rs +++ b/lib/src/network/seen_peers.rs @@ -65,7 +65,7 @@ impl SeenPeersInner { } fn start_new_round(&mut self) { - use crate::collections::hash_map::Entry; + use std::collections::hash_map::Entry; self.current_round_id += 1; self.rounds.retain(|round, peers| { @@ -243,7 +243,7 @@ impl fmt::Debug for SeenPeer { impl Drop for SeenPeer { fn drop(&mut self) { - use crate::collections::hash_map::Entry; + use std::collections::hash_map::Entry; let mut seen_peers = self.seen_peers.write().unwrap(); diff --git a/lib/src/network/tests.rs b/lib/src/network/tests.rs index 534a59357..494f717a0 100644 --- a/lib/src/network/tests.rs +++ b/lib/src/network/tests.rs @@ -2,10 +2,10 @@ use super::{ client::Client, constants::MAX_UNCHOKED_COUNT, message::{Message, Request, Response}, + request_tracker::RequestTracker, server::Server, }; use crate::{ - block_tracker::OfferState, crypto::sign::{Keypair, PublicKey}, db, event::{Event, EventSender, Payload}, @@ -21,11 +21,11 @@ use futures_util::{future, TryStreamExt}; use metrics::NoopRecorder; use rand::prelude::*; use state_monitor::StateMonitor; -use std::{fmt, future::Future, sync::Arc}; +use std::{fmt, future::Future, pin::pin, sync::Arc}; use tempfile::TempDir; use test_strategy::proptest; use tokio::{ - pin, select, + select, sync::{ broadcast::{self, error::RecvError}, mpsc, Semaphore, @@ -69,8 +69,9 @@ async fn transfer_snapshot_between_two_replicas_case( let mut rng = StdRng::seed_from_u64(rng_seed); let write_keys = Keypair::generate(&mut rng); - let (_a_base_dir, a_vault, a_choker, a_id) = create_repository(&mut rng, &write_keys).await; - let (_b_base_dir, b_vault, _, _) = create_repository(&mut rng, &write_keys).await; + let (_a_base_dir, a_vault, _, a_choker, a_id) = create_repository(&mut rng, &write_keys).await; + let (_b_base_dir, b_vault, b_request_tracker, _, _) = + create_repository(&mut rng, &write_keys).await; let snapshot = Snapshot::generate(&mut rng, leaf_count); save_snapshot(&a_vault, a_id, &write_keys, &snapshot).await; @@ -79,7 +80,7 @@ async fn transfer_snapshot_between_two_replicas_case( assert!(load_latest_root_node(&b_vault, &a_id).await.is_none()); let mut server = create_server(a_vault.clone(), a_choker); - let mut client = create_client(b_vault.clone()); + let mut client = create_client(b_vault.clone(), &b_request_tracker); // Wait until replica B catches up to replica A, then have replica A perform a local change // and repeat. @@ -125,8 +126,9 @@ async fn transfer_blocks_between_two_replicas_case(block_count: usize, rng_seed: let mut rng = StdRng::seed_from_u64(rng_seed); let write_keys = Keypair::generate(&mut rng); - let (_a_base_dir, a_vault, a_choker, a_id) = create_repository(&mut rng, &write_keys).await; - let (_b_base_dir, b_vault, _, b_id) = create_repository(&mut rng, &write_keys).await; + let (_a_base_dir, a_vault, _, a_choker, a_id) = create_repository(&mut rng, &write_keys).await; + let (_b_base_dir, b_vault, b_request_tracker, _, b_id) = + create_repository(&mut rng, &write_keys).await; // Initially both replicas have the whole snapshot but no blocks. let snapshot = Snapshot::generate(&mut rng, block_count); @@ -134,9 +136,7 @@ async fn transfer_blocks_between_two_replicas_case(block_count: usize, rng_seed: save_snapshot(&b_vault, b_id, &write_keys, &snapshot).await; let mut server = create_server(a_vault.clone(), a_choker.clone()); - let mut client = create_client(b_vault.clone()); - - let a_block_tracker = a_vault.block_tracker.client(); + let mut client = create_client(b_vault.clone(), &b_request_tracker); // Receive the blocks by replica A and verify they get received by replica B as well. let drive = async { @@ -144,17 +144,7 @@ async fn transfer_blocks_between_two_replicas_case(block_count: usize, rng_seed: let mut writer = a_vault.store().begin_client_write().await.unwrap(); for (id, block) in snapshot.blocks() { - a_vault.block_tracker.require(*id); - a_block_tracker.register(*id, OfferState::Approved); - let promise = a_block_tracker - .offers() - .try_next() - .unwrap() - .accept() - .unwrap(); - - writer.save_block(block, Some(promise)).await.unwrap(); - + writer.save_block(block).await.unwrap(); tracing::info!(?id, "save block"); } @@ -189,15 +179,16 @@ async fn failed_block_only_peer() { let mut rng = StdRng::seed_from_u64(0); let write_keys = Keypair::generate(&mut rng); - let (_a_base_dir, a_vault, a_choker, a_id) = create_repository(&mut rng, &write_keys).await; - let (_a_base_dir, b_vault, _, _) = create_repository(&mut rng, &write_keys).await; + let (_a_base_dir, a_vault, _, a_choker, a_id) = create_repository(&mut rng, &write_keys).await; + let (_b_base_dir, b_vault, b_request_tracker, _, _) = + create_repository(&mut rng, &write_keys).await; let snapshot = Snapshot::generate(&mut rng, 1); save_snapshot(&a_vault, a_id, &write_keys, &snapshot).await; save_blocks(&a_vault, &snapshot).await; let mut server = create_server(a_vault.clone(), a_choker.clone()); - let mut client = create_client(b_vault.clone()); + let mut client = create_client(b_vault.clone(), &b_request_tracker); simulate_connection_until( &mut server, @@ -211,7 +202,7 @@ async fn failed_block_only_peer() { drop(client); let mut server = create_server(a_vault.clone(), a_choker); - let mut client = create_client(b_vault.clone()); + let mut client = create_client(b_vault.clone(), &b_request_tracker); simulate_connection_until(&mut server, &mut client, async { for id in snapshot.blocks().keys() { @@ -228,9 +219,10 @@ async fn failed_block_same_peer() { let mut rng = StdRng::seed_from_u64(0); let write_keys = Keypair::generate(&mut rng); - let (_a_base_dir, a_vault, a_choker, a_id) = create_repository(&mut rng, &write_keys).await; - let (_b_base_dir, b_vault, b_choker, _) = create_repository(&mut rng, &write_keys).await; - let (_c_base_dir, c_vault, _, _) = create_repository(&mut rng, &write_keys).await; + let (_a_base_dir, a_vault, _, a_choker, a_id) = create_repository(&mut rng, &write_keys).await; + let (_b_base_dir, b_vault, _, b_choker, _) = create_repository(&mut rng, &write_keys).await; + let (_c_base_dir, c_vault, c_request_tracker, _, _) = + create_repository(&mut rng, &write_keys).await; let snapshot = Snapshot::generate(&mut rng, 1); save_snapshot(&a_vault, a_id, &write_keys, &snapshot).await; @@ -247,18 +239,17 @@ async fn failed_block_same_peer() { // [B]-(server_bc)---+ let mut server_ac = create_server(a_vault.clone(), a_choker.clone()); - let mut client_ca = create_client(c_vault.clone()); + let mut client_ca = create_client(c_vault.clone(), &c_request_tracker); let mut server_bc = create_server(b_vault.clone(), b_choker); - let mut client_cb = create_client(c_vault.clone()); + let mut client_cb = create_client(c_vault.clone(), &c_request_tracker); // Run both connections in parallel until C syncs its index (but not blocks) with A let conn_ac = simulate_connection(&mut server_ac, &mut client_ca); let conn_ac = conn_ac.instrument(tracing::info_span!("AC1")); let conn_bc = simulate_connection(&mut server_bc, &mut client_cb); - let conn_bc = conn_bc.instrument(tracing::info_span!("BC")); - pin!(conn_bc); + let mut conn_bc = pin!(conn_bc.instrument(tracing::info_span!("BC"))); run_until( future::join(conn_ac, &mut conn_bc), @@ -271,7 +262,7 @@ async fn failed_block_same_peer() { drop(client_ca); let mut server_ac = create_server(a_vault.clone(), a_choker); - let mut client_ca = create_client(c_vault.clone()); + let mut client_ca = create_client(c_vault.clone(), &c_request_tracker); // Run the new A-C connection in parallel with the existing B-C connection until all blocks are // received. @@ -289,6 +280,7 @@ async fn failed_block_same_peer() { // This test verifies that when there are two peers that have a particular block, even when one of // them drops, we can still succeed in retrieving the block from the remaining peer. #[tokio::test] +#[ignore = "request tracker is not cancel safe"] async fn failed_block_other_peer() { test_utils::init_log(); @@ -298,9 +290,12 @@ async fn failed_block_other_peer() { let mut rng = StdRng::seed_from_u64(0); let write_keys = Keypair::generate(&mut rng); - let (_a_base_dir, a_vault, a_choker, a_id) = create_repository(&mut rng, &write_keys).await; - let (_b_base_dir, b_vault, b_choker, b_id) = create_repository(&mut rng, &write_keys).await; - let (_c_base_dir, c_vault, _, _) = create_repository(&mut rng, &write_keys).await; + let (_a_base_dir, a_vault, _, a_choker, a_id) = + create_repository(&mut rng, &write_keys).await; + let (_b_base_dir, b_vault, b_request_tracker, b_choker, b_id) = + create_repository(&mut rng, &write_keys).await; + let (_c_base_dir, c_vault, c_request_tracker, _, _) = + create_repository(&mut rng, &write_keys).await; // Create the snapshot by A let snapshot = Snapshot::generate(&mut rng, 1); @@ -309,7 +304,7 @@ async fn failed_block_other_peer() { // Sync B with A let mut server_ab = create_server(a_vault.clone(), a_choker.clone()); - let mut client_ba = create_client(b_vault.clone()); + let mut client_ba = create_client(b_vault.clone(), &b_request_tracker); simulate_connection_until(&mut server_ab, &mut client_ba, async { for id in snapshot.blocks().keys() { wait_until_block_exists(&b_vault, id).await; @@ -334,12 +329,12 @@ async fn failed_block_other_peer() { let enter = span_ac.enter(); let mut server_ac = create_server(a_vault.clone(), a_choker); - let mut client_ca = create_client(c_vault.clone()); + let mut client_ca = create_client(c_vault.clone(), &c_request_tracker); drop(enter); let enter = span_bc.enter(); let mut server_bc = create_server(b_vault.clone(), b_choker); - let mut client_cb = create_client(c_vault.clone()); + let mut client_cb = create_client(c_vault.clone(), &c_request_tracker); drop(enter); // Run the two connections in parallel until C syncs its index with both A and B. @@ -358,6 +353,7 @@ async fn failed_block_other_peer() { // Drop the A-C connection so C can't receive any blocks from A anymore. let enter = span_ac.enter(); + tracing::info!("dropping connection"); drop(server_ac); drop(client_ca); drop(enter); @@ -400,7 +396,7 @@ async fn failed_block_other_peer() { async fn create_repository( rng: &mut R, write_keys: &Keypair, -) -> (TempDir, Vault, Arc, PublicKey) { +) -> (TempDir, Vault, RequestTracker, Arc, PublicKey) { let (base_dir, db) = db::create_temp().await.unwrap(); let writer_id = PublicKey::generate(rng); let repository_id = RepositoryId::from(write_keys.public_key()); @@ -413,9 +409,16 @@ async fn create_repository( RepositoryMonitor::new(StateMonitor::make_root(), &NoopRecorder), ); + let request_tracker = RequestTracker::new(); let response_limiter = Arc::new(Semaphore::new(MAX_UNCHOKED_COUNT)); - (base_dir, state, response_limiter, writer_id) + ( + base_dir, + state, + request_tracker, + response_limiter, + writer_id, + ) } // Enough capacity to prevent deadlocks. @@ -627,18 +630,18 @@ type ClientData = ( mpsc::Sender, ); -fn create_server(repo: Vault, response_limiter: Arc) -> ServerData { +fn create_server(vault: Vault, response_limiter: Arc) -> ServerData { let (send_tx, send_rx) = mpsc::unbounded_channel(); let (recv_tx, recv_rx) = mpsc::channel(CAPACITY); - let server = Server::new(repo, send_tx, recv_rx, response_limiter); + let server = Server::new(vault, send_tx, recv_rx, response_limiter); (server, send_rx, recv_tx) } -fn create_client(repo: Vault) -> ClientData { +fn create_client(vault: Vault, request_tracker: &RequestTracker) -> ClientData { let (send_tx, send_rx) = mpsc::unbounded_channel(); let (recv_tx, recv_rx) = mpsc::channel(CAPACITY); - let client = Client::new(repo, send_tx, recv_rx); + let client = Client::new(vault, send_tx, recv_rx, request_tracker); (client, send_rx, recv_tx) } diff --git a/lib/src/network/upnp.rs b/lib/src/network/upnp.rs index db2fcca2a..d39095158 100644 --- a/lib/src/network/upnp.rs +++ b/lib/src/network/upnp.rs @@ -1,5 +1,5 @@ use super::ip; -use crate::collections::{hash_map, HashMap}; +use crate::collections::HashMap; use chrono::{offset::Local, DateTime}; use deadlock::BlockingMutex; use futures_util::TryStreamExt; @@ -12,6 +12,7 @@ use rupnp::{ use scoped_task::ScopedJoinHandle; use state_monitor::StateMonitor; use std::{ + collections::hash_map::Entry, fmt, future::Future, io, net, @@ -62,11 +63,11 @@ impl PortForwarder { }; let is_new_mapping = match self.mappings.lock().unwrap().entry(data) { - hash_map::Entry::Occupied(mut entry) => { + Entry::Occupied(mut entry) => { *entry.get_mut() += 1; false } - hash_map::Entry::Vacant(entry) => { + Entry::Vacant(entry) => { tracing::info!( parent: &self.span, "UPnP starting port forwarding EXT:{} -> INT:{} ({})", @@ -263,8 +264,6 @@ impl PortForwarder { }) }; - use crate::collections::hash_map::Entry; - match job_handles.entry(device_url.clone()) { Entry::Occupied(mut entry) => { let dev = entry.get_mut(); @@ -314,7 +313,7 @@ impl Drop for Mapping { let mut mappings = self.mappings.lock().unwrap(); match mappings.entry(self.data) { - hash_map::Entry::Occupied(mut entry) => { + Entry::Occupied(mut entry) => { let refcount = entry.get_mut(); *refcount -= 1; @@ -323,7 +322,7 @@ impl Drop for Mapping { self.on_change_tx.send(()).unwrap_or(()); } } - hash_map::Entry::Vacant(_) => { + Entry::Vacant(_) => { unreachable!(); } } @@ -370,7 +369,7 @@ impl PerIGDPortForwarder { // Add to `active_mappings` those that are `active_mappings`. for k in new_mappings.keys() { - if let hash_map::Entry::Vacant(entry) = active_mappings.entry(*k) { + if let Entry::Vacant(entry) = active_mappings.entry(*k) { entry.insert(self.activate_mapping(*k, local_ip, &mappings_monitor)); } } diff --git a/lib/src/repository/mod.rs b/lib/src/repository/mod.rs index 8b2ba79ba..ff261022c 100644 --- a/lib/src/repository/mod.rs +++ b/lib/src/repository/mod.rs @@ -18,7 +18,7 @@ pub(crate) use self::{ use crate::{ access_control::{Access, AccessChange, AccessKeys, AccessMode, AccessSecrets, LocalSecret}, - block_tracker::RequestMode, + block_tracker::BlockRequestMode, branch::{Branch, BranchShared}, crypto::{sign::PublicKey, PasswordSalt}, db::{self, DatabaseId}, @@ -959,7 +959,7 @@ impl Repository { self.shared .vault .block_tracker - .set_request_mode(request_mode(&credentials.secrets)); + .set_request_mode(block_request_mode(&credentials.secrets)); *self.shared.credentials.write().unwrap() = credentials; *self.worker_handle.lock().unwrap() = Some(spawn_worker(self.shared.clone())); @@ -983,7 +983,7 @@ impl Shared { vault .block_tracker - .set_request_mode(request_mode(&credentials.secrets)); + .set_request_mode(block_request_mode(&credentials.secrets)); Self { vault, @@ -1075,10 +1075,10 @@ async fn report_sync_progress(vault: Vault) { } } -fn request_mode(secrets: &AccessSecrets) -> RequestMode { +fn block_request_mode(secrets: &AccessSecrets) -> BlockRequestMode { if secrets.can_read() { - RequestMode::Lazy + BlockRequestMode::Lazy } else { - RequestMode::Greedy + BlockRequestMode::Greedy } } diff --git a/lib/src/repository/vault/tests.rs b/lib/src/repository/vault/tests.rs index b18aef312..9c7654e2b 100644 --- a/lib/src/repository/vault/tests.rs +++ b/lib/src/repository/vault/tests.rs @@ -1,7 +1,7 @@ use super::{RepositoryId, RepositoryMonitor, Vault}; use crate::{ access_control::WriteSecrets, - block_tracker::RequestMode, + block_tracker::BlockRequestMode, collections::HashSet, crypto::{ sign::{Keypair, PublicKey}, @@ -569,7 +569,7 @@ async fn setup_with_rng(rng: &mut StdRng) -> (TempDir, Vault, WriteSecrets) { RepositoryMonitor::new(StateMonitor::make_root(), &NoopRecorder), ); - vault.block_tracker.set_request_mode(RequestMode::Lazy); + vault.block_tracker.set_request_mode(BlockRequestMode::Lazy); (base_dir, vault, secrets) } @@ -604,7 +604,7 @@ async fn receive_snapshot( async fn receive_block(vault: &Vault, block: &Block) { let mut writer = vault.store().begin_client_write().await.unwrap(); - writer.save_block(block, None).await.unwrap(); + writer.save_block(block).await.unwrap(); writer.commit().await.unwrap(); } diff --git a/lib/src/repository/worker.rs b/lib/src/repository/worker.rs index 4108fa60c..e717620f9 100644 --- a/lib/src/repository/worker.rs +++ b/lib/src/repository/worker.rs @@ -17,9 +17,6 @@ use tokio::select; #[cfg(test)] mod tests; -/// Notify the block tracker after marking this many blocks as required. -const BLOCK_REQUIRE_BATCH_SIZE: u32 = 1024; - /// Background worker to perform various jobs on the repository: /// - merge remote branches into the local one /// - remove outdated branches and snapshots @@ -201,6 +198,8 @@ mod scan { } async fn run_once(shared: &Shared) -> Result<()> { + shared.vault.block_tracker.clear_required(); + let branches = shared.load_branches().await?; let mut versions = Vec::with_capacity(branches.len()); @@ -279,36 +278,15 @@ mod scan { branch: &Branch, blob_id: BlobId, ) -> Result<()> { - let mut blob_block_ids = - BlockIds::open(branch.clone(), blob_id) - .await - .map_err(|error| { - tracing::trace!(?error, "open failed"); - error - })?; - let mut block_number = 0; - let mut require_batch = shared.vault.block_tracker.require_batch(); - - while let Some((block_id, block_presence)) = - blob_block_ids.try_next().await.map_err(|error| { - tracing::trace!(block_number, ?error, "try_next failed"); - error - })? - { + let mut blob_block_ids = BlockIds::open(branch.clone(), blob_id).await?; + + while let Some((block_id, block_presence)) = blob_block_ids.try_next().await? { match block_presence { SingleBlockPresence::Present => (), SingleBlockPresence::Missing | SingleBlockPresence::Expired => { - require_batch.add(block_id); + shared.vault.block_tracker.require(block_id); } } - - // Notify the block tracker after processing each batch of blocks (also notify on the - // first blocks so that it's requested first). - if block_number % BLOCK_REQUIRE_BATCH_SIZE == 0 { - require_batch.commit(); - } - - block_number = block_number.saturating_add(1); } Ok(()) diff --git a/lib/src/store/block_expiration_tracker.rs b/lib/src/store/block_expiration_tracker.rs index 32fdd124f..1bab8d956 100644 --- a/lib/src/store/block_expiration_tracker.rs +++ b/lib/src/store/block_expiration_tracker.rs @@ -1,7 +1,7 @@ use super::{block, error::Error, index, leaf_node, root_node}; use crate::{ block_tracker::BlockTracker as BlockDownloadTracker, - collections::{hash_map, HashMap, HashSet}, + collections::{HashMap, HashSet}, crypto::sign::PublicKey, db, future::TryStreamExt as _, @@ -13,7 +13,7 @@ use futures_util::{StreamExt, TryStreamExt}; use scoped_task::{self, ScopedJoinHandle}; use sqlx::Row; use std::{ - collections::{btree_map, BTreeMap}, + collections::{btree_map, hash_map::Entry, BTreeMap}, sync::Arc, time::{Duration, SystemTime}, }; @@ -198,7 +198,7 @@ impl Shared { fn insert_block(&mut self, block: &BlockId, ts: TimeUpdated) { // Asserts and unwraps are OK due to the `Shared` invariants defined above. match self.blocks_by_id.entry(*block) { - hash_map::Entry::Occupied(mut entry) => { + Entry::Occupied(mut entry) => { let old_ts = *entry.get(); if old_ts == ts { return; @@ -223,7 +223,7 @@ impl Shared { .or_default() .insert(*block)); } - hash_map::Entry::Vacant(entry) => { + Entry::Vacant(entry) => { assert!(self .blocks_by_expiration .entry(ts) @@ -576,7 +576,7 @@ mod test { match op { Op::Receive(block) => { let mut writer = store.begin_client_write().await.unwrap(); - writer.save_block(&block, None).await.unwrap(); + writer.save_block(&block).await.unwrap(); writer.commit().await.unwrap(); } Op::Remove(id) => { diff --git a/lib/src/store/block_id_cache.rs b/lib/src/store/block_id_cache.rs index cc84fee90..c117f4704 100644 --- a/lib/src/store/block_id_cache.rs +++ b/lib/src/store/block_id_cache.rs @@ -1,6 +1,6 @@ use super::Error; use crate::{ - collections::{hash_map::Entry, HashMap}, + collections::HashMap, crypto::Hash, db, protocol::{BlockId, RootNode, SingleBlockPresence}, @@ -10,6 +10,7 @@ use futures_util::TryStreamExt; use sqlx::Row; use std::{ cmp::Ordering, + collections::hash_map::Entry, sync::{Arc, Mutex}, }; use tokio::sync::Notify; @@ -232,11 +233,7 @@ mod tests { // ... make some of them as present. for node in snapshot.leaf_nodes().take(num_present) { let block = snapshot.blocks().get(&node.block_id).unwrap(); - writer - .client_writer() - .save_block(block, None) - .await - .unwrap(); + writer.client_writer().save_block(block).await.unwrap(); } writer.commit().await; diff --git a/lib/src/store/client.rs b/lib/src/store/client.rs index 69515d650..f7b0c66a2 100644 --- a/lib/src/store/client.rs +++ b/lib/src/store/client.rs @@ -10,7 +10,6 @@ use super::{ Error, }; use crate::{ - block_tracker::{BlockPromise, OfferState}, collections::HashSet, crypto::{sign::PublicKey, CacheHash, Hash, Hashable}, db, @@ -29,7 +28,7 @@ pub(crate) struct ClientWriter { block_expiration_tracker: Option>, quota: Option, summary_updates: Vec, - saved_blocks: Vec, + new_blocks: Vec, block_id_cache: BlockIdCache, block_id_cache_updates: Vec<(Hash, BlockId)>, } @@ -47,7 +46,7 @@ impl ClientWriter { block_expiration_tracker, quota, summary_updates: Vec::new(), - saved_blocks: Vec::new(), + new_blocks: Vec::new(), block_id_cache, block_id_cache_updates: Vec::new(), }) @@ -141,18 +140,16 @@ impl ClientWriter { match leaf_node::load_block_presence(&mut self.db, &node.block_id).await? { Some(SingleBlockPresence::Missing) | None => { // Missing, expired or not yet stored locally - let offer_state = if self.quota.is_some() { + let node_state = if self.quota.is_some() { // OPTIMIZE: the state is the same for all the nodes in `nodes`, so // it only needs to be loaded once. - load_block_offer_state_assuming_quota(&mut self.db, &node.block_id) + root_node::load_node_state_of_missing(&mut self.db, &node.block_id) .await? } else { - Some(OfferState::Approved) + NodeState::Approved }; - if let Some(offer_state) = offer_state { - new_block_offers.push((node.block_id, offer_state)); - } + new_block_offers.push((node.block_id, node_state)); } Some(SingleBlockPresence::Present | SingleBlockPresence::Expired) => (), } @@ -168,11 +165,7 @@ impl ClientWriter { Ok(LeafNodesStatus { new_block_offers }) } - pub async fn save_block( - &mut self, - block: &Block, - block_promise: Option, - ) -> Result<(), Error> { + pub async fn save_block(&mut self, block: &Block) -> Result<(), Error> { let updated = { let mut updated = false; let mut updates = leaf_node::set_present(&mut self.db, &block.id); @@ -193,12 +186,9 @@ impl ClientWriter { if let Some(tracker) = &self.block_expiration_tracker { tracker.handle_block_update(&block.id, false); } - } - self.saved_blocks.push(match block_promise { - Some(block_promise) => SavedBlock::WithPromise(block_promise), - None => SavedBlock::WithoutPromise(block.id), - }); + self.new_blocks.push(block.id); + } Ok(()) } @@ -219,15 +209,9 @@ impl ClientWriter { .load_approved_missing_blocks(&approved_branches) .await?; - let new_blocks = self - .saved_blocks - .iter() - .map(|saved_block| *saved_block.id()) - .collect(); - let Self { db, - saved_blocks, + new_blocks, block_id_cache, block_id_cache_updates, .. @@ -243,13 +227,6 @@ impl ClientWriter { let output = db .commit_and_then(move || { block_id_cache.set_present(&block_id_cache_updates); - - for saved_block in saved_blocks { - if let SavedBlock::WithPromise(promise) = saved_block { - promise.complete(); - } - } - f(status) }) .await?; @@ -350,20 +327,22 @@ impl ClientReader { Ok(Self { db, quota }) } - /// Returns the state (`Pending` or `Approved`) that the offer for the given block should be - /// registered with. If the block isn't referenced or isn't missing, returns `None`. - pub async fn load_block_offer_state( + /// Loads the root node state that should be used for offers for the given block. + /// + /// NOTE: If quota is enabled, returns the actual root node state. Otherwise returns an assumed + /// state that depends only on the block presence, for efficiency. + pub async fn load_effective_root_node_state_for_block( &mut self, block_id: &BlockId, - ) -> Result, Error> { + ) -> Result { if self.quota.is_some() { - load_block_offer_state_assuming_quota(&mut self.db, block_id).await + root_node::load_node_state_of_missing(&mut self.db, block_id).await } else { match leaf_node::load_block_presence(&mut self.db, block_id).await? { Some(SingleBlockPresence::Missing) | Some(SingleBlockPresence::Expired) => { - Ok(Some(OfferState::Approved)) + Ok(NodeState::Approved) } - Some(SingleBlockPresence::Present) | None => Ok(None), + Some(SingleBlockPresence::Present) | None => Ok(NodeState::Rejected), } } } @@ -377,8 +356,8 @@ pub(crate) struct InnerNodesStatus { #[derive(Default)] pub(crate) struct LeafNodesStatus { - /// Number of new blocks offered by the received nodes. - pub new_block_offers: Vec<(BlockId, OfferState)>, + /// New blocks offered by the received nodes together with their root node states. + pub new_block_offers: Vec<(BlockId, NodeState)>, } pub(crate) struct CommitStatus { @@ -397,31 +376,6 @@ struct FinalizeStatus { rejected_branches: Vec, } -async fn load_block_offer_state_assuming_quota( - conn: &mut db::Connection, - block_id: &BlockId, -) -> Result, Error> { - match root_node::load_node_state_of_missing(conn, block_id).await? { - NodeState::Incomplete | NodeState::Complete => Ok(Some(OfferState::Pending)), - NodeState::Approved => Ok(Some(OfferState::Approved)), - NodeState::Rejected => Ok(None), - } -} - -enum SavedBlock { - WithPromise(BlockPromise), - WithoutPromise(BlockId), -} - -impl SavedBlock { - fn id(&self) -> &BlockId { - match self { - SavedBlock::WithPromise(promise) => promise.block_id(), - SavedBlock::WithoutPromise(block_id) => block_id, - } - } -} - #[cfg(test)] mod tests { use super::super::Changeset; @@ -733,7 +687,7 @@ mod tests { // Mark one of the missing block as present so the block presences are different (but still // `Some`). let mut writer = store.begin_client_write().await.unwrap(); - writer.save_block(&block_1, None).await.unwrap(); + writer.save_block(&block_1).await.unwrap(); writer.commit().await.unwrap(); // Receive the same node we already have. The hashes and version vectors are equal but the @@ -964,7 +918,7 @@ mod tests { let mut writer = store.begin_client_write().await.unwrap(); for block in snapshot.blocks().values() { - writer.save_block(block, None).await.unwrap(); + writer.save_block(block).await.unwrap(); } writer.commit().await.unwrap(); diff --git a/lib/src/store/test_utils.rs b/lib/src/store/test_utils.rs index a0e1bca9b..609b48816 100644 --- a/lib/src/store/test_utils.rs +++ b/lib/src/store/test_utils.rs @@ -74,7 +74,7 @@ impl<'a> SnapshotWriter<'a> { pub async fn save_blocks(mut self) -> Self { for block in self.snapshot.blocks().values() { - self.writer.save_block(block, None).await.unwrap(); + self.writer.save_block(block).await.unwrap(); } self } diff --git a/lib/src/sync.rs b/lib/src/sync.rs index 99f693c5a..f39704fae 100644 --- a/lib/src/sync.rs +++ b/lib/src/sync.rs @@ -581,114 +581,6 @@ pub(crate) mod stream { } } -/// A hash map whose entries expire after a timeout. -pub(crate) mod delay_map { - use crate::collections::{hash_map, HashMap}; - use std::{ - borrow::Borrow, - hash::Hash, - task::{ready, Context, Poll}, - }; - use tokio::time::Duration; - use tokio_util::time::delay_queue::{DelayQueue, Key}; - - pub struct DelayMap { - items: HashMap>, - delays: DelayQueue, - } - - impl DelayMap { - pub fn new() -> Self { - Self { - items: HashMap::default(), - delays: DelayQueue::default(), - } - } - } - - impl DelayMap - where - K: Eq + Hash + Clone, - { - // This is unused right now but keeping it around so we don't have to recreate when we need - // it. - #[allow(unused)] - pub fn insert(&mut self, key: K, value: V, timeout: Duration) -> Option { - let delay_key = self.delays.insert(key.clone(), timeout); - let old = self.items.insert(key, Item { value, delay_key })?; - - self.delays.remove(&old.delay_key); - - Some(old.value) - } - - pub fn try_insert(&mut self, key: K) -> Option> { - match self.items.entry(key) { - hash_map::Entry::Vacant(item_entry) => Some(VacantEntry { - item_entry, - delays: &mut self.delays, - }), - hash_map::Entry::Occupied(_) => None, - } - } - - pub fn remove(&mut self, key: &Q) -> Option - where - K: Borrow, - Q: Hash + Eq + ?Sized, - { - let item = self.items.remove(key)?; - self.delays.remove(&item.delay_key); - - Some(item.value) - } - - pub fn len(&self) -> usize { - self.items.len() - } - - /// Poll for the next expired item. This can be wrapped in `future::poll_fn` and awaited. - /// Returns `Poll::Ready(None)` if the map is empty. - pub fn poll_expired(&mut self, cx: &mut Context<'_>) -> Poll> { - if let Some(expired) = ready!(self.delays.poll_expired(cx)) { - let key = expired.into_inner(); - // unwrap is OK because an entry exists in `delays` iff it exists in `items`. - let item = self.items.remove(&key).unwrap(); - - Poll::Ready(Some((key, item.value))) - } else { - Poll::Ready(None) - } - } - } - - impl Default for DelayMap { - fn default() -> Self { - Self::new() - } - } - - pub struct VacantEntry<'a, K, V> { - item_entry: hash_map::VacantEntry<'a, K, Item>, - delays: &'a mut DelayQueue, - } - - impl<'a, K, V> VacantEntry<'a, K, V> - where - K: Clone, - { - pub fn insert(self, value: V, timeout: Duration) -> &'a V { - let delay_key = self.delays.insert(self.item_entry.key().clone(), timeout); - &mut self.item_entry.insert(Item { value, delay_key }).value - } - } - - struct Item { - value: V, - delay_key: Key, - } -} - /// Extensions for `tokio::sync::watch::Sender`. pub(crate) trait WatchSenderExt { // Like `send_modify` but allows returning a value from the closure. From 53709a346ddf3fa88b3c80fb573e00f5364e2776 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Mon, 30 Sep 2024 16:19:28 +0200 Subject: [PATCH 39/55] Implement transactional semantics for RequestTracker --- lib/src/block_tracker.rs | 33 +-- lib/src/network/client.rs | 36 +-- lib/src/network/request_tracker.rs | 251 +++++++++++------- lib/src/network/request_tracker/simulation.rs | 4 + lib/src/network/request_tracker/tests.rs | 2 + 5 files changed, 197 insertions(+), 129 deletions(-) diff --git a/lib/src/block_tracker.rs b/lib/src/block_tracker.rs index 433fe7065..544a7bbf9 100644 --- a/lib/src/block_tracker.rs +++ b/lib/src/block_tracker.rs @@ -163,33 +163,29 @@ impl Worker { Command::InsertClient { client_id, block_tx, - } => self.handle_insert_client(client_id, block_tx), - Command::RemoveClient { client_id } => self.handle_remove_client(client_id), + } => self.insert_client(client_id, block_tx), + Command::RemoveClient { client_id } => self.remove_client(client_id), Command::InsertOffer { client_id, block_id, state, - } => self.handle_insert_offer(client_id, block_id, state), + } => self.insert_offer(client_id, block_id, state), Command::ApproveOffers { block_id } => self.handle_approve_offers(block_id), - Command::SetRequestMode { mode } => self.handle_set_request_mode(mode), - Command::InsertRequired { block_id } => self.handle_insert_required(block_id), - Command::ClearRequired => self.handle_clear_required(), + Command::SetRequestMode { mode } => self.set_request_mode(mode), + Command::InsertRequired { block_id } => self.insert_required(block_id), + Command::ClearRequired => self.clear_required(), } } - fn handle_insert_client( - &mut self, - client_id: ClientId, - block_tx: mpsc::UnboundedSender, - ) { + fn insert_client(&mut self, client_id: ClientId, block_tx: mpsc::UnboundedSender) { self.clients.insert(client_id, ClientState::new(block_tx)); } - fn handle_remove_client(&mut self, client_id: ClientId) { + fn remove_client(&mut self, client_id: ClientId) { self.clients.remove(&client_id); } - fn handle_set_request_mode(&mut self, mode: BlockRequestMode) { + fn set_request_mode(&mut self, mode: BlockRequestMode) { self.request_mode = mode; match mode { @@ -204,7 +200,7 @@ impl Worker { } } - fn handle_insert_required(&mut self, block_id: BlockId) { + fn insert_required(&mut self, block_id: BlockId) { if !self.required_blocks.insert(block_id) { // already required return; @@ -224,16 +220,11 @@ impl Worker { } } - fn handle_clear_required(&mut self) { + fn clear_required(&mut self) { self.required_blocks.clear(); } - fn handle_insert_offer( - &mut self, - client_id: ClientId, - block_id: BlockId, - new_state: BlockOfferState, - ) { + fn insert_offer(&mut self, client_id: ClientId, block_id: BlockId, new_state: BlockOfferState) { let Some(client_state) = self.clients.get_mut(&client_id) else { return; }; diff --git a/lib/src/network/client.rs b/lib/src/network/client.rs index 7fc6cff81..28b55b599 100644 --- a/lib/src/network/client.rs +++ b/lib/src/network/client.rs @@ -367,25 +367,31 @@ impl Inner { } async fn commit_responses(&self, writer: ClientWriter) -> Result<()> { - let event_tx = self.vault.event_tx.clone(); let status = writer - .commit_and_then(move |status| { - // Notify about newly written blocks. - for block_id in &status.new_blocks { - event_tx.send(Payload::BlockReceived(*block_id)); - } + .commit_and_then({ + let committer = self.request_tracker.new_committer(); + let event_tx = self.vault.event_tx.clone(); - // Notify about newly approved snapshots - for branch_id in &status.approved_branches { - event_tx.send(Payload::SnapshotApproved(*branch_id)); - } + move |status| { + committer.commit(); - // Notify about newly rejected snapshots - for branch_id in &status.rejected_branches { - event_tx.send(Payload::SnapshotRejected(*branch_id)); - } + // Notify about newly written blocks. + for block_id in &status.new_blocks { + event_tx.send(Payload::BlockReceived(*block_id)); + } - status + // Notify about newly approved snapshots + for branch_id in &status.approved_branches { + event_tx.send(Payload::SnapshotApproved(*branch_id)); + } + + // Notify about newly rejected snapshots + for branch_id in &status.rejected_branches { + event_tx.send(Payload::SnapshotRejected(*branch_id)); + } + + status + } }) .await?; diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index c298dd5ee..eb74098fd 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -100,23 +100,40 @@ impl RequestTrackerClient { .ok(); } - /// Commit all successfully completed requests. - /// - /// If this client is dropped before this is called, all requests successfully completed by this - /// client will be considered failed and will be made available for retry by other clients. - pub fn commit(&self) { + /// Obtain a handle to commit all the requests that were successfully completed by this client. + /// The handle can be sent to other tasks/threads before invoking the commit. + pub fn new_committer(&self) -> RequestTrackerCommitter { + RequestTrackerCommitter { + client_id: self.client_id, + command_tx: self.command_tx.clone(), + } + } +} + +impl Drop for RequestTrackerClient { + fn drop(&mut self) { self.command_tx - .send(Command::Commit { + .send(Command::RemoveClient { client_id: self.client_id, }) .ok(); } } -impl Drop for RequestTrackerClient { - fn drop(&mut self) { +pub(crate) struct RequestTrackerCommitter { + client_id: ClientId, + command_tx: mpsc::UnboundedSender, +} + +impl RequestTrackerCommitter { + /// Commit all successfully completed requests. + /// + /// If the client associated with this committer is dropped before this is called, all requests + /// successfully completed by the client will be considered failed and will be made available + /// for retry by other clients. + pub fn commit(self) { self.command_tx - .send(Command::RemoveClient { + .send(Command::Commit { client_id: self.client_id, }) .ok(); @@ -286,58 +303,49 @@ impl Worker { let node_key = self .clients - .get_mut(&client_id) - .and_then(|state| state.requests.remove(&request_key)); + .get(&client_id) + .and_then(|state| state.requests.get(&request_key)) + .copied(); - let (mut client_ids, remove_key) = if let Some(node_key) = node_key { + let mut client_ids = if let Some(node_key) = node_key { let Some(node) = self.requests.get_mut(node_key) else { return; }; - match node.value_mut() { + let (sender_timer_key, waiters) = match node.value_mut() { RequestState::InFlight { sender_client_id, sender_timer_key, waiters, - } if *sender_client_id == client_id => { - self.timer.try_remove(sender_timer_key); - - let waiters = mem::take(waiters); - let remove_key = if node.children().len() > 0 || !requests.is_empty() { - *node.value_mut() = RequestState::Complete; - None - } else { - Some(node_key) - }; + } if *sender_client_id == client_id => (*sender_timer_key, mem::take(waiters)), + RequestState::InFlight { .. } + | RequestState::Complete { .. } + | RequestState::Committed + | RequestState::Cancelled => return, + }; - (waiters, remove_key) - } - RequestState::InFlight { waiters, .. } => { - remove_from_queue(waiters, &client_id); - return; - } - RequestState::Complete | RequestState::Cancelled => return, - } - } else { - (Default::default(), None) - }; + let client_ids = if requests.is_empty() { + Vec::new() + } else { + iter::once(client_id) + .chain(waiters.iter().cloned()) + .collect() + }; - // Remove the node from the other waiting clients, if any. - for client_id in &client_ids { - if let Some(state) = self.clients.get_mut(client_id) { - state.requests.remove(&request_key); - } - } + *node.value_mut() = RequestState::Complete { + sender_client_id: client_id, + waiters, + }; - // If the node has no children, remove it. - if let Some(node_key) = remove_key { - self.remove_request(node_key); - } + self.timer.try_remove(&sender_timer_key); - // Register the followup (child) requests with this client but also with all the clients that - // were waiting for the original request. - client_ids.push_front(client_id); + client_ids + } else { + vec![client_id] + }; + // Register the followup (child) requests with this client but also with all the clients + // that were waiting for the original request. for child_request in requests { for (client_id, child_request) in // TODO: use `repeat_n` once it gets stabilized. @@ -373,9 +381,57 @@ impl Worker { #[instrument(skip(self))] fn commit(&mut self, client_id: ClientId) { - tracing::trace!("commit_failure"); + tracing::trace!("commit"); + + // Collect all requests completed by this client. + let requests: Vec<_> = self + .clients + .get_mut(&client_id) + .into_iter() + .flat_map(|client_state| client_state.requests.iter()) + .filter(|(_, node_key)| { + self.requests + .get(**node_key) + .map(|node| match node.value() { + RequestState::Complete { + sender_client_id, .. + } if *sender_client_id == client_id => true, + RequestState::Complete { .. } + | RequestState::InFlight { .. } + | RequestState::Committed + | RequestState::Cancelled => false, + }) + .unwrap_or(false) + }) + .map(|(request_key, node_key)| (*request_key, *node_key)) + .collect(); + + for (request_key, node_key) in requests { + let Some(node) = self.requests.get_mut(node_key) else { + unreachable!() + }; + + let waiters = match node.value_mut() { + RequestState::Complete { waiters, .. } => mem::take(waiters), + RequestState::InFlight { .. } + | RequestState::Committed + | RequestState::Cancelled => unreachable!(), + }; - todo!() + // Remove the requests from this client and all the waiters + for client_id in iter::once(client_id).chain(waiters) { + if let Some(client_state) = self.clients.get_mut(&client_id) { + client_state.requests.remove(&request_key); + } + } + + // If the node has no children, remove it, otherwise mark is as committed. + if node.children().len() == 0 { + self.remove_request(node_key); + } else { + *node.value_mut() = RequestState::Committed; + } + } } fn insert_request( @@ -403,11 +459,11 @@ impl Worker { let request_key = MessageKey::from(&node.request().request); match node.value_mut() { - RequestState::InFlight { waiters, .. } => { + RequestState::InFlight { waiters, .. } | RequestState::Complete { waiters, .. } => { waiters.push_back(client_id); client_state.requests.insert(request_key, node_key); } - RequestState::Complete => (), + RequestState::Committed => (), RequestState::Cancelled => { let timer_key = self.timer.insert((client_id, request_key), REQUEST_TIMEOUT); @@ -438,57 +494,57 @@ impl Worker { let (request, state) = node.parts_mut(); - match state { + let (sender_client_id, sender_timer_key, waiters) = match state { RequestState::InFlight { sender_client_id, sender_timer_key, waiters, - } => { - if *sender_client_id == client_id { - // The removed client is the current sender of this request. + } => (sender_client_id, Some(sender_timer_key), waiters), + RequestState::Complete { + sender_client_id, + waiters, + } => (sender_client_id, None, waiters), + RequestState::Committed | RequestState::Cancelled => return, + }; - // Remove the timeout for the previous sender - self.timer.try_remove(sender_timer_key); + if *sender_client_id != client_id { + remove_from_queue(waiters, &client_id); + return; + } - // Find a waiting client - let next_client = iter::from_fn(|| waiters.pop_front()).find_map(|client_id| { - self.clients - .get(&client_id) - .map(|client_state| (client_id, client_state)) - }); + if let Some(timer_key) = sender_timer_key { + self.timer.try_remove(timer_key); + } - if let Some((next_client_id, next_client_state)) = next_client { - // Next waiting client found. Promote it to a sender. + // Find next waiting client + let next_client = iter::from_fn(|| waiters.pop_front()) + .find_map(|client_id| self.clients.get_key_value(&client_id)); - *sender_client_id = next_client_id; - *sender_timer_key = self.timer.insert( - (next_client_id, MessageKey::from(&request.request)), - REQUEST_TIMEOUT, - ); + if let Some((&next_client_id, next_client_state)) = next_client { + // Next waiting client found. Promote it to the sender. + let sender_client_id = next_client_id; + let sender_timer_key = self.timer.insert( + (next_client_id, MessageKey::from(&request.request)), + REQUEST_TIMEOUT, + ); + let waiters = mem::take(waiters); - // Send the request to the new sender. - next_client_state.request_tx.send(request.clone()).ok(); + *state = RequestState::InFlight { + sender_client_id, + sender_timer_key, + waiters, + }; - return; - } else { - // No waiting client found. If this request has no children, we can remove - // it, otherwise we mark it as cancelled. - if node.children().len() > 0 { - *node.value_mut() = RequestState::Cancelled; - return; - } - } - } else { - // The removed client is one of the waiting clients - remove it from the - // waiting queue. - remove_from_queue(waiters, &client_id); - return; - } + next_client_state.request_tx.send(request.clone()).ok(); + } else { + // No waiting client found. If this request has no children, we can remove + // it, otherwise we mark it as cancelled. + if node.children().len() > 0 { + *node.value_mut() = RequestState::Cancelled; + } else { + self.remove_request(node_key); } - RequestState::Complete | RequestState::Cancelled => return, - }; - - self.remove_request(node_key); + } } fn remove_request(&mut self, node_key: GraphKey) { @@ -571,8 +627,17 @@ enum RequestState { /// timeouts, a new one will be picked from this list. waiters: VecDeque, }, - /// The response to this request has already been received. - Complete, + /// The response to this request has already been received but the request hasn't been committed + /// because the response hasn't been fully processed yet. If the sender client is dropped + /// before this request gets committed, a new sender is picked and the request is switched back + /// to `InFlight`. + Complete { + sender_client_id: ClientId, + waiters: VecDeque, + }, + /// The response to this request has been received and fully processed. This request won't be + /// retried even when the sender client gets dropped. + Committed, /// The response for the current client failed and there are no more clients waiting. Cancelled, } diff --git a/lib/src/network/request_tracker/simulation.rs b/lib/src/network/request_tracker/simulation.rs index d1d1cffcd..c04e144b1 100644 --- a/lib/src/network/request_tracker/simulation.rs +++ b/lib/src/network/request_tracker/simulation.rs @@ -259,6 +259,10 @@ impl TestClient { } Response::BlockOffer(_block_id, _debug_payload) => unimplemented!(), }; + + // Note: for simplicity, in this simulation we `commit` after every operation. To test + // committing properly, separate tests not based on this simulation need to be used. + self.tracker_client.new_committer().commit(); } fn poll_request(&mut self) -> Option { diff --git a/lib/src/network/request_tracker/tests.rs b/lib/src/network/request_tracker/tests.rs index 2fadf53dd..593fc26ee 100644 --- a/lib/src/network/request_tracker/tests.rs +++ b/lib/src/network/request_tracker/tests.rs @@ -21,6 +21,8 @@ use tokio::{sync::mpsc, time}; // needs a tokio runtime. #[tokio::test] async fn dynamic_swarm() { + // crate::test_utils::init_log(); + let seed = rand::random(); case(seed, 64, 4); From c91c08297a64aa04c19bf3748574afb8c5791cfa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Mon, 30 Sep 2024 16:44:02 +0200 Subject: [PATCH 40/55] Add test for dropping uncommitted RequestTracker client --- lib/src/network/request_tracker/graph.rs | 1 - lib/src/network/request_tracker/tests.rs | 73 +++++++++++++++++++++--- 2 files changed, 66 insertions(+), 8 deletions(-) diff --git a/lib/src/network/request_tracker/graph.rs b/lib/src/network/request_tracker/graph.rs index c94661901..1db4fe6d6 100644 --- a/lib/src/network/request_tracker/graph.rs +++ b/lib/src/network/request_tracker/graph.rs @@ -116,7 +116,6 @@ pub(super) struct Node { } impl Node { - #[cfg_attr(not(test), expect(dead_code))] pub fn value(&self) -> &T { &self.value } diff --git a/lib/src/network/request_tracker/tests.rs b/lib/src/network/request_tracker/tests.rs index 593fc26ee..ee0a3401f 100644 --- a/lib/src/network/request_tracker/tests.rs +++ b/lib/src/network/request_tracker/tests.rs @@ -13,7 +13,7 @@ use rand::{ Rng, SeedableRng, }; use std::{pin::pin, time::Duration}; -use tokio::{sync::mpsc, time}; +use tokio::{sync::mpsc::error::TryRecvError, time}; // Test syncing while peers keep joining and leaving the swarm. // @@ -182,15 +182,15 @@ async fn timeout() { .await .ok(); - // Only the first client gets the send permit. + // Only the first client gets the request. assert_eq!( - request_rx_a.try_recv().map(|permit| permit.request), + request_rx_a.try_recv().map(|r| r.request), Ok(request.clone()) ); assert_eq!( - request_rx_b.try_recv().map(|permit| permit.request), - Err(mpsc::error::TryRecvError::Empty), + request_rx_b.try_recv().map(|r| r.request), + Err(TryRecvError::Empty) ); // Wait until the request timeout passes @@ -198,13 +198,72 @@ async fn timeout() { .await .ok(); - // The first client timeouted so the second client now gets the permit. + // The first client timeouted so the second client now gets the request. assert_eq!( - request_rx_b.try_recv().map(|permit| permit.request), + request_rx_b.try_recv().map(|r| r.request), Ok(request.clone()) ); } +#[tokio::test] +async fn drop_uncommitted_client() { + let mut rng = StdRng::seed_from_u64(0); + let (tracker, mut tracker_worker) = build(); + + let (client_a, mut request_rx_a) = tracker.new_client(); + let (client_b, mut request_rx_b) = tracker.new_client(); + + let preceding_request_key = MessageKey::RootNode(PublicKey::generate(&mut rng)); + let request = Request::ChildNodes( + rng.gen(), + ResponseDisambiguator::new(MultiBlockPresence::Full), + DebugRequest::start(), + ); + let request_key = MessageKey::from(&request); + + for client in [&client_a, &client_b] { + client.success( + preceding_request_key, + vec![PendingRequest { + request: request.clone(), + block_presence: MultiBlockPresence::Full, + }], + ); + } + + tracker_worker.step(); + + assert_eq!( + request_rx_a.try_recv().map(|r| r.request), + Ok(request.clone()) + ); + assert_eq!( + request_rx_b.try_recv().map(|r| r.request), + Err(TryRecvError::Empty) + ); + + // Complete the request by the first client. + client_a.success(request_key, vec![]); + tracker_worker.step(); + + assert_eq!( + request_rx_a.try_recv().map(|r| r.request), + Err(TryRecvError::Empty) + ); + assert_eq!( + request_rx_b.try_recv().map(|r| r.request), + Err(TryRecvError::Empty) + ); + + // Drop the first client without commiting. + drop(client_a); + tracker_worker.step(); + + // The request falls back to the other client because although the request was completed, it + // wasn't committed. + assert_eq!(request_rx_b.try_recv().map(|r| r.request), Ok(request)); +} + /// Generate `count + 1` copies of the same snapshot. The first one will have all the blocks /// present (the "master copy"). The remaining ones will have some blocks missing but in such a /// way that every block is present in at least one of the snapshots. From 0e59ac485b6b119c70ccc3803e7b02c1af1868d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Tue, 1 Oct 2024 11:24:36 +0200 Subject: [PATCH 41/55] Fix BlockTracker suppressing repeated offers --- lib/src/block_tracker.rs | 80 +++++++++++++++-------- lib/src/store/block_expiration_tracker.rs | 3 + 2 files changed, 55 insertions(+), 28 deletions(-) diff --git a/lib/src/block_tracker.rs b/lib/src/block_tracker.rs index 544a7bbf9..682a5908d 100644 --- a/lib/src/block_tracker.rs +++ b/lib/src/block_tracker.rs @@ -4,10 +4,10 @@ use crate::{ }; use std::{ collections::hash_map::Entry, - mem, sync::atomic::{AtomicUsize, Ordering}, }; use tokio::{sync::mpsc, task}; +use tracing::{Instrument, Span}; /// Tracks blocks that are offered for requesting from some peers and blocks that are required /// locally. If a block is both, a notification is triggered to prompt us to request the block from @@ -23,7 +23,7 @@ pub(crate) struct BlockTracker { impl BlockTracker { pub fn new() -> Self { let (this, worker) = build(); - task::spawn(worker.run()); + task::spawn(worker.run().instrument(Span::current())); this } @@ -108,7 +108,7 @@ impl Drop for BlockTrackerClient { } } -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Eq, PartialEq)] pub(crate) enum BlockRequestMode { // Request only required blocks Lazy, @@ -116,7 +116,7 @@ pub(crate) enum BlockRequestMode { Greedy, } -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Debug)] pub(crate) enum BlockOfferState { Pending, Approved, @@ -229,30 +229,36 @@ impl Worker { return; }; - let old_state = mem::replace( - client_state - .offers - .entry(block_id) - .or_insert(BlockOfferState::Pending), - new_state, - ); - - if matches!(new_state, BlockOfferState::Pending) { - return; - } - - if matches!(old_state, BlockOfferState::Approved) { - return; - } - - match self.request_mode { - BlockRequestMode::Greedy => { - client_state.block_tx.send(block_id).ok(); - } - BlockRequestMode::Lazy if self.required_blocks.contains(&block_id) => { - client_state.block_tx.send(block_id).ok(); - } - BlockRequestMode::Lazy => (), + match client_state.offers.entry(block_id) { + Entry::Occupied(mut entry) => match (entry.get(), new_state) { + (BlockOfferState::Pending, BlockOfferState::Approved) => { + if self.request_mode == BlockRequestMode::Greedy + || self.required_blocks.contains(&block_id) + { + entry.remove(); + client_state.block_tx.send(block_id).ok(); + } else { + entry.insert(BlockOfferState::Approved); + } + } + (BlockOfferState::Pending, BlockOfferState::Pending) + | (BlockOfferState::Approved, BlockOfferState::Pending) + | (BlockOfferState::Approved, BlockOfferState::Approved) => (), + }, + Entry::Vacant(entry) => match new_state { + BlockOfferState::Pending => { + entry.insert(BlockOfferState::Pending); + } + BlockOfferState::Approved => { + if self.request_mode == BlockRequestMode::Greedy + || self.required_blocks.contains(&block_id) + { + client_state.block_tx.send(block_id).ok(); + } else { + entry.insert(BlockOfferState::Approved); + } + } + }, } } @@ -650,7 +656,25 @@ mod tests { let (client, mut block_rx) = tracker.new_client(); client.offer(block_id, BlockOfferState::Approved); + worker.step(); assert_eq!(block_rx.try_recv(), Err(TryRecvError::Empty)); } + + #[test] + fn repeated_offer() { + let (tracker, mut worker) = build(); + let block_id = BlockId::try_from([0; BlockId::SIZE].as_ref()).unwrap(); + + let (client, mut block_rx) = tracker.new_client(); + client.offer(block_id, BlockOfferState::Approved); + worker.step(); + + assert_eq!(block_rx.try_recv(), Ok(block_id)); + + client.offer(block_id, BlockOfferState::Approved); + worker.step(); + + assert_eq!(block_rx.try_recv(), Ok(block_id)); + } } diff --git a/lib/src/store/block_expiration_tracker.rs b/lib/src/store/block_expiration_tracker.rs index 1bab8d956..4eceab8a6 100644 --- a/lib/src/store/block_expiration_tracker.rs +++ b/lib/src/store/block_expiration_tracker.rs @@ -395,6 +395,9 @@ async fn set_as_missing_if_expired( continue; } + // TODO: Why is this here? For blind replicas requiring blocks is never necessary because + // the block tracker produces block ids immediatelly after they become offered by peers. + // For non-blind replicas blocks are already required in the `scan` job. block_download_tracker.require(*block_id); for (hash, _state) in index::update_summaries(&mut tx, parent_hashes).await? { From 016b7c10f579fde1258b73bda1e19278d13714de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Tue, 1 Oct 2024 14:59:22 +0200 Subject: [PATCH 42/55] Do not suppress followups to duplicate responses --- lib/src/network/request_tracker.rs | 27 +++++++------ lib/src/network/request_tracker/tests.rs | 48 ++++++++++++++++++++++++ lib/src/network/tests.rs | 7 +--- 3 files changed, 65 insertions(+), 17 deletions(-) diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index eb74098fd..b276633f0 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -312,32 +312,37 @@ impl Worker { return; }; - let (sender_timer_key, waiters) = match node.value_mut() { + let waiters = match node.value_mut() { RequestState::InFlight { sender_client_id, sender_timer_key, waiters, - } if *sender_client_id == client_id => (*sender_timer_key, mem::take(waiters)), + } if *sender_client_id == client_id => { + self.timer.try_remove(sender_timer_key); + Some(mem::take(waiters)) + } RequestState::InFlight { .. } | RequestState::Complete { .. } | RequestState::Committed - | RequestState::Cancelled => return, + | RequestState::Cancelled => None, }; let client_ids = if requests.is_empty() { Vec::new() } else { iter::once(client_id) - .chain(waiters.iter().cloned()) + .chain(waiters.as_ref().into_iter().flatten().copied()) .collect() }; - *node.value_mut() = RequestState::Complete { - sender_client_id: client_id, - waiters, - }; - - self.timer.try_remove(&sender_timer_key); + // If the request was `InFlight` from this client, switch it to `Complete`. Otherwise + // keep it as is. + if let Some(waiters) = waiters { + *node.value_mut() = RequestState::Complete { + sender_client_id: client_id, + waiters, + }; + } client_ids } else { @@ -408,7 +413,7 @@ impl Worker { for (request_key, node_key) in requests { let Some(node) = self.requests.get_mut(node_key) else { - unreachable!() + continue; }; let waiters = match node.value_mut() { diff --git a/lib/src/network/request_tracker/tests.rs b/lib/src/network/request_tracker/tests.rs index ee0a3401f..4e7683158 100644 --- a/lib/src/network/request_tracker/tests.rs +++ b/lib/src/network/request_tracker/tests.rs @@ -6,6 +6,7 @@ use crate::{ Block, }, }; +use assert_matches::assert_matches; use rand::{ distributions::{Bernoulli, Distribution, Standard}, rngs::StdRng, @@ -264,6 +265,53 @@ async fn drop_uncommitted_client() { assert_eq!(request_rx_b.try_recv().map(|r| r.request), Ok(request)); } +#[tokio::test] +async fn multiple_responses_to_identical_requests() { + let mut rng = StdRng::seed_from_u64(0); + let (tracker, mut worker) = build(); + let (client, mut request_rx) = tracker.new_client(); + + let initial_request = Request::RootNode(PublicKey::generate(&mut rng), DebugRequest::start()); + let followup_request = Request::ChildNodes( + rng.gen(), + ResponseDisambiguator::new(MultiBlockPresence::Full), + DebugRequest::start(), + ); + + // Send initial root node request + client.initial(initial_request.clone()); + worker.step(); + + assert_matches!(request_rx.try_recv(), Ok(_)); + + // Receive response to it + client.success(MessageKey::from(&initial_request), vec![]); + worker.step(); + + // do not commmit yet + + // Receive another response, this time unsolicited, which has the same key but different + // followups than the one received previously. + client.success( + MessageKey::from(&initial_request), + vec![PendingRequest { + request: followup_request.clone(), + block_presence: MultiBlockPresence::Full, + }], + ); + worker.step(); + + // The followup requests are sent even though the + assert_eq!( + request_rx.try_recv().map(|r| r.request), + Ok(followup_request) + ); + + // TODO: test these cases as well: + // - the initial request gets committed, but remains tracked because it has in-flight followups. + // - the responses are received by different clients +} + /// Generate `count + 1` copies of the same snapshot. The first one will have all the blocks /// present (the "master copy"). The remaining ones will have some blocks missing but in such a /// way that every block is present in at least one of the snapshots. diff --git a/lib/src/network/tests.rs b/lib/src/network/tests.rs index 494f717a0..cdddbd010 100644 --- a/lib/src/network/tests.rs +++ b/lib/src/network/tests.rs @@ -91,6 +91,7 @@ async fn transfer_snapshot_between_two_replicas_case( wait_until_snapshots_in_sync(&a_vault, a_id, &b_vault).await; if remaining_changesets > 0 { + tracing::info!("create changeset"); create_changeset(&mut rng, &a_vault, &a_id, &write_keys, changeset_size).await; remaining_changesets -= 1; } else { @@ -116,12 +117,6 @@ async fn transfer_blocks_between_two_replicas( transfer_blocks_between_two_replicas_case(block_count, rng_seed).await } -// #[tokio::test] -// async fn debug() { -// test_utils::init_log(); -// transfer_blocks_between_two_replicas_case(1, 0).await -// } - async fn transfer_blocks_between_two_replicas_case(block_count: usize, rng_seed: u64) { let mut rng = StdRng::seed_from_u64(rng_seed); From 1d7fd25773223c30d9a1178ea00814be33f57302 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Wed, 2 Oct 2024 12:05:38 +0200 Subject: [PATCH 43/55] Fix failed block requests not being retried in some cases --- lib/src/network/client.rs | 90 ++++-- lib/src/network/request_tracker.rs | 296 +++++++++++++----- lib/src/network/request_tracker/graph.rs | 21 ++ lib/src/network/request_tracker/simulation.rs | 28 +- lib/src/network/request_tracker/tests.rs | 51 ++- 5 files changed, 349 insertions(+), 137 deletions(-) diff --git a/lib/src/network/client.rs b/lib/src/network/client.rs index 28b55b599..30a4edeee 100644 --- a/lib/src/network/client.rs +++ b/lib/src/network/client.rs @@ -9,7 +9,10 @@ use crate::{ crypto::{sign::PublicKey, CacheHash, Hashable}, error::Result, event::Payload, - network::{message::Request, request_tracker::MessageKey}, + network::{ + message::Request, + request_tracker::{CandidateRequest, MessageKey}, + }, protocol::{ Block, BlockId, InnerNodes, LeafNodes, MultiBlockPresence, NodeState, ProofError, RootNodeFilter, UntrustedProof, @@ -249,19 +252,21 @@ impl Inner { tracing::debug!("Received root node - {status}"); - if status.request_children() { - self.request_tracker.success( - MessageKey::RootNode(writer_id), - vec![PendingRequest { - request: Request::ChildNodes( + self.request_tracker.success( + MessageKey::RootNode(writer_id), + status + .request_children() + .then_some( + CandidateRequest::new(Request::ChildNodes( hash, ResponseDisambiguator::new(block_presence), debug_payload.follow_up(), - ), - block_presence, - }], - ); - } + )) + .follow_up(block_presence), + ) + .into_iter() + .collect(), + ); Ok(()) } @@ -288,13 +293,13 @@ impl Inner { status .new_children .into_iter() - .map(|node| PendingRequest { - request: Request::ChildNodes( + .map(|node| { + CandidateRequest::new(Request::ChildNodes( node.hash, ResponseDisambiguator::new(node.summary.block_presence), debug_payload.follow_up(), - ), - block_presence: node.summary.block_presence, + )) + .follow_up(node.summary.block_presence) }) .collect(), ); @@ -319,14 +324,31 @@ impl Inner { total, ); - for (block_id, state) in status.new_block_offers { - if let Some(state) = block_offer_state(state) { - self.block_tracker.offer(block_id, state); - } - } + // IMPORTANT: Make sure the request tracker is processed before the block tracker to ensure + // the request is first inserted and only then resumed. - self.request_tracker - .success(MessageKey::ChildNodes(hash), Vec::new()); + let offers: Vec<_> = status + .new_block_offers + .into_iter() + .filter_map(|(block_id, root_node_state)| { + block_offer_state(root_node_state).map(move |offer_state| (block_id, offer_state)) + }) + .collect(); + + self.request_tracker.success( + MessageKey::ChildNodes(hash), + offers + .iter() + .map(|(block_id, _)| { + CandidateRequest::new(Request::Block(*block_id, debug_payload.follow_up())) + .suspended() + }) + .collect(), + ); + + for (block_id, offer_state) in offers { + self.block_tracker.offer(block_id, offer_state); + } Ok(()) } @@ -342,9 +364,18 @@ impl Inner { .load_effective_root_node_state_for_block(&block_id) .await?; - if let Some(offer_state) = block_offer_state(root_node_state) { - self.block_tracker.offer(block_id, offer_state); - } + let Some(offer_state) = block_offer_state(root_node_state) else { + return Ok(()); + }; + + // IMPORTANT: Make sure the request tracker is processed before the block tracker to ensure + // the request is first inserted and only then resumed. + + self.request_tracker.initial( + CandidateRequest::new(Request::Block(block_id, debug_payload.follow_up())).suspended(), + ); + + self.block_tracker.offer(block_id, offer_state); Ok(()) } @@ -409,7 +440,7 @@ impl Inner { async fn request_blocks(&self, block_rx: &mut mpsc::UnboundedReceiver) { while let Some(block_id) = block_rx.recv().await { self.request_tracker - .initial(Request::Block(block_id, DebugRequest::start())); + .resume(MessageKey::Block(block_id), MultiBlockPresence::None); } } @@ -441,9 +472,12 @@ impl Inner { // By requesting the root node again immediatelly, we ensure that the missing block is // requested as soon as possible. fn refresh_branches(&self, branches: impl IntoIterator) { - for branch_id in branches { + for writer_id in branches { self.request_tracker - .initial(Request::RootNode(branch_id, DebugRequest::start())); + .initial(CandidateRequest::new(Request::RootNode( + writer_id, + DebugRequest::start(), + ))); } } diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index b276633f0..7003399c0 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -13,14 +13,14 @@ use crate::{ protocol::{BlockId, MultiBlockPresence}, }; use std::{ - collections::VecDeque, + collections::{hash_map::Entry, VecDeque}, iter, mem, sync::atomic::{AtomicUsize, Ordering}, }; use tokio::{select, sync::mpsc, task}; use tokio_stream::StreamExt; use tokio_util::time::{delay_queue, DelayQueue}; -use tracing::instrument; +use tracing::{instrument, Instrument, Span}; /// Keeps track of in-flight requests. Falls back on another peer in case the request failed (due to /// error response, timeout or disconnection). Evenly distributes the requests between the peers @@ -31,9 +31,10 @@ pub(super) struct RequestTracker { } impl RequestTracker { + // TODO: Make request timeout configurable pub fn new() -> Self { let (this, worker) = build(); - task::spawn(worker.run()); + task::spawn(worker.run().instrument(Span::current())); this } @@ -70,7 +71,7 @@ pub(super) struct RequestTrackerClient { impl RequestTrackerClient { /// Handle sending a request that does not follow from any previously received response. - pub fn initial(&self, request: Request) { + pub fn initial(&self, request: CandidateRequest) { self.command_tx .send(Command::HandleInitial { client_id: self.client_id, @@ -80,7 +81,7 @@ impl RequestTrackerClient { } /// Handle sending requests that follow from a received success response. - pub fn success(&self, request_key: MessageKey, requests: Vec) { + pub fn success(&self, request_key: MessageKey, requests: Vec) { self.command_tx .send(Command::HandleSuccess { client_id: self.client_id, @@ -100,6 +101,16 @@ impl RequestTrackerClient { .ok(); } + /// Resume suspended request. + pub fn resume(&self, request_key: MessageKey, block_presence: MultiBlockPresence) { + self.command_tx + .send(Command::Resume { + request_key, + block_presence, + }) + .ok(); + } + /// Obtain a handle to commit all the requests that were successfully completed by this client. /// The handle can be sent to other tasks/threads before invoking the commit. pub fn new_committer(&self) -> RequestTrackerCommitter { @@ -140,7 +151,48 @@ impl RequestTrackerCommitter { } } -/// Request to be sent to the peer. +/// Request that we want to send to the peer. +#[derive(Clone, Debug)] +pub(super) struct CandidateRequest { + pub request: Request, + pub block_presence: MultiBlockPresence, + pub state: InitialRequestState, +} + +impl CandidateRequest { + /// Create new candiate for the given request. + pub fn new(request: Request) -> Self { + Self { + request, + block_presence: MultiBlockPresence::None, + state: InitialRequestState::InFlight, + } + } + + /// Set this candidate to a followup request to a response with the given block presence. + pub fn follow_up(self, block_presence: MultiBlockPresence) -> Self { + Self { + block_presence, + ..self + } + } + + /// Start this request in suspended state. + pub fn suspended(self) -> Self { + Self { + state: InitialRequestState::Suspended, + ..self + } + } +} + +#[derive(Clone, Copy, Debug)] +pub(super) enum InitialRequestState { + InFlight, + Suspended, +} + +/// Request that is ready to be sent to the peer. /// /// It also contains the block presence from the response that triggered this request. This is /// mostly useful for diagnostics and testing. @@ -248,6 +300,10 @@ impl Worker { } => { self.handle_failure(client_id, request_key, FailureReason::Response); } + Command::Resume { + request_key, + block_presence, + } => self.resume(request_key, block_presence), Command::Commit { client_id } => { self.commit(client_id); } @@ -279,17 +335,10 @@ impl Worker { } #[instrument(skip(self))] - fn handle_initial(&mut self, client_id: ClientId, request: Request) { + fn handle_initial(&mut self, client_id: ClientId, request: CandidateRequest) { tracing::trace!("handle_initial"); - self.insert_request( - client_id, - PendingRequest { - request, - block_presence: MultiBlockPresence::None, - }, - None, - ) + self.insert_request(client_id, request, None) } #[instrument(skip(self))] @@ -297,7 +346,7 @@ impl Worker { &mut self, client_id: ClientId, request_key: MessageKey, - requests: Vec, + requests: Vec, ) { tracing::trace!("handle_success"); @@ -322,6 +371,7 @@ impl Worker { Some(mem::take(waiters)) } RequestState::InFlight { .. } + | RequestState::Suspended { .. } | RequestState::Complete { .. } | RequestState::Committed | RequestState::Cancelled => None, @@ -384,6 +434,45 @@ impl Worker { self.cancel_request(client_id, node_key); } + #[instrument(skip(self))] + fn resume(&mut self, request_key: MessageKey, block_presence: MultiBlockPresence) { + let Some(node) = self.requests.lookup_mut(request_key, block_presence) else { + return; + }; + + let (sender_client_id, sender_client_state, waiters) = match node.value_mut() { + RequestState::Suspended { waiters } => { + let Some(client_id) = waiters.pop_front() else { + return; + }; + + let Some(client_state) = self.clients.get(&client_id) else { + return; + }; + + (client_id, client_state, mem::take(waiters)) + } + RequestState::InFlight { .. } + | RequestState::Complete { .. } + | RequestState::Committed + | RequestState::Cancelled => return, + }; + + let sender_timer_key = self + .timer + .insert((sender_client_id, request_key), REQUEST_TIMEOUT); + sender_client_state + .request_tx + .send(node.request().clone()) + .ok(); + + *node.value_mut() = RequestState::InFlight { + sender_client_id, + sender_timer_key, + waiters, + }; + } + #[instrument(skip(self))] fn commit(&mut self, client_id: ClientId) { tracing::trace!("commit"); @@ -402,6 +491,7 @@ impl Worker { sender_client_id, .. } if *sender_client_id == client_id => true, RequestState::Complete { .. } + | RequestState::Suspended { .. } | RequestState::InFlight { .. } | RequestState::Committed | RequestState::Cancelled => false, @@ -418,7 +508,8 @@ impl Worker { let waiters = match node.value_mut() { RequestState::Complete { waiters, .. } => mem::take(waiters), - RequestState::InFlight { .. } + RequestState::Suspended { .. } + | RequestState::InFlight { .. } | RequestState::Committed | RequestState::Cancelled => unreachable!(), }; @@ -442,17 +533,27 @@ impl Worker { fn insert_request( &mut self, client_id: ClientId, - request: PendingRequest, + request: CandidateRequest, parent_key: Option, ) { - let node_key = self - .requests - .get_or_insert(request, parent_key, RequestState::Cancelled); + let node_key = self.requests.get_or_insert( + PendingRequest { + request: request.request, + block_presence: request.block_presence, + }, + parent_key, + RequestState::Cancelled, + ); - self.update_request(client_id, node_key); + self.update_request(client_id, node_key, request.state); } - fn update_request(&mut self, client_id: ClientId, node_key: GraphKey) { + fn update_request( + &mut self, + client_id: ClientId, + node_key: GraphKey, + initial_state: InitialRequestState, + ) { let Some(node) = self.requests.get_mut(node_key) else { return; }; @@ -463,23 +564,35 @@ impl Worker { let request_key = MessageKey::from(&node.request().request); - match node.value_mut() { - RequestState::InFlight { waiters, .. } | RequestState::Complete { waiters, .. } => { - waiters.push_back(client_id); - client_state.requests.insert(request_key, node_key); - } - RequestState::Committed => (), - RequestState::Cancelled => { - let timer_key = self.timer.insert((client_id, request_key), REQUEST_TIMEOUT); - - *node.value_mut() = RequestState::InFlight { - sender_client_id: client_id, - sender_timer_key: timer_key, - waiters: VecDeque::new(), - }; - - client_state.requests.insert(request_key, node_key); - client_state.request_tx.send(node.request().clone()).ok(); + if let Entry::Vacant(entry) = client_state.requests.entry(request_key) { + match node.value_mut() { + RequestState::Suspended { waiters } + | RequestState::InFlight { waiters, .. } + | RequestState::Complete { waiters, .. } => { + waiters.push_back(client_id); + entry.insert(node_key); + } + RequestState::Committed => (), + RequestState::Cancelled => { + *node.value_mut() = match initial_state { + InitialRequestState::InFlight => { + let timer_key = + self.timer.insert((client_id, request_key), REQUEST_TIMEOUT); + client_state.request_tx.send(node.request().clone()).ok(); + + RequestState::InFlight { + sender_client_id: client_id, + sender_timer_key: timer_key, + waiters: VecDeque::new(), + } + } + InitialRequestState::Suspended => RequestState::Suspended { + waiters: [client_id].into(), + }, + }; + + entry.insert(node_key); + } } } @@ -488,7 +601,7 @@ impl Worker { let children: Vec<_> = node.children().collect(); for child_key in children { - self.update_request(client_id, child_key); + self.update_request(client_id, child_key, initial_state); } } @@ -499,56 +612,75 @@ impl Worker { let (request, state) = node.parts_mut(); - let (sender_client_id, sender_timer_key, waiters) = match state { + let waiters = match state { + RequestState::Suspended { waiters } => { + remove_from_queue(waiters, &client_id); + + if !waiters.is_empty() { + return; + } + + None + } RequestState::InFlight { sender_client_id, sender_timer_key, waiters, - } => (sender_client_id, Some(sender_timer_key), waiters), + } => { + if *sender_client_id == client_id { + self.timer.try_remove(sender_timer_key); + Some(waiters) + } else { + remove_from_queue(waiters, &client_id); + return; + } + } RequestState::Complete { sender_client_id, waiters, - } => (sender_client_id, None, waiters), + } => { + if *sender_client_id == client_id { + Some(waiters) + } else { + remove_from_queue(waiters, &client_id); + return; + } + } RequestState::Committed | RequestState::Cancelled => return, }; - if *sender_client_id != client_id { - remove_from_queue(waiters, &client_id); - return; - } + // Find next waiting client + if let Some(waiters) = waiters { + let next_client = iter::from_fn(|| waiters.pop_front()) + .find_map(|client_id| self.clients.get_key_value(&client_id)); + + if let Some((&next_client_id, next_client_state)) = next_client { + // Next waiting client found. Promote it to the sender. + let sender_client_id = next_client_id; + let sender_timer_key = self.timer.insert( + (next_client_id, MessageKey::from(&request.request)), + REQUEST_TIMEOUT, + ); + let waiters = mem::take(waiters); + + *state = RequestState::InFlight { + sender_client_id, + sender_timer_key, + waiters, + }; - if let Some(timer_key) = sender_timer_key { - self.timer.try_remove(timer_key); - } + next_client_state.request_tx.send(request.clone()).ok(); - // Find next waiting client - let next_client = iter::from_fn(|| waiters.pop_front()) - .find_map(|client_id| self.clients.get_key_value(&client_id)); - - if let Some((&next_client_id, next_client_state)) = next_client { - // Next waiting client found. Promote it to the sender. - let sender_client_id = next_client_id; - let sender_timer_key = self.timer.insert( - (next_client_id, MessageKey::from(&request.request)), - REQUEST_TIMEOUT, - ); - let waiters = mem::take(waiters); - - *state = RequestState::InFlight { - sender_client_id, - sender_timer_key, - waiters, - }; + return; + } + } - next_client_state.request_tx.send(request.clone()).ok(); + // No waiting client found. If this request has no children, we can remove + // it, otherwise we mark it as cancelled. + if node.children().len() > 0 { + *node.value_mut() = RequestState::Cancelled; } else { - // No waiting client found. If this request has no children, we can remove - // it, otherwise we mark it as cancelled. - if node.children().len() > 0 { - *node.value_mut() = RequestState::Cancelled; - } else { - self.remove_request(node_key); - } + self.remove_request(node_key); } } @@ -591,17 +723,21 @@ enum Command { }, HandleInitial { client_id: ClientId, - request: Request, + request: CandidateRequest, }, HandleSuccess { client_id: ClientId, request_key: MessageKey, - requests: Vec, + requests: Vec, }, HandleFailure { client_id: ClientId, request_key: MessageKey, }, + Resume { + request_key: MessageKey, + block_presence: MultiBlockPresence, + }, Commit { client_id: ClientId, }, @@ -622,6 +758,8 @@ impl ClientState { } enum RequestState { + /// This request is ready to be sent. + Suspended { waiters: VecDeque }, /// This request is currently in flight InFlight { /// Client who's sending this request diff --git a/lib/src/network/request_tracker/graph.rs b/lib/src/network/request_tracker/graph.rs index 1db4fe6d6..e1010759d 100644 --- a/lib/src/network/request_tracker/graph.rs +++ b/lib/src/network/request_tracker/graph.rs @@ -72,6 +72,27 @@ impl Graph { self.nodes.get_mut(key.0) } + #[expect(unused)] + pub fn lookup( + &self, + request_key: MessageKey, + block_presence: MultiBlockPresence, + ) -> Option<&Node> { + self.index + .get(&(request_key, block_presence)) + .and_then(|key| self.nodes.get(key.0)) + } + + pub fn lookup_mut( + &mut self, + request_key: MessageKey, + block_presence: MultiBlockPresence, + ) -> Option<&mut Node> { + self.index + .get(&(request_key, block_presence)) + .and_then(|key| self.nodes.get_mut(key.0)) + } + pub fn remove(&mut self, key: Key) -> Option> { let node = self.nodes.try_remove(key.0)?; diff --git a/lib/src/network/request_tracker/simulation.rs b/lib/src/network/request_tracker/simulation.rs index c04e144b1..2bbdf19aa 100644 --- a/lib/src/network/request_tracker/simulation.rs +++ b/lib/src/network/request_tracker/simulation.rs @@ -1,6 +1,6 @@ use super::{ super::message::{Request, Response, ResponseDisambiguator}, - MessageKey, PendingRequest, RequestTracker, RequestTrackerClient, + CandidateRequest, MessageKey, PendingRequest, RequestTracker, RequestTrackerClient, }; use crate::{ collections::{HashMap, HashSet}, @@ -192,14 +192,14 @@ impl TestClient { Response::RootNode(proof, block_presence, debug_payload) => { let requests = snapshot .insert_root(proof.hash, block_presence) - .then_some(PendingRequest { - request: Request::ChildNodes( + .then_some( + CandidateRequest::new(Request::ChildNodes( proof.hash, ResponseDisambiguator::new(block_presence), debug_payload.follow_up(), - ), - block_presence, - }) + )) + .follow_up(block_presence), + ) .into_iter() .collect(); @@ -212,13 +212,13 @@ impl TestClient { let requests: Vec<_> = nodes .into_iter() - .map(|(_, node)| PendingRequest { - request: Request::ChildNodes( + .map(|(_, node)| { + CandidateRequest::new(Request::ChildNodes( node.hash, ResponseDisambiguator::new(node.summary.block_presence), debug_payload.follow_up(), - ), - block_presence: node.summary.block_presence, + )) + .follow_up(node.summary.block_presence) }) .collect(); @@ -230,9 +230,11 @@ impl TestClient { let nodes = snapshot.insert_leaves(nodes); let requests = nodes .into_iter() - .map(|node| PendingRequest { - request: Request::Block(node.block_id, debug_payload.follow_up()), - block_presence: MultiBlockPresence::None, + .map(|node| { + CandidateRequest::new(Request::Block( + node.block_id, + debug_payload.follow_up(), + )) }) .collect(); diff --git a/lib/src/network/request_tracker/tests.rs b/lib/src/network/request_tracker/tests.rs index 4e7683158..6e27981bb 100644 --- a/lib/src/network/request_tracker/tests.rs +++ b/lib/src/network/request_tracker/tests.rs @@ -165,18 +165,12 @@ async fn timeout() { // Register the request with both clients. client_a.success( preceding_request_key, - vec![PendingRequest { - request: request.clone(), - block_presence: MultiBlockPresence::Full, - }], + vec![CandidateRequest::new(request.clone()).follow_up(MultiBlockPresence::Full)], ); client_b.success( preceding_request_key, - vec![PendingRequest { - request: request.clone(), - block_presence: MultiBlockPresence::Full, - }], + vec![CandidateRequest::new(request.clone()).follow_up(MultiBlockPresence::Full)], ); time::timeout(Duration::from_millis(1), &mut work) @@ -225,10 +219,7 @@ async fn drop_uncommitted_client() { for client in [&client_a, &client_b] { client.success( preceding_request_key, - vec![PendingRequest { - request: request.clone(), - block_presence: MultiBlockPresence::Full, - }], + vec![CandidateRequest::new(request.clone()).follow_up(MultiBlockPresence::Full)], ); } @@ -279,7 +270,7 @@ async fn multiple_responses_to_identical_requests() { ); // Send initial root node request - client.initial(initial_request.clone()); + client.initial(CandidateRequest::new(initial_request.clone())); worker.step(); assert_matches!(request_rx.try_recv(), Ok(_)); @@ -294,10 +285,7 @@ async fn multiple_responses_to_identical_requests() { // followups than the one received previously. client.success( MessageKey::from(&initial_request), - vec![PendingRequest { - request: followup_request.clone(), - block_presence: MultiBlockPresence::Full, - }], + vec![CandidateRequest::new(followup_request.clone()).follow_up(MultiBlockPresence::Full)], ); worker.step(); @@ -312,6 +300,35 @@ async fn multiple_responses_to_identical_requests() { // - the responses are received by different clients } +#[tokio::test] +async fn suspend_resume() { + let mut rng = StdRng::seed_from_u64(0); + let (tracker, mut worker) = build(); + + let (client, mut request_rx) = tracker.new_client(); + worker.step(); + + let preceding_request_key = MessageKey::ChildNodes(rng.gen()); + let request = Request::Block(rng.gen(), DebugRequest::start()); + let request_key = MessageKey::from(&request); + + client.success( + preceding_request_key, + vec![CandidateRequest::new(request.clone()).suspended()], + ); + worker.step(); + + assert_eq!( + request_rx.try_recv().map(|r| r.request), + Err(TryRecvError::Empty) + ); + + client.resume(request_key, MultiBlockPresence::None); + worker.step(); + + assert_eq!(request_rx.try_recv().map(|r| r.request), Ok(request)); +} + /// Generate `count + 1` copies of the same snapshot. The first one will have all the blocks /// present (the "master copy"). The remaining ones will have some blocks missing but in such a /// way that every block is present in at least one of the snapshots. From 60f2ae2087841dea941e35b7d502a71321c5ddde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Wed, 2 Oct 2024 13:40:27 +0200 Subject: [PATCH 44/55] utils/stress: Run tests in debug mode by default --- utils/stress-test/src/main.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/utils/stress-test/src/main.rs b/utils/stress-test/src/main.rs index c3e39de5c..d5e2ebf8d 100644 --- a/utils/stress-test/src/main.rs +++ b/utils/stress-test/src/main.rs @@ -106,6 +106,10 @@ struct Options { #[arg(short = 'F', long)] features: Vec, + /// Build package in release mode + #[arg(short, long)] + release: bool, + /// Test only this package's library #[arg(long)] lib: bool, @@ -134,10 +138,13 @@ fn build(options: &Options) -> Vec { .arg("--no-run") .arg("--package") .arg(&options.package) - .arg("--release") .arg("--message-format") .arg("json"); + if options.release { + command.arg("--release"); + } + for feature in &options.features { command.arg("--features").arg(feature); } From 327f32b5692f1760466eb73cc72bf1e2ba4bb981 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Wed, 2 Oct 2024 16:25:30 +0200 Subject: [PATCH 45/55] Fix requests not removed from clients in RequestTracker --- lib/src/network/request_tracker.rs | 62 +++++++++++++++++++++--------- lib/src/store/block.rs | 5 +-- lib/src/store/client.rs | 1 + 3 files changed, 45 insertions(+), 23 deletions(-) diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index 7003399c0..606e80de9 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -506,25 +506,11 @@ impl Worker { continue; }; - let waiters = match node.value_mut() { - RequestState::Complete { waiters, .. } => mem::take(waiters), - RequestState::Suspended { .. } - | RequestState::InFlight { .. } - | RequestState::Committed - | RequestState::Cancelled => unreachable!(), - }; - - // Remove the requests from this client and all the waiters - for client_id in iter::once(client_id).chain(waiters) { - if let Some(client_state) = self.clients.get_mut(&client_id) { - client_state.requests.remove(&request_key); - } - } - // If the node has no children, remove it, otherwise mark is as committed. if node.children().len() == 0 { self.remove_request(node_key); } else { + remove_request_from_clients(&mut self.clients, request_key, node.value()); *node.value_mut() = RequestState::Committed; } } @@ -649,6 +635,8 @@ impl Worker { RequestState::Committed | RequestState::Cancelled => return, }; + let request_key = MessageKey::from(&request.request); + // Find next waiting client if let Some(waiters) = waiters { let next_client = iter::from_fn(|| waiters.pop_front()) @@ -657,10 +645,9 @@ impl Worker { if let Some((&next_client_id, next_client_state)) = next_client { // Next waiting client found. Promote it to the sender. let sender_client_id = next_client_id; - let sender_timer_key = self.timer.insert( - (next_client_id, MessageKey::from(&request.request)), - REQUEST_TIMEOUT, - ); + let sender_timer_key = self + .timer + .insert((next_client_id, request_key), REQUEST_TIMEOUT); let waiters = mem::take(waiters); *state = RequestState::InFlight { @@ -678,6 +665,7 @@ impl Worker { // No waiting client found. If this request has no children, we can remove // it, otherwise we mark it as cancelled. if node.children().len() > 0 { + remove_request_from_clients(&mut self.clients, request_key, node.value()); *node.value_mut() = RequestState::Cancelled; } else { self.remove_request(node_key); @@ -689,6 +677,9 @@ impl Worker { return; }; + let request_key = MessageKey::from(&node.request().request); + remove_request_from_clients(&mut self.clients, request_key, node.value()); + for parent_key in node.parents() { let Some(parent_node) = self.requests.get(parent_key) else { continue; @@ -757,6 +748,7 @@ impl ClientState { } } +#[derive(Debug)] enum RequestState { /// This request is ready to be sent. Suspended { waiters: VecDeque }, @@ -785,6 +777,26 @@ enum RequestState { Cancelled, } +impl RequestState { + fn clients(&self) -> impl Iterator { + match self { + Self::Suspended { waiters } => Some((None, waiters)), + Self::InFlight { + sender_client_id, + waiters, + .. + } + | Self::Complete { + sender_client_id, + waiters, + } => Some((Some(sender_client_id), waiters)), + Self::Committed | Self::Cancelled => None, + } + .into_iter() + .flat_map(|(sender_client_id, waiters)| sender_client_id.into_iter().chain(waiters)) + } +} + #[derive(Debug)] enum FailureReason { Response, @@ -796,3 +808,15 @@ fn remove_from_queue(queue: &mut VecDeque, item: &T) { queue.remove(index); } } + +fn remove_request_from_clients( + clients: &mut HashMap, + request_key: MessageKey, + state: &RequestState, +) { + for client_id in state.clients() { + if let Some(client_state) = clients.get_mut(client_id) { + client_state.requests.remove(&request_key); + } + } +} diff --git a/lib/src/store/block.rs b/lib/src/store/block.rs index 0397cd6a9..7cd125e6e 100644 --- a/lib/src/store/block.rs +++ b/lib/src/store/block.rs @@ -24,10 +24,7 @@ pub(super) async fn read( .bind(id) .fetch_optional(conn) .await? - .ok_or_else(|| { - tracing::trace!(?id, "Block not found"); - Error::BlockNotFound - })?; + .ok_or(Error::BlockNotFound)?; let nonce: &[u8] = row.get(0); let nonce = BlockNonce::try_from(nonce).map_err(|_| Error::MalformedData)?; diff --git a/lib/src/store/client.rs b/lib/src/store/client.rs index f7b0c66a2..8bf133400 100644 --- a/lib/src/store/client.rs +++ b/lib/src/store/client.rs @@ -95,6 +95,7 @@ impl ClientWriter { } let local_node = inner_node::load(&mut self.db, &remote_node.hash).await?; + if local_node .map(|local_node| local_node.summary.is_outdated(&remote_node.summary)) .unwrap_or(true) From b19ecc8a445fbf80bc6b43209e0080b85e228977 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Thu, 3 Oct 2024 09:19:10 +0200 Subject: [PATCH 46/55] Use both local and remote block presence to derive request variant --- lib/src/network/client.rs | 36 ++++----- lib/src/network/request_tracker.rs | 72 +++++++++++------ lib/src/network/request_tracker/graph.rs | 78 +++++++++---------- lib/src/network/request_tracker/simulation.rs | 47 +++++------ lib/src/network/request_tracker/tests.rs | 32 ++++---- lib/src/protocol/summary.rs | 2 +- lib/src/store/client.rs | 38 +++++---- lib/src/store/root_node.rs | 51 +++++++----- lib/tests/sync.rs | 5 +- 9 files changed, 198 insertions(+), 163 deletions(-) diff --git a/lib/src/network/client.rs b/lib/src/network/client.rs index 30a4edeee..886c36170 100644 --- a/lib/src/network/client.rs +++ b/lib/src/network/client.rs @@ -11,7 +11,7 @@ use crate::{ event::Payload, network::{ message::Request, - request_tracker::{CandidateRequest, MessageKey}, + request_tracker::{CandidateRequest, MessageKey, RequestVariant}, }, protocol::{ Block, BlockId, InnerNodes, LeafNodes, MultiBlockPresence, NodeState, ProofError, @@ -98,8 +98,8 @@ impl Inner { } async fn send_requests(&self, request_rx: &mut mpsc::UnboundedReceiver) { - while let Some(PendingRequest { request, .. }) = request_rx.recv().await { - self.message_tx.send(Message::Request(request)).ok(); + while let Some(PendingRequest { payload, .. }) = request_rx.recv().await { + self.message_tx.send(Message::Request(payload)).ok(); } } @@ -256,14 +256,14 @@ impl Inner { MessageKey::RootNode(writer_id), status .request_children() - .then_some( + .map(|local_block_presence| { CandidateRequest::new(Request::ChildNodes( hash, ResponseDisambiguator::new(block_presence), debug_payload.follow_up(), )) - .follow_up(block_presence), - ) + .variant(RequestVariant::new(local_block_presence, block_presence)) + }) .into_iter() .collect(), ); @@ -280,26 +280,24 @@ impl Inner { ) -> Result<()> { let hash = nodes.hash(); let total = nodes.len(); - let status = writer.save_inner_nodes(nodes).await?; + let statuses = writer.save_inner_nodes(nodes).await?; - tracing::trace!( - "Received {}/{} inner nodes", - status.new_children.len(), - total - ); + tracing::trace!("Received {}/{} inner nodes", statuses.len(), total); self.request_tracker.success( MessageKey::ChildNodes(hash), - status - .new_children + statuses .into_iter() - .map(|node| { + .map(|status| { CandidateRequest::new(Request::ChildNodes( - node.hash, - ResponseDisambiguator::new(node.summary.block_presence), + status.hash, + ResponseDisambiguator::new(status.remote_block_presence), debug_payload.follow_up(), )) - .follow_up(node.summary.block_presence) + .variant(RequestVariant::new( + status.local_block_presence, + status.remote_block_presence, + )) }) .collect(), ); @@ -440,7 +438,7 @@ impl Inner { async fn request_blocks(&self, block_rx: &mut mpsc::UnboundedReceiver) { while let Some(block_id) = block_rx.recv().await { self.request_tracker - .resume(MessageKey::Block(block_id), MultiBlockPresence::None); + .resume(MessageKey::Block(block_id), RequestVariant::default()); } } diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index 606e80de9..2aab3adab 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -14,6 +14,8 @@ use crate::{ }; use std::{ collections::{hash_map::Entry, VecDeque}, + fmt, + hash::Hasher, iter, mem, sync::atomic::{AtomicUsize, Ordering}, }; @@ -21,6 +23,7 @@ use tokio::{select, sync::mpsc, task}; use tokio_stream::StreamExt; use tokio_util::time::{delay_queue, DelayQueue}; use tracing::{instrument, Instrument, Span}; +use twox_hash::xxh3::{Hash128, HasherExt}; /// Keeps track of in-flight requests. Falls back on another peer in case the request failed (due to /// error response, timeout or disconnection). Evenly distributes the requests between the peers @@ -102,11 +105,11 @@ impl RequestTrackerClient { } /// Resume suspended request. - pub fn resume(&self, request_key: MessageKey, block_presence: MultiBlockPresence) { + pub fn resume(&self, request_key: MessageKey, variant: RequestVariant) { self.command_tx .send(Command::Resume { request_key, - block_presence, + variant, }) .ok(); } @@ -154,27 +157,23 @@ impl RequestTrackerCommitter { /// Request that we want to send to the peer. #[derive(Clone, Debug)] pub(super) struct CandidateRequest { - pub request: Request, - pub block_presence: MultiBlockPresence, + pub payload: Request, + pub variant: RequestVariant, pub state: InitialRequestState, } impl CandidateRequest { /// Create new candiate for the given request. - pub fn new(request: Request) -> Self { + pub fn new(payload: Request) -> Self { Self { - request, - block_presence: MultiBlockPresence::None, + payload, + variant: RequestVariant::default(), state: InitialRequestState::InFlight, } } - /// Set this candidate to a followup request to a response with the given block presence. - pub fn follow_up(self, block_presence: MultiBlockPresence) -> Self { - Self { - block_presence, - ..self - } + pub fn variant(self, variant: RequestVariant) -> Self { + Self { variant, ..self } } /// Start this request in suspended state. @@ -186,6 +185,29 @@ impl CandidateRequest { } } +#[derive(Default, Clone, Copy, Eq, PartialEq, Hash)] +pub(super) struct RequestVariant([u8; 16]); + +impl RequestVariant { + pub fn new( + local_block_presence: MultiBlockPresence, + remote_block_presence: MultiBlockPresence, + ) -> Self { + let mut hasher = Hash128::default(); + + hasher.write(local_block_presence.checksum()); + hasher.write(remote_block_presence.checksum()); + + Self(hasher.finish_ext().to_le_bytes()) + } +} + +impl fmt::Debug for RequestVariant { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:<8x}", hex_fmt::HexFmt(&self.0)) + } +} + #[derive(Clone, Copy, Debug)] pub(super) enum InitialRequestState { InFlight, @@ -198,8 +220,8 @@ pub(super) enum InitialRequestState { /// mostly useful for diagnostics and testing. #[derive(Clone, Debug)] pub(super) struct PendingRequest { - pub request: Request, - pub block_presence: MultiBlockPresence, + pub payload: Request, + pub variant: RequestVariant, } /// Key identifying a request and its corresponding response. @@ -302,8 +324,8 @@ impl Worker { } Command::Resume { request_key, - block_presence, - } => self.resume(request_key, block_presence), + variant, + } => self.resume(request_key, variant), Command::Commit { client_id } => { self.commit(client_id); } @@ -435,8 +457,8 @@ impl Worker { } #[instrument(skip(self))] - fn resume(&mut self, request_key: MessageKey, block_presence: MultiBlockPresence) { - let Some(node) = self.requests.lookup_mut(request_key, block_presence) else { + fn resume(&mut self, request_key: MessageKey, variant: RequestVariant) { + let Some(node) = self.requests.lookup_mut(request_key, variant) else { return; }; @@ -524,8 +546,8 @@ impl Worker { ) { let node_key = self.requests.get_or_insert( PendingRequest { - request: request.request, - block_presence: request.block_presence, + payload: request.payload, + variant: request.variant, }, parent_key, RequestState::Cancelled, @@ -548,7 +570,7 @@ impl Worker { return; }; - let request_key = MessageKey::from(&node.request().request); + let request_key = MessageKey::from(&node.request().payload); if let Entry::Vacant(entry) = client_state.requests.entry(request_key) { match node.value_mut() { @@ -635,7 +657,7 @@ impl Worker { RequestState::Committed | RequestState::Cancelled => return, }; - let request_key = MessageKey::from(&request.request); + let request_key = MessageKey::from(&request.payload); // Find next waiting client if let Some(waiters) = waiters { @@ -677,7 +699,7 @@ impl Worker { return; }; - let request_key = MessageKey::from(&node.request().request); + let request_key = MessageKey::from(&node.request().payload); remove_request_from_clients(&mut self.clients, request_key, node.value()); for parent_key in node.parents() { @@ -727,7 +749,7 @@ enum Command { }, Resume { request_key: MessageKey, - block_presence: MultiBlockPresence, + variant: RequestVariant, }, Commit { client_id: ClientId, diff --git a/lib/src/network/request_tracker/graph.rs b/lib/src/network/request_tracker/graph.rs index e1010759d..cfa4e71d0 100644 --- a/lib/src/network/request_tracker/graph.rs +++ b/lib/src/network/request_tracker/graph.rs @@ -1,15 +1,14 @@ -use super::{MessageKey, PendingRequest}; +use super::{MessageKey, PendingRequest, RequestVariant}; use crate::{ collections::{HashMap, HashSet}, network::message::Request, - protocol::MultiBlockPresence, }; use slab::Slab; use std::collections::hash_map::Entry; /// DAG for storing data for the request tracker. pub(super) struct Graph { - index: HashMap<(MessageKey, MultiBlockPresence), Key>, + index: HashMap<(MessageKey, RequestVariant), Key>, nodes: Slab>, } @@ -29,7 +28,7 @@ impl Graph { ) -> Key { let node_key = match self .index - .entry((MessageKey::from(&request.request), request.block_presence)) + .entry((MessageKey::from(&request.payload), request.variant)) { Entry::Occupied(entry) => { self.nodes @@ -73,23 +72,19 @@ impl Graph { } #[expect(unused)] - pub fn lookup( - &self, - request_key: MessageKey, - block_presence: MultiBlockPresence, - ) -> Option<&Node> { + pub fn lookup(&self, request_key: MessageKey, variant: RequestVariant) -> Option<&Node> { self.index - .get(&(request_key, block_presence)) + .get(&(request_key, variant)) .and_then(|key| self.nodes.get(key.0)) } pub fn lookup_mut( &mut self, request_key: MessageKey, - block_presence: MultiBlockPresence, + variant: RequestVariant, ) -> Option<&mut Node> { self.index - .get(&(request_key, block_presence)) + .get(&(request_key, variant)) .and_then(|key| self.nodes.get_mut(key.0)) } @@ -97,8 +92,8 @@ impl Graph { let node = self.nodes.try_remove(key.0)?; self.index.remove(&( - MessageKey::from(&node.request.request), - node.request.block_presence, + MessageKey::from(&node.request.payload), + node.request.variant, )); for parent_key in &node.parents { @@ -122,7 +117,7 @@ impl Graph { #[cfg_attr(not(test), expect(dead_code))] pub fn requests(&self) -> impl ExactSizeIterator { - self.nodes.iter().map(|(_, node)| &node.request.request) + self.nodes.iter().map(|(_, node)| &node.request.payload) } } @@ -165,7 +160,10 @@ impl Node { #[cfg(test)] mod tests { use super::*; - use crate::network::{debug_payload::DebugRequest, message::ResponseDisambiguator}; + use crate::{ + network::{debug_payload::DebugRequest, message::ResponseDisambiguator}, + protocol::MultiBlockPresence, + }; use rand::Rng; #[test] @@ -183,8 +181,8 @@ mod tests { let parent_node_key = graph.get_or_insert( PendingRequest { - request: parent_request.clone(), - block_presence: MultiBlockPresence::Full, + payload: parent_request.clone(), + variant: RequestVariant::default(), }, None, 1, @@ -198,7 +196,7 @@ mod tests { assert_eq!(*node.value(), 1); assert_eq!(node.children().len(), 0); - assert_eq!(node.request().request, parent_request); + assert_eq!(node.request().payload, parent_request); let child_request = Request::ChildNodes( rng.gen(), @@ -208,8 +206,8 @@ mod tests { let child_node_key = graph.get_or_insert( PendingRequest { - request: child_request.clone(), - block_presence: MultiBlockPresence::Full, + payload: child_request.clone(), + variant: RequestVariant::default(), }, Some(parent_node_key), 2, @@ -223,7 +221,7 @@ mod tests { assert_eq!(*node.value(), 2); assert_eq!(node.children().len(), 0); - assert_eq!(node.request().request, child_request); + assert_eq!(node.request().payload, child_request); assert_eq!( graph @@ -254,8 +252,8 @@ mod tests { let node_key0 = graph.get_or_insert( PendingRequest { - request: request.clone(), - block_presence: MultiBlockPresence::Full, + payload: request.clone(), + variant: RequestVariant::default(), }, None, 1, @@ -264,8 +262,8 @@ mod tests { let node_key1 = graph.get_or_insert( PendingRequest { - request, - block_presence: MultiBlockPresence::Full, + payload: request, + variant: RequestVariant::default(), }, None, 1, @@ -282,21 +280,21 @@ mod tests { let hash = rng.gen(); let parent_request_0 = PendingRequest { - request: Request::ChildNodes( + payload: Request::ChildNodes( hash, ResponseDisambiguator::new(MultiBlockPresence::None), DebugRequest::start(), ), - block_presence: MultiBlockPresence::None, + variant: RequestVariant::new(MultiBlockPresence::None, MultiBlockPresence::None), }; let parent_request_1 = PendingRequest { - request: Request::ChildNodes( + payload: Request::ChildNodes( hash, ResponseDisambiguator::new(MultiBlockPresence::Full), DebugRequest::start(), ), - block_presence: MultiBlockPresence::Full, + variant: RequestVariant::new(MultiBlockPresence::None, MultiBlockPresence::Full), }; let child_request = Request::Block(rng.gen(), DebugRequest::start()); @@ -306,8 +304,8 @@ mod tests { let child_key_0 = graph.get_or_insert( PendingRequest { - request: child_request.clone(), - block_presence: MultiBlockPresence::None, + payload: child_request.clone(), + variant: RequestVariant::default(), }, Some(parent_key_0), 2, @@ -315,8 +313,8 @@ mod tests { let child_key_1 = graph.get_or_insert( PendingRequest { - request: child_request, - block_presence: MultiBlockPresence::None, + payload: child_request, + variant: RequestVariant::default(), }, Some(parent_key_1), 2, @@ -373,30 +371,30 @@ mod tests { let mut graph = Graph::new(); let parent_request = PendingRequest { - request: Request::ChildNodes( + payload: Request::ChildNodes( rng.gen(), ResponseDisambiguator::new(MultiBlockPresence::Full), DebugRequest::start(), ), - block_presence: MultiBlockPresence::Full, + variant: RequestVariant::default(), }; let child_request_0 = PendingRequest { - request: Request::ChildNodes( + payload: Request::ChildNodes( rng.gen(), ResponseDisambiguator::new(MultiBlockPresence::Full), DebugRequest::start(), ), - block_presence: MultiBlockPresence::Full, + variant: RequestVariant::default(), }; let child_request_1 = PendingRequest { - request: Request::ChildNodes( + payload: Request::ChildNodes( rng.gen(), ResponseDisambiguator::new(MultiBlockPresence::Full), DebugRequest::start(), ), - block_presence: MultiBlockPresence::Full, + variant: RequestVariant::default(), }; let parent_key = graph.get_or_insert(parent_request, None, 0); diff --git a/lib/src/network/request_tracker/simulation.rs b/lib/src/network/request_tracker/simulation.rs index 2bbdf19aa..c2cbe9e2c 100644 --- a/lib/src/network/request_tracker/simulation.rs +++ b/lib/src/network/request_tracker/simulation.rs @@ -1,6 +1,7 @@ use super::{ super::message::{Request, Response, ResponseDisambiguator}, CandidateRequest, MessageKey, PendingRequest, RequestTracker, RequestTrackerClient, + RequestVariant, }; use crate::{ collections::{HashMap, HashSet}, @@ -25,8 +26,8 @@ pub(super) struct Simulation { // All requests sent by live peers. This is used to verify that every request is sent only once // unless the peer that sent it died or the request failed. In those cases the request may be // sent by another peer. It's also allowed to sent the same request more than once as long as - // each one has a different block presence. - requests: HashMap>, + // each one has a different variant. + requests: HashMap>, snapshot: Snapshot, } @@ -67,8 +68,8 @@ impl Simulation { let index = rng.gen_range(0..self.peers.len()); let peer = self.peers.remove(index); - for (key, block_presence) in peer.requests { - cancel_request(&mut self.requests, key, block_presence); + for (key, variant) in peer.requests { + cancel_request(&mut self.requests, key, variant); } } @@ -91,20 +92,16 @@ impl Simulation { match side { Side::Client => { - if let Some(PendingRequest { - request, - block_presence, - }) = peer.client.poll_request() - { - let key = MessageKey::from(&request); + if let Some(PendingRequest { payload, variant }) = peer.client.poll_request() { + let key = MessageKey::from(&payload); assert!( - self.requests.entry(key).or_default().insert(block_presence), - "request sent more than once: {request:?} ({block_presence:?})" + self.requests.entry(key).or_default().insert(variant), + "request sent more than once: {payload:?} ({variant:?})" ); - peer.requests.insert(key, block_presence); - peer.server.handle_request(request); + peer.requests.insert(key, variant); + peer.server.handle_request(payload); return true; } @@ -129,8 +126,8 @@ impl Simulation { }; if let Some(key) = key { - if let Some(block_presence) = peer.requests.get(&key) { - cancel_request(&mut self.requests, key, *block_presence); + if let Some(variant) = peer.requests.get(&key) { + cancel_request(&mut self.requests, key, *variant); } } @@ -151,12 +148,12 @@ impl Simulation { } fn cancel_request( - requests: &mut HashMap>, + requests: &mut HashMap>, key: MessageKey, - block_presence: MultiBlockPresence, + variant: RequestVariant, ) { if let Entry::Occupied(mut entry) = requests.entry(key) { - entry.get_mut().remove(&block_presence); + entry.get_mut().remove(&variant); if entry.get().is_empty() { entry.remove(); @@ -168,7 +165,7 @@ struct TestPeer { client: TestClient, server: TestServer, // All requests sent by this peer. - requests: HashMap, + requests: HashMap, } struct TestClient { @@ -198,7 +195,10 @@ impl TestClient { ResponseDisambiguator::new(block_presence), debug_payload.follow_up(), )) - .follow_up(block_presence), + .variant(RequestVariant::new( + MultiBlockPresence::None, + block_presence, + )), ) .into_iter() .collect(); @@ -218,7 +218,10 @@ impl TestClient { ResponseDisambiguator::new(node.summary.block_presence), debug_payload.follow_up(), )) - .follow_up(node.summary.block_presence) + .variant(RequestVariant::new( + MultiBlockPresence::None, + node.summary.block_presence, + )) }) .collect(); diff --git a/lib/src/network/request_tracker/tests.rs b/lib/src/network/request_tracker/tests.rs index 6e27981bb..f64687f2b 100644 --- a/lib/src/network/request_tracker/tests.rs +++ b/lib/src/network/request_tracker/tests.rs @@ -165,12 +165,12 @@ async fn timeout() { // Register the request with both clients. client_a.success( preceding_request_key, - vec![CandidateRequest::new(request.clone()).follow_up(MultiBlockPresence::Full)], + vec![CandidateRequest::new(request.clone())], ); client_b.success( preceding_request_key, - vec![CandidateRequest::new(request.clone()).follow_up(MultiBlockPresence::Full)], + vec![CandidateRequest::new(request.clone())], ); time::timeout(Duration::from_millis(1), &mut work) @@ -179,12 +179,12 @@ async fn timeout() { // Only the first client gets the request. assert_eq!( - request_rx_a.try_recv().map(|r| r.request), + request_rx_a.try_recv().map(|r| r.payload), Ok(request.clone()) ); assert_eq!( - request_rx_b.try_recv().map(|r| r.request), + request_rx_b.try_recv().map(|r| r.payload), Err(TryRecvError::Empty) ); @@ -195,7 +195,7 @@ async fn timeout() { // The first client timeouted so the second client now gets the request. assert_eq!( - request_rx_b.try_recv().map(|r| r.request), + request_rx_b.try_recv().map(|r| r.payload), Ok(request.clone()) ); } @@ -219,18 +219,18 @@ async fn drop_uncommitted_client() { for client in [&client_a, &client_b] { client.success( preceding_request_key, - vec![CandidateRequest::new(request.clone()).follow_up(MultiBlockPresence::Full)], + vec![CandidateRequest::new(request.clone())], ); } tracker_worker.step(); assert_eq!( - request_rx_a.try_recv().map(|r| r.request), + request_rx_a.try_recv().map(|r| r.payload), Ok(request.clone()) ); assert_eq!( - request_rx_b.try_recv().map(|r| r.request), + request_rx_b.try_recv().map(|r| r.payload), Err(TryRecvError::Empty) ); @@ -239,11 +239,11 @@ async fn drop_uncommitted_client() { tracker_worker.step(); assert_eq!( - request_rx_a.try_recv().map(|r| r.request), + request_rx_a.try_recv().map(|r| r.payload), Err(TryRecvError::Empty) ); assert_eq!( - request_rx_b.try_recv().map(|r| r.request), + request_rx_b.try_recv().map(|r| r.payload), Err(TryRecvError::Empty) ); @@ -253,7 +253,7 @@ async fn drop_uncommitted_client() { // The request falls back to the other client because although the request was completed, it // wasn't committed. - assert_eq!(request_rx_b.try_recv().map(|r| r.request), Ok(request)); + assert_eq!(request_rx_b.try_recv().map(|r| r.payload), Ok(request)); } #[tokio::test] @@ -285,13 +285,13 @@ async fn multiple_responses_to_identical_requests() { // followups than the one received previously. client.success( MessageKey::from(&initial_request), - vec![CandidateRequest::new(followup_request.clone()).follow_up(MultiBlockPresence::Full)], + vec![CandidateRequest::new(followup_request.clone())], ); worker.step(); // The followup requests are sent even though the assert_eq!( - request_rx.try_recv().map(|r| r.request), + request_rx.try_recv().map(|r| r.payload), Ok(followup_request) ); @@ -319,14 +319,14 @@ async fn suspend_resume() { worker.step(); assert_eq!( - request_rx.try_recv().map(|r| r.request), + request_rx.try_recv().map(|r| r.payload), Err(TryRecvError::Empty) ); - client.resume(request_key, MultiBlockPresence::None); + client.resume(request_key, RequestVariant::default()); worker.step(); - assert_eq!(request_rx.try_recv().map(|r| r.request), Ok(request)); + assert_eq!(request_rx.try_recv().map(|r| r.payload), Ok(request)); } /// Generate `count + 1` copies of the same snapshot. The first one will have all the blocks diff --git a/lib/src/protocol/summary.rs b/lib/src/protocol/summary.rs index 7ece94379..1ab73bb12 100644 --- a/lib/src/protocol/summary.rs +++ b/lib/src/protocol/summary.rs @@ -263,7 +263,7 @@ impl MultiBlockPresence { } } - fn checksum(&self) -> &[u8] { + pub fn checksum(&self) -> &[u8] { match self { Self::None => NONE.as_slice(), Self::Some(checksum) => checksum.as_slice(), diff --git a/lib/src/store/client.rs b/lib/src/store/client.rs index 8bf133400..52c1f31c5 100644 --- a/lib/src/store/client.rs +++ b/lib/src/store/client.rs @@ -15,7 +15,7 @@ use crate::{ db, future::TryStreamExt as _, protocol::{ - Block, BlockId, InnerNode, InnerNodes, LeafNodes, MultiBlockPresence, NodeState, Proof, + Block, BlockId, InnerNodes, LeafNodes, MultiBlockPresence, NodeState, Proof, RootNodeFilter, SingleBlockPresence, Summary, }, repository, StorageSize, @@ -78,14 +78,14 @@ impl ClientWriter { pub async fn save_inner_nodes( &mut self, nodes: CacheHash, - ) -> Result { + ) -> Result, Error> { let parent_hash = nodes.hash(); if !index::parent_exists(&mut self.db, &parent_hash).await? { - return Ok(InnerNodesStatus::default()); + return Ok(Vec::new()); } - let mut new_children = Vec::with_capacity(nodes.len()); + let mut statuses = Vec::with_capacity(nodes.len()); let nodes = nodes.into_inner(); for (_, remote_node) in &nodes { @@ -95,12 +95,16 @@ impl ClientWriter { } let local_node = inner_node::load(&mut self.db, &remote_node.hash).await?; - - if local_node - .map(|local_node| local_node.summary.is_outdated(&remote_node.summary)) - .unwrap_or(true) - { - new_children.push(*remote_node); + let local_node_summary = local_node + .map(|node| node.summary) + .unwrap_or(Summary::INCOMPLETE); + + if local_node_summary.is_outdated(&remote_node.summary) { + statuses.push(InnerNodeStatus { + hash: remote_node.hash, + local_block_presence: local_node_summary.block_presence, + remote_block_presence: remote_node.summary.block_presence, + }); } } @@ -111,7 +115,7 @@ impl ClientWriter { self.summary_updates.push(parent_hash); } - Ok(InnerNodesStatus { new_children }) + Ok(statuses) } pub async fn save_leaf_nodes( @@ -349,10 +353,10 @@ impl ClientReader { } } -#[derive(Default)] -pub(crate) struct InnerNodesStatus { - /// Which of the received nodes should we request the children of. - pub new_children: Vec, +pub(crate) struct InnerNodeStatus { + pub hash: Hash, + pub local_block_presence: MultiBlockPresence, + pub remote_block_presence: MultiBlockPresence, } #[derive(Default)] @@ -835,11 +839,11 @@ mod tests { // Try to save the inner nodes let (hash, inner_nodes) = snapshot.inner_sets().next().unwrap(); let mut writer = store.begin_client_write().await.unwrap(); - let status = writer + let statuses = writer .save_inner_nodes(inner_nodes.clone().into()) .await .unwrap(); - assert!(status.new_children.is_empty()); + assert!(statuses.is_empty()); writer.commit().await.unwrap(); // The orphaned inner nodes were not written to the db. diff --git a/lib/src/store/root_node.rs b/lib/src/store/root_node.rs index 10bb9434e..f4e90d2dc 100644 --- a/lib/src/store/root_node.rs +++ b/lib/src/store/root_node.rs @@ -16,27 +16,35 @@ use std::{cmp::Ordering, fmt, future}; /// Status of receiving a root node #[derive(PartialEq, Eq, Debug)] pub(crate) enum RootNodeStatus { - /// The node represents a new snapshot - write it into the store and requests its children. - NewSnapshot, - /// We already have the node but its block presence indicated the peer potentially has some - /// blocks we don't have. Don't write it into the store but do request its children. - NewBlocks, + /// The incoming node is more up to date than the nodos we already have. Contains info about + /// which part of the node is more up to date and the block presence of the latest node we have + /// from the same branch as the incoming node. + Updated(RootNodeUpdated, MultiBlockPresence), /// The node is outdated - discard it. Outdated, } +#[derive(PartialEq, Eq, Debug)] +pub(crate) enum RootNodeUpdated { + /// The node represents a new snapshot - write it into the store and requests its children. + Snapshot, + /// The node represents a snapshot we already have but its block presence indicated the peer potentially has some + /// blocks we don't have. Don't write it into the store but do request its children. + Blocks, +} + impl RootNodeStatus { - pub fn request_children(&self) -> bool { + pub fn request_children(&self) -> Option { match self { - Self::NewSnapshot | Self::NewBlocks => true, - Self::Outdated => false, + Self::Updated(_, block_presence) => Some(*block_presence), + Self::Outdated => None, } } pub fn write(&self) -> bool { match self { - Self::NewSnapshot => true, - Self::NewBlocks | Self::Outdated => false, + Self::Updated(RootNodeUpdated::Snapshot, _) => true, + Self::Updated(RootNodeUpdated::Blocks, _) | Self::Outdated => false, } } } @@ -44,8 +52,8 @@ impl RootNodeStatus { impl fmt::Display for RootNodeStatus { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Self::NewSnapshot => write!(f, "new snapshot"), - Self::NewBlocks => write!(f, "new blocks"), + Self::Updated(RootNodeUpdated::Snapshot, _) => write!(f, "new snapshot"), + Self::Updated(RootNodeUpdated::Blocks, _) => write!(f, "new blocks"), Self::Outdated => write!(f, "outdated"), } } @@ -552,7 +560,8 @@ pub(super) async fn status( new_proof: &Proof, new_block_presence: &MultiBlockPresence, ) -> Result { - let mut status = RootNodeStatus::NewSnapshot; + let mut updated = RootNodeUpdated::Snapshot; + let mut block_presence = MultiBlockPresence::None; let mut old_nodes = load_all_latest(conn); while let Some(old_node) = old_nodes.try_next().await? { @@ -563,7 +572,7 @@ pub(super) async fn status( Some(Ordering::Less) => { // The incoming node is outdated compared to at least one existing node - discard // it. - status = RootNodeStatus::Outdated; + return Ok(RootNodeStatus::Outdated); } Some(Ordering::Equal) => { if new_proof.hash == old_node.proof.hash { @@ -577,9 +586,9 @@ pub(super) async fn status( .block_presence .is_outdated(new_block_presence) { - status = RootNodeStatus::NewBlocks; + updated = RootNodeUpdated::Blocks; } else { - status = RootNodeStatus::Outdated; + return Ok(RootNodeStatus::Outdated); } } else { tracing::warn!( @@ -591,7 +600,7 @@ pub(super) async fn status( "Received root node invalid - broken invariant: same vv but different hash" ); - status = RootNodeStatus::Outdated; + return Ok(RootNodeStatus::Outdated); } } Some(Ordering::Greater) => (), @@ -604,17 +613,17 @@ pub(super) async fn status( "Received root node invalid - broken invariant: concurrency within branch is not allowed" ); - status = RootNodeStatus::Outdated; + return Ok(RootNodeStatus::Outdated); } } } - if matches!(status, RootNodeStatus::Outdated) { - break; + if old_node.proof.writer_id == new_proof.writer_id { + block_presence = old_node.summary.block_presence; } } - Ok(status) + Ok(RootNodeStatus::Updated(updated, block_presence)) } pub(super) async fn debug_print(conn: &mut db::Connection, printer: DebugPrinter) { diff --git a/lib/tests/sync.rs b/lib/tests/sync.rs index 25d559113..a1432e7b3 100644 --- a/lib/tests/sync.rs +++ b/lib/tests/sync.rs @@ -970,7 +970,7 @@ fn redownload_expired_blocks() { let (finish_origin_tx, mut finish_origin_rx) = mpsc::channel(1); let (finish_cache_tx, mut finish_cache_rx) = mpsc::channel(1); - let test_content = Arc::new(common::random_bytes(2 * 1024 * 1024)); + let test_content = Arc::new(common::random_bytes(8 * BLOCK_SIZE - BLOB_HEADER_SIZE)); // Wait until the number of blocks is the `expected`. // @@ -982,7 +982,8 @@ fn redownload_expired_blocks() { .with_max_interval(Duration::from_millis(500)) .with_randomization_factor(0.0) .with_multiplier(2.0) - .with_max_elapsed_time(Some(Duration::from_secs(60))) + // .with_max_elapsed_time(Some(Duration::from_secs(60))) + .with_max_elapsed_time(Some(Duration::from_secs(10))) .build(); loop { From 7c355ca38debea099cbbfb27b1c5f1d897dfa53a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Tue, 1 Oct 2024 16:46:58 +0200 Subject: [PATCH 47/55] Add cookie to RootNode request/response --- lib/src/network/client.rs | 85 ++++++++++++++----- lib/src/network/message.rs | 24 +++++- lib/src/network/request_tracker.rs | 6 +- lib/src/network/request_tracker/simulation.rs | 61 ++++++++----- lib/src/network/request_tracker/tests.rs | 10 ++- lib/src/network/server.rs | 53 ++++++++---- 6 files changed, 168 insertions(+), 71 deletions(-) diff --git a/lib/src/network/client.rs b/lib/src/network/client.rs index 886c36170..317feaa8b 100644 --- a/lib/src/network/client.rs +++ b/lib/src/network/client.rs @@ -112,12 +112,18 @@ impl Inner { self.vault.monitor.responses_received.increment(1); match response { - Response::RootNode(proof, block_presence, debug) => { - persistable.push(PersistableResponse::RootNode( + Response::RootNode { + proof, + cookie, + block_presence, + debug, + } => { + persistable.push(PersistableResponse::RootNode { proof, + cookie, block_presence, debug, - )); + }); } Response::InnerNodes(nodes, _, debug) => { persistable.push(PersistableResponse::InnerNodes(nodes.into(), debug)); @@ -134,9 +140,11 @@ impl Inner { Response::BlockOffer(block_id, debug) => { ephemeral.push(EphemeralResponse::BlockOffer(block_id, debug)); } - Response::RootNodeError(writer_id, _) => { + Response::RootNodeError { + writer_id, cookie, .. + } => { self.request_tracker - .failure(MessageKey::RootNode(writer_id)); + .failure(MessageKey::RootNode(writer_id, cookie)); } Response::ChildNodesError(hash, _, _) => { self.request_tracker.failure(MessageKey::ChildNodes(hash)); @@ -175,8 +183,13 @@ impl Inner { for response in batch.drain(..) { match response { - PersistableResponse::RootNode(proof, block_presence, debug) => { - self.handle_root_node(&mut writer, proof, block_presence, debug) + PersistableResponse::RootNode { + proof, + cookie, + block_presence, + debug, + } => { + self.handle_root_node(&mut writer, proof, cookie, block_presence, debug) .await?; } PersistableResponse::InnerNodes(nodes, debug) => { @@ -222,6 +235,7 @@ impl Inner { vv = ?proof.version_vector, hash = ?proof.hash, ?block_presence, + cookie = cookie, ?debug_payload, ), err(Debug) @@ -230,6 +244,7 @@ impl Inner { &self, writer: &mut ClientWriter, proof: UntrustedProof, + cookie: u64, block_presence: MultiBlockPresence, debug_payload: DebugResponse, ) -> Result<()> { @@ -253,7 +268,7 @@ impl Inner { tracing::debug!("Received root node - {status}"); self.request_tracker.success( - MessageKey::RootNode(writer_id), + MessageKey::RootNode(writer_id, cookie), status .request_children() .map(|local_block_presence| { @@ -467,15 +482,16 @@ impl Inner { // before the block is marked as offered and only then we proceed with requesting it. This // can take arbitrarily long (even indefinitely). // - // By requesting the root node again immediatelly, we ensure that the missing block is + // By requesting the root node again immediately, we ensure that the missing block is // requested as soon as possible. fn refresh_branches(&self, branches: impl IntoIterator) { for writer_id in branches { self.request_tracker - .initial(CandidateRequest::new(Request::RootNode( + .initial(CandidateRequest::new(Request::RootNode { writer_id, - DebugRequest::start(), - ))); + cookie: next_root_node_cookie(), + debug: DebugRequest::start(), + })); } } @@ -527,7 +543,12 @@ enum EphemeralResponse { /// Response whose processing requires write access to the store. enum PersistableResponse { - RootNode(UntrustedProof, MultiBlockPresence, DebugResponse), + RootNode { + proof: UntrustedProof, + cookie: u64, + block_presence: MultiBlockPresence, + debug: DebugResponse, + }, InnerNodes(CacheHash, DebugResponse), LeafNodes(CacheHash, DebugResponse), Block(Block, DebugResponse), @@ -550,6 +571,22 @@ fn block_offer_state(root_node_state: NodeState) -> Option { } } +// Generate cookie for the next `RootNode` request. This value is guaranteed to be non-zero (zero is +// used for unsolicited responses). +fn next_root_node_cookie() -> u64 { + use std::sync::atomic::{AtomicU64, Ordering}; + + static NEXT: AtomicU64 = AtomicU64::new(1); + + loop { + let next = NEXT.fetch_add(1, Ordering::Relaxed); + + if next != 0 { + break next; + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -576,17 +613,18 @@ mod tests { // Receive invalid root node from the remote replica. let invalid_write_keys = Keypair::random(); inner - .handle_persistable_responses(&mut vec![PersistableResponse::RootNode( - Proof::new( + .handle_persistable_responses(&mut vec![PersistableResponse::RootNode { + proof: Proof::new( remote_id, VersionVector::first(remote_id), *EMPTY_INNER_HASH, &invalid_write_keys, ) .into(), - MultiBlockPresence::None, - DebugResponse::unsolicited(), - )]) + block_presence: MultiBlockPresence::None, + cookie: 0, + debug: DebugResponse::unsolicited(), + }]) .await .unwrap(); @@ -610,17 +648,18 @@ mod tests { let remote_id = PublicKey::random(); inner - .handle_persistable_responses(&mut vec![PersistableResponse::RootNode( - Proof::new( + .handle_persistable_responses(&mut vec![PersistableResponse::RootNode { + proof: Proof::new( remote_id, VersionVector::new(), *EMPTY_INNER_HASH, &secrets.write_keys, ) .into(), - MultiBlockPresence::None, - DebugResponse::unsolicited(), - )]) + block_presence: MultiBlockPresence::None, + cookie: 0, + debug: DebugResponse::unsolicited(), + }]) .await .unwrap(); diff --git a/lib/src/network/message.rs b/lib/src/network/message.rs index 061b9bab0..2de6c9b8a 100644 --- a/lib/src/network/message.rs +++ b/lib/src/network/message.rs @@ -14,7 +14,14 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)] pub(crate) enum Request { /// Request the latest root node of the given writer. - RootNode(PublicKey, DebugRequest), + RootNode { + writer_id: PublicKey, + // This value is returned in the response without change. It can be used to distinguish + // multiple otherwise identical requests. This is useful because multiple identical + // `RootNode` requests can yield different responses. + cookie: u64, + debug: DebugRequest, + }, /// Request child nodes of the given parent node. ChildNodes(Hash, ResponseDisambiguator, DebugRequest), /// Request block with the given id. @@ -38,9 +45,20 @@ pub(crate) enum Response { /// Send the latest root node of this replica to another replica. /// NOTE: This is both a response and notification - the server sends this as a response to /// `Request::RootNode` but also on its own when it detects change in the repo. - RootNode(UntrustedProof, MultiBlockPresence, DebugResponse), + RootNode { + proof: UntrustedProof, + block_presence: MultiBlockPresence, + // If this is a reponse, the `cookie` value from the request. If this is a notification, + // zero. + cookie: u64, + debug: DebugResponse, + }, /// Send that a RootNode request failed - RootNodeError(PublicKey, DebugResponse), + RootNodeError { + writer_id: PublicKey, + cookie: u64, + debug: DebugResponse, + }, /// Send inner nodes. InnerNodes(InnerNodes, ResponseDisambiguator, DebugResponse), /// Send leaf nodes. diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index 2aab3adab..b22acc17f 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -227,7 +227,7 @@ pub(super) struct PendingRequest { /// Key identifying a request and its corresponding response. #[derive(Clone, Copy, Eq, PartialEq, Hash, Ord, PartialOrd, Debug)] pub(super) enum MessageKey { - RootNode(PublicKey), + RootNode(PublicKey, u64), ChildNodes(Hash), Block(BlockId), } @@ -235,7 +235,9 @@ pub(super) enum MessageKey { impl<'a> From<&'a Request> for MessageKey { fn from(request: &'a Request) -> Self { match request { - Request::RootNode(writer_id, _) => MessageKey::RootNode(*writer_id), + Request::RootNode { + writer_id, cookie, .. + } => MessageKey::RootNode(*writer_id, *cookie), Request::ChildNodes(hash, _, _) => MessageKey::ChildNodes(*hash), Request::Block(block_id, _) => MessageKey::Block(*block_id), } diff --git a/lib/src/network/request_tracker/simulation.rs b/lib/src/network/request_tracker/simulation.rs index c2cbe9e2c..98d356a06 100644 --- a/lib/src/network/request_tracker/simulation.rs +++ b/lib/src/network/request_tracker/simulation.rs @@ -111,14 +111,14 @@ impl Simulation { // In case of failure, cancel the request so it can be retried without it // triggering assertion failure. let key = match response { - Response::RootNodeError(writer_id, _) => { - Some(MessageKey::RootNode(writer_id)) - } + Response::RootNodeError { + writer_id, cookie, .. + } => Some(MessageKey::RootNode(writer_id, cookie)), Response::ChildNodesError(hash, _, _) => { Some(MessageKey::ChildNodes(hash)) } Response::BlockError(block_id, _) => Some(MessageKey::Block(block_id)), - Response::RootNode(..) + Response::RootNode { .. } | Response::InnerNodes(..) | Response::LeafNodes(..) | Response::Block(..) @@ -186,14 +186,19 @@ impl TestClient { fn handle_response(&mut self, response: Response, snapshot: &mut Snapshot) { match response { - Response::RootNode(proof, block_presence, debug_payload) => { + Response::RootNode { + proof, + block_presence, + cookie, + debug, + } => { let requests = snapshot .insert_root(proof.hash, block_presence) .then_some( CandidateRequest::new(Request::ChildNodes( proof.hash, ResponseDisambiguator::new(block_presence), - debug_payload.follow_up(), + debug.follow_up(), )) .variant(RequestVariant::new( MultiBlockPresence::None, @@ -204,7 +209,7 @@ impl TestClient { .collect(); self.tracker_client - .success(MessageKey::RootNode(proof.writer_id), requests); + .success(MessageKey::RootNode(proof.writer_id, cookie), requests); } Response::InnerNodes(nodes, _disambiguator, debug_payload) => { let parent_hash = nodes.hash(); @@ -253,8 +258,11 @@ impl TestClient { self.tracker_client .success(MessageKey::Block(block_id), vec![]); } - Response::RootNodeError(writer_id, _debug_payload) => { - self.tracker_client.failure(MessageKey::RootNode(writer_id)); + Response::RootNodeError { + writer_id, cookie, .. + } => { + self.tracker_client + .failure(MessageKey::RootNode(writer_id, cookie)); } Response::ChildNodesError(hash, _disambiguator, _debug_payload) => { self.tracker_client.failure(MessageKey::ChildNodes(hash)); @@ -291,11 +299,12 @@ impl TestServer { &write_keys, )); - let outbox = [Response::RootNode( - proof.clone(), - snapshot.root_summary().block_presence, - DebugResponse::unsolicited(), - )] + let outbox = [Response::RootNode { + proof: proof.clone(), + block_presence: snapshot.root_summary().block_presence, + cookie: 0, + debug: DebugResponse::unsolicited(), + }] .into(); Self { @@ -308,7 +317,11 @@ impl TestServer { fn handle_request(&mut self, request: Request) { match request { - Request::RootNode(writer_id, debug_payload) => { + Request::RootNode { + writer_id, + cookie, + debug, + } => { if writer_id == self.writer_id { let proof = Proof::new( writer_id, @@ -317,14 +330,18 @@ impl TestServer { &self.write_keys, ); - self.outbox.push_back(Response::RootNode( - proof.into(), - self.snapshot.root_summary().block_presence, - debug_payload.reply(), - )); + self.outbox.push_back(Response::RootNode { + proof: proof.into(), + block_presence: self.snapshot.root_summary().block_presence, + cookie, + debug: debug.reply(), + }); } else { - self.outbox - .push_back(Response::RootNodeError(writer_id, debug_payload.reply())); + self.outbox.push_back(Response::RootNodeError { + writer_id, + cookie, + debug: debug.reply(), + }); } } Request::ChildNodes(hash, disambiguator, debug_payload) => { diff --git a/lib/src/network/request_tracker/tests.rs b/lib/src/network/request_tracker/tests.rs index f64687f2b..b121f858c 100644 --- a/lib/src/network/request_tracker/tests.rs +++ b/lib/src/network/request_tracker/tests.rs @@ -155,7 +155,7 @@ async fn timeout() { let (client_a, mut request_rx_a) = tracker.new_client(); let (client_b, mut request_rx_b) = tracker.new_client(); - let preceding_request_key = MessageKey::RootNode(PublicKey::generate(&mut rng)); + let preceding_request_key = MessageKey::RootNode(PublicKey::generate(&mut rng), 0); let request = Request::ChildNodes( rng.gen(), ResponseDisambiguator::new(MultiBlockPresence::Full), @@ -208,7 +208,7 @@ async fn drop_uncommitted_client() { let (client_a, mut request_rx_a) = tracker.new_client(); let (client_b, mut request_rx_b) = tracker.new_client(); - let preceding_request_key = MessageKey::RootNode(PublicKey::generate(&mut rng)); + let preceding_request_key = MessageKey::RootNode(PublicKey::generate(&mut rng), 0); let request = Request::ChildNodes( rng.gen(), ResponseDisambiguator::new(MultiBlockPresence::Full), @@ -262,7 +262,11 @@ async fn multiple_responses_to_identical_requests() { let (tracker, mut worker) = build(); let (client, mut request_rx) = tracker.new_client(); - let initial_request = Request::RootNode(PublicKey::generate(&mut rng), DebugRequest::start()); + let initial_request = Request::RootNode { + writer_id: PublicKey::generate(&mut rng), + cookie: 0, + debug: DebugRequest::start(), + }; let followup_request = Request::ChildNodes( rng.gen(), ResponseDisambiguator::new(MultiBlockPresence::Full), diff --git a/lib/src/network/server.rs b/lib/src/network/server.rs index 630f0055f..9c7ac72c1 100644 --- a/lib/src/network/server.rs +++ b/lib/src/network/server.rs @@ -95,7 +95,11 @@ impl Inner { self.vault.monitor.requests_received.increment(1); match request { - Request::RootNode(public_key, debug) => self.handle_root_node(public_key, debug).await, + Request::RootNode { + writer_id, + cookie, + debug, + } => self.handle_root_node(writer_id, cookie, debug).await, Request::ChildNodes(hash, disambiguator, debug) => { self.handle_child_nodes(hash, disambiguator, debug).await } @@ -104,7 +108,12 @@ impl Inner { } #[instrument(skip(self, debug), err(Debug))] - async fn handle_root_node(&self, writer_id: PublicKey, debug: DebugRequest) -> Result<()> { + async fn handle_root_node( + &self, + writer_id: PublicKey, + cookie: u64, + debug: DebugRequest, + ) -> Result<()> { let root_node = self .vault .store() @@ -117,24 +126,33 @@ impl Inner { Ok(node) => { tracing::trace!("root node found"); - let response = Response::RootNode( - node.proof.into(), - node.summary.block_presence, - debug.reply(), - ); + let response = Response::RootNode { + proof: node.proof.into(), + block_presence: node.summary.block_presence, + cookie, + debug: debug.reply(), + }; self.enqueue_response(response).await; Ok(()) } Err(store::Error::BranchNotFound) => { tracing::trace!("root node not found"); - self.enqueue_response(Response::RootNodeError(writer_id, debug.reply())) - .await; + self.enqueue_response(Response::RootNodeError { + writer_id, + cookie, + debug: debug.reply(), + }) + .await; Ok(()) } Err(error) => { - self.enqueue_response(Response::RootNodeError(writer_id, debug.reply())) - .await; + self.enqueue_response(Response::RootNodeError { + writer_id, + cookie, + debug: debug.reply(), + }) + .await; Err(error.into()) } } @@ -289,14 +307,13 @@ impl Inner { "send_root_node", ); - let response = Response::RootNode( - root_node.proof.into(), - root_node.summary.block_presence, - DebugResponse::unsolicited(), - ); + let response = Response::RootNode { + proof: root_node.proof.into(), + block_presence: root_node.summary.block_presence, + cookie: 0, + debug: DebugResponse::unsolicited(), + }; - // TODO: maybe this should use different metric counter, to distinguish - // solicited/unsolicited responses? self.enqueue_response(response).await; Ok(()) From 32533c84f81d96d2b36c4c01de3b7b0d202fd922 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Thu, 3 Oct 2024 12:20:52 +0200 Subject: [PATCH 48/55] Replace twox-hash with xxhash-rust It's better maintained, possibly faster and doesn't allocate. --- lib/Cargo.toml | 2 +- lib/src/network/request_tracker.rs | 18 ++++++++---------- lib/src/protocol/summary.rs | 12 ++++++------ 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 06810e84d..77b11fbfa 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -77,7 +77,7 @@ tokio-util = { workspace = true, features = ["time"] } tracing = { workspace = true } tracing-subscriber = { workspace = true, features = [ "env-filter" ] } turmoil = { workspace = true, optional = true } -twox-hash = { version = "1.6.3", default-features = false } +xxhash-rust = { version = "0.8.12", default-features = false, features = ["xxh3"] } urlencoding = "2.1.0" vint64 = "1.0.1" zeroize = "1.6.0" diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index b22acc17f..95915d84e 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -14,16 +14,14 @@ use crate::{ }; use std::{ collections::{hash_map::Entry, VecDeque}, - fmt, - hash::Hasher, - iter, mem, + fmt, iter, mem, sync::atomic::{AtomicUsize, Ordering}, }; use tokio::{select, sync::mpsc, task}; use tokio_stream::StreamExt; use tokio_util::time::{delay_queue, DelayQueue}; use tracing::{instrument, Instrument, Span}; -use twox_hash::xxh3::{Hash128, HasherExt}; +use xxhash_rust::xxh3::Xxh3Default; /// Keeps track of in-flight requests. Falls back on another peer in case the request failed (due to /// error response, timeout or disconnection). Evenly distributes the requests between the peers @@ -186,25 +184,25 @@ impl CandidateRequest { } #[derive(Default, Clone, Copy, Eq, PartialEq, Hash)] -pub(super) struct RequestVariant([u8; 16]); +pub(super) struct RequestVariant(u128); impl RequestVariant { pub fn new( local_block_presence: MultiBlockPresence, remote_block_presence: MultiBlockPresence, ) -> Self { - let mut hasher = Hash128::default(); + let mut hasher = Xxh3Default::default(); - hasher.write(local_block_presence.checksum()); - hasher.write(remote_block_presence.checksum()); + hasher.update(local_block_presence.checksum()); + hasher.update(remote_block_presence.checksum()); - Self(hasher.finish_ext().to_le_bytes()) + Self(hasher.digest128()) } } impl fmt::Debug for RequestVariant { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{:<8x}", hex_fmt::HexFmt(&self.0)) + write!(f, "{:<8x}", hex_fmt::HexFmt(&self.0.to_le_bytes())) } } diff --git a/lib/src/protocol/summary.rs b/lib/src/protocol/summary.rs index 1ab73bb12..95862ca11 100644 --- a/lib/src/protocol/summary.rs +++ b/lib/src/protocol/summary.rs @@ -6,9 +6,9 @@ use sqlx::{ sqlite::{SqliteArgumentValue, SqliteTypeInfo, SqliteValueRef}, Decode, Encode, Sqlite, Type, }; -use std::{fmt, hash::Hasher}; +use std::fmt; use thiserror::Error; -use twox_hash::xxh3::{Hash128, HasherExt}; +use xxhash_rust::xxh3::Xxh3Default; /// Summary info of a snapshot subtree. Contains whether the subtree has been completely downloaded /// and the number of missing blocks in the subtree. @@ -309,7 +309,7 @@ impl fmt::Debug for MultiBlockPresence { struct MultiBlockPresenceBuilder { state: BuilderState, - hasher: Hash128, + hasher: Xxh3Default, } #[derive(Copy, Clone, Debug)] @@ -324,12 +324,12 @@ impl MultiBlockPresenceBuilder { fn new() -> Self { Self { state: BuilderState::Init, - hasher: Hash128::default(), + hasher: Xxh3Default::default(), } } fn update(&mut self, p: MultiBlockPresence) { - self.hasher.write(p.checksum()); + self.hasher.update(p.checksum()); self.state = match (self.state, p) { (BuilderState::Init, MultiBlockPresence::None) => BuilderState::None, @@ -349,7 +349,7 @@ impl MultiBlockPresenceBuilder { match self.state { BuilderState::Init | BuilderState::None => MultiBlockPresence::None, BuilderState::Some => { - MultiBlockPresence::Some(clamp(self.hasher.finish_ext()).to_le_bytes()) + MultiBlockPresence::Some(clamp(self.hasher.digest128()).to_le_bytes()) } BuilderState::Full => MultiBlockPresence::Full, } From 22be5a469426fb4f0dd723f4d0412c252a1e57b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Thu, 3 Oct 2024 12:34:02 +0200 Subject: [PATCH 49/55] Remove ResponseDisambiguator --- lib/src/network/client.rs | 17 +++--- lib/src/network/message.rs | 20 ++----- lib/src/network/request_tracker.rs | 2 +- lib/src/network/request_tracker/graph.rs | 53 ++++--------------- lib/src/network/request_tracker/simulation.rs | 48 ++++++----------- lib/src/network/request_tracker/tests.rs | 20 ++----- lib/src/network/server.rs | 37 ++++--------- 7 files changed, 50 insertions(+), 147 deletions(-) diff --git a/lib/src/network/client.rs b/lib/src/network/client.rs index 317feaa8b..bf744b29c 100644 --- a/lib/src/network/client.rs +++ b/lib/src/network/client.rs @@ -1,7 +1,7 @@ use super::{ constants::RESPONSE_BATCH_SIZE, debug_payload::{DebugRequest, DebugResponse}, - message::{Message, Response, ResponseDisambiguator}, + message::{Message, Response}, request_tracker::{PendingRequest, RequestTracker, RequestTrackerClient}, }; use crate::{ @@ -125,10 +125,10 @@ impl Inner { debug, }); } - Response::InnerNodes(nodes, _, debug) => { + Response::InnerNodes(nodes, debug) => { persistable.push(PersistableResponse::InnerNodes(nodes.into(), debug)); } - Response::LeafNodes(nodes, _, debug) => { + Response::LeafNodes(nodes, debug) => { persistable.push(PersistableResponse::LeafNodes(nodes.into(), debug)); } Response::Block(block_content, block_nonce, debug) => { @@ -146,7 +146,7 @@ impl Inner { self.request_tracker .failure(MessageKey::RootNode(writer_id, cookie)); } - Response::ChildNodesError(hash, _, _) => { + Response::ChildNodesError(hash, _) => { self.request_tracker.failure(MessageKey::ChildNodes(hash)); } Response::BlockError(block_id, _) => { @@ -272,12 +272,8 @@ impl Inner { status .request_children() .map(|local_block_presence| { - CandidateRequest::new(Request::ChildNodes( - hash, - ResponseDisambiguator::new(block_presence), - debug_payload.follow_up(), - )) - .variant(RequestVariant::new(local_block_presence, block_presence)) + CandidateRequest::new(Request::ChildNodes(hash, debug_payload.follow_up())) + .variant(RequestVariant::new(local_block_presence, block_presence)) }) .into_iter() .collect(), @@ -306,7 +302,6 @@ impl Inner { .map(|status| { CandidateRequest::new(Request::ChildNodes( status.hash, - ResponseDisambiguator::new(status.remote_block_presence), debug_payload.follow_up(), )) .variant(RequestVariant::new( diff --git a/lib/src/network/message.rs b/lib/src/network/message.rs index 2de6c9b8a..8c03efb8d 100644 --- a/lib/src/network/message.rs +++ b/lib/src/network/message.rs @@ -23,23 +23,11 @@ pub(crate) enum Request { debug: DebugRequest, }, /// Request child nodes of the given parent node. - ChildNodes(Hash, ResponseDisambiguator, DebugRequest), + ChildNodes(Hash, DebugRequest), /// Request block with the given id. Block(BlockId, DebugRequest), } -/// ResponseDisambiguator is used to uniquelly assign a response to a request. -/// What we want to avoid is that an outdated response clears out a newer pending request. -#[derive(Clone, Copy, Eq, PartialEq, Hash, Serialize, Deserialize, Debug)] -#[serde(transparent)] -pub(crate) struct ResponseDisambiguator(MultiBlockPresence); - -impl ResponseDisambiguator { - pub fn new(multi_block_presence: MultiBlockPresence) -> Self { - Self(multi_block_presence) - } -} - #[derive(Serialize, Deserialize, Debug)] pub(crate) enum Response { /// Send the latest root node of this replica to another replica. @@ -60,11 +48,11 @@ pub(crate) enum Response { debug: DebugResponse, }, /// Send inner nodes. - InnerNodes(InnerNodes, ResponseDisambiguator, DebugResponse), + InnerNodes(InnerNodes, DebugResponse), /// Send leaf nodes. - LeafNodes(LeafNodes, ResponseDisambiguator, DebugResponse), + LeafNodes(LeafNodes, DebugResponse), /// Send that a ChildNodes request failed - ChildNodesError(Hash, ResponseDisambiguator, DebugResponse), + ChildNodesError(Hash, DebugResponse), /// Send a notification that a block became available on this replica. /// NOTE: This is always unsolicited - the server sends it on its own when it detects a newly /// received block. diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index 95915d84e..113840a36 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -236,7 +236,7 @@ impl<'a> From<&'a Request> for MessageKey { Request::RootNode { writer_id, cookie, .. } => MessageKey::RootNode(*writer_id, *cookie), - Request::ChildNodes(hash, _, _) => MessageKey::ChildNodes(*hash), + Request::ChildNodes(hash, _) => MessageKey::ChildNodes(*hash), Request::Block(block_id, _) => MessageKey::Block(*block_id), } } diff --git a/lib/src/network/request_tracker/graph.rs b/lib/src/network/request_tracker/graph.rs index cfa4e71d0..589fe0291 100644 --- a/lib/src/network/request_tracker/graph.rs +++ b/lib/src/network/request_tracker/graph.rs @@ -160,10 +160,7 @@ impl Node { #[cfg(test)] mod tests { use super::*; - use crate::{ - network::{debug_payload::DebugRequest, message::ResponseDisambiguator}, - protocol::MultiBlockPresence, - }; + use crate::{network::debug_payload::DebugRequest, protocol::MultiBlockPresence}; use rand::Rng; #[test] @@ -173,11 +170,7 @@ mod tests { assert_eq!(graph.requests().len(), 0); - let parent_request = Request::ChildNodes( - rng.gen(), - ResponseDisambiguator::new(MultiBlockPresence::Full), - DebugRequest::start(), - ); + let parent_request = Request::ChildNodes(rng.gen(), DebugRequest::start()); let parent_node_key = graph.get_or_insert( PendingRequest { @@ -198,11 +191,7 @@ mod tests { assert_eq!(node.children().len(), 0); assert_eq!(node.request().payload, parent_request); - let child_request = Request::ChildNodes( - rng.gen(), - ResponseDisambiguator::new(MultiBlockPresence::Full), - DebugRequest::start(), - ); + let child_request = Request::ChildNodes(rng.gen(), DebugRequest::start()); let child_node_key = graph.get_or_insert( PendingRequest { @@ -244,11 +233,7 @@ mod tests { assert_eq!(graph.requests().len(), 0); - let request = Request::ChildNodes( - rng.gen(), - ResponseDisambiguator::new(MultiBlockPresence::Full), - DebugRequest::start(), - ); + let request = Request::ChildNodes(rng.gen(), DebugRequest::start()); let node_key0 = graph.get_or_insert( PendingRequest { @@ -280,20 +265,12 @@ mod tests { let hash = rng.gen(); let parent_request_0 = PendingRequest { - payload: Request::ChildNodes( - hash, - ResponseDisambiguator::new(MultiBlockPresence::None), - DebugRequest::start(), - ), + payload: Request::ChildNodes(hash, DebugRequest::start()), variant: RequestVariant::new(MultiBlockPresence::None, MultiBlockPresence::None), }; let parent_request_1 = PendingRequest { - payload: Request::ChildNodes( - hash, - ResponseDisambiguator::new(MultiBlockPresence::Full), - DebugRequest::start(), - ), + payload: Request::ChildNodes(hash, DebugRequest::start()), variant: RequestVariant::new(MultiBlockPresence::None, MultiBlockPresence::Full), }; @@ -371,29 +348,17 @@ mod tests { let mut graph = Graph::new(); let parent_request = PendingRequest { - payload: Request::ChildNodes( - rng.gen(), - ResponseDisambiguator::new(MultiBlockPresence::Full), - DebugRequest::start(), - ), + payload: Request::ChildNodes(rng.gen(), DebugRequest::start()), variant: RequestVariant::default(), }; let child_request_0 = PendingRequest { - payload: Request::ChildNodes( - rng.gen(), - ResponseDisambiguator::new(MultiBlockPresence::Full), - DebugRequest::start(), - ), + payload: Request::ChildNodes(rng.gen(), DebugRequest::start()), variant: RequestVariant::default(), }; let child_request_1 = PendingRequest { - payload: Request::ChildNodes( - rng.gen(), - ResponseDisambiguator::new(MultiBlockPresence::Full), - DebugRequest::start(), - ), + payload: Request::ChildNodes(rng.gen(), DebugRequest::start()), variant: RequestVariant::default(), }; diff --git a/lib/src/network/request_tracker/simulation.rs b/lib/src/network/request_tracker/simulation.rs index 98d356a06..0194e936f 100644 --- a/lib/src/network/request_tracker/simulation.rs +++ b/lib/src/network/request_tracker/simulation.rs @@ -1,5 +1,5 @@ use super::{ - super::message::{Request, Response, ResponseDisambiguator}, + super::message::{Request, Response}, CandidateRequest, MessageKey, PendingRequest, RequestTracker, RequestTrackerClient, RequestVariant, }; @@ -114,7 +114,7 @@ impl Simulation { Response::RootNodeError { writer_id, cookie, .. } => Some(MessageKey::RootNode(writer_id, cookie)), - Response::ChildNodesError(hash, _, _) => { + Response::ChildNodesError(hash, _) => { Some(MessageKey::ChildNodes(hash)) } Response::BlockError(block_id, _) => Some(MessageKey::Block(block_id)), @@ -195,15 +195,11 @@ impl TestClient { let requests = snapshot .insert_root(proof.hash, block_presence) .then_some( - CandidateRequest::new(Request::ChildNodes( - proof.hash, - ResponseDisambiguator::new(block_presence), - debug.follow_up(), - )) - .variant(RequestVariant::new( - MultiBlockPresence::None, - block_presence, - )), + CandidateRequest::new(Request::ChildNodes(proof.hash, debug.follow_up())) + .variant(RequestVariant::new( + MultiBlockPresence::None, + block_presence, + )), ) .into_iter() .collect(); @@ -211,7 +207,7 @@ impl TestClient { self.tracker_client .success(MessageKey::RootNode(proof.writer_id, cookie), requests); } - Response::InnerNodes(nodes, _disambiguator, debug_payload) => { + Response::InnerNodes(nodes, debug_payload) => { let parent_hash = nodes.hash(); let nodes = snapshot.insert_inners(nodes); @@ -220,7 +216,6 @@ impl TestClient { .map(|(_, node)| { CandidateRequest::new(Request::ChildNodes( node.hash, - ResponseDisambiguator::new(node.summary.block_presence), debug_payload.follow_up(), )) .variant(RequestVariant::new( @@ -233,7 +228,7 @@ impl TestClient { self.tracker_client .success(MessageKey::ChildNodes(parent_hash), requests); } - Response::LeafNodes(nodes, _disambiguator, debug_payload) => { + Response::LeafNodes(nodes, debug_payload) => { let parent_hash = nodes.hash(); let nodes = snapshot.insert_leaves(nodes); let requests = nodes @@ -264,7 +259,7 @@ impl TestClient { self.tracker_client .failure(MessageKey::RootNode(writer_id, cookie)); } - Response::ChildNodesError(hash, _disambiguator, _debug_payload) => { + Response::ChildNodesError(hash, _debug_payload) => { self.tracker_client.failure(MessageKey::ChildNodes(hash)); } Response::BlockError(block_id, _debug_payload) => { @@ -344,25 +339,16 @@ impl TestServer { }); } } - Request::ChildNodes(hash, disambiguator, debug_payload) => { + Request::ChildNodes(hash, debug_payload) => { if let Some(nodes) = self.snapshot.get_inner_set(&hash) { - self.outbox.push_back(Response::InnerNodes( - nodes.clone(), - disambiguator, - debug_payload.reply(), - )); + self.outbox + .push_back(Response::InnerNodes(nodes.clone(), debug_payload.reply())); } else if let Some(nodes) = self.snapshot.get_leaf_set(&hash) { - self.outbox.push_back(Response::LeafNodes( - nodes.clone(), - disambiguator, - debug_payload.reply(), - )); + self.outbox + .push_back(Response::LeafNodes(nodes.clone(), debug_payload.reply())); } else { - self.outbox.push_back(Response::ChildNodesError( - hash, - disambiguator, - debug_payload.reply(), - )); + self.outbox + .push_back(Response::ChildNodesError(hash, debug_payload.reply())); } } Request::Block(block_id, debug_payload) => { diff --git a/lib/src/network/request_tracker/tests.rs b/lib/src/network/request_tracker/tests.rs index b121f858c..a25a112ae 100644 --- a/lib/src/network/request_tracker/tests.rs +++ b/lib/src/network/request_tracker/tests.rs @@ -1,6 +1,6 @@ use super::{simulation::Simulation, *}; use crate::{ - network::{debug_payload::DebugRequest, message::ResponseDisambiguator}, + network::debug_payload::DebugRequest, protocol::{ test_utils::{BlockState, Snapshot}, Block, @@ -156,11 +156,7 @@ async fn timeout() { let (client_b, mut request_rx_b) = tracker.new_client(); let preceding_request_key = MessageKey::RootNode(PublicKey::generate(&mut rng), 0); - let request = Request::ChildNodes( - rng.gen(), - ResponseDisambiguator::new(MultiBlockPresence::Full), - DebugRequest::start(), - ); + let request = Request::ChildNodes(rng.gen(), DebugRequest::start()); // Register the request with both clients. client_a.success( @@ -209,11 +205,7 @@ async fn drop_uncommitted_client() { let (client_b, mut request_rx_b) = tracker.new_client(); let preceding_request_key = MessageKey::RootNode(PublicKey::generate(&mut rng), 0); - let request = Request::ChildNodes( - rng.gen(), - ResponseDisambiguator::new(MultiBlockPresence::Full), - DebugRequest::start(), - ); + let request = Request::ChildNodes(rng.gen(), DebugRequest::start()); let request_key = MessageKey::from(&request); for client in [&client_a, &client_b] { @@ -267,11 +259,7 @@ async fn multiple_responses_to_identical_requests() { cookie: 0, debug: DebugRequest::start(), }; - let followup_request = Request::ChildNodes( - rng.gen(), - ResponseDisambiguator::new(MultiBlockPresence::Full), - DebugRequest::start(), - ); + let followup_request = Request::ChildNodes(rng.gen(), DebugRequest::start()); // Send initial root node request client.initial(CandidateRequest::new(initial_request.clone())); diff --git a/lib/src/network/server.rs b/lib/src/network/server.rs index 9c7ac72c1..cb3b62f5f 100644 --- a/lib/src/network/server.rs +++ b/lib/src/network/server.rs @@ -1,7 +1,7 @@ use super::{ constants::{INTEREST_TIMEOUT, MAX_UNCHOKED_DURATION}, debug_payload::{DebugRequest, DebugResponse}, - message::{Message, Request, Response, ResponseDisambiguator}, + message::{Message, Request, Response}, }; use crate::{ crypto::{sign::PublicKey, Hash}, @@ -100,9 +100,7 @@ impl Inner { cookie, debug, } => self.handle_root_node(writer_id, cookie, debug).await, - Request::ChildNodes(hash, disambiguator, debug) => { - self.handle_child_nodes(hash, disambiguator, debug).await - } + Request::ChildNodes(hash, debug) => self.handle_child_nodes(hash, debug).await, Request::Block(block_id, debug) => self.handle_block(block_id, debug).await, } } @@ -159,12 +157,7 @@ impl Inner { } #[instrument(skip(self, debug), err(Debug))] - async fn handle_child_nodes( - &self, - parent_hash: Hash, - disambiguator: ResponseDisambiguator, - debug: DebugRequest, - ) -> Result<()> { + async fn handle_child_nodes(&self, parent_hash: Hash, debug: DebugRequest) -> Result<()> { let mut reader = self.vault.store().acquire_read().await?; // At most one of these will be non-empty. @@ -176,31 +169,19 @@ impl Inner { if !inner_nodes.is_empty() || !leaf_nodes.is_empty() { if !inner_nodes.is_empty() { tracing::trace!("inner nodes found"); - self.enqueue_response(Response::InnerNodes( - inner_nodes, - disambiguator, - debug.reply(), - )) - .await; + self.enqueue_response(Response::InnerNodes(inner_nodes, debug.reply())) + .await; } if !leaf_nodes.is_empty() { tracing::trace!("leaf nodes found"); - self.enqueue_response(Response::LeafNodes( - leaf_nodes, - disambiguator, - debug.reply(), - )) - .await; + self.enqueue_response(Response::LeafNodes(leaf_nodes, debug.reply())) + .await; } } else { tracing::trace!("child nodes not found"); - self.enqueue_response(Response::ChildNodesError( - parent_hash, - disambiguator, - debug.reply(), - )) - .await; + self.enqueue_response(Response::ChildNodesError(parent_hash, debug.reply())) + .await; } Ok(()) From 82565f65927c109dfbb43d85c5b59f5324b0795a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Thu, 3 Oct 2024 12:56:06 +0200 Subject: [PATCH 50/55] Make request timeout configurable at runtime --- lib/src/network/mod.rs | 11 ++++++++++ lib/src/network/request_tracker.rs | 26 +++++++++++++++++++----- lib/src/network/request_tracker/tests.rs | 2 +- lib/tests/malice.rs | 7 +++++++ 4 files changed, 40 insertions(+), 6 deletions(-) diff --git a/lib/src/network/mod.rs b/lib/src/network/mod.rs index 92e956ab2..74e146439 100644 --- a/lib/src/network/mod.rs +++ b/lib/src/network/mod.rs @@ -38,6 +38,7 @@ pub use self::{ runtime_id::{PublicRuntimeId, SecretRuntimeId}, stats::Stats, }; +use constants::REQUEST_TIMEOUT; pub use net::stun::NatBehavior; use request_tracker::RequestTracker; @@ -356,6 +357,8 @@ impl Network { pex.set_enabled(pex_enabled); let request_tracker = RequestTracker::new(); + request_tracker.set_timeout(REQUEST_TIMEOUT); + // TODO: This should be global, not per repo let response_limiter = Arc::new(Semaphore::new(MAX_UNCHOKED_COUNT)); let stats_tracker = StatsTracker::default(); @@ -400,6 +403,14 @@ impl Network { shutdown_peers(peers).await; } + + /// Change the sync protocol request timeout. Useful mostly for testing and benchmarking as the + /// default value should be sufficient for most use cases. + pub fn set_request_timeout(&self, timeout: Duration) { + for (_, holder) in &self.inner.registry.lock().unwrap().repos { + holder.request_tracker.set_timeout(timeout); + } + } } pub struct Registration { diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index 113840a36..bbbc1c606 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -6,7 +6,7 @@ mod simulation; mod tests; use self::graph::{Graph, Key as GraphKey}; -use super::{constants::REQUEST_TIMEOUT, message::Request}; +use super::message::Request; use crate::{ collections::HashMap, crypto::{sign::PublicKey, Hash}, @@ -16,6 +16,7 @@ use std::{ collections::{hash_map::Entry, VecDeque}, fmt, iter, mem, sync::atomic::{AtomicUsize, Ordering}, + time::Duration, }; use tokio::{select, sync::mpsc, task}; use tokio_stream::StreamExt; @@ -23,6 +24,8 @@ use tokio_util::time::{delay_queue, DelayQueue}; use tracing::{instrument, Instrument, Span}; use xxhash_rust::xxh3::Xxh3Default; +const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30); + /// Keeps track of in-flight requests. Falls back on another peer in case the request failed (due to /// error response, timeout or disconnection). Evenly distributes the requests between the peers /// and ensures every request is only sent to one peer at a time. @@ -32,13 +35,16 @@ pub(super) struct RequestTracker { } impl RequestTracker { - // TODO: Make request timeout configurable pub fn new() -> Self { let (this, worker) = build(); task::spawn(worker.run().instrument(Span::current())); this } + pub fn set_timeout(&self, timeout: Duration) { + self.command_tx.send(Command::SetTimeout { timeout }).ok(); + } + pub fn new_client( &self, ) -> ( @@ -252,6 +258,7 @@ struct Worker { clients: HashMap, requests: Graph, timer: DelayQueue<(ClientId, MessageKey)>, + timeout: Duration, } impl Worker { @@ -261,6 +268,7 @@ impl Worker { clients: HashMap::default(), requests: Graph::new(), timer: DelayQueue::new(), + timeout: DEFAULT_TIMEOUT, } } @@ -306,6 +314,11 @@ impl Worker { Command::RemoveClient { client_id } => { self.remove_client(client_id); } + Command::SetTimeout { timeout } => { + // Note: for simplicity, the new timeout is be applied to future requests only, + // not the ones that've been already scheduled. + self.timeout = timeout; + } Command::HandleInitial { client_id, request } => { self.handle_initial(client_id, request); } @@ -482,7 +495,7 @@ impl Worker { let sender_timer_key = self .timer - .insert((sender_client_id, request_key), REQUEST_TIMEOUT); + .insert((sender_client_id, request_key), self.timeout); sender_client_state .request_tx .send(node.request().clone()) @@ -585,7 +598,7 @@ impl Worker { *node.value_mut() = match initial_state { InitialRequestState::InFlight => { let timer_key = - self.timer.insert((client_id, request_key), REQUEST_TIMEOUT); + self.timer.insert((client_id, request_key), self.timeout); client_state.request_tx.send(node.request().clone()).ok(); RequestState::InFlight { @@ -669,7 +682,7 @@ impl Worker { let sender_client_id = next_client_id; let sender_timer_key = self .timer - .insert((next_client_id, request_key), REQUEST_TIMEOUT); + .insert((next_client_id, request_key), self.timeout); let waiters = mem::take(waiters); *state = RequestState::InFlight { @@ -734,6 +747,9 @@ enum Command { RemoveClient { client_id: ClientId, }, + SetTimeout { + timeout: Duration, + }, HandleInitial { client_id: ClientId, request: CandidateRequest, diff --git a/lib/src/network/request_tracker/tests.rs b/lib/src/network/request_tracker/tests.rs index a25a112ae..bc05cf539 100644 --- a/lib/src/network/request_tracker/tests.rs +++ b/lib/src/network/request_tracker/tests.rs @@ -185,7 +185,7 @@ async fn timeout() { ); // Wait until the request timeout passes - time::timeout(REQUEST_TIMEOUT + Duration::from_millis(1), &mut work) + time::timeout(DEFAULT_TIMEOUT + Duration::from_millis(1), &mut work) .await .ok(); diff --git a/lib/tests/malice.rs b/lib/tests/malice.rs index 0c0139e66..108e9b696 100644 --- a/lib/tests/malice.rs +++ b/lib/tests/malice.rs @@ -3,6 +3,8 @@ #[macro_use] mod common; +use std::time::Duration; + use common::{actor, Env, Proto, DEFAULT_REPO}; use ouisync::{AccessMode, Error, Repository, StoreError}; use tokio::sync::mpsc; @@ -45,6 +47,11 @@ fn block_nonce_tamper() { env.actor("bob", async move { let (network, repo, _reg) = actor::setup().await; + // Bob first sends the block requests to Mallory but never receives the correct responses. + // Those requests first need to timeout before Bob retries them to Alice. By default that + // would make this test take too long. Decrease the timeout to make it faster. + network.set_request_timeout(Duration::from_secs(5)); + let (alice_id, alice_expected_vv) = mallory_to_bob_rx.recv().await.unwrap(); // Connect to Mallory and wait until index is synced (blocks should be rejected). From e004bf0fc4cbe7e245af941bf4eaaf820e1b49f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Thu, 3 Oct 2024 15:14:40 +0200 Subject: [PATCH 51/55] Add monitoring to RequestTracker --- lib/src/network/client.rs | 15 +-- lib/src/network/mod.rs | 2 +- lib/src/network/request_tracker.rs | 76 ++++++++++-- lib/src/network/request_tracker/tests.rs | 13 +- lib/src/network/server.rs | 4 +- lib/src/network/tests.rs | 16 +-- lib/src/repository/mod.rs | 5 +- lib/src/repository/monitor.rs | 150 +++++++++++++++-------- lib/src/repository/params.rs | 2 +- lib/src/repository/vault.rs | 2 +- 10 files changed, 192 insertions(+), 93 deletions(-) diff --git a/lib/src/network/client.rs b/lib/src/network/client.rs index bf744b29c..c270c7d95 100644 --- a/lib/src/network/client.rs +++ b/lib/src/network/client.rs @@ -109,7 +109,7 @@ impl Inner { loop { for response in recv_iter(rx).await { - self.vault.monitor.responses_received.increment(1); + self.vault.monitor.traffic.responses_received.increment(1); match response { Response::RootNode { @@ -592,7 +592,7 @@ mod tests { db, event::EventSender, protocol::{Proof, RepositoryId, EMPTY_INNER_HASH}, - repository::RepositoryMonitor, + repository::monitor::RepositoryMonitor, version_vector::VersionVector, }; use futures_util::TryStreamExt; @@ -676,17 +676,14 @@ mod tests { let secrets = WriteSecrets::random(); let repository_id = RepositoryId::from(secrets.write_keys.public_key()); + let monitor = RepositoryMonitor::new(StateMonitor::make_root(), &NoopRecorder); + let traffic_monitor = monitor.traffic.clone(); - let vault = Vault::new( - repository_id, - EventSender::new(1), - pool, - RepositoryMonitor::new(StateMonitor::make_root(), &NoopRecorder), - ); + let vault = Vault::new(repository_id, EventSender::new(1), pool, monitor); vault.block_tracker.set_request_mode(BlockRequestMode::Lazy); - let request_tracker = RequestTracker::new(); + let request_tracker = RequestTracker::new(traffic_monitor); let (request_tracker, _request_rx) = request_tracker.new_client(); let (block_tracker, _block_rx) = vault.block_tracker.new_client(); diff --git a/lib/src/network/mod.rs b/lib/src/network/mod.rs index 74e146439..7fe95325c 100644 --- a/lib/src/network/mod.rs +++ b/lib/src/network/mod.rs @@ -356,7 +356,7 @@ impl Network { let pex = self.inner.pex_discovery.new_repository(); pex.set_enabled(pex_enabled); - let request_tracker = RequestTracker::new(); + let request_tracker = RequestTracker::new(handle.vault.monitor.traffic.clone()); request_tracker.set_timeout(REQUEST_TIMEOUT); // TODO: This should be global, not per repo diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index bbbc1c606..564dcd9fd 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -11,12 +11,13 @@ use crate::{ collections::HashMap, crypto::{sign::PublicKey, Hash}, protocol::{BlockId, MultiBlockPresence}, + repository::monitor::{RequestEvent, RequestKind, TrafficMonitor}, }; use std::{ collections::{hash_map::Entry, VecDeque}, fmt, iter, mem, sync::atomic::{AtomicUsize, Ordering}, - time::Duration, + time::{Duration, Instant}, }; use tokio::{select, sync::mpsc, task}; use tokio_stream::StreamExt; @@ -35,8 +36,8 @@ pub(super) struct RequestTracker { } impl RequestTracker { - pub fn new() -> Self { - let (this, worker) = build(); + pub fn new(monitor: TrafficMonitor) -> Self { + let (this, worker) = build(monitor); task::spawn(worker.run().instrument(Span::current())); this } @@ -236,6 +237,15 @@ pub(super) enum MessageKey { Block(BlockId), } +impl MessageKey { + pub fn kind(&self) -> RequestKind { + match self { + Self::RootNode(..) | Self::ChildNodes(..) => RequestKind::Index, + Self::Block(..) => RequestKind::Block, + } + } +} + impl<'a> From<&'a Request> for MessageKey { fn from(request: &'a Request) -> Self { match request { @@ -248,9 +258,12 @@ impl<'a> From<&'a Request> for MessageKey { } } -fn build() -> (RequestTracker, Worker) { +fn build(monitor: TrafficMonitor) -> (RequestTracker, Worker) { let (command_tx, command_rx) = mpsc::unbounded_channel(); - (RequestTracker { command_tx }, Worker::new(command_rx)) + ( + RequestTracker { command_tx }, + Worker::new(command_rx, monitor), + ) } struct Worker { @@ -259,16 +272,18 @@ struct Worker { requests: Graph, timer: DelayQueue<(ClientId, MessageKey)>, timeout: Duration, + monitor: TrafficMonitor, } impl Worker { - fn new(command_rx: mpsc::UnboundedReceiver) -> Self { + fn new(command_rx: mpsc::UnboundedReceiver, monitor: TrafficMonitor) -> Self { Self { command_rx, clients: HashMap::default(), requests: Graph::new(), timer: DelayQueue::new(), timeout: DEFAULT_TIMEOUT, + monitor, } } @@ -365,7 +380,7 @@ impl Worker { }; for (_, node_key) in client_state.requests { - self.cancel_request(client_id, node_key); + self.cancel_request(client_id, node_key, None); } } @@ -400,9 +415,18 @@ impl Worker { RequestState::InFlight { sender_client_id, sender_timer_key, + sent_at, waiters, } if *sender_client_id == client_id => { self.timer.try_remove(sender_timer_key); + + self.monitor.record( + RequestEvent::Success { + rtt: sent_at.elapsed(), + }, + request_key.kind(), + ); + Some(mem::take(waiters)) } RequestState::InFlight { .. } @@ -466,7 +490,7 @@ impl Worker { return; }; - self.cancel_request(client_id, node_key); + self.cancel_request(client_id, node_key, Some(reason)); } #[instrument(skip(self))] @@ -504,8 +528,11 @@ impl Worker { *node.value_mut() = RequestState::InFlight { sender_client_id, sender_timer_key, + sent_at: Instant::now(), waiters, }; + + self.monitor.record(RequestEvent::Send, request_key.kind()); } #[instrument(skip(self))] @@ -601,9 +628,12 @@ impl Worker { self.timer.insert((client_id, request_key), self.timeout); client_state.request_tx.send(node.request().clone()).ok(); + self.monitor.record(RequestEvent::Send, request_key.kind()); + RequestState::InFlight { sender_client_id: client_id, sender_timer_key: timer_key, + sent_at: Instant::now(), waiters: VecDeque::new(), } } @@ -626,12 +656,18 @@ impl Worker { } } - fn cancel_request(&mut self, client_id: ClientId, node_key: GraphKey) { + fn cancel_request( + &mut self, + client_id: ClientId, + node_key: GraphKey, + failure_reason: Option, + ) { let Some(node) = self.requests.get_mut(node_key) else { return; }; let (request, state) = node.parts_mut(); + let request_key = MessageKey::from(&request.payload); let waiters = match state { RequestState::Suspended { waiters } => { @@ -646,10 +682,23 @@ impl Worker { RequestState::InFlight { sender_client_id, sender_timer_key, + sent_at, waiters, } => { if *sender_client_id == client_id { self.timer.try_remove(sender_timer_key); + + self.monitor.record( + match failure_reason { + Some(FailureReason::Response) => RequestEvent::Failure { + rtt: sent_at.elapsed(), + }, + Some(FailureReason::Timeout) => RequestEvent::Timeout, + None => RequestEvent::Cancel, + }, + request_key.kind(), + ); + Some(waiters) } else { remove_from_queue(waiters, &client_id); @@ -670,8 +719,6 @@ impl Worker { RequestState::Committed | RequestState::Cancelled => return, }; - let request_key = MessageKey::from(&request.payload); - // Find next waiting client if let Some(waiters) = waiters { let next_client = iter::from_fn(|| waiters.pop_front()) @@ -685,13 +732,16 @@ impl Worker { .insert((next_client_id, request_key), self.timeout); let waiters = mem::take(waiters); + next_client_state.request_tx.send(request.clone()).ok(); + *state = RequestState::InFlight { sender_client_id, sender_timer_key, + sent_at: Instant::now(), waiters, }; - next_client_state.request_tx.send(request.clone()).ok(); + self.monitor.record(RequestEvent::Send, request_key.kind()); return; } @@ -796,6 +846,8 @@ enum RequestState { sender_client_id: ClientId, /// Timeout key for the request sender_timer_key: delay_queue::Key, + /// When was the request sent. + sent_at: Instant, /// Other clients interested in sending this request. If the current client fails or /// timeouts, a new one will be picked from this list. waiters: VecDeque, diff --git a/lib/src/network/request_tracker/tests.rs b/lib/src/network/request_tracker/tests.rs index bc05cf539..c9d3382b7 100644 --- a/lib/src/network/request_tracker/tests.rs +++ b/lib/src/network/request_tracker/tests.rs @@ -7,6 +7,7 @@ use crate::{ }, }; use assert_matches::assert_matches; +use metrics::NoopRecorder; use rand::{ distributions::{Bernoulli, Distribution, Standard}, rngs::StdRng, @@ -39,7 +40,7 @@ async fn dynamic_swarm() { snapshot.blocks().len() ); - let (tracker, mut tracker_worker) = build(); + let (tracker, mut tracker_worker) = build(TrafficMonitor::new(&NoopRecorder)); // Action to perform on the set of peers. #[derive(Debug)] @@ -125,7 +126,7 @@ async fn missing_blocks() { "seed = {seed}, blocks = {num_blocks}/{max_blocks}, peers = {num_peers}/{max_peers}" ); - let (tracker, mut tracker_worker) = build(); + let (tracker, mut tracker_worker) = build(TrafficMonitor::new(&NoopRecorder)); for snapshot in peer_snapshots { sim.insert_peer(&mut rng, &tracker, snapshot); } @@ -148,7 +149,7 @@ async fn missing_blocks() { #[tokio::test(start_paused = true)] async fn timeout() { let mut rng = StdRng::seed_from_u64(0); - let (tracker, tracker_worker) = build(); + let (tracker, tracker_worker) = build(TrafficMonitor::new(&NoopRecorder)); let mut work = pin!(tracker_worker.run()); @@ -199,7 +200,7 @@ async fn timeout() { #[tokio::test] async fn drop_uncommitted_client() { let mut rng = StdRng::seed_from_u64(0); - let (tracker, mut tracker_worker) = build(); + let (tracker, mut tracker_worker) = build(TrafficMonitor::new(&NoopRecorder)); let (client_a, mut request_rx_a) = tracker.new_client(); let (client_b, mut request_rx_b) = tracker.new_client(); @@ -251,7 +252,7 @@ async fn drop_uncommitted_client() { #[tokio::test] async fn multiple_responses_to_identical_requests() { let mut rng = StdRng::seed_from_u64(0); - let (tracker, mut worker) = build(); + let (tracker, mut worker) = build(TrafficMonitor::new(&NoopRecorder)); let (client, mut request_rx) = tracker.new_client(); let initial_request = Request::RootNode { @@ -295,7 +296,7 @@ async fn multiple_responses_to_identical_requests() { #[tokio::test] async fn suspend_resume() { let mut rng = StdRng::seed_from_u64(0); - let (tracker, mut worker) = build(); + let (tracker, mut worker) = build(TrafficMonitor::new(&NoopRecorder)); let (client, mut request_rx) = tracker.new_client(); worker.step(); diff --git a/lib/src/network/server.rs b/lib/src/network/server.rs index cb3b62f5f..19e7f8a09 100644 --- a/lib/src/network/server.rs +++ b/lib/src/network/server.rs @@ -92,7 +92,7 @@ impl Inner { } async fn handle_request(&self, request: Request) -> Result<()> { - self.vault.monitor.requests_received.increment(1); + self.vault.monitor.traffic.requests_received.increment(1); match request { Request::RootNode { @@ -359,7 +359,7 @@ impl Inner { fn send_response(&self, response: Response) { if self.message_tx.send(Message::Response(response)).is_ok() { - self.vault.monitor.responses_sent.increment(1); + self.vault.monitor.traffic.responses_sent.increment(1); } } } diff --git a/lib/src/network/tests.rs b/lib/src/network/tests.rs index cdddbd010..3605f29cd 100644 --- a/lib/src/network/tests.rs +++ b/lib/src/network/tests.rs @@ -12,7 +12,7 @@ use crate::{ protocol::{ test_utils::Snapshot, Block, BlockId, Bump, RepositoryId, RootNode, SingleBlockPresence, }, - repository::{RepositoryMonitor, Vault}, + repository::{monitor::RepositoryMonitor, Vault}, store::{Changeset, SnapshotWriter}, test_utils, version_vector::VersionVector, @@ -393,18 +393,14 @@ async fn create_repository( write_keys: &Keypair, ) -> (TempDir, Vault, RequestTracker, Arc, PublicKey) { let (base_dir, db) = db::create_temp().await.unwrap(); + let writer_id = PublicKey::generate(rng); let repository_id = RepositoryId::from(write_keys.public_key()); let event_tx = EventSender::new(1); - - let state = Vault::new( - repository_id, - event_tx, - db, - RepositoryMonitor::new(StateMonitor::make_root(), &NoopRecorder), - ); - - let request_tracker = RequestTracker::new(); + let monitor = RepositoryMonitor::new(StateMonitor::make_root(), &NoopRecorder); + let traffic_monitor = monitor.traffic.clone(); + let state = Vault::new(repository_id, event_tx, db, monitor); + let request_tracker = RequestTracker::new(traffic_monitor); let response_limiter = Arc::new(Semaphore::new(MAX_UNCHOKED_COUNT)); ( diff --git a/lib/src/repository/mod.rs b/lib/src/repository/mod.rs index ff261022c..4fa6c5c81 100644 --- a/lib/src/repository/mod.rs +++ b/lib/src/repository/mod.rs @@ -1,6 +1,7 @@ +pub(crate) mod monitor; + mod credentials; mod metadata; -mod monitor; mod params; mod vault; mod worker; @@ -12,10 +13,10 @@ pub use self::{credentials::Credentials, metadata::Metadata, params::RepositoryP pub(crate) use self::{ metadata::{data_version, quota}, - monitor::RepositoryMonitor, vault::Vault, }; +use self::monitor::RepositoryMonitor; use crate::{ access_control::{Access, AccessChange, AccessKeys, AccessMode, AccessSecrets, LocalSecret}, block_tracker::BlockRequestMode, diff --git a/lib/src/repository/monitor.rs b/lib/src/repository/monitor.rs index 37ff947a6..abda7c66f 100644 --- a/lib/src/repository/monitor.rs +++ b/lib/src/repository/monitor.rs @@ -19,27 +19,7 @@ use tracing::{Instrument, Span}; pub(crate) struct RepositoryMonitor { pub info_hash: MonitoredValue>, - - // Total number of index requests sent. - pub index_requests_sent: Counter, - // Current number of sent index request for which responses haven't been received yet. - pub index_requests_inflight: Gauge, - // Total number of block requests sent. - pub block_requests_sent: Counter, - // Current number of sent block request for which responses haven't been received yet. - pub block_requests_inflight: Gauge, - // Total number of received requests - pub requests_received: Counter, - // Time from sending a request to receiving its response. - pub request_latency: Histogram, - // Total number of timeouted requests. - pub request_timeouts: Counter, - - // Total number of responses sent. - pub responses_sent: Counter, - // Total number of responses received. - pub responses_received: Counter, - + pub traffic: TrafficMonitor, pub scan_job: JobMonitor, pub merge_job: JobMonitor, pub prune_job: JobMonitor, @@ -57,21 +37,7 @@ impl RepositoryMonitor { let span = tracing::info_span!("repo", message = node.id().name()); let info_hash = node.make_value("info-hash", None); - - let index_requests_sent = create_counter(recorder, "index requests sent", Unit::Count); - let index_requests_inflight = - create_gauge(recorder, "index requests inflight", Unit::Count); - let block_requests_sent = create_counter(recorder, "block requests sent", Unit::Count); - let block_requests_inflight = - create_gauge(recorder, "block requests inflight", Unit::Count); - - let requests_received = create_counter(recorder, "requests received", Unit::Count); - let request_latency = create_histogram(recorder, "request latency", Unit::Seconds); - let request_timeouts = create_counter(recorder, "request timeouts", Unit::Count); - - let responses_sent = create_counter(recorder, "responses sent", Unit::Count); - let responses_received = create_counter(recorder, "responses received", Unit::Count); - + let traffic = TrafficMonitor::new(recorder); let scan_job = JobMonitor::new(&node, recorder, "scan"); let merge_job = JobMonitor::new(&node, recorder, "merge"); let prune_job = JobMonitor::new(&node, recorder, "prune"); @@ -79,23 +45,11 @@ impl RepositoryMonitor { Self { info_hash, - - index_requests_sent, - index_requests_inflight, - block_requests_sent, - block_requests_inflight, - requests_received, - request_latency, - request_timeouts, - - responses_sent, - responses_received, - + traffic, scan_job, merge_job, prune_job, trash_job, - span, node, } @@ -114,6 +68,104 @@ impl RepositoryMonitor { } } +#[derive(Clone)] +pub(crate) struct TrafficMonitor { + // Total number of index requests sent. + pub index_requests_sent: Counter, + // Current number of sent index request for which responses haven't been received yet. + pub index_requests_inflight: Gauge, + // Total number of block requests sent. + pub block_requests_sent: Counter, + // Current number of sent block request for which responses haven't been received yet. + pub block_requests_inflight: Gauge, + // Total number of received requests + pub requests_received: Counter, + // Time from sending a request to receiving its response. + pub request_latency: Histogram, + // Total number of timeouted requests. + pub request_timeouts: Counter, + + // Total number of responses sent. + pub responses_sent: Counter, + // Total number of responses received. + pub responses_received: Counter, +} + +impl TrafficMonitor { + pub fn new(recorder: &R) -> Self + where + R: Recorder + ?Sized, + { + Self { + index_requests_sent: create_counter(recorder, "index requests sent", Unit::Count), + index_requests_inflight: create_gauge(recorder, "index requests inflight", Unit::Count), + block_requests_sent: create_counter(recorder, "block requests sent", Unit::Count), + block_requests_inflight: create_gauge(recorder, "block requests inflight", Unit::Count), + requests_received: create_counter(recorder, "requests received", Unit::Count), + request_latency: create_histogram(recorder, "request latency", Unit::Seconds), + request_timeouts: create_counter(recorder, "request timeouts", Unit::Count), + responses_sent: create_counter(recorder, "responses sent", Unit::Count), + responses_received: create_counter(recorder, "responses received", Unit::Count), + } + } + + pub fn record(&self, event: RequestEvent, kind: RequestKind) { + match (event, kind) { + (RequestEvent::Send, RequestKind::Index) => { + self.index_requests_sent.increment(1); + self.index_requests_inflight.increment(1.0); + } + (RequestEvent::Send, RequestKind::Block) => { + self.block_requests_sent.increment(1); + self.block_requests_inflight.increment(1.0); + } + ( + RequestEvent::Success { .. } + | RequestEvent::Failure { .. } + | RequestEvent::Timeout + | RequestEvent::Cancel, + RequestKind::Index, + ) => { + self.index_requests_inflight.decrement(1.0); + } + ( + RequestEvent::Success { .. } + | RequestEvent::Failure { .. } + | RequestEvent::Timeout + | RequestEvent::Cancel, + RequestKind::Block, + ) => { + self.block_requests_inflight.decrement(1.0); + } + } + + match event { + RequestEvent::Success { rtt } | RequestEvent::Failure { rtt } => { + self.request_latency.record(rtt); + } + RequestEvent::Timeout => { + self.request_timeouts.increment(1); + } + RequestEvent::Send | RequestEvent::Cancel => (), + } + } +} + +#[derive(Clone, Copy, Eq, PartialEq, Debug)] +pub(crate) enum RequestKind { + Index, + Block, +} + +#[derive(Clone, Copy, Eq, PartialEq, Debug)] +pub(crate) enum RequestEvent { + Send, + Success { rtt: Duration }, + Failure { rtt: Duration }, + Timeout, + Cancel, +} + pub(crate) struct JobMonitor { name: String, count_running_tx: watch::Sender, diff --git a/lib/src/repository/params.rs b/lib/src/repository/params.rs index 2fe1e3a77..d42ee3c44 100644 --- a/lib/src/repository/params.rs +++ b/lib/src/repository/params.rs @@ -1,4 +1,4 @@ -use super::RepositoryMonitor; +use super::monitor::RepositoryMonitor; use crate::{db, device_id::DeviceId, error::Result}; use metrics::{NoopRecorder, Recorder}; use state_monitor::{metrics::MetricsRecorder, StateMonitor}; diff --git a/lib/src/repository/vault.rs b/lib/src/repository/vault.rs index b463a9f59..7deced3cc 100644 --- a/lib/src/repository/vault.rs +++ b/lib/src/repository/vault.rs @@ -3,7 +3,7 @@ #[cfg(test)] mod tests; -use super::{quota, Metadata, RepositoryMonitor}; +use super::{monitor::RepositoryMonitor, quota, Metadata}; use crate::{ block_tracker::BlockTracker, db, From ebf592d6d17a198afd796c0f72a58bda2a7ebcde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Mon, 7 Oct 2024 08:05:28 +0200 Subject: [PATCH 52/55] Bump protocol to v13 --- lib/src/network/protocol.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/src/network/protocol.rs b/lib/src/network/protocol.rs index 5f94b20ef..45566308a 100644 --- a/lib/src/network/protocol.rs +++ b/lib/src/network/protocol.rs @@ -4,7 +4,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; // First string in a handshake, helps with weeding out connections with completely different // protocols on the other end. pub(super) const MAGIC: &[u8; 7] = b"OUISYNC"; -pub(super) const VERSION: Version = Version(12); +pub(super) const VERSION: Version = Version(13); /// Protocol version #[derive(Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Debug)] From ad27bb11fd3d2a7058d4788cc01740fb1d3ad01d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Mon, 7 Oct 2024 11:31:47 +0200 Subject: [PATCH 53/55] Fix sending duplicate requests with different variants to the same peer --- lib/src/network/client.rs | 1 + lib/src/network/request_tracker.rs | 56 ++++++++++++++--- lib/src/network/request_tracker/tests.rs | 78 ++++++++++++++++++++++++ lib/src/network/server.rs | 3 + lib/src/network/tests.rs | 4 +- 5 files changed, 132 insertions(+), 10 deletions(-) diff --git a/lib/src/network/client.rs b/lib/src/network/client.rs index c270c7d95..e9ee9e6f0 100644 --- a/lib/src/network/client.rs +++ b/lib/src/network/client.rs @@ -99,6 +99,7 @@ impl Inner { async fn send_requests(&self, request_rx: &mut mpsc::UnboundedReceiver) { while let Some(PendingRequest { payload, .. }) = request_rx.recv().await { + tracing::trace!(?payload, "sending request"); self.message_tx.send(Message::Request(payload)).ok(); } } diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index 564dcd9fd..47a3beb7e 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -223,7 +223,7 @@ pub(super) enum InitialRequestState { /// /// It also contains the block presence from the response that triggered this request. This is /// mostly useful for diagnostics and testing. -#[derive(Clone, Debug)] +#[derive(Clone, Eq, PartialEq, Debug)] pub(super) struct PendingRequest { pub payload: Request, pub variant: RequestVariant, @@ -602,7 +602,18 @@ impl Worker { node_key: GraphKey, initial_state: InitialRequestState, ) { - let Some(node) = self.requests.get_mut(node_key) else { + let (request_key, add) = if let Some(node) = self.requests.get(node_key) { + let request_key = MessageKey::from(&node.request().payload); + let add = match node.value() { + RequestState::Suspended { .. } + | RequestState::InFlight { .. } + | RequestState::Complete { .. } + | RequestState::Cancelled => true, + RequestState::Committed => false, + }; + + (request_key, add) + } else { return; }; @@ -610,15 +621,42 @@ impl Worker { return; }; - let request_key = MessageKey::from(&node.request().payload); + let (old_key, send) = match client_state.requests.entry(request_key) { + // The request is not yet tracked with this client: start tracking it. + Entry::Vacant(entry) => { + if add { + entry.insert(node_key); + } + + (None, true) + } + // The request with the same variant is already tracked with this client: do nothing. + Entry::Occupied(entry) if *entry.get() == node_key => (None, false), + // The request with a different variant is already tracked with this client: cancel the + // existing request and start trackig the new one. + Entry::Occupied(mut entry) => { + let old_key = *entry.get(); + + if add { + entry.insert(node_key); + } else { + entry.remove(); + } + + (Some(old_key), true) + } + }; - if let Entry::Vacant(entry) = client_state.requests.entry(request_key) { + // unwrap is OK because we already handed the `None` case earlier. + let node = self.requests.get_mut(node_key).unwrap(); + let children: Vec<_> = node.children().collect(); + + if send { match node.value_mut() { RequestState::Suspended { waiters } | RequestState::InFlight { waiters, .. } | RequestState::Complete { waiters, .. } => { waiters.push_back(client_id); - entry.insert(node_key); } RequestState::Committed => (), RequestState::Cancelled => { @@ -641,16 +679,16 @@ impl Worker { waiters: [client_id].into(), }, }; - - entry.insert(node_key); } } } + if let Some(old_key) = old_key { + self.cancel_request(client_id, old_key, None); + } + // Note: we are using recursion, but the graph is only a few layers deep (currently 5) so // there is no danger of stack overflow. - let children: Vec<_> = node.children().collect(); - for child_key in children { self.update_request(client_id, child_key, initial_state); } diff --git a/lib/src/network/request_tracker/tests.rs b/lib/src/network/request_tracker/tests.rs index c9d3382b7..c1688b593 100644 --- a/lib/src/network/request_tracker/tests.rs +++ b/lib/src/network/request_tracker/tests.rs @@ -322,6 +322,84 @@ async fn suspend_resume() { assert_eq!(request_rx.try_recv().map(|r| r.payload), Ok(request)); } +mod duplicate_request_with_different_variant_on_the_same_client { + use super::*; + + #[tokio::test] + async fn in_flight() { + case(|_client, _request_key| ()); + } + + #[tokio::test] + async fn complete() { + case(|client, request_key| { + client.success(request_key, vec![]); + }); + } + + #[tokio::test] + async fn committed() { + case(|client, request_key| { + client.success(request_key, vec![]); + client.new_committer().commit(); + }); + } + + #[tokio::test] + async fn cancelled() { + case(|client, request_key| { + client.failure(request_key); + }); + } + + fn case(step: F) + where + F: FnOnce(&RequestTrackerClient, MessageKey), + { + let mut rng = StdRng::seed_from_u64(0); + let (tracker, mut worker) = build(TrafficMonitor::new(&NoopRecorder)); + let (client, mut request_rx) = tracker.new_client(); + worker.step(); + + let preceding_request_key = MessageKey::RootNode(PublicKey::generate(&mut rng), 0); + + let request = Request::ChildNodes(rng.gen(), DebugRequest::start()); + let variant_0 = RequestVariant::new(MultiBlockPresence::None, MultiBlockPresence::None); + let variant_1 = RequestVariant::new(MultiBlockPresence::None, MultiBlockPresence::Full); + + client.success( + preceding_request_key, + vec![CandidateRequest::new(request.clone()).variant(variant_0)], + ); + worker.step(); + + assert_eq!( + request_rx.try_recv(), + Ok(PendingRequest { + payload: request.clone(), + variant: variant_0 + }), + ); + + step(&client, MessageKey::from(&request)); + worker.step(); + + client.success( + preceding_request_key, + vec![CandidateRequest::new(request.clone()).variant(variant_1)], + ); + worker.step(); + + assert_eq!( + request_rx.try_recv(), + Ok(PendingRequest { + payload: request.clone(), + variant: variant_1 + }), + ); + } +} + /// Generate `count + 1` copies of the same snapshot. The first one will have all the blocks /// present (the "master copy"). The remaining ones will have some blocks missing but in such a /// way that every block is present in at least one of the snapshots. diff --git a/lib/src/network/server.rs b/lib/src/network/server.rs index 19e7f8a09..ca3b13101 100644 --- a/lib/src/network/server.rs +++ b/lib/src/network/server.rs @@ -241,6 +241,7 @@ impl Inner { } } + #[instrument(skip(self))] async fn handle_branch_changed_event(&self, branch_id: PublicKey) -> Result<()> { let root_node = match self.load_root_node(&branch_id).await { Ok(node) => node, @@ -254,12 +255,14 @@ impl Inner { self.send_root_node(root_node).await } + #[instrument(skip(self))] async fn handle_block_received_event(&self, block_id: BlockId) -> Result<()> { self.enqueue_response(Response::BlockOffer(block_id, DebugResponse::unsolicited())) .await; Ok(()) } + #[instrument(skip(self))] async fn handle_unknown_event(&self) -> Result<()> { let root_nodes = self.load_root_nodes().await?; for root_node in root_nodes { diff --git a/lib/src/network/tests.rs b/lib/src/network/tests.rs index 3605f29cd..2f5e32130 100644 --- a/lib/src/network/tests.rs +++ b/lib/src/network/tests.rs @@ -151,7 +151,9 @@ async fn transfer_blocks_between_two_replicas_case(block_count: usize, rng_seed: // Then wait until replica B receives and writes it too. for id in snapshot.blocks().keys() { + tracing::info!(?id, "waiting for block"); wait_until_block_exists(&b_vault, id).await; + tracing::info!(?id, "block received"); } }; @@ -396,7 +398,7 @@ async fn create_repository( let writer_id = PublicKey::generate(rng); let repository_id = RepositoryId::from(write_keys.public_key()); - let event_tx = EventSender::new(1); + let event_tx = EventSender::new(128); let monitor = RepositoryMonitor::new(StateMonitor::make_root(), &NoopRecorder); let traffic_monitor = monitor.traffic.clone(); let state = Vault::new(repository_id, event_tx, db, monitor); From 97ff991c355019c945d56d16bb66a4f4ebf72914 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Mon, 7 Oct 2024 11:41:50 +0200 Subject: [PATCH 54/55] Enable missing codec feature for tokio-util --- lib/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 77b11fbfa..7a47861ae 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -73,7 +73,7 @@ subtle = { version = "2.5.0", default-features = false, features = ["core_hint_b thiserror = { workspace = true } tokio = { workspace = true } tokio-stream = { workspace = true, features = ["sync"] } -tokio-util = { workspace = true, features = ["time"] } +tokio-util = { workspace = true, features = ["time", "codec"] } tracing = { workspace = true } tracing-subscriber = { workspace = true, features = [ "env-filter" ] } turmoil = { workspace = true, optional = true } From aef01ee0d0c3764cd65c9d74c63aaa9409e46d6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Mon, 7 Oct 2024 11:54:00 +0200 Subject: [PATCH 55/55] Minor refactoring --- lib/src/network/request_tracker.rs | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/lib/src/network/request_tracker.rs b/lib/src/network/request_tracker.rs index 47a3beb7e..a99185e54 100644 --- a/lib/src/network/request_tracker.rs +++ b/lib/src/network/request_tracker.rs @@ -299,7 +299,7 @@ impl Worker { } Some(expired) = self.timer.next() => { let (client_id, request_key) = expired.into_inner(); - self.handle_failure(client_id, request_key, FailureReason::Timeout); + self.failure(client_id, request_key, FailureReason::Timeout); } } } @@ -335,20 +335,20 @@ impl Worker { self.timeout = timeout; } Command::HandleInitial { client_id, request } => { - self.handle_initial(client_id, request); + self.initial(client_id, request); } Command::HandleSuccess { client_id, request_key, requests, } => { - self.handle_success(client_id, request_key, requests); + self.success(client_id, request_key, requests); } Command::HandleFailure { client_id, request_key, } => { - self.handle_failure(client_id, request_key, FailureReason::Response); + self.failure(client_id, request_key, FailureReason::Response); } Command::Resume { request_key, @@ -385,20 +385,20 @@ impl Worker { } #[instrument(skip(self))] - fn handle_initial(&mut self, client_id: ClientId, request: CandidateRequest) { - tracing::trace!("handle_initial"); + fn initial(&mut self, client_id: ClientId, request: CandidateRequest) { + tracing::trace!("initial"); self.insert_request(client_id, request, None) } #[instrument(skip(self))] - fn handle_success( + fn success( &mut self, client_id: ClientId, request_key: MessageKey, requests: Vec, ) { - tracing::trace!("handle_success"); + tracing::trace!("success"); let node_key = self .clients @@ -474,13 +474,8 @@ impl Worker { } #[instrument(skip(self))] - fn handle_failure( - &mut self, - client_id: ClientId, - request_key: MessageKey, - reason: FailureReason, - ) { - tracing::trace!("handle_failure"); + fn failure(&mut self, client_id: ClientId, request_key: MessageKey, reason: FailureReason) { + tracing::trace!("failure"); let Some(client_state) = self.clients.get_mut(&client_id) else { return;