diff options
author | Günther Foidl <gue@korporal.at> | 2022-11-11 23:50:56 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-11-11 23:50:56 +0300 |
commit | 005e2802d5ccce2e1755aa7d46cf03f2271ab203 (patch) | |
tree | ec769e4359e392efefbc5dc534d3fbe3b0202000 | |
parent | 68b828489c1c219f3514a95803f99ba450260d0f (diff) |
MemoryExtensions.Replace(Span<T>, T, T) implemented (#76337)
* Defined API
* Tests
* Scalar implementation
* Use EqualityComparer<T>.Default instead
* Delegation to SpanHelpers.Replace
* ReplaceValueType implemented
* Use ushort instead of short, as it doesn't sign-extend for broadcast and in the scalar loop
* Forward string.Replace(char, char) to SpanHelpers.ReplaceValueType
* Process remainder vectorized only when not done already and with max width available
* Split into inlineable scalar path and non-inlineable vectorized path
* Replaced open coded loops with Replace
* Don't use EqualityComparer<T>.Default
Cf. https://github.com/dotnet/runtime/pull/76337#discussion_r982886319
* Remove guards for remainder
Cf. https://github.com/dotnet/runtime/pull/76337#discussion_r983448480
* Don't split method into scalar and vectorized and don't force inlining of scalar-part
* Fixed assert
ReplaceValueType is called from string.Replace(char, char) so the Debug.Assert was on wrong position, as at entry to method non accelerated platforms are allowed to call it.
* Better handling of remainder from the vectorized loop(s)
Intentionally leave one iteration off, as the remaining elements are done vectorized anyway. This eliminates the less probable case (cf. https://github.com/dotnet/runtime/pull/76337#discussion_r983448480) that the last vector is done twice.
* PR feedback
10 files changed, 326 insertions, 66 deletions
diff --git a/src/libraries/Common/src/System/IO/Archiving.Utils.Unix.cs b/src/libraries/Common/src/System/IO/Archiving.Utils.Unix.cs index 18ab525f919..ce0aaf63e98 100644 --- a/src/libraries/Common/src/System/IO/Archiving.Utils.Unix.cs +++ b/src/libraries/Common/src/System/IO/Archiving.Utils.Unix.cs @@ -11,7 +11,7 @@ namespace System.IO { // Remove leading separators. int nonSlash = path.IndexOfAnyExcept('/'); - if (nonSlash == -1) + if (nonSlash < 0) { nonSlash = path.Length; } diff --git a/src/libraries/Common/src/System/IO/Archiving.Utils.Windows.cs b/src/libraries/Common/src/System/IO/Archiving.Utils.Windows.cs index 412563966fc..beceebc5604 100644 --- a/src/libraries/Common/src/System/IO/Archiving.Utils.Windows.cs +++ b/src/libraries/Common/src/System/IO/Archiving.Utils.Windows.cs @@ -48,7 +48,7 @@ namespace System.IO { // Remove leading separators. int nonSlash = path.IndexOfAnyExcept('/', '\\'); - if (nonSlash == -1) + if (nonSlash < 0) { nonSlash = path.Length; } @@ -76,12 +76,7 @@ namespace System.IO // To ensure tar files remain compatible with Unix, and per the ZIP File Format Specification 4.4.17.1, // all slashes should be forward slashes. - int pos; - while ((pos = dest.IndexOf('\\')) >= 0) - { - dest[pos] = '/'; - dest = dest.Slice(pos + 1); - } + dest.Replace('\\', '/'); }); } } diff --git a/src/libraries/System.Memory/ref/System.Memory.cs b/src/libraries/System.Memory/ref/System.Memory.cs index 9e0c91d131d..6c8f777bcb2 100644 --- a/src/libraries/System.Memory/ref/System.Memory.cs +++ b/src/libraries/System.Memory/ref/System.Memory.cs @@ -293,6 +293,7 @@ namespace System public static bool Overlaps<T>(this System.ReadOnlySpan<T> span, System.ReadOnlySpan<T> other, out int elementOffset) { throw null; } public static bool Overlaps<T>(this System.Span<T> span, System.ReadOnlySpan<T> other) { throw null; } public static bool Overlaps<T>(this System.Span<T> span, System.ReadOnlySpan<T> other, out int elementOffset) { throw null; } + public static void Replace<T>(this System.Span<T> span, T oldValue, T newValue) where T : System.IEquatable<T>? { } public static void Reverse<T>(this System.Span<T> span) { } public static int SequenceCompareTo<T>(this System.ReadOnlySpan<T> span, System.ReadOnlySpan<T> other) where T : System.IComparable<T>? { throw null; } public static int SequenceCompareTo<T>(this System.Span<T> span, System.ReadOnlySpan<T> other) where T : System.IComparable<T>? { throw null; } diff --git a/src/libraries/System.Memory/tests/Span/Replace.T.cs b/src/libraries/System.Memory/tests/Span/Replace.T.cs new file mode 100644 index 00000000000..92b4c16b367 --- /dev/null +++ b/src/libraries/System.Memory/tests/Span/Replace.T.cs @@ -0,0 +1,149 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using Xunit; + +namespace System.SpanTests +{ + public class ReplaceTests_Byte : ReplaceTests<byte> { protected override byte Create(int value) => (byte)value; } + public class ReplaceTests_Int16 : ReplaceTests<short> { protected override short Create(int value) => (short)value; } + public class ReplaceTests_Int32 : ReplaceTests<int> { protected override int Create(int value) => value; } + public class ReplaceTests_Int64 : ReplaceTests<long> { protected override long Create(int value) => value; } + public class ReplaceTests_Char : ReplaceTests<char> { protected override char Create(int value) => (char)value; } + public class ReplaceTests_Double : ReplaceTests<double> { protected override double Create(int value) => (double)value; } + public class ReplaceTests_Record : ReplaceTests<SimpleRecord> { protected override SimpleRecord Create(int value) => new SimpleRecord(value); } + public class ReplaceTests_CustomEquatable : ReplaceTests<CustomEquatable> { protected override CustomEquatable Create(int value) => new CustomEquatable((byte)value); } + + public readonly struct CustomEquatable : IEquatable<CustomEquatable> + { + public byte Value { get; } + + public CustomEquatable(byte value) => Value = value; + + public bool Equals(CustomEquatable other) => other.Value == Value; + } + + public abstract class ReplaceTests<T> where T : IEquatable<T> + { + private readonly T _oldValue; + private readonly T _newValue; + + protected ReplaceTests() + { + _oldValue = Create('a'); + _newValue = Create('b'); + } + + [Fact] + public void ZeroLengthSpan() + { + Exception actual = Record.Exception(() => Span<T>.Empty.Replace(_oldValue, _newValue)); + + Assert.Null(actual); + } + + [Theory] + [MemberData(nameof(Length_MemberData))] + public void AllElementsNeedToBeReplaced(int length) + { + Span<T> span = CreateArray(length, _oldValue); + T[] expected = CreateArray(length, _newValue); + + span.Replace(_oldValue, _newValue); + T[] actual = span.ToArray(); + + Assert.Equal(expected, actual); + } + + [Theory] + [MemberData(nameof(Length_MemberData))] + public void DefaultToBeReplaced(int length) + { + Span<T> span = CreateArray(length); + T[] expected = CreateArray(length, _newValue); + + span.Replace(default, _newValue); + T[] actual = span.ToArray(); + + Assert.Equal(expected, actual); + } + + [Theory] + [MemberData(nameof(Length_MemberData))] + public void NoElementsNeedToBeReplaced(int length) + { + T[] values = { Create('0'), Create('1') }; + + Span<T> span = CreateArray(length, values); + T[] expected = span.ToArray(); + + span.Replace(_oldValue, _newValue); + T[] actual = span.ToArray(); + + Assert.Equal(expected, actual); + } + + [Theory] + [MemberData(nameof(Length_MemberData))] + public void SomeElementsNeedToBeReplaced(int length) + { + T[] values = { Create('0'), Create('1') }; + + Span<T> span = CreateArray(length, values); + span[0] = _oldValue; + span[^1] = _oldValue; + + T[] expected = CreateArray(length, values); + expected[0] = _newValue; + expected[^1] = _newValue; + + span.Replace(_oldValue, _newValue); + T[] actual = span.ToArray(); + + Assert.Equal(expected, actual); + } + + [Theory] + [MemberData(nameof(Length_MemberData))] + public void OldAndNewValueAreSame(int length) + { + T[] values = { Create('0'), Create('1') }; + + Span<T> span = CreateArray(length, values); + span[0] = _oldValue; + span[^1] = _oldValue; + T[] expected = span.ToArray(); + + span.Replace(_oldValue, _oldValue); + T[] actual = span.ToArray(); + + Assert.Equal(expected, actual); + } + + public static IEnumerable<object[]> Length_MemberData() + { + foreach (int length in new[] { 1, 2, 4, 7, 15, 16, 17, 31, 32, 33, 100 }) + { + yield return new object[] { length }; + } + } + + protected abstract T Create(int value); + + private T[] CreateArray(int length, params T[] values) + { + var arr = new T[length]; + + if (values.Length > 0) + { + for (int i = 0; i < arr.Length; i++) + { + arr[i] = values[i % values.Length]; + } + } + + return arr; + } + } +} diff --git a/src/libraries/System.Memory/tests/System.Memory.Tests.csproj b/src/libraries/System.Memory/tests/System.Memory.Tests.csproj index fa93b4b6934..e7758562dcb 100644 --- a/src/libraries/System.Memory/tests/System.Memory.Tests.csproj +++ b/src/libraries/System.Memory/tests/System.Memory.Tests.csproj @@ -100,6 +100,7 @@ <Compile Include="Span\LastIndexOfSequence.T.cs" /> <Compile Include="Span\Overflow.cs" /> <Compile Include="Span\Overlaps.cs" /> + <Compile Include="Span\Replace.T.cs" /> <Compile Include="Span\Reverse.cs" /> <Compile Include="Span\SequenceCompareTo.bool.cs" /> <Compile Include="Span\SequenceCompareTo.byte.cs" /> diff --git a/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs b/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs index cf80b57c079..9bb51d939f2 100644 --- a/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs +++ b/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs @@ -2923,6 +2923,68 @@ namespace System } } + /// <summary> + /// Replaces all occurrences of <paramref name="oldValue"/> with <paramref name="newValue"/>. + /// </summary> + /// <typeparam name="T">The type of the elements in the span.</typeparam> + /// <param name="span">The span in which the elements should be replaced.</param> + /// <param name="oldValue">The value to be replaced with <paramref name="newValue"/>.</param> + /// <param name="newValue">The value to replace all occurrences of <paramref name="oldValue"/>.</param> + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Replace<T>(this Span<T> span, T oldValue, T newValue) where T : IEquatable<T>? + { + if (SpanHelpers.CanVectorizeAndBenefit<T>(span.Length)) + { + nuint length = (uint)span.Length; + + if (Unsafe.SizeOf<T>() == sizeof(byte)) + { + ref byte src = ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)); + SpanHelpers.ReplaceValueType( + ref src, + ref src, + Unsafe.As<T, byte>(ref oldValue), + Unsafe.As<T, byte>(ref newValue), + length); + } + else if (Unsafe.SizeOf<T>() == sizeof(ushort)) + { + // Use ushort rather than short, as this avoids a sign-extending move. + ref ushort src = ref Unsafe.As<T, ushort>(ref MemoryMarshal.GetReference(span)); + SpanHelpers.ReplaceValueType( + ref src, + ref src, + Unsafe.As<T, ushort>(ref oldValue), + Unsafe.As<T, ushort>(ref newValue), + length); + } + else if (Unsafe.SizeOf<T>() == sizeof(int)) + { + ref int src = ref Unsafe.As<T, int>(ref MemoryMarshal.GetReference(span)); + SpanHelpers.ReplaceValueType( + ref src, + ref src, + Unsafe.As<T, int>(ref oldValue), + Unsafe.As<T, int>(ref newValue), + length); + } + else + { + Debug.Assert(Unsafe.SizeOf<T>() == sizeof(long)); + + ref long src = ref Unsafe.As<T, long>(ref MemoryMarshal.GetReference(span)); + SpanHelpers.ReplaceValueType( + ref src, + ref src, + Unsafe.As<T, long>(ref oldValue), + Unsafe.As<T, long>(ref newValue), + length); + } + } + + SpanHelpers.Replace(span, oldValue, newValue); + } + /// <summary>Finds the length of any common prefix shared between <paramref name="span"/> and <paramref name="other"/>.</summary> /// <typeparam name="T">The type of the elements in the spans.</typeparam> /// <param name="span">The first sequence to compare.</param> diff --git a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs index e3bef2bfe47..ea734fd5e28 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs @@ -2620,6 +2620,103 @@ namespace System return -1; } + public static void Replace<T>(Span<T> span, T oldValue, T newValue) where T : IEquatable<T>? + { + if (default(T) is not null || oldValue is not null) + { + Debug.Assert(oldValue is not null); + + for (int i = 0; i < span.Length; ++i) + { + ref T val = ref span[i]; + if (oldValue.Equals(val)) + { + val = newValue; + } + } + } + else + { + for (int i = 0; i < span.Length; ++i) + { + ref T val = ref span[i]; + val ??= newValue; + } + } + } + + public static void ReplaceValueType<T>(ref T src, ref T dst, T oldValue, T newValue, nuint length) where T : struct + { + if (!Vector128.IsHardwareAccelerated || length < (uint)Vector128<T>.Count) + { + for (nuint idx = 0; idx < length; ++idx) + { + T original = Unsafe.Add(ref src, idx); + Unsafe.Add(ref dst, idx) = EqualityComparer<T>.Default.Equals(original, oldValue) ? newValue : original; + } + } + else + { + Debug.Assert(Vector128.IsHardwareAccelerated && Vector128<T>.IsSupported, "Vector128 is not HW-accelerated or not supported"); + + nuint idx = 0; + + if (!Vector256.IsHardwareAccelerated || length < (uint)Vector256<T>.Count) + { + nuint lastVectorIndex = length - (uint)Vector128<T>.Count; + Vector128<T> oldValues = Vector128.Create(oldValue); + Vector128<T> newValues = Vector128.Create(newValue); + Vector128<T> original, mask, result; + + do + { + original = Vector128.LoadUnsafe(ref src, idx); + mask = Vector128.Equals(oldValues, original); + result = Vector128.ConditionalSelect(mask, newValues, original); + result.StoreUnsafe(ref dst, idx); + + idx += (uint)Vector128<T>.Count; + } + while (idx < lastVectorIndex); + + // There are (0, Vector128<T>.Count] elements remaining now. + // As the operation is idempotent, and we know that in total there are at least Vector128<T>.Count + // elements available, we read a vector from the very end, perform the replace and write to the + // the resulting vector at the very end. + // Thus we can eliminate the scalar processing of the remaining elements. + original = Vector128.LoadUnsafe(ref src, lastVectorIndex); + mask = Vector128.Equals(oldValues, original); + result = Vector128.ConditionalSelect(mask, newValues, original); + result.StoreUnsafe(ref dst, lastVectorIndex); + } + else + { + Debug.Assert(Vector256.IsHardwareAccelerated && Vector256<T>.IsSupported, "Vector256 is not HW-accelerated or not supported"); + + nuint lastVectorIndex = length - (uint)Vector256<T>.Count; + Vector256<T> oldValues = Vector256.Create(oldValue); + Vector256<T> newValues = Vector256.Create(newValue); + Vector256<T> original, mask, result; + + do + { + original = Vector256.LoadUnsafe(ref src, idx); + mask = Vector256.Equals(oldValues, original); + result = Vector256.ConditionalSelect(mask, newValues, original); + result.StoreUnsafe(ref dst, idx); + + idx += (uint)Vector256<T>.Count; + } + while (idx < lastVectorIndex); + + original = Vector256.LoadUnsafe(ref src, lastVectorIndex); + mask = Vector256.Equals(oldValues, original); + result = Vector256.ConditionalSelect(mask, newValues, original); + result.StoreUnsafe(ref dst, lastVectorIndex); + } + } + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] private static int ComputeFirstIndex<T>(ref T searchSpace, ref T current, Vector128<T> equals) where T : struct { diff --git a/src/libraries/System.Private.CoreLib/src/System/String.Manipulation.cs b/src/libraries/System.Private.CoreLib/src/System/String.Manipulation.cs index 50b30cb5b01..a2524d0d343 100644 --- a/src/libraries/System.Private.CoreLib/src/System/String.Manipulation.cs +++ b/src/libraries/System.Private.CoreLib/src/System/String.Manipulation.cs @@ -1008,56 +1008,21 @@ namespace System // Copy the remaining characters, doing the replacement as we go. ref ushort pSrc = ref Unsafe.Add(ref GetRawStringDataAsUInt16(), (uint)copyLength); ref ushort pDst = ref Unsafe.Add(ref result.GetRawStringDataAsUInt16(), (uint)copyLength); - nuint i = 0; - if (Vector.IsHardwareAccelerated && Length >= Vector<ushort>.Count) + // If the string is long enough for vectorization to kick in, we'd like to + // process the remaining elements vectorized too. + // Thus we adjust the pointers so that at least one full vector from the end can be processed. + nuint length = (uint)Length; + if (Vector128.IsHardwareAccelerated && length >= (uint)Vector128<ushort>.Count) { - Vector<ushort> oldChars = new(oldChar); - Vector<ushort> newChars = new(newChar); - - Vector<ushort> original; - Vector<ushort> equals; - Vector<ushort> results; - - if (remainingLength > (nuint)Vector<ushort>.Count) - { - nuint lengthToExamine = remainingLength - (nuint)Vector<ushort>.Count; - - do - { - original = Vector.LoadUnsafe(ref pSrc, i); - equals = Vector.Equals(original, oldChars); - results = Vector.ConditionalSelect(equals, newChars, original); - results.StoreUnsafe(ref pDst, i); - - i += (nuint)Vector<ushort>.Count; - } - while (i < lengthToExamine); - } - - // There are [0, Vector<ushort>.Count) elements remaining now. - // As the operation is idempotent, and we know that in total there are at least Vector<ushort>.Count - // elements available, we read a vector from the very end of the string, perform the replace - // and write to the destination at the very end. - // Thus we can eliminate the scalar processing of the remaining elements. - // We perform this operation even if there are 0 elements remaining, as it is cheaper than the - // additional check which would introduce a branch here. - - i = (uint)(Length - Vector<ushort>.Count); - original = Vector.LoadUnsafe(ref GetRawStringDataAsUInt16(), i); - equals = Vector.Equals(original, oldChars); - results = Vector.ConditionalSelect(equals, newChars, original); - results.StoreUnsafe(ref result.GetRawStringDataAsUInt16(), i); - } - else - { - for (; i < remainingLength; ++i) - { - ushort currentChar = Unsafe.Add(ref pSrc, i); - Unsafe.Add(ref pDst, i) = currentChar == oldChar ? newChar : currentChar; - } + nuint adjust = (length - remainingLength) & ((uint)Vector128<ushort>.Count - 1); + pSrc = ref Unsafe.Subtract(ref pSrc, adjust); + pDst = ref Unsafe.Subtract(ref pDst, adjust); + remainingLength += adjust; } + SpanHelpers.ReplaceValueType(ref pSrc, ref pDst, oldChar, newChar, remainingLength); + return result; } diff --git a/src/libraries/System.Private.CoreLib/src/System/Text/StringBuilder.cs b/src/libraries/System.Private.CoreLib/src/System/Text/StringBuilder.cs index 1c0c6f493b0..65a228fa5fe 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Text/StringBuilder.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Text/StringBuilder.cs @@ -1965,12 +1965,7 @@ namespace System.Text int endInChunk = Math.Min(chunk.m_ChunkLength, endIndexInChunk); Span<char> span = chunk.m_ChunkChars.AsSpan(curInChunk, endInChunk - curInChunk); - int i; - while ((i = span.IndexOf(oldChar)) >= 0) - { - span[i] = newChar; - span = span.Slice(i + 1); - } + span.Replace(oldChar, newChar); } if (startIndexInChunk >= 0) diff --git a/src/libraries/System.Private.Uri/src/System/Uri.cs b/src/libraries/System.Private.Uri/src/System/Uri.cs index 634f4cf2eab..b24d4592da5 100644 --- a/src/libraries/System.Private.Uri/src/System/Uri.cs +++ b/src/libraries/System.Private.Uri/src/System/Uri.cs @@ -1035,12 +1035,7 @@ namespace System // Plus going through Compress will turn them into / anyway // Converting / back into \ Span<char> slashSpan = result.AsSpan(0, count); - int slashPos; - while ((slashPos = slashSpan.IndexOf('/')) >= 0) - { - slashSpan[slashPos] = '\\'; - slashSpan = slashSpan.Slice(slashPos + 1); - } + slashSpan.Replace('/', '\\'); return new string(result, 0, count); } |