Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Translate non-aggregate string.Join to CONCAT_WS on SQL Server #28900

Draft
wants to merge 1 commit into
base: release/7.0
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -43,6 +43,9 @@ public class SqlServerSqlTranslatingExpressionVisitor : RelationalSqlTranslating
ExpressionType.Modulo
};

private static readonly MethodInfo StringJoinMethodInfo
= typeof(string).GetRuntimeMethod(nameof(string.Join), new[] { typeof(string), typeof(string[]) })!;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -152,6 +155,97 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
return base.VisitUnary(unaryExpression);
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
var translation = base.VisitMethodCall(methodCallExpression);

if (translation != QueryCompilationContext.NotTranslatedExpression)
{
return translation;
}

if (methodCallExpression.Method == StringJoinMethodInfo)
{
if (methodCallExpression.Arguments[1] is not NewArrayExpression newArrayExpression)
{
return QueryCompilationContext.NotTranslatedExpression;
}

var sqlArguments = new SqlExpression[newArrayExpression.Expressions.Count + 1];

if (TranslationFailed(methodCallExpression.Arguments[0], Visit(methodCallExpression.Arguments[0]), out var sqlDelimiter))
{
return QueryCompilationContext.NotTranslatedExpression;
}

sqlArguments[0] = sqlDelimiter!;

var isUnicode = sqlDelimiter!.TypeMapping?.IsUnicode == true;

for (var i = 0; i < newArrayExpression.Expressions.Count; i++)
{
var argument = newArrayExpression.Expressions[i];
if (TranslationFailed(argument, Visit(argument), out var sqlArgument))
{
return QueryCompilationContext.NotTranslatedExpression;
}

// CONCAT_WS returns a type with a length that varies based on actual inputs (i.e. the sum of all argument lengths, plus
// the length needed for the delimiters). We don't know column values (or even parameter values, so we always return max.
// We do vary return varchar(max) or nvarchar(max) based on whether we saw any nvarchar mapping.
if (sqlArgument!.TypeMapping?.IsUnicode == true)
{
isUnicode = true;
}

// CONCAT_WS filters out nulls, but string.Join treats them as empty strings; coalesce unless we know we have a non-nullable
// argument.
sqlArguments[i + 1] = sqlArgument switch
{
ColumnExpression { IsNullable: false } => sqlArgument,
SqlConstantExpression constantExpression => constantExpression.Value is null
? new SqlConstantExpression(Expression.Constant(string.Empty, typeof(string)), null)
: constantExpression,
_ => Dependencies.SqlExpressionFactory.Coalesce(
sqlArgument,
Dependencies.SqlExpressionFactory.Constant(string.Empty, typeof(string)))
};
}

// CONCAT_WS never returns null; a null delimiter is interpreted as an empty string, and null arguments are skipped
// (but we coalesce them above in any case).
return Dependencies.SqlExpressionFactory.Function(
"CONCAT_WS",
sqlArguments,
nullable: false,
argumentsPropagateNullability: new bool[sqlArguments.Length],
methodCallExpression.Method.ReturnType,
Dependencies.TypeMappingSource.FindMapping(isUnicode ? "nvarchar(max)" : "varchar(max)"));
}

return QueryCompilationContext.NotTranslatedExpression;
}

private static string? GetProviderType(SqlExpression expression)
=> expression.TypeMapping?.StoreType;

[DebuggerStepThrough]
private static bool TranslationFailed(Expression? original, Expression? translation, out SqlExpression? castTranslation)
{
if (original != null
&& translation is not SqlExpression)
{
castTranslation = null;
return true;
}

castTranslation = translation as SqlExpression;
return false;
}
}
Expand Up @@ -1459,6 +1459,9 @@ public override Task String_Join_with_ordering(bool async)
public override Task String_Join_over_nullable_column(bool async)
=> AssertTranslationFailed(() => base.String_Join_over_nullable_column(async));

public override Task String_Join_non_aggregate(bool async)
=> AssertTranslationFailed(() => base.String_Join_non_aggregate(async));

public override Task String_Concat(bool async)
=> AssertTranslationFailed(() => base.String_Concat(async));

Expand Down
Expand Up @@ -244,6 +244,18 @@ public virtual Task String_Join_over_nullable_column(bool async)
a.Regions.Split("|").OrderBy(id => id).ToArray());
});

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task String_Join_non_aggregate(bool async)
{
var foo = "foo";

return AssertQuery(
async,
ss => ss.Set<Customer>().Where(c => string.Join("|", c.CompanyName, foo, "bar") == "Around the Horn|foo|bar"),
entryCount: 1);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task String_Concat(bool async)
Expand Down
Expand Up @@ -284,6 +284,19 @@ public override async Task String_Join_with_ordering(bool async)
GROUP BY [c].[City]");
}

[SqlServerCondition(SqlServerCondition.SupportsFunctions2017)]
public override async Task String_Join_non_aggregate(bool async)
{
await base.String_Join_non_aggregate(async);

AssertSql(
@"@__foo_0='foo' (Size = 4000)

SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE CONCAT_WS(N'|', COALESCE([c].[CompanyName], N''), COALESCE(@__foo_0, N''), N'bar') = N'Around the Horn|foo|bar'");
}

[SqlServerCondition(SqlServerCondition.SupportsFunctions2017)]
public override async Task String_Concat(bool async)
{
Expand Down
Expand Up @@ -372,6 +372,9 @@ public override async Task String_Join_with_ordering(bool async)
ORDER BY ""t"".""City"", ""c0"".""CustomerID"" DESC");
}

public override Task String_Join_non_aggregate(bool async)
=> AssertTranslationFailed(() => base.String_Join_non_aggregate(async));

public override async Task String_Concat(bool async)
{
await base.String_Concat(async);
Expand Down