Skip to content

Commit

Permalink
Fix missing comparer when creating FrozenDictionary (#83651)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub committed Mar 19, 2023
1 parent 9c818a3 commit 9834339
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 90 deletions.
Expand Up @@ -23,7 +23,7 @@ internal sealed class WrappedDictionaryFrozenDictionary<TKey, TValue> :
private TValue[]? _values;

internal WrappedDictionaryFrozenDictionary(Dictionary<TKey, TValue> source, bool sourceIsCopy) : base(source.Comparer) =>
_source = sourceIsCopy ? source : new Dictionary<TKey, TValue>(source);
_source = sourceIsCopy ? source : new Dictionary<TKey, TValue>(source, source.Comparer);

/// <inheritdoc />
private protected sealed override TKey[] KeysCore =>
Expand Down
Expand Up @@ -17,20 +17,21 @@ public abstract class FrozenDictionary_Generic_Tests<TKey, TValue> : IDictionary
protected override bool Enumerator_Current_UndefinedOperation_Throws => true;
protected override Type ICollection_Generic_CopyTo_IndexLargerThanArrayCount_ThrowType => typeof(ArgumentOutOfRangeException);

protected override IDictionary<TKey, TValue> GenericIDictionaryFactory(int count) =>
GenericIDictionaryFactory(count, optimizeForReading: true);
public virtual bool OptimizeForReading => true;

protected virtual bool AllowVeryLargeSizes => true;

protected virtual IDictionary<TKey, TValue> GenericIDictionaryFactory(int count, bool optimizeForReading)
public virtual TKey GetEqualKey(TKey key) => key;

protected override IDictionary<TKey, TValue> GenericIDictionaryFactory(int count)
{
var d = new Dictionary<TKey, TValue>();
for (int i = 0; i < count; i++)
{
d.Add(CreateTKey(i), CreateTValue(i));
}
return optimizeForReading ?
d.ToFrozenDictionary(GetKeyIEqualityComparer(), true) :
return OptimizeForReading ?
d.ToFrozenDictionary(GetKeyIEqualityComparer(), optimizeForReading: true) :
d.ToFrozenDictionary(GetKeyIEqualityComparer());
}

Expand All @@ -46,13 +47,12 @@ public abstract class FrozenDictionary_Generic_Tests<TKey, TValue> : IDictionary
protected override EnumerableOrder Order => EnumerableOrder.Unspecified;

[Theory]
[InlineData(100_000, false)]
[InlineData(100_000, true)]
public virtual void CreateVeryLargeDictionary_Success(int largeCount, bool optimizeForReading)
[InlineData(100_000)]
public virtual void CreateVeryLargeDictionary_Success(int largeCount)
{
if (AllowVeryLargeSizes)
{
GenericIDictionaryFactory(largeCount, optimizeForReading);
GenericIDictionaryFactory(largeCount);
}
}

Expand Down Expand Up @@ -87,26 +87,22 @@ public void EmptySource_ProducedFrozenDictionaryEmpty()
Assert.Same(FrozenDictionary<TKey, TValue>.Empty, Enumerable.Empty<KeyValuePair<TKey, TValue>>().ToFrozenDictionary(comparer));
Assert.Same(FrozenDictionary<TKey, TValue>.Empty, Array.Empty<KeyValuePair<TKey, TValue>>().ToFrozenDictionary(comparer));
Assert.Same(FrozenDictionary<TKey, TValue>.Empty, new List<KeyValuePair<TKey, TValue>>().ToFrozenDictionary(comparer));
foreach (bool optimizeForReading in new[] { false, true })
{
Assert.Same(FrozenDictionary<TKey, TValue>.Empty, new Dictionary<TKey, TValue>().ToFrozenDictionary(comparer, optimizeForReading));
Assert.Same(FrozenDictionary<TKey, TValue>.Empty, Enumerable.Empty<KeyValuePair<TKey, TValue>>().ToFrozenDictionary(comparer));
Assert.Same(FrozenDictionary<TKey, TValue>.Empty, Array.Empty<KeyValuePair<TKey, TValue>>().ToFrozenDictionary(comparer));
Assert.Same(FrozenDictionary<TKey, TValue>.Empty, new List<KeyValuePair<TKey, TValue>>().ToFrozenDictionary(comparer));
}

Assert.Same(FrozenDictionary<TKey, TValue>.Empty, new Dictionary<TKey, TValue>().ToFrozenDictionary(comparer, OptimizeForReading));
Assert.Same(FrozenDictionary<TKey, TValue>.Empty, Enumerable.Empty<KeyValuePair<TKey, TValue>>().ToFrozenDictionary(comparer, OptimizeForReading));
Assert.Same(FrozenDictionary<TKey, TValue>.Empty, Array.Empty<KeyValuePair<TKey, TValue>>().ToFrozenDictionary(comparer, OptimizeForReading));
Assert.Same(FrozenDictionary<TKey, TValue>.Empty, new List<KeyValuePair<TKey, TValue>>().ToFrozenDictionary(comparer, OptimizeForReading));
}

