Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved IPC support #536

Merged
merged 12 commits into from
Oct 22, 2023
Merged
3 changes: 3 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ let package = Package(
.target(name: "WireGuardC"),
.product(name: "WireGuard", package: "wireguard-apple"),
"Common"
],
swiftSettings: [
.define("DEBUG", .when(configuration: .debug))
]),
.target(
name: "SecureStorage",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>_XCCurrentVersionName</key>
<string>HTTPSUpgrade 3.xcdatamodel</string>
</dict>
<dict/>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, any reason why this changed?

</plist>
4 changes: 4 additions & 0 deletions Sources/NetworkProtection/Controllers/TunnelController.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,8 @@ public protocol TunnelController {
/// Stops the VPN connection used for Network Protection
///
func stop() async

/// Whether the tunnel is connected
///
var isConnected: Bool { get async }
}
25 changes: 25 additions & 0 deletions Sources/NetworkProtection/ExtensionMessage/ExtensionMessage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ public enum ExtensionMessage: RawRepresentable {
public typealias RawValue = Data

enum Name: UInt8 {
// This is actually an improved way to send messages.
// Please avoid adding new messages to this enum, and instead
// add them to `ExtensionRequest`
case request = 255

case resetAllState = 0
case getRuntimeConfiguration
case getLastErrorMessage
Expand All @@ -40,6 +45,11 @@ public enum ExtensionMessage: RawRepresentable {
case simulateConnectionInterruption
}

// This is actually an improved way to send messages.
// Please avoid adding new messages to this enum, and instead
// add them to `ExtensionRequest`
case request(_ request: ExtensionRequest)
Copy link
Contributor Author

@diegoreymendez diegoreymendez Oct 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment has it... ExtensionRequest is meant to replace ExtensionMessage.


// important: Preserve this order because Message Name is represented by Int value
case resetAllState
case getRuntimeConfiguration
Expand All @@ -62,6 +72,12 @@ public enum ExtensionMessage: RawRepresentable {
public init?(rawValue data: Data) {
let name = data.first.flatMap(Name.init(rawValue:))
switch name {
case .request:
guard let request = try? JSONDecoder().decode(ExtensionRequest.self, from: data[1...]) else {
return nil
}

self = .request(request)
case .resetAllState:
self = .resetAllState
case .getRuntimeConfiguration:
Expand Down Expand Up @@ -127,6 +143,7 @@ public enum ExtensionMessage: RawRepresentable {
// TO BE: Replaced with auto case name generating Macro when Xcode 15
private var name: Name {
switch self {
case .request: return .request
case .resetAllState: return .resetAllState
case .getRuntimeConfiguration: return .getRuntimeConfiguration
case .getLastErrorMessage: return .getLastErrorMessage
Expand All @@ -149,6 +166,14 @@ public enum ExtensionMessage: RawRepresentable {
public var rawValue: Data {
var encoder: (inout Data) -> Void = { _ in }
switch self {
case .request(let request):
encoder = {
do {
try $0.append(JSONEncoder().encode(request))
} catch {
assertionFailure("could not encode request: \(error)")
}
}
case .setSelectedServer(.some(let serverName)):
encoder = {
$0.append(ExtensionMessageString(serverName).rawValue)
Expand Down
29 changes: 29 additions & 0 deletions Sources/NetworkProtection/ExtensionMessage/ExtensionRequest.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
//
// ExtensionRequest.swift
//
// Copyright © 2023 DuckDuckGo. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

import Foundation

public enum DebugCommand: Codable {
case expireRegistrationKey
case sendTestNotification
}

public enum ExtensionRequest: Codable {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The advantage of using this is two folded:

  • Much simpler: adding a case is enough to be able to receive it in the extension, there's no need to add code to encode the data.
  • Less likely to cause trouble: because the app and extension use this enum, adding a new case will present an error if the case handling is missing in either side.

case changeTunnelSetting(_ change: TunnelSettings.Change)
case debugCommand(_ command: DebugCommand)
}
115 changes: 74 additions & 41 deletions Sources/NetworkProtection/PacketTunnelProvider.swift
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,11 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
return self.protocolConfiguration.enforceRoutes || self.protocolConfiguration.includeAllNetworks
}

// MARK: - Server Selection
// MARK: - Tunnel Settings

private let settings = TunnelSettings(defaults: .standard)

let selectedServerStore = NetworkProtectionSelectedServerUserDefaultsStore()
// MARK: - Server Selection

public var lastSelectedServerInfo: NetworkProtectionServerInfo? {
didSet {
Expand All @@ -137,25 +139,6 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {

private let tokenStore: NetworkProtectionTokenStore

/// This is for overriding the defaults. A `nil` value means NetP will just use the defaults.
///
private var keyValidity: TimeInterval?

private static let defaultRetryInterval: TimeInterval = .minutes(1)

/// Normally we'll retry using the default interval, but since we can override the key validity interval for testing purposes
/// we'll retry sooner if it's been overridden with values lower than the default retry interval.
///
/// In practical terms this means that if the validity interval is 15 secs, the retry will also be 15 secs instead of 1 minute.
///
private var retryInterval: TimeInterval {
guard let keyValidity = keyValidity else {
return Self.defaultRetryInterval
}

return keyValidity > Self.defaultRetryInterval ? Self.defaultRetryInterval : keyValidity
}

private func resetRegistrationKey() {
os_log("Resetting the current registration key", log: .networkProtectionKeyManagement)
keyStore.resetCurrentKeyPair()
Expand All @@ -182,27 +165,27 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
self.resetRegistrationKey()

do {
try await updateTunnelConfiguration(selectedServer: selectedServerStore.selectedServer, reassert: false)
try await updateTunnelConfiguration(selectedServer: settings.selectedServer, reassert: false)
} catch {
os_log("Rekey attempt failed. This is not an error if you're using debug Key Management options: %{public}@", log: .networkProtectionKeyManagement, type: .error, String(describing: error))
}
}

private func setKeyValidity(_ interval: TimeInterval?) {
guard keyValidity != interval else {
return
}

if let interval {
let firstExpirationDate = Date().addingTimeInterval(interval)

os_log("Setting key validity interval to %{public}@ seconds (next expiration date %{public}@)",
log: .networkProtectionKeyManagement,
String(describing: interval),
String(describing: firstExpirationDate))

settings.registrationKeyValidity = .custom(interval)
} else {
os_log("Resetting key validity interval",
log: .networkProtectionKeyManagement)

settings.registrationKeyValidity = .automatic
}

keyStore.setValidityInterval(interval)
Expand Down Expand Up @@ -387,11 +370,11 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
private func loadSelectedServer(from options: StartupOptions) {
switch options.selectedServer {
case .set(let selectedServer):
selectedServerStore.selectedServer = selectedServer
settings.selectedServer = selectedServer
case .useExisting:
break
case .reset:
selectedServerStore.selectedServer = .automatic
settings.selectedServer = .automatic
}
}

Expand Down Expand Up @@ -491,15 +474,15 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
let onDemand = options.startupMethod == .automaticOnDemand

os_log("Starting tunnel %{public}@", log: .networkProtection, options.startupMethod.debugDescription)
startTunnel(selectedServer: selectedServerStore.selectedServer, onDemand: onDemand, completionHandler: completionHandler)
startTunnel(selectedServer: settings.selectedServer, onDemand: onDemand, completionHandler: completionHandler)
}

private func startTunnel(selectedServer: SelectedNetworkProtectionServer, onDemand: Bool, completionHandler: @escaping (Error?) -> Void) {
private func startTunnel(selectedServer: TunnelSettings.SelectedServer, onDemand: Bool, completionHandler: @escaping (Error?) -> Void) {

Task {
let serverSelectionMethod: NetworkProtectionServerSelectionMethod

switch selectedServerStore.selectedServer {
switch settings.selectedServer {
case .automatic:
serverSelectionMethod = .automatic
case .endpoint(let serverName):
Expand Down Expand Up @@ -633,10 +616,10 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {

// MARK: - Tunnel Configuration

public func updateTunnelConfiguration(selectedServer: SelectedNetworkProtectionServer, reassert: Bool = true) async throws {
public func updateTunnelConfiguration(selectedServer: TunnelSettings.SelectedServer, reassert: Bool = true) async throws {
let serverSelectionMethod: NetworkProtectionServerSelectionMethod

switch selectedServerStore.selectedServer {
switch settings.selectedServer {
case .automatic:
serverSelectionMethod = .automatic
case .endpoint(let serverName):
Expand Down Expand Up @@ -717,6 +700,8 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
}

switch message {
case .request(let request):
handleRequest(request)
case .expireRegistrationKey:
handleExpireRegistrationKey(completionHandler: completionHandler)
case .getLastErrorMessage:
Expand All @@ -736,7 +721,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
case .resetAllState:
handleResetAllState(completionHandler: completionHandler)
case .triggerTestNotification:
handleTriggerTestNotification(completionHandler: completionHandler)
handleSendTestNotification(completionHandler: completionHandler)
case .setExcludedRoutes(let excludedRoutes):
setExcludedRoutes(excludedRoutes, completionHandler: completionHandler)
case .setIncludedRoutes(let includedRoutes):
Expand All @@ -752,6 +737,54 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
}
}

// MARK: - App Requests: Handling

private func handleRequest(_ request: ExtensionRequest, completionHandler: ((Data?) -> Void)? = nil) {
switch request {
case .changeTunnelSetting(let change):
handleSettingsChange(change, completionHandler: completionHandler)
case .debugCommand(let command):
handleDebugCommand(command, completionHandler: completionHandler)
}
}

private func handleSettingsChange(_ change: TunnelSettings.Change, completionHandler: ((Data?) -> Void)? = nil) {

settings.apply(change: change)

switch change {
case .setSelectedServer(let selectedServer):
let serverSelectionMethod: NetworkProtectionServerSelectionMethod

switch selectedServer {
case .automatic:
serverSelectionMethod = .automatic
case .endpoint(let serverName):
serverSelectionMethod = .preferredServer(serverName: serverName)
}

Task {
try? await updateTunnelConfiguration(serverSelectionMethod: serverSelectionMethod)
completionHandler?(nil)
}
case .setIncludeAllNetworks,
.setEnforceRoutes,
.setExcludeLocalNetworks,
.setRegistrationKeyValidity:
// Intentional no-op, as some setting changes don't require any further operation
break
}
}

private func handleDebugCommand(_ command: DebugCommand, completionHandler: ((Data?) -> Void)? = nil) {
switch command {
case .expireRegistrationKey:
handleExpireRegistrationKey(completionHandler: completionHandler)
case .sendTestNotification:
handleSendTestNotification(completionHandler: completionHandler)
}
}

// MARK: - App Messages: Handling

private func handleExpireRegistrationKey(completionHandler: ((Data?) -> Void)? = nil) {
Expand Down Expand Up @@ -794,20 +827,20 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
private func handleSetSelectedServer(_ serverName: String?, completionHandler: ((Data?) -> Void)? = nil) {
Task {
guard let serverName else {
if case .endpoint = selectedServerStore.selectedServer {
selectedServerStore.selectedServer = .automatic
if case .endpoint = settings.selectedServer {
settings.selectedServer = .automatic
try? await updateTunnelConfiguration(serverSelectionMethod: .automatic)
}
completionHandler?(nil)
return
}

guard selectedServerStore.selectedServer.stringValue != serverName else {
guard settings.selectedServer.stringValue != serverName else {
completionHandler?(nil)
return
}

selectedServerStore.selectedServer = .endpoint(serverName)
settings.selectedServer = .endpoint(serverName)
try? await updateTunnelConfiguration(serverSelectionMethod: .preferredServer(serverName: serverName))
completionHandler?(nil)
}
Expand All @@ -830,23 +863,23 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
}
}

private func handleTriggerTestNotification(completionHandler: ((Data?) -> Void)? = nil) {
private func handleSendTestNotification(completionHandler: ((Data?) -> Void)? = nil) {
notificationsPresenter.showTestNotification()
completionHandler?(nil)
}

private func setExcludedRoutes(_ excludedRoutes: [IPAddressRange], completionHandler: ((Data?) -> Void)? = nil) {
Task {
self.excludedRoutes = excludedRoutes
try? await updateTunnelConfiguration(selectedServer: selectedServerStore.selectedServer, reassert: false)
try? await updateTunnelConfiguration(selectedServer: settings.selectedServer, reassert: false)
completionHandler?(nil)
}
}

private func setIncludedRoutes(_ includedRoutes: [IPAddressRange], completionHandler: ((Data?) -> Void)? = nil) {
Task {
self.includedRoutes = includedRoutes
try? await updateTunnelConfiguration(selectedServer: selectedServerStore.selectedServer, reassert: false)
try? await updateTunnelConfiguration(selectedServer: settings.selectedServer, reassert: false)
completionHandler?(nil)
}
}
Expand Down
Loading