Skip to content

Commit

Permalink
Unroll SequenceEqual(ref byte, ref byte, nuint) in JIT (#83945)
Browse files Browse the repository at this point in the history
Co-authored-by: Jakob Botsch Nielsen <Jakob.botsch.nielsen@gmail.com>
  • Loading branch information
EgorBo and jakobbotsch committed Mar 29, 2023
1 parent 8ca896c commit dddf223
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 22 deletions.
8 changes: 8 additions & 0 deletions src/coreclr/jit/importercalls.cpp
Expand Up @@ -3782,6 +3782,7 @@ GenTree* Compiler::impIntrinsic(GenTree* newobjThis,
break;
}

case NI_System_SpanHelpers_SequenceEqual:
case NI_System_Buffer_Memmove:
{
// We'll try to unroll this in lower for constant input.
Expand Down Expand Up @@ -8139,6 +8140,13 @@ NamedIntrinsic Compiler::lookupNamedIntrinsic(CORINFO_METHOD_HANDLE method)
result = NI_System_Span_get_Length;
}
}
else if (strcmp(className, "SpanHelpers") == 0)
{
if (strcmp(methodName, "SequenceEqual") == 0)
{
result = NI_System_SpanHelpers_SequenceEqual;
}
}
else if (strcmp(className, "String") == 0)
{
if (strcmp(methodName, "Equals") == 0)
Expand Down
202 changes: 194 additions & 8 deletions src/coreclr/jit/lower.cpp
Expand Up @@ -1865,6 +1865,185 @@ GenTree* Lowering::LowerCallMemmove(GenTreeCall* call)
return nullptr;
}

