From 1d88c6cf5e0b62050f4b17431b9a34fa9e26070a Mon Sep 17 00:00:00 2001 From: James Newton-King Date: Fri, 2 Aug 2024 15:14:32 +0800 Subject: [PATCH] Propagate trace parent to SignalR hub invocations (#57049) --- .../Internal/HostingApplicationDiagnostics.cs | 77 +----- .../src/Microsoft.AspNetCore.Hosting.csproj | 1 + src/Shared/Diagnostics/ActivityCreator.cs | 123 ++++++++++ src/Shared/SignalR/InProcessTestServer.cs | 4 + .../csharp/Client.Core/src/HubConnection.cs | 21 ++ .../FunctionalTests/HubConnectionTests.cs | 220 +++++++++++++++++- .../UnitTests/HubConnectionTests.Protocol.cs | 65 +++++- .../Tests.Utils/ChannelExtensions.cs | 30 +++ .../testassets/Tests.Utils/TestClient.cs | 25 +- .../server/Core/src/HubConnectionHandler.cs | 4 - .../Core/src/Internal/DefaultHubDispatcher.cs | 111 ++++++--- .../Internal/SignalRServerActivitySource.cs | 2 + .../Microsoft.AspNetCore.SignalR.Core.csproj | 2 + .../HubConnectionHandlerTests.Activity.cs | 182 +++++++++++++-- 14 files changed, 742 insertions(+), 125 deletions(-) create mode 100644 src/Shared/Diagnostics/ActivityCreator.cs diff --git a/src/Hosting/Hosting/src/Internal/HostingApplicationDiagnostics.cs b/src/Hosting/Hosting/src/Internal/HostingApplicationDiagnostics.cs index a641d23e6c4e..1a344203727c 100644 --- a/src/Hosting/Hosting/src/Internal/HostingApplicationDiagnostics.cs +++ b/src/Hosting/Hosting/src/Internal/HostingApplicationDiagnostics.cs @@ -9,6 +9,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Http.Metadata; +using Microsoft.AspNetCore.Shared; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Hosting; @@ -389,80 +390,24 @@ private void RecordRequestStartMetrics(HttpContext httpContext) hasDiagnosticListener = false; var headers = httpContext.Request.Headers; - _propagator.ExtractTraceIdAndState(headers, + var activity = ActivityCreator.CreateFromRemote( + _activitySource, + _propagator, + headers, static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable? fieldValues) => { fieldValues = default; var headers = (IHeaderDictionary)carrier!; fieldValue = headers[fieldName]; }, - out var requestId, - out var traceState); - - Activity? activity = null; - if (_activitySource.HasListeners()) - { - if (ActivityContext.TryParse(requestId, traceState, isRemote: true, out ActivityContext context)) - { - // The requestId used the W3C ID format. Unfortunately, the ActivitySource.CreateActivity overload that - // takes a string parentId never sets HasRemoteParent to true. We work around that by calling the - // ActivityContext overload instead which sets HasRemoteParent to parentContext.IsRemote. - // https://github.com/dotnet/aspnetcore/pull/41568#discussion_r868733305 - activity = _activitySource.CreateActivity(ActivityName, ActivityKind.Server, context); - } - else - { - // Pass in the ID we got from the headers if there was one. - activity = _activitySource.CreateActivity(ActivityName, ActivityKind.Server, string.IsNullOrEmpty(requestId) ? null! : requestId); - } - } - + ActivityName, + ActivityKind.Server, + tags: null, + links: null, + loggingEnabled || diagnosticListenerActivityCreationEnabled); if (activity is null) { - // CreateActivity didn't create an Activity (this is an optimization for the - // case when there are no listeners). Let's create it here if needed. - if (loggingEnabled || diagnosticListenerActivityCreationEnabled) - { - activity = new Activity(ActivityName); - if (!string.IsNullOrEmpty(requestId)) - { - activity.SetParentId(requestId); - } - } - else - { - return null; - } - } - - // The trace id was successfully extracted, so we can set the trace state - // https://www.w3.org/TR/trace-context/#tracestate-header - if (!string.IsNullOrEmpty(requestId)) - { - if (!string.IsNullOrEmpty(traceState)) - { - activity.TraceStateString = traceState; - } - } - - // Baggage can be used regardless of whether a distributed trace id was present on the inbound request. - // https://www.w3.org/TR/baggage/#abstract - var baggage = _propagator.ExtractBaggage(headers, static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable? fieldValues) => - { - fieldValues = default; - var headers = (IHeaderDictionary)carrier!; - fieldValue = headers[fieldName]; - }); - - // AddBaggage adds items at the beginning of the list, so we need to add them in reverse to keep the same order as the client - // By contract, the propagator has already reversed the order of items so we need not reverse it again - // Order could be important if baggage has two items with the same key (that is allowed by the contract) - if (baggage is not null) - { - foreach (var baggageItem in baggage) - { - activity.AddBaggage(baggageItem.Key, baggageItem.Value); - } + return null; } _diagnosticListener.OnActivityImport(activity, httpContext); diff --git a/src/Hosting/Hosting/src/Microsoft.AspNetCore.Hosting.csproj b/src/Hosting/Hosting/src/Microsoft.AspNetCore.Hosting.csproj index 80948eacb708..e91f7124f54d 100644 --- a/src/Hosting/Hosting/src/Microsoft.AspNetCore.Hosting.csproj +++ b/src/Hosting/Hosting/src/Microsoft.AspNetCore.Hosting.csproj @@ -18,6 +18,7 @@ + diff --git a/src/Shared/Diagnostics/ActivityCreator.cs b/src/Shared/Diagnostics/ActivityCreator.cs new file mode 100644 index 000000000000..170e9cff267d --- /dev/null +++ b/src/Shared/Diagnostics/ActivityCreator.cs @@ -0,0 +1,123 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; + +namespace Microsoft.AspNetCore.Shared; + +internal static class ActivityCreator +{ + /// + /// Create an activity with details received from a remote source. + /// + public static Activity? CreateFromRemote( + ActivitySource activitySource, + DistributedContextPropagator propagator, + object distributedContextCarrier, + DistributedContextPropagator.PropagatorGetterCallback propagatorGetter, + string activityName, + ActivityKind kind, + IEnumerable>? tags, + IEnumerable? links, + bool diagnosticsOrLoggingEnabled) + { + Activity? activity = null; + string? requestId = null; + string? traceState = null; + + if (activitySource.HasListeners()) + { + propagator.ExtractTraceIdAndState( + distributedContextCarrier, + propagatorGetter, + out requestId, + out traceState); + + if (ActivityContext.TryParse(requestId, traceState, isRemote: true, out ActivityContext context)) + { + // The requestId used the W3C ID format. Unfortunately, the ActivitySource.CreateActivity overload that + // takes a string parentId never sets HasRemoteParent to true. We work around that by calling the + // ActivityContext overload instead which sets HasRemoteParent to parentContext.IsRemote. + // https://github.com/dotnet/aspnetcore/pull/41568#discussion_r868733305 + activity = activitySource.CreateActivity(activityName, kind, context, tags: tags, links: links); + } + else + { + // Pass in the ID we got from the headers if there was one. + activity = activitySource.CreateActivity(activityName, kind, string.IsNullOrEmpty(requestId) ? null : requestId, tags: tags, links: links); + } + } + + if (activity is null) + { + // CreateActivity didn't create an Activity (this is an optimization for the + // case when there are no listeners). Let's create it here if needed. + if (diagnosticsOrLoggingEnabled) + { + // Note that there is a very small chance that propagator has already been called. + // Requires that the activity source had listened, but it didn't create an activity. + // Can only happen if there is a race between HasListeners and CreateActivity calls, + // and someone removing the listener. + // + // The only negative of calling the propagator twice is a small performance hit. + // It's small and unlikely so it's not worth trying to optimize. + propagator.ExtractTraceIdAndState( + distributedContextCarrier, + propagatorGetter, + out requestId, + out traceState); + + activity = new Activity(activityName); + if (!string.IsNullOrEmpty(requestId)) + { + activity.SetParentId(requestId); + } + if (tags != null) + { + foreach (var tag in tags) + { + activity.AddTag(tag.Key, tag.Value); + } + } + if (links != null) + { + foreach (var link in links) + { + activity.AddLink(link); + } + } + } + else + { + return null; + } + } + + // The trace id was successfully extracted, so we can set the trace state + // https://www.w3.org/TR/trace-context/#tracestate-header + if (!string.IsNullOrEmpty(requestId)) + { + if (!string.IsNullOrEmpty(traceState)) + { + activity.TraceStateString = traceState; + } + } + + // Baggage can be used regardless of whether a distributed trace id was present on the inbound request. + // https://www.w3.org/TR/baggage/#abstract + var baggage = propagator.ExtractBaggage(distributedContextCarrier, propagatorGetter); + + // AddBaggage adds items at the beginning of the list, so we need to add them in reverse to keep the same order as the client + // By contract, the propagator has already reversed the order of items so we need not reverse it again + // Order could be important if baggage has two items with the same key (that is allowed by the contract) + if (baggage is not null) + { + foreach (var baggageItem in baggage) + { + activity.AddBaggage(baggageItem.Key, baggageItem.Value); + } + } + + return activity; + } +} diff --git a/src/Shared/SignalR/InProcessTestServer.cs b/src/Shared/SignalR/InProcessTestServer.cs index 533709b20ff1..f2fe303e8fca 100644 --- a/src/Shared/SignalR/InProcessTestServer.cs +++ b/src/Shared/SignalR/InProcessTestServer.cs @@ -27,6 +27,8 @@ public abstract class InProcessTestServer : IAsyncDisposable public abstract string Url { get; } + public abstract IServiceProvider Services { get; } + public abstract ValueTask DisposeAsync(); } @@ -54,6 +56,8 @@ internal override event Action ServerLogged public override string Url => _url; + public override IServiceProvider Services => _host.Services; + public static async Task> StartServer(ILoggerFactory loggerFactory, Action configureKestrelServerOptions = null, IDisposable disposable = null) { var server = new InProcessTestServer(loggerFactory, configureKestrelServerOptions, disposable); diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index e6d343cd7465..a7c1dd6abfff 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -1037,6 +1037,7 @@ private async Task InvokeCore(ConnectionState connectionState, string methodName // Client invocations are always blocking var invocationMessage = new InvocationMessage(irq.InvocationId, methodName, args, streams); + InjectHeaders(invocationMessage); Log.RegisteringInvocation(_logger, irq.InvocationId); connectionState.AddInvocation(irq); @@ -1063,6 +1064,7 @@ private async Task InvokeStreamCore(ConnectionState connectionState, string meth Log.PreparingStreamingInvocation(_logger, irq.InvocationId, methodName, irq.ResultType.FullName!, args.Length); var invocationMessage = new StreamInvocationMessage(irq.InvocationId, methodName, args, streams); + InjectHeaders(invocationMessage); Log.RegisteringInvocation(_logger, irq.InvocationId); @@ -1083,6 +1085,25 @@ private async Task InvokeStreamCore(ConnectionState connectionState, string meth } } + private static void InjectHeaders(HubInvocationMessage invocationMessage) + { + // TODO: Change when SignalR client has an activity. + // This sends info about the current activity, regardless of the activity source, to the SignalR server. + // When SignalR client supports client activities this logic should be updated to only send headers + // if the SignalR client activity is created. The goal is to match the behavior of distributed tracing in HttpClient. + if (Activity.Current is { } currentActivity) + { + DistributedContextPropagator.Current.Inject(currentActivity, invocationMessage, static (carrier, key, value) => + { + if (carrier is HubInvocationMessage invocationMessage) + { + invocationMessage.Headers ??= new Dictionary(); + invocationMessage.Headers[key] = value; + } + }); + } + } + private async Task SendHubMessage(ConnectionState connectionState, HubMessage hubMessage, CancellationToken cancellationToken = default) { _state.AssertConnectionValid(); diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs index f656136b81ed..f36614e184a8 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; using System.Net; using System.Net.Http; using System.Net.WebSockets; @@ -10,10 +11,11 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http.Connections; using Microsoft.AspNetCore.Http.Connections.Client; +using Microsoft.AspNetCore.InternalTesting; +using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.AspNetCore.SignalR.Test.Internal; using Microsoft.AspNetCore.SignalR.Tests; -using Microsoft.AspNetCore.InternalTesting; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Testing; @@ -113,6 +115,131 @@ public async Task CheckFixedMessage(string protocolName, HttpTransportType trans } } + [Theory] + [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] + public async Task InvokeAsync_SendTraceHeader(string protocolName, HttpTransportType transportType, string path) + { + var protocol = HubProtocols[protocolName]; + await using (var server = await StartServer()) + { + var channel = Channel.CreateUnbounded(); + var serverSource = server.Services.GetRequiredService().ActivitySource; + + using var listener = new ActivityListener + { + ShouldListenTo = activitySource => ReferenceEquals(activitySource, serverSource), + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllData, + ActivityStarted = activity => channel.Writer.TryWrite(activity) + }; + ActivitySource.AddActivityListener(listener); + + var connectionBuilder = new HubConnectionBuilder() + .WithLoggerFactory(LoggerFactory) + .WithUrl(server.Url + path, transportType); + connectionBuilder.Services.AddSingleton(protocol); + + var connection = connectionBuilder.Build(); + + Activity clientActivity1 = null; + Activity clientActivity2 = null; + try + { + await connection.StartAsync().DefaultTimeout(); + + // Invocation 1 + try + { + clientActivity1 = new Activity("ClientActivity1"); + clientActivity1.AddBaggage("baggage-1", "value-1"); + clientActivity1.Start(); + + var result = await connection.InvokeAsync(nameof(TestHub.HelloWorld)).DefaultTimeout(); + + Assert.Equal("Hello World!", result); + } + finally + { + clientActivity1?.Stop(); + } + + // Invocation 2 + try + { + clientActivity2 = new Activity("ClientActivity2"); + clientActivity2.AddBaggage("baggage-2", "value-2"); + clientActivity2.Start(); + + var result = await connection.InvokeAsync(nameof(TestHub.HelloWorld)).DefaultTimeout(); + + Assert.Equal("Hello World!", result); + } + finally + { + clientActivity2?.Stop(); + } + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().DefaultTimeout(); + } + + var activities = await channel.Reader.ReadAtLeastAsync(minimumCount: 4).DefaultTimeout(); + + var hubName = path switch + { + "/default" => typeof(TestHub).FullName, + "/hubT" => typeof(TestHubT).FullName, + "/dynamic" => typeof(DynamicTestHub).FullName, + _ => throw new InvalidOperationException("Unexpected path: " + path) + }; + + Assert.Collection(activities, + a => + { + Assert.Equal(SignalRServerActivitySource.OnConnected, a.OperationName); + Assert.Equal("Microsoft.AspNetCore.Hosting.HttpRequestIn", a.Parent.OperationName); + Assert.False(a.HasRemoteParent); + Assert.Empty(a.Baggage); + }, + a => + { + Assert.Equal(SignalRServerActivitySource.InvocationIn, a.OperationName); + Assert.Equal(clientActivity1.Id, a.ParentId); + Assert.True(a.HasRemoteParent); + Assert.Collection(a.Baggage, + b => + { + Assert.Equal("baggage-1", b.Key); + Assert.Equal("value-1", b.Value); + }); + }, + a => + { + Assert.Equal(SignalRServerActivitySource.InvocationIn, a.OperationName); + Assert.Equal(clientActivity2.Id, a.ParentId); + Assert.True(a.HasRemoteParent); + Assert.Collection(a.Baggage, + b => + { + Assert.Equal("baggage-2", b.Key); + Assert.Equal("value-2", b.Value); + }); + }, + a => + { + Assert.Equal(SignalRServerActivitySource.OnDisconnected, a.OperationName); + Assert.Equal("Microsoft.AspNetCore.Hosting.HttpRequestIn", a.Parent.OperationName); + Assert.False(a.HasRemoteParent); + Assert.Empty(a.Baggage); + }); + } + } + [Fact] public async Task ServerRejectsClientWithOldProtocol() { @@ -469,6 +596,97 @@ public async Task StreamAsyncCoreTest(string protocolName, HttpTransportType tra } } + [Theory] + [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] + public async Task StreamAsyncCore_SendTraceHeader(string protocolName, HttpTransportType transportType, string path) + { + var protocol = HubProtocols[protocolName]; + await using (var server = await StartServer()) + { + var channel = Channel.CreateUnbounded(); + var serverSource = server.Services.GetRequiredService().ActivitySource; + + using var listener = new ActivityListener + { + ShouldListenTo = activitySource => ReferenceEquals(activitySource, serverSource), + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllData, + ActivityStarted = activity => channel.Writer.TryWrite(activity) + }; + ActivitySource.AddActivityListener(listener); + + var connection = CreateHubConnection(server.Url, path, transportType, protocol, LoggerFactory); + + Activity clientActivity = null; + try + { + await connection.StartAsync().DefaultTimeout(); + + clientActivity = new Activity("ClientActivity"); + clientActivity.AddBaggage("baggage-1", "value-1"); + clientActivity.Start(); + + var expectedValue = 0; + var streamTo = 5; + var asyncEnumerable = connection.StreamAsyncCore("Stream", new object[] { streamTo }); + await foreach (var streamValue in asyncEnumerable) + { + Assert.Equal(expectedValue, streamValue); + expectedValue++; + } + + Assert.Equal(streamTo, expectedValue); + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + clientActivity?.Stop(); + await connection.DisposeAsync().DefaultTimeout(); + } + + var activities = await channel.Reader.ReadAtLeastAsync(minimumCount: 3).DefaultTimeout(); + + var hubName = path switch + { + "/default" => typeof(TestHub).FullName, + "/hubT" => typeof(TestHubT).FullName, + "/dynamic" => typeof(DynamicTestHub).FullName, + _ => throw new InvalidOperationException("Unexpected path: " + path) + }; + + Assert.Collection(activities, + a => + { + Assert.Equal(SignalRServerActivitySource.OnConnected, a.OperationName); + Assert.Equal("Microsoft.AspNetCore.Hosting.HttpRequestIn", a.Parent.OperationName); + Assert.False(a.HasRemoteParent); + Assert.Empty(a.Baggage); + }, + a => + { + Assert.Equal(SignalRServerActivitySource.InvocationIn, a.OperationName); + Assert.Equal(clientActivity.Id, a.ParentId); + Assert.True(a.HasRemoteParent); + Assert.Collection(a.Baggage, + b => + { + Assert.Equal("baggage-1", b.Key); + Assert.Equal("value-1", b.Value); + }); + }, + a => + { + Assert.Equal(SignalRServerActivitySource.OnDisconnected, a.OperationName); + Assert.Equal("Microsoft.AspNetCore.Hosting.HttpRequestIn", a.Parent.OperationName); + Assert.False(a.HasRemoteParent); + Assert.Empty(a.Baggage); + }); + } + } + [Theory] [InlineData("json")] [InlineData("messagepack")] diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs index 142b83c353ba..738762bcac45 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs @@ -1,13 +1,10 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; -using System.IO; +using System.Diagnostics; using System.Threading.Channels; -using System.Threading.Tasks; -using Microsoft.AspNetCore.SignalR.Tests; using Microsoft.AspNetCore.InternalTesting; -using Xunit; +using Microsoft.AspNetCore.SignalR.Tests; namespace Microsoft.AspNetCore.SignalR.Client.Tests; @@ -143,6 +140,34 @@ public async Task InvokeSendsAnInvocationMessage() } } + [Fact] + public async Task InvokeSendsAnInvocationMessage_SendTraceHeaders() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + using var clientActivity = new Activity("ClientActivity"); + clientActivity.Start(); + + var invokeTask = hubConnection.InvokeAsync("Foo"); + + var invokeMessage = await connection.ReadSentJsonAsync().DefaultTimeout(); + var traceParent = (string)invokeMessage["headers"]["traceparent"]; + + Assert.Equal(clientActivity.Id, traceParent); + + Assert.Equal(TaskStatus.WaitingForActivation, invokeTask.Status); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } + [Fact] public async Task ReceiveCloseMessageWithoutErrorWillCloseHubConnection() { @@ -228,6 +253,36 @@ public async Task StreamSendsAnInvocationMessage() } } + [Fact] + public async Task StreamSendsAnInvocationMessage_SendTraceHeaders() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + using var clientActivity = new Activity("ClientActivity"); + clientActivity.Start(); + + var channel = await hubConnection.StreamAsChannelAsync("Foo").DefaultTimeout(); + + var invokeMessage = await connection.ReadSentJsonAsync().DefaultTimeout(); + var traceParent = (string)invokeMessage["headers"]["traceparent"]; + + Assert.Equal(clientActivity.Id, traceParent); + + // Complete the channel + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).DefaultTimeout(); + await channel.Completion.DefaultTimeout(); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } + [Fact] public async Task InvokeCompletedWhenCompletionMessageReceived() { diff --git a/src/SignalR/common/testassets/Tests.Utils/ChannelExtensions.cs b/src/SignalR/common/testassets/Tests.Utils/ChannelExtensions.cs index f52828ab1d1d..1afb51b88e3d 100644 --- a/src/SignalR/common/testassets/Tests.Utils/ChannelExtensions.cs +++ b/src/SignalR/common/testassets/Tests.Utils/ChannelExtensions.cs @@ -31,4 +31,34 @@ public static async Task> ReadAndCollectAllAsync(this ChannelReader> ReadAtLeastAsync(this ChannelReader reader, int minimumCount, CancellationToken cancellationToken = default) + { + if (minimumCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(minimumCount), "minimumCount must be greater than zero."); + } + + var items = new List(); + + while (items.Count < minimumCount && !cancellationToken.IsCancellationRequested) + { + while (reader.TryRead(out var item)) + { + items.Add(item); + if (items.Count >= minimumCount) + { + return items; + } + } + + var readTask = reader.WaitToReadAsync(cancellationToken).AsTask(); + if (!await readTask.ConfigureAwait(false)) + { + throw new InvalidOperationException($"Channel ended after writing {items.Count} items."); + } + } + + return items; + } } diff --git a/src/SignalR/common/testassets/Tests.Utils/TestClient.cs b/src/SignalR/common/testassets/Tests.Utils/TestClient.cs index 3cce78132fa5..69c45deccd43 100644 --- a/src/SignalR/common/testassets/Tests.Utils/TestClient.cs +++ b/src/SignalR/common/testassets/Tests.Utils/TestClient.cs @@ -99,6 +99,12 @@ public async Task> StreamAsync(string methodName, string[] str return await ListenAllAsync(invocationId); } + public async Task> StreamAsync(string methodName, string[] streamIds, IDictionary headers, params object[] args) + { + var invocationId = await SendStreamInvocationAsync(methodName, streamIds, headers, args); + return await ListenAllAsync(invocationId); + } + public async Task> ListenAllAsync(string invocationId) { var result = new List(); @@ -185,10 +191,20 @@ public Task SendInvocationAsync(string methodName, params object[] args) return SendInvocationAsync(methodName, nonBlocking: false, args: args); } + public Task SendInvocationAsync(string methodName, IDictionary headers, params object[] args) + { + return SendInvocationAsync(methodName, nonBlocking: false, headers: headers, args: args); + } + public Task SendInvocationAsync(string methodName, bool nonBlocking, params object[] args) + { + return SendInvocationAsync(methodName, nonBlocking: nonBlocking, headers: null, args: args); + } + + public Task SendInvocationAsync(string methodName, bool nonBlocking, IDictionary headers, params object[] args) { var invocationId = nonBlocking ? null : GetInvocationId(); - return SendHubMessageAsync(new InvocationMessage(invocationId, methodName, args)); + return SendHubMessageAsync(new InvocationMessage(invocationId, methodName, args) { Headers = headers }); } public Task SendStreamInvocationAsync(string methodName, params object[] args) @@ -197,9 +213,14 @@ public Task SendStreamInvocationAsync(string methodName, params object[] } public Task SendStreamInvocationAsync(string methodName, string[] streamIds, params object[] args) + { + return SendStreamInvocationAsync(methodName, streamIds: streamIds, headers: null, args); + } + + public Task SendStreamInvocationAsync(string methodName, string[] streamIds, IDictionary headers, params object[] args) { var invocationId = GetInvocationId(); - return SendHubMessageAsync(new StreamInvocationMessage(invocationId, methodName, args, streamIds)); + return SendHubMessageAsync(new StreamInvocationMessage(invocationId, methodName, args, streamIds) { Headers = headers }); } public Task BeginUploadStreamAsync(string invocationId, string methodName, string[] streamIds, params object[] args) diff --git a/src/SignalR/server/Core/src/HubConnectionHandler.cs b/src/SignalR/server/Core/src/HubConnectionHandler.cs index a56cd9a90835..ce95fb48192b 100644 --- a/src/SignalR/server/Core/src/HubConnectionHandler.cs +++ b/src/SignalR/server/Core/src/HubConnectionHandler.cs @@ -136,10 +136,6 @@ public override async Task OnConnectedAsync(ConnectionContext connection) OriginalActivity = Activity.Current, }; - // Get off the parent span. - // This is likely the Http Request span and we want Hub method invocations to not be collected under a long running span. - Activity.Current = null; - var resolvedSupportedProtocols = (supportedProtocols as IReadOnlyList) ?? supportedProtocols.ToList(); if (!await connectionContext.HandshakeAsync(handshakeTimeout, resolvedSupportedProtocols, _protocolResolver, _userIdProvider, _enableDetailedErrors)) { diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index c511d5ac3765..24c675a5aa8f 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -9,6 +9,7 @@ using System.Threading.Channels; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Internal; +using Microsoft.AspNetCore.Shared; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Internal; @@ -91,7 +92,7 @@ public override async Task OnConnectedAsync(HubConnectionContext connection) // OnConnectedAsync won't work with client results (ISingleClientProxy.InvokeAsync) InitializeHub(hub, connection, invokeAllowed: false); - activity = StartActivity(connection, scope.ServiceProvider, nameof(hub.OnConnectedAsync)); + activity = StartActivity(SignalRServerActivitySource.OnConnected, ActivityKind.Internal, linkedActivity: null, scope.ServiceProvider, nameof(hub.OnConnectedAsync), headers: null, _logger); if (_onConnectedMiddleware != null) { @@ -126,7 +127,7 @@ public override async Task OnDisconnectedAsync(HubConnectionContext connection, { InitializeHub(hub, connection); - activity = StartActivity(connection, scope.ServiceProvider, nameof(hub.OnDisconnectedAsync)); + activity = StartActivity(SignalRServerActivitySource.OnDisconnected, ActivityKind.Internal, linkedActivity: null, scope.ServiceProvider, nameof(hub.OnDisconnectedAsync), headers: null, _logger); if (_onDisconnectedMiddleware != null) { @@ -394,9 +395,16 @@ static async Task ExecuteInvocation(DefaultHubDispatcher dispatcher, var logger = dispatcher._logger; var enableDetailedErrors = dispatcher._enableDetailedErrors; + // Hub invocation gets its parent from a remote source. Clear any current activity and restore it later. + var previousActivity = Activity.Current; + if (previousActivity != null) + { + Activity.Current = null; + } + // Use hubMethodInvocationMessage.Target instead of methodExecutor.MethodInfo.Name // We want to take HubMethodNameAttribute into account which will be the same as what the invocation target is - var activity = StartActivity(connection, scope.ServiceProvider, hubMethodInvocationMessage.Target); + var activity = StartActivity(SignalRServerActivitySource.InvocationIn, ActivityKind.Server, connection.OriginalActivity, scope.ServiceProvider, hubMethodInvocationMessage.Target, hubMethodInvocationMessage.Headers, logger); object? result; try @@ -417,6 +425,11 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, { activity?.Stop(); + if (Activity.Current != previousActivity) + { + Activity.Current = previousActivity; + } + // Stream response handles cleanup in StreamResultsAsync // And normal invocations handle cleanup below in the finally if (isStreamCall) @@ -502,7 +515,14 @@ private async Task StreamAsync(string invocationId, HubConnectionContext connect streamCts ??= CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted); - var activity = StartActivity(connection, scope.ServiceProvider, hubMethodInvocationMessage.Target); + // Hub invocation gets its parent from a remote source. Clear any current activity and restore it later. + var previousActivity = Activity.Current; + if (previousActivity != null) + { + Activity.Current = null; + } + + var activity = StartActivity(SignalRServerActivitySource.InvocationIn, ActivityKind.Server, connection.OriginalActivity, scope.ServiceProvider, hubMethodInvocationMessage.Target, hubMethodInvocationMessage.Headers, _logger); try { @@ -569,6 +589,11 @@ private async Task StreamAsync(string invocationId, HubConnectionContext connect { activity?.Stop(); + if (Activity.Current != previousActivity) + { + Activity.Current = previousActivity; + } + await CleanupInvocation(connection, hubMethodInvocationMessage, hubActivator, hub, scope); streamCts.Dispose(); @@ -806,34 +831,64 @@ public override IReadOnlyList GetParameterTypes(string methodName) // Starts an Activity for a Hub method invocation and sets up all the tags and other state. // Make sure to call Activity.Stop() once the Hub method completes, and consider calling SetActivityError on exception. - private static Activity? StartActivity(HubConnectionContext connectionContext, IServiceProvider serviceProvider, string methodName) + private static Activity? StartActivity(string operationName, ActivityKind kind, Activity? linkedActivity, IServiceProvider serviceProvider, string methodName, IDictionary? headers, ILogger logger) { - if (serviceProvider.GetService() is SignalRServerActivitySource signalRActivitySource - && signalRActivitySource.ActivitySource.HasListeners()) - { - var requestContext = connectionContext.OriginalActivity?.Context; - - var activity = signalRActivitySource.ActivitySource.CreateActivity(SignalRServerActivitySource.InvocationIn, ActivityKind.Server, parentId: null, - // https://github.com/open-telemetry/semantic-conventions/blob/main/docs/rpc/rpc-spans.md#server-attributes - tags: [ - new("rpc.method", methodName), - new("rpc.system", "signalr"), - new("rpc.service", _fullHubName), - // See https://github.com/dotnet/aspnetcore/blob/027c60168383421750f01e427e4f749d0684bc02/src/Servers/Kestrel/Core/src/Internal/Infrastructure/KestrelMetrics.cs#L308 - // And https://github.com/dotnet/aspnetcore/issues/43786 - //new("server.address", ...), - ], - links: requestContext.HasValue ? [new ActivityLink(requestContext.Value)] : null); - if (activity != null) - { - activity.DisplayName = $"{_fullHubName}/{methodName}"; - activity.Start(); - } + var activitySource = serviceProvider.GetService()?.ActivitySource; + if (activitySource is null) + { + return null; + } - return activity; + var loggingEnabled = logger.IsEnabled(LogLevel.Critical); + if (!activitySource.HasListeners() && !loggingEnabled) + { + return null; } - return null; + IEnumerable> tags = + [ + new("rpc.method", methodName), + new("rpc.system", "signalr"), + new("rpc.service", _fullHubName), + // See https://github.com/dotnet/aspnetcore/blob/027c60168383421750f01e427e4f749d0684bc02/src/Servers/Kestrel/Core/src/Internal/Infrastructure/KestrelMetrics.cs#L308 + // And https://github.com/dotnet/aspnetcore/issues/43786 + //new("server.address", ...), + ]; + IEnumerable? links = (linkedActivity is not null) ? [new ActivityLink(linkedActivity.Context)] : null; + + Activity? activity; + if (headers != null) + { + var propagator = serviceProvider.GetService() ?? DistributedContextPropagator.Current; + + activity = ActivityCreator.CreateFromRemote( + activitySource, + propagator, + headers, + static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable? fieldValues) => + { + fieldValues = default; + var headers = (IDictionary)carrier!; + headers.TryGetValue(fieldName, out fieldValue); + }, + operationName, + kind, + tags, + links, + loggingEnabled); + } + else + { + activity = activitySource.CreateActivity(operationName, kind, parentId: null, tags: tags, links: links); + } + + if (activity is not null) + { + activity.DisplayName = $"{_fullHubName}/{methodName}"; + activity.Start(); + } + + return activity; } private static void SetActivityError(Activity? activity, Exception ex) diff --git a/src/SignalR/server/Core/src/Internal/SignalRServerActivitySource.cs b/src/SignalR/server/Core/src/Internal/SignalRServerActivitySource.cs index 606f36316911..3e4034313bea 100644 --- a/src/SignalR/server/Core/src/Internal/SignalRServerActivitySource.cs +++ b/src/SignalR/server/Core/src/Internal/SignalRServerActivitySource.cs @@ -12,6 +12,8 @@ internal sealed class SignalRServerActivitySource { internal const string Name = "Microsoft.AspNetCore.SignalR.Server"; internal const string InvocationIn = $"{Name}.InvocationIn"; + internal const string OnConnected = $"{Name}.OnConnected"; + internal const string OnDisconnected = $"{Name}.OnDisconnected"; public ActivitySource ActivitySource { get; } = new ActivitySource(Name); } diff --git a/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj b/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj index f4276f93dcdb..f5016566f977 100644 --- a/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj +++ b/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj @@ -21,6 +21,7 @@ + @@ -36,6 +37,7 @@ + diff --git a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.Activity.cs b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.Activity.cs index acd845ff1a24..18af7a5e82e7 100644 --- a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.Activity.cs +++ b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.Activity.cs @@ -51,7 +51,7 @@ public async Task HubMethodInvokesCreateActivities() var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); var activity = Assert.Single(activities); - AssertHubMethodActivity(activity, nameof(MethodHub.OnConnectedAsync), mockHttpRequestActivity); + AssertHubMethodActivity(activity, mockHttpRequestActivity, nameof(MethodHub.OnConnectedAsync), linkedActivity: null, activityName: SignalRServerActivitySource.OnConnected); await client.SendInvocationAsync(nameof(MethodHub.Echo), "test").DefaultTimeout(); @@ -60,19 +60,19 @@ public async Task HubMethodInvokesCreateActivities() Assert.Equal("test", res); Assert.Equal(2, activities.Count); - AssertHubMethodActivity(activities[1], nameof(MethodHub.Echo), mockHttpRequestActivity); + AssertHubMethodActivity(activities[1], parent: null, nameof(MethodHub.Echo), mockHttpRequestActivity); await client.SendInvocationAsync("RenamedMethod").DefaultTimeout(); Assert.IsType(await client.ReadAsync().DefaultTimeout()); Assert.Equal(3, activities.Count); - AssertHubMethodActivity(activities[2], "RenamedMethod", mockHttpRequestActivity); + AssertHubMethodActivity(activities[2], parent: null, "RenamedMethod", mockHttpRequestActivity); await client.SendInvocationAsync(nameof(MethodHub.ActivityMethod)).DefaultTimeout(); Assert.IsType(await client.ReadAsync().DefaultTimeout()); Assert.Equal(5, activities.Count); - AssertHubMethodActivity(activities[3], nameof(MethodHub.ActivityMethod), mockHttpRequestActivity); + AssertHubMethodActivity(activities[3], parent: null, nameof(MethodHub.ActivityMethod), mockHttpRequestActivity); Assert.NotNull(activities[4].Parent); Assert.Equal("inner", activities[4].OperationName); Assert.Equal(activities[3], activities[4].Parent); @@ -83,7 +83,80 @@ public async Task HubMethodInvokesCreateActivities() } Assert.Equal(6, activities.Count); - AssertHubMethodActivity(activities[5], nameof(MethodHub.OnDisconnectedAsync), mockHttpRequestActivity); + AssertHubMethodActivity(activities[5], mockHttpRequestActivity, nameof(MethodHub.OnDisconnectedAsync), linkedActivity: null, activityName: SignalRServerActivitySource.OnDisconnected); + } + } + + [Fact] + public async Task HubMethodInvokesCreateActivities_ReadTraceHeaders() + { + using (StartVerifiableLog()) + { + var activities = new List(); + var testSource = new ActivitySource("test_source"); + var hubMethodTestSource = new TestActivitySource() { ActivitySource = new ActivitySource("test_custom") }; + + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + builder.AddSingleton(hubMethodTestSource); + + // Provided by hosting layer normally + builder.AddSingleton(testSource); + }, LoggerFactory); + var signalrSource = serviceProvider.GetRequiredService().ActivitySource; + + using var listener = new ActivityListener + { + ShouldListenTo = activitySource => (ReferenceEquals(activitySource, testSource)) + || ReferenceEquals(activitySource, hubMethodTestSource.ActivitySource) || ReferenceEquals(activitySource, signalrSource), + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllData, + ActivityStarted = activities.Add + }; + ActivitySource.AddActivityListener(listener); + + var mockHttpRequestActivity = new Activity("HttpRequest"); + mockHttpRequestActivity.Start(); + Activity.Current = mockHttpRequestActivity; + + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + + var activity = Assert.Single(activities); + AssertHubMethodActivity(activity, mockHttpRequestActivity, nameof(MethodHub.OnConnectedAsync), linkedActivity: null, activityName: SignalRServerActivitySource.OnConnected); + + var headers = new Dictionary + { + {"traceparent", "00-0123456789abcdef0123456789abcdef-0123456789abcdef-01"}, + {"tracestate", "TraceState1"}, + {"baggage", "Key1=value1, Key2=value2"} + }; + + await client.SendInvocationAsync(nameof(MethodHub.Echo), headers, "test").DefaultTimeout(); + + var completionMessage = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + var res = (string)completionMessage.Result; + Assert.Equal("test", res); + + Assert.Equal(2, activities.Count); + var invocationActivity = activities[1]; + AssertHubMethodActivity(invocationActivity, parent: null, nameof(MethodHub.Echo), mockHttpRequestActivity); + + Assert.True(invocationActivity.HasRemoteParent); + Assert.Equal(ActivityIdFormat.W3C, invocationActivity.IdFormat); + Assert.Equal("0123456789abcdef0123456789abcdef", invocationActivity.TraceId.ToHexString()); + Assert.Equal("0123456789abcdef", invocationActivity.ParentSpanId.ToHexString()); + Assert.Equal("TraceState1", invocationActivity.TraceStateString); + + client.Dispose(); + + await connectionHandlerTask; + } + + Assert.Equal(3, activities.Count); + AssertHubMethodActivity(activities[2], mockHttpRequestActivity, nameof(MethodHub.OnDisconnectedAsync), linkedActivity: null, activityName: SignalRServerActivitySource.OnDisconnected); } } @@ -124,17 +197,17 @@ public async Task StreamingHubMethodCreatesActivities() var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); var activity = Assert.Single(activities); - AssertHubMethodActivity(activity, nameof(StreamingHub.OnConnectedAsync), mockHttpRequestActivity); + AssertHubMethodActivity(activity, mockHttpRequestActivity, nameof(StreamingHub.OnConnectedAsync), linkedActivity: null, activityName: SignalRServerActivitySource.OnConnected); _ = await client.StreamAsync(nameof(StreamingHub.CounterChannel), 3).DefaultTimeout(); Assert.Equal(2, activities.Count); - AssertHubMethodActivity(activities[1], nameof(StreamingHub.CounterChannel), mockHttpRequestActivity); + AssertHubMethodActivity(activities[1], parent: null, nameof(StreamingHub.CounterChannel), mockHttpRequestActivity); _ = await client.StreamAsync("RenamedCounterChannel", 3).DefaultTimeout(); Assert.Equal(3, activities.Count); - AssertHubMethodActivity(activities[2], "RenamedCounterChannel", mockHttpRequestActivity); + AssertHubMethodActivity(activities[2], parent: null, "RenamedCounterChannel", mockHttpRequestActivity); client.Dispose(); @@ -142,7 +215,75 @@ public async Task StreamingHubMethodCreatesActivities() } Assert.Equal(4, activities.Count); - AssertHubMethodActivity(activities[3], nameof(StreamingHub.OnDisconnectedAsync), mockHttpRequestActivity); + AssertHubMethodActivity(activities[3], mockHttpRequestActivity, nameof(StreamingHub.OnDisconnectedAsync), linkedActivity: null, activityName: SignalRServerActivitySource.OnDisconnected); + } + } + + [Fact] + public async Task StreamingHubMethodCreatesActivities_ReadTraceHeaders() + { + using (StartVerifiableLog()) + { + var activities = new List(); + var testSource = new ActivitySource("test_source"); + + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + // Provided by hosting layer normally + builder.AddSingleton(testSource); + }, LoggerFactory); + var signalrSource = serviceProvider.GetRequiredService().ActivitySource; + + using var listener = new ActivityListener + { + ShouldListenTo = activitySource => (ReferenceEquals(activitySource, testSource)) + || ReferenceEquals(activitySource, signalrSource), + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllData, + ActivityStarted = activities.Add + }; + ActivitySource.AddActivityListener(listener); + + var mockHttpRequestActivity = new Activity("HttpRequest"); + mockHttpRequestActivity.Start(); + Activity.Current = mockHttpRequestActivity; + + var connectionHandler = serviceProvider.GetService>(); + Mock invocationBinder = new Mock(); + invocationBinder.Setup(b => b.GetStreamItemType(It.IsAny())).Returns(typeof(int)); + + using (var client = new TestClient(invocationBinder: invocationBinder.Object)) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + + var activity = Assert.Single(activities); + AssertHubMethodActivity(activity, mockHttpRequestActivity, nameof(StreamingHub.OnConnectedAsync), linkedActivity: null, activityName: SignalRServerActivitySource.OnConnected); + + var headers = new Dictionary + { + {"traceparent", "00-0123456789abcdef0123456789abcdef-0123456789abcdef-01"}, + {"tracestate", "TraceState1"}, + {"baggage", "Key1=value1, Key2=value2"} + }; + + _ = await client.StreamAsync(nameof(StreamingHub.CounterChannel), streamIds: null, headers: headers, 3).DefaultTimeout(); + + Assert.Equal(2, activities.Count); + var invocationActivity = activities[1]; + AssertHubMethodActivity(invocationActivity, parent: null, nameof(StreamingHub.CounterChannel), mockHttpRequestActivity); + + Assert.True(invocationActivity.HasRemoteParent); + Assert.Equal(ActivityIdFormat.W3C, invocationActivity.IdFormat); + Assert.Equal("0123456789abcdef0123456789abcdef", invocationActivity.TraceId.ToHexString()); + Assert.Equal("0123456789abcdef", invocationActivity.ParentSpanId.ToHexString()); + Assert.Equal("TraceState1", invocationActivity.TraceStateString); + + client.Dispose(); + + await connectionHandlerTask; + } + + Assert.Equal(3, activities.Count); + AssertHubMethodActivity(activities[2], mockHttpRequestActivity, nameof(StreamingHub.OnDisconnectedAsync), linkedActivity: null, activityName: SignalRServerActivitySource.OnDisconnected); } } @@ -187,8 +328,8 @@ bool ExpectedErrors(WriteContext writeContext) var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); var activity = Assert.Single(activities); - AssertHubMethodActivity(activity, nameof(OnConnectedThrowsHub.OnConnectedAsync), - mockHttpRequestActivity, exceptionType: typeof(InvalidOperationException)); + AssertHubMethodActivity(activity, mockHttpRequestActivity, nameof(OnConnectedThrowsHub.OnConnectedAsync), + linkedActivity: null, exceptionType: typeof(InvalidOperationException), activityName: SignalRServerActivitySource.OnConnected); } } } @@ -239,8 +380,8 @@ bool ExpectedErrors(WriteContext writeContext) Assert.Equal(2, activities.Count); var activity = activities[1]; - AssertHubMethodActivity(activity, nameof(OnDisconnectedThrowsHub.OnDisconnectedAsync), - mockHttpRequestActivity, exceptionType: typeof(InvalidOperationException)); + AssertHubMethodActivity(activity, mockHttpRequestActivity, nameof(OnDisconnectedThrowsHub.OnDisconnectedAsync), + linkedActivity: null, exceptionType: typeof(InvalidOperationException), activityName: SignalRServerActivitySource.OnDisconnected); } } @@ -288,7 +429,7 @@ bool ExpectedErrors(WriteContext writeContext) Assert.Equal(2, activities.Count); var activity = activities[1]; - AssertHubMethodActivity(activity, nameof(StreamingHub.ExceptionAsyncEnumerable), + AssertHubMethodActivity(activity, parent: null, nameof(StreamingHub.ExceptionAsyncEnumerable), mockHttpRequestActivity, exceptionType: typeof(Exception)); } } @@ -338,18 +479,18 @@ bool ExpectedErrors(WriteContext writeContext) Assert.Equal(2, activities.Count); var activity = activities[1]; - AssertHubMethodActivity(activity, nameof(MethodHub.MethodThatThrows), + AssertHubMethodActivity(activity, parent: null, nameof(MethodHub.MethodThatThrows), mockHttpRequestActivity, exceptionType: typeof(InvalidOperationException)); } } } - private static void AssertHubMethodActivity(Activity activity, string methodName, Activity httpActivity, Type exceptionType = null) + private static void AssertHubMethodActivity(Activity activity, Activity parent, string methodName, Activity linkedActivity, Type exceptionType = null, string activityName = null) { - Assert.Null(activity.Parent); + Assert.Equal(parent, activity.Parent); Assert.True(activity.IsStopped); Assert.Equal(SignalRServerActivitySource.Name, activity.Source.Name); - Assert.Equal(SignalRServerActivitySource.InvocationIn, activity.OperationName); + Assert.Equal(activityName ?? SignalRServerActivitySource.InvocationIn, activity.OperationName); Assert.Equal($"{typeof(THub).FullName}/{methodName}", activity.DisplayName); var tags = activity.Tags.ToArray(); @@ -374,6 +515,9 @@ private static void AssertHubMethodActivity(Activity activity, string meth Assert.Equal(typeof(THub).FullName, tags[2].Value); // Linked to original http request span - Assert.Equal(httpActivity.SpanId, Assert.Single(activity.Links).Context.SpanId); + if (linkedActivity != null) + { + Assert.Equal(linkedActivity.SpanId, Assert.Single(activity.Links).Context.SpanId); + } } }