Skip to content

Commit

Permalink
Use provider discriminator values for OfType()
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Jan 22, 2024
1 parent af05058 commit dcc4889
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 2 deletions.
Expand Up @@ -1147,13 +1147,28 @@ SqlExpression GeneratePredicateTpt(StructuralTypeProjectionExpression entityProj
{
var concreteEntityTypes = derivedType.GetConcreteDerivedTypesInclusive().ToList();
var discriminatorColumn = BindProperty(typeReference, discriminatorProperty);

// Apply any value conversion to the discriminator values.
// Note that this is important also to get the correct SqlConstantExpression.Type, which needs to be the provider type
// rather than the model type; this is in line with how we translate constants everywhere else, and is important in order
// for comparison logic between constants to function correctly (see #32865).
var converter = discriminatorColumn.TypeMapping?.Converter?.ConvertToProvider;

return concreteEntityTypes.Count == 1
? _sqlExpressionFactory.Equal(
discriminatorColumn,
_sqlExpressionFactory.Constant(concreteEntityTypes[0].GetDiscriminatorValue()))
_sqlExpressionFactory.Constant(GetDiscriminatorValue(concreteEntityTypes[0])))
: _sqlExpressionFactory.In(
discriminatorColumn,
concreteEntityTypes.Select(et => _sqlExpressionFactory.Constant(et.GetDiscriminatorValue())).ToArray());
concreteEntityTypes.Select(et => _sqlExpressionFactory.Constant(GetDiscriminatorValue(et))).ToArray());

object? GetDiscriminatorValue(IEntityType entityType)
=> entityType.GetDiscriminatorValue() switch
{
object value when converter is not null => converter(value),
object value => value,
null => null
};
}

return _sqlExpressionFactory.Constant(true);
Expand Down
Expand Up @@ -2382,4 +2382,42 @@ public class CompanyDto : ICompanyDto
}

#endregion

#region // #32865

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task OfType_on_enum_discriminator_with_Where_on_same_discriminator(bool async)
{
var contextFactory = await InitializeAsync<Context32865>();
await using var context = contextFactory.CreateContext();

context.Add(new Context32865.SpecialBlog());
await context.SaveChangesAsync();

Assert.Equal(1, await context.Set<Context32865.Blog>()
.OfType<Context32865.SpecialBlog>()
.CountAsync(b => b.Type == Context32865.BlogType.Special));
}

private class Context32865(DbContextOptions options) : DbContext(options)
{
protected override void OnModelCreating(ModelBuilder modelBuilder)
=> modelBuilder.Entity<Blog>()
.HasDiscriminator(e => e.Type)
.HasValue<Blog>(BlogType.Regular)
.HasValue<SpecialBlog>(BlogType.Special);

public class Blog
{
public int Id { get; set; }
public BlogType Type { get; set; }
}

public class SpecialBlog : Blog;

public enum BlogType { Regular, Special }
}

#endregion // #32865
}
Expand Up @@ -2317,6 +2317,28 @@ WHERE CASE
WHEN [c0].[Id] IS NOT NULL THEN [c1].[CountryName]
ELSE NULL
END = N'COUNTRY'
""");
}

public override async Task OfType_on_enum_discriminator_with_Where_on_same_discriminator(bool async)
{
await base.OfType_on_enum_discriminator_with_Where_on_same_discriminator(async);

AssertSql(
"""
@p0='1'

SET IMPLICIT_TRANSACTIONS OFF;
SET NOCOUNT ON;
INSERT INTO [Blog] ([Type])
OUTPUT INSERTED.[Id]
VALUES (@p0);
""",
//
"""
SELECT COUNT(*)
FROM [Blog] AS [b]
WHERE [b].[Type] = 1
""");
}
}

0 comments on commit dcc4889

Please sign in to comment.