Skip to content

Latest commit

 

History

History
202 lines (179 loc) · 5.97 KB

final_server_code.md

File metadata and controls

202 lines (179 loc) · 5.97 KB

Final Server Code

The final code looks like this:

use std::{
    collections::hash_map::{Entry, HashMap},
    future::Future,
    sync::Arc,
};

use tokio::{
    io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
    net::{tcp::OwnedWriteHalf, TcpListener, TcpStream, ToSocketAddrs},
    sync::{mpsc, oneshot, Notify},
    task,
};

type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
type Sender<T> = mpsc::UnboundedSender<T>;
type Receiver<T> = mpsc::UnboundedReceiver<T>;

#[tokio::main]
pub(crate) async fn main() -> Result<()> {
    accept_loop("127.0.0.1:8080").await
}

async fn accept_loop(addr: impl ToSocketAddrs) -> Result<()> {
    let listener = TcpListener::bind(addr).await?;

    let (broker_sender, broker_receiver) = mpsc::unbounded_channel();
    let broker = task::spawn(broker_loop(broker_receiver));
    let shutdown_notification = Arc::new(Notify::new());

    loop {
        tokio::select!{
            Ok((stream, _socket_addr)) = listener.accept() => {
                println!("Accepting from: {}", stream.peer_addr()?);
                spawn_and_log_error(connection_loop(broker_sender.clone(), stream, shutdown_notification.clone()));
            },
            _ = tokio::signal::ctrl_c() => break,
        }
    }
    println!("Shutting down!");
    shutdown_notification.notify_waiters();
    drop(broker_sender);
    broker.await?;
    Ok(())
}

async fn connection_loop(broker: Sender<Event>, stream: TcpStream, shutdown: Arc<Notify>) -> Result<()> {
    let (reader, writer) = stream.into_split();
    let reader = BufReader::new(reader);
    let mut lines = reader.lines();
    let (shutdown_sender, shutdown_receiver) = oneshot::channel::<()>();

    let name = match lines.next_line().await {
        Ok(Some(line)) => line,
        Ok(None) => return Err("peer disconnected immediately".into()),
        Err(e) => return Err(Box::new(e)),
    };

    println!("user {} connected", name);

    broker
        .send(Event::NewPeer {
            name: name.clone(),
            stream: writer,
            shutdown: shutdown_receiver,
        })
        .unwrap();
    
    loop {
        tokio::select! {
            Ok(Some(line)) = lines.next_line() => {
                let (dest, msg) = match line.split_once(':') {

                    None => continue,
                    Some((dest, msg)) => (dest, msg.trim()),
                };
                let dest: Vec<String> = dest
                    .split(',')
                    .map(|name| name.trim().to_string())
                    .collect();
                let msg: String = msg.trim().to_string();
        
                broker
                    .send(Event::Message {
                        from: name.clone(),
                        to: dest,
                        msg,
                    })
                    .unwrap();
            },
            _ = shutdown.notified() => break,
        }
    }
    println!("Closing connection loop!");
    drop(shutdown_sender);

    Ok(())
}

async fn connection_writer_loop(
    messages: &mut Receiver<String>,
    stream: &mut OwnedWriteHalf,
    mut shutdown: oneshot::Receiver<()>,
) -> Result<()> {
    loop {
        tokio::select! {
            msg = messages.recv() => match msg {
                Some(msg) => stream.write_all(msg.as_bytes()).await?,
                None => break,
            },
            _ = &mut shutdown => break
        }
    }

    println!("Closing connection_writer loop!");

    Ok(())
}

#[derive(Debug)]
enum Event {
    NewPeer {
        name: String,
        stream: OwnedWriteHalf,
        shutdown: oneshot::Receiver<()>,
    },
    Message {
        from: String,
        to: Vec<String>,
        msg: String,
    },
}

async fn broker_loop(mut events: Receiver<Event>) {
    let (disconnect_sender, mut disconnect_receiver) =
        mpsc::unbounded_channel::<(String, Receiver<String>)>();
    let mut peers: HashMap<String, Sender<String>> = HashMap::new();

    loop {
        let event = tokio::select! {
            event = events.recv() => match event {
                None => break,
                Some(event) => event,
            },
            disconnect = disconnect_receiver.recv() => {
                let (name, _pending_messages) = disconnect.unwrap();
                assert!(peers.remove(&name).is_some());
                println!("user {} disconnected", name);
                continue;
            },
        };
        match event {
            Event::Message { from, to, msg } => {
                for addr in to {
                    if let Some(peer) = peers.get_mut(&addr) {
                        let msg = format!("from {}: {}\n", from, msg);
                        peer.send(msg).unwrap();
                    }
                }
            }
            Event::NewPeer {
                name,
                mut stream,
                shutdown,
            } => match peers.entry(name.clone()) {
                Entry::Occupied(..) => (),
                Entry::Vacant(entry) => {
                    let (client_sender, mut client_receiver) = mpsc::unbounded_channel();
                    entry.insert(client_sender);
                    let disconnect_sender = disconnect_sender.clone();
                    spawn_and_log_error(async move {
                        let res =
                            connection_writer_loop(&mut client_receiver, &mut stream, shutdown)
                                .await;
                        println!("user {} disconnected", name);
                        disconnect_sender.send((name, client_receiver)).unwrap();
                        res
                    });
                }
            },
        }
    }
    drop(peers);
    drop(disconnect_sender);
    while let Some((_name, _pending_messages)) = disconnect_receiver.recv().await {}
}

fn spawn_and_log_error<F>(fut: F) -> task::JoinHandle<()>
where
    F: Future<Output = Result<()>> + Send + 'static,
{
    task::spawn(async move {
        if let Err(e) = fut.await {
            eprintln!("{}", e)
        }
    })
}