//------------------------------------------------------------------------
// LowerCallMemcmp: Replace SpanHelpers.SequenceEqual)(left, right, CNS_SIZE)
// with a series of merged comparisons (via GT_IND nodes)
//
// Arguments:
// tree - GenTreeCall node to unroll as memcmp
//
// Return Value:
// nullptr if no changes were made
//
GenTree* Lowering::LowerCallMemcmp(GenTreeCall* call)
{
JITDUMP("Considering Memcmp [%06d] for unrolling.. ", comp->dspTreeID(call))
assert(comp->lookupNamedIntrinsic(call->gtCallMethHnd) == NI_System_SpanHelpers_SequenceEqual);
assert(call->gtArgs.CountUserArgs() == 3);
assert(TARGET_POINTER_SIZE == 8);

if (!comp->opts.OptimizationEnabled())
{
JITDUMP("Optimizations aren't allowed - bail out.\n")
return nullptr;
}

if (comp->info.compHasNextCallRetAddr)
{
JITDUMP("compHasNextCallRetAddr=true so we won't be able to remove the call - bail out.\n")
return nullptr;
}

GenTree* lengthArg = call->gtArgs.GetUserArgByIndex(2)->GetNode();
if (lengthArg->IsIntegralConst())
{
ssize_t cnsSize = lengthArg->AsIntCon()->IconValue();
JITDUMP("Size=%ld.. ", (LONG)cnsSize);
// TODO-CQ: drop the whole thing in case of 0
if (cnsSize > 0)
{
GenTree* lArg = call->gtArgs.GetUserArgByIndex(0)->GetNode();
GenTree* rArg = call->gtArgs.GetUserArgByIndex(1)->GetNode();
// TODO: Add SIMD path for [16..128] via GT_HWINTRINSIC nodes
if (cnsSize <= 16)
{
unsigned loadWidth = 1 << BitOperations::Log2((unsigned)cnsSize);
var_types loadType;
if (loadWidth == 1)
{
loadType = TYP_UBYTE;
}
else if (loadWidth == 2)
{
loadType = TYP_USHORT;
}
else if (loadWidth == 4)
{
loadType = TYP_INT;
}
else if ((loadWidth == 8) || (loadWidth == 16))
{
loadWidth = 8;
loadType = TYP_LONG;
}
else
{
unreached();
}
var_types actualLoadType = genActualType(loadType);

GenTree* result = nullptr;

// loadWidth == cnsSize means a single load is enough for both args
if ((loadWidth == (unsigned)cnsSize) && (loadWidth <= 8))
{
// We're going to emit something like the following:
//
// bool result = *(int*)leftArg == *(int*)rightArg
//
// ^ in the given example we unroll for length=4
//
GenTree* lIndir = comp->gtNewIndir(loadType, lArg);
GenTree* rIndir = comp->gtNewIndir(loadType, rArg);
result = comp->gtNewOperNode(GT_EQ, TYP_INT, lIndir, rIndir);

BlockRange().InsertAfter(lArg, lIndir);
BlockRange().InsertAfter(rArg, rIndir);
BlockRange().InsertBefore(call, result);
}
else
{
// First, make both args multi-use:
LIR::Use lArgUse;
LIR::Use rArgUse;
bool lFoundUse = BlockRange().TryGetUse(lArg, &lArgUse);
bool rFoundUse = BlockRange().TryGetUse(rArg, &rArgUse);
assert(lFoundUse && rFoundUse);
GenTree* lArgClone = comp->gtNewLclvNode(lArgUse.ReplaceWithLclVar(comp), genActualType(lArg));
GenTree* rArgClone = comp->gtNewLclvNode(rArgUse.ReplaceWithLclVar(comp), genActualType(rArg));
BlockRange().InsertBefore(call, lArgClone, rArgClone);

// We're going to emit something like the following:
//
// bool result = ((*(int*)leftArg ^ *(int*)rightArg) |
// (*(int*)(leftArg + 1) ^ *((int*)(rightArg + 1)))) == 0;
//
// ^ in the given example we unroll for length=5
//
// In IR:
//
// * EQ int
// +--* OR int
// | +--* XOR int
// | | +--* IND int
// | | | \--* LCL_VAR byref V1
// | | \--* IND int
// | | \--* LCL_VAR byref V2
// | \--* XOR int
// | +--* IND int
// | | \--* ADD byref
// | | +--* LCL_VAR byref V1
// | | \--* CNS_INT int 1
// | \--* IND int
// | \--* ADD byref
// | +--* LCL_VAR byref V2
// | \--* CNS_INT int 1
// \--* CNS_INT int 0
//
GenTree* l1Indir = comp->gtNewIndir(loadType, lArgUse.Def());
GenTree* r1Indir = comp->gtNewIndir(loadType, rArgUse.Def());
GenTree* lXor = comp->gtNewOperNode(GT_XOR, actualLoadType, l1Indir, r1Indir);
GenTree* l2Offs = comp->gtNewIconNode(cnsSize - loadWidth, TYP_I_IMPL);
GenTree* l2AddOffs = comp->gtNewOperNode(GT_ADD, lArg->TypeGet(), lArgClone, l2Offs);
GenTree* l2Indir = comp->gtNewIndir(loadType, l2AddOffs);
GenTree* r2Offs = comp->gtCloneExpr(l2Offs); // offset is the same
GenTree* r2AddOffs = comp->gtNewOperNode(GT_ADD, rArg->TypeGet(), rArgClone, r2Offs);
GenTree* r2Indir = comp->gtNewIndir(loadType, r2AddOffs);
GenTree* rXor = comp->gtNewOperNode(GT_XOR, actualLoadType, l2Indir, r2Indir);
GenTree* resultOr = comp->gtNewOperNode(GT_OR, actualLoadType, lXor, rXor);
GenTree* zeroCns = comp->gtNewIconNode(0, actualLoadType);
result = comp->gtNewOperNode(GT_EQ, TYP_INT, resultOr, zeroCns);

BlockRange().InsertAfter(rArgClone, l1Indir, r1Indir, l2Offs, l2AddOffs);
BlockRange().InsertAfter(l2AddOffs, l2Indir, r2Offs, r2AddOffs, r2Indir);
BlockRange().InsertAfter(r2Indir, lXor, rXor, resultOr, zeroCns);
BlockRange().InsertAfter(zeroCns, result);
}

JITDUMP("\nUnrolled to:\n");
DISPTREE(result);

LIR::Use use;
if (BlockRange().TryGetUse(call, &use))
{
use.ReplaceWith(result);
}
BlockRange().Remove(lengthArg);
BlockRange().Remove(call);

// Remove all non-user args (e.g. r2r cell)
for (CallArg& arg : call->gtArgs.Args())
{
if (!arg.IsUserArg())
{
arg.GetNode()->SetUnusedValue();
}
}
return lArg;
}
}
else
{
JITDUMP("Size is either 0 or too big to unroll.\n")
}
}
else
{
JITDUMP("size is not a constant.\n")
}
return nullptr;
}

