diff --git a/src/net/mux.rs b/src/net/mux.rs index c28137d..4b072d6 100644 --- a/src/net/mux.rs +++ b/src/net/mux.rs @@ -5,15 +5,13 @@ // TODO: Add Multiplex trait with impls for Network and SplitChannel? -use std::{error::Error, vec::Drain}; -use std::{ops::RangeBounds, sync::Arc}; +use std::error::Error; +use std::sync::Arc; -use std::future::Future; use futures::future::join_all; -use itertools::{multiunzip, Itertools}; use thiserror::Error; -use tokio::{join, sync::{mpsc::{self, unbounded_channel, UnboundedSender, WeakUnboundedSender}, oneshot}}; +use tokio::sync::{mpsc::{self, unbounded_channel, UnboundedSender}, oneshot}; use tokio_util::bytes::{Buf, BufMut, Bytes, BytesMut}; use crate::{ @@ -133,7 +131,49 @@ impl SplitChannel for MuxConn { } } -struct GatewayInner +/// # Multiplexed Gateway Channel +/// +/// Enables splitting a channel into multiple multiplexed channels. +/// The multiplexed channels must be *driven* by the gateway +/// (see [Gateway::drive]) otherwise the multiplexed channels won't +/// be able to communicate. +/// +/// ## Example: +/// ``` +/// use caring::net::{connection::Connection, mux::Gateway, RecvBytes, SendBytes}; +/// # tokio_test::block_on(async { +/// let (c1,c2) = Connection::in_memory(); +/// +/// tokio::spawn(async {// party 1 +/// let con = c1; +/// let (mut gateway, mut m1) = Gateway::single(con); +/// let mut m2 = gateway.muxify(); +/// tokio::spawn(async move { +/// m1.send(&"Hello MUX1!".to_owned()).await.unwrap(); +/// }); +/// tokio::spawn(async move { +/// m2.send(&"Hello MUX2!".to_owned()).await.unwrap(); +/// }); +/// gateway.drive().await; +/// }); +/// +/// tokio::spawn( async {// party 2 +/// let con = c2; +/// let (mut gateway, mut m1) = Gateway::single(con); +/// let mut m2 = gateway.muxify(); +/// tokio::spawn(async move { +/// let msg : String = m1.recv().await.unwrap(); +/// assert_eq!(msg, "Hello MUX1!"); +/// }); +/// tokio::spawn(async move { +/// let msg : String = m2.recv().await.unwrap(); +/// assert_eq!(msg, "Hello MUX2!"); +/// }); +/// gateway.drive().await; +/// }); +/// }) +/// ``` +pub struct Gateway where C: SplitChannel, { @@ -144,104 +184,8 @@ where outbox: mpsc::WeakUnboundedSender } -/// Gateway channel for multiplexed connections/channels ([MuxConn]), -/// interally holding a [SplitChannel]. -/// -/// Constructed by [Gateway::multiplex] -pub struct Gateway { - inner: GatewayInner, - muxes: Vec, -} - impl Gateway { - /// Multiplex a channel to share it into `n` new connections. - /// - /// * `net`: Connection to use as a gateway for multiplexing - /// * `n`: Number of new connections to multiplex into - /// - /// Returns a gateway which the MuxConn communicate through, along with the MuxConn - /// - /// # Example - /// ``` - /// # use crate::caring::net::SendBytes; - /// # use caring::net::connection::Connection; - /// # use caring::net::mux::Gateway; - /// # tokio_test::block_on(async { - /// # let (c1, c2) = Connection::in_memory(); - /// # let first = async move { - /// # let mut con = c1; - /// use crate::caring::net::Channel; - /// use itertools::Itertools; - /// - /// let mut gateway = Gateway::multiplex(&mut con, 2); - /// let (mut m1, mut m2) = gateway.drain(..).collect_tuple().unwrap(); - /// let t1 = async move { - /// m1.send(&String::from("Hello")).await.unwrap(); - /// }; - /// let t2 = async move { - /// m2.send(&String::from("Friend")).await.unwrap(); - /// }; - /// futures::join!(t1, t2, gateway.drive()); // Gateway needs to be run aswell. - /// # }; - /// # - /// # use crate::caring::net::RecvBytes; - /// # use itertools::Itertools; - /// # use crate::caring::net::Channel; - /// # let second = async move { - /// # let mut con = c2; - /// # let mut gateway = Gateway::multiplex(&mut con, 2); - /// # let (mut m1, mut m2) = gateway.drain(..).collect_tuple().unwrap(); - /// # let t1 = async move { - /// # let _ : String = m1.recv().await.unwrap(); - /// # }; - /// # let t2 = async move { - /// # let _ : String = m2.recv().await.unwrap(); - /// # }; - /// # futures::join!(t1, t2, gateway.drive()); - /// # }; - /// # futures::join!(first, second) - /// # }); - /// - /// ``` - /// - pub fn multiplex(con: C, n: usize) -> Self { - let (mut gateway, con) = GatewayInner::new(con); - let mut muxes = vec![con]; - for _ in 1..n { - muxes.push(gateway.muxify()); - } - Self { - inner: gateway, - muxes, - } - } - - pub async fn map>(self, func: impl FnMut(MuxConn) -> F) -> Vec { - let res = join_all(self.muxes.into_iter().map(func)); - let (res, _) = join!(res, self.inner.run()); - res - } - - pub fn drain(&mut self, range: impl RangeBounds) -> Drain { - self.muxes.drain(range) - } - - /// Drive the gateway, allowing the multiplexed connections to run pub async fn drive(self) -> Self { - let muxes = self.muxes; - let inner = self.inner.run().await; - Self { inner, muxes } - } - - - pub fn destroy(self) -> C { - self.inner.channel - } - -} - -impl GatewayInner { - async fn run(self) -> Self { let mut gateway = self; { let (sending, recving) = gateway.channel.split(); @@ -283,7 +227,7 @@ impl GatewayInner { gateway } - pub fn new(channel: C) -> (Self, MuxConn) { + pub fn single(channel: C) -> (Self, MuxConn) { let (outbox, inbox) = unbounded_channel(); let gateway = outbox.clone(); let outbox= outbox.downgrade(); @@ -299,6 +243,26 @@ impl GatewayInner { } + pub fn destroy(self) -> C { + self.channel + } + + /// Multiplex a channel to share it into `n` new connections. + /// + /// * `net`: Connection to use as a gateway for multiplexing + /// * `n`: Number of new connections to multiplex into + /// + /// Returns a gateway which the MuxConn communicate through, along with the MuxConn + pub fn multiplex(con: C, n: usize) -> (Self, Vec) { + let (mut gateway, con) = Self::single(con); + let mut muxes = vec![con]; + for _ in 1..n { + muxes.push(gateway.muxify()); + } + (gateway, muxes) + } + + fn add_mux(&mut self, gateway: UnboundedSender) -> MuxConn { let id = self.mailboxes.len(); let (errors_coms1, error) = oneshot::channel(); @@ -341,8 +305,7 @@ where let mut matrix = Vec::new(); let index = net.index; for conn in net.connections { - let mut gateway = Gateway::multiplex(conn, n); - let muxes = gateway.drain(..).collect_vec(); + let (gateway, muxes) = Gateway::multiplex(conn, n); matrix.push(muxes); gateways.push(gateway); } @@ -376,10 +339,12 @@ where } pub fn new_mux(&mut self) -> MuxNet { - todo!() + let connections = self.gateways.iter_mut().map(|g| g.muxify() ).collect(); + MuxNet { connections, index: self.index } } } + #[cfg(test)] mod test { use std::time::Duration; @@ -404,14 +369,12 @@ mod test { msg } - // TODO: Better names for tests. - #[tokio::test] async fn sunshine() { let (c1, c2) = Connection::in_memory(); let p1 = async { - let mut gateway = Gateway::multiplex(c1, 3); - let (mut m1, mut m2, mut m3) = gateway.drain(..).collect_tuple().unwrap(); + let (gateway, mut muxes) = Gateway::multiplex(c1, 3); + let (mut m1, mut m2, mut m3) = muxes.drain(..).collect_tuple().unwrap(); let s = async move { let (s1, s2, s3) = futures::join!( @@ -422,15 +385,15 @@ mod test { s1 + &s2 + &s3 }; - let (s, mut gateway) = join!(s, gateway.inner.run()); + let (s, mut gateway) = join!(s, gateway.drive()); gateway.channel.send(&"bye".to_owned()).await.unwrap(); gateway.channel.shutdown().await.unwrap(); s }; let p2 = async { - let mut gateway = Gateway::multiplex(c2, 3); - let (mut m1, mut m2, mut m3) = gateway.drain(..).collect_tuple().unwrap(); + let (gateway, mut muxes) = Gateway::multiplex(c2, 3); + let (mut m1, mut m2, mut m3) = muxes.drain(..).collect_tuple().unwrap(); let s = async move { let (s1, s2, s3) = futures::join!( chat(&mut m1, "Hello, "), @@ -439,7 +402,7 @@ mod test { ); s1 + &s2 + &s3 }; - let (s, mut gateway) = (s, gateway.inner.run()).join().await; + let (s, mut gateway) = (s, gateway.drive()).join().await; let _ : String = gateway.channel.recv().await.unwrap(); gateway.channel.shutdown().await.unwrap(); s @@ -454,8 +417,8 @@ mod test { async fn moonshine() { let (c1, c2) = Connection::in_memory(); let p1 = async { - let mut gateway = Gateway::multiplex(c1, 3); - let (mut m1, mut m2, mut m3) = gateway.drain(..).collect_tuple().unwrap(); + let (gateway, mut muxes) = Gateway::multiplex(c1, 3); + let (mut m1, mut m2, mut m3) = muxes.drain(..).collect_tuple().unwrap(); let h = async { // Wait a little such the errors get time to propagate tokio::time::sleep(Duration::from_millis(5)).await; @@ -468,7 +431,7 @@ mod test { s2.expect_err("Should be closed"); s3.expect_err("Should be closed"); }; - join!(h, gateway.inner.run()) + join!(h, gateway.drive()) }; let p2 = async { diff --git a/src/net/network.rs b/src/net/network.rs index 6054650..6fe3632 100644 --- a/src/net/network.rs +++ b/src/net/network.rs @@ -298,7 +298,7 @@ impl Network { } - pub(crate) fn as_mut<'a>(&'a mut self) -> Network<&'a mut C> { + pub(crate) fn as_mut(&mut self) -> Network<&mut C> { let connections = self.connections.iter_mut().collect(); Network { connections, index: self.index } }