From 751cbb82bd114da7f2460ed1e16183f6a1f04244 Mon Sep 17 00:00:00 2001 From: Dmitry Zolotukhin Date: Sun, 21 Jul 2024 21:49:38 +0200 Subject: [PATCH] Cleaned up PPP handling code. Now, PPP should work without peeking at data, and be much more resilient against cancellations. smoltcp's main loop should run faster without blocking. Added bridge/proxy cleanup code. --- src/fortivpn.rs | 308 +++++++++++++++++++++++++++++++++--------------- src/network.rs | 41 +++++-- src/ppp.rs | 36 +++--- 3 files changed, 259 insertions(+), 126 deletions(-) diff --git a/src/fortivpn.rs b/src/fortivpn.rs index 1ea9506..f94821c 100644 --- a/src/fortivpn.rs +++ b/src/fortivpn.rs @@ -128,6 +128,7 @@ pub struct FortiVPNTunnel { socket: FortiTlsStream, addr: IpAddr, mtu: usize, + ppp_state: PPPState, ppp_magic: u32, ppp_identifier: u8, } @@ -142,12 +143,15 @@ impl FortiVPNTunnel { let mut socket = FortiVPNTunnel::connect(&config.destination_hostport, domain).await?; let addr = FortiVPNTunnel::request_vpn_allocation(domain, &mut socket, &cookie).await?; FortiVPNTunnel::start_vpn_tunnel(domain, &mut socket, &cookie).await?; - let ppp_magic = FortiVPNTunnel::start_ppp(&mut socket).await?; - FortiVPNTunnel::start_ipcp(&mut socket, addr).await?; + + let mut ppp_state = PPPState::new(); + let ppp_magic = FortiVPNTunnel::start_ppp(&mut socket, &mut ppp_state).await?; + FortiVPNTunnel::start_ipcp(&mut socket, &mut ppp_state, addr).await?; Ok(FortiVPNTunnel { socket, addr, mtu: PPP_MTU as usize, + ppp_state, ppp_magic, ppp_identifier: 2, }) @@ -214,7 +218,10 @@ impl FortiVPNTunnel { Ok(socket.flush().await?) } - async fn start_ppp(socket: &mut FortiTlsStream) -> Result { + async fn start_ppp( + socket: &mut FortiTlsStream, + ppp_state: &mut PPPState, + ) -> Result { // Open PPP link; 200 bytes should fit any PPP packet. // This is an oversimplified implementation of the RFC 1661 state machine. let mut req = [0u8; 20]; @@ -236,8 +243,17 @@ impl FortiVPNTunnel { let mut local_acked = false; let mut remote_acked = false; while !(local_acked && remote_acked) { - let length = FortiVPNTunnel::read_ppp_packet(socket, &mut resp, true).await?; - let response = ppp::Packet::from_bytes(&resp[..length]).map_err(|err| { + ppp_state.read_header(socket).await.map_err(|err| { + debug!("Failed to read PPP header: {}", err); + "Failed to read PPP header" + })?; + let protocol = if let Some(protocol) = ppp_state.read_protocol() { + protocol + } else { + return Err("Unable to read PPP protocol".into()); + }; + let length = FortiVPNTunnel::read_ppp_packet(socket, ppp_state, &mut resp).await?; + let response = ppp::Packet::from_bytes(protocol, &resp[..length]).map_err(|err| { debug!("Failed to decode PPP packet: {}", err); "Failed to decode PPP packet" })?; @@ -338,12 +354,16 @@ impl FortiVPNTunnel { Ok(magic) } - async fn start_ipcp(socket: &mut FortiTlsStream, addr: IpAddr) -> Result<(), FortiError> { + async fn start_ipcp( + socket: &mut FortiTlsStream, + ppp_state: &mut PPPState, + addr: IpAddr, + ) -> Result<(), FortiError> { // Open IPCP link; 20 bytes should fit any IPCP packet. // This is an oversimplified implementation of the RFC 1661 state machine. let mut req = [0u8; 20]; let mut resp = [0u8; 200]; - let mut identifier = 1; + let identifier = 1; let addr = match addr { IpAddr::V4(addr) => addr, _ => return Ok(()), @@ -356,15 +376,24 @@ impl FortiVPNTunnel { "Failed to encode IPCP Configure-Request" })?; let mut opts = [0u8; 100]; - let mut opts_len = length - 4; + let opts_len = length - 4; opts[..opts_len].copy_from_slice(&req[4..length]); FortiVPNTunnel::send_ppp_packet(socket, ppp::Protocol::IPV4CP, &req[..length]).await?; let mut local_acked = false; let mut remote_acked = false; while !(local_acked && remote_acked) { - let length = FortiVPNTunnel::read_ppp_packet(socket, &mut resp, true).await?; - let response = ppp::Packet::from_bytes(&resp[..length]).map_err(|err| { + ppp_state.read_header(socket).await.map_err(|err| { + debug!("Failed to read PPP header: {}", err); + "Failed to read PPP header" + })?; + let protocol = if let Some(protocol) = ppp_state.read_protocol() { + protocol + } else { + return Err("Unable to read PPP protocol".into()); + }; + let length = FortiVPNTunnel::read_ppp_packet(socket, ppp_state, &mut resp).await?; + let response = ppp::Packet::from_bytes(protocol, &resp[..length]).map_err(|err| { debug!("Failed to decode PPP packet: {}", err); "Failed to decode PPP packet" })?; @@ -469,82 +498,31 @@ impl FortiVPNTunnel { async fn read_ppp_packet( socket: &mut FortiTlsStream, + state: &mut PPPState, dest: &mut [u8], - first_packet: bool, ) -> Result { - let mut packet_header = [0u8; 6]; - println!("Reading packet"); - // If no data is available, this will return immediately. - match tokio::time::timeout(Duration::from_millis(100), async { - loop { - match socket.read_peek(packet_header.len()).await { - Ok(header) => { - if header.len() >= packet_header.len() { - println!("Have header {} bytes", header.len()); - return; - } else { - println!("Have {} bytes", header.len()); - } - } - Err(err) => { - println!("Failed to read header {}", err); - return; - } - } - } - }) - .await - { - Ok(_) => {} - Err(_) => { - println!("Read timed out"); - return Ok(0); - } - } - println!("Packet ready"); - - if first_packet { - if let Err(err) = FortiVPNTunnel::validate_link(socket, &packet_header[..6]).await { - return Err(err); - } - } - - socket.read(&mut packet_header).await?; - let mut ppp_size = [0u8; 2]; - ppp_size.copy_from_slice(&packet_header[..2]); - let ppp_size = u16::from_be_bytes(ppp_size); - let mut data_size = [0u8; 2]; - data_size.copy_from_slice(&packet_header[4..6]); - let data_size = u16::from_be_bytes(data_size); - let magic = &packet_header[2..4]; - if ppp_size != data_size + 6 { - debug!( - "Conflicting packet size data: PPP packet size is {}, data size is {}", - ppp_size, data_size - ); - return Err("Header has conflicting length data".into()); - } - if magic != &[0x50, 0x50] { - debug!( - "Found {:x}{:x} instead of magic", - packet_header[2], packet_header[3] - ); - return Err("Magic not found".into()); - } - let data_size = data_size as usize; - if data_size > dest.len() { + state.read_header(socket).await.map_err(|err| { + debug!("Failed to read PPP header: {}", err); + "Failed to read PPP header" + })?; + if state.remaining_bytes() > dest.len() { debug!( "Destination buffer ({} bytes) is smaller than the traferred packet ({} bytes)", dest.len(), - data_size + state.remaining_bytes() ); return Err("Destination buffer not large enough to fit all data".into()); } - let mut received_data = 0usize; - while received_data < data_size { - received_data += socket.read(&mut dest[received_data..]).await?; + // TODO: check if perhaps sending partial data is acceptable? + let mut received_data = 0; + while state.remaining_bytes() > 0 { + let bytes_transferred = socket + .read(&mut dest[received_data..received_data + state.remaining_bytes()]) + .await?; + state.consume_bytes(bytes_transferred)?; + received_data += bytes_transferred; } - Ok(data_size) + Ok(received_data) } pub async fn send_echo(&mut self) -> Result<(), FortiError> { @@ -568,36 +546,170 @@ impl FortiVPNTunnel { FortiVPNTunnel::send_ppp_packet(&mut self.socket, ppp::Protocol::IPV4, data).await } - pub async fn read_packet(&mut self, dest: &mut [u8]) -> Result { - // TODO: handle async PPP packets. - let length = FortiVPNTunnel::read_ppp_packet(&mut self.socket, dest, false).await?; - if length == 0 { - return Ok(0); + pub async fn try_read_packet( + &mut self, + dest: &mut [u8], + timeout: Option, + ) -> Result { + // Peek header if not yet available - to get the protocol. + if let Some(timeout) = timeout { + match tokio::time::timeout(timeout, self.ppp_state.read_header(&mut self.socket)).await + { + Ok(res) => res, + Err(_) => return Ok(0), + } + } else { + self.ppp_state.read_header(&mut self.socket).await } - let packet = match ppp::Packet::from_bytes(&dest[..length]) { - Ok(packet) => packet, - Err(err) => { - debug!("Failed to decode PPP packet: {}", err); - return Err("Failed to decode PPP packet".into()); + .map_err(|err| { + debug!("Failed to read PPP header: {}", err); + "Failed to read PPP header" + })?; + let protocol = match self.ppp_state.read_protocol() { + Some(protocol) => protocol, + None => { + return Err("Unknown PPP protocol, possibly a framing error".into()); } }; - println!("Packet= {}", packet); + let length = + FortiVPNTunnel::read_ppp_packet(&mut self.socket, &mut self.ppp_state, dest).await?; + match protocol { + ppp::Protocol::LCP => { + // TODO: handle echo requests/responsed here. + let packet = match ppp::Packet::from_bytes(protocol, &dest[..length]) { + Ok(packet) => packet, + Err(err) => { + debug!("Failed to decode PPP packet: {}", err); + return Err("Failed to decode PPP packet".into()); + } + }; + println!("Packet= {}", packet); + Ok(0) + } + ppp::Protocol::IPV4 | ppp::Protocol::IPV6 => Ok(length), + _ => { + println!("Received unexpected PPP packet {}, ignoring", protocol); + Ok(0) + } + } + } +} + +struct PPPState { + ppp_header: [u8; 8], + ppp_header_length: usize, + bytes_remaining: usize, + first_packet: bool, +} + +impl PPPState { + fn new() -> PPPState { + PPPState { + ppp_header: [0u8; 8], + ppp_header_length: 0, + bytes_remaining: 0, + first_packet: true, + } + } + + async fn read_header(&mut self, socket: &mut FortiTlsStream) -> Result<(), FortiError> { + println!("Reading packet"); + if self.bytes_remaining > 0 { + return Ok(()); + } + // If no data is available, this will return immediately. + //match tokio::time::timeout(Duration::from_millis(100), async { + while self.ppp_header_length < self.ppp_header.len() { + match socket + .read(&mut self.ppp_header[self.ppp_header_length..]) + .await + { + Ok(bytes_read) => { + self.ppp_header_length += bytes_read; + if self.ppp_header_length >= self.ppp_header.len() { + println!("Have header {} bytes", self.ppp_header_length); + break; + } else { + println!("Have {} bytes", self.ppp_header_length); + } + } + Err(err) => { + debug!("Failed to read PPP header {}", err); + return Err("Failed to read PPP header".into()); + } + } + } + /* + .await + { + Ok(_) => {} + Err(_) => { + println!("Read timed out"); + return Ok(0); + } + } + */ + println!("Packet ready"); + if let Err(err) = self.validate_link(socket).await { + return Err(err); + } + + let mut ppp_size = [0u8; 2]; + ppp_size.copy_from_slice(&self.ppp_header[..2]); + let ppp_size = u16::from_be_bytes(ppp_size); + let mut data_size = [0u8; 2]; + data_size.copy_from_slice(&self.ppp_header[4..6]); + let data_size = u16::from_be_bytes(data_size); + let magic = &self.ppp_header[2..4]; + if ppp_size != data_size + 6 { + debug!( + "Conflicting packet size data: PPP packet size is {}, data size is {}", + ppp_size, data_size + ); + return Err("Header has conflicting length data".into()); + } + if magic != &[0x50, 0x50] { + debug!( + "Found {:x}{:x} instead of magic", + self.ppp_header[2], self.ppp_header[3] + ); + return Err("Magic not found".into()); + } + self.bytes_remaining = data_size as usize - 2; + Ok(()) + } + + fn remaining_bytes(&self) -> usize { + self.bytes_remaining + } - if packet.read_protocol() == ppp::Protocol::IPV4 { - // TODO: improve this - dest.copy_within(2..length, 0); - Ok(length - 2) + fn consume_bytes(&mut self, count: usize) -> Result<(), FortiError> { + if self.bytes_remaining < count { + Err("Consumed more bytes than were available".into()) } else { - Ok(0) + self.bytes_remaining -= count; + if self.bytes_remaining == 0 { + self.ppp_header_length = 0; + } + Ok(()) } } - async fn validate_link( - socket: &mut FortiTlsStream, - packet_header: &[u8], - ) -> Result<(), FortiError> { + fn read_protocol(&self) -> Option { + if self.ppp_header_length == 8 { + Some(ppp::Protocol::from_be_slice(&self.ppp_header[6..])) + } else { + None + } + } + + async fn validate_link(&mut self, socket: &mut FortiTlsStream) -> Result<(), FortiError> { const FALL_BACK_TO_HTTP: &[u8] = "HTTP/1".as_bytes(); - if packet_header == FALL_BACK_TO_HTTP { + if !self.first_packet { + return Ok(()); + } + self.first_packet = false; + if &self.ppp_header[..FALL_BACK_TO_HTTP.len()] == FALL_BACK_TO_HTTP { // FortiVPN will return an HTTP response if something goes wrong on setup. let headers = read_http_headers(socket).await?; debug!("Tunnel not active, error response: {}", headers); diff --git a/src/network.rs b/src/network.rs index a9ba27c..eab7bda 100644 --- a/src/network.rs +++ b/src/network.rs @@ -83,26 +83,37 @@ impl Network<'_> { } fn copy_all_data(&mut self) { - // This is not fancy or super efficient, but with a single-user (a few connections) should work OK. + // This is not fancy, but with a single-user (a few connections) should work OK. + // smoltcp's poll works exactly the same way and seems to show reasonable performance. // The alternative is using poll/waking and additional buffers (or perhaps a list of futures), which // doesn't work well with smoltcp - as smoltcp keeps ownership of most of its data, and any writes // need to be guarded. use socket::tcp; - for (handle, tunnel) in self.bridges.iter() { + self.bridges.iter().for_each(|(handle, tunnel)| { let socket = self.sockets.get_mut::(*handle); if socket.can_send() { let result = socket.send(|dest| match tunnel.reader.try_read(dest) { - Ok(bytes) => (bytes, Ok(())), + Ok(bytes) => { + if bytes > 0 && dest.len() > 0 { + (bytes, Ok::<(), NetworkError>(())) + } else { + // Zero bytes means the stream is closed. + (0, Err("Proxy reader is closed".into())) + } + } Err(err) => match err.kind() { io::ErrorKind::WouldBlock => (0, Ok(())), - _ => (0, Err(err)), + _ => (0, Err(err.into())), }, }); if let Ok(result) = result { if let Err(err) = result { warn!("Failed to read data from SOCKS socket: {}", err); + socket.close(); + return; } } else if let Err(err) = result { + // Not critical if socket is still opening. warn!("Failed to send data to virtual socket: {}", err); } } @@ -118,12 +129,23 @@ impl Network<'_> { if let Ok(result) = result { if let Err(err) = result { warn!("Failed to write data to SOCKS socket: {}", err); + socket.close(); + return; } } else if let Err(err) = result { warn!("Failed to read data from virtual socket: {}", err); } } - } + }); + + self.bridges.retain(|handle, _| { + let socket = self.sockets.get_mut::(*handle); + if socket.is_open() { + return true; + } + self.sockets.remove(*handle); + false + }); for (handle, response) in self.opening_connections.iter_mut() { let socket = self.sockets.get::(*handle); @@ -156,7 +178,9 @@ impl Network<'_> { } else { continue; }; - let _ = response.send(result); + if let Err(_err) = response.send(result) { + debug!("Proxy listener not listening for response"); + } } self.opening_connections .retain(|_, response| response.is_some()); @@ -258,7 +282,10 @@ impl VpnDevice { // Data is not consumed yet. return Ok(()); } - self.read_packet_size = self.vpn.read_packet(&mut self.read_packet).await?; + self.read_packet_size = self + .vpn + .try_read_packet(&mut self.read_packet, None) + .await?; Ok(()) } diff --git a/src/ppp.rs b/src/ppp.rs index 8ef5f35..b21440f 100644 --- a/src/ppp.rs +++ b/src/ppp.rs @@ -7,34 +7,23 @@ use log::debug; */ pub struct Packet<'a> { + protocol: Protocol, data: &'a [u8], } impl Packet<'_> { - pub fn from_bytes(b: &[u8]) -> Result { - if b.len() < 2 { - debug!("Not enough data in PPP packet"); - Err("Not enough data in PPP packet".into()) - } else { - let packet = Packet { data: b }; - packet.validate()?; - Ok(packet) - } + pub fn from_bytes(protocol: Protocol, data: &[u8]) -> Result { + let packet = Packet { protocol, data }; + packet.validate()?; + Ok(packet) } pub fn validate(&self) -> Result<(), FormatError> { - self.read_protocol().validate() - } - - pub fn read_protocol(&self) -> Protocol { - let mut result = [0u8; 2]; - result.copy_from_slice(&self.data[..2]); - Protocol::from_u16(u16::from_be_bytes(result)) + self.protocol.validate() } pub fn to_lcp(&self) -> Result { - let protocol = self.read_protocol(); - if protocol == Protocol::LCP { + if self.protocol == Protocol::LCP { LcpPacket::from_bytes(&self.data[2..]) } else { Err("Protocol type is not LCP".into()) @@ -42,8 +31,7 @@ impl Packet<'_> { } pub fn to_ipcp(&self) -> Result { - let protocol = self.read_protocol(); - if protocol == Protocol::IPV4CP { + if self.protocol == Protocol::IPV4CP { IpcpPacket::from_bytes(&self.data[2..]) } else { Err("Protocol type is not IPCP".into()) @@ -60,6 +48,12 @@ impl Protocol { pub const LCP: Protocol = Protocol(0xc021); pub const IPV4CP: Protocol = Protocol(0x8021); + pub fn from_be_slice(slice: &[u8]) -> Protocol { + let mut result = [0u8; 2]; + result.copy_from_slice(&slice); + Protocol::from_u16(u16::from_be_bytes(result)) + } + fn from_u16(value: u16) -> Protocol { Protocol(value) } @@ -656,7 +650,7 @@ fn fmt_slice_hex(data: &[u8], f: &mut dyn std::fmt::Write) -> std::fmt::Result { impl fmt::Display for Packet<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Protocol: {}", self.read_protocol())?; + write!(f, "Protocol: {}", self.protocol)?; if let Ok(lcp) = self.to_lcp() { write!(f, ", LCP code: {} id: {}", lcp.code(), lcp.identifier())?; for opt in lcp.iter_options() {