Assert.NotSame(FrozenDictionary<TKey, TValue>.Empty, new Dictionary<TKey, TValue>().ToFrozenDictionary(NonDefaultEqualityComparer<TKey>.Instance));
Assert.NotSame(FrozenDictionary<TKey, TValue>.Empty, Enumerable.Empty<KeyValuePair<TKey, TValue>>().ToFrozenDictionary(NonDefaultEqualityComparer<TKey>.Instance));
Assert.NotSame(FrozenDictionary<TKey, TValue>.Empty, Array.Empty<KeyValuePair<TKey, TValue>>().ToFrozenDictionary(NonDefaultEqualityComparer<TKey>.Instance));
Assert.NotSame(FrozenDictionary<TKey, TValue>.Empty, new List<KeyValuePair<TKey, TValue>>().ToFrozenDictionary(NonDefaultEqualityComparer<TKey>.Instance));
foreach (bool optimizeForReading in new[] { false, true })
{
Assert.NotSame(FrozenDictionary<TKey, TValue>.Empty, new Dictionary<TKey, TValue>().ToFrozenDictionary(NonDefaultEqualityComparer<TKey>.Instance, optimizeForReading));
Assert.NotSame(FrozenDictionary<TKey, TValue>.Empty, Enumerable.Empty<KeyValuePair<TKey, TValue>>().ToFrozenDictionary(NonDefaultEqualityComparer<TKey>.Instance, optimizeForReading));
Assert.NotSame(FrozenDictionary<TKey, TValue>.Empty, Array.Empty<KeyValuePair<TKey, TValue>>().ToFrozenDictionary(NonDefaultEqualityComparer<TKey>.Instance, optimizeForReading));
Assert.NotSame(FrozenDictionary<TKey, TValue>.Empty, new List<KeyValuePair<TKey, TValue>>().ToFrozenDictionary(NonDefaultEqualityComparer<TKey>.Instance, optimizeForReading));
}

Assert.NotSame(FrozenDictionary<TKey, TValue>.Empty, new Dictionary<TKey, TValue>().ToFrozenDictionary(NonDefaultEqualityComparer<TKey>.Instance, OptimizeForReading));
Assert.NotSame(FrozenDictionary<TKey, TValue>.Empty, Enumerable.Empty<KeyValuePair<TKey, TValue>>().ToFrozenDictionary(NonDefaultEqualityComparer<TKey>.Instance, OptimizeForReading));
Assert.NotSame(FrozenDictionary<TKey, TValue>.Empty, Array.Empty<KeyValuePair<TKey, TValue>>().ToFrozenDictionary(NonDefaultEqualityComparer<TKey>.Instance, OptimizeForReading));
Assert.NotSame(FrozenDictionary<TKey, TValue>.Empty, new List<KeyValuePair<TKey, TValue>>().ToFrozenDictionary(NonDefaultEqualityComparer<TKey>.Instance, OptimizeForReading));
}

