Skip to content

Commit

Permalink
Reconnect V2 protocol (#278)
Browse files Browse the repository at this point in the history
  • Loading branch information
jasongin authored Aug 6, 2023
1 parent d54b31b commit 2793e16
Show file tree
Hide file tree
Showing 13 changed files with 311 additions and 121 deletions.
2 changes: 1 addition & 1 deletion cs/build/build.props
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
<ReportGeneratorVersion>4.8.13</ReportGeneratorVersion>
<SystemTextEncodingsWebPackageVersion>4.7.2</SystemTextEncodingsWebPackageVersion>
<VisualStudioValidationVersion>15.5.31</VisualStudioValidationVersion>
<DevTunnelsSshPackageVersion>3.11.16</DevTunnelsSshPackageVersion>
<DevTunnelsSshPackageVersion>3.11.21</DevTunnelsSshPackageVersion>
<XunitRunnerVisualStudioVersion>2.4.0</XunitRunnerVisualStudioVersion>
<XunitVersion>2.4.0</XunitVersion>
</PropertyGroup>
Expand Down
132 changes: 116 additions & 16 deletions cs/src/Connections/TunnelClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public abstract class TunnelClient : TunnelConnection, ITunnelClient
{
private bool acceptLocalConnectionsForForwardedPorts = true;
private IPAddress localForwardingHostAddress = IPAddress.Loopback;
private readonly Dictionary<int, List<SecureStream>> disconnectedStreams = new();

/// <summary>
/// Creates a new instance of the <see cref="TunnelClient" /> class.
Expand Down Expand Up @@ -185,8 +186,11 @@ protected async Task StartSshSessionAsync(Stream stream, CancellationToken cance
this.SshSession.Request -= OnRequest;
}

// Enable reconnect only if connector is set as reconnect depends on it.
var clientConfig = new SshSessionConfiguration(enableReconnect: this.connector != null);
// Enable V1 reconnect only if connector is set as reconnect depends on it.
// (V2 SSH reconnect is handled by the SecureStream class.)
var clientConfig = new SshSessionConfiguration(
enableReconnect: this.connector != null &&
ConnectionProtocol == TunnelRelayTunnelClient.WebSocketSubProtocol);

if (ConnectionProtocol == TunnelRelayTunnelClient.WebSocketSubProtocolV2)
{
Expand Down Expand Up @@ -254,6 +258,61 @@ private void ConfigurePortForwardingService()
{
pfs.MessageFactory = this;
pfs.ForwardedPortConnecting += OnForwardedPortConnecting;
pfs.RemoteForwardedPorts.PortAdded += (_, e) => OnForwardedPortAdded(pfs, e);
pfs.RemoteForwardedPorts.PortUpdated += (_, e) => OnForwardedPortAdded(pfs, e);
}
}

private void OnForwardedPortAdded(PortForwardingService pfs, ForwardedPortEventArgs e)
{
var port = e.Port.RemotePort;
if (!port.HasValue)
{
return;
}

List<SecureStream>? streams;
lock (this.disconnectedStreams)
{
// If there are disconnected streams for the port, re-connect them now.
if (!this.disconnectedStreams.TryGetValue(port.Value, out streams))
{
streams = null;
}
}

if (streams?.Count > 0)
{
this.Trace.Verbose(
$"Reconnecting {streams.Count} stream(s) to forwarded port {port}");

for (int i = streams.Count; i > 0; i--)
{
Task.Run(async () =>
{
try
{
await pfs.ConnectToForwardedPortAsync(port.Value, CancellationToken.None);
this.Trace.Verbose($"Reconnected stream to forwarded port {port}");
}
catch (Exception ex)
{
this.Trace.Warning(
$"Failed to reconnect to forwarded port {port}: {ex.Message}");
lock (this.disconnectedStreams)
{
// The host is no longer accepting connections on the forwarded port?
// Dispose and clear the list of disconnected streams for the port,
// because it seems it is no longer possible to reconnect them.
while (streams.Count > 0)
{
streams[0].Dispose();
streams.RemoveAt(0);
}
}
}
});
}
}
}

Expand All @@ -278,23 +337,64 @@ protected virtual void OnForwardedPortConnecting(
e.TransformTask = EncryptChannelAsync(e.Stream);
async Task<Stream?> EncryptChannelAsync(SshStream channelStream)
{
var secureStream = new SecureStream(
e.Stream,
clientCredentials,
false,
channel.Trace.WithName(channel.Trace.Name + "." + channel.ChannelId));
secureStream.Authenticating += OnHostAuthenticating;

// Do not pass the cancellation token from the connecting event,
// because the connection will outlive the event.
await secureStream.ConnectAsync();
SecureStream? secureStream = null;

// If there's a disconnected SecureStream for the port, try to reconnect it.
// If there are multiple, pick one and the host will match by SSH session ID.
lock (this.disconnectedStreams)
{
if (this.disconnectedStreams.TryGetValue(e.Port, out var streamsList) &&
streamsList.Count > 0)
{
secureStream = streamsList[0];
streamsList.RemoveAt(0);
}
}

var trace = channel.Trace.WithName(channel.Trace.Name + "." + channel.ChannelId);
if (secureStream != null)
{
trace.Verbose($"Reconnecting encrypted stream for port {e.Port}...");
await secureStream.ReconnectAsync(channelStream);
trace.Verbose($"Reconnecting encrypted stream for port {e.Port} succeeded.");
}
else
{
secureStream = new SecureStream(
e.Stream, clientCredentials, enableReconnect: true, trace);
secureStream.Authenticating += OnHostAuthenticating;
secureStream.Disconnected += (_, _) => OnSecureStreamDisconnected(
e.Port, secureStream, trace);

// Do not pass the cancellation token from the connecting event,
// because the connection will outlive the event.
await secureStream.ConnectAsync();
}

return secureStream;
}
}

this.ForwardedPortConnecting?.Invoke(this, e);
}

private void OnSecureStreamDisconnected(int port, SecureStream secureStream, TraceSource trace)
{
trace.Verbose($"Encrypted stream for port {port} disconnected.");

lock (this.disconnectedStreams)
{
if (this.disconnectedStreams.TryGetValue(port, out var streamsList))
{
streamsList.Add(secureStream);
}
else
{
this.disconnectedStreams.Add(port, new List<SecureStream> { secureStream });
}
}
}

private void OnHostAuthenticating(object? sender, SshAuthenticatingEventArgs e)
{
// If this method returns without assigning e.AuthenticationTask, the auth fails.
Expand Down Expand Up @@ -325,14 +425,14 @@ private void OnHostAuthenticating(object? sender, SshAuthenticatingEventArgs e)
}
else if (Tunnel != null && ManagementClient != null)
{
this.Trace.Verbose("Host public key verificiation failed. Refreshing tunnel.");
this.Trace.Verbose("Host public key verification failed. Refreshing tunnel.");
this.Trace.Verbose("Host key: " + hostKey);
this.Trace.Verbose("Expected key(s): " + string.Join(", ", this.HostPublicKeys));
e.AuthenticationTask = RefreshTunnelAndAuthenticateHostAsync(hostKey, DisposeToken);
}
else
{
this.Trace.Error("Host public key verificiation failed.");
this.Trace.Error("Host public key verification failed.");
this.Trace.Verbose("Host key: " + hostKey);
this.Trace.Verbose("Expected key(s): " + string.Join(", ", this.HostPublicKeys));
}
Expand All @@ -353,7 +453,7 @@ private void OnHostAuthenticating(object? sender, SshAuthenticatingEventArgs e)

if (Tunnel == null)
{
this.Trace.Warning("Host public key verificiation failed. Tunnel is not found.");
this.Trace.Warning("Host public key verification failed. Tunnel is not found.");
return null;
}

Expand All @@ -371,7 +471,7 @@ private void OnHostAuthenticating(object? sender, SshAuthenticatingEventArgs e)
return new ClaimsPrincipal();
}

