diff --git a/ios/MullvadRustRuntime/EphemeralPeerExchangeActor.swift b/ios/MullvadRustRuntime/EphemeralPeerExchangeActor.swift index 397b656d612b..fc7ff34529c2 100644 --- a/ios/MullvadRustRuntime/EphemeralPeerExchangeActor.swift +++ b/ios/MullvadRustRuntime/EphemeralPeerExchangeActor.swift @@ -21,13 +21,9 @@ public protocol EphemeralPeerExchangeActorProtocol { public class EphemeralPeerExchangeActor: EphemeralPeerExchangeActorProtocol { struct Negotiation { var negotiator: EphemeralPeerNegotiating - var inTunnelTCPConnection: NWTCPConnection - var tcpConnectionObserver: NSKeyValueObservation func cancel() { negotiator.cancelKeyNegotiation() - tcpConnectionObserver.invalidate() - inTunnelTCPConnection.cancel() } } @@ -54,15 +50,6 @@ public class EphemeralPeerExchangeActor: EphemeralPeerExchangeActorProtocol { self.keyExchangeRetriesIterator = iteratorProvider() } - private func createTCPConnection(_ gatewayEndpoint: NWHostEndpoint) -> NWTCPConnection { - self.packetTunnel.createTCPConnectionThroughTunnel( - to: gatewayEndpoint, - enableTLS: false, - tlsParameters: nil, - delegate: nil - ) - } - /// Starts a new key exchange. /// /// Any ongoing key negotiation is stopped before starting a new one. @@ -75,49 +62,46 @@ public class EphemeralPeerExchangeActor: EphemeralPeerExchangeActorProtocol { endCurrentNegotiation() let negotiator = negotiationProvider.init() - let gatewayAddress = LocalNetworkIPs.gatewayAddress.rawValue - let IPv4Gateway = IPv4Address(gatewayAddress)! - let endpoint = NWHostEndpoint(hostname: gatewayAddress, port: "\(CONFIG_SERVICE_PORT)") - let inTunnelTCPConnection = createTCPConnection(endpoint) - // This will become the new private key of the device let ephemeralSharedKey = PrivateKey() let tcpConnectionTimeout = keyExchangeRetriesIterator.next() ?? .seconds(10) // If the connection never becomes viable, force a reconnection after 10 seconds - scheduleInTunnelConnectionTimeout(startTime: .now() + tcpConnectionTimeout) - - let tcpConnectionObserver = inTunnelTCPConnection.observe(\.isViable, options: [ - .initial, - .new, - ]) { [weak self] observedConnection, _ in - guard let self, observedConnection.isViable else { return } - self.negotiation?.tcpConnectionObserver.invalidate() - self.timer?.cancel() - - if !negotiator.startNegotiation( - gatewayIP: IPv4Gateway, - devicePublicKey: privateKey.publicKey, - presharedKey: ephemeralSharedKey, - peerReceiver: packetTunnel, - tcpConnection: inTunnelTCPConnection, - peerExchangeTimeout: tcpConnectionTimeout, - enablePostQuantum: enablePostQuantum, - enableDaita: enableDaita - ) { - // Cancel the negotiation to shut down any remaining use of the TCP connection on the Rust side - self.negotiation?.cancel() - self.negotiation = nil - self.onFailure() - } + let peerParameters = EphemeralPeerParameters( + peer_exchange_timeout: UInt64(tcpConnectionTimeout.timeInterval), + enable_post_quantum: enablePostQuantum, + enable_daita: enableDaita, + funcs: mapWgFuncs(funcs: packetTunnel.wgFuncs()) + ) + + if !negotiator.startNegotiation( + devicePublicKey: privateKey.publicKey, + presharedKey: ephemeralSharedKey, + peerReceiver: packetTunnel, + ephemeralPeerParams: peerParameters + ) { + // Cancel the negotiation to shut down any remaining use of the TCP connection on the Rust side + self.negotiation?.cancel() + self.negotiation = nil + self.onFailure() } + negotiation = Negotiation( - negotiator: negotiator, - inTunnelTCPConnection: inTunnelTCPConnection, - tcpConnectionObserver: tcpConnectionObserver + negotiator: negotiator ) } + private func mapWgFuncs(funcs: WgFuncPointers) -> WgTcpConnectionFuncs { + var mappedFuncs = WgTcpConnectionFuncs() + + mappedFuncs.close_fn = funcs.close + mappedFuncs.open_fn = funcs.open + mappedFuncs.send_fn = funcs.send + mappedFuncs.recv_fn = funcs.receive + + return mappedFuncs + } + /// Cancels the ongoing key exchange. public func endCurrentNegotiation() { negotiation?.cancel() @@ -129,19 +113,4 @@ public class EphemeralPeerExchangeActor: EphemeralPeerExchangeActorProtocol { keyExchangeRetriesIterator = iteratorProvider() endCurrentNegotiation() } - - private func scheduleInTunnelConnectionTimeout(startTime: DispatchWallTime) { - let newTimer = DispatchSource.makeTimerSource() - - newTimer.setEventHandler { [weak self] in - self?.onFailure() - self?.timer?.cancel() - } - - newTimer.schedule(wallDeadline: startTime) - newTimer.activate() - - timer?.cancel() - timer = newTimer - } } diff --git a/ios/MullvadRustRuntime/EphemeralPeerNegotiator.swift b/ios/MullvadRustRuntime/EphemeralPeerNegotiator.swift index ffc0dc15b394..8346b2686d45 100644 --- a/ios/MullvadRustRuntime/EphemeralPeerNegotiator.swift +++ b/ios/MullvadRustRuntime/EphemeralPeerNegotiator.swift @@ -14,14 +14,10 @@ import WireGuardKitTypes // swiftlint:disable function_parameter_count public protocol EphemeralPeerNegotiating { func startNegotiation( - gatewayIP: IPv4Address, devicePublicKey: PublicKey, presharedKey: PrivateKey, peerReceiver: any TunnelProvider, - tcpConnection: NWTCPConnection, - peerExchangeTimeout: Duration, - enablePostQuantum: Bool, - enableDaita: Bool + ephemeralPeerParams: EphemeralPeerParameters ) -> Bool func cancelKeyNegotiation() @@ -33,35 +29,30 @@ public protocol EphemeralPeerNegotiating { public class EphemeralPeerNegotiator: EphemeralPeerNegotiating { required public init() {} - var cancelToken: EphemeralPeerCancelToken? + var cancelToken: OpaquePointer? public func startNegotiation( - gatewayIP: IPv4Address, devicePublicKey: PublicKey, presharedKey: PrivateKey, peerReceiver: any TunnelProvider, - tcpConnection: NWTCPConnection, - peerExchangeTimeout: Duration, - enablePostQuantum: Bool, - enableDaita: Bool + ephemeralPeerParams: EphemeralPeerParameters ) -> Bool { // swiftlint:disable:next force_cast let ephemeralPeerReceiver = Unmanaged.passUnretained(peerReceiver as! EphemeralPeerReceiver) .toOpaque() - let opaqueConnection = Unmanaged.passUnretained(tcpConnection).toOpaque() - var cancelToken = EphemeralPeerCancelToken() - let result = request_ephemeral_peer( + guard let tunnelHandle = try? peerReceiver.tunnelHandle() else { + return false + } + + let cancelToken = request_ephemeral_peer( devicePublicKey.rawValue.map { $0 }, presharedKey.rawValue.map { $0 }, ephemeralPeerReceiver, - opaqueConnection, - &cancelToken, - UInt64(peerExchangeTimeout.timeInterval), - enablePostQuantum, - enableDaita + tunnelHandle, + ephemeralPeerParams ) - guard result == 0 else { + guard let cancelToken else { return false } self.cancelToken = cancelToken @@ -69,13 +60,14 @@ public class EphemeralPeerNegotiator: EphemeralPeerNegotiating { } public func cancelKeyNegotiation() { - guard var cancelToken else { return } - cancel_ephemeral_peer_exchange(&cancelToken) + guard let cancelToken else { return } + cancel_ephemeral_peer_exchange(cancelToken) + self.cancelToken = nil } deinit { - guard var cancelToken else { return } - drop_ephemeral_peer_exchange_token(&cancelToken) + guard let cancelToken else { return } + drop_ephemeral_peer_exchange_token(cancelToken) } } diff --git a/ios/MullvadRustRuntime/PacketTunnelProvider+TCPConnection.swift b/ios/MullvadRustRuntime/PacketTunnelProvider+TCPConnection.swift index e19750d4dcae..290a71b2cd67 100644 --- a/ios/MullvadRustRuntime/PacketTunnelProvider+TCPConnection.swift +++ b/ios/MullvadRustRuntime/PacketTunnelProvider+TCPConnection.swift @@ -12,72 +12,6 @@ import MullvadTypes import NetworkExtension import WireGuardKitTypes -/// Writes data to the in-tunnel TCP connection -/// -/// This FFI function is called by Rust whenever there is data to be written to the in-tunnel TCP connection when exchanging -/// quantum-resistant pre shared keys. -/// -/// Whenever the flow control is given back from the connection, acknowledge that data was written using `rawWriteAcknowledgement`. -/// - Parameters: -/// - rawConnection: A raw pointer to the in-tunnel TCP connection -/// - rawData: A raw pointer to the data to write in the connection -/// - dataLength: The length of data to write in the connection -/// - rawWriteAcknowledgement: An opaque pointer needed for write acknowledgement -@_cdecl("swift_nw_tcp_connection_send") -func tcpConnectionSend( - rawConnection: UnsafeMutableRawPointer?, - rawData: UnsafeMutableRawPointer, - dataLength: UInt, - rawWriteAcknowledgement: UnsafeMutableRawPointer? -) { - guard let rawConnection, let rawWriteAcknowledgement else { - handle_sent(0, rawWriteAcknowledgement) - return - } - let tcpConnection = Unmanaged.fromOpaque(rawConnection).takeUnretainedValue() - let data = Data(bytes: rawData, count: Int(dataLength)) - - // The guarantee that all writes are sequential is done by virtue of not returning the execution context - // to Rust before this closure is done executing. - tcpConnection.write(data, completionHandler: { maybeError in - if maybeError != nil { - handle_sent(0, rawWriteAcknowledgement) - } else { - handle_sent(dataLength, rawWriteAcknowledgement) - } - }) -} - -/// Reads data to the in-tunnel TCP connection -/// -/// This FFI function is called by Rust whenever there is data to be read from the in-tunnel TCP connection when exchanging -/// quantum-resistant pre shared keys. -/// -/// Whenever the flow control is given back from the connection, acknowledge that data was read using `rawReadAcknowledgement`. -/// - Parameters: -/// - rawConnection: A raw pointer to the in-tunnel TCP connection -/// - rawReadAcknowledgement: An opaque pointer needed for read acknowledgement -@_cdecl("swift_nw_tcp_connection_read") -func tcpConnectionReceive( - rawConnection: UnsafeMutableRawPointer?, - rawReadAcknowledgement: UnsafeMutableRawPointer? -) { - guard let rawConnection, let rawReadAcknowledgement else { - handle_recv(nil, 0, rawReadAcknowledgement) - return - } - let tcpConnection = Unmanaged.fromOpaque(rawConnection).takeUnretainedValue() - tcpConnection.readMinimumLength(0, maximumLength: Int(UInt16.max)) { data, maybeError in - if let data { - if maybeError != nil { - handle_recv(nil, 0, rawReadAcknowledgement) - } else { - handle_recv(data.map { $0 }, UInt(data.count), rawReadAcknowledgement) - } - } - } -} - /// End sequence of an ephemeral peer exchange. /// /// This FFI function is called by Rust when an ephemeral peer negotiation succeeded or failed. diff --git a/ios/MullvadRustRuntime/include/mullvad_rust_runtime.h b/ios/MullvadRustRuntime/include/mullvad_rust_runtime.h index a45c6ed6c328..2cf3da7e9271 100644 --- a/ios/MullvadRustRuntime/include/mullvad_rust_runtime.h +++ b/ios/MullvadRustRuntime/include/mullvad_rust_runtime.h @@ -20,14 +20,26 @@ typedef uint8_t TunnelObfuscatorProtocol; */ typedef struct EncryptedDnsProxyState EncryptedDnsProxyState; +typedef struct ExchangeCancelToken ExchangeCancelToken; + typedef struct ProxyHandle { void *context; uint16_t port; } ProxyHandle; -typedef struct EphemeralPeerCancelToken { - void *context; -} EphemeralPeerCancelToken; +typedef struct WgTcpConnectionFuncs { + int32_t (*open_fn)(int32_t tunnelHandle, const char *address, uint64_t timeout); + int32_t (*close_fn)(int32_t tunnelHandle, int32_t socketHandle); + int32_t (*recv_fn)(int32_t tunnelHandle, int32_t socketHandle, uint8_t *data, int32_t len); + int32_t (*send_fn)(int32_t tunnelHandle, int32_t socketHandle, const uint8_t *data, int32_t len); +} WgTcpConnectionFuncs; + +typedef struct EphemeralPeerParameters { + uint64_t peer_exchange_timeout; + bool enable_post_quantum; + bool enable_daita; + struct WgTcpConnectionFuncs funcs; +} EphemeralPeerParameters; extern const uint16_t CONFIG_SERVICE_PORT; @@ -84,43 +96,17 @@ int32_t encrypted_dns_proxy_stop(struct ProxyHandle *proxy_config); * `sender` must be pointing to a valid instance of a `EphemeralPeerCancelToken` created by the * `PacketTunnelProvider`. */ -void cancel_ephemeral_peer_exchange(const struct EphemeralPeerCancelToken *sender); +void cancel_ephemeral_peer_exchange(struct ExchangeCancelToken *sender); /** - * Called by the Swift side to signal that the Rust `EphemeralPeerCancelToken` can be safely dropped - * from memory. + * Called by the Swift side to signal that the Rust `EphemeralPeerCancelToken` can be safely + * dropped from memory. * * # Safety * `sender` must be pointing to a valid instance of a `EphemeralPeerCancelToken` created by the * `PacketTunnelProvider`. */ -void drop_ephemeral_peer_exchange_token(const struct EphemeralPeerCancelToken *sender); - -/** - * Called by Swift whenever data has been written to the in-tunnel TCP connection when exchanging - * quantum-resistant pre shared keys, or ephemeral peers. - * - * If `bytes_sent` is 0, this indicates that the connection was closed or that an error occurred. - * - * # Safety - * `sender` must be pointing to a valid instance of a `write_tx` created by the `IosTcpProvider` - * Callback to call when the TCP connection has written data. - */ -void handle_sent(uintptr_t bytes_sent, const void *sender); - -/** - * Called by Swift whenever data has been read from the in-tunnel TCP connection when exchanging - * quantum-resistant pre shared keys, or ephemeral peers. - * - * If `data` is null or empty, this indicates that the connection was closed or that an error - * occurred. An empty buffer is sent to the underlying reader to signal EOF. - * - * # Safety - * `sender` must be pointing to a valid instance of a `read_tx` created by the `IosTcpProvider` - * - * Callback to call when the TCP connection has received data. - */ -void handle_recv(const uint8_t *data, uintptr_t data_len, const void *sender); +void drop_ephemeral_peer_exchange_token(struct ExchangeCancelToken *sender); /** * Entry point for requesting ephemeral peers on iOS. @@ -128,33 +114,15 @@ void handle_recv(const uint8_t *data, uintptr_t data_len, const void *sender); * # Safety * `public_key` and `ephemeral_key` must be valid respective `PublicKey` and `PrivateKey` types. * They will not be valid after this function is called, and thus must be copied here. - * `packet_tunnel` and `tcp_connection` must be valid pointers to a packet tunnel and a TCP - * connection instances. - * `cancel_token` should be owned by the caller of this function. - */ -int32_t request_ephemeral_peer(const uint8_t *public_key, - const uint8_t *ephemeral_key, - const void *packet_tunnel, - const void *tcp_connection, - struct EphemeralPeerCancelToken *cancel_token, - uint64_t peer_exchange_timeout, - bool enable_post_quantum, - bool enable_daita); - -/** - * Called when there is data to send on the TCP connection. - * The TCP connection must write data on the wire, then call the `handle_sent` function. - */ -extern void swift_nw_tcp_connection_send(const void *connection, - const void *data, - uintptr_t data_len, - const void *sender); - -/** - * Called when there is data to read on the TCP connection. - * The TCP connection must read data from the wire, then call the `handle_read` function. - */ -extern void swift_nw_tcp_connection_read(const void *connection, const void *sender); + * `packet_tunnel` must be valid pointers to a packet tunnel, the packet tunnel pointer must + * outlive the ephemeral peer exchange. `cancel_token` should be owned by the caller of this + * function. + */ +struct ExchangeCancelToken *request_ephemeral_peer(const uint8_t *public_key, + const uint8_t *ephemeral_key, + const void *packet_tunnel, + int32_t tunnel_handle, + struct EphemeralPeerParameters peer_parameters); /** * Called when the preshared post quantum key is ready, diff --git a/ios/MullvadRustRuntimeTests/MullvadPostQuantum+Stubs.swift b/ios/MullvadRustRuntimeTests/MullvadPostQuantum+Stubs.swift index 683e1ab8de61..284588624d55 100644 --- a/ios/MullvadRustRuntimeTests/MullvadPostQuantum+Stubs.swift +++ b/ios/MullvadRustRuntimeTests/MullvadPostQuantum+Stubs.swift @@ -27,6 +27,19 @@ class NWTCPConnectionStub: NWTCPConnection { } class TunnelProviderStub: TunnelProvider { + func tunnelHandle() throws -> Int32 { + 0 + } + + func wgFuncs() -> MullvadTypes.WgFuncPointers { + return MullvadTypes.WgFuncPointers( + open: { _, _, _ in return 0 }, + close: { _, _ in return 0 }, + receive: { _, _, _, _ in return 0 }, + send: { _, _, _, _ in return 0 } + ) + } + let tcpConnection: NWTCPConnectionStub init(tcpConnection: NWTCPConnectionStub) { @@ -55,15 +68,13 @@ class FailedNegotiatorStub: EphemeralPeerNegotiating { } func startNegotiation( - gatewayIP: IPv4Address, devicePublicKey: WireGuardKitTypes.PublicKey, presharedKey: WireGuardKitTypes.PrivateKey, - peerReceiver packetTunnel: any MullvadTypes.TunnelProvider, - tcpConnection: NWTCPConnection, - peerExchangeTimeout: MullvadTypes.Duration, - enablePostQuantum: Bool, - enableDaita: Bool - ) -> Bool { false } + peerReceiver: any MullvadTypes.TunnelProvider, + ephemeralPeerParams: EphemeralPeerParameters + ) -> Bool { + false + } func cancelKeyNegotiation() { onCancelKeyNegotiation?() @@ -81,15 +92,13 @@ class SuccessfulNegotiatorStub: EphemeralPeerNegotiating { } func startNegotiation( - gatewayIP: IPv4Address, devicePublicKey: WireGuardKitTypes.PublicKey, presharedKey: WireGuardKitTypes.PrivateKey, - peerReceiver packetTunnel: any MullvadTypes.TunnelProvider, - tcpConnection: NWTCPConnection, - peerExchangeTimeout: MullvadTypes.Duration, - enablePostQuantum: Bool, - enableDaita: Bool - ) -> Bool { true } + peerReceiver: any MullvadTypes.TunnelProvider, + ephemeralPeerParams: EphemeralPeerParameters + ) -> Bool { + true + } func cancelKeyNegotiation() { onCancelKeyNegotiation?() diff --git a/ios/MullvadTypes/Promise.swift b/ios/MullvadTypes/Promise.swift index 886f00f63354..48a12f818230 100644 --- a/ios/MullvadTypes/Promise.swift +++ b/ios/MullvadTypes/Promise.swift @@ -47,3 +47,19 @@ public final class Promise { } } } + + +public struct OneshotChannel { + private let semaphore = DispatchSemaphore(value: 0) + + public init() { + } + + public mutating func send() { + semaphore.signal() + } + + public func receive() { + semaphore.wait() + } +} diff --git a/ios/MullvadTypes/Protocols/EphemeralPeerReceiver.swift b/ios/MullvadTypes/Protocols/EphemeralPeerReceiver.swift index e5fc68f68a47..d56693eb8265 100644 --- a/ios/MullvadTypes/Protocols/EphemeralPeerReceiver.swift +++ b/ios/MullvadTypes/Protocols/EphemeralPeerReceiver.swift @@ -11,10 +11,20 @@ import NetworkExtension import WireGuardKitTypes public class EphemeralPeerReceiver: EphemeralPeerReceiving, TunnelProvider { - unowned let tunnelProvider: NEPacketTunnelProvider + public func tunnelHandle() throws -> Int32 { + try tunnelProvider.tunnelHandle() + } + + public func wgFuncs() -> WgFuncPointers { + tunnelProvider.wgFuncs() + } + + unowned let tunnelProvider: any TunnelProvider + let keyReceiver: any EphemeralPeerReceiving - public init(tunnelProvider: NEPacketTunnelProvider) { + public init(tunnelProvider: TunnelProvider, keyReceiver: any EphemeralPeerReceiving) { self.tunnelProvider = tunnelProvider + self.keyReceiver = keyReceiver } // MARK: - EphemeralPeerReceiving @@ -30,23 +40,6 @@ public class EphemeralPeerReceiver: EphemeralPeerReceiving, TunnelProvider { } public func ephemeralPeerExchangeFailed() { - guard let receiver = tunnelProvider as? EphemeralPeerReceiving else { return } - receiver.ephemeralPeerExchangeFailed() - } - - // MARK: - TunnelProvider - - public func createTCPConnectionThroughTunnel( - to remoteEndpoint: NWEndpoint, - enableTLS: Bool, - tlsParameters TLSParameters: NWTLSParameters?, - delegate: Any? - ) -> NWTCPConnection { - tunnelProvider.createTCPConnectionThroughTunnel( - to: remoteEndpoint, - enableTLS: enableTLS, - tlsParameters: TLSParameters, - delegate: delegate - ) + keyReceiver.ephemeralPeerExchangeFailed() } } diff --git a/ios/MullvadTypes/Protocols/TunnelProvider.swift b/ios/MullvadTypes/Protocols/TunnelProvider.swift index 61aa99ba7da8..78e312fa33d4 100644 --- a/ios/MullvadTypes/Protocols/TunnelProvider.swift +++ b/ios/MullvadTypes/Protocols/TunnelProvider.swift @@ -10,12 +10,25 @@ import Foundation import NetworkExtension public protocol TunnelProvider: AnyObject { - func createTCPConnectionThroughTunnel( - to remoteEndpoint: NWEndpoint, - enableTLS: Bool, - tlsParameters TLSParameters: NWTLSParameters?, - delegate: Any? - ) -> NWTCPConnection + func tunnelHandle() throws -> Int32 + func wgFuncs() -> WgFuncPointers } -extension NEPacketTunnelProvider: TunnelProvider {} +public typealias TcpOpenFunc = @convention(c) (Int32, UnsafePointer?, UInt64) -> Int32 +public typealias TcpCloseFunc = @convention(c) (Int32, Int32) -> Int32 +public typealias TcpSendFunc = @convention(c) (Int32, Int32, UnsafePointer?, Int32) -> Int32 +public typealias TcpRecvFunc = @convention(c) (Int32, Int32, UnsafeMutablePointer?, Int32) -> Int32 + +public struct WgFuncPointers { + public let open: TcpOpenFunc + public let close: TcpCloseFunc + public let receive: TcpRecvFunc + public let send: TcpSendFunc + + public init(open: TcpOpenFunc, close: TcpCloseFunc, receive: TcpRecvFunc, send: TcpSendFunc) { + self.open = open + self.close = close + self.receive = receive + self.send = send + } +} diff --git a/ios/PacketTunnel/PacketTunnelProvider/PacketTunnelProvider.swift b/ios/PacketTunnel/PacketTunnelProvider/PacketTunnelProvider.swift index 84ff437b2010..3216773eb56b 100644 --- a/ios/PacketTunnel/PacketTunnelProvider/PacketTunnelProvider.swift +++ b/ios/PacketTunnel/PacketTunnelProvider/PacketTunnelProvider.swift @@ -34,7 +34,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider { private let tunnelSettingsListener = TunnelSettingsListener() private lazy var ephemeralPeerReceiver = { - EphemeralPeerReceiver(tunnelProvider: self) + EphemeralPeerReceiver(tunnelProvider: adapter, keyReceiver: self) }() // swiftlint:disable:next function_body_length @@ -110,7 +110,9 @@ class PacketTunnelProvider: NEPacketTunnelProvider { iteratorProvider: { REST.RetryStrategy.postQuantumKeyExchange.makeDelayIterator() } ), onUpdateConfiguration: { [unowned self] configuration in - actor.changeEphemeralPeerNegotiationState(configuration: configuration) + let channel = OneshotChannel() + actor.changeEphemeralPeerNegotiationState(configuration: configuration, reconfigurationSemaphore: channel) + channel.receive() }, onFinish: { [unowned self] in actor.notifyEphemeralPeerNegotiated() } diff --git a/ios/PacketTunnel/PostQuantum/MultiHopEphemeralPeerExchanger.swift b/ios/PacketTunnel/PostQuantum/MultiHopEphemeralPeerExchanger.swift index ee1a3afe6c42..0e2bc97778a1 100644 --- a/ios/PacketTunnel/PostQuantum/MultiHopEphemeralPeerExchanger.swift +++ b/ios/PacketTunnel/PostQuantum/MultiHopEphemeralPeerExchanger.swift @@ -19,7 +19,7 @@ final class MultiHopEphemeralPeerExchanger: EphemeralPeerExchangingProtocol { let keyExchanger: EphemeralPeerExchangeActorProtocol let devicePrivateKey: PrivateKey let onFinish: () -> Void - let onUpdateConfiguration: (EphemeralPeerNegotiationState) -> Void + let onUpdateConfiguration: (EphemeralPeerNegotiationState) async -> Void let enablePostQuantum: Bool let enableDaita: Bool @@ -48,7 +48,7 @@ final class MultiHopEphemeralPeerExchanger: EphemeralPeerExchangingProtocol { keyExchanger: EphemeralPeerExchangeActorProtocol, enablePostQuantum: Bool, enableDaita: Bool, - onUpdateConfiguration: @escaping (EphemeralPeerNegotiationState) -> Void, + onUpdateConfiguration: @escaping (EphemeralPeerNegotiationState) async -> Void, onFinish: @escaping () -> Void ) { self.entry = entry @@ -61,37 +61,37 @@ final class MultiHopEphemeralPeerExchanger: EphemeralPeerExchangingProtocol { self.onFinish = onFinish } - func start() { + func start() async { guard state == .initial else { return } - negotiateWithEntry() + await negotiateWithEntry() } - public func receiveEphemeralPeerPrivateKey(_ ephemeralPeerPrivateKey: PrivateKey) { + public func receiveEphemeralPeerPrivateKey(_ ephemeralPeerPrivateKey: PrivateKey) async { if state == .negotiatingWithEntry { entryPeerKey = EphemeralPeerKey(ephemeralKey: ephemeralPeerPrivateKey) - negotiateBetweenEntryAndExit() + await negotiateBetweenEntryAndExit() } else if state == .negotiatingBetweenEntryAndExit { exitPeerKey = EphemeralPeerKey(ephemeralKey: ephemeralPeerPrivateKey) - makeConnection() + await makeConnection() } } func receivePostQuantumKey( _ preSharedKey: PreSharedKey, ephemeralKey: PrivateKey - ) { + ) async { if state == .negotiatingWithEntry { entryPeerKey = EphemeralPeerKey(preSharedKey: preSharedKey, ephemeralKey: ephemeralKey) - negotiateBetweenEntryAndExit() + await negotiateBetweenEntryAndExit() } else if state == .negotiatingBetweenEntryAndExit { exitPeerKey = EphemeralPeerKey(preSharedKey: preSharedKey, ephemeralKey: ephemeralKey) - makeConnection() + await makeConnection() } } - private func negotiateWithEntry() { + private func negotiateWithEntry() async { state = .negotiatingWithEntry - onUpdateConfiguration(.single(EphemeralPeerRelayConfiguration( + await onUpdateConfiguration(.single(EphemeralPeerRelayConfiguration( relay: entry, configuration: EphemeralPeerConfiguration( privateKey: devicePrivateKey, @@ -105,9 +105,9 @@ final class MultiHopEphemeralPeerExchanger: EphemeralPeerExchangingProtocol { ) } - private func negotiateBetweenEntryAndExit() { + private func negotiateBetweenEntryAndExit() async { state = .negotiatingBetweenEntryAndExit - onUpdateConfiguration(.multi( + await onUpdateConfiguration(.multi( entry: EphemeralPeerRelayConfiguration( relay: entry, configuration: EphemeralPeerConfiguration( @@ -132,9 +132,9 @@ final class MultiHopEphemeralPeerExchanger: EphemeralPeerExchangingProtocol { ) } - private func makeConnection() { + private func makeConnection() async { state = .makeConnection - onUpdateConfiguration(.multi( + await onUpdateConfiguration(.multi( entry: EphemeralPeerRelayConfiguration( relay: entry, configuration: EphemeralPeerConfiguration( diff --git a/ios/PacketTunnel/WireGuardAdapter/WgAdapter.swift b/ios/PacketTunnel/WireGuardAdapter/WgAdapter.swift index 4bfd9b809103..f02a64d18e1d 100644 --- a/ios/PacketTunnel/WireGuardAdapter/WgAdapter.swift +++ b/ios/PacketTunnel/WireGuardAdapter/WgAdapter.swift @@ -13,7 +13,7 @@ import NetworkExtension import PacketTunnelCore import WireGuardKit -struct WgAdapter: TunnelAdapterProtocol { +class WgAdapter: TunnelAdapterProtocol { let logger = Logger(label: "WgAdapter") let adapter: WireGuardAdapter @@ -212,3 +212,18 @@ private extension WgStats { return UInt64(value) } + +extension WgAdapter: TunnelProvider { + public func tunnelHandle() throws -> Int32 { + return try self.adapter.tunnelHandle() + } + + public func wgFuncs() -> WgFuncPointers { + WgFuncPointers( + open: adapter.inTunnelTcpOpen, + close: adapter.inTunnelTcpClose, + receive: adapter.inTunnelTcpRecv, + send: adapter.inTunnelTcpSend + ) + } +} diff --git a/ios/PacketTunnelCore/Actor/PacketTunnelActor+Public.swift b/ios/PacketTunnelCore/Actor/PacketTunnelActor+Public.swift index 05b69deb35d2..9a9fb531c3ae 100644 --- a/ios/PacketTunnelCore/Actor/PacketTunnelActor+Public.swift +++ b/ios/PacketTunnelCore/Actor/PacketTunnelActor+Public.swift @@ -7,6 +7,7 @@ // import Foundation +import MullvadTypes import WireGuardKitTypes /** @@ -64,8 +65,11 @@ extension PacketTunnelActor { - Parameter key: the new key */ - nonisolated public func changeEphemeralPeerNegotiationState(configuration: EphemeralPeerNegotiationState) { - eventChannel.send(.ephemeralPeerNegotiationStateChanged(configuration)) + nonisolated public func changeEphemeralPeerNegotiationState( + configuration: EphemeralPeerNegotiationState, + reconfigurationSemaphore: OneshotChannel + ) { + eventChannel.send(.ephemeralPeerNegotiationStateChanged(configuration, reconfigurationSemaphore)) } /** diff --git a/ios/PacketTunnelCore/Actor/PacketTunnelActor.swift b/ios/PacketTunnelCore/Actor/PacketTunnelActor.swift index 74f8a5a0e262..1a81d99a74bc 100644 --- a/ios/PacketTunnelCore/Actor/PacketTunnelActor.swift +++ b/ios/PacketTunnelCore/Actor/PacketTunnelActor.swift @@ -48,6 +48,7 @@ public actor PacketTunnelActor { public let relaySelector: RelaySelectorProtocol let settingsReader: SettingsReaderProtocol let protocolObfuscator: ProtocolObfuscation +// private var ephemeralPeerExchangingPipeline: EphemeralPeerExchangingPipeline! nonisolated let eventChannel = EventChannel() @@ -139,13 +140,14 @@ public actor PacketTunnelActor { case let .cacheActiveKey(lastKeyRotation): cacheActiveKey(lastKeyRotation: lastKeyRotation) - case let .reconfigureForEphemeralPeer(configuration): + case let .reconfigureForEphemeralPeer(configuration, configurationSemaphore): do { try await updateEphemeralPeerNegotiationState(configuration: configuration) } catch { logger.error(error: error, message: "Failed to reconfigure tunnel after each hop negotiation.") await setErrorStateInternal(with: error) } + configurationSemaphore.send() case .connectWithEphemeralPeer: await connectWithEphemeralPeer() case .setDisconnectedState: diff --git a/ios/PacketTunnelCore/Actor/PacketTunnelActorCommand.swift b/ios/PacketTunnelCore/Actor/PacketTunnelActorCommand.swift index b677986d041b..3e7ab37e540b 100644 --- a/ios/PacketTunnelCore/Actor/PacketTunnelActorCommand.swift +++ b/ios/PacketTunnelCore/Actor/PacketTunnelActorCommand.swift @@ -7,6 +7,7 @@ // import Foundation +import MullvadTypes import WireGuardKitTypes extension PacketTunnelActor { @@ -37,7 +38,7 @@ extension PacketTunnelActor { case networkReachability(NetworkPath) /// Update the device private key, as per post-quantum protocols - case ephemeralPeerNegotiationStateChanged(EphemeralPeerNegotiationState) + case ephemeralPeerNegotiationStateChanged(EphemeralPeerNegotiationState, OneshotChannel) /// Notify that an ephemeral peer exchanging took place case notifyEphemeralPeerNegotiated diff --git a/ios/PacketTunnelCore/Actor/PacketTunnelActorReducer.swift b/ios/PacketTunnelCore/Actor/PacketTunnelActorReducer.swift index 3382baf2092a..5b01778f2d7a 100644 --- a/ios/PacketTunnelCore/Actor/PacketTunnelActorReducer.swift +++ b/ios/PacketTunnelCore/Actor/PacketTunnelActorReducer.swift @@ -7,6 +7,7 @@ // import Foundation +import MullvadTypes import WireGuardKitTypes extension PacketTunnelActor { @@ -25,7 +26,7 @@ extension PacketTunnelActor { case stopTunnelAdapter case configureForErrorState(BlockedStateReason) case cacheActiveKey(Date?) - case reconfigureForEphemeralPeer(EphemeralPeerNegotiationState) + case reconfigureForEphemeralPeer(EphemeralPeerNegotiationState, OneshotChannel) case connectWithEphemeralPeer // acknowledge that the disconnection process has concluded, go to .disconnected. @@ -45,7 +46,7 @@ extension PacketTunnelActor { case (.stopTunnelAdapter, .stopTunnelAdapter): true case let (.configureForErrorState(r0), .configureForErrorState(r1)): r0 == r1 case let (.cacheActiveKey(d0), .cacheActiveKey(d1)): d0 == d1 - case let (.reconfigureForEphemeralPeer(eph0), .reconfigureForEphemeralPeer(eph1)): eph0 == eph1 + case let (.reconfigureForEphemeralPeer(eph0, _), .reconfigureForEphemeralPeer(eph1, _)): eph0 == eph1 case (.connectWithEphemeralPeer, .connectWithEphemeralPeer): true case (.setDisconnectedState, .setDisconnectedState): true default: false @@ -89,8 +90,8 @@ extension PacketTunnelActor { state.mutateAssociatedData { $0.networkReachability = newReachability } return [.updateTunnelMonitorPath(defaultPath)] - case let .ephemeralPeerNegotiationStateChanged(configuration): - return [.reconfigureForEphemeralPeer(configuration)] + case let .ephemeralPeerNegotiationStateChanged(configuration, reconfigurationSemaphore): + return [.reconfigureForEphemeralPeer(configuration, reconfigurationSemaphore)] case .notifyEphemeralPeerNegotiated: return [.connectWithEphemeralPeer] diff --git a/ios/PacketTunnelCoreTests/EphemeralPeerExchangingPipelineTests.swift b/ios/PacketTunnelCoreTests/EphemeralPeerExchangingPipelineTests.swift index 2af86eedfe42..bbce8c8b44f7 100644 --- a/ios/PacketTunnelCoreTests/EphemeralPeerExchangingPipelineTests.swift +++ b/ios/PacketTunnelCoreTests/EphemeralPeerExchangingPipelineTests.swift @@ -60,7 +60,7 @@ final class EphemeralPeerExchangingPipelineTests: XCTestCase { ) } - func testSingleHopPostQuantumKeyExchange() throws { + func testSingleHopPostQuantumKeyExchange() async throws { let reconfigurationExpectation = expectation(description: "Tunnel reconfiguration took place") reconfigurationExpectation.expectedFulfillmentCount = 2 @@ -78,11 +78,11 @@ final class EphemeralPeerExchangingPipelineTests: XCTestCase { } keyExchangeActor.delegate = KeyExchangingResultStub(onReceivePostQuantumKey: { preSharedKey, privateKey in - postQuantumKeyExchangingPipeline.receivePostQuantumKey(preSharedKey, ephemeralKey: privateKey) + await postQuantumKeyExchangingPipeline.receivePostQuantumKey(preSharedKey, ephemeralKey: privateKey) }) let connectionState = stubConnectionState(enableMultiHop: false, enablePostQuantum: true, enableDaita: false) - postQuantumKeyExchangingPipeline.startNegotiation(connectionState, privateKey: PrivateKey()) + await postQuantumKeyExchangingPipeline.startNegotiation(connectionState, privateKey: PrivateKey()) wait( for: [reconfigurationExpectation, negotiationSuccessful], @@ -90,7 +90,7 @@ final class EphemeralPeerExchangingPipelineTests: XCTestCase { ) } - func testSingleHopDaitaPeerExchange() throws { + func testSingleHopDaitaPeerExchange() async throws { let reconfigurationExpectation = expectation(description: "Tunnel reconfiguration took place") reconfigurationExpectation.expectedFulfillmentCount = 2 @@ -108,11 +108,11 @@ final class EphemeralPeerExchangingPipelineTests: XCTestCase { } keyExchangeActor.delegate = KeyExchangingResultStub(onReceiveEphemeralPeerPrivateKey: { privateKey in - postQuantumKeyExchangingPipeline.receiveEphemeralPeerPrivateKey(privateKey) + await postQuantumKeyExchangingPipeline.receiveEphemeralPeerPrivateKey(privateKey) }) let connectionState = stubConnectionState(enableMultiHop: false, enablePostQuantum: false, enableDaita: true) - postQuantumKeyExchangingPipeline.startNegotiation(connectionState, privateKey: PrivateKey()) + await postQuantumKeyExchangingPipeline.startNegotiation(connectionState, privateKey: PrivateKey()) wait( for: [reconfigurationExpectation, negotiationSuccessful], @@ -120,7 +120,7 @@ final class EphemeralPeerExchangingPipelineTests: XCTestCase { ) } - func testMultiHopPostQuantumKeyExchange() throws { + func testMultiHopPostQuantumKeyExchange() async throws { let reconfigurationExpectation = expectation(description: "Tunnel reconfiguration took place") reconfigurationExpectation.expectedFulfillmentCount = 3 @@ -138,11 +138,11 @@ final class EphemeralPeerExchangingPipelineTests: XCTestCase { } keyExchangeActor.delegate = KeyExchangingResultStub(onReceivePostQuantumKey: { preSharedKey, privateKey in - postQuantumKeyExchangingPipeline.receivePostQuantumKey(preSharedKey, ephemeralKey: privateKey) + await postQuantumKeyExchangingPipeline.receivePostQuantumKey(preSharedKey, ephemeralKey: privateKey) }) let connectionState = stubConnectionState(enableMultiHop: true, enablePostQuantum: true, enableDaita: false) - postQuantumKeyExchangingPipeline.startNegotiation(connectionState, privateKey: PrivateKey()) + await postQuantumKeyExchangingPipeline.startNegotiation(connectionState, privateKey: PrivateKey()) wait( for: [reconfigurationExpectation, negotiationSuccessful], @@ -150,7 +150,7 @@ final class EphemeralPeerExchangingPipelineTests: XCTestCase { ) } - func testMultiHopDaitaExchange() throws { + func testMultiHopDaitaExchange() async throws { let reconfigurationExpectation = expectation(description: "Tunnel reconfiguration took place") reconfigurationExpectation.expectedFulfillmentCount = 3 @@ -168,11 +168,11 @@ final class EphemeralPeerExchangingPipelineTests: XCTestCase { } keyExchangeActor.delegate = KeyExchangingResultStub(onReceiveEphemeralPeerPrivateKey: { privateKey in - postQuantumKeyExchangingPipeline.receiveEphemeralPeerPrivateKey(privateKey) + await postQuantumKeyExchangingPipeline.receiveEphemeralPeerPrivateKey(privateKey) }) let connectionState = stubConnectionState(enableMultiHop: true, enablePostQuantum: false, enableDaita: true) - postQuantumKeyExchangingPipeline.startNegotiation(connectionState, privateKey: PrivateKey()) + await postQuantumKeyExchangingPipeline.startNegotiation(connectionState, privateKey: PrivateKey()) wait( for: [reconfigurationExpectation, negotiationSuccessful], diff --git a/ios/PacketTunnelCoreTests/Mocks/EphemeralPeerExchangeActorStub.swift b/ios/PacketTunnelCoreTests/Mocks/EphemeralPeerExchangeActorStub.swift index 7f17af56bec8..2f217488200e 100644 --- a/ios/PacketTunnelCoreTests/Mocks/EphemeralPeerExchangeActorStub.swift +++ b/ios/PacketTunnelCoreTests/Mocks/EphemeralPeerExchangeActorStub.swift @@ -22,9 +22,9 @@ final class EphemeralPeerExchangeActorStub: EphemeralPeerExchangeActorProtocol { switch result { case let .success((preSharedKey, ephemeralKey)): if enablePostQuantum { - delegate?.receivePostQuantumKey(preSharedKey, ephemeralKey: ephemeralKey) + Task { await delegate?.receivePostQuantumKey(preSharedKey, ephemeralKey: ephemeralKey) } } else { - delegate?.receiveEphemeralPeerPrivateKey(ephemeralKey) + Task { await delegate?.receiveEphemeralPeerPrivateKey(ephemeralKey) } } case .failure: delegate?.ephemeralPeerExchangeFailed() diff --git a/ios/PacketTunnelCoreTests/Mocks/KeyExchangingResultStub.swift b/ios/PacketTunnelCoreTests/Mocks/KeyExchangingResultStub.swift index 9dc9dca58c82..250524ec0677 100644 --- a/ios/PacketTunnelCoreTests/Mocks/KeyExchangingResultStub.swift +++ b/ios/PacketTunnelCoreTests/Mocks/KeyExchangingResultStub.swift @@ -12,15 +12,15 @@ struct KeyExchangingResultStub: EphemeralPeerReceiving { var onFailure: (() -> Void)? - var onReceivePostQuantumKey: ((PreSharedKey, PrivateKey) -> Void)? - var onReceiveEphemeralPeerPrivateKey: ((PrivateKey) -> Void)? + var onReceivePostQuantumKey: ((PreSharedKey, PrivateKey) async -> Void)? + var onReceiveEphemeralPeerPrivateKey: ((PrivateKey) async -> Void)? - func receivePostQuantumKey(_ key: PreSharedKey, ephemeralKey: PrivateKey) { - onReceivePostQuantumKey?(key, ephemeralKey) + func receivePostQuantumKey(_ key: PreSharedKey, ephemeralKey: PrivateKey) async { + await onReceivePostQuantumKey?(key, ephemeralKey) } - public func receiveEphemeralPeerPrivateKey(_ ephemeralPeerPrivateKey: PrivateKey) { - onReceiveEphemeralPeerPrivateKey?(ephemeralPeerPrivateKey) + public func receiveEphemeralPeerPrivateKey(_ ephemeralPeerPrivateKey: PrivateKey) async { + await onReceiveEphemeralPeerPrivateKey?(ephemeralPeerPrivateKey) } func ephemeralPeerExchangeFailed() { diff --git a/ios/PacketTunnelCoreTests/MultiHopEphemeralPeerExchangerTests.swift b/ios/PacketTunnelCoreTests/MultiHopEphemeralPeerExchangerTests.swift index a4c1d09155ee..c55f5b4d6566 100644 --- a/ios/PacketTunnelCoreTests/MultiHopEphemeralPeerExchangerTests.swift +++ b/ios/PacketTunnelCoreTests/MultiHopEphemeralPeerExchangerTests.swift @@ -59,7 +59,7 @@ final class MultiHopEphemeralPeerExchangerTests: XCTestCase { ) } - func testEphemeralPeerExchangeFailsWhenNegotiationCannotStart() { + func testEphemeralPeerExchangeFailsWhenNegotiationCannotStart() async { let expectedNegotiationFailure = expectation(description: "Negotiation failed.") let reconfigurationExpectation = expectation(description: "Tunnel reconfiguration took place") @@ -88,7 +88,7 @@ final class MultiHopEphemeralPeerExchangerTests: XCTestCase { expectedNegotiationFailure.fulfill() } - multiHopExchanger.start() + await multiHopExchanger.start() wait( for: [expectedNegotiationFailure, reconfigurationExpectation, negotiationSuccessful], @@ -96,7 +96,7 @@ final class MultiHopEphemeralPeerExchangerTests: XCTestCase { ) } - func testEphemeralPeerExchangeSuccessWhenPostQuantumNegotiationStarts() throws { + func testEphemeralPeerExchangeSuccessWhenPostQuantumNegotiationStarts() async throws { let unexpectedNegotiationFailure = expectation(description: "Negotiation failed.") unexpectedNegotiationFailure.isInverted = true @@ -124,9 +124,9 @@ final class MultiHopEphemeralPeerExchangerTests: XCTestCase { } peerExchangeActor.delegate = KeyExchangingResultStub(onReceivePostQuantumKey: { preSharedKey, ephemeralKey in - multiHopPeerExchanger.receivePostQuantumKey(preSharedKey, ephemeralKey: ephemeralKey) + await multiHopPeerExchanger.receivePostQuantumKey(preSharedKey, ephemeralKey: ephemeralKey) }) - multiHopPeerExchanger.start() + await multiHopPeerExchanger.start() wait( for: [unexpectedNegotiationFailure, reconfigurationExpectation, negotiationSuccessful], @@ -134,7 +134,7 @@ final class MultiHopEphemeralPeerExchangerTests: XCTestCase { ) } - func testEphemeralPeerExchangeSuccessWhenDaitaNegotiationStarts() throws { + func testEphemeralPeerExchangeSuccessWhenDaitaNegotiationStarts() async throws { let unexpectedNegotiationFailure = expectation(description: "Negotiation failed.") unexpectedNegotiationFailure.isInverted = true @@ -162,9 +162,9 @@ final class MultiHopEphemeralPeerExchangerTests: XCTestCase { } peerExchangeActor.delegate = KeyExchangingResultStub(onReceiveEphemeralPeerPrivateKey: { ephemeralKey in - multiHopPeerExchanger.receiveEphemeralPeerPrivateKey(ephemeralKey) + await multiHopPeerExchanger.receiveEphemeralPeerPrivateKey(ephemeralKey) }) - multiHopPeerExchanger.start() + await multiHopPeerExchanger.start() wait( for: [unexpectedNegotiationFailure, reconfigurationExpectation, negotiationSuccessful], diff --git a/ios/PacketTunnelCoreTests/SingleHopEphemeralPeerExchangerTests.swift b/ios/PacketTunnelCoreTests/SingleHopEphemeralPeerExchangerTests.swift index 2ce3558fbac6..deb402abeb4a 100644 --- a/ios/PacketTunnelCoreTests/SingleHopEphemeralPeerExchangerTests.swift +++ b/ios/PacketTunnelCoreTests/SingleHopEphemeralPeerExchangerTests.swift @@ -38,7 +38,7 @@ final class SingleHopEphemeralPeerExchangerTests: XCTestCase { exitRelay = SelectedRelay(endpoint: match.endpoint, hostname: match.relay.hostname, location: match.location) } - func testEphemeralPeerExchangeFailsWhenNegotiationCannotStart() { + func testEphemeralPeerExchangeFailsWhenNegotiationCannotStart() async { let expectedNegotiationFailure = expectation(description: "Negotiation failed.") let reconfigurationExpectation = expectation(description: "Tunnel reconfiguration took place") @@ -66,7 +66,7 @@ final class SingleHopEphemeralPeerExchangerTests: XCTestCase { expectedNegotiationFailure.fulfill() } - singleHopPostQuantumKeyExchanging.start() + await singleHopPostQuantumKeyExchanging.start() wait( for: [expectedNegotiationFailure, reconfigurationExpectation, negotiationSuccessful], @@ -74,7 +74,7 @@ final class SingleHopEphemeralPeerExchangerTests: XCTestCase { ) } - func testEphemeralPeerExchangeSuccessWhenPostQuantumNegotiationStarts() throws { + func testEphemeralPeerExchangeSuccessWhenPostQuantumNegotiationStarts() async throws { let unexpectedNegotiationFailure = expectation(description: "Negotiation failed.") unexpectedNegotiationFailure.isInverted = true @@ -101,9 +101,9 @@ final class SingleHopEphemeralPeerExchangerTests: XCTestCase { } keyExchangeActor.delegate = KeyExchangingResultStub(onReceivePostQuantumKey: { preSharedKey, ephemeralKey in - singleHopPostQuantumKeyExchanging.receivePostQuantumKey(preSharedKey, ephemeralKey: ephemeralKey) + await singleHopPostQuantumKeyExchanging.receivePostQuantumKey(preSharedKey, ephemeralKey: ephemeralKey) }) - singleHopPostQuantumKeyExchanging.start() + await singleHopPostQuantumKeyExchanging.start() wait( for: [unexpectedNegotiationFailure, reconfigurationExpectation, negotiationSuccessful], @@ -111,7 +111,7 @@ final class SingleHopEphemeralPeerExchangerTests: XCTestCase { ) } - func testEphemeralPeerExchangeSuccessWhenDaitaNegotiationStarts() throws { + func testEphemeralPeerExchangeSuccessWhenDaitaNegotiationStarts() async throws { let unexpectedNegotiationFailure = expectation(description: "Negotiation failed.") unexpectedNegotiationFailure.isInverted = true @@ -138,9 +138,9 @@ final class SingleHopEphemeralPeerExchangerTests: XCTestCase { } peerExchangeActor.delegate = KeyExchangingResultStub(onReceiveEphemeralPeerPrivateKey: { ephemeralKey in - multiHopPeerExchanger.receiveEphemeralPeerPrivateKey(ephemeralKey) + await multiHopPeerExchanger.receiveEphemeralPeerPrivateKey(ephemeralKey) }) - multiHopPeerExchanger.start() + await multiHopPeerExchanger.start() wait( for: [unexpectedNegotiationFailure, reconfigurationExpectation, negotiationSuccessful], diff --git a/mullvad-ios/src/encrypted_dns_proxy.rs b/mullvad-ios/src/encrypted_dns_proxy.rs index 2aa83d833dc0..f23482f3550c 100644 --- a/mullvad-ios/src/encrypted_dns_proxy.rs +++ b/mullvad-ios/src/encrypted_dns_proxy.rs @@ -1,8 +1,10 @@ use crate::ProxyHandle; use libc::c_char; -use mullvad_encrypted_dns_proxy::state::{EncryptedDnsProxyState as State, FetchConfigError}; -use mullvad_encrypted_dns_proxy::Forwarder; +use mullvad_encrypted_dns_proxy::{ + state::{EncryptedDnsProxyState as State, FetchConfigError}, + Forwarder, +}; use std::{ io, mem, net::{Ipv4Addr, SocketAddr}, diff --git a/mullvad-ios/src/ephemeral_peer_proxy/ios_runtime.rs b/mullvad-ios/src/ephemeral_peer_proxy/ios_runtime.rs deleted file mode 100644 index 19107689ab29..000000000000 --- a/mullvad-ios/src/ephemeral_peer_proxy/ios_runtime.rs +++ /dev/null @@ -1,188 +0,0 @@ -use super::{ - ios_tcp_connection::*, EphemeralPeerCancelToken, EphemeralPeerParameters, PacketTunnelBridge, -}; -use libc::c_void; -use std::{ - io, ptr, - sync::{Arc, Mutex}, -}; -use talpid_tunnel_config_client::{request_ephemeral_peer_with, Error, RelayConfigService}; -use talpid_types::net::wireguard::{PrivateKey, PublicKey}; -use tokio::runtime::Handle as TokioHandle; -use tonic::transport::channel::Endpoint; -use tower::util::service_fn; - -/// # Safety -/// packet_tunnel and tcp_connection must be valid pointers to a packet tunnel and a TCP connection -/// instances. -pub unsafe fn run_ephemeral_peer_exchange( - pub_key: [u8; 32], - ephemeral_key: [u8; 32], - packet_tunnel_bridge: PacketTunnelBridge, - peer_parameters: EphemeralPeerParameters, - tokio_handle: TokioHandle, -) -> Result { - match unsafe { - IOSRuntime::new( - pub_key, - ephemeral_key, - packet_tunnel_bridge, - peer_parameters, - ) - } { - Ok(runtime) => { - let token = runtime.packet_tunnel.tcp_connection.clone(); - runtime.run(tokio_handle); - Ok(EphemeralPeerCancelToken { - context: Arc::into_raw(token) as *mut _, - }) - } - Err(err) => { - log::error!("Failed to create runtime {}", err); - Err(Error::UnableToCreateRuntime) - } - } -} - -#[derive(Clone)] -pub struct SwiftContext { - pub packet_tunnel: *const c_void, - pub tcp_connection: Arc>, -} - -unsafe impl Send for SwiftContext {} -unsafe impl Sync for SwiftContext {} - -struct IOSRuntime { - pub_key: [u8; 32], - ephemeral_key: [u8; 32], - packet_tunnel: SwiftContext, - peer_parameters: EphemeralPeerParameters, -} - -impl IOSRuntime { - pub unsafe fn new( - pub_key: [u8; 32], - ephemeral_key: [u8; 32], - packet_tunnel_bridge: PacketTunnelBridge, - peer_parameters: EphemeralPeerParameters, - ) -> io::Result { - let context = SwiftContext { - packet_tunnel: packet_tunnel_bridge.packet_tunnel, - tcp_connection: Arc::new(Mutex::new(ConnectionContext::new( - packet_tunnel_bridge.tcp_connection, - ))), - }; - - Ok(Self { - pub_key, - ephemeral_key, - packet_tunnel: context, - peer_parameters, - }) - } - - pub fn run(self, handle: TokioHandle) { - handle.spawn(async move { - self.run_service_inner().await; - }); - } - /// Creates a `RelayConfigService` using the in-tunnel TCP Connection provided by the Packet - /// Tunnel Provider - /// - /// ## Safety - /// It is unsafe to call this with an already used `SwiftContext` - async unsafe fn ios_tcp_client( - ctx: SwiftContext, - ) -> Result<(RelayConfigService, IosTcpShutdownHandle), Error> { - let endpoint = Endpoint::from_static("tcp://0.0.0.0:0"); - - let (tcp_provider, conn_handle) = unsafe { IosTcpProvider::new(ctx.tcp_connection) }; - // One (1) TCP connection - let mut one_tcp_connection = Some(tcp_provider); - let conn = endpoint - .connect_with_connector(service_fn(move |_| { - let connection = one_tcp_connection - .take() - .map(hyper_util::rt::tokio::TokioIo::new) - .ok_or(Error::TcpConnectionExpired); - async { connection } - })) - .await - .map_err(Error::GrpcConnectError)?; - - Ok((RelayConfigService::new(conn), conn_handle)) - } - - async fn run_service_inner(self) { - let (async_provider, shutdown_handle) = unsafe { - match Self::ios_tcp_client(self.packet_tunnel.clone()).await { - Ok(result) => result, - Err(error) => { - log::error!("Failed to create iOS TCP client: {error}"); - swift_ephemeral_peer_ready( - self.packet_tunnel.packet_tunnel, - ptr::null(), - ptr::null(), - ); - return; - } - } - }; - // Use `self.ephemeral_key` as the new private key when no PQ but yes DAITA - let ephemeral_pub_key = PrivateKey::from(self.ephemeral_key).public_key(); - - tokio::select! { - ephemeral_peer = request_ephemeral_peer_with( - async_provider, - PublicKey::from(self.pub_key), - ephemeral_pub_key, - self.peer_parameters.enable_post_quantum, - self.peer_parameters.enable_daita, - ) => { - shutdown_handle.shutdown(); - if let Ok(mut connection) = self.packet_tunnel.tcp_connection.lock() { - connection.shutdown(); - } - match ephemeral_peer { - Ok(peer) => { - match peer.psk { - Some(preshared_key) => unsafe { - let preshared_key_bytes = preshared_key.as_bytes(); - swift_ephemeral_peer_ready(self.packet_tunnel.packet_tunnel, - preshared_key_bytes.as_ptr(), - self.ephemeral_key.as_ptr()); - }, - None => { - // Daita peer was requested, but without enabling post quantum keys - unsafe { - swift_ephemeral_peer_ready(self.packet_tunnel.packet_tunnel, - ptr::null(), - self.ephemeral_key.as_ptr()); - } - } - } - }, - Err(error) => { - log::error!("Key exchange failed {}", error); - unsafe { - swift_ephemeral_peer_ready(self.packet_tunnel.packet_tunnel, - ptr::null(), - ptr::null()); - } - } - } - } - - _ = tokio::time::sleep(std::time::Duration::from_secs(self.peer_parameters.peer_exchange_timeout)) => { - if let Ok(mut connection) = self.packet_tunnel.tcp_connection.lock() { - connection.shutdown(); - }; - shutdown_handle.shutdown(); - unsafe { swift_ephemeral_peer_ready(self.packet_tunnel.packet_tunnel, - ptr::null(), - ptr::null()); } - } - } - } -} diff --git a/mullvad-ios/src/ephemeral_peer_proxy/ios_tcp_connection.rs b/mullvad-ios/src/ephemeral_peer_proxy/ios_tcp_connection.rs index d91081fe576d..0a9420308fea 100644 --- a/mullvad-ios/src/ephemeral_peer_proxy/ios_tcp_connection.rs +++ b/mullvad-ios/src/ephemeral_peer_proxy/ios_tcp_connection.rs @@ -1,35 +1,77 @@ use libc::c_void; use std::{ - io::{self, Result}, - sync::{Arc, Mutex, MutexGuard, Weak}, - task::{Poll, Waker}, -}; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::mpsc, + ffi::CStr, + future::Future, + io::{self}, + pin::Pin, + task::{ready, Poll}, + time::Duration, }; +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::EphemeralPeerParameters; fn connection_closed_err() -> io::Error { io::Error::new(io::ErrorKind::BrokenPipe, "TCP connection closed") } -extern "C" { - /// Called when there is data to send on the TCP connection. - /// The TCP connection must write data on the wire, then call the `handle_sent` function. - pub fn swift_nw_tcp_connection_send( - connection: *const libc::c_void, - data: *const libc::c_void, - data_len: usize, - sender: *const libc::c_void, - ); +#[derive(Clone, Copy)] +#[repr(C)] +pub struct WgTcpConnectionFuncs { + pub open_fn: + unsafe extern "C" fn(tunnelHandle: i32, address: *const libc::c_char, timeout: u64) -> i32, + pub close_fn: unsafe extern "C" fn(tunnelHandle: i32, socketHandle: i32) -> i32, + pub recv_fn: + unsafe extern "C" fn(tunnelHandle: i32, socketHandle: i32, data: *mut u8, len: i32) -> i32, + pub send_fn: unsafe extern "C" fn( + tunnelHandle: i32, + socketHandle: i32, + data: *const u8, + len: i32, + ) -> i32, +} - /// Called when there is data to read on the TCP connection. - /// The TCP connection must read data from the wire, then call the `handle_read` function. - pub fn swift_nw_tcp_connection_read( - connection: *const libc::c_void, - sender: *const libc::c_void, - ); +impl WgTcpConnectionFuncs { + /// # Safety + /// This function is safe to call so long as the function pointer is valid for its declared + /// signature. + pub unsafe fn open(&self, tunnel_handle: i32, address: *const u8, timeout: u64) -> i32 { + unsafe { (self.open_fn)(tunnel_handle, address.cast(), timeout) } + } + + /// # Safety + /// This function is safe to call so long as the function pointer is valid for its declared + /// signature. + pub unsafe fn close(&self, tunnel_handle: i32, socket_handle: i32) -> i32 { + unsafe { (self.close_fn)(tunnel_handle, socket_handle) } + } + /// # Safety + /// This function is safe to call so long as the function pointer is valid for its declared + /// signature. + pub unsafe fn receive(&self, tunnel_handle: i32, socket_handle: i32, data: &mut [u8]) -> i32 { + let ptr = data.as_mut_ptr(); + let len = data + .len() + .try_into() + .expect("Cannot receive a buffer larger than 2GiB"); + unsafe { (self.recv_fn)(tunnel_handle, socket_handle, ptr.cast(), len) } + } + + /// # Safety + /// This function is safe to call so long as the function pointer is valid for its declared + /// signature. + pub unsafe fn send(&self, tunnel_handle: i32, socket_handle: i32, data: &[u8]) -> i32 { + let ptr = data.as_ptr(); + let len = data + .len() + .try_into() + .expect("Cannot send a buffer larger than 2GiB"); + unsafe { (self.send_fn)(tunnel_handle, socket_handle, ptr.cast(), len) } + } +} + +extern "C" { /// Called when the preshared post quantum key is ready, /// or when a Daita peer has been successfully requested. /// `raw_preshared_key` will be NULL if: @@ -40,181 +82,188 @@ extern "C" { raw_preshared_key: *const u8, raw_ephemeral_private_key: *const u8, ); -} -unsafe impl Send for IosTcpProvider {} +} +#[derive(Clone)] pub struct IosTcpProvider { - write_tx: Arc>, - write_rx: mpsc::UnboundedReceiver, - read_tx: Arc>>, - read_rx: mpsc::UnboundedReceiver>, - tcp_connection: Arc>, - read_in_progress: bool, - write_in_progress: bool, + tunnel_handle: i32, + timeout: Duration, + funcs: WgTcpConnectionFuncs, } -pub struct IosTcpShutdownHandle { - context: Arc>, -} +type InFlightIoTask = Option>>>>>; -pub struct ConnectionContext { - waker: Option, - tcp_connection: Option<*const c_void>, +pub struct IosTcpConnection { + tunnel_handle: i32, + socket_handle: i32, + funcs: WgTcpConnectionFuncs, + in_flight_read: InFlightIoTask, + in_flight_write: InFlightIoTask, } -unsafe impl Send for ConnectionContext {} +#[derive(Debug)] +pub enum WgTcpError { + /// Failed to open the socket + Open, + /// Panicked during opening of the socket + Panic, +} impl IosTcpProvider { - /// # Safety - /// `connection` must be pointing to a valid instance of a `NWTCPConnection`, created by the - /// `PacketTunnelProvider` - pub unsafe fn new(connection: Arc>) -> (Self, IosTcpShutdownHandle) { - let (tx, rx) = mpsc::unbounded_channel(); - let (recv_tx, recv_rx) = mpsc::unbounded_channel(); - - ( - Self { - write_tx: Arc::new(tx), - write_rx: rx, - read_tx: Arc::new(recv_tx), - read_rx: recv_rx, - tcp_connection: connection.clone(), - read_in_progress: false, - write_in_progress: false, - }, - IosTcpShutdownHandle { - context: connection, - }, - ) + pub fn new(tunnel_handle: i32, params: EphemeralPeerParameters) -> Self { + Self { + tunnel_handle, + timeout: Duration::from_secs(params.peer_exchange_timeout), + funcs: params.funcs, + } } - fn maybe_set_waker(new_waker: Waker, connection: &mut MutexGuard<'_, ConnectionContext>) { - connection.waker = Some(new_waker); + pub async fn connect(&self, address: &'static CStr) -> Result { + let tunnel_handle = self.tunnel_handle; + let timeout = self.timeout.as_secs(); + let funcs = self.funcs; + let result = tokio::task::spawn_blocking(move || unsafe { + // SAFETY + // The `open_fn` function pointer in `funcs` must be valid. + funcs.open(tunnel_handle, address.as_ptr() as *const _, timeout) + }) + .await + .map_err(|_| WgTcpError::Panic)?; + + if result < 0 { + return Err(WgTcpError::Open); + } + + Ok(IosTcpConnection { + tunnel_handle, + socket_handle: result, + funcs: self.funcs, + in_flight_read: None, + in_flight_write: None, + }) } } -impl IosTcpShutdownHandle { - pub fn shutdown(self) { - let Ok(mut context) = self.context.lock() else { - return; - }; - - context.tcp_connection = None; - if let Some(waker) = context.waker.take() { - waker.wake(); - } - std::mem::drop(context); +impl Drop for IosTcpConnection { + fn drop(&mut self) { + // Safety + // `funcs.close_fn` must be a valid function pointer. + unsafe { self.funcs.close(self.tunnel_handle, self.socket_handle) }; } } -impl AsyncWrite for IosTcpProvider { +impl AsyncWrite for IosTcpConnection { fn poll_write( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8], - ) -> std::task::Poll> { - let connection_lock = self.tcp_connection.clone(); - let Ok(mut connection) = connection_lock.lock() else { - return Poll::Ready(Err(connection_closed_err())); - }; - let Some(tcp_ptr) = connection.tcp_connection else { - return Poll::Ready(Err(connection_closed_err())); - }; - Self::maybe_set_waker(cx.waker().clone(), &mut connection); - - match self.write_rx.poll_recv(cx) { - std::task::Poll::Ready(Some(bytes_sent)) => { - self.write_in_progress = false; - Poll::Ready(Ok(bytes_sent)) - } - std::task::Poll::Ready(None) => { - self.write_in_progress = false; - Poll::Ready(Err(connection_closed_err())) - } - std::task::Poll::Pending => { - if !self.write_in_progress { - let raw_sender = Weak::into_raw(Arc::downgrade(&self.write_tx)); - unsafe { - swift_nw_tcp_connection_send( - tcp_ptr, - buf.as_ptr() as _, - buf.len(), - raw_sender as _, - ); - } - self.write_in_progress = true; + ) -> std::task::Poll> { + // If task is already spawned, poll it + if let Some(handle) = &mut self.in_flight_write { + let result = match ready!(handle.as_mut().poll(cx)) { + Ok(Ok(written)) => Ok(written.len()), + Ok(Err(e)) => Err(e), + Err(_) => Err(io::Error::new(io::ErrorKind::Other, "Write task panicked")), + }; + // important to clear the in flight write here. + self.in_flight_write = None; + Poll::Ready(result) + } else { + // if no write task has been spawned, spawn one + let tunnel_handle = self.tunnel_handle; + let socket_handle = self.socket_handle; + // The data has to be cloned, since it will be moved into another thread and it has to + // outlive this function call. + let data = buf.to_vec(); + let funcs = self.funcs; + let task = tokio::task::spawn_blocking(move || { + // Safety + // `funcs.send_fn` must be a valid function pointer. + let result = unsafe { funcs.send(tunnel_handle, socket_handle, data.as_slice()) }; + if result < 0 { + Err(io::Error::new( + io::ErrorKind::Other, + format!("Write error: {}", result), + )) + } else { + Ok(data[..result as usize].to_vec()) } - std::task::Poll::Pending - } + }); + + self.in_flight_write = Some(Box::pin(task)); + cx.waker().wake_by_ref(); + Poll::Pending } } fn poll_flush( self: std::pin::Pin<&mut Self>, _: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + ) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } fn poll_shutdown( self: std::pin::Pin<&mut Self>, _: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + ) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } } -impl AsyncRead for IosTcpProvider { + +impl AsyncRead for IosTcpConnection { fn poll_read( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { - let connection_lock = self.tcp_connection.clone(); - let Ok(mut connection) = connection_lock.lock() else { - return Poll::Ready(Err(connection_closed_err())); - }; - let Some(tcp_ptr) = connection.tcp_connection else { - return Poll::Ready(Err(connection_closed_err())); - }; - Self::maybe_set_waker(cx.waker().clone(), &mut connection); - - match self.read_rx.poll_recv(cx) { - std::task::Poll::Ready(Some(data)) => { - buf.put_slice(&data); - self.read_in_progress = false; - Poll::Ready(Ok(())) - } - std::task::Poll::Ready(None) => { - self.read_in_progress = false; - Poll::Ready(Err(connection_closed_err())) - } - std::task::Poll::Pending => { - if !self.read_in_progress { - let raw_sender = Weak::into_raw(Arc::downgrade(&self.read_tx)); - unsafe { - swift_nw_tcp_connection_read(tcp_ptr, raw_sender as _); - } - self.read_in_progress = true; + ) -> std::task::Poll> { + // If task is already spawned, poll it + if let Some(handle) = &mut self.in_flight_read { + let result = match ready!(handle.as_mut().poll(cx)) { + Ok(Ok(data)) => { + // We are assuming that the buffer has not been used for anything else between + // spawning the task and writing to it now, since we expect `buf.remaining()` + // to return the same value between those two points in time. + let len = data.len().min(buf.remaining()); + buf.put_slice(&data[..len]); + Ok(()) } - Poll::Pending - } - } - } -} + Ok(Err(e)) => Err(e), + Err(_) => Err(io::Error::new(io::ErrorKind::Other, "Read task panicked")), + }; + // Clear the in-flight read, since the read task finished + self.in_flight_read = None; + Poll::Ready(result) + } else { + // If no read task has been spawned, spawn one + let tunnel_handle = self.tunnel_handle; + let socket_handle = self.socket_handle; + let funcs = self.funcs; + let mut buffer = vec![0u8; buf.remaining()]; + let task = tokio::task::spawn_blocking(move || { + // Safety + // `funcs.receive_fn` must be a valid function pointer. + let result = + unsafe { funcs.receive(tunnel_handle, socket_handle, buffer.as_mut_slice()) }; + match result { + size @ 1.. => { + buffer.truncate(size as usize); + Ok(buffer) + } -impl ConnectionContext { - pub fn new(tcp_connection: *const c_void) -> Self { - Self { - tcp_connection: Some(tcp_connection), - waker: None, - } - } + errval @ ..0 => Err(io::Error::new( + io::ErrorKind::Other, + format!("Read error: {}", errval), + )), + + 0 => Err(connection_closed_err()), + } + }); - pub fn shutdown(&mut self) { - self.tcp_connection = None; - if let Some(waker) = self.waker.take() { - waker.wake(); + self.in_flight_read = Some(Box::pin(task)); + cx.waker().wake_by_ref(); + Poll::Pending } } } diff --git a/mullvad-ios/src/ephemeral_peer_proxy/mod.rs b/mullvad-ios/src/ephemeral_peer_proxy/mod.rs index c69b7d6b3b80..95287b8cda77 100644 --- a/mullvad-ios/src/ephemeral_peer_proxy/mod.rs +++ b/mullvad-ios/src/ephemeral_peer_proxy/mod.rs @@ -1,58 +1,46 @@ #![cfg(target_os = "ios")] -pub mod ios_runtime; pub mod ios_tcp_connection; +pub mod peer_exchange; -use ios_runtime::run_ephemeral_peer_exchange; -use ios_tcp_connection::ConnectionContext; +use ios_tcp_connection::swift_ephemeral_peer_ready; use libc::c_void; -use std::sync::{Arc, Mutex, Weak}; -use tokio::sync::mpsc; +use peer_exchange::EphemeralPeerExchange; -use std::sync::Once; +use std::{ptr, sync::Once}; static INIT_LOGGING: Once = Once::new(); -#[repr(C)] -pub struct EphemeralPeerCancelToken { - // Must keep a pointer to a valid std::sync::Arc - pub context: *mut c_void, -} - +#[derive(Clone)] pub struct PacketTunnelBridge { pub packet_tunnel: *const c_void, - pub tcp_connection: *const c_void, + pub tunnel_handle: i32, } -pub struct EphemeralPeerParameters { - pub peer_exchange_timeout: u64, - pub enable_post_quantum: bool, - pub enable_daita: bool, -} +impl PacketTunnelBridge { + fn fail_exchange(self) { + unsafe { swift_ephemeral_peer_ready(self.packet_tunnel, ptr::null(), ptr::null()) }; + } -impl EphemeralPeerCancelToken { - /// # Safety - /// This function can only be called when the context pointer is valid. - unsafe fn cancel(&self) { - // # Safety - // Try to take the value, if there is a value, we can safely send the message, otherwise, - // assume it has been dropped and nothing happens - let connection_context: Arc> = - unsafe { Arc::from_raw(self.context as _) }; - if let Ok(mut connection) = connection_context.lock() { - connection.shutdown(); - } + fn succeed_exchange(self, ephemeral_key: [u8; 32], preshared_key: Option<[u8; 32]>) { + let ephemeral_ptr = ephemeral_key.as_ptr(); + let preshared_ptr = preshared_key + .as_ref() + .map(|key| key.as_ptr()) + .unwrap_or(ptr::null()); - // Call std::mem::forget here to avoid dropping the channel. - std::mem::forget(connection_context); + unsafe { swift_ephemeral_peer_ready(self.packet_tunnel, preshared_ptr, ephemeral_ptr) }; } } -impl Drop for EphemeralPeerCancelToken { - fn drop(&mut self) { - let _: Arc> = unsafe { Arc::from_raw(self.context as _) }; - } -} +unsafe impl Send for PacketTunnelBridge {} -unsafe impl Send for EphemeralPeerCancelToken {} +#[repr(C)] +#[derive(Clone, Copy)] +pub struct EphemeralPeerParameters { + pub peer_exchange_timeout: u64, + pub enable_post_quantum: bool, + pub enable_daita: bool, + pub funcs: ios_tcp_connection::WgTcpConnectionFuncs, +} /// Called by the Swift side to signal that the ephemeral peer exchange should be cancelled. /// After this call, the cancel token is no longer valid. @@ -61,64 +49,25 @@ unsafe impl Send for EphemeralPeerCancelToken {} /// `sender` must be pointing to a valid instance of a `EphemeralPeerCancelToken` created by the /// `PacketTunnelProvider`. #[no_mangle] -pub unsafe extern "C" fn cancel_ephemeral_peer_exchange(sender: *const EphemeralPeerCancelToken) { - let sender = unsafe { &*sender }; +pub unsafe extern "C" fn cancel_ephemeral_peer_exchange( + sender: *mut peer_exchange::ExchangeCancelToken, +) { + let sender = unsafe { Box::from_raw(sender) }; sender.cancel(); } -/// Called by the Swift side to signal that the Rust `EphemeralPeerCancelToken` can be safely dropped -/// from memory. +/// Called by the Swift side to signal that the Rust `EphemeralPeerCancelToken` can be safely +/// dropped from memory. /// /// # Safety /// `sender` must be pointing to a valid instance of a `EphemeralPeerCancelToken` created by the /// `PacketTunnelProvider`. #[no_mangle] pub unsafe extern "C" fn drop_ephemeral_peer_exchange_token( - sender: *const EphemeralPeerCancelToken, + sender: *mut peer_exchange::ExchangeCancelToken, ) { - let _sender = unsafe { std::ptr::read(sender) }; -} - -/// Called by Swift whenever data has been written to the in-tunnel TCP connection when exchanging -/// quantum-resistant pre shared keys, or ephemeral peers. -/// -/// If `bytes_sent` is 0, this indicates that the connection was closed or that an error occurred. -/// -/// # Safety -/// `sender` must be pointing to a valid instance of a `write_tx` created by the `IosTcpProvider` -/// Callback to call when the TCP connection has written data. -#[no_mangle] -pub unsafe extern "C" fn handle_sent(bytes_sent: usize, sender: *const c_void) { - let weak_tx: Weak> = unsafe { Weak::from_raw(sender as _) }; - if let Some(send_tx) = weak_tx.upgrade() { - _ = send_tx.send(bytes_sent); - } -} - -/// Called by Swift whenever data has been read from the in-tunnel TCP connection when exchanging -/// quantum-resistant pre shared keys, or ephemeral peers. -/// -/// If `data` is null or empty, this indicates that the connection was closed or that an error -/// occurred. An empty buffer is sent to the underlying reader to signal EOF. -/// -/// # Safety -/// `sender` must be pointing to a valid instance of a `read_tx` created by the `IosTcpProvider` -/// -/// Callback to call when the TCP connection has received data. -#[no_mangle] -pub unsafe extern "C" fn handle_recv(data: *const u8, mut data_len: usize, sender: *const c_void) { - let weak_tx: Weak>> = unsafe { Weak::from_raw(sender as _) }; - - if data.is_null() { - data_len = 0; - } - let mut bytes = vec![0u8; data_len]; - if !data.is_null() { - std::ptr::copy_nonoverlapping(data, bytes.as_mut_ptr(), data_len); - } - if let Some(read_tx) = weak_tx.upgrade() { - _ = read_tx.send(bytes.into_boxed_slice()); - } + // drop the cancel token + let _sender = unsafe { Box::from_raw(sender) }; } /// Entry point for requesting ephemeral peers on iOS. @@ -126,61 +75,43 @@ pub unsafe extern "C" fn handle_recv(data: *const u8, mut data_len: usize, sende /// # Safety /// `public_key` and `ephemeral_key` must be valid respective `PublicKey` and `PrivateKey` types. /// They will not be valid after this function is called, and thus must be copied here. -/// `packet_tunnel` and `tcp_connection` must be valid pointers to a packet tunnel and a TCP -/// connection instances. -/// `cancel_token` should be owned by the caller of this function. +/// `packet_tunnel` must be valid pointers to a packet tunnel, the packet tunnel pointer must +/// outlive the ephemeral peer exchange. `cancel_token` should be owned by the caller of this +/// function. #[no_mangle] pub unsafe extern "C" fn request_ephemeral_peer( public_key: *const u8, ephemeral_key: *const u8, packet_tunnel: *const c_void, - tcp_connection: *const c_void, - cancel_token: *mut EphemeralPeerCancelToken, - peer_exchange_timeout: u64, - enable_post_quantum: bool, - enable_daita: bool, -) -> i32 { + tunnel_handle: i32, + peer_parameters: EphemeralPeerParameters, +) -> *mut peer_exchange::ExchangeCancelToken { INIT_LOGGING.call_once(|| { let _ = oslog::OsLogger::new("net.mullvad.MullvadVPN.TTCC") .level_filter(log::LevelFilter::Debug) .init(); }); - let pub_key: [u8; 32] = unsafe { std::ptr::read(public_key as *const [u8; 32]) }; - let eph_key: [u8; 32] = unsafe { std::ptr::read(ephemeral_key as *const [u8; 32]) }; + let pub_key: [u8; 32] = unsafe { ptr::read(public_key as *const [u8; 32]) }; + let eph_key: [u8; 32] = unsafe { ptr::read(ephemeral_key as *const [u8; 32]) }; let handle = match crate::mullvad_ios_runtime() { Ok(handle) => handle, Err(err) => { log::error!("Failed to obtain a handle to a tokio runtime: {err}"); - return -1; + return ptr::null_mut(); } }; let packet_tunnel_bridge = PacketTunnelBridge { packet_tunnel, - tcp_connection, - }; - let peer_parameters = EphemeralPeerParameters { - peer_exchange_timeout, - enable_post_quantum, - enable_daita, + tunnel_handle, }; - match unsafe { - run_ephemeral_peer_exchange( - pub_key, - eph_key, - packet_tunnel_bridge, - peer_parameters, - handle, - ) - } { - Ok(token) => { - unsafe { std::ptr::write(cancel_token, token) }; - 0 - } - Err(_) => -1, - } + let cancel_token = + EphemeralPeerExchange::new(pub_key, eph_key, packet_tunnel_bridge, peer_parameters) + .run(handle); + + Box::into_raw(Box::new(cancel_token)) } diff --git a/mullvad-ios/src/ephemeral_peer_proxy/peer_exchange.rs b/mullvad-ios/src/ephemeral_peer_proxy/peer_exchange.rs new file mode 100644 index 000000000000..ee9a512777e8 --- /dev/null +++ b/mullvad-ios/src/ephemeral_peer_proxy/peer_exchange.rs @@ -0,0 +1,173 @@ +use super::{ios_tcp_connection::*, EphemeralPeerParameters, PacketTunnelBridge}; +use std::{ffi::CStr, sync::Mutex, thread}; +use talpid_tunnel_config_client::{request_ephemeral_peer_with, Error, RelayConfigService}; +use talpid_types::net::wireguard::{PrivateKey, PublicKey}; +use tokio::{runtime::Handle as TokioHandle, task::JoinHandle}; +use tonic::transport::channel::Endpoint; +use tower::util::service_fn; + +const GRPC_HOST_CSTR: &CStr = c"10.64.0.1:1337"; + +pub struct ExchangeCancelToken { + inner: Mutex, +} + +impl ExchangeCancelToken { + fn new(tokio_handle: TokioHandle, task: JoinHandle<()>) -> Self { + let inner = CancelToken { + tokio_handle, + task: Some(task), + }; + Self { + inner: Mutex::new(inner), + } + } + + /// Blocks until the associated ephemeral peer exchange task is finished. + pub fn cancel(&self) { + if let Ok(mut inner) = self.inner.lock() { + if let Some(task) = inner.task.take() { + task.abort(); + let _ = inner.tokio_handle.block_on(task); + } + } + } +} + +struct CancelToken { + tokio_handle: TokioHandle, + task: Option>, +} + +pub struct EphemeralPeerExchange { + pub_key: [u8; 32], + ephemeral_key: [u8; 32], + packet_tunnel: PacketTunnelBridge, + peer_parameters: EphemeralPeerParameters, +} + +// # Safety +// This is safe because the void pointer in PacketTunnelBridge is valid for the lifetime of the +// process where this type is intended to be used. +unsafe impl Send for EphemeralPeerExchange {} + +impl EphemeralPeerExchange { + pub fn new( + pub_key: [u8; 32], + ephemeral_key: [u8; 32], + packet_tunnel: PacketTunnelBridge, + peer_parameters: EphemeralPeerParameters, + ) -> EphemeralPeerExchange { + Self { + pub_key, + ephemeral_key, + packet_tunnel, + peer_parameters, + } + } + + pub fn run(self, tokio: TokioHandle) -> ExchangeCancelToken { + let task = tokio.spawn(async move { + self.run_service_inner().await; + }); + + ExchangeCancelToken::new(tokio, task) + } + + /// Creates a `RelayConfigService` using the in-tunnel TCP Connection provided by the Packet + /// Tunnel Provider + async fn ios_tcp_client( + tunnel_handle: i32, + peer_parameters: EphemeralPeerParameters, + ) -> Result { + let endpoint = Endpoint::from_static("tcp://0.0.0.0:0"); + + let tcp_provider = IosTcpProvider::new(tunnel_handle, peer_parameters); + + let conn = endpoint + // it is assumend that the service function will only be called once. + // Yet, by its signature, it is forced to be callable multiple times. + // The tcp_provider appeases this constraint, maybe we should rewrite this back to + // explicitly only allow a single invocation? It is due to this mismatch between how we + // use it and what the interface expects that we are using a oneshot channel to + // transfer the shutdown handle. + .connect_with_connector(service_fn(move |_| { + let provider = tcp_provider.clone(); + async move { + provider + .connect(GRPC_HOST_CSTR) + .await + .map(hyper_util::rt::tokio::TokioIo::new) + .map_err(|_| Error::TcpConnectionOpen) + } + })) + .await + .map_err(Error::GrpcConnectError)?; + + Ok(RelayConfigService::new(conn)) + } + + fn report_failure(self) { + thread::spawn(move || { + self.packet_tunnel.fail_exchange(); + }); + } + + async fn run_service_inner(self) { + let async_provider = match Self::ios_tcp_client( + self.packet_tunnel.tunnel_handle, + self.peer_parameters, + ) + .await + { + Ok(result) => result, + Err(error) => { + log::error!("Failed to create iOS TCP client: {error}"); + self.report_failure(); + return; + } + }; + // Use `self.ephemeral_key` as the new private key when no PQ but yes DAITA + let ephemeral_pub_key = PrivateKey::from(self.ephemeral_key).public_key(); + + tokio::select! { + ephemeral_peer = request_ephemeral_peer_with( + async_provider, + PublicKey::from(self.pub_key), + ephemeral_pub_key, + self.peer_parameters.enable_post_quantum, + self.peer_parameters.enable_daita, + ) => { + match ephemeral_peer { + Ok(peer) => { + match peer.psk { + Some(preshared_key) => { + let preshared_key_bytes = *preshared_key.as_bytes(); + thread::spawn(move || { + let Self{ ephemeral_key, packet_tunnel, .. } = self; + packet_tunnel.succeed_exchange(ephemeral_key, Some(preshared_key_bytes)); + }); + + }, + None => { + // Daita peer was requested, but without enabling post quantum keys + thread::spawn(move || { + let Self{ ephemeral_key, packet_tunnel, .. } = self; + packet_tunnel.succeed_exchange(ephemeral_key, None); + }); + } + } + }, + Err(error) => { + log::error!("Key exchange failed {}", error); + self.report_failure(); + } + } + } + + _ = tokio::time::sleep(std::time::Duration::from_secs(self.peer_parameters.peer_exchange_timeout)) => { + self.report_failure(); + } + } + } +} diff --git a/talpid-tunnel-config-client/src/lib.rs b/talpid-tunnel-config-client/src/lib.rs index f7d559f6417d..7c80d4f3e5fa 100644 --- a/talpid-tunnel-config-client/src/lib.rs +++ b/talpid-tunnel-config-client/src/lib.rs @@ -40,7 +40,7 @@ pub enum Error { }, MissingDaitaResponse, #[cfg(target_os = "ios")] - TcpConnectionExpired, + TcpConnectionOpen, #[cfg(target_os = "ios")] UnableToCreateRuntime, } @@ -49,7 +49,7 @@ impl std::fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use Error::*; match self { - GrpcConnectError(_) => "Failed to connect to config service".fmt(f), + GrpcConnectError(err) => write!(f, "Failed to connect to config service: {err:?}"), GrpcError(status) => write!(f, "RPC failed: {status}"), MissingCiphertexts => write!(f, "Found no ciphertexts in response"), InvalidCiphertextLength { @@ -65,7 +65,7 @@ impl std::fmt::Display for Error { } MissingDaitaResponse => "Expected DAITA configuration in response".fmt(f), #[cfg(target_os = "ios")] - TcpConnectionExpired => "TCP connection is already shut down".fmt(f), + TcpConnectionOpen => "Failed to open TCP connection".fmt(f), #[cfg(target_os = "ios")] UnableToCreateRuntime => "Unable to create iOS PQ PSK runtime".fmt(f), }