diff --git a/Sources/Engine/NativeEngine.swift b/Sources/Engine/NativeEngine.swift index 7294e364..15fd56da 100644 --- a/Sources/Engine/NativeEngine.swift +++ b/Sources/Engine/NativeEngine.swift @@ -11,8 +11,13 @@ import Foundation @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) public class NativeEngine: NSObject, Engine, URLSessionDataDelegate, URLSessionWebSocketDelegate { private var task: URLSessionWebSocketTask? + private var clientCredential: URLCredential? weak var delegate: EngineDelegate? + public init(clientCredential: URLCredential? = nil) { + self.clientCredential = clientCredential + } + public func register(delegate: EngineDelegate) { self.delegate = delegate } @@ -93,4 +98,17 @@ public class NativeEngine: NSObject, Engine, URLSessionDataDelegate, URLSessionW } broadcast(event: .disconnected(r, UInt16(closeCode.rawValue))) } + + public func urlSession(_ session: URLSession, didReceive challenge: URLAuthenticationChallenge, completionHandler: @escaping (URLSession.AuthChallengeDisposition, URLCredential?) -> Void) { + var credential: URLCredential? = nil + var disposition: URLSession.AuthChallengeDisposition = .performDefaultHandling + + let authMethod = challenge.protectionSpace.authenticationMethod + if authMethod == NSURLAuthenticationMethodClientCertificate && self.clientCredential != nil { + credential = self.clientCredential + disposition = .useCredential + } + + completionHandler(disposition, credential) + } } diff --git a/Sources/Engine/WSEngine.swift b/Sources/Engine/WSEngine.swift index decca641..e0872e2b 100644 --- a/Sources/Engine/WSEngine.swift +++ b/Sources/Engine/WSEngine.swift @@ -15,6 +15,7 @@ FrameCollectorDelegate, HTTPHandlerDelegate { private let httpHandler: HTTPHandler private let compressionHandler: CompressionHandler? private let certPinner: CertificatePinning? + private let clientCredential: URLCredential? private let headerChecker: HeaderValidator private var request: URLRequest! @@ -30,6 +31,7 @@ FrameCollectorDelegate, HTTPHandlerDelegate { public init(transport: Transport, certPinner: CertificatePinning? = nil, + clientCredential: URLCredential? = nil, headerValidator: HeaderValidator = FoundationSecurity(), httpHandler: HTTPHandler = FoundationHTTPHandler(), framer: Framer = WSFramer(), @@ -38,6 +40,7 @@ FrameCollectorDelegate, HTTPHandlerDelegate { self.framer = framer self.httpHandler = httpHandler self.certPinner = certPinner + self.clientCredential = clientCredential self.headerChecker = headerValidator self.compressionHandler = compressionHandler framer.updateCompression(supports: compressionHandler != nil) @@ -64,7 +67,7 @@ FrameCollectorDelegate, HTTPHandlerDelegate { guard let url = request.url else { return } - transport.connect(url: url, timeout: request.timeoutInterval, certificatePinning: certPinner) + transport.connect(url: url, timeout: request.timeoutInterval, certificatePinning: certPinner, clientCredential: clientCredential) } public func stop(closeCode: UInt16 = CloseCode.normal.rawValue) { diff --git a/Sources/Framer/HTTPHandler.swift b/Sources/Framer/HTTPHandler.swift index 70941e75..5e08bd77 100644 --- a/Sources/Framer/HTTPHandler.swift +++ b/Sources/Framer/HTTPHandler.swift @@ -78,7 +78,8 @@ public struct HTTPWSHeader { let val = "permessage-deflate; client_max_window_bits; server_max_window_bits=15" req.setValue(val, forHTTPHeaderField: HTTPWSHeader.extensionName) } - let hostValue = req.allHTTPHeaderFields?[HTTPWSHeader.hostName] ?? "\(parts.host):\(parts.port)" + let hostname = request.url?.port != nil ? "\(parts.host):\(parts.port)" : parts.host + let hostValue = req.allHTTPHeaderFields?[HTTPWSHeader.hostName] ?? hostname req.setValue(hostValue, forHTTPHeaderField: HTTPWSHeader.hostName) return req } diff --git a/Sources/Starscream/WebSocket.swift b/Sources/Starscream/WebSocket.swift index 1d3545c3..d348ab47 100644 --- a/Sources/Starscream/WebSocket.swift +++ b/Sources/Starscream/WebSocket.swift @@ -120,13 +120,13 @@ open class WebSocket: WebSocketClient, EngineDelegate { self.engine = engine } - public convenience init(request: URLRequest, certPinner: CertificatePinning? = FoundationSecurity(), compressionHandler: CompressionHandler? = nil, useCustomEngine: Bool = true) { + public convenience init(request: URLRequest, certPinner: CertificatePinning? = FoundationSecurity(), clientCredential: URLCredential? = nil, compressionHandler: CompressionHandler? = nil, useCustomEngine: Bool = true) { if #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *), !useCustomEngine { - self.init(request: request, engine: NativeEngine()) + self.init(request: request, engine: NativeEngine(clientCredential: clientCredential)) } else if #available(macOS 10.14, iOS 12.0, watchOS 5.0, tvOS 12.0, *) { - self.init(request: request, engine: WSEngine(transport: TCPTransport(), certPinner: certPinner, compressionHandler: compressionHandler)) + self.init(request: request, engine: WSEngine(transport: TCPTransport(), certPinner: certPinner, clientCredential: clientCredential, compressionHandler: compressionHandler)) } else { - self.init(request: request, engine: WSEngine(transport: FoundationTransport(), certPinner: certPinner, compressionHandler: compressionHandler)) + self.init(request: request, engine: WSEngine(transport: FoundationTransport(), certPinner: certPinner, clientCredential: clientCredential, compressionHandler: compressionHandler)) } } diff --git a/Sources/Transport/FoundationTransport.swift b/Sources/Transport/FoundationTransport.swift index 8d304f88..972b28b7 100644 --- a/Sources/Transport/FoundationTransport.swift +++ b/Sources/Transport/FoundationTransport.swift @@ -52,7 +52,7 @@ public class FoundationTransport: NSObject, Transport, StreamDelegate { outputStream?.delegate = nil } - public func connect(url: URL, timeout: Double = 10, certificatePinning: CertificatePinning? = nil) { + public func connect(url: URL, timeout: Double = 10, certificatePinning: CertificatePinning? = nil, clientCredential: URLCredential? = nil) { guard let parts = url.getParts() else { delegate?.connectionChanged(state: .failed(FoundationTransportError.invalidRequest)) return @@ -75,6 +75,14 @@ public class FoundationTransport: NSObject, Transport, StreamDelegate { let key = CFStreamPropertyKey(rawValue: kCFStreamPropertySocketSecurityLevel) CFReadStreamSetProperty(inStream, key, kCFStreamSocketSecurityLevelNegotiatedSSL) CFWriteStreamSetProperty(outStream, key, kCFStreamSocketSecurityLevelNegotiatedSSL) + + if let clientCredential = clientCredential { + let certificates = [clientCredential.identity] + clientCredential.certificates + let sslSettings = [kCFStreamSSLCertificates: certificates] as CFDictionary + let sslSettingsKey = CFStreamPropertyKey(rawValue: kCFStreamPropertySSLSettings) + CFReadStreamSetProperty(inStream, sslSettingsKey, sslSettings) + CFWriteStreamSetProperty(outStream, sslSettingsKey, sslSettings) + } } onConnect?(inStream, outStream) diff --git a/Sources/Transport/TCPTransport.swift b/Sources/Transport/TCPTransport.swift index 459cb2ed..1b8fe57b 100644 --- a/Sources/Transport/TCPTransport.swift +++ b/Sources/Transport/TCPTransport.swift @@ -49,7 +49,7 @@ public class TCPTransport: Transport { //normal connection, will use the "connect" method below } - public func connect(url: URL, timeout: Double = 10, certificatePinning: CertificatePinning? = nil) { + public func connect(url: URL, timeout: Double = 10, certificatePinning: CertificatePinning? = nil, clientCredential: URLCredential? = nil) { guard let parts = url.getParts() else { delegate?.connectionChanged(state: .failed(TCPTransportError.invalidRequest)) return @@ -75,6 +75,12 @@ public class TCPTransport: Transport { } }) }, queue) + + if let clientCredential = clientCredential { + sec_protocol_options_set_challenge_block(tlsOpts.securityProtocolOptions, { (_, completionHandler) in + completionHandler(sec_identity_create(clientCredential.identity!)!) + }, queue) + } } let parameters = NWParameters(tls: tlsOptions, tcp: options) let conn = NWConnection(host: NWEndpoint.Host.name(parts.host, nil), port: NWEndpoint.Port(rawValue: UInt16(parts.port))!, using: parameters) diff --git a/Sources/Transport/Transport.swift b/Sources/Transport/Transport.swift index e645651f..08c720f1 100644 --- a/Sources/Transport/Transport.swift +++ b/Sources/Transport/Transport.swift @@ -48,7 +48,7 @@ public protocol TransportEventClient: class { public protocol Transport: class { func register(delegate: TransportEventClient) - func connect(url: URL, timeout: Double, certificatePinning: CertificatePinning?) + func connect(url: URL, timeout: Double, certificatePinning: CertificatePinning?, clientCredential: URLCredential?) func disconnect() func write(data: Data, completion: @escaping ((Error?) -> ())) var usingTLS: Bool { get }