this.Trace.Error("Host public key verificiation failed.");
this.Trace.Error("Host public key verification failed.");
this.Trace.Verbose("Host key: " + hostKey);
this.Trace.Verbose("Expected key(s): " + string.Join(", ", this.HostPublicKeys));
return null;
Expand Down
4 changes: 3 additions & 1 deletion cs/src/Connections/TunnelConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ protected set
if (value != previousConnectionStatus)
{
// If there were temporary connection issue, DisconnectException may be not null.
// Since we have successfuly connected after all, clean it up.
// Since we have successfully connected after all, clean it up.
if (value == ConnectionStatus.Connected)
{
DisconnectException = null;
Expand Down Expand Up @@ -392,6 +392,8 @@ protected void StartReconnectTaskIfNotDisposed()
{
lock (DisposeLock)
{
ConnectionStatus = ConnectionStatus.Disconnected;

if (!this.disposeCts.IsCancellationRequested &&
this.reconnectTask == null &&
this.connector != null) // The connector may be null if the tunnel client/host was created directly from a stream.
Expand Down
23 changes: 12 additions & 11 deletions cs/src/Connections/TunnelRelayTunnelHost.cs
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ protected override void OnForwardedPortConnecting(
var secureStream = new SecureStream(
e.Stream,
serverCredentials,
null,
this.reconnectableSessions,
channel.Trace.WithName(channel.Trace.Name + "." + channel.ChannelId));

// The client was already authenticated by the relay.
Expand Down Expand Up @@ -450,7 +450,7 @@ private async Task ConnectAndRunClientSessionAsync(Stream stream, CancellationTo
session.Reconnected += OnSshClientReconnected;
session.Request += OnClientSessionRequest;
session.ChannelOpening += OnSshChannelOpening;
session.Closed += Session_Closed;
session.Closed += OnClientSessionClosed;

try
{
Expand All @@ -477,7 +477,7 @@ private async Task ConnectAndRunClientSessionAsync(Stream stream, CancellationTo
session.ClientAuthenticated -= OnSshClientAuthenticated;
session.Reconnected -= OnSshClientReconnected;
session.ChannelOpening -= OnSshChannelOpening;
session.Closed -= Session_Closed;
session.Closed -= OnClientSessionClosed;

RemoveClientSshSession(session);
}
Expand All @@ -488,24 +488,25 @@ private async Task ConnectAndRunClientSessionAsync(Stream stream, CancellationTo
throw;
}

void Session_Closed(object? sender, SshSessionClosedEventArgs e)
void OnClientSessionClosed(object? sender, SshSessionClosedEventArgs e)
{
// Reconnecting client session may cause the new session close with 'None' reason and null exception.
TraceSource trace = ((SshSession)sender!).Trace;

// Reconnecting client session may cause the new session to close with 'None' reason.
if (cancellation.IsCancellationRequested)
{
Trace.WithName("ClientSSH").Verbose("Session cancelled.");
trace.Verbose("Session cancelled.");
}
else if (e.Reason == SshDisconnectReason.ByApplication)
{
Trace.WithName("ClientSSH").Verbose("Session closed.");

trace.Verbose("Session closed.");
}
else if (e.Reason != SshDisconnectReason.None || e.Exception != null)
else if (e.Reason != SshDisconnectReason.None)
{
Trace.TraceEvent(
trace.TraceEvent(
TraceEventType.Error,
0,
"Client ssh session closed unexpectely due to {0}, \"{1}\"\n{2}",
"Session closed unexpectedly due to {0}, \"{1}\"\n{2}",
e.Reason,
e.Message,
e.Exception);
Expand Down
24 changes: 16 additions & 8 deletions cs/test/TunnelsSDK.Test/TunnelHostAndClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -674,25 +674,26 @@ await managementClient.CreateTunnelPortAsync(
Assert.Equal(this.localPortsFixture.Port, await clientPortAdded.Task);

// Reconnect the tunnel client
var relayClientDisconnected = new TaskCompletionSource();
var relayClientReconnected = new TaskCompletionSource();
relayClient.ConnectionStatusChanged += (sender, args) =>
{
switch (args.Status)
{
case ConnectionStatus.Connected:
relayClientReconnected.TrySetResult();
break;
case ConnectionStatus.Disconnected:
relayClientReconnected.TrySetException(args.DisconnectException ?? new Exception("Unexpected disconnection"));
relayClientDisconnected.TrySetResult();
break;
case ConnectionStatus.Connected:
relayClientReconnected.TrySetResult();
break;
}
};

await clientSshStream.Channel.CloseAsync();

await relayClientReconnected.Task;
await relayClientDisconnected.Task.WithTimeout(Timeout);
await relayClientReconnected.Task.WithTimeout(Timeout);

clientPortAdded = new TaskCompletionSource<int?>();
await managementClient.CreateTunnelPortAsync(
Expand Down Expand Up @@ -761,6 +762,7 @@ await managementClient.CreateTunnelPortAsync(
Assert.Equal(this.localPortsFixture.Port, await clientPortAdded.Task);

// Expect disconnection
bool reconnectStarted = false;
var relayClientDisconnected = new TaskCompletionSource<Exception>();
relayClient.ConnectionStatusChanged += (sender, args) =>
{
Expand All @@ -770,10 +772,16 @@ await managementClient.CreateTunnelPortAsync(
relayClientDisconnected.TrySetException(new Exception("Unexpected reconnection"));
break;
case ConnectionStatus.Disconnected:
relayClientDisconnected.TrySetResult(args.DisconnectException);
case ConnectionStatus.Connecting:
reconnectStarted = true;
break;
case ConnectionStatus.Disconnected:
if (reconnectStarted)
{
relayClientDisconnected.TrySetResult(args.DisconnectException);
}
break;
}
};

Expand Down
16 changes: 8 additions & 8 deletions ts/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions ts/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
"build-pack-publish": "npm run build && npm run pack && npm run publish"
},
"dependencies": {
"@microsoft/dev-tunnels-ssh": "^3.11.16",
"@microsoft/dev-tunnels-ssh-tcp": "^3.11.16",
"@microsoft/dev-tunnels-ssh": "^3.11.21",
"@microsoft/dev-tunnels-ssh-tcp": "^3.11.21",
"await-semaphore": "^0.1.3",
"axios": "^0.21.1",
"buffer": "^5.2.1",
Expand Down
Loading

0 comments on commit 2793e16

Please sign in to comment.