[Fact]
Expand Down Expand Up @@ -170,14 +166,12 @@ public void FrozenDictionary_ToFrozenDictionary_Idempotent()
Assert.NotSame(frozen, frozen.ToFrozenDictionary(NonDefaultEqualityComparer<TKey>.Instance));
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public void ToFrozenDictionary_BoolArg_UsesDefaultComparer(bool optimizeForReading)
[Fact]
public void ToFrozenDictionary_BoolArg_UsesDefaultComparer()
{
Dictionary<TKey, TValue> source = Enumerable.Range(0, 4).ToDictionary(CreateTKey, CreateTValue);

FrozenDictionary<TKey, TValue> frozen1 = source.ToFrozenDictionary(optimizeForReading);
FrozenDictionary<TKey, TValue> frozen1 = source.ToFrozenDictionary(OptimizeForReading);

Assert.Same(EqualityComparer<TKey>.Default, frozen1.Comparer);
}
Expand Down Expand Up @@ -215,20 +209,19 @@ public void ToFrozenDictionary_KeySelectorAndValueSelector_ResultsAreUsed()
from size in new[] { 0, 1, 2, 10, 999, 1024 }
from comparer in new IEqualityComparer<TKey>[] { null, EqualityComparer<TKey>.Default, NonDefaultEqualityComparer<TKey>.Instance }
from specifySameComparer in new[] { false, true }
from optimizeForReading in new[] { false, true }
select new object[] { size, comparer, specifySameComparer, optimizeForReading };
select new object[] { size, comparer, specifySameComparer };

