diff options
author | Stephen Toub <stoub@microsoft.com> | 2017-10-26 01:55:40 +0300 |
---|---|---|
committer | Stephen Toub <stoub@microsoft.com> | 2017-10-26 02:45:02 +0300 |
commit | c8ec3a850bd05aa8d8a6067071b723366b5648a3 (patch) | |
tree | 26fe65cd10d28404199a538643f5f269c4f3df49 /src | |
parent | 789035424aa1d3940baa3acde7b78fd78641b5e4 (diff) |
Add System.Threading.Channels to corefx
Bring the source over from corefxlab, add a package, get everything building, etc.
Diffstat (limited to 'src')
38 files changed, 4160 insertions, 0 deletions
diff --git a/src/Common/src/System/Collections/Concurrent/SingleProducerConsumerQueue.cs b/src/Common/src/System/Collections/Concurrent/SingleProducerConsumerQueue.cs new file mode 100644 index 0000000000..c28db22d13 --- /dev/null +++ b/src/Common/src/System/Collections/Concurrent/SingleProducerConsumerQueue.cs @@ -0,0 +1,315 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Threading; + +namespace System.Collections.Concurrent +{ + /// <summary> + /// Provides a producer/consumer queue safe to be used by only one producer and one consumer concurrently. + /// </summary> + /// <typeparam name="T">Specifies the type of data contained in the queue.</typeparam> + [DebuggerDisplay("Count = {Count}")] + [DebuggerTypeProxy(typeof(SingleProducerSingleConsumerQueue<>.SingleProducerSingleConsumerQueue_DebugView))] + internal sealed class SingleProducerSingleConsumerQueue<T> : IEnumerable<T> + { + // Design: + // + // SingleProducerSingleConsumerQueue (SPSCQueue) is a concurrent queue designed to be used + // by one producer thread and one consumer thread. SPSCQueue does not work correctly when used by + // multiple producer threads concurrently or multiple consumer threads concurrently. + // + // SPSCQueue is based on segments that behave like circular buffers. Each circular buffer is represented + // as an array with two indexes: _first and _last. _first is the index of the array slot for the consumer + // to read next, and _last is the slot for the producer to write next. The circular buffer is empty when + // (_first == _last), and full when ((_last+1) % _array.Length == _first). + // + // Since _first is only ever modified by the consumer thread and _last by the producer, the two indices can + // be updated without interlocked operations. As long as the queue size fits inside a single circular buffer, + // enqueues and dequeues simply advance the corresponding indices around the circular buffer. If an enqueue finds + // that there is no room in the existing buffer, however, a new circular buffer is allocated that is twice as big + // as the old buffer. From then on, the producer will insert values into the new buffer. The consumer will first + // empty out the old buffer and only then follow the producer into the new (larger) buffer. + // + // As described above, the enqueue operation on the fast path only modifies the _first field of the current segment. + // However, it also needs to read _last in order to verify that there is room in the current segment. Similarly, the + // dequeue operation on the fast path only needs to modify _last, but also needs to read _first to verify that the + // queue is non-empty. This results in true cache line sharing between the producer and the consumer. + // + // The cache line sharing issue can be mitigating by having a possibly stale copy of _first that is owned by the producer, + // and a possibly stale copy of _last that is owned by the consumer. So, the consumer state is described using + // (_first, _lastCopy) and the producer state using (_firstCopy, _last). The consumer state is separated from + // the producer state by padding, which allows fast-path enqueues and dequeues from hitting shared cache lines. + // _lastCopy is the consumer's copy of _last. Whenever the consumer can tell that there is room in the buffer + // simply by observing _lastCopy, the consumer thread does not need to read _last and thus encounter a cache miss. Only + // when the buffer appears to be empty will the consumer refresh _lastCopy from _last. _firstCopy is used by the producer + // in the same way to avoid reading _first on the hot path. + + /// <summary>The initial size to use for segments (in number of elements).</summary> + private const int InitialSegmentSize = 32; // must be a power of 2 + /// <summary>The maximum size to use for segments (in number of elements).</summary> + private const int MaxSegmentSize = 0x1000000; // this could be made as large as Int32.MaxValue / 2 + + /// <summary>The head of the linked list of segments.</summary> + private volatile Segment _head; + /// <summary>The tail of the linked list of segments.</summary> + private volatile Segment _tail; + + /// <summary>Initializes the queue.</summary> + public SingleProducerSingleConsumerQueue() + { + // Validate constants in ctor rather than in an explicit cctor that would cause perf degradation + Debug.Assert(InitialSegmentSize > 0, "Initial segment size must be > 0."); + Debug.Assert((InitialSegmentSize & (InitialSegmentSize - 1)) == 0, "Initial segment size must be a power of 2"); + Debug.Assert(InitialSegmentSize <= MaxSegmentSize, "Initial segment size should be <= maximum."); + Debug.Assert(MaxSegmentSize < int.MaxValue / 2, "Max segment size * 2 must be < Int32.MaxValue, or else overflow could occur."); + + // Initialize the queue + _head = _tail = new Segment(InitialSegmentSize); + } + + /// <summary>Enqueues an item into the queue.</summary> + /// <param name="item">The item to enqueue.</param> + public void Enqueue(T item) + { + Segment segment = _tail; + T[] array = segment._array; + int last = segment._state._last; // local copy to avoid multiple volatile reads + + // Fast path: there's obviously room in the current segment + int tail2 = (last + 1) & (array.Length - 1); + if (tail2 != segment._state._firstCopy) + { + array[last] = item; + segment._state._last = tail2; + } + // Slow path: there may not be room in the current segment. + else EnqueueSlow(item, ref segment); + } + + /// <summary>Enqueues an item into the queue.</summary> + /// <param name="item">The item to enqueue.</param> + /// <param name="segment">The segment in which to first attempt to store the item.</param> + private void EnqueueSlow(T item, ref Segment segment) + { + Debug.Assert(segment != null, "Expected a non-null segment."); + + if (segment._state._firstCopy != segment._state._first) + { + segment._state._firstCopy = segment._state._first; + Enqueue(item); // will only recur once for this enqueue operation + return; + } + + int newSegmentSize = _tail._array.Length << 1; // double size + Debug.Assert(newSegmentSize > 0, "The max size should always be small enough that we don't overflow."); + if (newSegmentSize > MaxSegmentSize) newSegmentSize = MaxSegmentSize; + + var newSegment = new Segment(newSegmentSize); + newSegment._array[0] = item; + newSegment._state._last = 1; + newSegment._state._lastCopy = 1; + + try { } + finally + { + // Finally block to protect against corruption due to a thread abort + // between setting _next and setting _tail. + Volatile.Write(ref _tail._next, newSegment); // ensure segment not published until item is fully stored + _tail = newSegment; + } + } + + /// <summary>Attempts to dequeue an item from the queue.</summary> + /// <param name="result">The dequeued item.</param> + /// <returns>true if an item could be dequeued; otherwise, false.</returns> + public bool TryDequeue(out T result) + { + Segment segment = _head; + T[] array = segment._array; + int first = segment._state._first; // local copy to avoid multiple volatile reads + + // Fast path: there's obviously data available in the current segment + if (first != segment._state._lastCopy) + { + result = array[first]; + array[first] = default; // Clear the slot to release the element + segment._state._first = (first + 1) & (array.Length - 1); + return true; + } + // Slow path: there may not be data available in the current segment + else return TryDequeueSlow(ref segment, ref array, out result); + } + + /// <summary>Attempts to dequeue an item from the queue.</summary> + /// <param name="array">The array from which the item was dequeued.</param> + /// <param name="segment">The segment from which the item was dequeued.</param> + /// <param name="result">The dequeued item.</param> + /// <returns>true if an item could be dequeued; otherwise, false.</returns> + private bool TryDequeueSlow(ref Segment segment, ref T[] array, out T result) + { + Debug.Assert(segment != null, "Expected a non-null segment."); + Debug.Assert(array != null, "Expected a non-null item array."); + + if (segment._state._last != segment._state._lastCopy) + { + segment._state._lastCopy = segment._state._last; + return TryDequeue(out result); // will only recur once for this dequeue operation + } + + if (segment._next != null && segment._state._first == segment._state._last) + { + segment = segment._next; + array = segment._array; + _head = segment; + } + + int first = segment._state._first; // local copy to avoid extraneous volatile reads + + if (first == segment._state._last) + { + result = default; + return false; + } + + result = array[first]; + array[first] = default; // Clear the slot to release the element + segment._state._first = (first + 1) & (segment._array.Length - 1); + segment._state._lastCopy = segment._state._last; // Refresh _lastCopy to ensure that _first has not passed _lastCopy + + return true; + } + + /// <summary>Gets whether the collection is currently empty.</summary> + public bool IsEmpty + { + // This implementation is optimized for calls from the consumer. + get + { + Segment head = _head; + if (head._state._first != head._state._lastCopy) return false; // _first is volatile, so the read of _lastCopy cannot get reordered + if (head._state._first != head._state._last) return false; + return head._next == null; + } + } + + /// <summary>Gets an enumerable for the collection.</summary> + /// <remarks>This method is not safe to use concurrently with any other members that may mutate the collection.</remarks> + public IEnumerator<T> GetEnumerator() + { + for (Segment segment = _head; segment != null; segment = segment._next) + { + for (int pt = segment._state._first; + pt != segment._state._last; + pt = (pt + 1) & (segment._array.Length - 1)) + { + yield return segment._array[pt]; + } + } + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + /// <summary>Gets the number of items in the collection.</summary> + /// <remarks>This method is not safe to use concurrently with any other members that may mutate the collection.</remarks> + internal int Count + { + get + { + int count = 0; + for (Segment segment = _head; segment != null; segment = segment._next) + { + int arraySize = segment._array.Length; + int first, last; + while (true) // Count is not meant to be used concurrently, but this helps to avoid issues if it is + { + first = segment._state._first; + last = segment._state._last; + if (first == segment._state._first) break; + } + count += (last - first) & (arraySize - 1); + } + return count; + } + } + + /// <summary>A segment in the queue containing one or more items.</summary> + [StructLayout(LayoutKind.Sequential)] + private sealed class Segment + { + /// <summary>The next segment in the linked list of segments.</summary> + internal Segment _next; + /// <summary>The data stored in this segment.</summary> + internal readonly T[] _array; + /// <summary>Details about the segment.</summary> + internal SegmentState _state; // separated out to enable StructLayout attribute to take effect + + /// <summary>Initializes the segment.</summary> + /// <param name="size">The size to use for this segment.</param> + internal Segment(int size) + { + Debug.Assert((size & (size - 1)) == 0, "Size must be a power of 2"); + _array = new T[size]; + } + } + + /// <summary>Stores information about a segment.</summary> + [StructLayout(LayoutKind.Sequential)] // enforce layout so that padding reduces false sharing + private struct SegmentState + { + /// <summary>Padding to reduce false sharing between the segment's array and _first.</summary> + internal PaddingFor32 _pad0; + + /// <summary>The index of the current head in the segment.</summary> + internal volatile int _first; + /// <summary>A copy of the current tail index.</summary> + internal int _lastCopy; // not volatile as read and written by the producer, except for IsEmpty, and there _lastCopy is only read after reading the volatile _first + + /// <summary>Padding to reduce false sharing between the first and last.</summary> + internal PaddingFor32 _pad1; + + /// <summary>A copy of the current head index.</summary> + internal int _firstCopy; // not volatile as only read and written by the consumer thread + /// <summary>The index of the current tail in the segment.</summary> + internal volatile int _last; + + /// <summary>Padding to reduce false sharing with the last and what's after the segment.</summary> + internal PaddingFor32 _pad2; + } + + /// <summary>Debugger type proxy for a SingleProducerSingleConsumerQueue of T.</summary> + private sealed class SingleProducerSingleConsumerQueue_DebugView + { + /// <summary>The queue being visualized.</summary> + private readonly SingleProducerSingleConsumerQueue<T> _queue; + + /// <summary>Initializes the debug view.</summary> + /// <param name="queue">The queue being debugged.</param> + public SingleProducerSingleConsumerQueue_DebugView(SingleProducerSingleConsumerQueue<T> queue) + { + Debug.Assert(queue != null, "Expected a non-null queue."); + _queue = queue; + } + + /// <summary>Gets the contents of the list.</summary> + [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)] + public T[] Items => new List<T>(_queue).ToArray(); + } + } + + + /// <summary>A placeholder class for common padding constants and eventually routines.</summary> + internal static class PaddingHelpers + { + /// <summary>A size greater than or equal to the size of the most common CPU cache lines.</summary> + internal const int CACHE_LINE_SIZE = 128; + } + + /// <summary>Padding structure used to minimize false sharing in SingleProducerSingleConsumerQueue{T}.</summary> + [StructLayout(LayoutKind.Explicit, Size = PaddingHelpers.CACHE_LINE_SIZE - sizeof(int))] // Based on common case of 64-byte cache lines + internal struct PaddingFor32 { } +} diff --git a/src/System.Threading.Channels/System.Threading.Channels.sln b/src/System.Threading.Channels/System.Threading.Channels.sln new file mode 100644 index 0000000000..7a2097ca9f --- /dev/null +++ b/src/System.Threading.Channels/System.Threading.Channels.sln @@ -0,0 +1,64 @@ +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio 15 +VisualStudioVersion = 15.0.27019.1 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "System.Threading.Channels.Tests", "tests\System.Threading.Channels.Tests.csproj", "{95DFC527-4DC1-495E-97D7-E94EE1F7140D}" + ProjectSection(ProjectDependencies) = postProject + {1DD0FF15-6234-4BD6-850A-317F05479554} = {1DD0FF15-6234-4BD6-850A-317F05479554} + EndProjectSection +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "System.Threading.Channels", "src\System.Threading.Channels.csproj", "{1DD0FF15-6234-4BD6-850A-317F05479554}" + ProjectSection(ProjectDependencies) = postProject + {9C524CA0-92FF-437B-B568-BCE8A794A69A} = {9C524CA0-92FF-437B-B568-BCE8A794A69A} + EndProjectSection +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "System.Threading.Channels", "ref\System.Threading.Channels.csproj", "{9C524CA0-92FF-437B-B568-BCE8A794A69A}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tests", "tests", "{1A2F9F4A-A032-433E-B914-ADD5992BB178}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{E107E9C1-E893-4E87-987E-04EF0DCEAEFD}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + netstandard-Debug|Any CPU = netstandard-Debug|Any CPU + netstandard-Release|Any CPU = netstandard-Release|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {95DFC527-4DC1-495E-97D7-E94EE1F7140D}.Debug|Any CPU.ActiveCfg = netstandard-Debug|Any CPU + {95DFC527-4DC1-495E-97D7-E94EE1F7140D}.Debug|Any CPU.Build.0 = netstandard-Debug|Any CPU + {95DFC527-4DC1-495E-97D7-E94EE1F7140D}.netstandard-Debug|Any CPU.ActiveCfg = netstandard-Debug|Any CPU + {95DFC527-4DC1-495E-97D7-E94EE1F7140D}.netstandard-Debug|Any CPU.Build.0 = netstandard-Debug|Any CPU + {95DFC527-4DC1-495E-97D7-E94EE1F7140D}.netstandard-Release|Any CPU.ActiveCfg = netstandard-Release|Any CPU + {95DFC527-4DC1-495E-97D7-E94EE1F7140D}.netstandard-Release|Any CPU.Build.0 = netstandard-Release|Any CPU + {95DFC527-4DC1-495E-97D7-E94EE1F7140D}.Release|Any CPU.ActiveCfg = netstandard-Release|Any CPU + {95DFC527-4DC1-495E-97D7-E94EE1F7140D}.Release|Any CPU.Build.0 = netstandard-Release|Any CPU + {1DD0FF15-6234-4BD6-850A-317F05479554}.Debug|Any CPU.ActiveCfg = netstandard1.3-Debug|Any CPU + {1DD0FF15-6234-4BD6-850A-317F05479554}.Debug|Any CPU.Build.0 = netstandard1.3-Debug|Any CPU + {1DD0FF15-6234-4BD6-850A-317F05479554}.netstandard-Debug|Any CPU.ActiveCfg = netstandard1.3-Release|Any CPU + {1DD0FF15-6234-4BD6-850A-317F05479554}.netstandard-Debug|Any CPU.Build.0 = netstandard1.3-Release|Any CPU + {1DD0FF15-6234-4BD6-850A-317F05479554}.netstandard-Release|Any CPU.ActiveCfg = netstandard1.3-Release|Any CPU + {1DD0FF15-6234-4BD6-850A-317F05479554}.netstandard-Release|Any CPU.Build.0 = netstandard1.3-Release|Any CPU + {1DD0FF15-6234-4BD6-850A-317F05479554}.Release|Any CPU.ActiveCfg = netstandard1.3-Release|Any CPU + {1DD0FF15-6234-4BD6-850A-317F05479554}.Release|Any CPU.Build.0 = netstandard1.3-Release|Any CPU + {9C524CA0-92FF-437B-B568-BCE8A794A69A}.Debug|Any CPU.ActiveCfg = netstandard1.3-Debug|Any CPU + {9C524CA0-92FF-437B-B568-BCE8A794A69A}.Debug|Any CPU.Build.0 = netstandard1.3-Debug|Any CPU + {9C524CA0-92FF-437B-B568-BCE8A794A69A}.netstandard-Debug|Any CPU.ActiveCfg = netstandard1.3-Release|Any CPU + {9C524CA0-92FF-437B-B568-BCE8A794A69A}.netstandard-Debug|Any CPU.Build.0 = netstandard1.3-Release|Any CPU + {9C524CA0-92FF-437B-B568-BCE8A794A69A}.netstandard-Release|Any CPU.ActiveCfg = netstandard1.3-Release|Any CPU + {9C524CA0-92FF-437B-B568-BCE8A794A69A}.netstandard-Release|Any CPU.Build.0 = netstandard1.3-Release|Any CPU + {9C524CA0-92FF-437B-B568-BCE8A794A69A}.Release|Any CPU.ActiveCfg = netstandard1.3-Release|Any CPU + {9C524CA0-92FF-437B-B568-BCE8A794A69A}.Release|Any CPU.Build.0 = netstandard1.3-Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(NestedProjects) = preSolution + {95DFC527-4DC1-495E-97D7-E94EE1F7140D} = {1A2F9F4A-A032-433E-B914-ADD5992BB178} + {1DD0FF15-6234-4BD6-850A-317F05479554} = {E107E9C1-E893-4E87-987E-04EF0DCEAEFD} + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {83C15975-72A6-4FC2-9694-46EF0F4C7A3D} + EndGlobalSection +EndGlobal diff --git a/src/System.Threading.Channels/dir.props b/src/System.Threading.Channels/dir.props new file mode 100644 index 0000000000..4356decc45 --- /dev/null +++ b/src/System.Threading.Channels/dir.props @@ -0,0 +1,8 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="14.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <Import Project="..\dir.props" /> + <PropertyGroup> + <AssemblyVersion>4.0.0.0</AssemblyVersion> + <AssemblyKey>Open</AssemblyKey> + </PropertyGroup> +</Project>
\ No newline at end of file diff --git a/src/System.Threading.Channels/pkg/System.Threading.Channels.pkgproj b/src/System.Threading.Channels/pkg/System.Threading.Channels.pkgproj new file mode 100644 index 0000000000..c96fb26052 --- /dev/null +++ b/src/System.Threading.Channels/pkg/System.Threading.Channels.pkgproj @@ -0,0 +1,14 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="14.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.props))\dir.props" /> + <PropertyGroup> + <!-- we need to be supported on pre-nuget-3 platforms (Dev12, Dev11, etc) --> + <MinClientVersion>2.8.6</MinClientVersion> + </PropertyGroup> + <ItemGroup> + <ProjectReference Include="..\src\System.Threading.Channels.csproj"> + <SupportedFramework>netcoreapp2.0;net461;$(AllXamarinFrameworks)</SupportedFramework> + </ProjectReference> + </ItemGroup> + <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.targets))\dir.targets" /> +</Project> diff --git a/src/System.Threading.Channels/ref/Configurations.props b/src/System.Threading.Channels/ref/Configurations.props new file mode 100644 index 0000000000..78953dfc88 --- /dev/null +++ b/src/System.Threading.Channels/ref/Configurations.props @@ -0,0 +1,8 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="14.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <PropertyGroup> + <BuildConfigurations> + netstandard; + </BuildConfigurations> + </PropertyGroup> +</Project> diff --git a/src/System.Threading.Channels/ref/System.Threading.Channels.cs b/src/System.Threading.Channels/ref/System.Threading.Channels.cs new file mode 100644 index 0000000000..c76c10a905 --- /dev/null +++ b/src/System.Threading.Channels/ref/System.Threading.Channels.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. +// ------------------------------------------------------------------------------ +// Changes to this file must follow the http://aka.ms/api-review process. +// ------------------------------------------------------------------------------ + +namespace System.Threading.Channels +{ + public enum BoundedChannelFullMode + { + DropNewest = 1, + DropOldest = 2, + DropWrite = 3, + Wait = 0, + } + public sealed partial class BoundedChannelOptions : System.Threading.Channels.ChannelOptions + { + public BoundedChannelOptions(int capacity) { } + public int Capacity { get { throw null; } set { } } + public System.Threading.Channels.BoundedChannelFullMode FullMode { get { throw null; } set { } } + } + public static partial class Channel + { + public static System.Threading.Channels.Channel<T> CreateBounded<T>(int capacity) { throw null; } + public static System.Threading.Channels.Channel<T> CreateBounded<T>(System.Threading.Channels.BoundedChannelOptions options) { throw null; } + public static System.Threading.Channels.Channel<T> CreateUnbounded<T>() { throw null; } + public static System.Threading.Channels.Channel<T> CreateUnbounded<T>(System.Threading.Channels.UnboundedChannelOptions options) { throw null; } + public static System.Threading.Channels.Channel<T> CreateUnbuffered<T>() { throw null; } + public static System.Threading.Channels.Channel<T> CreateUnbuffered<T>(System.Threading.Channels.UnbufferedChannelOptions options) { throw null; } + } + public partial class ChannelClosedException : System.InvalidOperationException + { + public ChannelClosedException() { } + public ChannelClosedException(System.Exception innerException) { } + public ChannelClosedException(string message) { } + public ChannelClosedException(string message, System.Exception innerException) { } + } + public abstract partial class ChannelOptions + { + protected ChannelOptions() { } + public bool AllowSynchronousContinuations { get { throw null; } set { } } + public bool SingleReader { get { throw null; } set { } } + public bool SingleWriter { get { throw null; } set { } } + } + public abstract partial class ChannelReader<T> + { + protected ChannelReader() { } + public virtual System.Threading.Tasks.Task Completion { get { throw null; } } + public abstract bool TryRead(out T item); + public abstract System.Threading.Tasks.Task<bool> WaitToReadAsync(System.Threading.CancellationToken cancellationToken=default); + } + public abstract partial class ChannelWriter<T> + { + protected ChannelWriter() { } + public void Complete(System.Exception error=null) { } + public virtual bool TryComplete(System.Exception error=null) { throw null; } + public abstract bool TryWrite(T item); + public abstract System.Threading.Tasks.Task<bool> WaitToWriteAsync(System.Threading.CancellationToken cancellationToken=default); + public virtual System.Threading.Tasks.Task WriteAsync(T item, System.Threading.CancellationToken cancellationToken=default) { throw null; } + } + public abstract partial class Channel<T> : System.Threading.Channels.Channel<T, T> + { + protected Channel() { } + } + public abstract partial class Channel<TWrite, TRead> + { + protected Channel() { } + public System.Threading.Channels.ChannelReader<TRead> Reader { get { throw null; } protected set { } } + public System.Threading.Channels.ChannelWriter<TWrite> Writer { get { throw null; } protected set { } } + public static implicit operator System.Threading.Channels.ChannelReader<TRead> (System.Threading.Channels.Channel<TWrite, TRead> channel) { throw null; } + public static implicit operator System.Threading.Channels.ChannelWriter<TWrite> (System.Threading.Channels.Channel<TWrite, TRead> channel) { throw null; } + } + public sealed partial class UnboundedChannelOptions : System.Threading.Channels.ChannelOptions + { + public UnboundedChannelOptions() { } + } + public sealed partial class UnbufferedChannelOptions : System.Threading.Channels.ChannelOptions + { + public UnbufferedChannelOptions() { } + } +} diff --git a/src/System.Threading.Channels/ref/System.Threading.Channels.csproj b/src/System.Threading.Channels/ref/System.Threading.Channels.csproj new file mode 100644 index 0000000000..557c7f0f2d --- /dev/null +++ b/src/System.Threading.Channels/ref/System.Threading.Channels.csproj @@ -0,0 +1,17 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="14.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.props))\dir.props" /> + <PropertyGroup> + <ProjectGuid>{9C524CA0-92FF-437B-B568-BCE8A794A69A}</ProjectGuid> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'netstandard1.3-Debug|AnyCPU'" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'netstandard1.3-Release|AnyCPU'" /> + <ItemGroup> + <Compile Include="System.Threading.Channels.cs" /> + </ItemGroup> + <ItemGroup> + <Reference Include="System.Runtime" /> + <Reference Include="System.Threading.Tasks" /> + </ItemGroup> + <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.targets))\dir.targets" /> +</Project> diff --git a/src/System.Threading.Channels/src/Configurations.props b/src/System.Threading.Channels/src/Configurations.props new file mode 100644 index 0000000000..78953dfc88 --- /dev/null +++ b/src/System.Threading.Channels/src/Configurations.props @@ -0,0 +1,8 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="14.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <PropertyGroup> + <BuildConfigurations> + netstandard; + </BuildConfigurations> + </PropertyGroup> +</Project> diff --git a/src/System.Threading.Channels/src/Resources/Strings.resx b/src/System.Threading.Channels/src/Resources/Strings.resx new file mode 100644 index 0000000000..2beea8a357 --- /dev/null +++ b/src/System.Threading.Channels/src/Resources/Strings.resx @@ -0,0 +1,123 @@ +<?xml version="1.0" encoding="utf-8"?> +<root> + <!-- + Microsoft ResX Schema + + Version 2.0 + + The primary goals of this format is to allow a simple XML format + that is mostly human readable. The generation and parsing of the + various data types are done through the TypeConverter classes + associated with the data types. + + Example: + + ... ado.net/XML headers & schema ... + <resheader name="resmimetype">text/microsoft-resx</resheader> + <resheader name="version">2.0</resheader> + <resheader name="reader">System.Resources.ResXResourceReader, System.Windows.Forms, ...</resheader> + <resheader name="writer">System.Resources.ResXResourceWriter, System.Windows.Forms, ...</resheader> + <data name="Name1"><value>this is my long string</value><comment>this is a comment</comment></data> + <data name="Color1" type="System.Drawing.Color, System.Drawing">Blue</data> + <data name="Bitmap1" mimetype="application/x-microsoft.net.object.binary.base64"> + <value>[base64 mime encoded serialized .NET Framework object]</value> + </data> + <data name="Icon1" type="System.Drawing.Icon, System.Drawing" mimetype="application/x-microsoft.net.object.bytearray.base64"> + <value>[base64 mime encoded string representing a byte array form of the .NET Framework object]</value> + <comment>This is a comment</comment> + </data> + + There are any number of "resheader" rows that contain simple + name/value pairs. + + Each data row contains a name, and value. The row also contains a + type or mimetype. Type corresponds to a .NET class that support + text/value conversion through the TypeConverter architecture. + Classes that don't support this are serialized and stored with the + mimetype set. + + The mimetype is used for serialized objects, and tells the + ResXResourceReader how to depersist the object. This is currently not + extensible. For a given mimetype the value must be set accordingly: + + Note - application/x-microsoft.net.object.binary.base64 is the format + that the ResXResourceWriter will generate, however the reader can + read any of the formats listed below. + + mimetype: application/x-microsoft.net.object.binary.base64 + value : The object must be serialized with + : System.Runtime.Serialization.Formatters.Binary.BinaryFormatter + : and then encoded with base64 encoding. + + mimetype: application/x-microsoft.net.object.soap.base64 + value : The object must be serialized with + : System.Runtime.Serialization.Formatters.Soap.SoapFormatter + : and then encoded with base64 encoding. + + mimetype: application/x-microsoft.net.object.bytearray.base64 + value : The object must be serialized into a byte array + : using a System.ComponentModel.TypeConverter + : and then encoded with base64 encoding. + --> + <xsd:schema id="root" xmlns="" xmlns:xsd="http://www.w3.org/2001/XMLSchema" xmlns:msdata="urn:schemas-microsoft-com:xml-msdata"> + <xsd:import namespace="http://www.w3.org/XML/1998/namespace" /> + <xsd:element name="root" msdata:IsDataSet="true"> + <xsd:complexType> + <xsd:choice maxOccurs="unbounded"> + <xsd:element name="metadata"> + <xsd:complexType> + <xsd:sequence> + <xsd:element name="value" type="xsd:string" minOccurs="0" /> + </xsd:sequence> + <xsd:attribute name="name" use="required" type="xsd:string" /> + <xsd:attribute name="type" type="xsd:string" /> + <xsd:attribute name="mimetype" type="xsd:string" /> + <xsd:attribute ref="xml:space" /> + </xsd:complexType> + </xsd:element> + <xsd:element name="assembly"> + <xsd:complexType> + <xsd:attribute name="alias" type="xsd:string" /> + <xsd:attribute name="name" type="xsd:string" /> + </xsd:complexType> + </xsd:element> + <xsd:element name="data"> + <xsd:complexType> + <xsd:sequence> + <xsd:element name="value" type="xsd:string" minOccurs="0" msdata:Ordinal="1" /> + <xsd:element name="comment" type="xsd:string" minOccurs="0" msdata:Ordinal="2" /> + </xsd:sequence> + <xsd:attribute name="name" type="xsd:string" use="required" msdata:Ordinal="1" /> + <xsd:attribute name="type" type="xsd:string" msdata:Ordinal="3" /> + <xsd:attribute name="mimetype" type="xsd:string" msdata:Ordinal="4" /> + <xsd:attribute ref="xml:space" /> + </xsd:complexType> + </xsd:element> + <xsd:element name="resheader"> + <xsd:complexType> + <xsd:sequence> + <xsd:element name="value" type="xsd:string" minOccurs="0" msdata:Ordinal="1" /> + </xsd:sequence> + <xsd:attribute name="name" type="xsd:string" use="required" /> + </xsd:complexType> + </xsd:element> + </xsd:choice> + </xsd:complexType> + </xsd:element> + </xsd:schema> + <resheader name="resmimetype"> + <value>text/microsoft-resx</value> + </resheader> + <resheader name="version"> + <value>2.0</value> + </resheader> + <resheader name="reader"> + <value>System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089</value> + </resheader> + <resheader name="writer"> + <value>System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089</value> + </resheader> + <data name="ChannelClosedException_DefaultMessage" xml:space="preserve"> + <value>The channel has been closed.</value> + </data> +</root>
\ No newline at end of file diff --git a/src/System.Threading.Channels/src/System.Threading.Channels.csproj b/src/System.Threading.Channels/src/System.Threading.Channels.csproj new file mode 100644 index 0000000000..deac429f30 --- /dev/null +++ b/src/System.Threading.Channels/src/System.Threading.Channels.csproj @@ -0,0 +1,45 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="14.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.props))\dir.props" /> + <PropertyGroup> + <ProjectGuid>{1DD0FF15-6234-4BD6-850A-317F05479554}</ProjectGuid> + <RootNamespace>System.Threading.Channels</RootNamespace> + <DocumentationFile>$(OutputPath)$(MSBuildProjectName).xml</DocumentationFile> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'netstandard1.3-Debug|AnyCPU'" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'netstandard1.3-Release|AnyCPU'" /> + <ItemGroup> + <Compile Include="System\VoidResult.cs" /> + <Compile Include="System\Collections\Generic\Dequeue.cs" /> + <Compile Include="System\Threading\Channels\BoundedChannel.cs" /> + <Compile Include="System\Threading\Channels\BoundedChannelFullMode.cs" /> + <Compile Include="System\Threading\Channels\Channel.cs" /> + <Compile Include="System\Threading\Channels\ChannelClosedException.cs" /> + <Compile Include="System\Threading\Channels\ChannelOptions.cs" /> + <Compile Include="System\Threading\Channels\ChannelReader.cs" /> + <Compile Include="System\Threading\Channels\ChannelUtilities.cs" /> + <Compile Include="System\Threading\Channels\ChannelWriter.cs" /> + <Compile Include="System\Threading\Channels\Channel_1.cs" /> + <Compile Include="System\Threading\Channels\Channel_2.cs" /> + <Compile Include="System\Threading\Channels\IDebugEnumerator.cs" /> + <Compile Include="System\Threading\Channels\Interactor.cs" /> + <Compile Include="System\Threading\Channels\SingleConsumerUnboundedChannel.cs" /> + <Compile Include="System\Threading\Channels\UnboundedChannel.cs" /> + <Compile Include="System\Threading\Channels\UnbufferedChannel.cs" /> + <Compile Include="$(CommonPath)\System\Collections\Concurrent\SingleProducerConsumerQueue.cs"> + <Link>Common\System\Collections\Concurrent\SingleProducerConsumerQueue.cs</Link> + </Compile> + </ItemGroup> + <ItemGroup> + <Reference Include="System.Collections" /> + <Reference Include="System.Collections.Concurrent" /> + <Reference Include="System.Diagnostics.Debug" /> + <Reference Include="System.Resources.ResourceManager" /> + <Reference Include="System.Runtime" /> + <Reference Include="System.Runtime.Extensions" /> + <Reference Include="System.Threading" /> + <Reference Include="System.Threading.Tasks" /> + <Reference Include="System.Threading.Tasks.Extensions" /> + </ItemGroup> + <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.targets))\dir.targets" /> +</Project> diff --git a/src/System.Threading.Channels/src/System/Collections/Generic/Dequeue.cs b/src/System.Threading.Channels/src/System/Collections/Generic/Dequeue.cs new file mode 100644 index 0000000000..6c44b73043 --- /dev/null +++ b/src/System.Threading.Channels/src/System/Collections/Generic/Dequeue.cs @@ -0,0 +1,124 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Diagnostics; + +namespace System.Collections.Generic +{ + /// <summary>Provides a double-ended queue data structure.</summary> + /// <typeparam name="T">Type of the data stored in the dequeue.</typeparam> + [DebuggerDisplay("Count = {_size}")] + internal sealed class Dequeue<T> + { + private T[] _array = Array.Empty<T>(); + private int _head; // First valid element in the queue + private int _tail; // First open slot in the dequeue, unless the dequeue is full + private int _size; // Number of elements. + + public int Count => _size; + + public bool IsEmpty => _size == 0; + + public void EnqueueTail(T item) + { + if (_size == _array.Length) + { + Grow(); + } + + _array[_tail] = item; + if (++_tail == _array.Length) + { + _tail = 0; + } + _size++; + } + + //// Uncomment if/when enqueueing at the head is needed + //public void EnqueueHead(T item) + //{ + // if (_size == _array.Length) + // { + // Grow(); + // } + // + // _head = (_head == 0 ? _array.Length : _head) - 1; + // _array[_head] = item; + // _size++; + //} + + public T DequeueHead() + { + Debug.Assert(!IsEmpty); // caller's responsibility to make sure there are elements remaining + + T item = _array[_head]; + _array[_head] = default; + + if (++_head == _array.Length) + { + _head = 0; + } + _size--; + + return item; + } + + public T DequeueTail() + { + Debug.Assert(!IsEmpty); // caller's responsibility to make sure there are elements remaining + + if (--_tail == -1) + { + _tail = _array.Length - 1; + } + + T item = _array[_tail]; + _array[_tail] = default; + + _size--; + return item; + } + + public IEnumerator<T> GetEnumerator() // meant for debug purposes only + { + int pos = _head; + int count = _size; + while (count-- > 0) + { + yield return _array[pos]; + pos = (pos + 1) % _array.Length; + } + } + + private void Grow() + { + Debug.Assert(_size == _array.Length); + Debug.Assert(_head == _tail); + + const int MinimumGrow = 4; + + int capacity = (int)(_array.Length * 2L); + if (capacity < _array.Length + MinimumGrow) + { + capacity = _array.Length + MinimumGrow; + } + + T[] newArray = new T[capacity]; + + if (_head == 0) + { + Array.Copy(_array, 0, newArray, 0, _size); + } + else + { + Array.Copy(_array, _head, newArray, 0, _array.Length - _head); + Array.Copy(_array, 0, newArray, _array.Length - _head, _tail); + } + + _array = newArray; + _head = 0; + _tail = _size; + } + } +} diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs b/src/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs new file mode 100644 index 0000000000..0b7ea44064 --- /dev/null +++ b/src/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs @@ -0,0 +1,411 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading.Tasks; + +namespace System.Threading.Channels +{ + /// <summary>Provides a channel with a bounded capacity.</summary> + [DebuggerDisplay("Items={ItemsCountForDebugger}, Capacity={_bufferedCapacity}")] + [DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<>))] + internal sealed class BoundedChannel<T> : Channel<T>, IDebugEnumerable<T> + { + /// <summary>The mode used when the channel hits its bound.</summary> + private readonly BoundedChannelFullMode _mode; + /// <summary>Task signaled when the channel has completed.</summary> + private readonly TaskCompletionSource<VoidResult> _completion; + /// <summary>The maximum capacity of the channel.</summary> + private readonly int _bufferedCapacity; + /// <summary>Items currently stored in the channel waiting to be read.</summary> + private readonly Dequeue<T> _items = new Dequeue<T>(); + /// <summary>Writers waiting to write to the channel.</summary> + private readonly Dequeue<WriterInteractor<T>> _blockedWriters = new Dequeue<WriterInteractor<T>>(); + /// <summary>Task signaled when any WaitToReadAsync waiters should be woken up.</summary> + private ReaderInteractor<bool> _waitingReaders; + /// <summary>Task signaled when any WaitToWriteAsync waiters should be woken up.</summary> + private ReaderInteractor<bool> _waitingWriters; + /// <summary>Whether to force continuations to be executed asynchronously from producer writes.</summary> + private readonly bool _runContinuationsAsynchronously; + /// <summary>Set to non-null once Complete has been called.</summary> + private Exception _doneWriting; + /// <summary>Gets an object used to synchronize all state on the instance.</summary> + private object SyncObj => _items; + + /// <summary>Initializes the <see cref="BoundedChannel{T}"/>.</summary> + /// <param name="bufferedCapacity">The positive bounded capacity for the channel.</param> + /// <param name="mode">The mode used when writing to a full channel.</param> + /// <param name="runContinuationsAsynchronously">Whether to force continuations to be executed asynchronously.</param> + internal BoundedChannel(int bufferedCapacity, BoundedChannelFullMode mode, bool runContinuationsAsynchronously) + { + Debug.Assert(bufferedCapacity > 0); + _bufferedCapacity = bufferedCapacity; + _mode = mode; + _runContinuationsAsynchronously = runContinuationsAsynchronously; + _completion = new TaskCompletionSource<VoidResult>(runContinuationsAsynchronously ? TaskCreationOptions.RunContinuationsAsynchronously : TaskCreationOptions.None); + Reader = new BoundedChannelReader(this); + Writer = new BoundedChannelWriter(this); + } + + private sealed class BoundedChannelReader : ChannelReader<T> + { + internal readonly BoundedChannel<T> _parent; + internal BoundedChannelReader(BoundedChannel<T> parent) => _parent = parent; + + public override Task Completion => _parent._completion.Task; + + public override bool TryRead(out T item) + { + BoundedChannel<T> parent = _parent; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + // Get an item if there is one. + if (!parent._items.IsEmpty) + { + item = DequeueItemAndPostProcess(); + return true; + } + } + + item = default; + return false; + } + + public override Task<bool> WaitToReadAsync(CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled<bool>(cancellationToken); + } + + BoundedChannel<T> parent = _parent; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + // If there are any items available, a read is possible. + if (!parent._items.IsEmpty) + { + return ChannelUtilities.s_trueTask; + } + + // There were no items available, so if we're done writing, a read will never be possible. + if (parent._doneWriting != null) + { + return parent._doneWriting != ChannelUtilities.s_doneWritingSentinel ? + Task.FromException<bool>(parent._doneWriting) : + ChannelUtilities.s_falseTask; + } + + // There were no items available, but there could be in the future, so ensure + // there's a blocked reader task and return it. + return ChannelUtilities.GetOrCreateWaiter(ref parent._waitingReaders, parent._runContinuationsAsynchronously, cancellationToken); + } + } + + /// <summary>Dequeues an item, and then fixes up our state around writers and completion.</summary> + /// <returns>The dequeued item.</returns> + private T DequeueItemAndPostProcess() + { + BoundedChannel<T> parent = _parent; + Debug.Assert(Monitor.IsEntered(parent.SyncObj)); + + // Dequeue an item. + T item = parent._items.DequeueHead(); + + // If we're now empty and we're done writing, complete the channel. + if (parent._doneWriting != null && parent._items.IsEmpty) + { + ChannelUtilities.Complete(parent._completion, parent._doneWriting); + } + + // If there are any writers blocked, there's now room for at least one + // to be promoted to have its item moved into the items queue. We need + // to loop while trying to complete the writer in order to find one that + // hasn't yet been canceled (canceled writers transition to canceled but + // remain in the physical queue). + while (!parent._blockedWriters.IsEmpty) + { + WriterInteractor<T> w = parent._blockedWriters.DequeueHead(); + if (w.Success(default(VoidResult))) + { + parent._items.EnqueueTail(w.Item); + return item; + } + } + + // There was no blocked writer, so see if there's a WaitToWriteAsync + // we should wake up. + ChannelUtilities.WakeUpWaiters(ref parent._waitingWriters, result: true); + + // Return the item + return item; + } + } + + private sealed class BoundedChannelWriter : ChannelWriter<T> + { + internal readonly BoundedChannel<T> _parent; + internal BoundedChannelWriter(BoundedChannel<T> parent) => _parent = parent; + + public override bool TryComplete(Exception error) + { + BoundedChannel<T> parent = _parent; + bool completeTask; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + // If we've already marked the channel as completed, bail. + if (parent._doneWriting != null) + { + return false; + } + + // Mark that we're done writing. + parent._doneWriting = error ?? ChannelUtilities.s_doneWritingSentinel; + completeTask = parent._items.IsEmpty; + } + + // If there are no items in the queue, complete the channel's task, + // as no more data can possibly arrive at this point. We do this outside + // of the lock in case we'll be running synchronous completions, and we + // do it before completing blocked/waiting readers, so that when they + // wake up they'll see the task as being completed. + if (completeTask) + { + ChannelUtilities.Complete(parent._completion, error); + } + + // At this point, _blockedWriters and _waitingReaders/Writers will not be mutated: + // they're only mutated by readers/writers while holding the lock, and only if _doneWriting is null. + // We also know that only one thread (this one) will ever get here, as only that thread + // will be the one to transition from _doneWriting false to true. As such, we can + // freely manipulate them without any concurrency concerns. + ChannelUtilities.FailInteractors<WriterInteractor<T>, VoidResult>(parent._blockedWriters, ChannelUtilities.CreateInvalidCompletionException(error)); + ChannelUtilities.WakeUpWaiters(ref parent._waitingReaders, result: false, error: error); + ChannelUtilities.WakeUpWaiters(ref parent._waitingWriters, result: false, error: error); + + // Successfully transitioned to completed. + return true; + } + + public override bool TryWrite(T item) + { + ReaderInteractor<bool> waitingReaders = null; + + BoundedChannel<T> parent = _parent; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + // If we're done writing, nothing more to do. + if (parent._doneWriting != null) + { + return false; + } + + // Get the number of items in the channel currently. + int count = parent._items.Count; + + if (count == 0) + { + // There are no items in the channel, which means we may have waiting readers. + // Store the item. + parent._items.EnqueueTail(item); + waitingReaders = parent._waitingReaders; + if (waitingReaders == null) + { + // If no one's waiting to be notified about a 0-to-1 transition, we're done. + return true; + } + parent._waitingReaders = null; + } + else if (count < parent._bufferedCapacity) + { + // There's room in the channel. Since we're not transitioning from 0-to-1 and + // since there's room, we can simply store the item and exit without having to + // worry about blocked/waiting readers. + parent._items.EnqueueTail(item); + return true; + } + else if (parent._mode == BoundedChannelFullMode.Wait) + { + // The channel is full and we're in a wait mode. + // Simply exit and let the caller know we didn't write the data. + return false; + } + else if (parent._mode == BoundedChannelFullMode.DropWrite) + { + // The channel is full. Just ignore the item being added + // but say we added it. + return true; + } + else + { + // The channel is full, and we're in a dropping mode. + // Drop either the oldest or the newest and write the new item. + T droppedItem = parent._mode == BoundedChannelFullMode.DropNewest ? + parent._items.DequeueTail() : + parent._items.DequeueHead(); + parent._items.EnqueueTail(item); + return true; + } + } + + // We stored an item bringing the count up from 0 to 1. Alert + // any waiting readers that there may be something for them to consume. + // Since we're no longer holding the lock, it's possible we'll end up + // waking readers that have since come in. + waitingReaders.Success(item: true); + return true; + } + + public override Task<bool> WaitToWriteAsync(CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled<bool>(cancellationToken); + } + + BoundedChannel<T> parent = _parent; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + // If we're done writing, no writes will ever succeed. + if (parent._doneWriting != null) + { + return parent._doneWriting != ChannelUtilities.s_doneWritingSentinel ? + Task.FromException<bool>(parent._doneWriting) : + ChannelUtilities.s_falseTask; + } + + // If there's space to write, a write is possible. + // And if the mode involves dropping/ignoring, we can always write, as even if it's + // full we'll just drop an element to make room. + if (parent._items.Count < parent._bufferedCapacity || parent._mode != BoundedChannelFullMode.Wait) + { + return ChannelUtilities.s_trueTask; + } + + // We're still allowed to write, but there's no space, so ensure a waiter is queued and return it. + return ChannelUtilities.GetOrCreateWaiter(ref parent._waitingWriters, true, cancellationToken); + } + } + + public override Task WriteAsync(T item, CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + ReaderInteractor<bool> waitingReaders = null; + + BoundedChannel<T> parent = _parent; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + // If we're done writing, trying to write is an error. + if (parent._doneWriting != null) + { + return Task.FromException(ChannelUtilities.CreateInvalidCompletionException(parent._doneWriting)); + } + + // Get the number of items in the channel currently. + int count = parent._items.Count; + + if (count == 0) + { + // There are no items in the channel, which means we may have waiting readers. + // Store the item. + parent._items.EnqueueTail(item); + waitingReaders = parent._waitingReaders; + if (waitingReaders == null) + { + // If no one's waiting to be notified about a 0-to-1 transition, we're done. + return ChannelUtilities.s_trueTask; + } + parent._waitingReaders = null; + } + else if (count < parent._bufferedCapacity) + { + // There's room in the channel. Since we're not transitioning from 0-to-1 and + // since there's room, we can simply store the item and exit without having to + // worry about blocked/waiting readers. + parent._items.EnqueueTail(item); + return ChannelUtilities.s_trueTask; + } + else if (parent._mode == BoundedChannelFullMode.Wait) + { + // The channel is full and we're in a wait mode. + // Queue the writer. + var writer = WriterInteractor<T>.Create(true, cancellationToken, item); + parent._blockedWriters.EnqueueTail(writer); + return writer.Task; + } + else if (parent._mode == BoundedChannelFullMode.DropWrite) + { + // The channel is full and we're in ignore mode. + // Ignore the item but say we accepted it. + return ChannelUtilities.s_trueTask; + } + else + { + // The channel is full, and we're in a dropping mode. + // Drop either the oldest or the newest and write the new item. + T droppedItem = parent._mode == BoundedChannelFullMode.DropNewest ? + parent._items.DequeueTail() : + parent._items.DequeueHead(); + parent._items.EnqueueTail(item); + return ChannelUtilities.s_trueTask; + } + } + + // We stored an item bringing the count up from 0 to 1. Alert + // any waiting readers that there may be something for them to consume. + // Since we're no longer holding the lock, it's possible we'll end up + // waking readers that have since come in. + waitingReaders.Success(item: true); + return ChannelUtilities.s_trueTask; + } + } + + [Conditional("DEBUG")] + private void AssertInvariants() + { + Debug.Assert(SyncObj != null, "The sync obj must not be null."); + Debug.Assert(Monitor.IsEntered(SyncObj), "Invariants can only be validated while holding the lock."); + + if (!_items.IsEmpty) + { + Debug.Assert(_waitingReaders == null, "There are items available, so there shouldn't be any waiting readers."); + } + if (_items.Count < _bufferedCapacity) + { + Debug.Assert(_blockedWriters.IsEmpty, "There's space available, so there shouldn't be any blocked writers."); + Debug.Assert(_waitingWriters == null, "There's space available, so there shouldn't be any waiting writers."); + } + if (!_blockedWriters.IsEmpty) + { + Debug.Assert(_items.Count == _bufferedCapacity, "We should have a full buffer if there's a blocked writer."); + } + if (_completion.Task.IsCompleted) + { + Debug.Assert(_doneWriting != null, "We can only complete if we're done writing."); + } + } + + /// <summary>Gets the number of items in the channel. This should only be used by the debugger.</summary> + private int ItemsCountForDebugger => _items.Count; + + /// <summary>Gets an enumerator the debugger can use to show the contents of the channel.</summary> + IEnumerator<T> IDebugEnumerable<T>.GetEnumerator() => _items.GetEnumerator(); + } +} diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/BoundedChannelFullMode.cs b/src/System.Threading.Channels/src/System/Threading/Channels/BoundedChannelFullMode.cs new file mode 100644 index 0000000000..385256541e --- /dev/null +++ b/src/System.Threading.Channels/src/System/Threading/Channels/BoundedChannelFullMode.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Threading.Channels +{ + /// <summary>Specifies the behavior to use when writing to a bounded channel that is already full.</summary> + public enum BoundedChannelFullMode + { + /// <summary>Wait for space to be available in order to complete the write operation.</summary> + Wait, + /// <summary>Remove and ignore the newest item in the channel in order to make room for the item being written.</summary> + DropNewest, + /// <summary>Remove and ignore the oldest item in the channel in order to make room for the item being written.</summary> + DropOldest, + /// <summary>Drop the item being written.</summary> + DropWrite + } +} diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/Channel.cs b/src/System.Threading.Channels/src/System/Threading/Channels/Channel.cs new file mode 100644 index 0000000000..9dedd4ef9f --- /dev/null +++ b/src/System.Threading.Channels/src/System/Threading/Channels/Channel.cs @@ -0,0 +1,76 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Threading.Channels +{ + /// <summary>Provides static methods for creating channels.</summary> + public static class Channel + { + /// <summary>Creates an unbounded channel usable by any number of readers and writers concurrently.</summary> + /// <returns>The created channel.</returns> + public static Channel<T> CreateUnbounded<T>() => + new UnboundedChannel<T>(runContinuationsAsynchronously: true); + + /// <summary>Creates an unbounded channel subject to the provided options.</summary> + /// <typeparam name="T">Specifies the type of data in the channel.</typeparam> + /// <param name="options">Options that guide the behavior of the channel.</param> + /// <returns>The created channel.</returns> + public static Channel<T> CreateUnbounded<T>(UnboundedChannelOptions options) => + options == null ? throw new ArgumentOutOfRangeException(nameof(options)) : + options.SingleReader ? new SingleConsumerUnboundedChannel<T>(!options.AllowSynchronousContinuations) : + (Channel<T>)new UnboundedChannel<T>(!options.AllowSynchronousContinuations); + + /// <summary>Creates a channel with the specified maximum capacity.</summary> + /// <typeparam name="T">Specifies the type of data in the channel.</typeparam> + /// <param name="capacity">The maximum number of items the channel may store.</param> + /// <returns>The created channel.</returns> + /// <remarks> + /// Channels created with this method apply the <see cref="BoundedChannelFullMode.Wait"/> + /// behavior and prohibit continuations from running synchronously. + /// </remarks> + public static Channel<T> CreateBounded<T>(int capacity) + { + if (capacity < 1) + { + throw new ArgumentOutOfRangeException(nameof(capacity)); + } + + return new BoundedChannel<T>(capacity, BoundedChannelFullMode.Wait, runContinuationsAsynchronously: true); + } + + /// <summary>Creates a channel with the specified maximum capacity.</summary> + /// <typeparam name="T">Specifies the type of data in the channel.</typeparam> + /// <param name="options">Options that guide the behavior of the channel.</param> + /// <returns>The created channel.</returns> + public static Channel<T> CreateBounded<T>(BoundedChannelOptions options) + { + if (options == null) + { + throw new ArgumentOutOfRangeException(nameof(options)); + } + + return new BoundedChannel<T>(options.Capacity, options.FullMode, !options.AllowSynchronousContinuations); + } + + /// <summary>Creates a channel that doesn't buffer any items.</summary> + /// <typeparam name="T">Specifies the type of data in the channel.</typeparam> + /// <returns>The created channel.</returns> + public static Channel<T> CreateUnbuffered<T>() => + new UnbufferedChannel<T>(); + + /// <summary>Creates a channel that doesn't buffer any items.</summary> + /// <typeparam name="T">Specifies the type of data in the channel.</typeparam> + /// <param name="options">Options that guide the behavior of the channel.</param> + /// <returns>The created channel.</returns> + public static Channel<T> CreateUnbuffered<T>(UnbufferedChannelOptions options) + { + if (options == null) + { + throw new ArgumentOutOfRangeException(nameof(options)); + } + + return new UnbufferedChannel<T>(); + } + } +} diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/ChannelClosedException.cs b/src/System.Threading.Channels/src/System/Threading/Channels/ChannelClosedException.cs new file mode 100644 index 0000000000..fe9efc615e --- /dev/null +++ b/src/System.Threading.Channels/src/System/Threading/Channels/ChannelClosedException.cs @@ -0,0 +1,28 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Threading.Channels +{ + /// <summary>Exception thrown when a channel is used after it's been closed.</summary> + public class ChannelClosedException : InvalidOperationException + { + /// <summary>Initializes a new instance of the <see cref="ChannelClosedException"/> class.</summary> + public ChannelClosedException() : + base(SR.ChannelClosedException_DefaultMessage) { } + + /// <summary>Initializes a new instance of the <see cref="ChannelClosedException"/> class.</summary> + /// <param name="message">The message that describes the error.</param> + public ChannelClosedException(string message) : base(message) { } + + /// <summary>Initializes a new instance of the <see cref="ChannelClosedException"/> class.</summary> + /// <param name="innerException">The exception that is the cause of this exception.</param> + public ChannelClosedException(Exception innerException) : + base(SR.ChannelClosedException_DefaultMessage, innerException) { } + + /// <summary>Initializes a new instance of the <see cref="ChannelClosedException"/> class.</summary> + /// <param name="message">The message that describes the error.</param> + /// <param name="innerException">The exception that is the cause of this exception.</param> + public ChannelClosedException(string message, Exception innerException) : base(message, innerException) { } + } +} diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/ChannelOptions.cs b/src/System.Threading.Channels/src/System/Threading/Channels/ChannelOptions.cs new file mode 100644 index 0000000000..9172889de8 --- /dev/null +++ b/src/System.Threading.Channels/src/System/Threading/Channels/ChannelOptions.cs @@ -0,0 +1,107 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Threading.Channels +{ + /// <summary>Provides options that control the behavior of channel instances.</summary> + public abstract class ChannelOptions + { + /// <summary> + /// <code>true</code> if writers to the channel guarantee that there will only ever be at most one write operation + /// at a time; <code>false</code> if no such constraint is guaranteed. + /// </summary> + /// <remarks> + /// If true, the channel may be able to optimize certain operations based on knowing about the single-writer guarantee. + /// The default is false. + /// </remarks> + public bool SingleWriter { get; set; } + + /// <summary> + /// <code>true</code> readers from the channel guarantee that there will only ever be at most one read operation at a time; + /// <code>false</code> if no such constraint is guaranteed. + /// </summary> + /// <remarks> + /// If true, the channel may be able to optimize certain operations based on knowing about the single-reader guarantee. + /// The default is false. + /// </remarks> + public bool SingleReader { get; set; } + + /// <summary> + /// <code>true</code> if operations performed on a channel may synchronously invoke continuations subscribed to + /// notifications of pending async operations; <code>false</code> if all continuations should be invoked asynchronously. + /// </summary> + /// <remarks> + /// Setting this option to <code>true</code> can provide measurable throughput improvements by avoiding + /// scheduling additional work items. However, it may come at the cost of reduced parallelism, as for example a producer + /// may then be the one to execute work associated with a consumer, and if not done thoughtfully, this can lead + /// to unexpected interactions. The default is false. + /// </remarks> + public bool AllowSynchronousContinuations { get; set; } + } + + /// <summary>Provides options that control the behavior of <see cref="BoundedChannel{T}"/> instances.</summary> + public sealed class BoundedChannelOptions : ChannelOptions + { + /// <summary>The maximum number of items the bounded channel may store.</summary> + private int _capacity; + /// <summary>The behavior incurred by write operations when the channel is full.</summary> + private BoundedChannelFullMode _mode = BoundedChannelFullMode.Wait; + + /// <summary>Initializes the options.</summary> + /// <param name="capacity">The maximum number of items the bounded channel may store.</param> + public BoundedChannelOptions(int capacity) + { + if (capacity < 1) + { + throw new ArgumentOutOfRangeException(nameof(capacity)); + } + + Capacity = capacity; + } + + /// <summary>Gets or sets the maximum number of items the bounded channel may store.</summary> + public int Capacity + { + get => _capacity; + set + { + if (value < 1) + { + throw new ArgumentOutOfRangeException(nameof(value)); + } + _capacity = value; + } + } + + /// <summary>Gets or sets the behavior incurred by write operations when the channel is full.</summary> + public BoundedChannelFullMode FullMode + { + get => _mode; + set + { + switch (value) + { + case BoundedChannelFullMode.Wait: + case BoundedChannelFullMode.DropNewest: + case BoundedChannelFullMode.DropOldest: + case BoundedChannelFullMode.DropWrite: + _mode = value; + break; + default: + throw new ArgumentOutOfRangeException(nameof(value)); + } + } + } + } + + /// <summary>Provides options that control the behavior of <see cref="UnboundedChannel{T}"/> instances.</summary> + public sealed class UnboundedChannelOptions : ChannelOptions + { + } + + /// <summary>Provides options that control the behavior of <see cref="UnbufferedChannel{T}"/> instances.</summary> + public sealed class UnbufferedChannelOptions : ChannelOptions + { + } +} diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/ChannelReader.cs b/src/System.Threading.Channels/src/System/Threading/Channels/ChannelReader.cs new file mode 100644 index 0000000000..a5d7d806ce --- /dev/null +++ b/src/System.Threading.Channels/src/System/Threading/Channels/ChannelReader.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Threading.Tasks; + +namespace System.Threading.Channels +{ + /// <summary> + /// Provides a base class for reading from a channel. + /// </summary> + /// <typeparam name="T">Specifies the type of data that may be read from the channel.</typeparam> + public abstract class ChannelReader<T> + { + /// <summary> + /// Gets a <see cref="Task"/> that completes when no more data will ever + /// be available to be read from this channel. + /// </summary> + public virtual Task Completion => ChannelUtilities.s_neverCompletingTask; + + /// <summary>Attempts to read an item to the channel.</summary> + /// <param name="item">The read item, or a default value if no item could be read.</param> + /// <returns>true if an item was read; otherwise, false if no item was read.</returns> + public abstract bool TryRead(out T item); + + /// <summary>Returns a <see cref="Task{Boolean}"/> that will complete when data is available to read.</summary> + /// <param name="cancellationToken">A <see cref="CancellationToken"/> used to cancel the wait operation.</param> + /// <returns> + /// A <see cref="Task{Boolean}"/> that will complete with a <c>true</c> result when data is available to read + /// or with a <c>false</c> result when no further data will ever be available to be read. + /// </returns> + public abstract Task<bool> WaitToReadAsync(CancellationToken cancellationToken = default(CancellationToken)); + } +} diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs b/src/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs new file mode 100644 index 0000000000..a7411c4fb2 --- /dev/null +++ b/src/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs @@ -0,0 +1,138 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading.Tasks; + +namespace System.Threading.Channels +{ + /// <summary>Provides internal helper methods for implementing channels.</summary> + internal static class ChannelUtilities + { + /// <summary>Sentinel object used to indicate being done writing.</summary> + internal static readonly Exception s_doneWritingSentinel = new Exception(nameof(s_doneWritingSentinel)); + /// <summary>A cached task with a Boolean true result.</summary> + internal static readonly Task<bool> s_trueTask = Task.FromResult(true); + /// <summary>A cached task with a Boolean false result.</summary> + internal static readonly Task<bool> s_falseTask = Task.FromResult(false); + /// <summary>A cached task that never completes.</summary> + internal static readonly Task s_neverCompletingTask = new TaskCompletionSource<bool>().Task; + + /// <summary>Completes the specified TaskCompletionSource.</summary> + /// <param name="tcs">The source to complete.</param> + /// <param name="error"> + /// The optional exception with which to complete. + /// If this is null or the DoneWritingSentinel, the source will be completed successfully. + /// If this is an OperationCanceledException, it'll be completed with the exception's token. + /// Otherwise, it'll be completed as faulted with the exception. + /// </param> + internal static void Complete(TaskCompletionSource<VoidResult> tcs, Exception error = null) + { + if (error is OperationCanceledException oce) + { + tcs.TrySetCanceled(oce.CancellationToken); + } + else if (error != null && error != s_doneWritingSentinel) + { + tcs.TrySetException(error); + } + else + { + tcs.TrySetResult(default(VoidResult)); + } + } + + /// <summary>Gets a value task representing an error.</summary> + /// <typeparam name="T">Specifies the type of the value that would have been returned.</typeparam> + /// <param name="error">The error. This may be <see cref="s_doneWritingSentinel"/>.</param> + /// <returns>The failed task.</returns> + internal static ValueTask<T> GetInvalidCompletionValueTask<T>(Exception error) + { + Debug.Assert(error != null); + + Task<T> t = + error == s_doneWritingSentinel ? Task.FromException<T>(CreateInvalidCompletionException()) : + error is OperationCanceledException oce ? Task.FromCanceled<T>(oce.CancellationToken.IsCancellationRequested ? oce.CancellationToken : new CancellationToken(true)) : + Task.FromException<T>(CreateInvalidCompletionException(error)); + + return new ValueTask<T>(t); + } + + /// <summary>Wake up all of the waiters and null out the field.</summary> + /// <param name="waiters">The waiters.</param> + /// <param name="result">The value with which to complete each waiter.</param> + internal static void WakeUpWaiters(ref ReaderInteractor<bool> waiters, bool result) + { + ReaderInteractor<bool> w = waiters; + if (w != null) + { + w.Success(result); + waiters = null; + } + } + + /// <summary>Wake up all of the waiters and null out the field.</summary> + /// <param name="waiters">The waiters.</param> + /// <param name="result">The success value with which to complete each waiter if <paramref name="error">error</paramref> is null.</param> + /// <param name="error">The failure with which to cmplete each waiter, if non-null.</param> + internal static void WakeUpWaiters(ref ReaderInteractor<bool> waiters, bool result, Exception error = null) + { + ReaderInteractor<bool> w = waiters; + if (w != null) + { + if (error != null) + { + w.Fail(error); + } + else + { + w.Success(result); + } + waiters = null; + } + } + + /// <summary>Removes all interactors from the queue, failing each.</summary> + /// <param name="interactors">The queue of interactors to complete.</param> + /// <param name="error">The error with which to complete each interactor.</param> + internal static void FailInteractors<T, TInner>(Dequeue<T> interactors, Exception error) where T : Interactor<TInner> + { + while (!interactors.IsEmpty) + { + interactors.DequeueHead().Fail(error ?? CreateInvalidCompletionException()); + } + } + + /// <summary>Gets or creates a "waiter" (e.g. WaitForRead/WriteAsync) interactor.</summary> + /// <param name="waiter">The field storing the waiter interactor.</param> + /// <param name="runContinuationsAsynchronously">true to force continuations to run asynchronously; otherwise, false.</param> + /// <param name="cancellationToken">The token to use to cancel the wait.</param> + internal static Task<bool> GetOrCreateWaiter(ref ReaderInteractor<bool> waiter, bool runContinuationsAsynchronously, CancellationToken cancellationToken) + { + // Get the existing waiters interactor. + ReaderInteractor<bool> w = waiter; + + // If there isn't one, create one. This explicitly does not include the cancellation token, + // as we reuse it for any number of waiters that overlap. + if (w == null) + { + waiter = w = ReaderInteractor<bool>.Create(runContinuationsAsynchronously); + } + + // If the cancellation token can't be canceled, then just return the waiter task. + // If it can, we need to return a task that will complete when the waiter task does but that can also be canceled. + // Easiest way to do that is with a cancelable continuation. + return cancellationToken.CanBeCanceled ? + w.Task.ContinueWith(t => t.Result, cancellationToken, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default) : + w.Task; + } + + /// <summary>Creates and returns an exception object to indicate that a channel has been closed.</summary> + internal static Exception CreateInvalidCompletionException(Exception inner = null) => + inner is OperationCanceledException ? inner : + inner != null && inner != s_doneWritingSentinel ? new ChannelClosedException(inner) : + new ChannelClosedException(); + } +} diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/ChannelWriter.cs b/src/System.Threading.Channels/src/System/Threading/Channels/ChannelWriter.cs new file mode 100644 index 0000000000..e96c1fff75 --- /dev/null +++ b/src/System.Threading.Channels/src/System/Threading/Channels/ChannelWriter.cs @@ -0,0 +1,79 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Threading.Tasks; + +namespace System.Threading.Channels +{ + /// <summary> + /// Provides a base class for writing to a channel. + /// </summary> + /// <typeparam name="T">Specifies the type of data that may be written to the channel.</typeparam> + public abstract class ChannelWriter<T> + { + /// <summary>Attempts to mark the channel as being completed, meaning no more data will be written to it.</summary> + /// <param name="error">An <see cref="Exception"/> indicating the failure causing no more data to be written, or null for success.</param> + /// <returns> + /// true if this operation successfully completes the channel; otherwise, false if the channel could not be marked for completion, + /// for example due to having already been marked as such, or due to not supporting completion. + /// </returns> + public virtual bool TryComplete(Exception error = null) => false; + + /// <summary>Attempts to write the specified item to the channel.</summary> + /// <param name="item">The item to write.</param> + /// <returns>true if the item was written; otherwise, false if it wasn't written.</returns> + public abstract bool TryWrite(T item); + + /// <summary>Returns a <see cref="Task{Boolean}"/> that will complete when space is available to write an item.</summary> + /// <param name="cancellationToken">A <see cref="CancellationToken"/> used to cancel the wait operation.</param> + /// <returns> + /// A <see cref="Task{Boolean}"/> that will complete with a <c>true</c> result when space is available to write an item + /// or with a <c>false</c> result when no further writing will be permitted. + /// </returns> + public abstract Task<bool> WaitToWriteAsync(CancellationToken cancellationToken = default(CancellationToken)); + + /// <summary>Asynchronously writes an item to the channel.</summary> + /// <param name="item">The value to write to the channel.</param> + /// <param name="cancellationToken">A <see cref="CancellationToken"/> used to cancel the write operation.</param> + /// <returns>A <see cref="Task"/> that represents the asynchronous write operation.</returns> + public virtual Task WriteAsync(T item, CancellationToken cancellationToken = default(CancellationToken)) + { + try + { + return + cancellationToken.IsCancellationRequested ? Task.FromCanceled<T>(cancellationToken) : + TryWrite(item) ? Task.CompletedTask : + WriteAsyncCore(item, cancellationToken); + } + catch (Exception e) + { + return Task.FromException(e); + } + + async Task WriteAsyncCore(T innerItem, CancellationToken ct) + { + while (await WaitToWriteAsync(ct).ConfigureAwait(false)) + { + if (TryWrite(innerItem)) + { + return; + } + } + + throw ChannelUtilities.CreateInvalidCompletionException(); + } + } + + /// <summary>Mark the channel as being complete, meaning no more items will be written to it.</summary> + /// <param name="error">Optional Exception indicating a failure that's causing the channel to complete.</param> + /// <exception cref="InvalidOperationException">The channel has already been marked as complete.</exception> + public void Complete(Exception error = null) + { + if (!TryComplete(error)) + { + throw ChannelUtilities.CreateInvalidCompletionException(); + } + } + } +} diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/Channel_1.cs b/src/System.Threading.Channels/src/System/Threading/Channels/Channel_1.cs new file mode 100644 index 0000000000..c10dea341d --- /dev/null +++ b/src/System.Threading.Channels/src/System/Threading/Channels/Channel_1.cs @@ -0,0 +1,10 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Threading.Channels +{ + /// <summary>Provides a base class for channels that support reading and writing elements of type <typeparamref name="T"/>.</summary> + /// <typeparam name="T">Specifies the type of data readable and writable in the channel.</typeparam> + public abstract class Channel<T> : Channel<T, T> { } +} diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/Channel_2.cs b/src/System.Threading.Channels/src/System/Threading/Channels/Channel_2.cs new file mode 100644 index 0000000000..d8e2b28848 --- /dev/null +++ b/src/System.Threading.Channels/src/System/Threading/Channels/Channel_2.cs @@ -0,0 +1,29 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Threading.Channels +{ + /// <summary> + /// Provides a base class for channels that support reading elements of type <typeparamref name="TRead"/> + /// and writing elements of type <typeparamref name="TWrite"/>. + /// </summary> + /// <typeparam name="TWrite">Specifies the type of data that may be written to the channel.</typeparam> + /// <typeparam name="TRead">Specifies the type of data that may be read from the channel.</typeparam> + public abstract class Channel<TWrite, TRead> + { + /// <summary>Gets the readable half of this channel.</summary> + public ChannelReader<TRead> Reader { get; protected set; } + + /// <summary>Gets the writable half of this channel.</summary> + public ChannelWriter<TWrite> Writer { get; protected set; } + + /// <summary>Implicit cast from a channel to its readable half.</summary> + /// <param name="channel">The channel being cast.</param> + public static implicit operator ChannelReader<TRead>(Channel<TWrite, TRead> channel) => channel.Reader; + + /// <summary>Implicit cast from a channel to its writable half.</summary> + /// <param name="channel">The channel being cast.</param> + public static implicit operator ChannelWriter<TWrite>(Channel<TWrite, TRead> channel) => channel.Writer; + } +} diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/IDebugEnumerator.cs b/src/System.Threading.Channels/src/System/Threading/Channels/IDebugEnumerator.cs new file mode 100644 index 0000000000..d83eb78a23 --- /dev/null +++ b/src/System.Threading.Channels/src/System/Threading/Channels/IDebugEnumerator.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Diagnostics; + +namespace System.Threading.Channels +{ + interface IDebugEnumerable<T> + { + IEnumerator<T> GetEnumerator(); + } + + internal sealed class DebugEnumeratorDebugView<T> + { + public DebugEnumeratorDebugView(IDebugEnumerable<T> enumerable) + { + var list = new List<T>(); + foreach (T item in enumerable) + { + list.Add(item); + } + Items = list.ToArray(); + } + + [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)] + public T[] Items { get; } + } +} diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/Interactor.cs b/src/System.Threading.Channels/src/System/Threading/Channels/Interactor.cs new file mode 100644 index 0000000000..cac9aaa62e --- /dev/null +++ b/src/System.Threading.Channels/src/System/Threading/Channels/Interactor.cs @@ -0,0 +1,101 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Threading.Tasks; + +namespace System.Threading.Channels +{ + internal abstract class Interactor<T> : TaskCompletionSource<T> + { + protected Interactor(bool runContinuationsAsynchronously) : + base(runContinuationsAsynchronously ? TaskCreationOptions.RunContinuationsAsynchronously : TaskCreationOptions.None) { } + + internal bool Success(T item) + { + UnregisterCancellation(); + return TrySetResult(item); + } + + internal bool Fail(Exception exception) + { + UnregisterCancellation(); + return TrySetException(exception); + } + + internal virtual void UnregisterCancellation() { } + } + + internal class ReaderInteractor<T> : Interactor<T> + { + protected ReaderInteractor(bool runContinuationsAsynchronously) : base(runContinuationsAsynchronously) { } + + public static ReaderInteractor<T> Create(bool runContinuationsAsynchronously) => + new ReaderInteractor<T>(runContinuationsAsynchronously); + + public static ReaderInteractor<T> Create(bool runContinuationsAsynchronously, CancellationToken cancellationToken) => + cancellationToken.CanBeCanceled ? + new CancelableReaderInteractor<T>(runContinuationsAsynchronously, cancellationToken) : + new ReaderInteractor<T>(runContinuationsAsynchronously); + } + + internal class WriterInteractor<T> : Interactor<VoidResult> + { + protected WriterInteractor(bool runContinuationsAsynchronously) : base(runContinuationsAsynchronously) { } + + internal T Item { get; private set; } + + public static WriterInteractor<T> Create(bool runContinuationsAsynchronously, CancellationToken cancellationToken, T item) + { + WriterInteractor<T> w = cancellationToken.CanBeCanceled ? + new CancelableWriter<T>(runContinuationsAsynchronously, cancellationToken) : + new WriterInteractor<T>(runContinuationsAsynchronously); + w.Item = item; + return w; + } + } + + internal sealed class CancelableReaderInteractor<T> : ReaderInteractor<T> + { + private CancellationToken _token; + private CancellationTokenRegistration _registration; + + internal CancelableReaderInteractor(bool runContinuationsAsynchronously, CancellationToken cancellationToken) : base(runContinuationsAsynchronously) + { + _token = cancellationToken; + _registration = cancellationToken.Register(s => + { + var thisRef = (CancelableReaderInteractor<T>)s; + thisRef.TrySetCanceled(thisRef._token); + }, this); + } + + internal override void UnregisterCancellation() + { + _registration.Dispose(); + _registration = default(CancellationTokenRegistration); + } + } + + internal sealed class CancelableWriter<T> : WriterInteractor<T> + { + private CancellationToken _token; + private CancellationTokenRegistration _registration; + + internal CancelableWriter(bool runContinuationsAsynchronously, CancellationToken cancellationToken) : base(runContinuationsAsynchronously) + { + _token = cancellationToken; + _registration = cancellationToken.Register(s => + { + var thisRef = (CancelableWriter<T>)s; + thisRef.TrySetCanceled(thisRef._token); + }, this); + } + + internal override void UnregisterCancellation() + { + _registration.Dispose(); + _registration = default(CancellationTokenRegistration); + } + } +} diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/SingleConsumerUnboundedChannel.cs b/src/System.Threading.Channels/src/System/Threading/Channels/SingleConsumerUnboundedChannel.cs new file mode 100644 index 0000000000..294b58859a --- /dev/null +++ b/src/System.Threading.Channels/src/System/Threading/Channels/SingleConsumerUnboundedChannel.cs @@ -0,0 +1,236 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading.Tasks; + +namespace System.Threading.Channels +{ + /// <summary> + /// Provides a buffered channel of unbounded capacity for use by any number + /// of writers but at most a single reader at a time. + /// </summary> + [DebuggerDisplay("Items={ItemsCountForDebugger}")] + [DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<>))] + internal sealed class SingleConsumerUnboundedChannel<T> : Channel<T>, IDebugEnumerable<T> + { + /// <summary>Task that indicates the channel has completed.</summary> + private readonly TaskCompletionSource<VoidResult> _completion; + /// <summary> + /// A concurrent queue to hold the items for this channel. The queue itself supports at most + /// one writer and one reader at a time; as a result, since this channel supports multiple writers, + /// all write access to the queue must be synchronized by the channel. + /// </summary> + private readonly SingleProducerSingleConsumerQueue<T> _items = new SingleProducerSingleConsumerQueue<T>(); + /// <summary>Whether to force continuations to be executed asynchronously from producer writes.</summary> + private readonly bool _runContinuationsAsynchronously; + + /// <summary>non-null if the channel has been marked as complete for writing.</summary> + private volatile Exception _doneWriting; + + /// <summary>A waiting reader (e.g. WaitForReadAsync) if there is one.</summary> + private ReaderInteractor<bool> _waitingReader; + + /// <summary>Initialize the channel.</summary> + /// <param name="runContinuationsAsynchronously">Whether to force continuations to be executed asynchronously.</param> + internal SingleConsumerUnboundedChannel(bool runContinuationsAsynchronously) + { + _runContinuationsAsynchronously = runContinuationsAsynchronously; + _completion = new TaskCompletionSource<VoidResult>(runContinuationsAsynchronously ? TaskCreationOptions.RunContinuationsAsynchronously : TaskCreationOptions.None); + + Reader = new UnboundedChannelReader(this); + Writer = new UnboundedChannelWriter(this); + } + + private sealed class UnboundedChannelReader : ChannelReader<T> + { + internal readonly SingleConsumerUnboundedChannel<T> _parent; + internal UnboundedChannelReader(SingleConsumerUnboundedChannel<T> parent) => _parent = parent; + + public override Task Completion => _parent._completion.Task; + + public override bool TryRead(out T item) + { + SingleConsumerUnboundedChannel<T> parent = _parent; + if (parent._items.TryDequeue(out item)) + { + if (parent._doneWriting != null && parent._items.IsEmpty) + { + ChannelUtilities.Complete(parent._completion, parent._doneWriting); + } + return true; + } + return false; + } + + public override Task<bool> WaitToReadAsync(CancellationToken cancellationToken) + { + // Outside of the lock, check if there are any items waiting to be read. If there are, we're done. + return !_parent._items.IsEmpty ? + ChannelUtilities.s_trueTask : + WaitToReadAsyncCore(cancellationToken); + + Task<bool> WaitToReadAsyncCore(CancellationToken ct) + { + // Now check for cancellation. + if (ct.IsCancellationRequested) + { + return Task.FromCanceled<bool>(ct); + } + + SingleConsumerUnboundedChannel<T> parent = _parent; + ReaderInteractor<bool> oldWaiter = null, newWaiter; + lock (parent.SyncObj) + { + // Again while holding the lock, check to see if there are any items available. + if (!parent._items.IsEmpty) + { + return ChannelUtilities.s_trueTask; + } + + // There aren't any items; if we're done writing, there never will be more items. + if (parent._doneWriting != null) + { + return parent._doneWriting != ChannelUtilities.s_doneWritingSentinel ? + Task.FromException<bool>(parent._doneWriting) : + ChannelUtilities.s_falseTask; + } + + // Create the new waiter. We're a bit more tolerant of a stray waiting reader + // than we are of a blocked reader, as with usage patterns it's easier to leave one + // behind, so we just cancel any that may have been waiting around. + oldWaiter = parent._waitingReader; + parent._waitingReader = newWaiter = ReaderInteractor<bool>.Create(parent._runContinuationsAsynchronously, ct); + } + + oldWaiter?.TrySetCanceled(); + return newWaiter.Task; + } + } + } + + private sealed class UnboundedChannelWriter : ChannelWriter<T> + { + internal readonly SingleConsumerUnboundedChannel<T> _parent; + internal UnboundedChannelWriter(SingleConsumerUnboundedChannel<T> parent) => _parent = parent; + + public override bool TryComplete(Exception error) + { + ReaderInteractor<bool> waitingReader = null; + bool completeTask = false; + + SingleConsumerUnboundedChannel<T> parent = _parent; + lock (parent.SyncObj) + { + // If we're already marked as complete, there's nothing more to do. + if (parent._doneWriting != null) + { + return false; + } + + // Mark as complete for writing. + parent._doneWriting = error ?? ChannelUtilities.s_doneWritingSentinel; + + // If we have no more items remaining, then the channel needs to be marked as completed + // and readers need to be informed they'll never get another item. All of that needs + // to happen outside of the lock to avoid invoking continuations under the lock. + if (parent._items.IsEmpty) + { + completeTask = true; + + if (parent._waitingReader != null) + { + waitingReader = parent._waitingReader; + parent._waitingReader = null; + } + } + } + + // Complete the channel task if necessary + if (completeTask) + { + ChannelUtilities.Complete(parent._completion, error); + } + + // Complete a waiting reader if necessary. + if (waitingReader != null) + { + if (error != null) + { + waitingReader.Fail(error); + } + else + { + waitingReader.Success(false); + } + } + + // Successfully completed the channel + return true; + } + + public override bool TryWrite(T item) + { + SingleConsumerUnboundedChannel<T> parent = _parent; + while (true) // in case a reader was canceled and we need to try again + { + ReaderInteractor<bool> waitingReader = null; + + lock (parent.SyncObj) + { + // If writing is completed, exit out without writing. + if (parent._doneWriting != null) + { + return false; + } + + // Queue the item being written; then if there's a waiting + // reader, store it for notification outside of the lock. + parent._items.Enqueue(item); + + waitingReader = parent._waitingReader; + if (waitingReader == null) + { + return true; + } + parent._waitingReader = null; + } + + // If we get here, we grabbed a waiting reader. + // Notify it that an item was written and exit. + Debug.Assert(waitingReader != null, "Expected a waiting reader"); + waitingReader.Success(true); + return true; + } + } + + public override Task<bool> WaitToWriteAsync(CancellationToken cancellationToken) + { + Exception doneWriting = _parent._doneWriting; + return + doneWriting == null ? ChannelUtilities.s_trueTask : + cancellationToken.IsCancellationRequested ? Task.FromCanceled<bool>(cancellationToken) : + doneWriting != ChannelUtilities.s_doneWritingSentinel ? Task.FromException<bool>(doneWriting) : + ChannelUtilities.s_falseTask; + } + + public override Task WriteAsync(T item, CancellationToken cancellationToken) => + // Writing always succeeds (unless we've already completed writing or cancellation has been requested), + // so just TryWrite and return a completed task. + TryWrite(item) ? Task.CompletedTask : + cancellationToken.IsCancellationRequested ? Task.FromCanceled(cancellationToken) : + Task.FromException(ChannelUtilities.CreateInvalidCompletionException(_parent._doneWriting)); + } + + private object SyncObj => _items; + + /// <summary>Gets the number of items in the channel. This should only be used by the debugger.</summary> + private int ItemsCountForDebugger => _items.Count; + + /// <summary>Gets an enumerator the debugger can use to show the contents of the channel.</summary> + IEnumerator<T> IDebugEnumerable<T>.GetEnumerator() => _items.GetEnumerator(); + } +} diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs b/src/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs new file mode 100644 index 0000000000..4391367809 --- /dev/null +++ b/src/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs @@ -0,0 +1,232 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading.Tasks; + +namespace System.Threading.Channels +{ + /// <summary>Provides a buffered channel of unbounded capacity.</summary> + [DebuggerDisplay("Items={ItemsCountForDebugger}")] + [DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<>))] + internal sealed class UnboundedChannel<T> : Channel<T>, IDebugEnumerable<T> + { + /// <summary>Task that indicates the channel has completed.</summary> + private readonly TaskCompletionSource<VoidResult> _completion; + /// <summary>The items in the channel.</summary> + private readonly ConcurrentQueue<T> _items = new ConcurrentQueue<T>(); + /// <summary>Whether to force continuations to be executed asynchronously from producer writes.</summary> + private readonly bool _runContinuationsAsynchronously; + + /// <summary>Readers waiting for a notification that data is available.</summary> + private ReaderInteractor<bool> _waitingReaders; + /// <summary>Set to non-null once Complete has been called.</summary> + private Exception _doneWriting; + + /// <summary>Initialize the channel.</summary> + internal UnboundedChannel(bool runContinuationsAsynchronously) + { + _runContinuationsAsynchronously = runContinuationsAsynchronously; + _completion = new TaskCompletionSource<VoidResult>(runContinuationsAsynchronously ? TaskCreationOptions.RunContinuationsAsynchronously : TaskCreationOptions.None); + base.Reader = new UnboundedChannelReader(this); + Writer = new UnboundedChannelWriter(this); + } + + private sealed class UnboundedChannelReader : ChannelReader<T> + { + internal readonly UnboundedChannel<T> _parent; + internal UnboundedChannelReader(UnboundedChannel<T> parent) => _parent = parent; + + public override Task Completion => _parent._completion.Task; + + public override bool TryRead(out T item) + { + UnboundedChannel<T> parent = _parent; + + // Dequeue an item if we can + if (parent._items.TryDequeue(out item)) + { + if (parent._doneWriting != null && parent._items.IsEmpty) + { + // If we've now emptied the items queue and we're not getting any more, complete. + ChannelUtilities.Complete(parent._completion, parent._doneWriting); + } + return true; + } + + item = default; + return false; + } + + public override Task<bool> WaitToReadAsync(CancellationToken cancellationToken) + { + // If there are any items, readers can try to get them. + return !_parent._items.IsEmpty ? + ChannelUtilities.s_trueTask : + WaitToReadAsyncCore(cancellationToken); + + Task<bool> WaitToReadAsyncCore(CancellationToken ct) + { + UnboundedChannel<T> parent = _parent; + + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + // Try again to read now that we're synchronized with writers. + if (!parent._items.IsEmpty) + { + return ChannelUtilities.s_trueTask; + } + + // There are no items, so if we're done writing, there's never going to be data available. + if (parent._doneWriting != null) + { + return parent._doneWriting != ChannelUtilities.s_doneWritingSentinel ? + Task.FromException<bool>(parent._doneWriting) : + ChannelUtilities.s_falseTask; + } + + // Queue the waiter + return ChannelUtilities.GetOrCreateWaiter(ref parent._waitingReaders, parent._runContinuationsAsynchronously, ct); + } + } + } + } + + private sealed class UnboundedChannelWriter : ChannelWriter<T> + { + internal readonly UnboundedChannel<T> _parent; + internal UnboundedChannelWriter(UnboundedChannel<T> parent) => _parent = parent; + + public override bool TryComplete(Exception error) + { + UnboundedChannel<T> parent = _parent; + bool completeTask; + + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + // If we've already marked the channel as completed, bail. + if (parent._doneWriting != null) + { + return false; + } + + // Mark that we're done writing. + parent._doneWriting = error ?? ChannelUtilities.s_doneWritingSentinel; + completeTask = parent._items.IsEmpty; + } + + // If there are no items in the queue, complete the channel's task, + // as no more data can possibly arrive at this point. We do this outside + // of the lock in case we'll be running synchronous completions, and we + // do it before completing blocked/waiting readers, so that when they + // wake up they'll see the task as being completed. + if (completeTask) + { + ChannelUtilities.Complete(parent._completion, error); + } + + // At this point, _waitingReaders will not be mutated: + // it's only mutated by readers while holding the lock, and only if _doneWriting is null. + // We also know that only one thread (this one) will ever get here, as only that thread + // will be the one to transition from _doneWriting false to true. As such, we can + // freely manipulate _waitingReaders without any concurrency concerns. + ChannelUtilities.WakeUpWaiters(ref parent._waitingReaders, result: false, error: error); + + // Successfully transitioned to completed. + return true; + } + + public override bool TryWrite(T item) + { + UnboundedChannel<T> parent = _parent; + while (true) + { + ReaderInteractor<bool> waitingReaders = null; + lock (parent.SyncObj) + { + // If writing has already been marked as done, fail the write. + parent.AssertInvariants(); + if (parent._doneWriting != null) + { + return false; + } + + // Add the data to the queue, and let any waiting readers know that they should try to read it. + // We can only complete such waiters here under the lock if they run continuations asynchronously + // (otherwise the synchronous continuations could be invoked under the lock). If we don't complete + // them here, we need to do so outside of the lock. + parent._items.Enqueue(item); + waitingReaders = parent._waitingReaders; + if (waitingReaders == null) + { + return true; + } + parent._waitingReaders = null; + } + + // Wake up all of the waiters. Since we've released the lock, it's possible + // we could cause some spurious wake-ups here, if we tell a waiter there's + // something available but all data has already been removed. It's a benign + // race condition, though, as consumers already need to account for such things. + waitingReaders.Success(item: true); + return true; + } + } + + public override Task<bool> WaitToWriteAsync(CancellationToken cancellationToken) + { + Exception doneWriting = _parent._doneWriting; + return + doneWriting == null ? ChannelUtilities.s_trueTask : // unbounded writing can always be done if we haven't completed + cancellationToken.IsCancellationRequested ? Task.FromCanceled<bool>(cancellationToken) : + doneWriting != ChannelUtilities.s_doneWritingSentinel ? Task.FromException<bool>(doneWriting) : + ChannelUtilities.s_falseTask; + } + + public override Task WriteAsync(T item, CancellationToken cancellationToken) => + TryWrite(item) ? ChannelUtilities.s_trueTask : + cancellationToken.IsCancellationRequested ? Task.FromCanceled(cancellationToken) : + Task.FromException(ChannelUtilities.CreateInvalidCompletionException(_parent._doneWriting)); + } + + /// <summary>Gets the object used to synchronize access to all state on this instance.</summary> + private object SyncObj => _items; + + [Conditional("DEBUG")] + private void AssertInvariants() + { + Debug.Assert(SyncObj != null, "The sync obj must not be null."); + Debug.Assert(Monitor.IsEntered(SyncObj), "Invariants can only be validated while holding the lock."); + + if (!_items.IsEmpty) + { + if (_runContinuationsAsynchronously) + { + Debug.Assert(_waitingReaders == null, "There's data available, so there shouldn't be any waiting readers."); + } + Debug.Assert(!_completion.Task.IsCompleted, "We still have data available, so shouldn't be completed."); + } + if (_waitingReaders != null && _runContinuationsAsynchronously) + { + Debug.Assert(_items.IsEmpty, "There are blocked/waiting readers, so there shouldn't be any data available."); + } + if (_completion.Task.IsCompleted) + { + Debug.Assert(_doneWriting != null, "We're completed, so we must be done writing."); + } + } + + /// <summary>Gets the number of items in the channel. This should only be used by the debugger.</summary> + private int ItemsCountForDebugger => _items.Count; + + /// <summary>Gets an enumerator the debugger can use to show the contents of the channel.</summary> + IEnumerator<T> IDebugEnumerable<T>.GetEnumerator() => _items.GetEnumerator(); + } +} diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/UnbufferedChannel.cs b/src/System.Threading.Channels/src/System/Threading/Channels/UnbufferedChannel.cs new file mode 100644 index 0000000000..ba94e330eb --- /dev/null +++ b/src/System.Threading.Channels/src/System/Threading/Channels/UnbufferedChannel.cs @@ -0,0 +1,217 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading.Tasks; + +namespace System.Threading.Channels +{ + /// <summary>Provides an unbuffered channel, such that a reader and a writer must rendezvous to succeed.</summary> + [DebuggerDisplay("Blocked Writers: {BlockedWritersCountForDebugger}, Waiting Readers: {WaitingReadersForDebugger}")] + [DebuggerTypeProxy(typeof(UnbufferedChannel<>.DebugView))] + internal sealed class UnbufferedChannel<T> : Channel<T> + { + /// <summary>Task that represents the completion of the channel.</summary> + private readonly TaskCompletionSource<VoidResult> _completion = new TaskCompletionSource<VoidResult>(TaskCreationOptions.RunContinuationsAsynchronously); + /// <summary>A queue of writers blocked waiting to be matched with a reader.</summary> + private readonly Dequeue<WriterInteractor<T>> _blockedWriters = new Dequeue<WriterInteractor<T>>(); + + /// <summary>Task signaled when any WaitToReadAsync waiters should be woken up.</summary> + private ReaderInteractor<bool> _waitingReaders; + + private sealed class UnbufferedChannelReader : ChannelReader<T> + { + internal readonly UnbufferedChannel<T> _parent; + internal UnbufferedChannelReader(UnbufferedChannel<T> parent) => _parent = parent; + + public override Task Completion => _parent._completion.Task; + + public override bool TryRead(out T item) + { + UnbufferedChannel<T> parent = _parent; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + // Try to find a writer to pair with + while (!parent._blockedWriters.IsEmpty) + { + WriterInteractor<T> w = parent._blockedWriters.DequeueHead(); + if (w.Success(default(VoidResult))) + { + item = w.Item; + return true; + } + } + } + + // None found + item = default; + return false; + } + + public override Task<bool> WaitToReadAsync(CancellationToken cancellationToken) + { + UnbufferedChannel<T> parent = _parent; + lock (parent.SyncObj) + { + // If we're done writing, fail. + if (parent._completion.Task.IsCompleted) + { + return parent._completion.Task.IsFaulted ? + Task.FromException<bool>(parent._completion.Task.Exception.InnerException) : + ChannelUtilities.s_falseTask; + } + + // If there's a blocked writer, we can read. + if (!parent._blockedWriters.IsEmpty) + { + return ChannelUtilities.s_trueTask; + } + + // Otherwise, queue the waiter. + return ChannelUtilities.GetOrCreateWaiter(ref parent._waitingReaders, true, cancellationToken); + } + } + } + + private sealed class UnbufferedChannelWriter : ChannelWriter<T> + { + internal readonly UnbufferedChannel<T> _parent; + internal UnbufferedChannelWriter(UnbufferedChannel<T> parent) => _parent = parent; + + public override bool TryComplete(Exception error) + { + UnbufferedChannel<T> parent = _parent; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + // Mark the channel as being done. Since there's no buffered data, we can complete immediately. + if (parent._completion.Task.IsCompleted) + { + return false; + } + ChannelUtilities.Complete(parent._completion, error); + + // Fail any blocked writers, as there will be no readers to pair them with. + if (parent._blockedWriters.Count > 0) + { + ChannelUtilities.FailInteractors<WriterInteractor<T>, VoidResult>(parent._blockedWriters, ChannelUtilities.CreateInvalidCompletionException(error)); + } + + // Let any waiting readers know there won't be any more data. + ChannelUtilities.WakeUpWaiters(ref parent._waitingReaders, result: false, error: error); + } + + return true; + } + + public override bool TryWrite(T item) + { + // TryWrite on an UnbufferedChannel can never succeed, as there aren't + // any readers that are able to wait-and-read atomically + return false; + } + + public override Task<bool> WaitToWriteAsync(CancellationToken cancellationToken) + { + UnbufferedChannel<T> parent = _parent; + + // If we're done writing, fail. + if (parent._completion.Task.IsCompleted) + { + return parent._completion.Task.IsFaulted ? + Task.FromException<bool>(parent._completion.Task.Exception.InnerException) : + ChannelUtilities.s_falseTask; + } + + // Otherwise, just return a task suggesting a write be attempted. + // Since there's no "ReadAsync", there's nothing to wait for. + return ChannelUtilities.s_trueTask; + } + + public override Task WriteAsync(T item, CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + UnbufferedChannel<T> parent = _parent; + lock (parent.SyncObj) + { + // Fail if we've already completed. + if (parent._completion.Task.IsCompleted) + { + return + parent._completion.Task.IsCanceled ? Task.FromCanceled<T>(new CancellationToken(true)) : + Task.FromException<T>( + parent._completion.Task.IsFaulted ? + ChannelUtilities.CreateInvalidCompletionException(parent._completion.Task.Exception.InnerException) : + ChannelUtilities.CreateInvalidCompletionException()); + } + + // Queue the writer. + var w = WriterInteractor<T>.Create(true, cancellationToken, item); + parent._blockedWriters.EnqueueTail(w); + + // And let any waiting readers know it's their lucky day. + ChannelUtilities.WakeUpWaiters(ref parent._waitingReaders, result: true); + + return w.Task; + } + } + } + + /// <summary>Initialize the channel.</summary> + internal UnbufferedChannel() + { + base.Reader = new UnbufferedChannelReader(this); + Writer = new UnbufferedChannelWriter(this); + } + + /// <summary>Gets an object used to synchronize all state on the instance.</summary> + private object SyncObj => _completion; + + [Conditional("DEBUG")] + private void AssertInvariants() + { + Debug.Assert(SyncObj != null, "The sync obj must not be null."); + Debug.Assert(Monitor.IsEntered(SyncObj), "Invariants can only be validated while holding the lock."); + + if (_completion.Task.IsCompleted) + { + Debug.Assert(_blockedWriters.IsEmpty, "No writers can be blocked after we've completed."); + } + } + + /// <summary>Gets whether there are any waiting readers. This should only be used by the debugger.</summary> + private bool WaitingReadersForDebugger => _waitingReaders != null; + /// <summary>Gets the number of blocked writers. This should only be used by the debugger.</summary> + private int BlockedWritersCountForDebugger => _blockedWriters.Count; + + private sealed class DebugView + { + private readonly UnbufferedChannel<T> _channel; + + public DebugView(UnbufferedChannel<T> channel) => _channel = channel; + + public bool WaitingReaders => _channel._waitingReaders != null; + public T[] BlockedWriters + { + get + { + var items = new List<T>(); + foreach (WriterInteractor<T> blockedWriter in _channel._blockedWriters) + { + items.Add(blockedWriter.Item); + } + return items.ToArray(); + } + } + } + } +} diff --git a/src/System.Threading.Channels/src/System/VoidResult.cs b/src/System.Threading.Channels/src/System/VoidResult.cs new file mode 100644 index 0000000000..43e8f44d98 --- /dev/null +++ b/src/System.Threading.Channels/src/System/VoidResult.cs @@ -0,0 +1,9 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System +{ + /// <summary>An empty struct, used to represent void in generic types.</summary> + internal struct VoidResult { } +} diff --git a/src/System.Threading.Channels/tests/BoundedChannelTests.cs b/src/System.Threading.Channels/tests/BoundedChannelTests.cs new file mode 100644 index 0000000000..b19bc8605a --- /dev/null +++ b/src/System.Threading.Channels/tests/BoundedChannelTests.cs @@ -0,0 +1,402 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Threading.Tasks; +using Xunit; + +namespace System.Threading.Channels.Tests +{ + public class BoundedChannelTests : ChannelTestBase + { + protected override Channel<int> CreateChannel() => Channel.CreateBounded<int>(1); + protected override Channel<int> CreateFullChannel() + { + var c = Channel.CreateBounded<int>(1); + c.Writer.WriteAsync(42).Wait(); + return c; + } + + [Theory] + [InlineData(1)] + [InlineData(10)] + [InlineData(10000)] + public void TryWrite_TryRead_Many_Wait(int bufferedCapacity) + { + var c = Channel.CreateBounded<int>(bufferedCapacity); + + for (int i = 0; i < bufferedCapacity; i++) + { + Assert.True(c.Writer.TryWrite(i)); + } + Assert.False(c.Writer.TryWrite(bufferedCapacity)); + + int result; + for (int i = 0; i < bufferedCapacity; i++) + { + Assert.True(c.Reader.TryRead(out result)); + Assert.Equal(i, result); + } + + Assert.False(c.Reader.TryRead(out result)); + Assert.Equal(0, result); + } + + [Theory] + [InlineData(1)] + [InlineData(10)] + [InlineData(10000)] + public void TryWrite_TryRead_Many_DropOldest(int bufferedCapacity) + { + var c = Channel.CreateBounded<int>(new BoundedChannelOptions(bufferedCapacity) { FullMode = BoundedChannelFullMode.DropOldest }); + + for (int i = 0; i < bufferedCapacity * 2; i++) + { + Assert.True(c.Writer.TryWrite(i)); + } + + int result; + for (int i = bufferedCapacity; i < bufferedCapacity * 2; i++) + { + Assert.True(c.Reader.TryRead(out result)); + Assert.Equal(i, result); + } + + Assert.False(c.Reader.TryRead(out result)); + Assert.Equal(0, result); + } + + [Theory] + [InlineData(1)] + [InlineData(10)] + [InlineData(10000)] + public void WriteAsync_TryRead_Many_DropOldest(int bufferedCapacity) + { + var c = Channel.CreateBounded<int>(new BoundedChannelOptions(bufferedCapacity) { FullMode = BoundedChannelFullMode.DropOldest }); + + for (int i = 0; i < bufferedCapacity * 2; i++) + { + AssertSynchronousSuccess(c.Writer.WriteAsync(i)); + } + + int result; + for (int i = bufferedCapacity; i < bufferedCapacity * 2; i++) + { + Assert.True(c.Reader.TryRead(out result)); + Assert.Equal(i, result); + } + + Assert.False(c.Reader.TryRead(out result)); + Assert.Equal(0, result); + } + + [Theory] + [InlineData(1)] + [InlineData(10)] + [InlineData(10000)] + public void TryWrite_TryRead_Many_DropNewest(int bufferedCapacity) + { + var c = Channel.CreateBounded<int>(new BoundedChannelOptions(bufferedCapacity) { FullMode = BoundedChannelFullMode.DropNewest }); + + for (int i = 0; i < bufferedCapacity * 2; i++) + { + Assert.True(c.Writer.TryWrite(i)); + } + + int result; + for (int i = 0; i < bufferedCapacity - 1; i++) + { + Assert.True(c.Reader.TryRead(out result)); + Assert.Equal(i, result); + } + Assert.True(c.Reader.TryRead(out result)); + Assert.Equal(bufferedCapacity * 2 - 1, result); + + Assert.False(c.Reader.TryRead(out result)); + Assert.Equal(0, result); + } + + [Theory] + [InlineData(1)] + [InlineData(10)] + [InlineData(10000)] + public void WriteAsync_TryRead_Many_DropNewest(int bufferedCapacity) + { + var c = Channel.CreateBounded<int>(new BoundedChannelOptions(bufferedCapacity) { FullMode = BoundedChannelFullMode.DropNewest }); + + for (int i = 0; i < bufferedCapacity * 2; i++) + { + AssertSynchronousSuccess(c.Writer.WriteAsync(i)); + } + + int result; + for (int i = 0; i < bufferedCapacity - 1; i++) + { + Assert.True(c.Reader.TryRead(out result)); + Assert.Equal(i, result); + } + Assert.True(c.Reader.TryRead(out result)); + Assert.Equal(bufferedCapacity * 2 - 1, result); + + Assert.False(c.Reader.TryRead(out result)); + Assert.Equal(0, result); + } + + [Fact] + public async Task TryWrite_DropNewest_WrappedAroundInternalQueue() + { + var c = Channel.CreateBounded<int>(new BoundedChannelOptions(3) { FullMode = BoundedChannelFullMode.DropNewest }); + + // Move head of dequeue beyond the beginning + Assert.True(c.Writer.TryWrite(1)); + Assert.True(c.Reader.TryRead(out int item)); + Assert.Equal(1, item); + + // Add items to fill the capacity and put the tail at 0 + Assert.True(c.Writer.TryWrite(2)); + Assert.True(c.Writer.TryWrite(3)); + Assert.True(c.Writer.TryWrite(4)); + + // Add an item to overwrite the newest + Assert.True(c.Writer.TryWrite(5)); + + // Verify current contents + Assert.Equal(2, await c.Reader.ReadAsync()); + Assert.Equal(3, await c.Reader.ReadAsync()); + Assert.Equal(5, await c.Reader.ReadAsync()); + } + + [Theory] + [InlineData(1)] + [InlineData(10)] + [InlineData(10000)] + public void TryWrite_TryRead_Many_Ignore(int bufferedCapacity) + { + var c = Channel.CreateBounded<int>(new BoundedChannelOptions(bufferedCapacity) { FullMode = BoundedChannelFullMode.DropWrite }); + + for (int i = 0; i < bufferedCapacity * 2; i++) + { + Assert.True(c.Writer.TryWrite(i)); + } + + int result; + for (int i = 0; i < bufferedCapacity; i++) + { + Assert.True(c.Reader.TryRead(out result)); + Assert.Equal(i, result); + } + + Assert.False(c.Reader.TryRead(out result)); + Assert.Equal(0, result); + } + + [Theory] + [InlineData(1)] + [InlineData(10)] + [InlineData(10000)] + public void WriteAsync_TryRead_Many_Ignore(int bufferedCapacity) + { + var c = Channel.CreateBounded<int>(new BoundedChannelOptions(bufferedCapacity) { FullMode = BoundedChannelFullMode.DropWrite }); + + for (int i = 0; i < bufferedCapacity * 2; i++) + { + AssertSynchronousSuccess(c.Writer.WriteAsync(i)); + } + + int result; + for (int i = 0; i < bufferedCapacity; i++) + { + Assert.True(c.Reader.TryRead(out result)); + Assert.Equal(i, result); + } + + Assert.False(c.Reader.TryRead(out result)); + Assert.Equal(0, result); + } + + [Fact] + public async Task CancelPendingWrite_Reading_DataTransferredFromCorrectWriter() + { + var c = Channel.CreateBounded<int>(1); + Assert.Equal(TaskStatus.RanToCompletion, c.Writer.WriteAsync(42).Status); + + var cts = new CancellationTokenSource(); + + Task write1 = c.Writer.WriteAsync(43, cts.Token); + Assert.Equal(TaskStatus.WaitingForActivation, write1.Status); + + cts.Cancel(); + + Task write2 = c.Writer.WriteAsync(44); + + Assert.Equal(42, await c.Reader.ReadAsync()); + Assert.Equal(44, await c.Reader.ReadAsync()); + + await AssertCanceled(write1, cts.Token); + await write2; + } + + [Theory] + [InlineData(1)] + [InlineData(10)] + [InlineData(10000)] + public void TryWrite_TryRead_OneAtATime(int bufferedCapacity) + { + var c = Channel.CreateBounded<int>(bufferedCapacity); + + const int NumItems = 100000; + for (int i = 0; i < NumItems; i++) + { + Assert.True(c.Writer.TryWrite(i)); + Assert.True(c.Reader.TryRead(out int result)); + Assert.Equal(i, result); + } + } + + [Theory] + [InlineData(1)] + [InlineData(10)] + [InlineData(10000)] + public void SingleProducerConsumer_ConcurrentReadWrite_WithBufferedCapacity_Success(int bufferedCapacity) + { + var c = Channel.CreateBounded<int>(bufferedCapacity); + + const int NumItems = 10000; + Task.WaitAll( + Task.Run(async () => + { + for (int i = 0; i < NumItems; i++) + { + await c.Writer.WriteAsync(i); + } + }), + Task.Run(async () => + { + for (int i = 0; i < NumItems; i++) + { + Assert.Equal(i, await c.Reader.ReadAsync()); + } + })); + } + + [Theory] + [InlineData(1)] + [InlineData(10)] + [InlineData(10000)] + public void ManyProducerConsumer_ConcurrentReadWrite_WithBufferedCapacity_Success(int bufferedCapacity) + { + var c = Channel.CreateBounded<int>(bufferedCapacity); + + const int NumWriters = 10; + const int NumReaders = 10; + const int NumItems = 10000; + + long readTotal = 0; + int remainingWriters = NumWriters; + int remainingItems = NumItems; + + Task[] tasks = new Task[NumWriters + NumReaders]; + + for (int i = 0; i < NumReaders; i++) + { + tasks[i] = Task.Run(async () => + { + try + { + while (true) + { + Interlocked.Add(ref readTotal, await c.Reader.ReadAsync()); + } + } + catch (ChannelClosedException) { } + }); + } + + for (int i = 0; i < NumWriters; i++) + { + tasks[NumReaders + i] = Task.Run(async () => + { + while (true) + { + int value = Interlocked.Decrement(ref remainingItems); + if (value < 0) + { + break; + } + await c.Writer.WriteAsync(value + 1); + } + if (Interlocked.Decrement(ref remainingWriters) == 0) + { + c.Writer.Complete(); + } + }); + } + + Task.WaitAll(tasks); + Assert.Equal((NumItems * (NumItems + 1L)) / 2, readTotal); + } + + [Fact] + public async Task WaitToWriteAsync_AfterFullThenRead_ReturnsTrue() + { + var c = Channel.CreateBounded<int>(1); + Assert.True(c.Writer.TryWrite(1)); + + Task<bool> write1 = c.Writer.WaitToWriteAsync(); + Assert.False(write1.IsCompleted); + + Task<bool> write2 = c.Writer.WaitToWriteAsync(); + Assert.False(write2.IsCompleted); + + Assert.Equal(1, await c.Reader.ReadAsync()); + + Assert.True(await write1); + Assert.True(await write2); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public void AllowSynchronousContinuations_WaitToReadAsync_ContinuationsInvokedAccordingToSetting(bool allowSynchronousContinuations) + { + var c = Channel.CreateBounded<int>(new BoundedChannelOptions(1) { AllowSynchronousContinuations = allowSynchronousContinuations }); + + int expectedId = Environment.CurrentManagedThreadId; + Task r = c.Reader.WaitToReadAsync().ContinueWith(_ => + { + Assert.Equal(allowSynchronousContinuations, expectedId == Environment.CurrentManagedThreadId); + }, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); + + Assert.Equal(TaskStatus.RanToCompletion, c.Writer.WriteAsync(42).Status); + ((IAsyncResult)r).AsyncWaitHandle.WaitOne(); // avoid inlining the continuation + r.GetAwaiter().GetResult(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public void AllowSynchronousContinuations_CompletionTask_ContinuationsInvokedAccordingToSetting(bool allowSynchronousContinuations) + { + var c = Channel.CreateBounded<int>(new BoundedChannelOptions(1) { AllowSynchronousContinuations = allowSynchronousContinuations }); + + int expectedId = Environment.CurrentManagedThreadId; + Task r = c.Reader.Completion.ContinueWith(_ => + { + Assert.Equal(allowSynchronousContinuations, expectedId == Environment.CurrentManagedThreadId); + }, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); + + Assert.True(c.Writer.TryComplete()); + ((IAsyncResult)r).AsyncWaitHandle.WaitOne(); // avoid inlining the continuation + r.GetAwaiter().GetResult(); + } + + [Fact] + public void TryWrite_NoBlockedReaders_WaitingReader_WaiterNotifified() + { + Channel<int> c = CreateChannel(); + + Task<bool> r = c.Reader.WaitToReadAsync(); + Assert.True(c.Writer.TryWrite(42)); + AssertSynchronousTrue(r); + } + } +} diff --git a/src/System.Threading.Channels/tests/ChannelClosedExceptionTests.cs b/src/System.Threading.Channels/tests/ChannelClosedExceptionTests.cs new file mode 100644 index 0000000000..38a1e9bb50 --- /dev/null +++ b/src/System.Threading.Channels/tests/ChannelClosedExceptionTests.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Xunit; + +namespace System.Threading.Channels.Tests +{ + public class ChannelClosedExceptionTests + { + [Fact] + public void Ctors() + { + var e = new ChannelClosedException(); + Assert.NotEmpty(e.Message); + Assert.Null(e.InnerException); + + e = new ChannelClosedException("hello"); + Assert.Equal("hello", e.Message); + Assert.Null(e.InnerException); + + var inner = new FormatException(); + e = new ChannelClosedException("hello", inner); + Assert.Equal("hello", e.Message); + Assert.Same(inner, e.InnerException); + } + } +} diff --git a/src/System.Threading.Channels/tests/ChannelTestBase.cs b/src/System.Threading.Channels/tests/ChannelTestBase.cs new file mode 100644 index 0000000000..91b2d71408 --- /dev/null +++ b/src/System.Threading.Channels/tests/ChannelTestBase.cs @@ -0,0 +1,457 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Diagnostics; +using System.Threading.Tasks; +using Xunit; + +namespace System.Threading.Channels.Tests +{ + public abstract class ChannelTestBase : TestBase + { + protected abstract Channel<int> CreateChannel(); + protected abstract Channel<int> CreateFullChannel(); + + protected virtual bool RequiresSingleReader => false; + protected virtual bool RequiresSingleWriter => false; + + [Fact] + public void ValidateDebuggerAttributes() + { + Channel<int> c = CreateChannel(); + for (int i = 1; i <= 10; i++) + { + c.Writer.WriteAsync(i); + } + DebuggerAttributes.ValidateDebuggerDisplayReferences(c); + DebuggerAttributes.ValidateDebuggerTypeProxyProperties(c); + } + + [Fact] + public void Cast_MatchesInOut() + { + Channel<int> c = CreateChannel(); + ChannelReader<int> rc = c; + ChannelWriter<int> wc = c; + Assert.Same(rc, c.Reader); + Assert.Same(wc, c.Writer); + } + + [Fact] + public void Completion_Idempotent() + { + Channel<int> c = CreateChannel(); + + Task completion = c.Reader.Completion; + Assert.Equal(TaskStatus.WaitingForActivation, completion.Status); + + Assert.Same(completion, c.Reader.Completion); + c.Writer.Complete(); + Assert.Same(completion, c.Reader.Completion); + + Assert.Equal(TaskStatus.RanToCompletion, completion.Status); + } + + [Fact] + public async Task Complete_AfterEmpty_NoWaiters_TriggersCompletion() + { + Channel<int> c = CreateChannel(); + c.Writer.Complete(); + await c.Reader.Completion; + } + + [Fact] + public async Task Complete_AfterEmpty_WaitingReader_TriggersCompletion() + { + Channel<int> c = CreateChannel(); + Task<int> r = c.Reader.ReadAsync().AsTask(); + c.Writer.Complete(); + await c.Reader.Completion; + await Assert.ThrowsAnyAsync<InvalidOperationException>(() => r); + } + + [Fact] + public async Task Complete_BeforeEmpty_WaitingReaders_TriggersCompletion() + { + Channel<int> c = CreateChannel(); + Task<int> read = c.Reader.ReadAsync().AsTask(); + c.Writer.Complete(); + await c.Reader.Completion; + await Assert.ThrowsAnyAsync<InvalidOperationException>(() => read); + } + + [Fact] + public void Complete_Twice_ThrowsInvalidOperationException() + { + Channel<int> c = CreateChannel(); + c.Writer.Complete(); + Assert.ThrowsAny<InvalidOperationException>(() => c.Writer.Complete()); + } + + [Fact] + public void TryComplete_Twice_ReturnsTrueThenFalse() + { + Channel<int> c = CreateChannel(); + Assert.True(c.Writer.TryComplete()); + Assert.False(c.Writer.TryComplete()); + Assert.False(c.Writer.TryComplete()); + } + + [Fact] + public async Task TryComplete_ErrorsPropage() + { + Channel<int> c; + + // Success + c = CreateChannel(); + Assert.True(c.Writer.TryComplete()); + await c.Reader.Completion; + + // Error + c = CreateChannel(); + Assert.True(c.Writer.TryComplete(new FormatException())); + await Assert.ThrowsAsync<FormatException>(() => c.Reader.Completion); + + // Canceled + c = CreateChannel(); + var cts = new CancellationTokenSource(); + cts.Cancel(); + Assert.True(c.Writer.TryComplete(new OperationCanceledException(cts.Token))); + await AssertCanceled(c.Reader.Completion, cts.Token); + } + + [Fact] + public void SingleProducerConsumer_ConcurrentReadWrite_Success() + { + Channel<int> c = CreateChannel(); + + const int NumItems = 100000; + Task.WaitAll( + Task.Run(async () => + { + for (int i = 0; i < NumItems; i++) + { + await c.Writer.WriteAsync(i); + } + }), + Task.Run(async () => + { + for (int i = 0; i < NumItems; i++) + { + Assert.Equal(i, await c.Reader.ReadAsync()); + } + })); + } + + [Fact] + public void SingleProducerConsumer_PingPong_Success() + { + Channel<int> c1 = CreateChannel(); + Channel<int> c2 = CreateChannel(); + + const int NumItems = 100000; + Task.WaitAll( + Task.Run(async () => + { + for (int i = 0; i < NumItems; i++) + { + Assert.Equal(i, await c1.Reader.ReadAsync()); + await c2.Writer.WriteAsync(i); + } + }), + Task.Run(async () => + { + for (int i = 0; i < NumItems; i++) + { + await c1.Writer.WriteAsync(i); + Assert.Equal(i, await c2.Reader.ReadAsync()); + } + })); + } + + [Theory] + [InlineData(1, 1)] + [InlineData(1, 10)] + [InlineData(10, 1)] + [InlineData(10, 10)] + public void ManyProducerConsumer_ConcurrentReadWrite_Success(int numReaders, int numWriters) + { + if (RequiresSingleReader && numReaders > 1) + { + return; + } + + if (RequiresSingleWriter && numWriters > 1) + { + return; + } + + Channel<int> c = CreateChannel(); + + const int NumItems = 10000; + + long readTotal = 0; + int remainingWriters = numWriters; + int remainingItems = NumItems; + + Task[] tasks = new Task[numWriters + numReaders]; + + for (int i = 0; i < numReaders; i++) + { + tasks[i] = Task.Run(async () => + { + try + { + while (await c.Reader.WaitToReadAsync()) + { + if (c.Reader.TryRead(out int value)) + { + Interlocked.Add(ref readTotal, value); + } + } + } + catch (ChannelClosedException) { } + }); + } + + for (int i = 0; i < numWriters; i++) + { + tasks[numReaders + i] = Task.Run(async () => + { + while (true) + { + int value = Interlocked.Decrement(ref remainingItems); + if (value < 0) + { + break; + } + await c.Writer.WriteAsync(value + 1); + } + if (Interlocked.Decrement(ref remainingWriters) == 0) + { + c.Writer.Complete(); + } + }); + } + + Task.WaitAll(tasks); + Assert.Equal((NumItems * (NumItems + 1L)) / 2, readTotal); + } + + [Fact] + public void WaitToReadAsync_DataAvailableBefore_CompletesSynchronously() + { + Channel<int> c = CreateChannel(); + Task write = c.Writer.WriteAsync(42); + Task<bool> read = c.Reader.WaitToReadAsync(); + Assert.Equal(TaskStatus.RanToCompletion, read.Status); + } + + [Fact] + public void WaitToReadAsync_DataAvailableAfter_CompletesAsynchronously() + { + Channel<int> c = CreateChannel(); + Task<bool> read = c.Reader.WaitToReadAsync(); + Assert.False(read.IsCompleted); + Task write = c.Writer.WriteAsync(42); + Assert.True(read.Result); + } + + [Fact] + public void WaitToReadAsync_AfterComplete_SynchronouslyCompletes() + { + Channel<int> c = CreateChannel(); + c.Writer.Complete(); + Task<bool> read = c.Reader.WaitToReadAsync(); + Assert.Equal(TaskStatus.RanToCompletion, read.Status); + Assert.False(read.Result); + } + + [Fact] + public void WaitToReadAsync_BeforeComplete_AsynchronouslyCompletes() + { + Channel<int> c = CreateChannel(); + Task<bool> read = c.Reader.WaitToReadAsync(); + Assert.False(read.IsCompleted); + c.Writer.Complete(); + Assert.False(read.Result); + } + + [Fact] + public void WaitToWriteAsync_AfterComplete_SynchronouslyCompletes() + { + Channel<int> c = CreateChannel(); + c.Writer.Complete(); + Task<bool> write = c.Writer.WaitToWriteAsync(); + Assert.Equal(TaskStatus.RanToCompletion, write.Status); + Assert.False(write.Result); + } + + [Fact] + public void TryRead_DataAvailable_Success() + { + Channel<int> c = CreateChannel(); + Task write = c.Writer.WriteAsync(42); + Assert.True(c.Reader.TryRead(out int result)); + Assert.Equal(42, result); + } + + [Fact] + public void TryRead_AfterComplete_ReturnsFalse() + { + Channel<int> c = CreateChannel(); + c.Writer.Complete(); + Assert.False(c.Reader.TryRead(out int result)); + } + + [Fact] + public void TryWrite_AfterComplete_ReturnsFalse() + { + Channel<int> c = CreateChannel(); + c.Writer.Complete(); + Assert.False(c.Writer.TryWrite(42)); + } + + [Fact] + public async Task WriteAsync_AfterComplete_ThrowsException() + { + Channel<int> c = CreateChannel(); + c.Writer.Complete(); + await Assert.ThrowsAnyAsync<InvalidOperationException>(() => c.Writer.WriteAsync(42)); + } + + [Fact] + public async Task Complete_WithException_PropagatesToCompletion() + { + Channel<int> c = CreateChannel(); + var exc = new FormatException(); + c.Writer.Complete(exc); + Assert.Same(exc, await Assert.ThrowsAsync<FormatException>(() => c.Reader.Completion)); + } + + [Fact] + public async Task Complete_WithCancellationException_PropagatesToCompletion() + { + Channel<int> c = CreateChannel(); + var cts = new CancellationTokenSource(); + cts.Cancel(); + + Exception exc = null; + try { cts.Token.ThrowIfCancellationRequested(); } + catch (Exception e) { exc = e; } + + c.Writer.Complete(exc); + await AssertCanceled(c.Reader.Completion, cts.Token); + } + + [Fact] + public async Task Complete_WithException_PropagatesToExistingWriter() + { + Channel<int> c = CreateFullChannel(); + if (c != null) + { + Task write = c.Writer.WriteAsync(42); + var exc = new FormatException(); + c.Writer.Complete(exc); + Assert.Same(exc, (await Assert.ThrowsAsync<ChannelClosedException>(() => write)).InnerException); + } + } + + [Fact] + public async Task Complete_WithException_PropagatesToNewWriter() + { + Channel<int> c = CreateChannel(); + var exc = new FormatException(); + c.Writer.Complete(exc); + Task write = c.Writer.WriteAsync(42); + Assert.Same(exc, (await Assert.ThrowsAsync<ChannelClosedException>(() => write)).InnerException); + } + + [Fact] + public async Task Complete_WithException_PropagatesToExistingWaitingReader() + { + Channel<int> c = CreateChannel(); + Task<bool> read = c.Reader.WaitToReadAsync(); + var exc = new FormatException(); + c.Writer.Complete(exc); + await Assert.ThrowsAsync<FormatException>(() => read); + } + + [Fact] + public async Task Complete_WithException_PropagatesToNewWaitingReader() + { + Channel<int> c = CreateChannel(); + var exc = new FormatException(); + c.Writer.Complete(exc); + Task<bool> read = c.Reader.WaitToReadAsync(); + await Assert.ThrowsAsync<FormatException>(() => read); + } + + [Fact] + public async Task Complete_WithException_PropagatesToNewWaitingWriter() + { + Channel<int> c = CreateChannel(); + var exc = new FormatException(); + c.Writer.Complete(exc); + Task<bool> write = c.Writer.WaitToWriteAsync(); + await Assert.ThrowsAsync<FormatException>(() => write); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + public void ManyWriteAsync_ThenManyTryRead_Success(int readMode) + { + if (RequiresSingleReader || RequiresSingleWriter) + { + return; + } + + Channel<int> c = CreateChannel(); + + const int NumItems = 2000; + + Task[] writers = new Task[NumItems]; + for (int i = 0; i < writers.Length; i++) + { + writers[i] = c.Writer.WriteAsync(i); + } + + Task<int>[] readers = new Task<int>[NumItems]; + for (int i = 0; i < readers.Length; i++) + { + int result; + Assert.True(c.Reader.TryRead(out result)); + Assert.Equal(i, result); + } + + Assert.All(writers, w => Assert.True(w.IsCompleted)); + } + + [Fact] + public void Precancellation_Writing_ReturnsSuccessImmediately() + { + Channel<int> c = CreateChannel(); + var cts = new CancellationTokenSource(); + cts.Cancel(); + + Task writeTask = c.Writer.WriteAsync(42, cts.Token); + Assert.True(writeTask.Status == TaskStatus.Canceled || writeTask.Status == TaskStatus.RanToCompletion, $"Status == {writeTask.Status}"); + + Task<bool> waitTask = c.Writer.WaitToWriteAsync(cts.Token); + Assert.True(writeTask.Status == TaskStatus.Canceled || writeTask.Status == TaskStatus.RanToCompletion, $"Status == {writeTask.Status}"); + if (waitTask.Status == TaskStatus.RanToCompletion) + { + Assert.True(waitTask.Result); + } + } + + [Fact] + public void Write_WaitToReadAsync_CompletesSynchronously() + { + Channel<int> c = CreateChannel(); + c.Writer.WriteAsync(42); + AssertSynchronousTrue(c.Reader.WaitToReadAsync()); + } + } +} diff --git a/src/System.Threading.Channels/tests/ChannelTests.cs b/src/System.Threading.Channels/tests/ChannelTests.cs new file mode 100644 index 0000000000..ffaf46ea9f --- /dev/null +++ b/src/System.Threading.Channels/tests/ChannelTests.cs @@ -0,0 +1,145 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.IO; +using System.Threading.Tasks; +using Xunit; + +namespace System.Threading.Channels.Tests +{ + public class ChannelTests + { + [Fact] + public void ChannelOptimizations_Properties_Roundtrip() + { + var co = new UnboundedChannelOptions(); + + Assert.False(co.SingleReader); + Assert.False(co.SingleWriter); + + co.SingleReader = true; + Assert.True(co.SingleReader); + Assert.False(co.SingleWriter); + co.SingleReader = false; + Assert.False(co.SingleReader); + + co.SingleWriter = true; + Assert.False(co.SingleReader); + Assert.True(co.SingleWriter); + co.SingleWriter = false; + Assert.False(co.SingleWriter); + + co.SingleReader = true; + co.SingleWriter = true; + Assert.True(co.SingleReader); + Assert.True(co.SingleWriter); + + Assert.False(co.AllowSynchronousContinuations); + co.AllowSynchronousContinuations = true; + Assert.True(co.AllowSynchronousContinuations); + co.AllowSynchronousContinuations = false; + Assert.False(co.AllowSynchronousContinuations); + } + + [Theory] + [InlineData(0)] + [InlineData(-2)] + public void CreateBounded_InvalidBufferSizes_ThrowArgumentExceptions(int capacity) + { + Assert.Throws<ArgumentOutOfRangeException>("capacity", () => Channel.CreateBounded<int>(capacity)); + Assert.Throws<ArgumentOutOfRangeException>("capacity", () => new BoundedChannelOptions(capacity)); + } + + [Theory] + [InlineData((BoundedChannelFullMode)(-1))] + [InlineData((BoundedChannelFullMode)(4))] + public void BoundedChannelOptions_InvalidModes_ThrowArgumentExceptions(BoundedChannelFullMode mode) => + Assert.Throws<ArgumentOutOfRangeException>("value", () => new BoundedChannelOptions(1) { FullMode = mode }); + + [Theory] + [InlineData(1)] + public void CreateBounded_ValidBufferSizes_Success(int bufferedCapacity) => + Assert.NotNull(Channel.CreateBounded<int>(bufferedCapacity)); + + [Fact] + public async Task DefaultWriteAsync_UsesWaitToWriteAsyncAndTryWrite() + { + var c = new TestChannelWriter<int>(10); + Assert.False(c.TryComplete()); + Assert.Equal(TaskStatus.Canceled, c.WriteAsync(42, new CancellationToken(true)).Status); + + int count = 0; + try + { + while (true) + { + await c.WriteAsync(count++); + } + } + catch (ChannelClosedException) { } + Assert.Equal(11, count); + } + + private sealed class TestChannelWriter<T> : ChannelWriter<T> + { + private readonly Random _rand = new Random(42); + private readonly int _max; + private int _count; + + public TestChannelWriter(int max) => _max = max; + + public override bool TryWrite(T item) => _rand.Next(0, 2) == 0 && _count++ < _max; // succeed if we're under our limit, and add random failures + + public override Task<bool> WaitToWriteAsync(CancellationToken cancellationToken) => + _count >= _max ? Task.FromResult(false) : + _rand.Next(0, 2) == 0 ? Task.Delay(1).ContinueWith(_ => true) : // randomly introduce delays + Task.FromResult(true); + } + + private sealed class TestChannelReader<T> : ChannelReader<T> + { + private Random _rand = new Random(42); + private IEnumerator<T> _enumerator; + private int _count; + private bool _closed; + + public TestChannelReader(IEnumerable<T> enumerable) => _enumerator = enumerable.GetEnumerator(); + + public override bool TryRead(out T item) + { + // Randomly fail to read + if (_rand.Next(0, 2) == 0) + { + item = default; + return false; + } + + // If the enumerable is closed, fail the read. + if (!_enumerator.MoveNext()) + { + _enumerator.Dispose(); + _closed = true; + item = default; + return false; + } + + // Otherwise return the next item. + _count++; + item = _enumerator.Current; + return true; + } + + public override Task<bool> WaitToReadAsync(CancellationToken cancellationToken) => + _closed ? Task.FromResult(false) : + _rand.Next(0, 2) == 0 ? Task.Delay(1).ContinueWith(_ => true) : // randomly introduce delays + Task.FromResult(true); + } + + private sealed class CanReadFalseStream : MemoryStream + { + public override bool CanRead => false; + } + } +} diff --git a/src/System.Threading.Channels/tests/Configurations.props b/src/System.Threading.Channels/tests/Configurations.props new file mode 100644 index 0000000000..78953dfc88 --- /dev/null +++ b/src/System.Threading.Channels/tests/Configurations.props @@ -0,0 +1,8 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="14.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <PropertyGroup> + <BuildConfigurations> + netstandard; + </BuildConfigurations> + </PropertyGroup> +</Project> diff --git a/src/System.Threading.Channels/tests/DebuggerAttributes.cs b/src/System.Threading.Channels/tests/DebuggerAttributes.cs new file mode 100644 index 0000000000..02e12925a7 --- /dev/null +++ b/src/System.Threading.Channels/tests/DebuggerAttributes.cs @@ -0,0 +1,145 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; +using System.Reflection; + +namespace System.Diagnostics +{ + internal static class DebuggerAttributes + { + internal static object GetFieldValue(object obj, string fieldName) => GetField(obj, fieldName).GetValue(obj); + + internal static void ValidateDebuggerTypeProxyProperties(object obj) + { + // Get the DebuggerTypeProxyAttibute for obj + CustomAttributeData[] attrs = + obj.GetType().GetTypeInfo().CustomAttributes + .Where(a => a.AttributeType == typeof(DebuggerTypeProxyAttribute)) + .ToArray(); + if (attrs.Length != 1) + { + throw new InvalidOperationException( + string.Format("Expected one DebuggerTypeProxyAttribute on {0}.", obj)); + } + CustomAttributeData cad = attrs[0]; + + // Get the proxy type. As written, this only works if the proxy and the target type + // have the same generic parameters, e.g. Dictionary<TKey,TValue> and Proxy<TKey,TValue>. + // It will not work with, for example, Dictionary<TKey,TValue>.Keys and Proxy<TKey>, + // as the former has two generic parameters and the latter only one. + Type proxyType = cad.ConstructorArguments[0].ArgumentType == typeof(Type) ? + (Type)cad.ConstructorArguments[0].Value : + Type.GetType((string)cad.ConstructorArguments[0].Value); + Type[] genericArguments = obj.GetType().GenericTypeArguments; + if (genericArguments.Length > 0) + { + proxyType = proxyType.MakeGenericType(genericArguments); + } + + // Create an instance of the proxy type, and make sure we can access all of the instance properties + // on the type without exception + object proxyInstance = Activator.CreateInstance(proxyType, obj); + foreach (PropertyInfo pi in proxyInstance.GetType().GetTypeInfo().DeclaredProperties) + { + pi.GetValue(proxyInstance, null); + } + } + + internal static void ValidateDebuggerDisplayReferences(object obj) + { + // Get the DebuggerDisplayAttribute for obj + CustomAttributeData[] attrs = + obj.GetType().GetTypeInfo().CustomAttributes + .Where(a => a.AttributeType == typeof(DebuggerDisplayAttribute)) + .ToArray(); + if (attrs.Length != 1) + { + throw new InvalidOperationException( + string.Format("Expected one DebuggerDisplayAttribute on {0}.", obj)); + } + CustomAttributeData cad = attrs[0]; + + // Get the text of the DebuggerDisplayAttribute + string attrText = (string)cad.ConstructorArguments[0].Value; + + // Parse the text for all expressions + var references = new List<string>(); + int pos = 0; + while (true) + { + int openBrace = attrText.IndexOf('{', pos); + if (openBrace < pos) + { + break; + } + + int closeBrace = attrText.IndexOf('}', openBrace); + if (closeBrace < openBrace) + { + break; + } + + string reference = attrText.Substring(openBrace + 1, closeBrace - openBrace - 1).Replace(",nq", ""); + pos = closeBrace + 1; + + references.Add(reference); + } + if (references.Count == 0) + { + throw new InvalidOperationException( + string.Format("The DebuggerDisplayAttribute for {0} doesn't reference any expressions.", obj)); + } + + // Make sure that each referenced expression is a simple field or property name, and that we can + // invoke the property's get accessor or read from the field. + foreach (string reference in references) + { + PropertyInfo pi = GetProperty(obj, reference); + if (pi != null) + { + object ignored = pi.GetValue(obj, null); + continue; + } + + FieldInfo fi = GetField(obj, reference); + if (fi != null) + { + object ignored = fi.GetValue(obj); + continue; + } + + throw new InvalidOperationException( + string.Format("The DebuggerDisplayAttribute for {0} contains the expression \"{1}\".", obj, reference)); + } + } + + private static FieldInfo GetField(object obj, string fieldName) + { + for (Type t = obj.GetType(); t != null; t = t.GetTypeInfo().BaseType) + { + FieldInfo fi = t.GetTypeInfo().GetDeclaredField(fieldName); + if (fi != null) + { + return fi; + } + } + return null; + } + + private static PropertyInfo GetProperty(object obj, string propertyName) + { + for (Type t = obj.GetType(); t != null; t = t.GetTypeInfo().BaseType) + { + PropertyInfo pi = t.GetTypeInfo().GetDeclaredProperty(propertyName); + if (pi != null) + { + return pi; + } + } + return null; + } + } +} diff --git a/src/System.Threading.Channels/tests/System.Threading.Channels.Tests.csproj b/src/System.Threading.Channels/tests/System.Threading.Channels.Tests.csproj new file mode 100644 index 0000000000..b34fb28fc4 --- /dev/null +++ b/src/System.Threading.Channels/tests/System.Threading.Channels.Tests.csproj @@ -0,0 +1,21 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="14.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.props))\dir.props" /> + <PropertyGroup> + <ProjectGuid>{95DFC527-4DC1-495E-97D7-E94EE1F7140D}</ProjectGuid> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'netstandard-Debug|AnyCPU'" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'netstandard-Release|AnyCPU'" /> + <ItemGroup> + <Compile Include="BoundedChannelTests.cs" /> + <Compile Include="ChannelClosedExceptionTests.cs" /> + <Compile Include="ChannelTestBase.cs" /> + <Compile Include="ChannelTests.cs" /> + <Compile Include="DebuggerAttributes.cs" /> + <Compile Include="TestBase.cs" /> + <Compile Include="TestExtensions.cs" /> + <Compile Include="UnboundedChannelTests.cs" /> + <Compile Include="UnbufferedChannelTests.cs" /> + </ItemGroup> + <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.targets))\dir.targets" /> +</Project> diff --git a/src/System.Threading.Channels/tests/TestBase.cs b/src/System.Threading.Channels/tests/TestBase.cs new file mode 100644 index 0000000000..d3af1ba7eb --- /dev/null +++ b/src/System.Threading.Channels/tests/TestBase.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Threading.Tasks; +using Xunit; + +#pragma warning disable 0649 // unused fields there for future testing needs + +namespace System.Threading.Channels.Tests +{ + public abstract class TestBase + { + protected void AssertSynchronouslyCanceled(Task task, CancellationToken token) + { + Assert.Equal(TaskStatus.Canceled, task.Status); + OperationCanceledException oce = Assert.ThrowsAny<OperationCanceledException>(() => task.GetAwaiter().GetResult()); + Assert.Equal(token, oce.CancellationToken); + } + + protected async Task AssertCanceled(Task task, CancellationToken token) + { + await Assert.ThrowsAnyAsync<OperationCanceledException>(() => task); + AssertSynchronouslyCanceled(task, token); + } + + protected void AssertSynchronousSuccess(Task task) => Assert.Equal(TaskStatus.RanToCompletion, task.Status); + + protected void AssertSynchronousTrue(Task<bool> task) + { + AssertSynchronousSuccess(task); + Assert.True(task.Result); + } + + internal sealed class DelegateObserver<T> : IObserver<T> + { + public Action<T> OnNextDelegate = null; + public Action<Exception> OnErrorDelegate = null; + public Action OnCompletedDelegate = null; + + void IObserver<T>.OnNext(T value) => OnNextDelegate?.Invoke(value); + + void IObserver<T>.OnError(Exception error) => OnErrorDelegate?.Invoke(error); + + void IObserver<T>.OnCompleted() => OnCompletedDelegate?.Invoke(); + } + } +} diff --git a/src/System.Threading.Channels/tests/TestExtensions.cs b/src/System.Threading.Channels/tests/TestExtensions.cs new file mode 100644 index 0000000000..5cf8d40bf1 --- /dev/null +++ b/src/System.Threading.Channels/tests/TestExtensions.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Threading.Tasks; + +namespace System.Threading.Channels.Tests +{ + internal static class TestExtensions + { + public static async ValueTask<T> ReadAsync<T>(this ChannelReader<T> reader, CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + + try + { + while (true) + { + if (!await reader.WaitToReadAsync(cancellationToken)) + { + throw new ChannelClosedException(); + } + + if (reader.TryRead(out T item)) + { + return item; + } + } + } + catch (Exception exc) when (!(exc is ChannelClosedException)) + { + throw new ChannelClosedException(exc); + } + } + } +} diff --git a/src/System.Threading.Channels/tests/UnboundedChannelTests.cs b/src/System.Threading.Channels/tests/UnboundedChannelTests.cs new file mode 100644 index 0000000000..f3d6b67127 --- /dev/null +++ b/src/System.Threading.Channels/tests/UnboundedChannelTests.cs @@ -0,0 +1,213 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Diagnostics; +using System.Threading.Tasks; +using Xunit; + +namespace System.Threading.Channels.Tests +{ + public abstract class UnboundedChannelTests : ChannelTestBase + { + protected abstract bool AllowSynchronousContinuations { get; } + protected override Channel<int> CreateChannel() => Channel.CreateUnbounded<int>( + new UnboundedChannelOptions + { + SingleReader = RequiresSingleReader, + AllowSynchronousContinuations = AllowSynchronousContinuations + }); + protected override Channel<int> CreateFullChannel() => null; + + [Fact] + public async Task Complete_BeforeEmpty_NoWaiters_TriggersCompletion() + { + Channel<int> c = CreateChannel(); + Assert.True(c.Writer.TryWrite(42)); + c.Writer.Complete(); + Assert.False(c.Reader.Completion.IsCompleted); + Assert.Equal(42, await c.Reader.ReadAsync()); + await c.Reader.Completion; + } + + [Fact] + public void TryWrite_TryRead_Many() + { + Channel<int> c = CreateChannel(); + + const int NumItems = 100000; + for (int i = 0; i < NumItems; i++) + { + Assert.True(c.Writer.TryWrite(i)); + } + for (int i = 0; i < NumItems; i++) + { + Assert.True(c.Reader.TryRead(out int result)); + Assert.Equal(i, result); + } + } + + [Fact] + public void TryWrite_TryRead_OneAtATime() + { + Channel<int> c = CreateChannel(); + + for (int i = 0; i < 10; i++) + { + Assert.True(c.Writer.TryWrite(i)); + Assert.True(c.Reader.TryRead(out int result)); + Assert.Equal(i, result); + } + } + + [Fact] + public void WaitForReadAsync_DataAvailable_CompletesSynchronously() + { + Channel<int> c = CreateChannel(); + Assert.True(c.Writer.TryWrite(42)); + AssertSynchronousTrue(c.Reader.WaitToReadAsync()); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + public async Task WriteMany_ThenComplete_SuccessfullyReadAll(int readMode) + { + Channel<int> c = CreateChannel(); + for (int i = 0; i < 10; i++) + { + Assert.True(c.Writer.TryWrite(i)); + } + + c.Writer.Complete(); + Assert.False(c.Reader.Completion.IsCompleted); + + for (int i = 0; i < 10; i++) + { + Assert.False(c.Reader.Completion.IsCompleted); + switch (readMode) + { + case 0: + int result; + Assert.True(c.Reader.TryRead(out result)); + Assert.Equal(i, result); + break; + case 1: + Assert.Equal(i, await c.Reader.ReadAsync()); + break; + } + } + + await c.Reader.Completion; + } + + [Fact] + public void AllowSynchronousContinuations_WaitToReadAsync_ContinuationsInvokedAccordingToSetting() + { + Channel<int> c = CreateChannel(); + + int expectedId = Environment.CurrentManagedThreadId; + Task r = c.Reader.WaitToReadAsync().ContinueWith(_ => + { + Assert.Equal(AllowSynchronousContinuations, expectedId == Environment.CurrentManagedThreadId); + }, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); + + Assert.Equal(TaskStatus.RanToCompletion, c.Writer.WriteAsync(42).Status); + ((IAsyncResult)r).AsyncWaitHandle.WaitOne(); // avoid inlining the continuation + r.GetAwaiter().GetResult(); + } + + [Fact] + public void AllowSynchronousContinuations_CompletionTask_ContinuationsInvokedAccordingToSetting() + { + Channel<int> c = CreateChannel(); + + int expectedId = Environment.CurrentManagedThreadId; + Task r = c.Reader.Completion.ContinueWith(_ => + { + Assert.Equal(AllowSynchronousContinuations, expectedId == Environment.CurrentManagedThreadId); + }, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); + + Assert.True(c.Writer.TryComplete()); + ((IAsyncResult)r).AsyncWaitHandle.WaitOne(); // avoid inlining the continuation + r.GetAwaiter().GetResult(); + } + } + + public abstract class SingleReaderUnboundedChannelTests : UnboundedChannelTests + { + protected override bool RequiresSingleReader => true; + + [Fact] + public void ValidateInternalDebuggerAttributes() + { + Channel<int> c = CreateChannel(); + Assert.True(c.Writer.TryWrite(1)); + Assert.True(c.Writer.TryWrite(2)); + + object queue = DebuggerAttributes.GetFieldValue(c, "_items"); + DebuggerAttributes.ValidateDebuggerDisplayReferences(queue); + DebuggerAttributes.ValidateDebuggerTypeProxyProperties(queue); + } + + [Fact] + public async Task MultipleWaiters_CancelsPreviousWaiter() + { + Channel<int> c = CreateChannel(); + Task<bool> t1 = c.Reader.WaitToReadAsync(); + Task<bool> t2 = c.Reader.WaitToReadAsync(); + await Assert.ThrowsAnyAsync<OperationCanceledException>(() => t1); + Assert.True(c.Writer.TryWrite(42)); + Assert.True(await t2); + } + + [Fact] + public void Stress_TryWrite_TryRead() + { + const int NumItems = 3000000; + Channel<int> c = CreateChannel(); + + Task.WaitAll( + Task.Run(async () => + { + int received = 0; + while (await c.Reader.WaitToReadAsync()) + { + while (c.Reader.TryRead(out int i)) + { + Assert.Equal(received, i); + received++; + } + } + }), + Task.Run(() => + { + for (int i = 0; i < NumItems; i++) + { + Assert.True(c.Writer.TryWrite(i)); + } + c.Writer.Complete(); + })); + } + } + + public sealed class SyncMultiReaderUnboundedChannelTests : UnboundedChannelTests + { + protected override bool AllowSynchronousContinuations => true; + } + + public sealed class AsyncMultiReaderUnboundedChannelTests : UnboundedChannelTests + { + protected override bool AllowSynchronousContinuations => false; + } + + public sealed class SyncSingleReaderUnboundedChannelTests : SingleReaderUnboundedChannelTests + { + protected override bool AllowSynchronousContinuations => true; + } + + public sealed class AsyncSingleReaderUnboundedChannelTests : SingleReaderUnboundedChannelTests + { + protected override bool AllowSynchronousContinuations => false; + } +} diff --git a/src/System.Threading.Channels/tests/UnbufferedChannelTests.cs b/src/System.Threading.Channels/tests/UnbufferedChannelTests.cs new file mode 100644 index 0000000000..e7e681de69 --- /dev/null +++ b/src/System.Threading.Channels/tests/UnbufferedChannelTests.cs @@ -0,0 +1,94 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Linq; +using System.Threading.Tasks; +using Xunit; + +namespace System.Threading.Channels.Tests +{ + public class UnbufferedChannelTests : ChannelTestBase + { + protected override Channel<int> CreateChannel() => Channel.CreateUnbuffered<int>(); + protected override Channel<int> CreateFullChannel() => CreateChannel(); + + [Fact] + public async Task Complete_BeforeEmpty_WaitingWriters_TriggersCompletion() + { + Channel<int> c = CreateChannel(); + Task write1 = c.Writer.WriteAsync(42); + Task write2 = c.Writer.WriteAsync(43); + c.Writer.Complete(); + await c.Reader.Completion; + await Assert.ThrowsAnyAsync<InvalidOperationException>(() => write1); + await Assert.ThrowsAnyAsync<InvalidOperationException>(() => write2); + } + + [Fact] + public void TryReadWrite_NoPartner_Fail() + { + Channel<int> c = CreateChannel(); + Assert.False(c.Writer.TryWrite(42)); + Assert.False(c.Reader.TryRead(out int result)); + Assert.Equal(result, 0); + } + + [Fact] + public void TryRead_WriteAsync_Success() + { + Channel<int> c = CreateChannel(); + Task w = c.Writer.WriteAsync(42); + Assert.False(w.IsCompleted); + Assert.True(c.Reader.TryRead(out int result)); + Assert.Equal(42, result); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task Read_MultipleUnpartneredWrites_CancelSome_ReadSucceeds(bool useReadAsync) + { + Channel<int> c = CreateChannel(); + var cts = new CancellationTokenSource(); + + Task[] cancelableWrites = (from i in Enumerable.Range(0, 10) select c.Writer.WriteAsync(42, cts.Token)).ToArray(); + Assert.All(cancelableWrites, cw => Assert.Equal(TaskStatus.WaitingForActivation, cw.Status)); + + Task w = c.Writer.WriteAsync(84); + + cts.Cancel(); + foreach (Task t in cancelableWrites) + { + await AssertCanceled(t, cts.Token); + } + + if (useReadAsync) + { + Assert.True(c.Reader.TryRead(out int result)); + Assert.Equal(84, result); + } + else + { + Assert.Equal(84, await c.Reader.ReadAsync()); + } + } + + [Fact] + public async Task Cancel_PartneredWrite_Success() + { + Channel<int> c = CreateChannel(); + var cts = new CancellationTokenSource(); + + Task w = c.Writer.WriteAsync(42, cts.Token); + Assert.False(w.IsCompleted); + + ValueTask<int> r = c.Reader.ReadAsync(); + Assert.True(r.IsCompletedSuccessfully); + + cts.Cancel(); + await w; // no throw + } + + } +} |