Skip to content

Commit

Permalink
Support for various query root scenarios + cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Mar 25, 2024
1 parent ed42d8e commit d68d249
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 57 deletions.
36 changes: 24 additions & 12 deletions src/EFCore.Design/Query/Internal/CSharpToLinqTranslator.cs
Expand Up @@ -351,13 +351,19 @@ public override Expression VisitIdentifierName(IdentifierNameSyntax identifierNa

var symbol = _semanticModel.GetSymbolInfo(identifierName).Symbol;

ILocalSymbol localSymbol;
ITypeSymbol typeSymbol;
switch (symbol)
{
case INamedTypeSymbol typeSymbol:
return Constant(ResolveType(typeSymbol));
case ILocalSymbol ls:
localSymbol = ls;
case INamedTypeSymbol s:
return Constant(ResolveType(s));
case ILocalSymbol s:
typeSymbol = s.Type;
break;
case IFieldSymbol s:
typeSymbol = s.Type;
break;
case IPropertySymbol s:
typeSymbol = s.Type;
break;
case null:
throw new InvalidOperationException($"Identifier without symbol: {identifierName}");
Expand All @@ -366,24 +372,24 @@ public override Expression VisitIdentifierName(IdentifierNameSyntax identifierNa
}

// TODO: Separate out EF Core-specific logic (EF Core would extend this visitor)
if (localSymbol.Type.Name.Contains("DbSet"))
if (typeSymbol.Name.Contains("DbSet"))
{
throw new NotImplementedException("DbSet local symbol");
}

// We have an identifier which isn't in our parameters stack.

// First, if the identifier type is the user's DbContext type, return a constant over that.
if (localSymbol.Type.Equals(_userDbContextSymbol, SymbolEqualityComparer.Default))
// First, if the identifier type is the user's DbContext type (e.g. DbContext local variable, or field/property),
// return a constant over that.
if (typeSymbol.Equals(_userDbContextSymbol, SymbolEqualityComparer.Default))
{
// This is a local DbContext variable.
return Constant(_userDbContext);
}

// The Translate entry point into the translator uses Roslyn's data flow analysis to locate all captured variables, and populates
// the _capturedVariable dictionary with them (with null values).
// TODO: Test closure over class member (not local variable)
if (_capturedVariables.TryGetValue(localSymbol, out var memberExpression))
if (symbol is ILocalSymbol localSymbol && _capturedVariables.TryGetValue(localSymbol, out var memberExpression))
{
// The first time we see a captured variable, we create MemberExpression for it and cache it in _capturedVariables.
return memberExpression
Expand Down Expand Up @@ -450,6 +456,13 @@ public override Expression VisitInvocationExpression(InvocationExpressionSyntax
throw new InvalidOperationException("Could not find symbol for method invocation: " + invocation);
}

// First, if the identifier type is the user's DbContext type (e.g. DbContext local variable, or field/property),
// return a constant over that.
if (methodSymbol.ReturnType.Equals(_userDbContextSymbol, SymbolEqualityComparer.Default))
{
return Constant(_userDbContext);
}

var declaringType = ResolveType(methodSymbol.ContainingType);

Expression? instance = null;
Expand Down Expand Up @@ -557,12 +570,11 @@ public override Expression VisitInvocationExpression(InvocationExpressionSyntax
else
{
// Non-generic method

// TODO: private/internal binding flags
var reducedMethodSymbol = methodSymbol.ReducedFrom ?? methodSymbol;

methodInfo = declaringType.GetMethod(
methodSymbol.Name,
BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static,
reducedMethodSymbol.Parameters.Select(p => ResolveType(p.Type)).ToArray());

if (methodInfo is null)
Expand Down
44 changes: 20 additions & 24 deletions src/EFCore.Design/Query/Internal/PrecompiledQueryCodeGenerator.cs
Expand Up @@ -258,7 +258,11 @@ public PrecompiledQueryCodeGenerator()
{
ProcessQueryOperator(
_code, semanticModel, querySyntax, queryTree, queryNum + 1, operatorNum: out _, isTerminatingOperator: true,
cancellationToken, queryExecutor);
cancellationToken);

// After the terminating operator definition, generate the query's executor generator.
var variableNames = new HashSet<string>(); // TODO
GenerateQueryExecutor(_code, queryNum + 1, queryExecutor, _namespaces, _unsafeAccessors, variableNames);
}
finally
{
Expand Down Expand Up @@ -430,8 +434,7 @@ protected virtual Expression CompileQuery(MethodCallExpression terminatingOperat
int queryNum,
out int operatorNum,
bool isTerminatingOperator,
CancellationToken cancellationToken,
Expression? queryExecutor = null)
CancellationToken cancellationToken)
{
var memberAccess = (MemberAccessExpressionSyntax)operatorSyntax.Expression;

Expand Down Expand Up @@ -474,17 +477,17 @@ protected virtual Expression CompileQuery(MethodCallExpression terminatingOperat
// TODO: Validate the below, throw informative (e.g. top-level TVF fails here because non-generic)
var reducedOperatorSymbol = operatorSymbol.GetConstructedReducedFrom() ?? operatorSymbol;

var (sourceIdentifier, sourceType) = reducedOperatorSymbol.IsStatic
? (_g.IdentifierName(reducedOperatorSymbol.Parameters[0].Name), reducedOperatorSymbol.Parameters[0].Type)
: (_g.ThisExpression(), reducedOperatorSymbol.ReceiverType!);
var (sourceVariableName, sourceTypeSymbol) = reducedOperatorSymbol.IsStatic
? (reducedOperatorSymbol.Parameters[0].Name, reducedOperatorSymbol.Parameters[0].Type)
: ("this", reducedOperatorSymbol.ReceiverType!);

// var sourceParameter = reducedOperatorSymbol.IsStatic ? reducedOperatorSymbol.Parameters[0] : reducedOperatorSymbol.ReceiverType;
// var sourceParameterIdentifier = _g.IdentifierName(sourceParameter.Name);
if (sourceType is not INamedTypeSymbol { TypeArguments: [var sourceElementTypeSymbol]})
if (sourceTypeSymbol is not INamedTypeSymbol { TypeArguments: [var sourceElementTypeSymbol]})
{
throw new UnreachableException($"Non-IQueryable first parameter in LINQ operator '{operatorLinq.Method.Name}'");
}

var sourceElementTypeName = sourceElementTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);

var returnTypeSymbol = reducedOperatorSymbol.ReturnType;

// Unwrap Task<T> to get the element type (e.g. Task<List<int>>)
Expand Down Expand Up @@ -530,8 +533,8 @@ when namedReturnType2.AllInterfaces.Prepend(namedReturnType2)
code.AppendLine(
"var precompiledQueryContext = "
+ (nestedOperatorSyntax is null
? $"new PrecompiledQueryContext<{sourceElementTypeSymbol.Name}>(((IDbContextContainer){sourceIdentifier}).DbContext);"
: $"(PrecompiledQueryContext<{sourceElementTypeSymbol.Name}>){sourceIdentifier};"));
? $"new PrecompiledQueryContext<{sourceElementTypeName}>(((IDbContextContainer){sourceVariableName}).DbContext);"
: $"(PrecompiledQueryContext<{sourceElementTypeName}>){sourceVariableName};"));

// Go over the operator's arguments (skipping the first, which is the source).
// For those which have captured variables, run them through our funcletizer, which will return code for extracting any captured
Expand Down Expand Up @@ -641,7 +644,7 @@ void GenerateCapturedVariableExtractors(string currentIdentifier, Type currentTy
{
// The query returns a scalar, not an enumerable (e.g. the terminating operator is Max()).
// The executor directly returns the needed result (e.g. int), so just return that.
var returnType = _g.TypeExpression(returnTypeSymbol).NormalizeWhitespace().ToFullString();
var returnType = returnTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
code.AppendLine($"return ((Func<QueryContext, {returnType}>)({executorFieldIdentifier}))(queryContext);");
}
else
Expand All @@ -663,8 +666,8 @@ void GenerateCapturedVariableExtractors(string currentIdentifier, Type currentTy
&& operatorLinq.Type.GetGenericTypeDefinition() == typeof(IQueryable<>);

var returnValue = isAsync
? $"IAsyncEnumerable<{sourceElementTypeSymbol.Name}>"
: $"IEnumerable<{sourceElementTypeSymbol.Name}>";
? $"IAsyncEnumerable<{sourceElementTypeName}>"
: $"IEnumerable<{sourceElementTypeName}>";

code.AppendLine(
$"var queryingEnumerable = ((Func<QueryContext, {returnValue}>)({executorFieldIdentifier}))(queryContext);");
Expand Down Expand Up @@ -692,7 +695,7 @@ void GenerateCapturedVariableExtractors(string currentIdentifier, Type currentTy
// TODO: This is an additional runtime allocation; if we had System.Linq.Async we wouldn't need this. We could
// have additional versions of all async terminating operators over IAsyncEnumerable<T> (effectively duplicating
// System.Linq.Async) as an alternative.
code.AppendLine($"var asyncQueryingEnumerable = new PrecompiledQueryableAsyncEnumerableAdapter<{sourceElementTypeSymbol}>(queryingEnumerable);");
code.AppendLine($"var asyncQueryingEnumerable = new PrecompiledQueryableAsyncEnumerableAdapter<{sourceElementTypeName}>(queryingEnumerable);");
code.Append("return asyncQueryingEnumerable");
}
else
Expand Down Expand Up @@ -725,24 +728,17 @@ void GenerateCapturedVariableExtractors(string currentIdentifier, Type currentTy
|| returnTypeSymbol.OriginalDefinition.Equals(_symbols.IOrderedQueryable, SymbolEqualityComparer.Default)
=> SymbolEqualityComparer.Default.Equals(sourceElementTypeSymbol, returnElementTypeSymbol)
? "return precompiledQueryContext;"
: $"return precompiledQueryContext.ToType<{returnElementTypeSymbol}>();",
: $"return precompiledQueryContext.ToType<{returnElementTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>();",

_ when returnTypeSymbol.OriginalDefinition.Equals(_symbols.IIncludableQueryable, SymbolEqualityComparer.Default)
&& returnTypeSymbol is INamedTypeSymbol { OriginalDefinition.TypeArguments: [_, var includedPropertySymbol] }
=> $"return precompiledQueryContext.ToIncludable<{includedPropertySymbol}>();",
=> $"return precompiledQueryContext.ToIncludable<{includedPropertySymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>();",

_ => throw new UnreachableException()
});
}

code.DecrementIndent().AppendLine("}").AppendLine();

// After the terminating operator definition, generate the query's executor generator.
if (isTerminatingOperator)
{
var variableNames = new HashSet<string>(); // TODO
GenerateQueryExecutor(code, queryNum, queryExecutor!, _namespaces, _unsafeAccessors, variableNames);
}
}

private void GenerateQueryExecutor(
Expand Down
Expand Up @@ -10,8 +10,6 @@

namespace Microsoft.EntityFrameworkCore.Query;

#nullable enable

// ReSharper disable InconsistentNaming

public class PrecompiledQueryRelationalTestBase
Expand Down Expand Up @@ -836,11 +834,47 @@ public virtual Task DbContext_as_local_variable()

[ConditionalFact]
public virtual Task DbContext_as_field()
=> throw new NotImplementedException();
=> FullSourceTest(
"""
public static class TestContainer
{
private static PrecompiledQueryContext _context;

public static async Task Test(DbContextOptions dbContextOptions)
{
using (_context = new PrecompiledQueryContext(dbContextOptions))
{
var blogs = await _context.Blogs.ToListAsync();
Assert.Collection(
blogs.OrderBy(b => b.Id),
b => Assert.Equal(8, b.Id),
b => Assert.Equal(9, b.Id));
}
}
}
""");

[ConditionalFact]
public virtual Task DbContext_as_property()
=> throw new NotImplementedException();
=> FullSourceTest(
"""
public static class TestContainer
{
private static PrecompiledQueryContext Context { get; set; }

public static async Task Test(DbContextOptions dbContextOptions)
{
using (Context = new PrecompiledQueryContext(dbContextOptions))
{
var blogs = await Context.Blogs.ToListAsync();
Assert.Collection(
blogs.OrderBy(b => b.Id),
b => Assert.Equal(8, b.Id),
b => Assert.Equal(9, b.Id));
}
}
}
""");

[ConditionalFact]
public virtual Task DbContext_as_captured_variable()
Expand All @@ -852,7 +886,28 @@ public virtual Task DbContext_as_captured_variable()

[ConditionalFact]
public virtual Task DbContext_as_method_invocation_result()
=> throw new NotImplementedException();
=> FullSourceTest(
"""
public static class TestContainer
{
private static PrecompiledQueryContext _context;

public static async Task Test(DbContextOptions dbContextOptions)
{
using (_context = new PrecompiledQueryContext(dbContextOptions))
{
var blogs = await GetContext().Blogs.ToListAsync();
Assert.Collection(
blogs.OrderBy(b => b.Id),
b => Assert.Equal(8, b.Id),
b => Assert.Equal(9, b.Id));
}
}

private static PrecompiledQueryContext GetContext()
=> _context;
}
""");

#endregion Different DbContext expressions

Expand Down Expand Up @@ -1037,6 +1092,21 @@ protected void AssertSql(params string[] expected)
AlwaysPrintGeneratedSources,
callerName);

protected virtual Task FullSourceTest(
string sourceCode,
Action<string>? interceptorCodeAsserter = null,
Action<List<PrecompiledQueryCodeGenerator.QueryPrecompilationError>>? errorAsserter = null,
[CallerMemberName] string callerName = "")
=> Fixture.PrecompiledQueryTestHelpers.FullSourceTest(
sourceCode,
Fixture.ServiceProvider.GetRequiredService<DbContextOptions>(),
typeof(PrecompiledQueryContext),
interceptorCodeAsserter,
errorAsserter,
TestOutputHelper,
AlwaysPrintGeneratedSources,
callerName);

protected virtual bool AlwaysPrintGeneratedSources
=> false;

Expand Down
Expand Up @@ -14,16 +14,38 @@

namespace Microsoft.EntityFrameworkCore.TestUtilities;

#nullable enable

public abstract class PrecompiledQueryTestHelpers
{
private readonly MetadataReference[] _metadataReferences;

protected PrecompiledQueryTestHelpers()
=> _metadataReferences = BuildMetadataReferences().ToArray();

public async Task Test(
public Task Test(
string sourceCode,
DbContextOptions dbContextOptions,
Type dbContextType,
Action<string>? interceptorCodeAsserter,
Action<List<PrecompiledQueryCodeGenerator.QueryPrecompilationError>>? errorAsserter,
ITestOutputHelper testOutputHelper,
bool alwaysPrintGeneratedSources,
string callerName)
{
var source = $$"""
public static class TestContainer
{
public static async Task Test(DbContextOptions dbContextOptions)
{
{{sourceCode}}
}
}
""";
return FullSourceTest(
source, dbContextOptions, dbContextType, interceptorCodeAsserter, errorAsserter, testOutputHelper, alwaysPrintGeneratedSources,
callerName);
}

public async Task FullSourceTest(
string sourceCode,
DbContextOptions dbContextOptions,
Type dbContextType,
Expand All @@ -40,7 +62,7 @@ protected PrecompiledQueryTestHelpers()
// EF LINQ queries.
// 3. Integrate the additional syntax trees into the compilation, and again, produce an assembly from it and load it.
// 4. Use reflection to find the EntryPoint (Main method) on this assembly, and invoke it.
var source = $$"""
var source = $"""
using System;
using System.Collections.Generic;
using System.Linq;
Expand All @@ -53,13 +75,7 @@ protected PrecompiledQueryTestHelpers()
using static Microsoft.EntityFrameworkCore.Query.PrecompiledQueryRelationalTestBase;
//using Microsoft.EntityFrameworkCore.PrecompiledQueryTest;

public static class TestContainer
{
public static async Task Test(DbContextOptions dbContextOptions)
{
{{sourceCode}}
}
}
{sourceCode}
""";

// This turns on the interceptors feature for the designated namespace(s).
Expand Down

0 comments on commit d68d249

Please sign in to comment.