Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support client certificate authentication #871

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
18 changes: 18 additions & 0 deletions Sources/Engine/NativeEngine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@ import Foundation
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
public class NativeEngine: NSObject, Engine, URLSessionDataDelegate, URLSessionWebSocketDelegate {
private var task: URLSessionWebSocketTask?
private var clientCredential: URLCredential?
weak var delegate: EngineDelegate?

public init(clientCredential: URLCredential? = nil) {
self.clientCredential = clientCredential
}

public func register(delegate: EngineDelegate) {
self.delegate = delegate
}
Expand Down Expand Up @@ -93,4 +98,17 @@ public class NativeEngine: NSObject, Engine, URLSessionDataDelegate, URLSessionW
}
broadcast(event: .disconnected(r, UInt16(closeCode.rawValue)))
}

public func urlSession(_ session: URLSession, didReceive challenge: URLAuthenticationChallenge, completionHandler: @escaping (URLSession.AuthChallengeDisposition, URLCredential?) -> Void) {
var credential: URLCredential? = nil
var disposition: URLSession.AuthChallengeDisposition = .performDefaultHandling

let authMethod = challenge.protectionSpace.authenticationMethod
if authMethod == NSURLAuthenticationMethodClientCertificate && self.clientCredential != nil {
credential = self.clientCredential
disposition = .useCredential
}

completionHandler(disposition, credential)
}
}
5 changes: 4 additions & 1 deletion Sources/Engine/WSEngine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ FrameCollectorDelegate, HTTPHandlerDelegate {
private let httpHandler: HTTPHandler
private let compressionHandler: CompressionHandler?
private let certPinner: CertificatePinning?
private let clientCredential: URLCredential?
private let headerChecker: HeaderValidator
private var request: URLRequest!

Expand All @@ -30,6 +31,7 @@ FrameCollectorDelegate, HTTPHandlerDelegate {

public init(transport: Transport,
certPinner: CertificatePinning? = nil,
clientCredential: URLCredential? = nil,
headerValidator: HeaderValidator = FoundationSecurity(),
httpHandler: HTTPHandler = FoundationHTTPHandler(),
framer: Framer = WSFramer(),
Expand All @@ -38,6 +40,7 @@ FrameCollectorDelegate, HTTPHandlerDelegate {
self.framer = framer
self.httpHandler = httpHandler
self.certPinner = certPinner
self.clientCredential = clientCredential
self.headerChecker = headerValidator
self.compressionHandler = compressionHandler
framer.updateCompression(supports: compressionHandler != nil)
Expand All @@ -64,7 +67,7 @@ FrameCollectorDelegate, HTTPHandlerDelegate {
guard let url = request.url else {
return
}
transport.connect(url: url, timeout: request.timeoutInterval, certificatePinning: certPinner)
transport.connect(url: url, timeout: request.timeoutInterval, certificatePinning: certPinner, clientCredential: clientCredential)
}

public func stop(closeCode: UInt16 = CloseCode.normal.rawValue) {
Expand Down
3 changes: 2 additions & 1 deletion Sources/Framer/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ public struct HTTPWSHeader {
let val = "permessage-deflate; client_max_window_bits; server_max_window_bits=15"
req.setValue(val, forHTTPHeaderField: HTTPWSHeader.extensionName)
}
let hostValue = req.allHTTPHeaderFields?[HTTPWSHeader.hostName] ?? "\(parts.host):\(parts.port)"
let hostname = request.url?.port != nil ? "\(parts.host):\(parts.port)" : parts.host
let hostValue = req.allHTTPHeaderFields?[HTTPWSHeader.hostName] ?? hostname
req.setValue(hostValue, forHTTPHeaderField: HTTPWSHeader.hostName)
return req
}
Expand Down
8 changes: 4 additions & 4 deletions Sources/Starscream/WebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ open class WebSocket: WebSocketClient, EngineDelegate {
self.engine = engine
}

public convenience init(request: URLRequest, certPinner: CertificatePinning? = FoundationSecurity(), compressionHandler: CompressionHandler? = nil, useCustomEngine: Bool = true) {
public convenience init(request: URLRequest, certPinner: CertificatePinning? = FoundationSecurity(), clientCredential: URLCredential? = nil, compressionHandler: CompressionHandler? = nil, useCustomEngine: Bool = true) {
if #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *), !useCustomEngine {
self.init(request: request, engine: NativeEngine())
self.init(request: request, engine: NativeEngine(clientCredential: clientCredential))
} else if #available(macOS 10.14, iOS 12.0, watchOS 5.0, tvOS 12.0, *) {
self.init(request: request, engine: WSEngine(transport: TCPTransport(), certPinner: certPinner, compressionHandler: compressionHandler))
self.init(request: request, engine: WSEngine(transport: TCPTransport(), certPinner: certPinner, clientCredential: clientCredential, compressionHandler: compressionHandler))
} else {
self.init(request: request, engine: WSEngine(transport: FoundationTransport(), certPinner: certPinner, compressionHandler: compressionHandler))
self.init(request: request, engine: WSEngine(transport: FoundationTransport(), certPinner: certPinner, clientCredential: clientCredential, compressionHandler: compressionHandler))
}
}

Expand Down
10 changes: 9 additions & 1 deletion Sources/Transport/FoundationTransport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public class FoundationTransport: NSObject, Transport, StreamDelegate {
outputStream?.delegate = nil
}

public func connect(url: URL, timeout: Double = 10, certificatePinning: CertificatePinning? = nil) {
public func connect(url: URL, timeout: Double = 10, certificatePinning: CertificatePinning? = nil, clientCredential: URLCredential? = nil) {
guard let parts = url.getParts() else {
delegate?.connectionChanged(state: .failed(FoundationTransportError.invalidRequest))
return
Expand All @@ -75,6 +75,14 @@ public class FoundationTransport: NSObject, Transport, StreamDelegate {
let key = CFStreamPropertyKey(rawValue: kCFStreamPropertySocketSecurityLevel)
CFReadStreamSetProperty(inStream, key, kCFStreamSocketSecurityLevelNegotiatedSSL)
CFWriteStreamSetProperty(outStream, key, kCFStreamSocketSecurityLevelNegotiatedSSL)

if let clientCredential = clientCredential {
let certificates = [clientCredential.identity] + clientCredential.certificates
let sslSettings = [kCFStreamSSLCertificates: certificates] as CFDictionary
let sslSettingsKey = CFStreamPropertyKey(rawValue: kCFStreamPropertySSLSettings)
CFReadStreamSetProperty(inStream, sslSettingsKey, sslSettings)
CFWriteStreamSetProperty(outStream, sslSettingsKey, sslSettings)
}
}

onConnect?(inStream, outStream)
Expand Down
8 changes: 7 additions & 1 deletion Sources/Transport/TCPTransport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public class TCPTransport: Transport {
//normal connection, will use the "connect" method below
}

public func connect(url: URL, timeout: Double = 10, certificatePinning: CertificatePinning? = nil) {
public func connect(url: URL, timeout: Double = 10, certificatePinning: CertificatePinning? = nil, clientCredential: URLCredential? = nil) {
guard let parts = url.getParts() else {
delegate?.connectionChanged(state: .failed(TCPTransportError.invalidRequest))
return
Expand All @@ -75,6 +75,12 @@ public class TCPTransport: Transport {
}
})
}, queue)

if let clientCredential = clientCredential {
sec_protocol_options_set_challenge_block(tlsOpts.securityProtocolOptions, { (_, completionHandler) in
completionHandler(sec_identity_create(clientCredential.identity!)!)
}, queue)
}
}
let parameters = NWParameters(tls: tlsOptions, tcp: options)
let conn = NWConnection(host: NWEndpoint.Host.name(parts.host, nil), port: NWEndpoint.Port(rawValue: UInt16(parts.port))!, using: parameters)
Expand Down
2 changes: 1 addition & 1 deletion Sources/Transport/Transport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public protocol TransportEventClient: class {

public protocol Transport: class {
func register(delegate: TransportEventClient)
func connect(url: URL, timeout: Double, certificatePinning: CertificatePinning?)
func connect(url: URL, timeout: Double, certificatePinning: CertificatePinning?, clientCredential: URLCredential?)
func disconnect()
func write(data: Data, completion: @escaping ((Error?) -> ()))
var usingTLS: Bool { get }
Expand Down