[Theory]
[MemberData(nameof(LookupItems_AllItemsFoundAsExpected_MemberData))]
public void LookupItems_AllItemsFoundAsExpected(int size, IEqualityComparer<TKey> comparer, bool specifySameComparer, bool optimizeForReading)
public void LookupItems_AllItemsFoundAsExpected(int size, IEqualityComparer<TKey> comparer, bool specifySameComparer)
{
Dictionary<TKey, TValue> original =
Enumerable.Range(0, size)
.Select(i => new KeyValuePair<TKey, TValue>(CreateTKey(i), CreateTValue(i)))
.ToDictionary(p => p.Key, p => p.Value, comparer);
KeyValuePair<TKey, TValue>[] originalPairs = original.ToArray();

FrozenDictionary<TKey, TValue> frozen = (specifySameComparer, optimizeForReading) switch
FrozenDictionary<TKey, TValue> frozen = (specifySameComparer, OptimizeForReading) switch
{
(false, false) => original.ToFrozenDictionary(),
(false, true) => original.ToFrozenDictionary(null, true),
Expand Down Expand Up @@ -275,6 +268,31 @@ public void LookupItems_AllItemsFoundAsExpected(int size, IEqualityComparer<TKey
}
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public void EqualButPossiblyDifferentKeys_Found(bool fromDictionary)
{
Dictionary<TKey, TValue> original =
Enumerable.Range(0, 50)
.Select(i => new KeyValuePair<TKey, TValue>(CreateTKey(i), CreateTValue(i)))
.ToDictionary(p => p.Key, p => p.Value, GetKeyIEqualityComparer());

FrozenDictionary<TKey, TValue> frozen = fromDictionary ?
original.ToFrozenDictionary(GetKeyIEqualityComparer()) :
original.Select(k => k).ToFrozenDictionary(GetKeyIEqualityComparer());

foreach (TKey key in original.Keys)
{
Assert.True(original.ContainsKey(key));
Assert.True(frozen.ContainsKey(key));

TKey equalKey = GetEqualKey(key);
Assert.True(original.ContainsKey(equalKey));
Assert.True(frozen.ContainsKey(equalKey));
}
}

[Fact]
public void MultipleValuesSameKey_LastInSourceWins()
{
Expand All @@ -301,15 +319,12 @@ public void MultipleValuesSameKey_LastInSourceWins()
}

[Theory]
[InlineData(0, false)]
[InlineData(1, false)]
[InlineData(75, false)]
[InlineData(0, true)]
[InlineData(1, true)]
[InlineData(75, true)]
public void IReadOnlyDictionary_Generic_Keys_ContainsAllCorrectKeys(int count, bool optimizeForReading)
[InlineData(0)]
[InlineData(1)]
[InlineData(75)]
public void IReadOnlyDictionary_Generic_Keys_ContainsAllCorrectKeys(int count)
{
IDictionary<TKey, TValue> dictionary = GenericIDictionaryFactory(count, optimizeForReading);
IDictionary<TKey, TValue> dictionary = GenericIDictionaryFactory(count);
IEnumerable<TKey> expected = dictionary.Select((pair) => pair.Key);

IReadOnlyDictionary<TKey, TValue> rod = (IReadOnlyDictionary<TKey, TValue>)dictionary;
Expand All @@ -318,15 +333,12 @@ public void IReadOnlyDictionary_Generic_Keys_ContainsAllCorrectKeys(int count, b
}

[Theory]
[InlineData(0, false)]
[InlineData(1, false)]
[InlineData(75, false)]
[InlineData(0, true)]
[InlineData(1, true)]
[InlineData(75, true)]
public void IReadOnlyDictionary_Generic_Values_ContainsAllCorrectValues(int count, bool optimizeForReading)
[InlineData(0)]
[InlineData(1)]
[InlineData(75)]
public void IReadOnlyDictionary_Generic_Values_ContainsAllCorrectValues(int count)
{
IDictionary<TKey, TValue> dictionary = GenericIDictionaryFactory(count, optimizeForReading);
IDictionary<TKey, TValue> dictionary = GenericIDictionaryFactory(count);
IEnumerable<TValue> expected = dictionary.Select((pair) => pair.Value);

IReadOnlyDictionary<TKey, TValue> rod = (IReadOnlyDictionary<TKey, TValue>)dictionary;
Expand Down Expand Up @@ -374,9 +386,16 @@ public class FrozenDictionary_Generic_Tests_string_string_Ordinal : FrozenDictio
public override IEqualityComparer<string> GetKeyIEqualityComparer() => StringComparer.Ordinal;
}

public class FrozenDictionary_Generic_Tests_string_string_OrdinalIgnoreCase_ReadingUnoptimized : FrozenDictionary_Generic_Tests_string_string_OrdinalIgnoreCase
{
public override bool OptimizeForReading => false;
}

public class FrozenDictionary_Generic_Tests_string_string_OrdinalIgnoreCase : FrozenDictionary_Generic_Tests_string_string
{
public override IEqualityComparer<string> GetKeyIEqualityComparer() => StringComparer.OrdinalIgnoreCase;

public override string GetEqualKey(string key) => key.ToLowerInvariant();
}

public class FrozenDictionary_Generic_Tests_string_string_NonDefault : FrozenDictionary_Generic_Tests_string_string
Expand Down Expand Up @@ -407,13 +426,12 @@ protected override ulong CreateTKey(int seed)

[OuterLoop("Takes several seconds")]
[Theory]
[InlineData(8_000_000, false)]
[InlineData(8_000_000, true)]
public void CreateHugeDictionary_Success(int largeCount, bool optimizeForReading)
[InlineData(8_000_000)]
public void CreateHugeDictionary_Success(int largeCount)
{
if (AllowVeryLargeSizes)
{
GenericIDictionaryFactory(largeCount, optimizeForReading);
GenericIDictionaryFactory(largeCount);
}
}
}
Expand All @@ -433,6 +451,11 @@ public class FrozenDictionary_Generic_Tests_int_int : FrozenDictionary_Generic_T
protected override int CreateTValue(int seed) => CreateTKey(seed);
}

public class FrozenDictionary_Generic_Tests_int_int_ReadingUnoptimized : FrozenDictionary_Generic_Tests_int_int
{
public override bool OptimizeForReading => false;
}

public class FrozenDictionary_Generic_Tests_SimpleClass_SimpleClass : FrozenDictionary_Generic_Tests<SimpleClass, SimpleClass>
{
protected override KeyValuePair<SimpleClass, SimpleClass> CreateT(int seed)
Expand Down

0 comments on commit 9834339

Please sign in to comment.