diff options
8 files changed, 97 insertions, 264 deletions
diff --git a/src/System.Net.Security/ref/System.Net.Security.cs b/src/System.Net.Security/ref/System.Net.Security.cs index aea8fb45bf..17891ac318 100644 --- a/src/System.Net.Security/ref/System.Net.Security.cs +++ b/src/System.Net.Security/ref/System.Net.Security.cs @@ -176,7 +176,6 @@ namespace System.Net.Security #endif public void Write(byte[] buffer) { } public override void Write(byte[] buffer, int offset, int count) { } - public override System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) { throw null; } } } namespace System.Security.Authentication diff --git a/src/System.Net.Security/src/System/Net/FixedSizeReader.cs b/src/System.Net.Security/src/System/Net/FixedSizeReader.cs index c4eb5cb05f..d1ffb829b6 100644 --- a/src/System.Net.Security/src/System/Net/FixedSizeReader.cs +++ b/src/System.Net.Security/src/System/Net/FixedSizeReader.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.IO; +using System.Threading.Tasks; namespace System.Net { @@ -70,17 +71,26 @@ namespace System.Net { while (true) { - IAsyncResult ar = _transport.BeginRead(_request.Buffer, _request.Offset + _totalRead, _request.Count - _totalRead, s_readCallback, this); - if (!ar.CompletedSynchronously) + int bytes; + + Task<int> t = _transport.ReadAsync(_request.Buffer, _request.Offset + _totalRead, _request.Count - _totalRead); + if (t.IsCompleted) + { + bytes = t.GetAwaiter().GetResult(); + } + else { + IAsyncResult ar = TaskToApm.Begin(t, s_readCallback, this); + if (!ar.CompletedSynchronously) + { #if DEBUG - _request._DebugAsyncChain = ar; + _request._DebugAsyncChain = ar; #endif - break; + break; + } + bytes = TaskToApm.End<int>(ar); } - int bytes = _transport.EndRead(ar); - if (CheckCompletionBeforeNextRead(bytes)) { break; @@ -135,7 +145,7 @@ namespace System.Net // Async completion. try { - int bytes = reader._transport.EndRead(transportResult); + int bytes = TaskToApm.End<int>(transportResult); if (reader.CheckCompletionBeforeNextRead(bytes)) { diff --git a/src/System.Net.Security/src/System/Net/Security/InternalNegotiateStream.cs b/src/System.Net.Security/src/System/Net/Security/InternalNegotiateStream.cs index c57008bb9c..b23fc4e27c 100644 --- a/src/System.Net.Security/src/System/Net/Security/InternalNegotiateStream.cs +++ b/src/System.Net.Security/src/System/Net/Security/InternalNegotiateStream.cs @@ -5,6 +5,7 @@ using System.Diagnostics; using System.IO; using System.Threading; +using System.Threading.Tasks; namespace System.Net.Security { @@ -167,13 +168,20 @@ namespace System.Net.Security { // prepare for the next request asyncRequest.SetNextRequest(buffer, offset + chunkBytes, count - chunkBytes, null); - IAsyncResult ar = InnerStream.BeginWrite(outBuffer, 0, encryptedBytes, s_writeCallback, asyncRequest); - if (!ar.CompletedSynchronously) + Task t = InnerStream.WriteAsync(outBuffer, 0, encryptedBytes); + if (t.IsCompleted) { - return; + t.GetAwaiter().GetResult(); + } + else + { + IAsyncResult ar = TaskToApm.Begin(t, s_writeCallback, asyncRequest); + if (!ar.CompletedSynchronously) + { + return; + } + TaskToApm.End(ar); } - - InnerStream.EndWrite(ar); } else { @@ -384,7 +392,7 @@ namespace System.Net.Security try { NegotiateStream negoStream = (NegotiateStream)asyncRequest.AsyncObject; - negoStream.InnerStream.EndWrite(transportResult); + TaskToApm.End(transportResult); if (asyncRequest.Count == 0) { // This was the last chunk. diff --git a/src/System.Net.Security/src/System/Net/Security/NegotiateStream.cs b/src/System.Net.Security/src/System/Net/Security/NegotiateStream.cs index 4ad3718532..e07a61c578 100644 --- a/src/System.Net.Security/src/System/Net/Security/NegotiateStream.cs +++ b/src/System.Net.Security/src/System/Net/Security/NegotiateStream.cs @@ -575,7 +575,7 @@ namespace System.Net.Security if (!_negoState.CanGetSecureStream) { - return InnerStream.BeginRead(buffer, offset, count, asyncCallback, asyncState); + return TaskToApm.Begin(InnerStream.ReadAsync(buffer, offset, count), asyncCallback, asyncState); } BufferAsyncResult bufferResult = new BufferAsyncResult(this, buffer, offset, count, asyncState, asyncCallback); @@ -597,7 +597,7 @@ namespace System.Net.Security if (!_negoState.CanGetSecureStream) { - return InnerStream.EndRead(asyncResult); + return TaskToApm.End<int>(asyncResult); } @@ -647,7 +647,7 @@ namespace System.Net.Security if (!_negoState.CanGetSecureStream) { - return InnerStream.BeginWrite(buffer, offset, count, asyncCallback, asyncState); + return TaskToApm.Begin(InnerStream.WriteAsync(buffer, offset, count), asyncCallback, asyncState); } BufferAsyncResult bufferResult = new BufferAsyncResult(this, buffer, offset, count, asyncState, asyncCallback); @@ -670,7 +670,7 @@ namespace System.Net.Security if (!_negoState.CanGetSecureStream) { - InnerStream.EndWrite(asyncResult); + TaskToApm.End(asyncResult); return; } @@ -706,67 +706,5 @@ namespace System.Net.Security } #endif } - - // ReadAsync - provide async read functionality. - // - // This method provides async read functionality. All we do is - // call through to the Begin/EndRead methods. - // - // Input: - // - // buffer - Buffer to read into. - // offset - Offset into the buffer where we're to read. - // size - Number of bytes to read. - // cancellationToken - Token used to request cancellation of the operation - // - // Returns: - // - // A Task<int> representing the read. - public override Task<int> ReadAsync(byte[] buffer, int offset, int size, CancellationToken cancellationToken) - { - if (cancellationToken.IsCancellationRequested) - { - return Task.FromCanceled<int>(cancellationToken); - } - - return Task.Factory.FromAsync( - (bufferArg, offsetArg, sizeArg, callback, state) => ((NegotiateStream)state).BeginRead(bufferArg, offsetArg, sizeArg, callback, state), - iar => ((NegotiateStream)iar.AsyncState).EndRead(iar), - buffer, - offset, - size, - this); - } - - // WriteAsync - provide async write functionality. - // - // This method provides async write functionality. All we do is - // call through to the Begin/EndWrite methods. - // - // Input: - // - // buffer - Buffer to write into. - // offset - Offset into the buffer where we're to write. - // size - Number of bytes to write. - // cancellationToken - Token used to request cancellation of the operation - // - // Returns: - // - // A Task representing the write. - public override Task WriteAsync(byte[] buffer, int offset, int size, CancellationToken cancellationToken) - { - if (cancellationToken.IsCancellationRequested) - { - return Task.FromCanceled<int>(cancellationToken); - } - - return Task.Factory.FromAsync( - (bufferArg, offsetArg, sizeArg, callback, state) => ((NegotiateStream)state).BeginWrite(bufferArg, offsetArg, sizeArg, callback, state), - iar => ((NegotiateStream)iar.AsyncState).EndWrite(iar), - buffer, - offset, - size, - this); - } } } diff --git a/src/System.Net.Security/src/System/Net/Security/SslState.cs b/src/System.Net.Security/src/System/Net/Security/SslState.cs index 7f89383269..cdc7f2f2ba 100644 --- a/src/System.Net.Security/src/System/Net/Security/SslState.cs +++ b/src/System.Net.Security/src/System/Net/Security/SslState.cs @@ -790,16 +790,23 @@ namespace System.Net.Security else { asyncRequest.AsyncState = message; - IAsyncResult ar = InnerStream.BeginWrite(message.Payload, 0, message.Size, s_writeCallback, asyncRequest); - if (!ar.CompletedSynchronously) + Task t = InnerStream.WriteAsync(message.Payload, 0, message.Size); + if (t.IsCompleted) + { + t.GetAwaiter().GetResult(); + } + else { + IAsyncResult ar = TaskToApm.Begin(t, s_writeCallback, asyncRequest); + if (!ar.CompletedSynchronously) + { #if DEBUG - asyncRequest._DebugAsyncChain = ar; + asyncRequest._DebugAsyncChain = ar; #endif - return; + return; + } + TaskToApm.End(ar); } - - InnerStream.EndWrite(ar); } } @@ -996,12 +1003,20 @@ namespace System.Net.Security else { asyncRequest.AsyncState = exception; - IAsyncResult ar = InnerStream.BeginWrite(message.Payload, 0, message.Size, s_writeCallback, asyncRequest); - if (!ar.CompletedSynchronously) + Task t = InnerStream.WriteAsync(message.Payload, 0, message.Size); + if (t.IsCompleted) { - return; + t.GetAwaiter().GetResult(); + } + else + { + IAsyncResult ar = TaskToApm.Begin(t, s_writeCallback, asyncRequest); + if (!ar.CompletedSynchronously) + { + return; + } + TaskToApm.End(ar); } - InnerStream.EndWrite(ar); } exception.Throw(); @@ -1065,7 +1080,7 @@ namespace System.Net.Security // Async completion. try { - sslState.InnerStream.EndWrite(transportResult); + TaskToApm.End(transportResult); // Special case for an error notification. object asyncState = asyncRequest.AsyncState; @@ -1833,14 +1848,14 @@ namespace System.Net.Security CheckThrow(authSuccessCheck:true, shutdownCheck:true); ProtocolToken message = Context.CreateShutdownToken(); - return InnerStream.BeginWrite(message.Payload, 0, message.Payload.Length, asyncCallback, asyncState); + return TaskToApm.Begin(InnerStream.WriteAsync(message.Payload, 0, message.Payload.Length), asyncCallback, asyncState); } internal void EndShutdown(IAsyncResult result) { CheckThrow(authSuccessCheck: true, shutdownCheck:true); - InnerStream.EndWrite(result); + TaskToApm.End(result); _shutdown = true; } } diff --git a/src/System.Net.Security/src/System/Net/Security/SslStream.cs b/src/System.Net.Security/src/System/Net/Security/SslStream.cs index 8457125c01..0f55e769b8 100644 --- a/src/System.Net.Security/src/System/Net/Security/SslStream.cs +++ b/src/System.Net.Security/src/System/Net/Security/SslStream.cs @@ -527,11 +527,6 @@ namespace System.Net.Security _sslState.SecureStream.Write(buffer, offset, count); } - public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - return _sslState.SecureStream.WriteAsync(buffer, offset, count, cancellationToken); - } - public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback asyncCallback, object asyncState) { return _sslState.SecureStream.BeginRead(buffer, offset, count, asyncCallback, asyncState); diff --git a/src/System.Net.Security/src/System/Net/Security/SslStreamInternal.cs b/src/System.Net.Security/src/System/Net/Security/SslStreamInternal.cs index c946a2ea35..6e459ec398 100644 --- a/src/System.Net.Security/src/System/Net/Security/SslStreamInternal.cs +++ b/src/System.Net.Security/src/System/Net/Security/SslStreamInternal.cs @@ -31,8 +31,6 @@ namespace System.Net.Security private AsyncProtocolRequest _readProtocolRequest; // cached, reusable AsyncProtocolRequest used for read operations private AsyncProtocolRequest _writeProtocolRequest; // cached, reusable AsyncProtocolRequest used for write operations - private SemaphoreSlim _asyncWriteActiveSemaphore; - // Never updated directly, special properties are used. This is the read buffer. private byte[] _internalBuffer; private bool _internalBufferFromPinnableCache; @@ -56,13 +54,6 @@ namespace System.Net.Security _reader = new FixedSizeReader(_sslState.InnerStream); } - internal SemaphoreSlim EnsureAsyncActiveWriteSemaphoreInitialized() - { - // Lazily-initialize _asyncWriteActiveSemaphore. As we're never accessing the SemaphoreSlim's - // WaitHandle, we don't need to worry about Disposing it. - return LazyInitializer.EnsureInitialized(ref _asyncWriteActiveSemaphore, () => new SemaphoreSlim(1, 1)); - } - // If we have a read buffer from the pinnable cache, return it. private void FreeReadBuffer() { @@ -141,55 +132,6 @@ namespace System.Net.Security ProcessWrite(buffer, offset, count, null); } - internal Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - if (!_sslState.InnerStream.CanWrite) return Task.FromException(new NotSupportedException(SR.NotSupported_UnwritableStream)); - // If cancellation was requested, bail early with an already completed task. - // Otherwise, return a task that represents the Begin/End methods. - return cancellationToken.IsCancellationRequested - ? Task.FromCanceled(cancellationToken) - : WriteAsyncImpl(buffer, offset, count, cancellationToken); - } - - private async Task WriteAsyncImpl(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - var semaphore = EnsureAsyncActiveWriteSemaphoreInitialized(); - await semaphore.WaitAsync(cancellationToken).ConfigureAwait(false); - - InitiateWrite(buffer, offset, count, null); - - bool failed = false; - try - { - await StartWritingAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); - - if (Interlocked.Exchange(ref _nestedWrite, 0) == 0) - { - throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, "WriteAsyncImpl")); - } - } - catch (Exception e) - { - _sslState.FinishWrite(); - - failed = true; - if (e is IOException) - { - throw; - } - - throw new IOException(SR.net_io_write, e); - } - finally - { - semaphore.Release(); - if (failed) - { - _nestedWrite = 0; - } - } - } - internal IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback asyncCallback, object asyncState) { var bufferResult = new BufferAsyncResult(this, buffer, offset, count, asyncState, asyncCallback); @@ -404,7 +346,17 @@ namespace System.Net.Security // private void ProcessWrite(byte[] buffer, int offset, int count, LazyAsyncResult asyncResult) { - var asyncRequest = InitiateWrite(buffer, offset, count, asyncResult); + _sslState.CheckThrow(authSuccessCheck:true, shutdownCheck:true); + ValidateParameters(buffer, offset, count); + + if (Interlocked.Exchange(ref _nestedWrite, 1) == 1) + { + throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "Write", "write")); + } + + // If this is an async operation, get the AsyncProtocolRequest to use. + // We do this only after we verify we're the sole write operation in flight. + AsyncProtocolRequest asyncRequest = GetOrCreateProtocolRequest(ref _writeProtocolRequest, asyncResult); bool failed = false; @@ -433,22 +385,6 @@ namespace System.Net.Security } } - private AsyncProtocolRequest InitiateWrite(byte[] buffer, int offset, int count, LazyAsyncResult asyncResult) - { - _sslState.CheckThrow(authSuccessCheck: true, shutdownCheck: true); - ValidateParameters(buffer, offset, count); - - if (Interlocked.Exchange(ref _nestedWrite, 1) == 1) - { - throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "Write", "write")); - } - - // If this is an async operation, get the AsyncProtocolRequest to use. - // We do this only after we verify we're the sole write operation in flight. - AsyncProtocolRequest asyncRequest = GetOrCreateProtocolRequest(ref _writeProtocolRequest, asyncResult); - return asyncRequest; - } - private void StartWriting(byte[] buffer, int offset, int count, AsyncProtocolRequest asyncRequest) { if (asyncRequest != null) @@ -519,14 +455,20 @@ namespace System.Net.Security { // Prepare for the next request. asyncRequest.SetNextRequest(buffer, offset + chunkBytes, count - chunkBytes, s_resumeAsyncWriteCallback); - IAsyncResult ar = _sslState.InnerStream.BeginWrite(outBuffer, 0, encryptedBytes, s_writeCallback, asyncRequest); - if (!ar.CompletedSynchronously) + Task t = _sslState.InnerStream.WriteAsync(outBuffer, 0, encryptedBytes); + if (t.IsCompleted) { - return; + t.GetAwaiter().GetResult(); + } + else + { + IAsyncResult ar = TaskToApm.Begin(t, s_writeCallback, asyncRequest); + if (!ar.CompletedSynchronously) + { + return; + } + TaskToApm.End(ar); } - - _sslState.InnerStream.EndWrite(ar); - } else { @@ -557,81 +499,6 @@ namespace System.Net.Security } } - private async Task StartWritingAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - // We loop to this method from the callback. - // If the last chunk was just completed from async callback (count < 0), we complete user request. - if (count >= 0 ) - { - byte[] outBuffer = null; - if (_pinnableOutputBufferInUse == null) - { - if (_pinnableOutputBuffer == null) - { - _pinnableOutputBuffer = s_PinnableWriteBufferCache.AllocateBuffer(); - } - - _pinnableOutputBufferInUse = buffer; - outBuffer = _pinnableOutputBuffer; - if (PinnableBufferCacheEventSource.Log.IsEnabled()) - { - PinnableBufferCacheEventSource.Log.DebugMessage3("In System.Net._SslStream.StartWritingAsync Trying Pinnable", this.GetHashCode(), count, PinnableBufferCacheEventSource.AddressOfByteArray(outBuffer)); - } - } - else - { - if (PinnableBufferCacheEventSource.Log.IsEnabled()) - { - PinnableBufferCacheEventSource.Log.DebugMessage2("In System.Net._SslStream.StartWritingAsync BufferInUse", this.GetHashCode(), count); - } - } - - do - { - if (count == 0 && !SslStreamPal.CanEncryptEmptyMessage) - { - // If it's an empty message and the PAL doesn't support that, - // we're done. - break; - } - - int chunkBytes = Math.Min(count, _sslState.MaxDataSize); - int encryptedBytes; - SecurityStatusPal status = _sslState.EncryptData(buffer, offset, chunkBytes, ref outBuffer, out encryptedBytes); - if (status.ErrorCode != SecurityStatusPalErrorCode.OK) - { - // Re-handshake status is not supported. - ProtocolToken message = new ProtocolToken(null, status); - throw new IOException(SR.net_io_encrypt, message.GetException()); - } - - if (PinnableBufferCacheEventSource.Log.IsEnabled()) - { - PinnableBufferCacheEventSource.Log.DebugMessage3("In System.Net._SslStream.StartWritingAsync Got Encrypted Buffer", - this.GetHashCode(), encryptedBytes, PinnableBufferCacheEventSource.AddressOfByteArray(outBuffer)); - } - - await _sslState.InnerStream.WriteAsync(outBuffer, 0, encryptedBytes, cancellationToken); - - offset += chunkBytes; - count -= chunkBytes; - - // Release write IO slot. - _sslState.FinishWrite(); - - } while (count != 0); - } - - if (buffer == _pinnableOutputBufferInUse) - { - _pinnableOutputBufferInUse = null; - if (PinnableBufferCacheEventSource.Log.IsEnabled()) - { - PinnableBufferCacheEventSource.Log.DebugMessage1("In System.Net._SslStream.StartWritingAsync Freeing buffer.", this.GetHashCode()); - } - } - } - // // Combined sync/async read method. For sync request asyncRequest==null. // @@ -907,7 +774,7 @@ namespace System.Net.Security try { - sslStream._sslState.InnerStream.EndWrite(transportResult); + TaskToApm.End(transportResult); sslStream._sslState.FinishWrite(); if (asyncRequest.Count == 0) diff --git a/src/System.Net.Security/src/System/Net/StreamFramer.cs b/src/System.Net.Security/src/System/Net/StreamFramer.cs index 3aac3ef885..8123c48235 100644 --- a/src/System.Net.Security/src/System/Net/StreamFramer.cs +++ b/src/System.Net.Security/src/System/Net/StreamFramer.cs @@ -5,6 +5,7 @@ using System.Diagnostics; using System.IO; using System.Globalization; +using System.Threading.Tasks; namespace System.Net { @@ -136,7 +137,7 @@ namespace System.Net _readHeaderBuffer, 0, _readHeaderBuffer.Length); - IAsyncResult result = _transport.BeginRead(_readHeaderBuffer, 0, _readHeaderBuffer.Length, + IAsyncResult result = TaskToApm.Begin(_transport.ReadAsync(_readHeaderBuffer, 0, _readHeaderBuffer.Length), _readFrameCallback, workerResult); if (result.CompletedSynchronously) @@ -199,7 +200,7 @@ namespace System.Net WorkerAsyncResult workerResult = (WorkerAsyncResult)transportResult.AsyncState; - int bytesRead = _transport.EndRead(transportResult); + int bytesRead = TaskToApm.End<int>(transportResult); workerResult.Offset += bytesRead; if (!(workerResult.Offset <= workerResult.End)) @@ -260,7 +261,7 @@ namespace System.Net workerResult.End = frame.Length; workerResult.Offset = 0; - // Transport.BeginRead below will pickup those changes. + // Transport.ReadAsync below will pickup those changes. } else { @@ -271,7 +272,7 @@ namespace System.Net } // This means we need more data to complete the data block. - transportResult = _transport.BeginRead(workerResult.Buffer, workerResult.Offset, workerResult.End - workerResult.Offset, + transportResult = TaskToApm.Begin(_transport.ReadAsync(workerResult.Buffer, workerResult.Offset, workerResult.End - workerResult.Offset), _readFrameCallback, workerResult); } while (transportResult.CompletedSynchronously); } @@ -352,7 +353,7 @@ namespace System.Net if (message.Length == 0) { - return _transport.BeginWrite(_writeHeaderBuffer, 0, _writeHeaderBuffer.Length, + return TaskToApm.Begin(_transport.WriteAsync(_writeHeaderBuffer, 0, _writeHeaderBuffer.Length), asyncCallback, stateObject); } @@ -361,7 +362,7 @@ namespace System.Net message, 0, message.Length); // Charge the first: - IAsyncResult result = _transport.BeginWrite(_writeHeaderBuffer, 0, _writeHeaderBuffer.Length, + IAsyncResult result = TaskToApm.Begin(_transport.WriteAsync(_writeHeaderBuffer, 0, _writeHeaderBuffer.Length), _beginWriteCallback, workerResult); if (result.CompletedSynchronously) @@ -412,7 +413,7 @@ namespace System.Net WorkerAsyncResult workerResult = (WorkerAsyncResult)transportResult.AsyncState; // First, complete the previous portion write. - _transport.EndWrite(transportResult); + TaskToApm.End(transportResult); // Check on exit criterion. if (workerResult.Offset == workerResult.End) @@ -425,7 +426,7 @@ namespace System.Net workerResult.Offset = workerResult.End; // Write next portion (frame body) using Async IO. - transportResult = _transport.BeginWrite(workerResult.Buffer, 0, workerResult.End, + transportResult = TaskToApm.Begin(_transport.WriteAsync(workerResult.Buffer, 0, workerResult.End), _beginWriteCallback, workerResult); } while (transportResult.CompletedSynchronously); @@ -454,7 +455,7 @@ namespace System.Net } else { - _transport.EndWrite(asyncResult); + TaskToApm.End(asyncResult); } } } |