// do lowering steps for a call
// this includes:
// - adding the placement nodes (either stack or register variety) for arguments
Expand All @@ -1883,19 +2062,26 @@ GenTree* Lowering::LowerCall(GenTree* node)
// All runtime lookups are expected to be expanded in fgExpandRuntimeLookups
assert(!call->IsExpRuntimeLookup());

#if defined(TARGET_AMD64) || defined(TARGET_ARM64)
if (call->gtCallMoreFlags & GTF_CALL_M_SPECIAL_INTRINSIC)
{
#if defined(TARGET_AMD64) || defined(TARGET_ARM64)
if (comp->lookupNamedIntrinsic(call->gtCallMethHnd) == NI_System_Buffer_Memmove)
GenTree* newNode = nullptr;
NamedIntrinsic ni = comp->lookupNamedIntrinsic(call->gtCallMethHnd);
if (ni == NI_System_Buffer_Memmove)
{
GenTree* newNode = LowerCallMemmove(call);
if (newNode != nullptr)
{
return newNode->gtNext;
}
newNode = LowerCallMemmove(call);
}
else if (ni == NI_System_SpanHelpers_SequenceEqual)
{
newNode = LowerCallMemcmp(call);
}

if (newNode != nullptr)
{
return newNode->gtNext;
}
#endif
}
#endif

call->ClearOtherRegs();
LowerArgsForCall(call);
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/lower.h
Expand Up @@ -128,6 +128,7 @@ class Lowering final : public Phase
// ------------------------------
GenTree* LowerCall(GenTree* call);
GenTree* LowerCallMemmove(GenTreeCall* call);
GenTree* LowerCallMemcmp(GenTreeCall* call);
void LowerCFGCall(GenTreeCall* call);
void MoveCFGCallArg(GenTreeCall* call, GenTree* node);
#ifndef TARGET_64BIT
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/namedintrinsiclist.h
Expand Up @@ -104,6 +104,7 @@ enum NamedIntrinsic : unsigned short
NI_System_String_StartsWith,
NI_System_Span_get_Item,
NI_System_Span_get_Length,
NI_System_SpanHelpers_SequenceEqual,
NI_System_ReadOnlySpan_get_Item,
NI_System_ReadOnlySpan_get_Length,

