diff options
author | Brennan Conroy <brecon@microsoft.com> | 2022-08-09 19:45:19 +0300 |
---|---|---|
committer | Brennan Conroy <brecon@microsoft.com> | 2022-08-09 19:45:19 +0300 |
commit | d1cef746caa8af45fb7f40cfbc9b0172d0eeebd0 (patch) | |
tree | c77680231d15994b3782019cc56efec16b9c7cdf | |
parent | 72ee5732f00289191a7f89ce8000854e52de808c (diff) |
Add invocation queue to SignalR to avoid client results blocking receive loopbrecon/queue
3 files changed, 212 insertions, 69 deletions
diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index 04e211d74f..66361d3f92 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -7,6 +7,7 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.IO.Pipelines; using System.Security.Claims; +using System.Threading.Channels; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Features; @@ -85,6 +86,16 @@ public partial class HubConnectionContext { ActiveInvocationLimit = new SemaphoreSlim(maxInvokeLimit, maxInvokeLimit); } + + // TODO: configurable capacity + PendingInvokes = Channel.CreateBounded<(HubMethodInvocationMessage, bool isStreamResponse, HubMethodDescriptor, object?[])> (new BoundedChannelOptions(5) + { + AllowSynchronousContinuations = false, + // We use TryWrite, Wait mode is the only option that returns false when the channel is full + FullMode = BoundedChannelFullMode.Wait, + SingleWriter = true, + SingleReader = true, + }); } internal StreamTracker StreamTracker @@ -101,6 +112,8 @@ public partial class HubConnectionContext } } + internal Channel<(HubMethodInvocationMessage, bool isStreamResponse, HubMethodDescriptor, object?[])> PendingInvokes { get; private set; } + internal HubCallerContext HubCallerContext { get; } internal HubCallerClients HubCallerClients { get; set; } = null!; @@ -108,6 +121,8 @@ public partial class HubConnectionContext internal SemaphoreSlim? ActiveInvocationLimit { get; } + internal Task? InvokeLoop { get; set; } + /// <summary> /// Gets a <see cref="CancellationToken"/> that notifies when the connection is aborted. /// </summary> @@ -677,6 +692,8 @@ public partial class HubConnectionContext { var connection = (HubConnectionContext)state!; + connection.PendingInvokes.Writer.TryComplete(); + try { connection._connectionAbortedTokenSource.Cancel(); diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 59b06dbf10..7ea9da0a51 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -98,6 +98,37 @@ internal sealed partial class DefaultHubDispatcher<THub> : HubDispatcher<THub> w { hubActivator.Release(hub); } + + connection.InvokeLoop = ProcessInvokes(connection); + } + + private async Task ProcessInvokes(HubConnectionContext connection) + { + // TODO: exceptions would close connection previously, figure out if we want to do that here as well + + // TODO: not sure completing channel immediately ends the loop if items still in channel + while (await connection.PendingInvokes.Reader.WaitToReadAsync()) + { + while (connection.PendingInvokes.Reader.TryRead(out var invoke)) + { + var (hubMethodInvocationMessage, isStreamResponse, descriptor, arguments) = invoke; + bool isStreamCall = descriptor.StreamingParameters != null; + if (connection.ActiveInvocationLimit != null && !isStreamCall && !isStreamResponse) + { + await connection.ActiveInvocationLimit.RunAsync(static state => + { + var (dispatcher, descriptor, connection, invocationMessage, arguments) = state; + return dispatcher.Invoke(descriptor, connection, invocationMessage, arguments, isStreamResponse: false, isStreamCall: false); + }, (this, descriptor, connection, hubMethodInvocationMessage, arguments)); + } + else + { + await Invoke(descriptor, connection, hubMethodInvocationMessage, arguments, isStreamResponse, isStreamCall); + } + } + } + + // TODO: cleanup all pending invokes } public override async Task OnDisconnectedAsync(HubConnectionContext connection, Exception? exception) @@ -124,6 +155,8 @@ internal sealed partial class DefaultHubDispatcher<THub> : HubDispatcher<THub> w { hubActivator.Release(hub); } + + await connection.InvokeLoop!; } public override Task DispatchMessageAsync(HubConnectionContext connection, HubMessage hubMessage) @@ -248,31 +281,108 @@ internal sealed partial class DefaultHubDispatcher<THub> : HubDispatcher<THub> w return connection.WriteAsync(CompletionMessage.WithError( hubMethodInvocationMessage.InvocationId, $"Unknown hub method '{hubMethodInvocationMessage.Target}'")).AsTask(); } - else + return Task.CompletedTask; + } + + if (!ValidateInvocationMode(descriptor, isStreamResponse, hubMethodInvocationMessage, out var error)) + { + return connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId!, error)).AsTask(); + } + + var arguments = hubMethodInvocationMessage.Arguments; + CancellationTokenSource? cts = null; + try + { + var clientStreamLength = hubMethodInvocationMessage.StreamIds?.Length ?? 0; + var serverStreamLength = descriptor.StreamingParameters?.Count ?? 0; + if (clientStreamLength != serverStreamLength) { - return Task.CompletedTask; + var ex = new HubException($"Client sent {clientStreamLength} stream(s), Hub method expects {serverStreamLength}."); + Log.InvalidHubParameters(_logger, hubMethodInvocationMessage.Target, ex); + return SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, + ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors)); + } + + if (descriptor.HasSyntheticArguments) + { + bool isStreamCall = descriptor.StreamingParameters != null; + ReplaceArguments(descriptor, hubMethodInvocationMessage, isStreamCall, connection, ref arguments, out cts); + } + + if (isStreamResponse) + { + if (!connection.ActiveRequestCancellationSources.TryAdd(hubMethodInvocationMessage.InvocationId!, cts ?? CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted))) + { + Log.InvocationIdInUse(_logger, hubMethodInvocationMessage.InvocationId!); + error = $"Invocation ID '{hubMethodInvocationMessage.InvocationId}' is already in use."; + + // TODO: tests for cleanup code, nothing fails currently with these 3 locations commented out + + //if (hubMethodInvocationMessage.StreamIds != null) + //{ + // foreach (var stream in hubMethodInvocationMessage.StreamIds) + // { + // connection.StreamTracker.TryRemove(CompletionMessage.Empty(stream)); + // } + //} + + //cts?.Dispose(); + //connection.ActiveRequestCancellationSources.TryRemove(hubMethodInvocationMessage.InvocationId!, out _); + + return connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId!, error)).AsTask(); + } } } - else + catch { - bool isStreamCall = descriptor.StreamingParameters != null; - if (connection.ActiveInvocationLimit != null && !isStreamCall && !isStreamResponse) + //if (hubMethodInvocationMessage.StreamIds != null) + //{ + // foreach (var stream in hubMethodInvocationMessage.StreamIds) + // { + // connection.StreamTracker.TryRemove(CompletionMessage.Empty(stream)); + // } + //} + + //cts?.Dispose(); + //connection.ActiveRequestCancellationSources.TryRemove(hubMethodInvocationMessage.InvocationId!, out _); + + return Task.CompletedTask; + } + + if (!connection.PendingInvokes.Writer.TryWrite((hubMethodInvocationMessage, isStreamResponse, descriptor, arguments))) + { + // Log, dropped invoke call + + //if (hubMethodInvocationMessage.StreamIds != null) + //{ + // foreach (var stream in hubMethodInvocationMessage.StreamIds) + // { + // connection.StreamTracker.TryRemove(CompletionMessage.Empty(stream)); + // } + //} + + //cts?.Dispose(); + //connection.ActiveRequestCancellationSources.TryRemove(hubMethodInvocationMessage.InvocationId!, out _); + + if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId)) { - return connection.ActiveInvocationLimit.RunAsync(static state => - { - var (dispatcher, descriptor, connection, invocationMessage) = state; - return dispatcher.Invoke(descriptor, connection, invocationMessage, isStreamResponse: false, isStreamCall: false); - }, (this, descriptor, connection, hubMethodInvocationMessage)); + // Send an error to the client. Then let the normal completion process occur + return connection.WriteAsync(CompletionMessage.WithError( + hubMethodInvocationMessage.InvocationId, "Invoke dropped.")).AsTask(); } else { - return Invoke(descriptor, connection, hubMethodInvocationMessage, isStreamResponse, isStreamCall); + return Task.CompletedTask; } } + else + { + return Task.CompletedTask; + } } private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection, - HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamResponse, bool isStreamCall) + HubMethodInvocationMessage hubMethodInvocationMessage, object?[] arguments, bool isStreamResponse, bool isStreamCall) { var methodExecutor = descriptor.MethodExecutor; @@ -293,37 +403,19 @@ internal sealed partial class DefaultHubDispatcher<THub> : HubDispatcher<THub> w return; } - if (!await ValidateInvocationMode(descriptor, isStreamResponse, hubMethodInvocationMessage, connection)) - { - return; - } - try { - var clientStreamLength = hubMethodInvocationMessage.StreamIds?.Length ?? 0; - var serverStreamLength = descriptor.StreamingParameters?.Count ?? 0; - if (clientStreamLength != serverStreamLength) - { - var ex = new HubException($"Client sent {clientStreamLength} stream(s), Hub method expects {serverStreamLength}."); - Log.InvalidHubParameters(_logger, hubMethodInvocationMessage.Target, ex); - await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, - ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors)); - return; - } - InitializeHub(hub, connection); Task? invocation = null; - var arguments = hubMethodInvocationMessage.Arguments; - CancellationTokenSource? cts = null; if (descriptor.HasSyntheticArguments) { - ReplaceArguments(descriptor, hubMethodInvocationMessage, isStreamCall, connection, scope, ref arguments, out cts); + ReplaceServiceArguments(descriptor, scope, ref arguments); } if (isStreamResponse) { - _ = StreamAsync(hubMethodInvocationMessage.InvocationId!, connection, arguments, scope, hubActivator, hub, cts, hubMethodInvocationMessage, descriptor); + _ = StreamAsync(hubMethodInvocationMessage.InvocationId!, connection, arguments, scope, hubActivator, hub, hubMethodInvocationMessage, descriptor); } else { @@ -429,20 +521,19 @@ internal sealed partial class DefaultHubDispatcher<THub> : HubDispatcher<THub> w } private async Task StreamAsync(string invocationId, HubConnectionContext connection, object?[] arguments, AsyncServiceScope scope, - IHubActivator<THub> hubActivator, THub hub, CancellationTokenSource? streamCts, HubMethodInvocationMessage hubMethodInvocationMessage, HubMethodDescriptor descriptor) + IHubActivator<THub> hubActivator, THub hub, HubMethodInvocationMessage hubMethodInvocationMessage, HubMethodDescriptor descriptor) { string? error = null; - - streamCts ??= CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted); + if (!connection.ActiveRequestCancellationSources.TryGetValue(invocationId, out var streamCts)) + { + // should not happen (yet) + // maybe if we short-circuit canceled invocations then this might be able to occur + Debug.Assert(false, "stream cts not available when executing streaming hub method"); + return; + } try { - if (!connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts)) - { - Log.InvocationIdInUse(_logger, invocationId); - error = $"Invocation ID '{invocationId}' is already in use."; - return; - } object? result; try @@ -585,17 +676,17 @@ internal sealed partial class DefaultHubDispatcher<THub> : HubDispatcher<THub> w return authorizationResult.Succeeded; } - private async Task<bool> ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamResponse, - HubMethodInvocationMessage hubMethodInvocationMessage, HubConnectionContext connection) + private bool ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamResponse, + HubMethodInvocationMessage hubMethodInvocationMessage, out string? error) { + error = null; if (hubMethodDescriptor.IsStreamResponse && !isStreamResponse) { // Non-null/empty InvocationId? Blocking if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId)) { Log.StreamingMethodCalledWithInvoke(_logger, hubMethodInvocationMessage); - await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId, - $"The client attempted to invoke the streaming '{hubMethodInvocationMessage.Target}' method with a non-streaming invocation.")); + error = $"The client attempted to invoke the streaming '{hubMethodInvocationMessage.Target}' method with a non-streaming invocation."; } return false; @@ -604,8 +695,7 @@ internal sealed partial class DefaultHubDispatcher<THub> : HubDispatcher<THub> w if (!hubMethodDescriptor.IsStreamResponse && isStreamResponse) { Log.NonStreamingMethodCalledWithStream(_logger, hubMethodInvocationMessage); - await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId!, - $"The client attempted to invoke the non-streaming '{hubMethodInvocationMessage.Target}' method with a streaming invocation.")); + error = $"The client attempted to invoke the non-streaming '{hubMethodInvocationMessage.Target}' method with a streaming invocation."; return false; } @@ -613,8 +703,19 @@ internal sealed partial class DefaultHubDispatcher<THub> : HubDispatcher<THub> w return true; } + private static void ReplaceServiceArguments(HubMethodDescriptor descriptor, AsyncServiceScope scope, ref object?[] arguments) + { + for (var parameterPointer = 0; parameterPointer < arguments.Length; parameterPointer++) + { + if (descriptor.IsServiceArgument(parameterPointer)) + { + arguments[parameterPointer] = scope.ServiceProvider.GetRequiredService(descriptor.OriginalParameterTypes![parameterPointer]); + } + } + } + private void ReplaceArguments(HubMethodDescriptor descriptor, HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamCall, - HubConnectionContext connection, AsyncServiceScope scope, ref object?[] arguments, out CancellationTokenSource? cts) + HubConnectionContext connection, ref object?[] arguments, out CancellationTokenSource? cts) { cts = null; // In order to add the synthetic arguments we need a new array because the invocation array is too small (it doesn't know about synthetic arguments) @@ -641,7 +742,8 @@ internal sealed partial class DefaultHubDispatcher<THub> : HubDispatcher<THub> w } else if (descriptor.IsServiceArgument(parameterPointer)) { - arguments[parameterPointer] = scope.ServiceProvider.GetRequiredService(descriptor.OriginalParameterTypes[parameterPointer]); + // replaced later when we create the scope + arguments[parameterPointer] = null; } else if (isStreamCall && ReflectionHelper.IsStreamingType(descriptor.OriginalParameterTypes[parameterPointer], mustBeDirectType: true)) { diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index 8b35d46475..7c395b96c3 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -543,7 +543,11 @@ public partial class HubConnectionHandlerTests : VerifiableLoggedTest using (StartVerifiableLog()) { var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory, - services => services.AddSignalR().AddHubOptions<HubT>(o => o.MaximumReceiveMessageSize = maximumMessageSize)); + services => services.AddSignalR().AddHubOptions<HubT>(o => + { + o.MaximumReceiveMessageSize = maximumMessageSize; + o.EnableDetailedErrors = true; + })); using (var client = new TestClient()) { @@ -554,20 +558,33 @@ public partial class HubConnectionHandlerTests : VerifiableLoggedTest client.Connection.Application.Output.Write(payload3); await client.Connection.Application.Output.FlushAsync(); - // 2 invocations should be processed - var completionMessage = await client.ReadAsync().DefaultTimeout() as CompletionMessage; - Assert.NotNull(completionMessage); - Assert.Equal("1", completionMessage.InvocationId); - Assert.Equal("one", completionMessage.Result); - - completionMessage = await client.ReadAsync().DefaultTimeout() as CompletionMessage; - Assert.NotNull(completionMessage); - Assert.Equal("2", completionMessage.InvocationId); - Assert.Equal("two", completionMessage.Result); - - // We never receive the 3rd message since it was over the maximum message size - CloseMessage closeMessage = await client.ReadAsync().DefaultTimeout() as CloseMessage; - Assert.NotNull(closeMessage); + // 0-2 invocations may be processed, invocations are put in a queue to unblock the receive loop, so the processing of the queue can race with the reading of the bad message and closing of the connection + for (var i = 0; i < 3; ++i) + { + var hubMessage = await client.ReadAsync().DefaultTimeout(); + if (hubMessage is CloseMessage closeMessage) + { + Assert.Equal("Connection closed with an error. InvalidDataException: The maximum message size of 71B was exceeded. The message size can be configured in AddHubOptions.", closeMessage.Error); + break; + } + else if (hubMessage is CompletionMessage completionMessage) + { + if (i == 0) + { + Assert.Equal("1", completionMessage.InvocationId); + Assert.Equal("one", completionMessage.Result); + } + else + { + Assert.Equal("2", completionMessage.InvocationId); + Assert.Equal("two", completionMessage.Result); + } + } + else + { + Assert.True(false); + } + } client.Dispose(); @@ -2964,7 +2981,7 @@ public partial class HubConnectionHandlerTests : VerifiableLoggedTest public PipeWriter Output => _originalDuplexPipe.Output; } - [Fact] + [Fact(Skip = "Invokes no longer block receive loop, figure out what we want to test here")] public async Task HubMethodInvokeDoesNotCountTowardsClientTimeout() { using (StartVerifiableLog()) @@ -3242,6 +3259,7 @@ public partial class HubConnectionHandlerTests : VerifiableLoggedTest builder.AddSingleton(typeof(IHubActivator<>), typeof(CustomHubActivator<>)); }, LoggerFactory); var connectionHandler = serviceProvider.GetService<HubConnectionHandler<StreamingHub>>(); + var hubActivator = serviceProvider.GetService<IHubActivator<StreamingHub>>() as CustomHubActivator<StreamingHub>; using (var client = new TestClient()) { @@ -3250,20 +3268,24 @@ public partial class HubConnectionHandlerTests : VerifiableLoggedTest await client.Connected.DefaultTimeout(); await client.SendHubMessageAsync(new StreamInvocationMessage("123", nameof(StreamingHub.BlockingStream), Array.Empty<object>())).DefaultTimeout(); + await hubActivator.CreateTask.Task.DefaultTimeout(); await client.SendHubMessageAsync(new StreamInvocationMessage("123", nameof(StreamingHub.BlockingStream), Array.Empty<object>())).DefaultTimeout(); var completion = Assert.IsType<CompletionMessage>(await client.ReadAsync().DefaultTimeout()); Assert.Equal("Invocation ID '123' is already in use.", completion.Error); - var hubActivator = serviceProvider.GetService<IHubActivator<StreamingHub>>() as CustomHubActivator<StreamingHub>; - - // OnConnectedAsync and BlockingStream hubs have been disposed - Assert.Equal(2, hubActivator.ReleaseCount); + // OnConnectedAsync Hub has been disposed, BlockingStream still running + Assert.Equal(1, hubActivator.ReleaseCount); + Assert.Equal(2, hubActivator.CreatedCount); client.Dispose(); await connectionHandlerTask.DefaultTimeout(); + + // OnConnectedAsync, BlockingStream, and OnDisconnectedAsync hubs have been disposed + Assert.Equal(3, hubActivator.ReleaseCount); + Assert.Equal(3, hubActivator.CreatedCount); } } } @@ -4845,6 +4867,7 @@ public partial class HubConnectionHandlerTests : VerifiableLoggedTest private class CustomHubActivator<THub> : IHubActivator<THub> where THub : Hub { + public int CreatedCount; public int ReleaseCount; private readonly IServiceProvider _serviceProvider; public TaskCompletionSource ReleaseTask = new TaskCompletionSource(); @@ -4857,6 +4880,7 @@ public partial class HubConnectionHandlerTests : VerifiableLoggedTest public THub Create() { + CreatedCount++; ReleaseTask = new TaskCompletionSource(); var hub = new DefaultHubActivator<THub>(_serviceProvider).Create(); CreateTask.TrySetResult(); |