Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Insert SSH 3.11.41, add more unit tests and logging to tests #473

Merged
merged 1 commit into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cs/build/build.props
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
<ReportGeneratorVersion>4.8.13</ReportGeneratorVersion>
<SystemTextEncodingsWebPackageVersion>4.7.2</SystemTextEncodingsWebPackageVersion>
<VisualStudioValidationVersion>15.5.31</VisualStudioValidationVersion>
<DevTunnelsSshPackageVersion>3.11.36</DevTunnelsSshPackageVersion>
<DevTunnelsSshPackageVersion>3.11.41</DevTunnelsSshPackageVersion>
<XunitRunnerVisualStudioVersion>2.4.0</XunitRunnerVisualStudioVersion>
<XunitVersion>2.4.0</XunitVersion>
</PropertyGroup>
Expand Down
163 changes: 163 additions & 0 deletions cs/test/TunnelsSDK.Test/TcpListeners.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
using System.Diagnostics;
using System.Globalization;
using System.Net;
using System.Net.Sockets;
using System.Text;
using Microsoft.DevTunnels.Management;

namespace Microsoft.DevTunnels.Test;
public sealed class TcpListeners : IAsyncDisposable
{
private const int MaxAttempts = 10;

private readonly TraceSource trace;
private readonly CancellationTokenSource cts = new();
private readonly List<TcpListener> listeners = new();
private readonly List<Task> listenerTasks = new();

public TcpListeners(int count, TraceSource trace)
{
Requires.Argument(count > 0, nameof(count), "Count must be greater than 0.");
this.trace = trace.WithName("TcpListeners");
Ports = new int[count];
for (int index = 0; index < count; index++)
{
TcpListener listener = null;
int port;
int attempt = 0;
while (true)
{
try
{
port = TcpUtils.GetAvailableTcpPort(canReuseAddress: false);
listener = new TcpListener(IPAddress.Loopback, port);
listener.Start();
break;
}
catch (SocketException ex)
{
listener?.Stop();
if (++attempt >= MaxAttempts)
{
throw new InvalidOperationException("Failed to find available port", ex);
}
}
catch
{
listener?.Stop();
throw;
}
}

Ports[index] = port;
this.listeners.Add(listener);
this.listenerTasks.Add(AcceptConnectionsAsync(listener, port));
}

this.trace.Info("Listening on ports: {0}", string.Join(", ", Ports));
}

public int Port { get; }

public int[] Ports { get; }

public async ValueTask DisposeAsync()
{
cts.Cancel();
StopListeners();
await Task.WhenAll(this.listenerTasks);
this.listenerTasks.Clear();
}

private async Task AcceptConnectionsAsync(TcpListener listener, int port)
{
var tasks = new List<Task>();
TaskCompletionSource allTasksCompleted = null;
try
{
while (!cts.IsCancellationRequested)
{
var tcpClient = await listener.AcceptTcpClientAsync(cts.Token);
var task = Task.Run(() => RunClientAsync(tcpClient, port));
lock (tasks)
{
tasks.Add(task);
}

_ = task.ContinueWith(
(t) =>
{
lock (tasks)
{
tasks.Remove(t);
if (tasks.Count == 0)
{
allTasksCompleted?.TrySetResult();
}
}
});
}
}
catch (OperationCanceledException) when (this.cts.IsCancellationRequested)
{
// Ignore
}
catch (SocketException) when (this.cts.IsCancellationRequested)
{
// Ignore
}
catch (Exception ex)
{
this.trace.Error($"Error accepting TCP client for port {port}: ${ex}");
}

lock (tasks)
{
if (tasks.Count == 0)
{
return;
}

allTasksCompleted = new TaskCompletionSource();
}

await allTasksCompleted.Task;
}

private async Task RunClientAsync(TcpClient tcpClient, int port)
{
try
{
using var disposable = tcpClient;

this.trace.Info($"Accepted client connection to TCP port {port}");
await using var stream = tcpClient.GetStream();

var bytes = Encoding.UTF8.GetBytes(port.ToString(CultureInfo.InvariantCulture));
await stream.WriteAsync(bytes);

}
catch (OperationCanceledException) when (this.cts.IsCancellationRequested)
{
// Ignore
}
catch (SocketException) when (this.cts.IsCancellationRequested)
{
// Ignore
}
catch (Exception ex)
{
this.trace.Error($"Error handling TCP client on listener running on port {port}: ${ex}");
}
}

private void StopListeners()
{
foreach (var listener in this.listeners)
{
listener.Stop();
}

this.listeners.Clear();
}
}
19 changes: 17 additions & 2 deletions cs/test/TunnelsSDK.Test/TcpUtils.cs
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
using System.Net;
using System.Globalization;
using System.Net;
using System.Net.Sockets;
using System.Text;

namespace Microsoft.DevTunnels.Test;

internal static class TcpUtils
{
public static int GetAvailableTcpPort()
public static int GetAvailableTcpPort(bool canReuseAddress = true)
{
// Get any available local tcp port
var l = new TcpListener(IPAddress.Loopback, 0);
if (!canReuseAddress)
{
l.Server.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, false);
}

l.Start();
int port = ((IPEndPoint)l.LocalEndpoint).Port;
l.Stop();
return port;
}

public static async Task<int> ReadIntToEndAsync(this Stream stream, CancellationToken cancellation)
{
var buffer = new byte[1024];
var length = await stream.ReadAsync(buffer, cancellation);
var text = Encoding.UTF8.GetString(buffer, 0, length);
return int.Parse(text, CultureInfo.InvariantCulture);
}
}
77 changes: 68 additions & 9 deletions cs/test/TunnelsSDK.Test/TunnelHostAndClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Microsoft.DevTunnels.Test.Mocks;
using Nerdbank.Streams;
using Xunit;
using Xunit.Abstractions;
using Xunit.Sdk;

