diff options
author | Stephen Toub <stoub@microsoft.com> | 2017-10-30 22:02:27 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-10-30 22:02:27 +0300 |
commit | 0e517d15d2a90aceffdd6bbc6c1de80b8582a154 (patch) | |
tree | 3cc0ce767799ed14612d07821a0d1eeea5045049 /src | |
parent | 9a90c2e1fcfb3276275a39f6cdc383e8a3012df0 (diff) |
Add ManagedHandler support for cancelling connect operations (#24873)
Diffstat (limited to 'src')
7 files changed, 235 insertions, 44 deletions
diff --git a/src/System.Net.Http/src/System/Net/Http/Managed/ConnectHelper.cs b/src/System.Net.Http/src/System/Net/Http/Managed/ConnectHelper.cs index d96ff6547d..46774e6eb2 100644 --- a/src/System.Net.Http/src/System/Net/Http/Managed/ConnectHelper.cs +++ b/src/System.Net.Http/src/System/Net/Http/Managed/ConnectHelper.cs @@ -2,31 +2,93 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Diagnostics; using System.IO; using System.Net.Sockets; +using System.Runtime.CompilerServices; +using System.Threading; using System.Threading.Tasks; namespace System.Net.Http { internal static class ConnectHelper { - public static async ValueTask<Stream> ConnectAsync(string host, int port) + public static async ValueTask<Stream> ConnectAsync(string host, int port, CancellationToken cancellationToken) { - var socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; try { - // TODO #23151: cancellation support? - await (IPAddress.TryParse(host, out IPAddress address) ? - socket.ConnectAsync(address, port) : - socket.ConnectAsync(host, port)).ConfigureAwait(false); + // Rather than creating a new Socket and calling ConnectAsync on it, we use the static + // Socket.ConnectAsync with a SocketAsyncEventArgs, as we can then use Socket.CancelConnectAsync + // to cancel it if needed. + using (var saea = new BuilderAndCancellationTokenSocketAsyncEventArgs(cancellationToken)) + { + // Configure which server to which to connect. + saea.RemoteEndPoint = IPAddress.TryParse(host, out IPAddress address) ? + (EndPoint)new IPEndPoint(address, port) : + new DnsEndPoint(host, port); + + // Hook up a callback that'll complete the Task when the operation completes. + saea.Completed += (s, e) => + { + var csaea = (BuilderAndCancellationTokenSocketAsyncEventArgs)e; + switch (e.SocketError) + { + case SocketError.Success: + csaea.Builder.SetResult(); + break; + case SocketError.OperationAborted: + case SocketError.ConnectionAborted: + if (cancellationToken.IsCancellationRequested) + { + csaea.Builder.SetException(new OperationCanceledException(csaea.CancellationToken)); + break; + } + goto default; + default: + csaea.Builder.SetException(new SocketException((int)e.SocketError)); + break; + } + }; + + // Initiate the connection. + if (Socket.ConnectAsync(SocketType.Stream, ProtocolType.Tcp, saea)) + { + // If it didn't complete synchronously, enable it to be canceled and wait for it. + using (cancellationToken.Register(s => Socket.CancelConnectAsync((SocketAsyncEventArgs)s), saea)) + { + await saea.Builder.Task.ConfigureAwait(false); + } + } + + Debug.Assert(saea.ConnectSocket != null, "Expected non-null socket"); + Debug.Assert(saea.ConnectSocket.Connected, "Expected socket to be connected"); + + // Configure the socket and return a stream for it. + Socket socket = saea.ConnectSocket; + socket.NoDelay = true; + return new NetworkStream(socket, ownsSocket: true); + } } catch (SocketException se) { - socket.Dispose(); throw new HttpRequestException(se.Message, se); } + } - return new NetworkStream(socket, ownsSocket: true); + /// <summary>SocketAsyncEventArgs that carries with it additional state for a Task builder and a CancellationToken.</summary> + private sealed class BuilderAndCancellationTokenSocketAsyncEventArgs : SocketAsyncEventArgs + { + public AsyncTaskMethodBuilder Builder { get; } + public CancellationToken CancellationToken { get; } + + public BuilderAndCancellationTokenSocketAsyncEventArgs(CancellationToken cancellationToken) + { + var b = new AsyncTaskMethodBuilder(); + var ignored = b.Task; // force initialization + Builder = b; + + CancellationToken = cancellationToken; + } } } } diff --git a/src/System.Net.Http/src/System/Net/Http/Managed/HttpConnectionHandler.cs b/src/System.Net.Http/src/System/Net/Http/Managed/HttpConnectionHandler.cs index 6a5a0723cf..c7963208e2 100644 --- a/src/System.Net.Http/src/System/Net/Http/Managed/HttpConnectionHandler.cs +++ b/src/System.Net.Http/src/System/Net/Http/Managed/HttpConnectionHandler.cs @@ -28,8 +28,8 @@ namespace System.Net.Http HttpConnectionPool pool = _connectionPools.GetOrAddPool(key); ValueTask<HttpConnection> connectionTask = pool.GetConnectionAsync( - state => state.handler.CreateConnection(state.request, state.key, state.pool), - (handler: this, request: request, key: key, pool: pool)); + (state, ct) => state.handler.CreateConnection(state.request, state.key, state.pool, ct), + (handler: this, request: request, key: key, pool: pool), cancellationToken); return connectionTask.IsCompletedSuccessfully ? connectionTask.Result.SendAsync(request, cancellationToken) : @@ -43,7 +43,7 @@ namespace System.Net.Http return await connection.SendAsync(request, cancellationToken).ConfigureAwait(false); } - private async ValueTask<SslStream> EstablishSslConnection(string host, HttpRequestMessage request, Stream stream) + private async ValueTask<SslStream> EstablishSslConnection(string host, HttpRequestMessage request, Stream stream, CancellationToken cancellationToken) { RemoteCertificateValidationCallback callback = null; if (_settings._serverCertificateCustomValidationCallback != null) @@ -61,12 +61,18 @@ namespace System.Net.Http }; } - SslStream sslStream = new SslStream(stream, false, callback); + var sslStream = new SslStream(stream); try { - // TODO https://github.com/dotnet/corefx/issues/23077#issuecomment-321807131: No cancellationToken? - await sslStream.AuthenticateAsClientAsync(host, _settings._clientCertificates, _settings._sslProtocols, _settings._checkCertificateRevocationList).ConfigureAwait(false); + await sslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions + { + TargetHost = host, + ClientCertificates = _settings._clientCertificates, + EnabledSslProtocols = _settings._sslProtocols, + CertificateRevocationCheckMode = _settings._checkCertificateRevocationList ? X509RevocationMode.Online : X509RevocationMode.NoCheck, + RemoteCertificateValidationCallback = callback + }, cancellationToken).ConfigureAwait(false); } catch (Exception e) { @@ -81,11 +87,12 @@ namespace System.Net.Http return sslStream; } - private async ValueTask<HttpConnection> CreateConnection(HttpRequestMessage request, HttpConnectionKey key, HttpConnectionPool pool) + private async ValueTask<HttpConnection> CreateConnection( + HttpRequestMessage request, HttpConnectionKey key, HttpConnectionPool pool, CancellationToken cancellationToken) { Uri uri = request.RequestUri; - Stream stream = await ConnectHelper.ConnectAsync(uri.IdnHost, uri.Port).ConfigureAwait(false); + Stream stream = await ConnectHelper.ConnectAsync(uri.IdnHost, uri.Port, cancellationToken).ConfigureAwait(false); TransportContext transportContext = null; @@ -126,7 +133,7 @@ namespace System.Net.Http } // Establish the connection using the parsed host name. - SslStream sslStream = await EstablishSslConnection(host, request, stream).ConfigureAwait(false); + SslStream sslStream = await EstablishSslConnection(host, request, stream, cancellationToken).ConfigureAwait(false); stream = sslStream; transportContext = sslStream.TransportContext; } diff --git a/src/System.Net.Http/src/System/Net/Http/Managed/HttpConnectionPool.cs b/src/System.Net.Http/src/System/Net/Http/Managed/HttpConnectionPool.cs index c2529d2f61..eb3143530f 100644 --- a/src/System.Net.Http/src/System/Net/Http/Managed/HttpConnectionPool.cs +++ b/src/System.Net.Http/src/System/Net/Http/Managed/HttpConnectionPool.cs @@ -21,8 +21,11 @@ namespace System.Net.Http private readonly List<CachedConnection> _idleConnections = new List<CachedConnection>(); /// <summary>The maximum number of connections allowed to be associated with the pool.</summary> private readonly int _maxConnections; - /// <summary>A queue of waiters waiting for a connection. This will be null if there's no maximum set.</summary> - private readonly Queue<ConnectionWaiter> _waiters; + + /// <summary>The head of a list of waiters waiting for a connection. Null if no one's waiting.</summary> + private ConnectionWaiter _waitersHead; + /// <summary>The tail of a list of waiters waiting for a connection. Null if no one's waiting.</summary> + private ConnectionWaiter _waitersTail; /// <summary>The number of connections associated with the pool. Some of these may be in <see cref="_idleConnections"/>, others may be in use.</summary> private int _associatedConnectionCount; @@ -36,17 +39,19 @@ namespace System.Net.Http public HttpConnectionPool(int maxConnections = int.MaxValue) // int.MaxValue treated as infinite { _maxConnections = maxConnections; - if (maxConnections < int.MaxValue) - { - _waiters = new Queue<ConnectionWaiter>(); - } } /// <summary>Object used to synchronize access to state in the pool.</summary> private object SyncObj => _idleConnections; - public ValueTask<HttpConnection> GetConnectionAsync<TState>(Func<TState, ValueTask<HttpConnection>> createConnection, TState state) + public ValueTask<HttpConnection> GetConnectionAsync<TState>( + Func<TState, CancellationToken, ValueTask<HttpConnection>> createConnection, TState state, CancellationToken cancellationToken) { + if (cancellationToken.IsCancellationRequested) + { + return new ValueTask<HttpConnection>(Task.FromCanceled<HttpConnection>(cancellationToken)); + } + List<CachedConnection> list = _idleConnections; lock (SyncObj) { @@ -76,11 +81,11 @@ namespace System.Net.Http // there's no limit on the number of connections associated with this // pool, or if we haven't reached such a limit, simply create a new // connection. - if (_waiters == null || _associatedConnectionCount < _maxConnections) + if (_associatedConnectionCount < _maxConnections) { if (NetEventSource.IsEnabled) Trace("Creating new connection for pool."); IncrementConnectionCountNoLock(); - return WaitForCreatedConnectionAsync(createConnection(state)); + return WaitForCreatedConnectionAsync(createConnection(state, cancellationToken)); } else { @@ -92,8 +97,27 @@ namespace System.Net.Http // space is available and the provided creation func has successfully // created the connection to be used. if (NetEventSource.IsEnabled) Trace("Limit reached. Waiting to create new connection."); - var waiter = new ConnectionWaiter<TState>(this, createConnection, state); - _waiters.Enqueue(waiter); + var waiter = new ConnectionWaiter<TState>(this, createConnection, state, cancellationToken); + EnqueueWaiter(waiter); + if (cancellationToken.CanBeCanceled) + { + // If cancellation could be requested, register a callback for it that'll cancel + // the waiter and remove the waiter from the queue. Note that this registration needs + // to happen under the reentrant lock and after enqueueing the waiter. + waiter._cancellationTokenRegistration = cancellationToken.Register(s => + { + var innerWaiter = (ConnectionWaiter)s; + lock (innerWaiter._pool.SyncObj) + { + // If it's in the list, remove it and cancel it. + if (innerWaiter._pool.RemoveWaiterForCancellation(innerWaiter)) + { + bool canceled = innerWaiter.TrySetCanceled(innerWaiter._cancellationToken); + Debug.Assert(canceled); + } + } + }, waiter); + } return new ValueTask<HttpConnection>(waiter.Task); } @@ -108,6 +132,87 @@ namespace System.Net.Http } } + /// <summary>Enqueues a waiter to the waiters list.</summary> + /// <param name="waiter">The waiter to add.</param> + private void EnqueueWaiter(ConnectionWaiter waiter) + { + Debug.Assert(Monitor.IsEntered(SyncObj)); + Debug.Assert(waiter != null); + Debug.Assert(waiter._next == null); + Debug.Assert(waiter._prev == null); + + waiter._next = _waitersHead; + if (_waitersHead != null) + { + _waitersHead._prev = waiter; + } + else + { + Debug.Assert(_waitersTail == null); + _waitersTail = waiter; + } + _waitersHead = waiter; + } + + /// <summary>Dequeues a waiter from the waiters list. The list must not be empty.</summary> + /// <returns>The dequeued waiter.</returns> + private ConnectionWaiter DequeueWaiter() + { + Debug.Assert(Monitor.IsEntered(SyncObj)); + Debug.Assert(_waitersTail != null); + + ConnectionWaiter waiter = _waitersTail; + _waitersTail = waiter._prev; + + if (_waitersTail != null) + { + _waitersTail._next = null; + } + else + { + Debug.Assert(_waitersHead == waiter); + _waitersHead = null; + } + + waiter._next = null; + waiter._prev = null; + + return waiter; + } + + /// <summary>Removes the specified waiter from the waiters list as part of a cancellation request.</summary> + /// <param name="waiter">The waiter to remove.</param> + /// <returns>true if the waiter was in the list; otherwise, false.</returns> + private bool RemoveWaiterForCancellation(ConnectionWaiter waiter) + { + Debug.Assert(Monitor.IsEntered(SyncObj)); + Debug.Assert(waiter != null); + Debug.Assert(waiter._cancellationToken.IsCancellationRequested); + + bool inList = waiter._next != null || waiter._prev != null || _waitersHead == waiter || _waitersTail == waiter; + + if (waiter._next != null) waiter._next._prev = waiter._prev; + if (waiter._prev != null) waiter._prev._next = waiter._next; + + if (_waitersHead == waiter && _waitersTail == waiter) + { + _waitersHead = _waitersTail = null; + } + else if (_waitersHead == waiter) + { + _waitersHead = waiter._next; + } + else if (_waitersTail == waiter) + { + _waitersTail = waiter._prev; + } + + waiter._next = null; + waiter._prev = null; + + return inList; + } + /// <summary>Waits for and returns the created connection, decrementing the associated connection count if it fails.</summary> private async ValueTask<HttpConnection> WaitForCreatedConnectionAsync(ValueTask<HttpConnection> creationTask) { @@ -160,7 +265,7 @@ namespace System.Net.Http // Mark the pool as not being stale. _usedSinceLastCleanup = true; - if (_waiters == null || _waiters.Count == 0) + if (_waitersHead == null) { // There are no waiters to which the count should logically be transferred, // so simply decrement the count. @@ -171,9 +276,10 @@ namespace System.Net.Http // There's at least one waiter to which we should try to logically transfer // the associated count. Get the waiter. Debug.Assert(_idleConnections.Count == 0, $"With {_idleConnections} connections, we shouldn't have a waiter."); - ConnectionWaiter waiter = _waiters.Dequeue(); + ConnectionWaiter waiter = DequeueWaiter(); Debug.Assert(waiter != null, "Expected non-null waiter"); Debug.Assert(waiter.Task.Status == TaskStatus.WaitingForActivation, $"Expected {waiter.Task.Status} == {nameof(TaskStatus.WaitingForActivation)}"); + waiter._cancellationTokenRegistration.Dispose(); // Having a waiter means there must not be any idle connections, so we need to create // one, and we do so using the logic associated with the waiter. @@ -231,10 +337,14 @@ namespace System.Net.Http // If there's someone waiting for a connection, simply // transfer this one to them rather than pooling it. - if (_waiters != null && _waiters.TryDequeue(out ConnectionWaiter waiter)) + if (_waitersTail != null) { + ConnectionWaiter waiter = DequeueWaiter(); + waiter._cancellationTokenRegistration.Dispose(); + if (NetEventSource.IsEnabled) connection.Trace("Transferring connection returned to pool."); waiter.SetResult(connection); + return; } @@ -419,14 +529,15 @@ namespace System.Net.Http private sealed class ConnectionWaiter<TState> : ConnectionWaiter { /// <summary>The function to invoke if/when <see cref="CreateConnectionAsync"/> is invoked.</summary> - private readonly Func<TState, ValueTask<HttpConnection>> _createConnectionAsync; + private readonly Func<TState, CancellationToken, ValueTask<HttpConnection>> _createConnectionAsync; /// <summary>The state to pass to <paramref name="func"/> when it's invoked.</summary> private readonly TState _state; /// <summary>Initializes the waiter.</summary> /// <param name="func">The function to invoke if/when <see cref="CreateConnectionAsync"/> is invoked.</param> /// <param name="state">The state to pass to <paramref name="func"/> when it's invoked.</param> - public ConnectionWaiter(HttpConnectionPool pool, Func<TState, ValueTask<HttpConnection>> func, TState state) : base(pool) + public ConnectionWaiter(HttpConnectionPool pool, Func<TState, CancellationToken, ValueTask<HttpConnection>> func, TState state, CancellationToken cancellationToken) : + base(pool, cancellationToken) { _createConnectionAsync = func; _state = state; @@ -437,7 +548,7 @@ namespace System.Net.Http { try { - return _createConnectionAsync(_state); + return _createConnectionAsync(_state, _cancellationToken); } catch (Exception e) { @@ -462,12 +573,21 @@ namespace System.Net.Http { /// <summary>The pool with which this waiter is associated.</summary> internal readonly HttpConnectionPool _pool; + /// <summary>Cancellation token for the waiter.</summary> + internal readonly CancellationToken _cancellationToken; + /// <summary>Registration that removes the waiter from the registration list.</summary> + internal CancellationTokenRegistration _cancellationTokenRegistration; + /// <summary>Next waiter in the list.</summary> + internal ConnectionWaiter _next; + /// <summary>Previous waiter in the list.</summary> + internal ConnectionWaiter _prev; /// <summary>Initializes the waiter.</summary> - public ConnectionWaiter(HttpConnectionPool pool) : base(TaskCreationOptions.RunContinuationsAsynchronously) + public ConnectionWaiter(HttpConnectionPool pool, CancellationToken cancellationToken) : base(TaskCreationOptions.RunContinuationsAsynchronously) { Debug.Assert(pool != null, "Expected non-null pool"); _pool = pool; + _cancellationToken = cancellationToken; } /// <summary>Creates a connection.</summary> diff --git a/src/System.Net.Http/src/System/Net/Http/Managed/HttpProxyConnectionHandler.cs b/src/System.Net.Http/src/System/Net/Http/Managed/HttpProxyConnectionHandler.cs index db807d407f..16ea967510 100644 --- a/src/System.Net.Http/src/System/Net/Http/Managed/HttpProxyConnectionHandler.cs +++ b/src/System.Net.Http/src/System/Net/Http/Managed/HttpProxyConnectionHandler.cs @@ -65,7 +65,7 @@ namespace System.Net.Http throw new NotImplementedException("no support for SSL tunneling through proxy"); } - HttpConnection connection = await GetOrCreateConnection(request, proxyUri).ConfigureAwait(false); + HttpConnection connection = await GetOrCreateConnection(request, proxyUri, cancellationToken).ConfigureAwait(false); HttpResponseMessage response = await connection.SendAsync(request, cancellationToken).ConfigureAwait(false); @@ -88,7 +88,7 @@ namespace System.Net.Http request.Headers.ProxyAuthorization = new AuthenticationHeaderValue(AuthenticationHelper.Basic, AuthenticationHelper.GetBasicTokenForCredential(credential)); - connection = await GetOrCreateConnection(request, proxyUri).ConfigureAwait(false); + connection = await GetOrCreateConnection(request, proxyUri, cancellationToken).ConfigureAwait(false); response = await connection.SendAsync(request, cancellationToken).ConfigureAwait(false); } @@ -140,15 +140,15 @@ namespace System.Net.Http return response; } - private ValueTask<HttpConnection> GetOrCreateConnection(HttpRequestMessage request, Uri proxyUri) + private ValueTask<HttpConnection> GetOrCreateConnection(HttpRequestMessage request, Uri proxyUri, CancellationToken cancellationToken) { var key = new HttpConnectionKey(proxyUri); HttpConnectionPool pool = _connectionPools.GetOrAddPool(key); - return pool.GetConnectionAsync(async state => + return pool.GetConnectionAsync(async (state, ct) => { - Stream stream = await ConnectHelper.ConnectAsync(state.proxyUri.IdnHost, state.proxyUri.Port).ConfigureAwait(false); + Stream stream = await ConnectHelper.ConnectAsync(state.proxyUri.IdnHost, state.proxyUri.Port, ct).ConfigureAwait(false); return new HttpConnection(state.pool, state.key, null, stream, null, true); - }, (pool: pool, key: key, request: request, proxyUri: proxyUri)); + }, (pool: pool, key: key, request: request, proxyUri: proxyUri), cancellationToken); } protected override void Dispose(bool disposing) diff --git a/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index 56af69daa4..dc4d0ac270 100644 --- a/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -138,7 +138,11 @@ namespace System.Net.WebSockets } // Issue the request. The response must be status code 101. - HttpResponseMessage response = await handler.SendAsync(request, cancellationToken).ConfigureAwait(false); + HttpResponseMessage response; + using (var externalAndAbortCancellation = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _abortSource.Token)) + { + response = await handler.SendAsync(request, externalAndAbortCancellation.Token).ConfigureAwait(false); + } if (response.StatusCode != HttpStatusCode.SwitchingProtocols) { throw new WebSocketException(SR.net_webstatus_ConnectFailure); diff --git a/src/System.Net.WebSockets.Client/tests/AbortTest.cs b/src/System.Net.WebSockets.Client/tests/AbortTest.cs index ff6ac0076d..e68c61ee7c 100644 --- a/src/System.Net.WebSockets.Client/tests/AbortTest.cs +++ b/src/System.Net.WebSockets.Client/tests/AbortTest.cs @@ -16,7 +16,6 @@ namespace System.Net.WebSockets.Client.Tests { public AbortTest(ITestOutputHelper output) : base(output) { } - [ActiveIssue(23151, TestPlatforms.AnyUnix)] // need ManagedHandler support for canceling a ConnectAsync operation [OuterLoop] // TODO: Issue #11345 [ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServers))] public async Task Abort_ConnectAndAbort_ThrowsWebSocketExceptionWithmessage(Uri server) diff --git a/src/System.Net.WebSockets.Client/tests/CancelTest.cs b/src/System.Net.WebSockets.Client/tests/CancelTest.cs index 183289829d..9bc3015aab 100644 --- a/src/System.Net.WebSockets.Client/tests/CancelTest.cs +++ b/src/System.Net.WebSockets.Client/tests/CancelTest.cs @@ -14,7 +14,6 @@ namespace System.Net.WebSockets.Client.Tests { public CancelTest(ITestOutputHelper output) : base(output) { } - [ActiveIssue(23151, TestPlatforms.AnyUnix)] // connection opening currently can't be canceled on ManagedHandler [OuterLoop] // TODO: Issue #11345 [ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServers))] public async Task ConnectAsync_Cancel_ThrowsWebSocketExceptionWithMessage(Uri server) |