diff --git a/src/ikev2/mod.rs b/src/ikev2/mod.rs index 99e10f0..662f853 100644 --- a/src/ikev2/mod.rs +++ b/src/ikev2/mod.rs @@ -9,7 +9,7 @@ use std::{ sync::Arc, time::{Duration, Instant}, }; -use tokio::{net::UdpSocket, runtime, sync::mpsc, task::JoinHandle}; +use tokio::{net::UdpSocket, runtime, sync::mpsc, task::JoinHandle, time}; mod crypto; mod message; @@ -32,6 +32,7 @@ const MAX_SIGNATURE_LENGTH: usize = 1 + 12 + 72; const IKE_INIT_SA_EXPIRATION: Duration = Duration::from_secs(15); const IKE_SESSION_EXPIRATION: Duration = Duration::from_secs(60 * 15); const IKE_RESPONSE_EXPIRATION: Duration = Duration::from_secs(60); +const IKE_RETRANSMISSIONS_LIMIT: usize = 5; pub struct Config { pub listen_ips: Vec, @@ -305,6 +306,7 @@ impl UdpDatagram { enum SessionMessage { UdpDatagram(UdpDatagram), + RetransmitRequest(SessionID, u32), CleanupTimer, Shutdown, } @@ -395,7 +397,9 @@ impl Sessions { } }); self.sessions.retain(|session_id, session| { - if session.last_update + IKE_SESSION_EXPIRATION < now { + if session.last_update + IKE_SESSION_EXPIRATION < now + || session.request_retransmit > IKE_RETRANSMISSIONS_LIMIT + { info!( "Deleting expired session with SPI {} {:?}", session_id, session.user_id @@ -412,18 +416,28 @@ impl Sessions { }); if self.shutdown { for (session_id, session) in self.sessions.iter_mut() { - if let Err(err) = session.start_request_delete_ike() { - warn!( - "Failed to prepare Delete request to session {}: {}", - session_id, err - ); + match session.sent_request { + Some(RequestContext::DeleteIKEv2) => continue, + _ => {} } - if let Err(err) = session.send_last_request(&self.sockets).await { + let message_id = match session.start_request_delete_ike() { + Ok(message_id) => message_id, + Err(err) => { + warn!( + "Failed to prepare Delete request to session {}: {}", + session_id, err + ); + continue; + } + }; + if let Err(err) = session.send_last_request(&self.sockets, message_id).await { warn!( - "Failed to prepare Delete request to session {}: {}", + "Failed to send Delete request to session {}: {}", session_id, err ); } + session.request_retransmit += 1; + Self::schedule_retransmission(self.tx.clone(), *session_id, message_id, 1).await } } } @@ -442,6 +456,9 @@ impl Sessions { SessionMessage::CleanupTimer => { self.cleanup().await; } + SessionMessage::RetransmitRequest(session_id, message_id) => { + self.retransmit_request(session_id, message_id).await; + } SessionMessage::Shutdown => { self.shutdown = true; self.cleanup().await; @@ -576,6 +593,48 @@ impl Sessions { } } + async fn retransmit_request(&mut self, session_id: SessionID, message_id: u32) { + let session = if let Some(session) = self.sessions.get_mut(&session_id) { + session + } else { + return; + }; + if let Err(err) = session.send_last_request(&self.sockets, message_id).await { + warn!( + "Failed to retransmit last reqeust to session {}: {}", + session_id, err + ); + } + session.request_retransmit += 1; + if session.request_retransmit > IKE_RETRANSMISSIONS_LIMIT { + warn!("Session {} reached retrasmission limit", session_id); + return; + } + let retransmit_counter = session.request_retransmit; + Self::schedule_retransmission(self.tx.clone(), session_id, message_id, retransmit_counter) + .await + } + + async fn schedule_retransmission( + tx: mpsc::Sender, + session_id: SessionID, + message_id: u32, + retransmit_counter: usize, + ) { + let next_retransmission = 3000 * retransmit_counter as u64; + let jitter = next_retransmission * 15 / 100; + let next_delay = rand::thread_rng().gen_range( + next_retransmission.saturating_sub(jitter)..=next_retransmission.saturating_add(jitter), + ); + let next_retransmission = time::Duration::from_millis(next_delay); + let rt = runtime::Handle::current(); + rt.spawn(async move { + time::sleep(next_retransmission).await; + tx.send(SessionMessage::RetransmitRequest(session_id, message_id)) + .await + }); + } + async fn process_esp_packet(&mut self, datagram: &mut UdpDatagram) -> Result<(), IKEv2Error> { let packet_bytes = datagram.request.as_mut_slice(); if packet_bytes == [0xff] { @@ -655,6 +714,7 @@ struct IKEv2Session { last_response: Option<([u8; MAX_DATAGRAM_SIZE], usize)>, last_request: Option<([u8; MAX_DATAGRAM_SIZE], usize)>, sent_request: Option, + request_retransmit: usize, pending_actions: Vec, } @@ -682,6 +742,7 @@ impl IKEv2Session { last_response: None, last_request: None, sent_request: None, + request_retransmit: 0, pending_actions: vec![], } } @@ -1621,7 +1682,7 @@ impl IKEv2Session { &mut self, exchange_type: message::ExchangeType, command_generator: impl FnOnce(&mut message::MessageWriter) -> Result<(), IKEv2Error>, - ) -> Result<(), IKEv2Error> { + ) -> Result { if self.sent_request.is_some() || self.last_request.is_some() { return Err("Already processing another command".into()); } @@ -1645,29 +1706,38 @@ impl IKEv2Session { let request_len = self.complete_encrypted_payload(&mut ikev2_request)?; self.last_request = Some((request_bytes, request_len + start_offset)); - Ok(()) + self.request_retransmit = 0; + Ok(self.local_message_id) } - fn start_request_delete_ike(&mut self) -> Result<(), IKEv2Error> { + fn start_request_delete_ike(&mut self) -> Result { match self.state { SessionState::Empty | SessionState::InitSA(_) | SessionState::Deleting => { - debug!("Received Delete request for a non-established session, ignoring"); - return Ok(()); + return Err( + "Received Delete request for a non-established session, ignoring".into(), + ); } SessionState::Established => {} } - self.start_request(message::ExchangeType::INFORMATIONAL, |writer| { + let message_id = self.start_request(message::ExchangeType::INFORMATIONAL, |writer| { Ok(writer.write_delete_payload(message::IPSecProtocolID::IKE, &[])?) })?; self.state = SessionState::Deleting; self.sent_request = Some(RequestContext::DeleteIKEv2); - Ok(()) + Ok(message_id) } - async fn send_last_request(&self, sockets: &Sockets) -> Result<(), IKEv2Error> { + async fn send_last_request( + &self, + sockets: &Sockets, + message_id: u32, + ) -> Result<(), IKEv2Error> { + if message_id != self.local_message_id { + return Ok(()); + } if let Some((request_bytes, request_len)) = self.last_request { debug!( - "Restransmitting request {} for session {}", + "Transmitting request {} for session {}", self.local_message_id, self.session_id ); sockets @@ -1740,6 +1810,7 @@ impl IKEv2Session { self.local_message_id += 1; self.last_request = None; self.sent_request = None; + self.request_retransmit = 0; // Update remote address if client changed IP or switched to another NAT port. self.remote_addr = remote_addr;