namespace Microsoft.DevTunnels.Test;
Expand All @@ -24,27 +25,24 @@ public class TunnelHostAndClientTests : IClassFixture<LocalPortsFixture>
private const string MockHostRelayUri = "ws://localhost/tunnel/host";
private const string MockClientRelayUri = "ws://localhost/tunnel/client";

private static readonly TraceSource TestTS =
private readonly TraceSource TestTS =
new TraceSource(nameof(TunnelHostAndClientTests));
private static readonly TimeSpan Timeout = Debugger.IsAttached ? TimeSpan.FromHours(1) : TimeSpan.FromSeconds(10);
private static readonly TimeSpan Timeout = Debugger.IsAttached ? TimeSpan.FromHours(1) : TimeSpan.FromSeconds(20);
private readonly CancellationToken TimeoutToken = new CancellationTokenSource(Timeout).Token;

private Stream serverStream;
private Stream clientStream;
private readonly IKeyPair serverSshKey;
private readonly LocalPortsFixture localPortsFixture;

static TunnelHostAndClientTests()
{
// Enabling tracing to debug console.
TestTS.Switch.Level = SourceLevels.All;
}

public TunnelHostAndClientTests(LocalPortsFixture localPortsFixture)
public TunnelHostAndClientTests(LocalPortsFixture localPortsFixture, ITestOutputHelper output)
{
(this.serverStream, this.clientStream) = FullDuplexStream.CreatePair();
this.serverSshKey = SshAlgorithms.PublicKey.ECDsaSha2Nistp384.GenerateKeyPair();
this.localPortsFixture = localPortsFixture;

TestTS.Switch.Level = SourceLevels.All;
TestTS.Listeners.Add(new XunitTraceListener(output));
}

private Tunnel CreateRelayTunnel(bool addClientEndpoint = true) => CreateRelayTunnel(addClientEndpoint, Enumerable.Empty<int>());
Expand Down Expand Up @@ -1453,6 +1451,67 @@ public async Task ConnectRelayHostThenConnectRelayClientToForwardedPortStream()
using var sshStream = await clientSshSession.ConnectToForwardedPortAsync(port, TimeoutToken);
}

[Fact]
public async Task ConnectRelayHostThenConnectRelayClientsToForwardedPortStreamsThenSendData()
{
const int PortCount = 2;
const int ClientConnectionCount = 50;

var managementClient = new MockTunnelManagementClient
{
HostRelayUri = MockHostRelayUri,
ClientRelayUri = MockClientRelayUri,
};

var relayHost = new TunnelRelayTunnelHost(managementClient, TestTS);

await using var listeners = new TcpListeners(PortCount, TestTS);
var tunnel = CreateRelayTunnel(false, listeners.Ports);

using var multiChannelStream = await ConnectRelayHostAsync(relayHost, tunnel);
Assert.Equal(ConnectionStatus.Connected, relayHost.ConnectionStatus);

var clientStreamFactory = new MockTunnelRelayStreamFactory(TunnelRelayConnection.ClientWebSocketSubProtocol)
{
StreamFactory = async (accessToken) =>
{
return await multiChannelStream.OpenStreamAsync(TunnelRelayTunnelHost.ClientStreamChannelType);
},
};

for (int clientConnection = 0; clientConnection < ClientConnectionCount; clientConnection++)
{
foreach (var port in listeners.Ports)
{
TestTS.TraceInformation("Connecting client #{0} to port {1}", clientConnection, port);

// Create and connect tunnel client
await using var relayClient = new TunnelRelayTunnelClient(TestTS)
{
AcceptLocalConnectionsForForwardedPorts = false,
StreamFactory = clientStreamFactory,
};

Assert.Equal(ConnectionStatus.None, relayClient.ConnectionStatus);

await relayClient.ConnectAsync(tunnel, TimeoutToken);
Assert.Equal(ConnectionStatus.Connected, relayClient.ConnectionStatus);

await relayClient.WaitForForwardedPortAsync(port, TimeoutToken);
using var stream = await relayClient.ConnectToForwardedPortAsync(port, TimeoutToken);

var actualPort = await stream.ReadIntToEndAsync(TimeoutToken);
if (port != actualPort)
{
// Debugger.Launch();
TestTS.TraceInformation("Client #{0} received unexpected port {1} instead of {2}", clientConnection, actualPort, port);
}

Assert.Equal(port, actualPort);
}
}
}

[Fact]
public async Task ConnectRelayHostThenConnectRelayClientToDifferentPort_Fails()
{
Expand Down
32 changes: 32 additions & 0 deletions cs/test/TunnelsSDK.Test/XunitTraceListener.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using System.Diagnostics;
using System.Text;
using Xunit.Abstractions;

namespace Microsoft.DevTunnels.Test;

internal sealed class XunitTraceListener : TraceListener
{
private readonly ITestOutputHelper output;
private readonly StringBuilder currentLine = new ();
private readonly DateTimeOffset loggingStart = DateTimeOffset.UtcNow;
private DateTimeOffset? messageStart;

public XunitTraceListener(ITestOutputHelper output)
{
this.output = output;
}

public override void Write(string message)
{
this.messageStart ??= DateTimeOffset.UtcNow;
this.currentLine.Append(message);
}

public override void WriteLine(string message)
{
var messageTime = (this.messageStart ?? DateTimeOffset.UtcNow) - this.loggingStart;
this.output.WriteLine($"{messageTime} {this.currentLine}{message}");
this.currentLine.Clear();
this.messageStart = null;
}
}
Loading