Expand Down
Expand Up @@ -1429,12 +1429,11 @@ private static void ThrowNullLowHighInclusive<T>(T? lowInclusive, T? highInclusi

if (RuntimeHelpers.IsBitwiseEquatable<T>())
{
nuint size = (nuint)sizeof(T);
return length == other.Length &&
SpanHelpers.SequenceEqual(
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(other)),
((uint)length) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
((uint)length) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
}

return length == other.Length && SpanHelpers.SequenceEqual(ref MemoryMarshal.GetReference(span), ref MemoryMarshal.GetReference(other), length);
Expand Down Expand Up @@ -2164,12 +2163,11 @@ public static int SequenceCompareTo<T>(this Span<T> span, ReadOnlySpan<T> other)
int length = span.Length;
if (RuntimeHelpers.IsBitwiseEquatable<T>())
{
nuint size = (nuint)sizeof(T);
return length == other.Length &&
SpanHelpers.SequenceEqual(
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(other)),
((uint)length) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this API in such a case so we choose not to take the overhead of checking.
((uint)length) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this API in such a case so we choose not to take the overhead of checking.
}

return length == other.Length && SpanHelpers.SequenceEqual(ref MemoryMarshal.GetReference(span), ref MemoryMarshal.GetReference(other), length);
Expand Down Expand Up @@ -2207,11 +2205,10 @@ public static unsafe bool SequenceEqual<T>(this ReadOnlySpan<T> span, ReadOnlySp
// If no comparer was supplied and the type is bitwise equatable, take the fast path doing a bitwise comparison.
if (RuntimeHelpers.IsBitwiseEquatable<T>())
{
nuint size = (nuint)sizeof(T);
return SpanHelpers.SequenceEqual(
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(other)),
((uint)span.Length) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this API in such a case so we choose not to take the overhead of checking.
((uint)span.Length) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this API in such a case so we choose not to take the overhead of checking.
}

// Otherwise, compare each element using EqualityComparer<T>.Default.Equals in a way that will enable it to devirtualize.
Expand Down Expand Up @@ -2277,12 +2274,11 @@ public static unsafe int SequenceCompareTo<T>(this ReadOnlySpan<T> span, ReadOnl
int valueLength = value.Length;
if (RuntimeHelpers.IsBitwiseEquatable<T>())
{
nuint size = (nuint)sizeof(T);
return valueLength <= span.Length &&
SpanHelpers.SequenceEqual(
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(value)),
((uint)valueLength) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
((uint)valueLength) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
}

return valueLength <= span.Length && SpanHelpers.SequenceEqual(ref MemoryMarshal.GetReference(span), ref MemoryMarshal.GetReference(value), valueLength);
Expand All @@ -2298,12 +2294,11 @@ public static unsafe int SequenceCompareTo<T>(this ReadOnlySpan<T> span, ReadOnl
int valueLength = value.Length;
if (RuntimeHelpers.IsBitwiseEquatable<T>())
{
nuint size = (nuint)sizeof(T);
return valueLength <= span.Length &&
SpanHelpers.SequenceEqual(
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(value)),
((uint)valueLength) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
((uint)valueLength) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
}

return valueLength <= span.Length && SpanHelpers.SequenceEqual(ref MemoryMarshal.GetReference(span), ref MemoryMarshal.GetReference(value), valueLength);
Expand All @@ -2319,12 +2314,11 @@ public static unsafe int SequenceCompareTo<T>(this ReadOnlySpan<T> span, ReadOnl
int valueLength = value.Length;
if (RuntimeHelpers.IsBitwiseEquatable<T>())
{
nuint size = (nuint)sizeof(T);
return valueLength <= spanLength &&
SpanHelpers.SequenceEqual(
ref Unsafe.As<T, byte>(ref Unsafe.Add(ref MemoryMarshal.GetReference(span), (nint)(uint)(spanLength - valueLength) /* force zero-extension */)),
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(value)),
((uint)valueLength) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
((uint)valueLength) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
}

return valueLength <= spanLength &&
Expand All @@ -2344,12 +2338,11 @@ public static unsafe int SequenceCompareTo<T>(this ReadOnlySpan<T> span, ReadOnl
int valueLength = value.Length;
if (RuntimeHelpers.IsBitwiseEquatable<T>())
{
nuint size = (nuint)sizeof(T);
return valueLength <= spanLength &&
SpanHelpers.SequenceEqual(
ref Unsafe.As<T, byte>(ref Unsafe.Add(ref MemoryMarshal.GetReference(span), (nint)(uint)(spanLength - valueLength) /* force zero-extension */)),
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(value)),
((uint)valueLength) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
((uint)valueLength) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
}

return valueLength <= spanLength &&
Expand Down
Expand Up @@ -566,6 +566,7 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace)

// Optimized byte-based SequenceEquals. The "length" parameter for this one is declared a nuint rather than int as we also use it for types other than byte
// where the length can exceed 2Gb once scaled by sizeof(T).
[Intrinsic] // Unrolled for constant length
public static unsafe bool SequenceEqual(ref byte first, ref byte second, nuint length)
{
bool result;
Expand Down

0 comments on commit dddf223

Please sign in to comment.