From dddf223348481d2e52e709c555f907adfae8a96d Mon Sep 17 00:00:00 2001 From: Egor Bogatov Date: Wed, 29 Mar 2023 17:17:08 +0200 Subject: [PATCH] Unroll SequenceEqual(ref byte, ref byte, nuint) in JIT (#83945) Co-authored-by: Jakob Botsch Nielsen --- src/coreclr/jit/importercalls.cpp | 8 + src/coreclr/jit/lower.cpp | 202 +++++++++++++++++- src/coreclr/jit/lower.h | 1 + src/coreclr/jit/namedintrinsiclist.h | 1 + .../src/System/MemoryExtensions.cs | 21 +- .../src/System/SpanHelpers.Byte.cs | 1 + .../SpanHelpers_SequenceEqual.cs | 45 ++++ .../SpanHelpers_SequenceEqual.csproj | 9 + 8 files changed, 266 insertions(+), 22 deletions(-) create mode 100644 src/tests/JIT/opt/Vectorization/SpanHelpers_SequenceEqual.cs create mode 100644 src/tests/JIT/opt/Vectorization/SpanHelpers_SequenceEqual.csproj diff --git a/src/coreclr/jit/importercalls.cpp b/src/coreclr/jit/importercalls.cpp index 09ab191eff40..4c8d1da92ca5 100644 --- a/src/coreclr/jit/importercalls.cpp +++ b/src/coreclr/jit/importercalls.cpp @@ -3782,6 +3782,7 @@ GenTree* Compiler::impIntrinsic(GenTree* newobjThis, break; } + case NI_System_SpanHelpers_SequenceEqual: case NI_System_Buffer_Memmove: { // We'll try to unroll this in lower for constant input. @@ -8139,6 +8140,13 @@ NamedIntrinsic Compiler::lookupNamedIntrinsic(CORINFO_METHOD_HANDLE method) result = NI_System_Span_get_Length; } } + else if (strcmp(className, "SpanHelpers") == 0) + { + if (strcmp(methodName, "SequenceEqual") == 0) + { + result = NI_System_SpanHelpers_SequenceEqual; + } + } else if (strcmp(className, "String") == 0) { if (strcmp(methodName, "Equals") == 0) diff --git a/src/coreclr/jit/lower.cpp b/src/coreclr/jit/lower.cpp index decaed31eeb8..ea5709b0cde5 100644 --- a/src/coreclr/jit/lower.cpp +++ b/src/coreclr/jit/lower.cpp @@ -1865,6 +1865,185 @@ GenTree* Lowering::LowerCallMemmove(GenTreeCall* call) return nullptr; } +//------------------------------------------------------------------------ +// LowerCallMemcmp: Replace SpanHelpers.SequenceEqual)(left, right, CNS_SIZE) +// with a series of merged comparisons (via GT_IND nodes) +// +// Arguments: +// tree - GenTreeCall node to unroll as memcmp +// +// Return Value: +// nullptr if no changes were made +// +GenTree* Lowering::LowerCallMemcmp(GenTreeCall* call) +{ + JITDUMP("Considering Memcmp [%06d] for unrolling.. ", comp->dspTreeID(call)) + assert(comp->lookupNamedIntrinsic(call->gtCallMethHnd) == NI_System_SpanHelpers_SequenceEqual); + assert(call->gtArgs.CountUserArgs() == 3); + assert(TARGET_POINTER_SIZE == 8); + + if (!comp->opts.OptimizationEnabled()) + { + JITDUMP("Optimizations aren't allowed - bail out.\n") + return nullptr; + } + + if (comp->info.compHasNextCallRetAddr) + { + JITDUMP("compHasNextCallRetAddr=true so we won't be able to remove the call - bail out.\n") + return nullptr; + } + + GenTree* lengthArg = call->gtArgs.GetUserArgByIndex(2)->GetNode(); + if (lengthArg->IsIntegralConst()) + { + ssize_t cnsSize = lengthArg->AsIntCon()->IconValue(); + JITDUMP("Size=%ld.. ", (LONG)cnsSize); + // TODO-CQ: drop the whole thing in case of 0 + if (cnsSize > 0) + { + GenTree* lArg = call->gtArgs.GetUserArgByIndex(0)->GetNode(); + GenTree* rArg = call->gtArgs.GetUserArgByIndex(1)->GetNode(); + // TODO: Add SIMD path for [16..128] via GT_HWINTRINSIC nodes + if (cnsSize <= 16) + { + unsigned loadWidth = 1 << BitOperations::Log2((unsigned)cnsSize); + var_types loadType; + if (loadWidth == 1) + { + loadType = TYP_UBYTE; + } + else if (loadWidth == 2) + { + loadType = TYP_USHORT; + } + else if (loadWidth == 4) + { + loadType = TYP_INT; + } + else if ((loadWidth == 8) || (loadWidth == 16)) + { + loadWidth = 8; + loadType = TYP_LONG; + } + else + { + unreached(); + } + var_types actualLoadType = genActualType(loadType); + + GenTree* result = nullptr; + + // loadWidth == cnsSize means a single load is enough for both args + if ((loadWidth == (unsigned)cnsSize) && (loadWidth <= 8)) + { + // We're going to emit something like the following: + // + // bool result = *(int*)leftArg == *(int*)rightArg + // + // ^ in the given example we unroll for length=4 + // + GenTree* lIndir = comp->gtNewIndir(loadType, lArg); + GenTree* rIndir = comp->gtNewIndir(loadType, rArg); + result = comp->gtNewOperNode(GT_EQ, TYP_INT, lIndir, rIndir); + + BlockRange().InsertAfter(lArg, lIndir); + BlockRange().InsertAfter(rArg, rIndir); + BlockRange().InsertBefore(call, result); + } + else + { + // First, make both args multi-use: + LIR::Use lArgUse; + LIR::Use rArgUse; + bool lFoundUse = BlockRange().TryGetUse(lArg, &lArgUse); + bool rFoundUse = BlockRange().TryGetUse(rArg, &rArgUse); + assert(lFoundUse && rFoundUse); + GenTree* lArgClone = comp->gtNewLclvNode(lArgUse.ReplaceWithLclVar(comp), genActualType(lArg)); + GenTree* rArgClone = comp->gtNewLclvNode(rArgUse.ReplaceWithLclVar(comp), genActualType(rArg)); + BlockRange().InsertBefore(call, lArgClone, rArgClone); + + // We're going to emit something like the following: + // + // bool result = ((*(int*)leftArg ^ *(int*)rightArg) | + // (*(int*)(leftArg + 1) ^ *((int*)(rightArg + 1)))) == 0; + // + // ^ in the given example we unroll for length=5 + // + // In IR: + // + // * EQ int + // +--* OR int + // | +--* XOR int + // | | +--* IND int + // | | | \--* LCL_VAR byref V1 + // | | \--* IND int + // | | \--* LCL_VAR byref V2 + // | \--* XOR int + // | +--* IND int + // | | \--* ADD byref + // | | +--* LCL_VAR byref V1 + // | | \--* CNS_INT int 1 + // | \--* IND int + // | \--* ADD byref + // | +--* LCL_VAR byref V2 + // | \--* CNS_INT int 1 + // \--* CNS_INT int 0 + // + GenTree* l1Indir = comp->gtNewIndir(loadType, lArgUse.Def()); + GenTree* r1Indir = comp->gtNewIndir(loadType, rArgUse.Def()); + GenTree* lXor = comp->gtNewOperNode(GT_XOR, actualLoadType, l1Indir, r1Indir); + GenTree* l2Offs = comp->gtNewIconNode(cnsSize - loadWidth, TYP_I_IMPL); + GenTree* l2AddOffs = comp->gtNewOperNode(GT_ADD, lArg->TypeGet(), lArgClone, l2Offs); + GenTree* l2Indir = comp->gtNewIndir(loadType, l2AddOffs); + GenTree* r2Offs = comp->gtCloneExpr(l2Offs); // offset is the same + GenTree* r2AddOffs = comp->gtNewOperNode(GT_ADD, rArg->TypeGet(), rArgClone, r2Offs); + GenTree* r2Indir = comp->gtNewIndir(loadType, r2AddOffs); + GenTree* rXor = comp->gtNewOperNode(GT_XOR, actualLoadType, l2Indir, r2Indir); + GenTree* resultOr = comp->gtNewOperNode(GT_OR, actualLoadType, lXor, rXor); + GenTree* zeroCns = comp->gtNewIconNode(0, actualLoadType); + result = comp->gtNewOperNode(GT_EQ, TYP_INT, resultOr, zeroCns); + + BlockRange().InsertAfter(rArgClone, l1Indir, r1Indir, l2Offs, l2AddOffs); + BlockRange().InsertAfter(l2AddOffs, l2Indir, r2Offs, r2AddOffs, r2Indir); + BlockRange().InsertAfter(r2Indir, lXor, rXor, resultOr, zeroCns); + BlockRange().InsertAfter(zeroCns, result); + } + + JITDUMP("\nUnrolled to:\n"); + DISPTREE(result); + + LIR::Use use; + if (BlockRange().TryGetUse(call, &use)) + { + use.ReplaceWith(result); + } + BlockRange().Remove(lengthArg); + BlockRange().Remove(call); + + // Remove all non-user args (e.g. r2r cell) + for (CallArg& arg : call->gtArgs.Args()) + { + if (!arg.IsUserArg()) + { + arg.GetNode()->SetUnusedValue(); + } + } + return lArg; + } + } + else + { + JITDUMP("Size is either 0 or too big to unroll.\n") + } + } + else + { + JITDUMP("size is not a constant.\n") + } + return nullptr; +} + // do lowering steps for a call // this includes: // - adding the placement nodes (either stack or register variety) for arguments @@ -1883,19 +2062,26 @@ GenTree* Lowering::LowerCall(GenTree* node) // All runtime lookups are expected to be expanded in fgExpandRuntimeLookups assert(!call->IsExpRuntimeLookup()); +#if defined(TARGET_AMD64) || defined(TARGET_ARM64) if (call->gtCallMoreFlags & GTF_CALL_M_SPECIAL_INTRINSIC) { -#if defined(TARGET_AMD64) || defined(TARGET_ARM64) - if (comp->lookupNamedIntrinsic(call->gtCallMethHnd) == NI_System_Buffer_Memmove) + GenTree* newNode = nullptr; + NamedIntrinsic ni = comp->lookupNamedIntrinsic(call->gtCallMethHnd); + if (ni == NI_System_Buffer_Memmove) { - GenTree* newNode = LowerCallMemmove(call); - if (newNode != nullptr) - { - return newNode->gtNext; - } + newNode = LowerCallMemmove(call); + } + else if (ni == NI_System_SpanHelpers_SequenceEqual) + { + newNode = LowerCallMemcmp(call); + } + + if (newNode != nullptr) + { + return newNode->gtNext; } -#endif } +#endif call->ClearOtherRegs(); LowerArgsForCall(call); diff --git a/src/coreclr/jit/lower.h b/src/coreclr/jit/lower.h index 6937d2e3c043..34b5f76f6b6b 100644 --- a/src/coreclr/jit/lower.h +++ b/src/coreclr/jit/lower.h @@ -128,6 +128,7 @@ class Lowering final : public Phase // ------------------------------ GenTree* LowerCall(GenTree* call); GenTree* LowerCallMemmove(GenTreeCall* call); + GenTree* LowerCallMemcmp(GenTreeCall* call); void LowerCFGCall(GenTreeCall* call); void MoveCFGCallArg(GenTreeCall* call, GenTree* node); #ifndef TARGET_64BIT diff --git a/src/coreclr/jit/namedintrinsiclist.h b/src/coreclr/jit/namedintrinsiclist.h index ff17b3d6c770..2e3f1f0a013b 100644 --- a/src/coreclr/jit/namedintrinsiclist.h +++ b/src/coreclr/jit/namedintrinsiclist.h @@ -104,6 +104,7 @@ enum NamedIntrinsic : unsigned short NI_System_String_StartsWith, NI_System_Span_get_Item, NI_System_Span_get_Length, + NI_System_SpanHelpers_SequenceEqual, NI_System_ReadOnlySpan_get_Item, NI_System_ReadOnlySpan_get_Length, diff --git a/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs b/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs index 7db3907fba28..bdeebbf560b7 100644 --- a/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs +++ b/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs @@ -1429,12 +1429,11 @@ private static void ThrowNullLowHighInclusive(T? lowInclusive, T? highInclusi if (RuntimeHelpers.IsBitwiseEquatable()) { - nuint size = (nuint)sizeof(T); return length == other.Length && SpanHelpers.SequenceEqual( ref Unsafe.As(ref MemoryMarshal.GetReference(span)), ref Unsafe.As(ref MemoryMarshal.GetReference(other)), - ((uint)length) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking. + ((uint)length) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking. } return length == other.Length && SpanHelpers.SequenceEqual(ref MemoryMarshal.GetReference(span), ref MemoryMarshal.GetReference(other), length); @@ -2164,12 +2163,11 @@ public static int SequenceCompareTo(this Span span, ReadOnlySpan other) int length = span.Length; if (RuntimeHelpers.IsBitwiseEquatable()) { - nuint size = (nuint)sizeof(T); return length == other.Length && SpanHelpers.SequenceEqual( ref Unsafe.As(ref MemoryMarshal.GetReference(span)), ref Unsafe.As(ref MemoryMarshal.GetReference(other)), - ((uint)length) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this API in such a case so we choose not to take the overhead of checking. + ((uint)length) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this API in such a case so we choose not to take the overhead of checking. } return length == other.Length && SpanHelpers.SequenceEqual(ref MemoryMarshal.GetReference(span), ref MemoryMarshal.GetReference(other), length); @@ -2207,11 +2205,10 @@ public static unsafe bool SequenceEqual(this ReadOnlySpan span, ReadOnlySp // If no comparer was supplied and the type is bitwise equatable, take the fast path doing a bitwise comparison. if (RuntimeHelpers.IsBitwiseEquatable()) { - nuint size = (nuint)sizeof(T); return SpanHelpers.SequenceEqual( ref Unsafe.As(ref MemoryMarshal.GetReference(span)), ref Unsafe.As(ref MemoryMarshal.GetReference(other)), - ((uint)span.Length) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this API in such a case so we choose not to take the overhead of checking. + ((uint)span.Length) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this API in such a case so we choose not to take the overhead of checking. } // Otherwise, compare each element using EqualityComparer.Default.Equals in a way that will enable it to devirtualize. @@ -2277,12 +2274,11 @@ public static unsafe int SequenceCompareTo(this ReadOnlySpan span, ReadOnl int valueLength = value.Length; if (RuntimeHelpers.IsBitwiseEquatable()) { - nuint size = (nuint)sizeof(T); return valueLength <= span.Length && SpanHelpers.SequenceEqual( ref Unsafe.As(ref MemoryMarshal.GetReference(span)), ref Unsafe.As(ref MemoryMarshal.GetReference(value)), - ((uint)valueLength) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking. + ((uint)valueLength) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking. } return valueLength <= span.Length && SpanHelpers.SequenceEqual(ref MemoryMarshal.GetReference(span), ref MemoryMarshal.GetReference(value), valueLength); @@ -2298,12 +2294,11 @@ public static unsafe int SequenceCompareTo(this ReadOnlySpan span, ReadOnl int valueLength = value.Length; if (RuntimeHelpers.IsBitwiseEquatable()) { - nuint size = (nuint)sizeof(T); return valueLength <= span.Length && SpanHelpers.SequenceEqual( ref Unsafe.As(ref MemoryMarshal.GetReference(span)), ref Unsafe.As(ref MemoryMarshal.GetReference(value)), - ((uint)valueLength) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking. + ((uint)valueLength) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking. } return valueLength <= span.Length && SpanHelpers.SequenceEqual(ref MemoryMarshal.GetReference(span), ref MemoryMarshal.GetReference(value), valueLength); @@ -2319,12 +2314,11 @@ public static unsafe int SequenceCompareTo(this ReadOnlySpan span, ReadOnl int valueLength = value.Length; if (RuntimeHelpers.IsBitwiseEquatable()) { - nuint size = (nuint)sizeof(T); return valueLength <= spanLength && SpanHelpers.SequenceEqual( ref Unsafe.As(ref Unsafe.Add(ref MemoryMarshal.GetReference(span), (nint)(uint)(spanLength - valueLength) /* force zero-extension */)), ref Unsafe.As(ref MemoryMarshal.GetReference(value)), - ((uint)valueLength) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking. + ((uint)valueLength) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking. } return valueLength <= spanLength && @@ -2344,12 +2338,11 @@ public static unsafe int SequenceCompareTo(this ReadOnlySpan span, ReadOnl int valueLength = value.Length; if (RuntimeHelpers.IsBitwiseEquatable()) { - nuint size = (nuint)sizeof(T); return valueLength <= spanLength && SpanHelpers.SequenceEqual( ref Unsafe.As(ref Unsafe.Add(ref MemoryMarshal.GetReference(span), (nint)(uint)(spanLength - valueLength) /* force zero-extension */)), ref Unsafe.As(ref MemoryMarshal.GetReference(value)), - ((uint)valueLength) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking. + ((uint)valueLength) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking. } return valueLength <= spanLength && diff --git a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs index 0a41cc7b8beb..c3d7cd2cd507 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs @@ -566,6 +566,7 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace) // Optimized byte-based SequenceEquals. The "length" parameter for this one is declared a nuint rather than int as we also use it for types other than byte // where the length can exceed 2Gb once scaled by sizeof(T). + [Intrinsic] // Unrolled for constant length public static unsafe bool SequenceEqual(ref byte first, ref byte second, nuint length) { bool result; diff --git a/src/tests/JIT/opt/Vectorization/SpanHelpers_SequenceEqual.cs b/src/tests/JIT/opt/Vectorization/SpanHelpers_SequenceEqual.cs new file mode 100644 index 000000000000..722977a59d37 --- /dev/null +++ b/src/tests/JIT/opt/Vectorization/SpanHelpers_SequenceEqual.cs @@ -0,0 +1,45 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; + +public class UnrollSequenceEqualTests +{ + public static int Main() + { + var testMethods = typeof(UnrollSequenceEqualTests) + .GetMethods(BindingFlags.Static | BindingFlags.NonPublic) + .Where(m => m.Name.StartsWith("Test")); + + foreach (MethodInfo testMethod in testMethods) + if (!(bool)testMethod.Invoke(null, new object[] { "0123456789ABCDEF0"u8.ToArray() })) + throw new InvalidOperationException($"{testMethod.Name} returned false."); + + foreach (MethodInfo testMethod in testMethods) + if ((bool)testMethod.Invoke(null, new object[] { "123456789ABCDEF01"u8.ToArray() })) + throw new InvalidOperationException($"{testMethod.Name} returned true."); + + return 100; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] static bool Test1(byte[] data) => data.AsSpan().StartsWith("0"u8); + [MethodImpl(MethodImplOptions.AggressiveInlining)] static bool Test2(byte[] data) => data.AsSpan().StartsWith("01"u8); + [MethodImpl(MethodImplOptions.AggressiveInlining)] static bool Test3(byte[] data) => data.AsSpan().StartsWith("012"u8); + [MethodImpl(MethodImplOptions.AggressiveInlining)] static bool Test4(byte[] data) => data.AsSpan().StartsWith("0123"u8); + [MethodImpl(MethodImplOptions.AggressiveInlining)] static bool Test5(byte[] data) => data.AsSpan().StartsWith("01234"u8); + [MethodImpl(MethodImplOptions.AggressiveInlining)] static bool Test6(byte[] data) => data.AsSpan().StartsWith("012345"u8); + [MethodImpl(MethodImplOptions.AggressiveInlining)] static bool Test7(byte[] data) => data.AsSpan().StartsWith("0123456"u8); + [MethodImpl(MethodImplOptions.AggressiveInlining)] static bool Test8(byte[] data) => data.AsSpan().StartsWith("01234567"u8); + [MethodImpl(MethodImplOptions.AggressiveInlining)] static bool Test9(byte[] data) => data.AsSpan().StartsWith("012345678"u8); + [MethodImpl(MethodImplOptions.AggressiveInlining)] static bool Test10(byte[] data) => data.AsSpan().StartsWith("0123456789"u8); + [MethodImpl(MethodImplOptions.AggressiveInlining)] static bool Test11(byte[] data) => data.AsSpan().StartsWith("0123456789A"u8); + [MethodImpl(MethodImplOptions.AggressiveInlining)] static bool Test12(byte[] data) => data.AsSpan().StartsWith("0123456789AB"u8); + [MethodImpl(MethodImplOptions.AggressiveInlining)] static bool Test13(byte[] data) => data.AsSpan().StartsWith("0123456789ABC"u8); + [MethodImpl(MethodImplOptions.AggressiveInlining)] static bool Test14(byte[] data) => data.AsSpan().StartsWith("0123456789ABCD"u8); + [MethodImpl(MethodImplOptions.AggressiveInlining)] static bool Test15(byte[] data) => data.AsSpan().StartsWith("0123456789ABCDE"u8); + [MethodImpl(MethodImplOptions.AggressiveInlining)] static bool Test16(byte[] data) => data.AsSpan().StartsWith("0123456789ABCDEF"u8); + [MethodImpl(MethodImplOptions.AggressiveInlining)] static bool Test17(byte[] data) => data.AsSpan().StartsWith("0123456789ABCDEF0"u8); +} diff --git a/src/tests/JIT/opt/Vectorization/SpanHelpers_SequenceEqual.csproj b/src/tests/JIT/opt/Vectorization/SpanHelpers_SequenceEqual.csproj new file mode 100644 index 000000000000..6946bed81bfd --- /dev/null +++ b/src/tests/JIT/opt/Vectorization/SpanHelpers_SequenceEqual.csproj @@ -0,0 +1,9 @@ + + + Exe + True + + + + +