Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure atomicity between (re)connection attempts #5273

Merged
merged 1 commit into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 4 additions & 18 deletions ios/PacketTunnelCore/Actor/Actor+ConnectionMonitoring.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,28 +47,14 @@ extension PacketTunnelActor {
}
}

/// Increment connection attempt counter and reconnect the tunnel.
/// Tell the tunnel to reconnect providing the correct reason to ensure that the attempt counter is incremented before reconnect.
private func onHandleConnectionRecovery() async {
switch state {
case var .connecting(connState):
connState.incrementAttemptCount()
state = .connecting(connState)

case var .reconnecting(connState):
connState.incrementAttemptCount()
state = .reconnecting(connState)

case var .connected(connState):
connState.incrementAttemptCount()
state = .connected(connState)
case .connecting, .reconnecting, .connected:
commandChannel.send(.reconnect(.random, reason: .connectionLoss))

case .initial, .disconnected, .disconnecting, .error:
// Explicit return to prevent reconnecting the tunnel.
return
break
}

// Tunnel monitor should already be paused at this point so don't stop it to avoid a reset of its internal
// counters.
commandChannel.send(.reconnect(.random, stopTunnelMonitor: false))
}
}
46 changes: 35 additions & 11 deletions ios/PacketTunnelCore/Actor/Actor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ public actor PacketTunnelActor {
case .stop:
await stop()

case let .reconnect(nextRelay, stopTunnelMonitor):
await reconnect(to: nextRelay, shouldStopTunnelMonitor: stopTunnelMonitor)
case let .reconnect(nextRelay, reason):
await reconnect(to: nextRelay, reason: reason)

case let .error(reason):
await setErrorStateInternal(with: reason)
Expand Down Expand Up @@ -173,16 +173,22 @@ extension PacketTunnelActor {

- Parameters:
- nextRelay: next relay to connect to
- shouldStopTunnelMonitor: whether tunnel monitor should be stopped
- reason: reason for reconnect
*/
private func reconnect(to nextRelay: NextRelay, shouldStopTunnelMonitor: Bool) async {
private func reconnect(to nextRelay: NextRelay, reason: ReconnectReason) async {
do {
switch state {
case .connecting, .connected, .reconnecting, .error:
if shouldStopTunnelMonitor {
switch reason {
case .connectionLoss:
// Tunnel monitor is already paused at this point. Avoid calling stop() to prevent the reset of
// internal state
break
case .userInitiated:
tunnelMonitor.stop()
}
try await tryStart(nextRelay: nextRelay)

try await tryStart(nextRelay: nextRelay, reason: reason)

case .disconnected, .disconnecting, .initial:
break
Expand All @@ -205,12 +211,14 @@ extension PacketTunnelActor {
- Start tunnel monitor.
- Reactivate default path observation (disabled when configuring tunnel adapter)

- Parameter nextRelay: which relay should be selected next.
- Parameters:
- nextRelay: which relay should be selected next.
- reason: reason for reconnect
*/
private func tryStart(nextRelay: NextRelay = .random) async throws {
private func tryStart(nextRelay: NextRelay = .random, reason: ReconnectReason = .userInitiated) async throws {
let settings: Settings = try settingsReader.read()

guard let connectionState = try makeConnectionState(nextRelay: nextRelay, settings: settings),
guard let connectionState = try makeConnectionState(nextRelay: nextRelay, settings: settings, reason: reason),
let targetState = state.targetStateForReconnect else { return }

let activeKey: PrivateKey
Expand Down Expand Up @@ -261,10 +269,15 @@ extension PacketTunnelActor {
- Parameters:
- nextRelay: relay preference that should be used when selecting next relay.
- settings: current settings
- reason: reason for reconnect

- Returns: New connection state or `nil` if current state is at or past `.disconnecting` phase.
*/
private func makeConnectionState(nextRelay: NextRelay, settings: Settings) throws -> ConnectionState? {
private func makeConnectionState(
nextRelay: NextRelay,
settings: Settings,
reason: ReconnectReason
) throws -> ConnectionState? {
let relayConstraints = settings.relayConstraints
let privateKey = settings.privateKey

Expand All @@ -284,7 +297,18 @@ extension PacketTunnelActor {
connectionAttemptCount: 0
)

case var .connecting(connState), var .connected(connState), var .reconnecting(connState):
case var .connecting(connState), var .reconnecting(connState):
switch reason {
case .connectionLoss:
// Increment attempt counter when reconnection is requested due to connectivity loss.
connState.incrementAttemptCount()
case .userInitiated:
break
}
// Explicit fallthrough
fallthrough

case var .connected(connState):
connState.selectedRelay = try selectRelay(
nextRelay: nextRelay,
relayConstraints: relayConstraints,
Expand Down
4 changes: 1 addition & 3 deletions ios/PacketTunnelCore/Actor/Command.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ enum Command {
case stop

/// Reconnect tunnel.
/// `stopTunnelMonitor = false` is only used when tunnel monitor is paused in response to connectivity loss and shouldn't be stopped explicitly,
/// as this would reset its internal counters.
case reconnect(NextRelay, stopTunnelMonitor: Bool = true)
case reconnect(NextRelay, reason: ReconnectReason = .userInitiated)

/// Enter blocked state.
case error(BlockedStateReason)
Expand Down
10 changes: 10 additions & 0 deletions ios/PacketTunnelCore/Actor/State.swift
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,13 @@ public enum NextRelay: Equatable, Codable {
/// Use pre-selected relay.
case preSelected(SelectedRelay)
}

/// Describes the reason for reconnection request.
public enum ReconnectReason {
/// Initiated by user.
case userInitiated

/// Initiated by tunnel monitor due to loss of connectivity.
/// Actor will increment the connection attempt counter before picking next relay.
case connectionLoss
}
94 changes: 94 additions & 0 deletions ios/PacketTunnelCoreTests/ActorTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,98 @@ final class ActorTests: XCTestCase {

await fulfillment(of: allExpectations, timeout: 1, enforceOrder: true)
}

/**
Each subsequent connection attempt should produce a single change to `state` containing the incremented attempt counter and new relay.

.connecting (attempt: 0) → .connecting (attempt: 1) → .connecting (attempt: 2) → ...
*/
func testConnectionAttemptTransition() async throws {
let tunnelMonitor = TunnelMonitorStub { _, _ in }
let actor = PacketTunnelActor.mock(tunnelMonitor: tunnelMonitor)
let connectingStateExpectation = expectation(description: "Expect connecting state")
connectingStateExpectation.expectedFulfillmentCount = 5

var nextAttemptCount: UInt = 0
stateSink = await actor.$state
.receive(on: DispatchQueue.main)
.sink { newState in
switch newState {
case .initial:
break

case let .connecting(connState):
XCTAssertEqual(connState.connectionAttemptCount, nextAttemptCount)
nextAttemptCount += 1
connectingStateExpectation.fulfill()

if nextAttemptCount < connectingStateExpectation.expectedFulfillmentCount {
tunnelMonitor.dispatch(.connectionLost, after: .milliseconds(10))
}

default:
XCTFail("Received invalid state: \(newState.name).")
}
}

self.actor = actor

actor.start(options: StartOptions(launchSource: .app))

await fulfillment(of: [connectingStateExpectation], timeout: 1)
}

/**
Each subsequent re-connection attempt should produce a single change to `state` containing the incremented attempt counter and new relay.

.reconnecting (attempt: 0) → .reconnecting (attempt: 1) → .reconnecting (attempt: 2) → ...
*/
func testReconnectionAttemptTransition() async throws {
let tunnelMonitor = TunnelMonitorStub { _, _ in }
let actor = PacketTunnelActor.mock(tunnelMonitor: tunnelMonitor)
let connectingStateExpectation = expectation(description: "Expect connecting state")
let connectedStateExpectation = expectation(description: "Expect connected state")
let reconnectingStateExpectation = expectation(description: "Expect reconnecting state")
reconnectingStateExpectation.expectedFulfillmentCount = 5

var nextAttemptCount: UInt = 0
stateSink = await actor.$state
.receive(on: DispatchQueue.main)
.sink { newState in
switch newState {
case .initial:
break

case .connecting:
connectingStateExpectation.fulfill()
tunnelMonitor.dispatch(.connectionEstablished, after: .milliseconds(10))

case .connected:
connectedStateExpectation.fulfill()
tunnelMonitor.dispatch(.connectionLost, after: .milliseconds(10))

case let .reconnecting(connState):
XCTAssertEqual(connState.connectionAttemptCount, nextAttemptCount)
nextAttemptCount += 1
reconnectingStateExpectation.fulfill()

if nextAttemptCount < reconnectingStateExpectation.expectedFulfillmentCount {
tunnelMonitor.dispatch(.connectionLost, after: .milliseconds(10))
}

default:
XCTFail("Received invalid state: \(newState.name).")
}
}

self.actor = actor

actor.start(options: StartOptions(launchSource: .app))

await fulfillment(
of: [connectingStateExpectation, connectedStateExpectation, reconnectingStateExpectation],
timeout: 1,
enforceOrder: true
)
}
}
2 changes: 1 addition & 1 deletion ios/PacketTunnelCoreTests/Mocks/TunnelMonitorStub.swift
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class TunnelMonitorStub: TunnelMonitorProtocol {

func onSleep() {}

private func dispatch(_ event: TunnelMonitorEvent, after delay: DispatchTimeInterval = .never) {
func dispatch(_ event: TunnelMonitorEvent, after delay: DispatchTimeInterval = .never) {
if case .never = delay {
onEvent?(event)
} else {
Expand Down
Loading