Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use provider discriminator values for OfType() #32878

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -1147,13 +1147,28 @@ SqlExpression GeneratePredicateTpt(StructuralTypeProjectionExpression entityProj
{
var concreteEntityTypes = derivedType.GetConcreteDerivedTypesInclusive().ToList();
var discriminatorColumn = BindProperty(typeReference, discriminatorProperty);

// Apply any value conversion to the discriminator values.
// Note that this is important also to get the correct SqlConstantExpression.Type, which needs to be the provider type
// rather than the model type; this is in line with how we translate constants everywhere else, and is important in order
// for comparison logic between constants to function correctly (see #32865).
var converter = discriminatorColumn.TypeMapping?.Converter?.ConvertToProvider;

return concreteEntityTypes.Count == 1
? _sqlExpressionFactory.Equal(
discriminatorColumn,
_sqlExpressionFactory.Constant(concreteEntityTypes[0].GetDiscriminatorValue()))
_sqlExpressionFactory.Constant(GetDiscriminatorValue(concreteEntityTypes[0])))
: _sqlExpressionFactory.In(
discriminatorColumn,
concreteEntityTypes.Select(et => _sqlExpressionFactory.Constant(et.GetDiscriminatorValue())).ToArray());
concreteEntityTypes.Select(et => _sqlExpressionFactory.Constant(GetDiscriminatorValue(et))).ToArray());

object? GetDiscriminatorValue(IEntityType entityType)
=> entityType.GetDiscriminatorValue() switch
{
object value when converter is not null => converter(value),
object value => value,
null => null
};
}

return _sqlExpressionFactory.Constant(true);
Expand Down
13 changes: 11 additions & 2 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Expand Up @@ -806,14 +806,23 @@ private void AddConditions(SelectExpression selectExpression, IEntityType entity
{
var discriminatorColumn = GetMappedProjection(selectExpression).BindProperty(discriminatorProperty);
var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToList();
var converter = discriminatorColumn.TypeMapping?.Converter?.ConvertToProvider;
var predicate = concreteEntityTypes.Count == 1
? (SqlExpression)Equal(discriminatorColumn, Constant(concreteEntityTypes[0].GetDiscriminatorValue()))
: In(discriminatorColumn, concreteEntityTypes.Select(et => Constant(et.GetDiscriminatorValue())).ToArray());
? (SqlExpression)Equal(discriminatorColumn, Constant(GetDiscriminatorValue(concreteEntityTypes[0])))
: In(discriminatorColumn, concreteEntityTypes.Select(et => Constant(GetDiscriminatorValue(et))).ToArray());

selectExpression.ApplyPredicate(predicate);

// If discriminator predicate is added then it will also serve as condition for existence of dependents in table sharing
return;

object? GetDiscriminatorValue(IEntityType entityType)
=> entityType.GetDiscriminatorValue() switch
{
object value when converter is not null => converter(value),
object value => value,
null => null
};
}

// Keyless entities cannot be table sharing
Expand Down
Expand Up @@ -226,6 +226,22 @@ public override async Task Can_use_of_type_rose(bool async)
""");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task OfType_on_enum_discriminator_with_Where_on_same_discriminator(bool async)
{
await AssertQuery(
async,
ss => ss.Set<Plant>().OfType<Rose>().Where(p => p.Genus == PlantGenus.Rose));

AssertSql(
"""
SELECT [p].[Species], [p].[CountryId], [p].[Genus], [p].[Name], [p].[HasThorns]
FROM [Plants] AS [p]
WHERE [p].[Genus] = 0
""");
}

public override async Task Can_query_all_animals(bool async)
{
await base.Can_query_all_animals(async);
Expand Down