diff --git a/src/client.rs b/src/client.rs index 95a5a749..10f741a1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -441,11 +441,21 @@ impl ControlChannel { // Read ack debug!("Reading ack"); - match read_ack(&mut conn).await? { - Ack::Ok => {} - v => { - return Err(anyhow!("{}", v)) - .with_context(|| format!("Authentication failed: {}", self.service.name)); + for _ in 0..2 { + match read_ack(&mut conn).await? { + Ack::Ok => break, + Ack::RequireServiceConfig => { + debug!("Sending client service config"); + let s = toml::to_string(&self.service).unwrap(); + let buf = s.as_bytes(); + conn.write_u32(buf.len() as u32).await?; + conn.write_all(buf).await?; + conn.flush().await?; + } + v => { + return Err(anyhow!("{}", v)) + .with_context(|| format!("Authentication failed: {}", self.service.name)); + } } } diff --git a/src/config.rs b/src/config.rs index ca85fc20..89571efb 100644 --- a/src/config.rs +++ b/src/config.rs @@ -63,6 +63,7 @@ pub struct ClientServiceConfig { #[serde(skip)] pub name: String, pub local_addr: String, + pub recommend_bind_addr: Option, pub token: Option, pub nodelay: Option, pub retry_interval: Option, @@ -214,11 +215,17 @@ fn default_heartbeat_interval() -> u64 { DEFAULT_HEARTBEAT_INTERVAL_SECS } +fn default_accept_client_recommend_service() -> bool { + false +} + #[derive(Debug, Serialize, Deserialize, Default, PartialEq, Eq, Clone)] #[serde(deny_unknown_fields)] pub struct ServerConfig { pub bind_addr: String, pub default_token: Option, + #[serde(default = "default_accept_client_recommend_service")] + pub accept_client_recommend_service: bool, pub services: HashMap, #[serde(default)] pub transport: TransportConfig, diff --git a/src/protocol.rs b/src/protocol.rs index 577c7323..057aaa2b 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -8,6 +8,8 @@ use std::net::SocketAddr; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tracing::trace; +use crate::config::{ClientServiceConfig, ServerServiceConfig}; + type ProtocolVersion = u8; const _PROTO_V0: u8 = 0u8; const PROTO_V1: u8 = 1u8; @@ -30,6 +32,7 @@ pub enum Ack { Ok, ServiceNotExist, AuthFailed, + RequireServiceConfig, } impl std::fmt::Display for Ack { @@ -41,6 +44,7 @@ impl std::fmt::Display for Ack { Ack::Ok => "Ok", Ack::ServiceNotExist => "Service not exist", Ack::AuthFailed => "Incorrect token", + Ack::RequireServiceConfig => "Try to use service config defined in client", } ) } @@ -112,8 +116,7 @@ impl UdpTraffic { } pub async fn read(reader: &mut T, hdr_len: u8) -> Result { - let mut buf = Vec::new(); - buf.resize(hdr_len as usize, 0); + let mut buf = vec![0; hdr_len as usize]; reader .read_exact(&mut buf) .await @@ -207,6 +210,34 @@ pub async fn read_hello(conn: &mut T) -> Resu Ok(hello) } +pub async fn read_server_service_config_from_client(conn: &mut T) -> Result { + conn.write_all(&bincode::serialize(&Ack::RequireServiceConfig).unwrap()) + .await?; + conn.flush().await?; + + let n = conn.read_u32() + .await + .with_context(|| "Failed to read client service config")? as usize; + let mut buf = vec![0u8; n]; + conn.read_exact(&mut buf) + .await + .with_context(|| "Failed to read client service config")?; + + let config: ClientServiceConfig = toml::from_str(&String::from_utf8(buf)?[..]).with_context(|| "Failed to parse the config")?; + Ok( + ServerServiceConfig{ + bind_addr: match config.recommend_bind_addr { + Some(bind_addr) => bind_addr, + None => return Err(anyhow::anyhow!(format!("Expect 'recommend_bind_addr' in {}", config.name))), + }, + service_type: config.service_type, + name: config.name, + nodelay: config.nodelay, + token: config.token, + } + ) +} + pub async fn read_auth(conn: &mut T) -> Result { let mut buf = vec![0u8; PACKET_LEN.auth]; conn.read_exact(&mut buf) diff --git a/src/server.rs b/src/server.rs index a36e3c21..2d8810a4 100644 --- a/src/server.rs +++ b/src/server.rs @@ -6,7 +6,7 @@ use crate::multi_map::MultiMap; use crate::protocol::Hello::{ControlChannelHello, DataChannelHello}; use crate::protocol::{ self, read_auth, read_hello, Ack, ControlChannelCmd, DataChannelCmd, Hello, UdpTraffic, - HASH_WIDTH_IN_BYTES, + HASH_WIDTH_IN_BYTES, read_server_service_config_from_client, }; use crate::transport::{SocketOpts, TcpTransport, Transport}; use anyhow::{anyhow, bail, Context, Result}; @@ -297,16 +297,34 @@ async fn do_control_channel_handshake( .await?; conn.flush().await?; + + // Read auth + let protocol::Auth(d) = read_auth(&mut conn).await?; + // Lookup the service let service_config = match services.read().await.get(&service_digest) { - Some(v) => v, + Some(v) => v.clone(), None => { - conn.write_all(&bincode::serialize(&Ack::ServiceNotExist).unwrap()) - .await?; - bail!("No such a service {}", hex::encode(service_digest)); + let op = match server_config.accept_client_recommend_service { + true => { + match read_server_service_config_from_client(&mut conn).await { // Send ACK::RequireServiceConfig + Ok(config) => Some(config), + Err(_) => None, + } + }, + false => None, + }; + + match op { + Some(config) => config, + None => { + conn.write_all(&bincode::serialize(&Ack::ServiceNotExist).unwrap()) + .await?; + bail!("No such a service {}", hex::encode(service_digest)); + } + } } - } - .to_owned(); + }; let service_name = &service_config.name; @@ -314,8 +332,6 @@ async fn do_control_channel_handshake( let mut concat = Vec::from(service_config.token.as_ref().unwrap().as_bytes()); concat.append(&mut nonce); - // Read auth - let protocol::Auth(d) = read_auth(&mut conn).await?; // Validate let session_key = protocol::digest(&concat);