diff --git a/ios/MullvadPostQuantum/PostQuantumKeyNegotiatior.swift b/ios/MullvadPostQuantum/PostQuantumKeyNegotiatior.swift index 861ecee3033b..9bdd8b41e838 100644 --- a/ios/MullvadPostQuantum/PostQuantumKeyNegotiatior.swift +++ b/ios/MullvadPostQuantum/PostQuantumKeyNegotiatior.swift @@ -12,9 +12,10 @@ import TalpidTunnelConfigClientProxy import WireGuardKitTypes public class PostQuantumKeyNegotiatior { - private var cancellationToken: UnsafeRawPointer? public init() {} + var cancelToken: PostQuantumCancelToken? + public func negotiateKey( gatewayIP: IPv4Address, devicePublicKey: PublicKey, @@ -24,24 +25,30 @@ public class PostQuantumKeyNegotiatior { ) { let packetTunnelPointer = Unmanaged.passUnretained(packetTunnel).toOpaque() let opaqueConnection = Unmanaged.passUnretained(tcpConnection).toOpaque() + var cancelToken = PostQuantumCancelToken() // TODO: Any non 0 return is considered a failure, and should be handled gracefully - let token = negotiate_post_quantum_key( + let result = negotiate_post_quantum_key( devicePublicKey.rawValue.map { $0 }, presharedKey.rawValue.map { $0 }, packetTunnelPointer, - opaqueConnection + opaqueConnection, + &cancelToken ) - guard let token else { + guard result == 0 else { // Handle failure here return } - - cancellationToken = token + self.cancelToken = cancelToken } public func cancelKeyNegotiation() { - guard let cancellationToken else { return } - cancel_post_quantum_key_exchange(cancellationToken) + guard var cancelToken else { return } + cancel_post_quantum_key_exchange(&cancelToken) + } + + deinit { + guard var cancelToken else { return } + drop_post_quantum_key_exchange_token(&cancelToken) } } diff --git a/ios/MullvadPostQuantum/talpid-tunnel-config-client/include/talpid_tunnel_config_client.h b/ios/MullvadPostQuantum/talpid-tunnel-config-client/include/talpid_tunnel_config_client.h index 5d4604a59e39..c14dda4e12ba 100644 --- a/ios/MullvadPostQuantum/talpid-tunnel-config-client/include/talpid_tunnel_config_client.h +++ b/ios/MullvadPostQuantum/talpid-tunnel-config-client/include/talpid_tunnel_config_client.h @@ -3,7 +3,13 @@ #include #include -void cancel_post_quantum_key_exchange(const void *sender); +typedef struct PostQuantumCancelToken { + void *context; +} PostQuantumCancelToken; + +void cancel_post_quantum_key_exchange(const struct PostQuantumCancelToken *sender); + +void drop_post_quantum_key_exchange_token(const struct PostQuantumCancelToken *sender); /** * Callback to call when the TCP connection has written data. @@ -21,10 +27,11 @@ void handle_recv(const uint8_t *data, uintptr_t data_len, const void *sender); * # Safety * This function is safe to call */ -const void *negotiate_post_quantum_key(const uint8_t *public_key, - const uint8_t *ephemeral_public_key, - const void *packet_tunnel, - const void *tcp_connection); +int32_t negotiate_post_quantum_key(const uint8_t *public_key, + const uint8_t *ephemeral_public_key, + const void *packet_tunnel, + const void *tcp_connection, + struct PostQuantumCancelToken *cancel_token); /** * Called when there is data to send on the TCP connection. diff --git a/ios/MullvadPostQuantum/talpid-tunnel-config-client/src/ios_ffi.rs b/ios/MullvadPostQuantum/talpid-tunnel-config-client/src/ios_ffi.rs index 78db55283f06..b69a1abdf4ec 100644 --- a/ios/MullvadPostQuantum/talpid-tunnel-config-client/src/ios_ffi.rs +++ b/ios/MullvadPostQuantum/talpid-tunnel-config-client/src/ios_ffi.rs @@ -1,22 +1,26 @@ use libc::c_void; use tokio::sync::mpsc; +use crate::PostQuantumCancelToken; + use super::run_ios_runtime; -use std::{rc::Weak, sync::Once}; +use std::sync::Once; static INIT_LOGGING: Once = Once::new(); -#[allow(clippy::let_underscore_future)] #[no_mangle] -pub unsafe extern "C" fn cancel_post_quantum_key_exchange(sender: *const c_void) { - // Try to take the value, if there is a value, we can safely send the message, otherwise, assume it has been dropped and nothing happens - let send_tx: Weak> = unsafe { Weak::from_raw(sender as _) }; - if let Some(tx) = send_tx.upgrade() { - // # Safety - // Clippy warns of a non-binding let on a future, this future is being awaited on. - _ = tx.send(()); - } +pub unsafe extern "C" fn cancel_post_quantum_key_exchange(sender: *const PostQuantumCancelToken) { + let sender = unsafe { &*sender }; + sender.cancel(); } + +#[no_mangle] +pub unsafe extern "C" fn drop_post_quantum_key_exchange_token( + sender: *const PostQuantumCancelToken, +) { + let _sender = unsafe { std::ptr::read(sender) }; +} + /// Callback to call when the TCP connection has written data. #[no_mangle] pub unsafe extern "C" fn handle_sent(bytes_sent: usize, sender: *const c_void) { @@ -45,7 +49,8 @@ pub unsafe extern "C" fn negotiate_post_quantum_key( ephemeral_public_key: *const u8, packet_tunnel: *const c_void, tcp_connection: *const c_void, -) -> *const c_void { + cancel_token: *mut PostQuantumCancelToken, +) -> i32 { INIT_LOGGING.call_once(|| { let _ = oslog::OsLogger::new("net.mullvad.MullvadVPN.TTCC") .level_filter(log::LevelFilter::Trace) @@ -56,10 +61,18 @@ pub unsafe extern "C" fn negotiate_post_quantum_key( let eph_pub_key_copy: [u8; 32] = unsafe { std::ptr::read(ephemeral_public_key as *const [u8; 32]) }; - run_ios_runtime( - pub_key_copy, - eph_pub_key_copy, - packet_tunnel, - tcp_connection, - ) + match unsafe { + run_ios_runtime( + pub_key_copy, + eph_pub_key_copy, + packet_tunnel, + tcp_connection, + ) + } { + Ok(token) => { + unsafe { std::ptr::write(cancel_token, token) }; + 0 + } + Err(err) => err, + } } diff --git a/ios/MullvadPostQuantum/talpid-tunnel-config-client/src/lib.rs b/ios/MullvadPostQuantum/talpid-tunnel-config-client/src/lib.rs index 17212a0a269f..079baf102708 100644 --- a/ios/MullvadPostQuantum/talpid-tunnel-config-client/src/lib.rs +++ b/ios/MullvadPostQuantum/talpid-tunnel-config-client/src/lib.rs @@ -18,24 +18,51 @@ use talpid_types::net::wireguard::PublicKey; use tonic::transport::Endpoint; use tower::service_fn; +#[repr(C)] +pub struct PostQuantumCancelToken { + // Must keep a pointer to a valid std::sync::Arc + pub context: *mut c_void, +} + +impl PostQuantumCancelToken { + /// #Safety + /// This function can only be called when the context pointer is valid. + unsafe fn cancel(&self) { + // Try to take the value, if there is a value, we can safely send the message, otherwise, assume it has been dropped and nothing happens + let send_tx: Arc> = unsafe { Arc::from_raw(self.context as _) }; + let _ = send_tx.send(()); + std::mem::forget(send_tx); + } +} + +impl Drop for PostQuantumCancelToken { + fn drop(&mut self) { + let _: Arc> = unsafe { Arc::from_raw(self.context as _) }; + } +} +unsafe impl Send for PostQuantumCancelToken {} + /// # Safety -/// This function is safe to call +/// packet_tunnel and tcp_connection must be valid pointers to a packet tunnel and a TCP connection instances. +/// pub unsafe fn run_ios_runtime( pub_key: [u8; 32], ephemeral_pub_key: [u8; 32], packet_tunnel: *const c_void, tcp_connection: *const c_void, -) -> *const c_void { - match IOSRuntime::new(pub_key, ephemeral_pub_key, packet_tunnel, tcp_connection) { +) -> Result { + match unsafe { IOSRuntime::new(pub_key, ephemeral_pub_key, packet_tunnel, tcp_connection) } { Ok(runtime) => { - let weak_cancel_token = Arc::downgrade(&runtime.cancel_token_tx); - let token = weak_cancel_token.into_raw() as _; + let token = runtime.cancel_token_tx.clone(); + runtime.run(); - token + Ok(PostQuantumCancelToken { + context: Arc::into_raw(token) as *mut _, + }) } Err(err) => { log::error!("Failed to create runtime {}", err); - std::ptr::null() + Err(-1) } } } diff --git a/ios/PacketTunnelCore/Actor/State+Extensions.swift b/ios/PacketTunnelCore/Actor/State+Extensions.swift index 5737aa413818..93fe29f3b731 100644 --- a/ios/PacketTunnelCore/Actor/State+Extensions.swift +++ b/ios/PacketTunnelCore/Actor/State+Extensions.swift @@ -91,6 +91,7 @@ extension State { let .connecting(connState), let .connected(connState), let .reconnecting(connState), + let .negotiatingPostQuantumKey(connState, _), let .disconnecting(connState): connState default: nil } @@ -121,6 +122,7 @@ extension State { case .connected: .connected(newValue) case .reconnecting: .reconnecting(newValue) case .disconnecting: .disconnecting(newValue) + case let .negotiatingPostQuantumKey(_, privateKey): .negotiatingPostQuantumKey(newValue, privateKey) default: self } }