From 414662027cdbf9c06c97aab61a528bc731fded66 Mon Sep 17 00:00:00 2001 From: Ilya Biryukov Date: Thu, 15 Aug 2024 17:15:33 -0700 Subject: [PATCH] Insert SSH 3.11.41, add more unit tests and logging to tests Fix for https://github.com/devdiv-microsoft/basis-planning/issues/1618 Insert SSH 3.11.41 that has the fix. Add a unit test for client connecting to host when the tunnel has multiple port. Add more logging to unit tests. --- cs/build/build.props | 2 +- cs/test/TunnelsSDK.Test/TcpListeners.cs | 163 ++++++++++++++++++ cs/test/TunnelsSDK.Test/TcpUtils.cs | 19 +- .../TunnelHostAndClientTests.cs | 77 ++++++++- cs/test/TunnelsSDK.Test/XunitTraceListener.cs | 32 ++++ 5 files changed, 281 insertions(+), 12 deletions(-) create mode 100644 cs/test/TunnelsSDK.Test/TcpListeners.cs create mode 100644 cs/test/TunnelsSDK.Test/XunitTraceListener.cs diff --git a/cs/build/build.props b/cs/build/build.props index 67d73351..12744ef0 100644 --- a/cs/build/build.props +++ b/cs/build/build.props @@ -48,7 +48,7 @@ 4.8.13 4.7.2 15.5.31 - 3.11.36 + 3.11.41 2.4.0 2.4.0 diff --git a/cs/test/TunnelsSDK.Test/TcpListeners.cs b/cs/test/TunnelsSDK.Test/TcpListeners.cs new file mode 100644 index 00000000..c638484b --- /dev/null +++ b/cs/test/TunnelsSDK.Test/TcpListeners.cs @@ -0,0 +1,163 @@ +using System.Diagnostics; +using System.Globalization; +using System.Net; +using System.Net.Sockets; +using System.Text; +using Microsoft.DevTunnels.Management; + +namespace Microsoft.DevTunnels.Test; +public sealed class TcpListeners : IAsyncDisposable +{ + private const int MaxAttempts = 10; + + private readonly TraceSource trace; + private readonly CancellationTokenSource cts = new(); + private readonly List listeners = new(); + private readonly List listenerTasks = new(); + + public TcpListeners(int count, TraceSource trace) + { + Requires.Argument(count > 0, nameof(count), "Count must be greater than 0."); + this.trace = trace.WithName("TcpListeners"); + Ports = new int[count]; + for (int index = 0; index < count; index++) + { + TcpListener listener = null; + int port; + int attempt = 0; + while (true) + { + try + { + port = TcpUtils.GetAvailableTcpPort(canReuseAddress: false); + listener = new TcpListener(IPAddress.Loopback, port); + listener.Start(); + break; + } + catch (SocketException ex) + { + listener?.Stop(); + if (++attempt >= MaxAttempts) + { + throw new InvalidOperationException("Failed to find available port", ex); + } + } + catch + { + listener?.Stop(); + throw; + } + } + + Ports[index] = port; + this.listeners.Add(listener); + this.listenerTasks.Add(AcceptConnectionsAsync(listener, port)); + } + + this.trace.Info("Listening on ports: {0}", string.Join(", ", Ports)); + } + + public int Port { get; } + + public int[] Ports { get; } + + public async ValueTask DisposeAsync() + { + cts.Cancel(); + StopListeners(); + await Task.WhenAll(this.listenerTasks); + this.listenerTasks.Clear(); + } + + private async Task AcceptConnectionsAsync(TcpListener listener, int port) + { + var tasks = new List(); + TaskCompletionSource allTasksCompleted = null; + try + { + while (!cts.IsCancellationRequested) + { + var tcpClient = await listener.AcceptTcpClientAsync(cts.Token); + var task = Task.Run(() => RunClientAsync(tcpClient, port)); + lock (tasks) + { + tasks.Add(task); + } + + _ = task.ContinueWith( + (t) => + { + lock (tasks) + { + tasks.Remove(t); + if (tasks.Count == 0) + { + allTasksCompleted?.TrySetResult(); + } + } + }); + } + } + catch (OperationCanceledException) when (this.cts.IsCancellationRequested) + { + // Ignore + } + catch (SocketException) when (this.cts.IsCancellationRequested) + { + // Ignore + } + catch (Exception ex) + { + this.trace.Error($"Error accepting TCP client for port {port}: ${ex}"); + } + + lock (tasks) + { + if (tasks.Count == 0) + { + return; + } + + allTasksCompleted = new TaskCompletionSource(); + } + + await allTasksCompleted.Task; + } + + private async Task RunClientAsync(TcpClient tcpClient, int port) + { + try + { + using var disposable = tcpClient; + + this.trace.Info($"Accepted client connection to TCP port {port}"); + await using var stream = tcpClient.GetStream(); + + var bytes = Encoding.UTF8.GetBytes(port.ToString(CultureInfo.InvariantCulture)); + await stream.WriteAsync(bytes); + + } + catch (OperationCanceledException) when (this.cts.IsCancellationRequested) + { + // Ignore + } + catch (SocketException) when (this.cts.IsCancellationRequested) + { + // Ignore + } + catch (Exception ex) + { + this.trace.Error($"Error handling TCP client on listener running on port {port}: ${ex}"); + } + } + + private void StopListeners() + { + foreach (var listener in this.listeners) + { + listener.Stop(); + } + + this.listeners.Clear(); + } +} diff --git a/cs/test/TunnelsSDK.Test/TcpUtils.cs b/cs/test/TunnelsSDK.Test/TcpUtils.cs index f02e82c8..79eb5942 100644 --- a/cs/test/TunnelsSDK.Test/TcpUtils.cs +++ b/cs/test/TunnelsSDK.Test/TcpUtils.cs @@ -1,17 +1,32 @@ -using System.Net; +using System.Globalization; +using System.Net; using System.Net.Sockets; +using System.Text; namespace Microsoft.DevTunnels.Test; internal static class TcpUtils { - public static int GetAvailableTcpPort() + public static int GetAvailableTcpPort(bool canReuseAddress = true) { // Get any available local tcp port var l = new TcpListener(IPAddress.Loopback, 0); + if (!canReuseAddress) + { + l.Server.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, false); + } + l.Start(); int port = ((IPEndPoint)l.LocalEndpoint).Port; l.Stop(); return port; } + + public static async Task ReadIntToEndAsync(this Stream stream, CancellationToken cancellation) + { + var buffer = new byte[1024]; + var length = await stream.ReadAsync(buffer, cancellation); + var text = Encoding.UTF8.GetString(buffer, 0, length); + return int.Parse(text, CultureInfo.InvariantCulture); + } } diff --git a/cs/test/TunnelsSDK.Test/TunnelHostAndClientTests.cs b/cs/test/TunnelsSDK.Test/TunnelHostAndClientTests.cs index cd7a25bd..855241bf 100644 --- a/cs/test/TunnelsSDK.Test/TunnelHostAndClientTests.cs +++ b/cs/test/TunnelsSDK.Test/TunnelHostAndClientTests.cs @@ -13,6 +13,7 @@ using Microsoft.DevTunnels.Test.Mocks; using Nerdbank.Streams; using Xunit; +using Xunit.Abstractions; using Xunit.Sdk; namespace Microsoft.DevTunnels.Test; @@ -24,9 +25,9 @@ public class TunnelHostAndClientTests : IClassFixture private const string MockHostRelayUri = "ws://localhost/tunnel/host"; private const string MockClientRelayUri = "ws://localhost/tunnel/client"; - private static readonly TraceSource TestTS = + private readonly TraceSource TestTS = new TraceSource(nameof(TunnelHostAndClientTests)); - private static readonly TimeSpan Timeout = Debugger.IsAttached ? TimeSpan.FromHours(1) : TimeSpan.FromSeconds(10); + private static readonly TimeSpan Timeout = Debugger.IsAttached ? TimeSpan.FromHours(1) : TimeSpan.FromSeconds(20); private readonly CancellationToken TimeoutToken = new CancellationTokenSource(Timeout).Token; private Stream serverStream; @@ -34,17 +35,14 @@ public class TunnelHostAndClientTests : IClassFixture private readonly IKeyPair serverSshKey; private readonly LocalPortsFixture localPortsFixture; - static TunnelHostAndClientTests() - { - // Enabling tracing to debug console. - TestTS.Switch.Level = SourceLevels.All; - } - - public TunnelHostAndClientTests(LocalPortsFixture localPortsFixture) + public TunnelHostAndClientTests(LocalPortsFixture localPortsFixture, ITestOutputHelper output) { (this.serverStream, this.clientStream) = FullDuplexStream.CreatePair(); this.serverSshKey = SshAlgorithms.PublicKey.ECDsaSha2Nistp384.GenerateKeyPair(); this.localPortsFixture = localPortsFixture; + + TestTS.Switch.Level = SourceLevels.All; + TestTS.Listeners.Add(new XunitTraceListener(output)); } private Tunnel CreateRelayTunnel(bool addClientEndpoint = true) => CreateRelayTunnel(addClientEndpoint, Enumerable.Empty()); @@ -1453,6 +1451,67 @@ public async Task ConnectRelayHostThenConnectRelayClientToForwardedPortStream() using var sshStream = await clientSshSession.ConnectToForwardedPortAsync(port, TimeoutToken); } + [Fact] + public async Task ConnectRelayHostThenConnectRelayClientsToForwardedPortStreamsThenSendData() + { + const int PortCount = 2; + const int ClientConnectionCount = 50; + + var managementClient = new MockTunnelManagementClient + { + HostRelayUri = MockHostRelayUri, + ClientRelayUri = MockClientRelayUri, + }; + + var relayHost = new TunnelRelayTunnelHost(managementClient, TestTS); + + await using var listeners = new TcpListeners(PortCount, TestTS); + var tunnel = CreateRelayTunnel(false, listeners.Ports); + + using var multiChannelStream = await ConnectRelayHostAsync(relayHost, tunnel); + Assert.Equal(ConnectionStatus.Connected, relayHost.ConnectionStatus); + + var clientStreamFactory = new MockTunnelRelayStreamFactory(TunnelRelayConnection.ClientWebSocketSubProtocol) + { + StreamFactory = async (accessToken) => + { + return await multiChannelStream.OpenStreamAsync(TunnelRelayTunnelHost.ClientStreamChannelType); + }, + }; + + for (int clientConnection = 0; clientConnection < ClientConnectionCount; clientConnection++) + { + foreach (var port in listeners.Ports) + { + TestTS.TraceInformation("Connecting client #{0} to port {1}", clientConnection, port); + + // Create and connect tunnel client + await using var relayClient = new TunnelRelayTunnelClient(TestTS) + { + AcceptLocalConnectionsForForwardedPorts = false, + StreamFactory = clientStreamFactory, + }; + + Assert.Equal(ConnectionStatus.None, relayClient.ConnectionStatus); + + await relayClient.ConnectAsync(tunnel, TimeoutToken); + Assert.Equal(ConnectionStatus.Connected, relayClient.ConnectionStatus); + + await relayClient.WaitForForwardedPortAsync(port, TimeoutToken); + using var stream = await relayClient.ConnectToForwardedPortAsync(port, TimeoutToken); + + var actualPort = await stream.ReadIntToEndAsync(TimeoutToken); + if (port != actualPort) + { + // Debugger.Launch(); + TestTS.TraceInformation("Client #{0} received unexpected port {1} instead of {2}", clientConnection, actualPort, port); + } + + Assert.Equal(port, actualPort); + } + } + } + [Fact] public async Task ConnectRelayHostThenConnectRelayClientToDifferentPort_Fails() { diff --git a/cs/test/TunnelsSDK.Test/XunitTraceListener.cs b/cs/test/TunnelsSDK.Test/XunitTraceListener.cs new file mode 100644 index 00000000..fbbe95f8 --- /dev/null +++ b/cs/test/TunnelsSDK.Test/XunitTraceListener.cs @@ -0,0 +1,32 @@ +using System.Diagnostics; +using System.Text; +using Xunit.Abstractions; + +namespace Microsoft.DevTunnels.Test; + +internal sealed class XunitTraceListener : TraceListener +{ + private readonly ITestOutputHelper output; + private readonly StringBuilder currentLine = new (); + private readonly DateTimeOffset loggingStart = DateTimeOffset.UtcNow; + private DateTimeOffset? messageStart; + + public XunitTraceListener(ITestOutputHelper output) + { + this.output = output; + } + + public override void Write(string message) + { + this.messageStart ??= DateTimeOffset.UtcNow; + this.currentLine.Append(message); + } + + public override void WriteLine(string message) + { + var messageTime = (this.messageStart ?? DateTimeOffset.UtcNow) - this.loggingStart; + this.output.WriteLine($"{messageTime} {this.currentLine}{message}"); + this.currentLine.Clear(); + this.messageStart = null; + } +} \ No newline at end of file