Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/dotnet/aspnetcore.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrennan Conroy <brecon@microsoft.com>2022-08-09 19:45:19 +0300
committerBrennan Conroy <brecon@microsoft.com>2022-08-09 19:45:19 +0300
commitd1cef746caa8af45fb7f40cfbc9b0172d0eeebd0 (patch)
treec77680231d15994b3782019cc56efec16b9c7cdf
parent72ee5732f00289191a7f89ce8000854e52de808c (diff)
Add invocation queue to SignalR to avoid client results blocking receive loopbrecon/queue
-rw-r--r--src/SignalR/server/Core/src/HubConnectionContext.cs17
-rw-r--r--src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs200
-rw-r--r--src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs64
3 files changed, 212 insertions, 69 deletions
diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs
index 04e211d74f..66361d3f92 100644
--- a/src/SignalR/server/Core/src/HubConnectionContext.cs
+++ b/src/SignalR/server/Core/src/HubConnectionContext.cs
@@ -7,6 +7,7 @@ using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO.Pipelines;
using System.Security.Claims;
+using System.Threading.Channels;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Http.Features;
@@ -85,6 +86,16 @@ public partial class HubConnectionContext
{
ActiveInvocationLimit = new SemaphoreSlim(maxInvokeLimit, maxInvokeLimit);
}
+
+ // TODO: configurable capacity
+ PendingInvokes = Channel.CreateBounded<(HubMethodInvocationMessage, bool isStreamResponse, HubMethodDescriptor, object?[])> (new BoundedChannelOptions(5)
+ {
+ AllowSynchronousContinuations = false,
+ // We use TryWrite, Wait mode is the only option that returns false when the channel is full
+ FullMode = BoundedChannelFullMode.Wait,
+ SingleWriter = true,
+ SingleReader = true,
+ });
}
internal StreamTracker StreamTracker
@@ -101,6 +112,8 @@ public partial class HubConnectionContext
}
}
+ internal Channel<(HubMethodInvocationMessage, bool isStreamResponse, HubMethodDescriptor, object?[])> PendingInvokes { get; private set; }
+
internal HubCallerContext HubCallerContext { get; }
internal HubCallerClients HubCallerClients { get; set; } = null!;
@@ -108,6 +121,8 @@ public partial class HubConnectionContext
internal SemaphoreSlim? ActiveInvocationLimit { get; }
+ internal Task? InvokeLoop { get; set; }
+
/// <summary>
/// Gets a <see cref="CancellationToken"/> that notifies when the connection is aborted.
/// </summary>
@@ -677,6 +692,8 @@ public partial class HubConnectionContext
{
var connection = (HubConnectionContext)state!;
+ connection.PendingInvokes.Writer.TryComplete();
+
try
{
connection._connectionAbortedTokenSource.Cancel();
diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs
index 59b06dbf10..7ea9da0a51 100644
--- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs
+++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs
@@ -98,6 +98,37 @@ internal sealed partial class DefaultHubDispatcher<THub> : HubDispatcher<THub> w
{
hubActivator.Release(hub);
}
+
+ connection.InvokeLoop = ProcessInvokes(connection);
+ }
+
+ private async Task ProcessInvokes(HubConnectionContext connection)
+ {
+ // TODO: exceptions would close connection previously, figure out if we want to do that here as well
+
+ // TODO: not sure completing channel immediately ends the loop if items still in channel
+ while (await connection.PendingInvokes.Reader.WaitToReadAsync())
+ {
+ while (connection.PendingInvokes.Reader.TryRead(out var invoke))
+ {
+ var (hubMethodInvocationMessage, isStreamResponse, descriptor, arguments) = invoke;
+ bool isStreamCall = descriptor.StreamingParameters != null;
+ if (connection.ActiveInvocationLimit != null && !isStreamCall && !isStreamResponse)
+ {
+ await connection.ActiveInvocationLimit.RunAsync(static state =>
+ {
+ var (dispatcher, descriptor, connection, invocationMessage, arguments) = state;
+ return dispatcher.Invoke(descriptor, connection, invocationMessage, arguments, isStreamResponse: false, isStreamCall: false);
+ }, (this, descriptor, connection, hubMethodInvocationMessage, arguments));
+ }
+ else
+ {
+ await Invoke(descriptor, connection, hubMethodInvocationMessage, arguments, isStreamResponse, isStreamCall);
+ }
+ }
+ }
+
+ // TODO: cleanup all pending invokes
}
public override async Task OnDisconnectedAsync(HubConnectionContext connection, Exception? exception)
@@ -124,6 +155,8 @@ internal sealed partial class DefaultHubDispatcher<THub> : HubDispatcher<THub> w
{
hubActivator.Release(hub);
}
+
+ await connection.InvokeLoop!;
}
public override Task DispatchMessageAsync(HubConnectionContext connection, HubMessage hubMessage)
@@ -248,31 +281,108 @@ internal sealed partial class DefaultHubDispatcher<THub> : HubDispatcher<THub> w
return connection.WriteAsync(CompletionMessage.WithError(
hubMethodInvocationMessage.InvocationId, $"Unknown hub method '{hubMethodInvocationMessage.Target}'")).AsTask();
}
- else
+ return Task.CompletedTask;
+ }
+
+ if (!ValidateInvocationMode(descriptor, isStreamResponse, hubMethodInvocationMessage, out var error))
+ {
+ return connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId!, error)).AsTask();
+ }
+
+ var arguments = hubMethodInvocationMessage.Arguments;
+ CancellationTokenSource? cts = null;
+ try
+ {
+ var clientStreamLength = hubMethodInvocationMessage.StreamIds?.Length ?? 0;
+ var serverStreamLength = descriptor.StreamingParameters?.Count ?? 0;
+ if (clientStreamLength != serverStreamLength)
{
- return Task.CompletedTask;
+ var ex = new HubException($"Client sent {clientStreamLength} stream(s), Hub method expects {serverStreamLength}.");
+ Log.InvalidHubParameters(_logger, hubMethodInvocationMessage.Target, ex);
+ return SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
+ ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors));
+ }
+
+ if (descriptor.HasSyntheticArguments)
+ {
+ bool isStreamCall = descriptor.StreamingParameters != null;
+ ReplaceArguments(descriptor, hubMethodInvocationMessage, isStreamCall, connection, ref arguments, out cts);
+ }
+
+ if (isStreamResponse)
+ {
+ if (!connection.ActiveRequestCancellationSources.TryAdd(hubMethodInvocationMessage.InvocationId!, cts ?? CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted)))
+ {
+ Log.InvocationIdInUse(_logger, hubMethodInvocationMessage.InvocationId!);
+ error = $"Invocation ID '{hubMethodInvocationMessage.InvocationId}' is already in use.";
+
+ // TODO: tests for cleanup code, nothing fails currently with these 3 locations commented out
+
+ //if (hubMethodInvocationMessage.StreamIds != null)
+ //{
+ // foreach (var stream in hubMethodInvocationMessage.StreamIds)
+ // {
+ // connection.StreamTracker.TryRemove(CompletionMessage.Empty(stream));
+ // }
+ //}
+
+ //cts?.Dispose();
+ //connection.ActiveRequestCancellationSources.TryRemove(hubMethodInvocationMessage.InvocationId!, out _);
+
+ return connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId!, error)).AsTask();
+ }
}
}
- else
+ catch
{
- bool isStreamCall = descriptor.StreamingParameters != null;
- if (connection.ActiveInvocationLimit != null && !isStreamCall && !isStreamResponse)
+ //if (hubMethodInvocationMessage.StreamIds != null)
+ //{
+ // foreach (var stream in hubMethodInvocationMessage.StreamIds)
+ // {
+ // connection.StreamTracker.TryRemove(CompletionMessage.Empty(stream));
+ // }
+ //}
+
+ //cts?.Dispose();
+ //connection.ActiveRequestCancellationSources.TryRemove(hubMethodInvocationMessage.InvocationId!, out _);
+
+ return Task.CompletedTask;
+ }
+
+ if (!connection.PendingInvokes.Writer.TryWrite((hubMethodInvocationMessage, isStreamResponse, descriptor, arguments)))
+ {
+ // Log, dropped invoke call
+
+ //if (hubMethodInvocationMessage.StreamIds != null)
+ //{
+ // foreach (var stream in hubMethodInvocationMessage.StreamIds)
+ // {
+ // connection.StreamTracker.TryRemove(CompletionMessage.Empty(stream));
+ // }
+ //}
+
+ //cts?.Dispose();
+ //connection.ActiveRequestCancellationSources.TryRemove(hubMethodInvocationMessage.InvocationId!, out _);
+
+ if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
{
- return connection.ActiveInvocationLimit.RunAsync(static state =>
- {
- var (dispatcher, descriptor, connection, invocationMessage) = state;
- return dispatcher.Invoke(descriptor, connection, invocationMessage, isStreamResponse: false, isStreamCall: false);
- }, (this, descriptor, connection, hubMethodInvocationMessage));
+ // Send an error to the client. Then let the normal completion process occur
+ return connection.WriteAsync(CompletionMessage.WithError(
+ hubMethodInvocationMessage.InvocationId, "Invoke dropped.")).AsTask();
}
else
{
- return Invoke(descriptor, connection, hubMethodInvocationMessage, isStreamResponse, isStreamCall);
+ return Task.CompletedTask;
}
}
+ else
+ {
+ return Task.CompletedTask;
+ }
}
private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection,
- HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamResponse, bool isStreamCall)
+ HubMethodInvocationMessage hubMethodInvocationMessage, object?[] arguments, bool isStreamResponse, bool isStreamCall)
{
var methodExecutor = descriptor.MethodExecutor;
@@ -293,37 +403,19 @@ internal sealed partial class DefaultHubDispatcher<THub> : HubDispatcher<THub> w
return;
}
- if (!await ValidateInvocationMode(descriptor, isStreamResponse, hubMethodInvocationMessage, connection))
- {
- return;
- }
-
try
{
- var clientStreamLength = hubMethodInvocationMessage.StreamIds?.Length ?? 0;
- var serverStreamLength = descriptor.StreamingParameters?.Count ?? 0;
- if (clientStreamLength != serverStreamLength)
- {
- var ex = new HubException($"Client sent {clientStreamLength} stream(s), Hub method expects {serverStreamLength}.");
- Log.InvalidHubParameters(_logger, hubMethodInvocationMessage.Target, ex);
- await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
- ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors));
- return;
- }
-
InitializeHub(hub, connection);
Task? invocation = null;
- var arguments = hubMethodInvocationMessage.Arguments;
- CancellationTokenSource? cts = null;
if (descriptor.HasSyntheticArguments)
{
- ReplaceArguments(descriptor, hubMethodInvocationMessage, isStreamCall, connection, scope, ref arguments, out cts);
+ ReplaceServiceArguments(descriptor, scope, ref arguments);
}
if (isStreamResponse)
{
- _ = StreamAsync(hubMethodInvocationMessage.InvocationId!, connection, arguments, scope, hubActivator, hub, cts, hubMethodInvocationMessage, descriptor);
+ _ = StreamAsync(hubMethodInvocationMessage.InvocationId!, connection, arguments, scope, hubActivator, hub, hubMethodInvocationMessage, descriptor);
}
else
{
@@ -429,20 +521,19 @@ internal sealed partial class DefaultHubDispatcher<THub> : HubDispatcher<THub> w
}
private async Task StreamAsync(string invocationId, HubConnectionContext connection, object?[] arguments, AsyncServiceScope scope,
- IHubActivator<THub> hubActivator, THub hub, CancellationTokenSource? streamCts, HubMethodInvocationMessage hubMethodInvocationMessage, HubMethodDescriptor descriptor)
+ IHubActivator<THub> hubActivator, THub hub, HubMethodInvocationMessage hubMethodInvocationMessage, HubMethodDescriptor descriptor)
{
string? error = null;
-
- streamCts ??= CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
+ if (!connection.ActiveRequestCancellationSources.TryGetValue(invocationId, out var streamCts))
+ {
+ // should not happen (yet)
+ // maybe if we short-circuit canceled invocations then this might be able to occur
+ Debug.Assert(false, "stream cts not available when executing streaming hub method");
+ return;
+ }
try
{
- if (!connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts))
- {
- Log.InvocationIdInUse(_logger, invocationId);
- error = $"Invocation ID '{invocationId}' is already in use.";
- return;
- }
object? result;
try
@@ -585,17 +676,17 @@ internal sealed partial class DefaultHubDispatcher<THub> : HubDispatcher<THub> w
return authorizationResult.Succeeded;
}
- private async Task<bool> ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamResponse,
- HubMethodInvocationMessage hubMethodInvocationMessage, HubConnectionContext connection)
+ private bool ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamResponse,
+ HubMethodInvocationMessage hubMethodInvocationMessage, out string? error)
{
+ error = null;
if (hubMethodDescriptor.IsStreamResponse && !isStreamResponse)
{
// Non-null/empty InvocationId? Blocking
if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
{
Log.StreamingMethodCalledWithInvoke(_logger, hubMethodInvocationMessage);
- await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId,
- $"The client attempted to invoke the streaming '{hubMethodInvocationMessage.Target}' method with a non-streaming invocation."));
+ error = $"The client attempted to invoke the streaming '{hubMethodInvocationMessage.Target}' method with a non-streaming invocation.";
}
return false;
@@ -604,8 +695,7 @@ internal sealed partial class DefaultHubDispatcher<THub> : HubDispatcher<THub> w
if (!hubMethodDescriptor.IsStreamResponse && isStreamResponse)
{
Log.NonStreamingMethodCalledWithStream(_logger, hubMethodInvocationMessage);
- await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId!,
- $"The client attempted to invoke the non-streaming '{hubMethodInvocationMessage.Target}' method with a streaming invocation."));
+ error = $"The client attempted to invoke the non-streaming '{hubMethodInvocationMessage.Target}' method with a streaming invocation.";
return false;
}
@@ -613,8 +703,19 @@ internal sealed partial class DefaultHubDispatcher<THub> : HubDispatcher<THub> w
return true;
}
+ private static void ReplaceServiceArguments(HubMethodDescriptor descriptor, AsyncServiceScope scope, ref object?[] arguments)
+ {
+ for (var parameterPointer = 0; parameterPointer < arguments.Length; parameterPointer++)
+ {
+ if (descriptor.IsServiceArgument(parameterPointer))
+ {
+ arguments[parameterPointer] = scope.ServiceProvider.GetRequiredService(descriptor.OriginalParameterTypes![parameterPointer]);
+ }
+ }
+ }
+
private void ReplaceArguments(HubMethodDescriptor descriptor, HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamCall,
- HubConnectionContext connection, AsyncServiceScope scope, ref object?[] arguments, out CancellationTokenSource? cts)
+ HubConnectionContext connection, ref object?[] arguments, out CancellationTokenSource? cts)
{
cts = null;
// In order to add the synthetic arguments we need a new array because the invocation array is too small (it doesn't know about synthetic arguments)
@@ -641,7 +742,8 @@ internal sealed partial class DefaultHubDispatcher<THub> : HubDispatcher<THub> w
}
else if (descriptor.IsServiceArgument(parameterPointer))
{
- arguments[parameterPointer] = scope.ServiceProvider.GetRequiredService(descriptor.OriginalParameterTypes[parameterPointer]);
+ // replaced later when we create the scope
+ arguments[parameterPointer] = null;
}
else if (isStreamCall && ReflectionHelper.IsStreamingType(descriptor.OriginalParameterTypes[parameterPointer], mustBeDirectType: true))
{
diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs
index 8b35d46475..7c395b96c3 100644
--- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs
+++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs
@@ -543,7 +543,11 @@ public partial class HubConnectionHandlerTests : VerifiableLoggedTest
using (StartVerifiableLog())
{
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory,
- services => services.AddSignalR().AddHubOptions<HubT>(o => o.MaximumReceiveMessageSize = maximumMessageSize));
+ services => services.AddSignalR().AddHubOptions<HubT>(o =>
+ {
+ o.MaximumReceiveMessageSize = maximumMessageSize;
+ o.EnableDetailedErrors = true;
+ }));
using (var client = new TestClient())
{
@@ -554,20 +558,33 @@ public partial class HubConnectionHandlerTests : VerifiableLoggedTest
client.Connection.Application.Output.Write(payload3);
await client.Connection.Application.Output.FlushAsync();
- // 2 invocations should be processed
- var completionMessage = await client.ReadAsync().DefaultTimeout() as CompletionMessage;
- Assert.NotNull(completionMessage);
- Assert.Equal("1", completionMessage.InvocationId);
- Assert.Equal("one", completionMessage.Result);
-
- completionMessage = await client.ReadAsync().DefaultTimeout() as CompletionMessage;
- Assert.NotNull(completionMessage);
- Assert.Equal("2", completionMessage.InvocationId);
- Assert.Equal("two", completionMessage.Result);
-
- // We never receive the 3rd message since it was over the maximum message size
- CloseMessage closeMessage = await client.ReadAsync().DefaultTimeout() as CloseMessage;
- Assert.NotNull(closeMessage);
+ // 0-2 invocations may be processed, invocations are put in a queue to unblock the receive loop, so the processing of the queue can race with the reading of the bad message and closing of the connection
+ for (var i = 0; i < 3; ++i)
+ {
+ var hubMessage = await client.ReadAsync().DefaultTimeout();
+ if (hubMessage is CloseMessage closeMessage)
+ {
+ Assert.Equal("Connection closed with an error. InvalidDataException: The maximum message size of 71B was exceeded. The message size can be configured in AddHubOptions.", closeMessage.Error);
+ break;
+ }
+ else if (hubMessage is CompletionMessage completionMessage)
+ {
+ if (i == 0)
+ {
+ Assert.Equal("1", completionMessage.InvocationId);
+ Assert.Equal("one", completionMessage.Result);
+ }
+ else
+ {
+ Assert.Equal("2", completionMessage.InvocationId);
+ Assert.Equal("two", completionMessage.Result);
+ }
+ }
+ else
+ {
+ Assert.True(false);
+ }
+ }
client.Dispose();
@@ -2964,7 +2981,7 @@ public partial class HubConnectionHandlerTests : VerifiableLoggedTest
public PipeWriter Output => _originalDuplexPipe.Output;
}
- [Fact]
+ [Fact(Skip = "Invokes no longer block receive loop, figure out what we want to test here")]
public async Task HubMethodInvokeDoesNotCountTowardsClientTimeout()
{
using (StartVerifiableLog())
@@ -3242,6 +3259,7 @@ public partial class HubConnectionHandlerTests : VerifiableLoggedTest
builder.AddSingleton(typeof(IHubActivator<>), typeof(CustomHubActivator<>));
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<StreamingHub>>();
+ var hubActivator = serviceProvider.GetService<IHubActivator<StreamingHub>>() as CustomHubActivator<StreamingHub>;
using (var client = new TestClient())
{
@@ -3250,20 +3268,24 @@ public partial class HubConnectionHandlerTests : VerifiableLoggedTest
await client.Connected.DefaultTimeout();
await client.SendHubMessageAsync(new StreamInvocationMessage("123", nameof(StreamingHub.BlockingStream), Array.Empty<object>())).DefaultTimeout();
+ await hubActivator.CreateTask.Task.DefaultTimeout();
await client.SendHubMessageAsync(new StreamInvocationMessage("123", nameof(StreamingHub.BlockingStream), Array.Empty<object>())).DefaultTimeout();
var completion = Assert.IsType<CompletionMessage>(await client.ReadAsync().DefaultTimeout());
Assert.Equal("Invocation ID '123' is already in use.", completion.Error);
- var hubActivator = serviceProvider.GetService<IHubActivator<StreamingHub>>() as CustomHubActivator<StreamingHub>;
-
- // OnConnectedAsync and BlockingStream hubs have been disposed
- Assert.Equal(2, hubActivator.ReleaseCount);
+ // OnConnectedAsync Hub has been disposed, BlockingStream still running
+ Assert.Equal(1, hubActivator.ReleaseCount);
+ Assert.Equal(2, hubActivator.CreatedCount);
client.Dispose();
await connectionHandlerTask.DefaultTimeout();
+
+ // OnConnectedAsync, BlockingStream, and OnDisconnectedAsync hubs have been disposed
+ Assert.Equal(3, hubActivator.ReleaseCount);
+ Assert.Equal(3, hubActivator.CreatedCount);
}
}
}
@@ -4845,6 +4867,7 @@ public partial class HubConnectionHandlerTests : VerifiableLoggedTest
private class CustomHubActivator<THub> : IHubActivator<THub> where THub : Hub
{
+ public int CreatedCount;
public int ReleaseCount;
private readonly IServiceProvider _serviceProvider;
public TaskCompletionSource ReleaseTask = new TaskCompletionSource();
@@ -4857,6 +4880,7 @@ public partial class HubConnectionHandlerTests : VerifiableLoggedTest
public THub Create()
{
+ CreatedCount++;
ReleaseTask = new TaskCompletionSource();
var hub = new DefaultHubActivator<THub>(_serviceProvider).Create();
CreateTask.TrySetResult();