Skip to content

Commit

Permalink
Rewrite the PQ cancellation token in a safer way
Browse files Browse the repository at this point in the history
  • Loading branch information
buggmagnet committed Apr 8, 2024
1 parent 6291038 commit 6d03a7e
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 37 deletions.
23 changes: 15 additions & 8 deletions ios/MullvadPostQuantum/PostQuantumKeyNegotiatior.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ import TalpidTunnelConfigClientProxy
import WireGuardKitTypes

public class PostQuantumKeyNegotiatior {
private var cancellationToken: UnsafeRawPointer?
public init() {}

var cancelToken: PostQuantumCancelToken?

public func negotiateKey(
gatewayIP: IPv4Address,
devicePublicKey: PublicKey,
Expand All @@ -24,24 +25,30 @@ public class PostQuantumKeyNegotiatior {
) {
let packetTunnelPointer = Unmanaged.passUnretained(packetTunnel).toOpaque()
let opaqueConnection = Unmanaged.passUnretained(tcpConnection).toOpaque()
var cancelToken = PostQuantumCancelToken()

// TODO: Any non 0 return is considered a failure, and should be handled gracefully
let token = negotiate_post_quantum_key(
let result = negotiate_post_quantum_key(
devicePublicKey.rawValue.map { $0 },
presharedKey.rawValue.map { $0 },
packetTunnelPointer,
opaqueConnection
opaqueConnection,
&cancelToken
)
guard let token else {
guard result == 0 else {
// Handle failure here
return
}

cancellationToken = token
self.cancelToken = cancelToken
}

public func cancelKeyNegotiation() {
guard let cancellationToken else { return }
cancel_post_quantum_key_exchange(cancellationToken)
guard var cancelToken else { return }
cancel_post_quantum_key_exchange(&cancelToken)
}

deinit {
guard var cancelToken else { return }
drop_post_quantum_key_exchange_token(&cancelToken)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
#include <stdint.h>
#include <stdlib.h>

void cancel_post_quantum_key_exchange(const void *sender);
typedef struct PostQuantumCancelToken {
void *context;
} PostQuantumCancelToken;

void cancel_post_quantum_key_exchange(const struct PostQuantumCancelToken *sender);

void drop_post_quantum_key_exchange_token(const struct PostQuantumCancelToken *sender);

/**
* Callback to call when the TCP connection has written data.
Expand All @@ -21,10 +27,11 @@ void handle_recv(const uint8_t *data, uintptr_t data_len, const void *sender);
* # Safety
* This function is safe to call
*/
const void *negotiate_post_quantum_key(const uint8_t *public_key,
const uint8_t *ephemeral_public_key,
const void *packet_tunnel,
const void *tcp_connection);
int32_t negotiate_post_quantum_key(const uint8_t *public_key,
const uint8_t *ephemeral_public_key,
const void *packet_tunnel,
const void *tcp_connection,
struct PostQuantumCancelToken *cancel_token);

/**
* Called when there is data to send on the TCP connection.
Expand Down
47 changes: 30 additions & 17 deletions ios/MullvadPostQuantum/talpid-tunnel-config-client/src/ios_ffi.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
use libc::c_void;
use tokio::sync::mpsc;

use crate::PostQuantumCancelToken;

use super::run_ios_runtime;

use std::{rc::Weak, sync::Once};
use std::sync::Once;
static INIT_LOGGING: Once = Once::new();

#[allow(clippy::let_underscore_future)]
#[no_mangle]
pub unsafe extern "C" fn cancel_post_quantum_key_exchange(sender: *const c_void) {
// Try to take the value, if there is a value, we can safely send the message, otherwise, assume it has been dropped and nothing happens
let send_tx: Weak<mpsc::Sender<()>> = unsafe { Weak::from_raw(sender as _) };
if let Some(tx) = send_tx.upgrade() {
// # Safety
// Clippy warns of a non-binding let on a future, this future is being awaited on.
_ = tx.send(());
}
pub unsafe extern "C" fn cancel_post_quantum_key_exchange(sender: *const PostQuantumCancelToken) {
let sender = unsafe { &*sender };
sender.cancel();
}

#[no_mangle]
pub unsafe extern "C" fn drop_post_quantum_key_exchange_token(
sender: *const PostQuantumCancelToken,
) {
let _sender = unsafe { std::ptr::read(sender) };
}

/// Callback to call when the TCP connection has written data.
#[no_mangle]
pub unsafe extern "C" fn handle_sent(bytes_sent: usize, sender: *const c_void) {
Expand Down Expand Up @@ -45,7 +49,8 @@ pub unsafe extern "C" fn negotiate_post_quantum_key(
ephemeral_public_key: *const u8,
packet_tunnel: *const c_void,
tcp_connection: *const c_void,
) -> *const c_void {
cancel_token: *mut PostQuantumCancelToken,
) -> i32 {
INIT_LOGGING.call_once(|| {
let _ = oslog::OsLogger::new("net.mullvad.MullvadVPN.TTCC")
.level_filter(log::LevelFilter::Trace)
Expand All @@ -56,10 +61,18 @@ pub unsafe extern "C" fn negotiate_post_quantum_key(
let eph_pub_key_copy: [u8; 32] =
unsafe { std::ptr::read(ephemeral_public_key as *const [u8; 32]) };

run_ios_runtime(
pub_key_copy,
eph_pub_key_copy,
packet_tunnel,
tcp_connection,
)
match unsafe {
run_ios_runtime(
pub_key_copy,
eph_pub_key_copy,
packet_tunnel,
tcp_connection,
)
} {
Ok(token) => {
unsafe { std::ptr::write(cancel_token, token) };
0
}
Err(err) => err,
}
}
41 changes: 34 additions & 7 deletions ios/MullvadPostQuantum/talpid-tunnel-config-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,51 @@ use talpid_types::net::wireguard::PublicKey;
use tonic::transport::Endpoint;
use tower::service_fn;

#[repr(C)]
pub struct PostQuantumCancelToken {
// Must keep a pointer to a valid std::sync::Arc<tokio::mpsc::UnboundedSender>
pub context: *mut c_void,
}

impl PostQuantumCancelToken {
/// #Safety
/// This function can only be called when the context pointer is valid.
unsafe fn cancel(&self) {
// Try to take the value, if there is a value, we can safely send the message, otherwise, assume it has been dropped and nothing happens
let send_tx: Arc<mpsc::UnboundedSender<()>> = unsafe { Arc::from_raw(self.context as _) };
let _ = send_tx.send(());
std::mem::forget(send_tx);
}
}

impl Drop for PostQuantumCancelToken {
fn drop(&mut self) {
let _: Arc<mpsc::UnboundedSender<()>> = unsafe { Arc::from_raw(self.context as _) };
}
}
unsafe impl Send for PostQuantumCancelToken {}

/// # Safety
/// This function is safe to call
/// packet_tunnel and tcp_connection must be valid pointers to a packet tunnel and a TCP connection instances.
///
pub unsafe fn run_ios_runtime(
pub_key: [u8; 32],
ephemeral_pub_key: [u8; 32],
packet_tunnel: *const c_void,
tcp_connection: *const c_void,
) -> *const c_void {
match IOSRuntime::new(pub_key, ephemeral_pub_key, packet_tunnel, tcp_connection) {
) -> Result<PostQuantumCancelToken, i32> {
match unsafe { IOSRuntime::new(pub_key, ephemeral_pub_key, packet_tunnel, tcp_connection) } {
Ok(runtime) => {
let weak_cancel_token = Arc::downgrade(&runtime.cancel_token_tx);
let token = weak_cancel_token.into_raw() as _;
let token = runtime.cancel_token_tx.clone();

runtime.run();
token
Ok(PostQuantumCancelToken {
context: Arc::into_raw(token) as *mut _,
})
}
Err(err) => {
log::error!("Failed to create runtime {}", err);
std::ptr::null()
Err(-1)
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions ios/PacketTunnelCore/Actor/State+Extensions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ extension State {
let .connecting(connState),
let .connected(connState),
let .reconnecting(connState),
let .negotiatingPostQuantumKey(connState, _),
let .disconnecting(connState): connState
default: nil
}
Expand Down Expand Up @@ -121,6 +122,7 @@ extension State {
case .connected: .connected(newValue)
case .reconnecting: .reconnecting(newValue)
case .disconnecting: .disconnecting(newValue)
case let .negotiatingPostQuantumKey(_, privateKey): .negotiatingPostQuantumKey(newValue, privateKey)
default: self
}
}
Expand Down

0 comments on commit 6d03a7e

Please sign in to comment.