diff --git a/src/EFCore.Cosmos/ChangeTracking/Internal/StringDictionaryComparer.cs b/src/EFCore.Cosmos/ChangeTracking/Internal/StringDictionaryComparer.cs index aa4589d54b8..096ebc3b278 100644 --- a/src/EFCore.Cosmos/ChangeTracking/Internal/StringDictionaryComparer.cs +++ b/src/EFCore.Cosmos/ChangeTracking/Internal/StringDictionaryComparer.cs @@ -14,13 +14,25 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.ChangeTracking.Internal; /// public sealed class StringDictionaryComparer : ValueComparer, IInfrastructure { + private static readonly bool UseOldBehavior35239 = + AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue35239", out var enabled35239) && enabled35239; + private static readonly MethodInfo CompareMethod = typeof(StringDictionaryComparer).GetMethod( + nameof(Compare), BindingFlags.Static | BindingFlags.NonPublic, [typeof(object), typeof(object), typeof(Func)])!; + + private static readonly MethodInfo LegacyCompareMethod = typeof(StringDictionaryComparer).GetMethod( nameof(Compare), BindingFlags.Static | BindingFlags.NonPublic, [typeof(object), typeof(object), typeof(ValueComparer)])!; private static readonly MethodInfo GetHashCodeMethod = typeof(StringDictionaryComparer).GetMethod( + nameof(GetHashCode), BindingFlags.Static | BindingFlags.NonPublic, [typeof(IEnumerable), typeof(Func)])!; + + private static readonly MethodInfo LegacyGetHashCodeMethod = typeof(StringDictionaryComparer).GetMethod( nameof(GetHashCode), BindingFlags.Static | BindingFlags.NonPublic, [typeof(IEnumerable), typeof(ValueComparer)])!; private static readonly MethodInfo SnapshotMethod = typeof(StringDictionaryComparer).GetMethod( + nameof(Snapshot), BindingFlags.Static | BindingFlags.NonPublic, [typeof(object), typeof(Func)])!; + + private static readonly MethodInfo LegacySnapshotMethod = typeof(StringDictionaryComparer).GetMethod( nameof(Snapshot), BindingFlags.Static | BindingFlags.NonPublic, [typeof(object), typeof(ValueComparer)])!; /// @@ -52,14 +64,56 @@ ValueComparer IInfrastructure.Instance var prm1 = Expression.Parameter(typeof(object), "a"); var prm2 = Expression.Parameter(typeof(object), "b"); + if (UseOldBehavior35239) + { + // (a, b) => Compare(a, b, new Comparer(...)) + return Expression.Lambda>( + Expression.Call( + LegacyCompareMethod, + prm1, + prm2, +#pragma warning disable EF9100 + elementComparer.ConstructorExpression), +#pragma warning restore EF9100 + prm1, + prm2); + } + + // we check the compatibility between element type we expect on the Equals methods + // vs what we actually get from the element comparer + // if the expected is assignable from actual we can just do simple call... + if (typeof(TElement).IsAssignableFrom(elementComparer.Type)) + { + // (a, b) => Compare(a, b, elementComparer.Equals) + return Expression.Lambda>( + Expression.Call( + CompareMethod, + prm1, + prm2, + elementComparer.EqualsExpression), + prm1, + prm2); + } + + // ...otherwise we need to rewrite the actual lambda (as we can't change the expected signature) + // in that case we are rewriting the inner lambda parameters to TElement and cast to the element comparer + // type argument in the body, so that semantics of the element comparison func don't change + var newInnerPrm1 = Expression.Parameter(typeof(TElement), "a"); + var newInnerPrm2 = Expression.Parameter(typeof(TElement), "b"); + + var newEqualsExpressionBody = elementComparer.ExtractEqualsBody( + Expression.Convert(newInnerPrm1, elementComparer.Type), + Expression.Convert(newInnerPrm2, elementComparer.Type)); + return Expression.Lambda>( Expression.Call( CompareMethod, prm1, prm2, -#pragma warning disable EF9100 - elementComparer.ConstructorExpression), -#pragma warning restore EF9100 + Expression.Lambda( + newEqualsExpressionBody, + newInnerPrm1, + newInnerPrm2)), prm1, prm2); } @@ -68,15 +122,50 @@ private static Expression> GetHashCodeLambda(ValueComparer ele { var prm = Expression.Parameter(typeof(object), "o"); + if (UseOldBehavior35239) + { + // o => GetHashCode((IEnumerable)o, new Comparer(...)) + return Expression.Lambda>( + Expression.Call( + LegacyGetHashCodeMethod, + Expression.Convert( + prm, + typeof(IEnumerable)), +#pragma warning disable EF9100 + elementComparer.ConstructorExpression), +#pragma warning restore EF9100 + prm); + } + + if (typeof(TElement).IsAssignableFrom(elementComparer.Type)) + { + // o => GetHashCode((IEnumerable)o, elementComparer.GetHashCode) + return Expression.Lambda>( + Expression.Call( + GetHashCodeMethod, + Expression.Convert( + prm, + typeof(IEnumerable)), + elementComparer.HashCodeExpression), + prm); + } + + var newInnerPrm = Expression.Parameter(typeof(TElement), "o"); + + var newInnerBody = elementComparer.ExtractHashCodeBody( + Expression.Convert( + newInnerPrm, + elementComparer.Type)); + return Expression.Lambda>( Expression.Call( GetHashCodeMethod, Expression.Convert( prm, typeof(IEnumerable)), -#pragma warning disable EF9100 - elementComparer.ConstructorExpression), -#pragma warning restore EF9100 + Expression.Lambda( + newInnerBody, + newInnerPrm)), prm); } @@ -84,16 +173,93 @@ private static Expression> SnapshotLambda(ValueComparer ele { var prm = Expression.Parameter(typeof(object), "source"); + if (UseOldBehavior35239) + { + // source => Snapshot(source, new Comparer(..)) + return Expression.Lambda>( + Expression.Call( + LegacySnapshotMethod, + prm, +#pragma warning disable EF9100 + elementComparer.ConstructorExpression), +#pragma warning restore EF9100 + prm); + } + + // TElement is both argument and return type so the types need to be the same + if (typeof(TElement) == elementComparer.Type) + { + // source => Snapshot(source, elementComparer.Snapshot) + return Expression.Lambda>( + Expression.Call( + SnapshotMethod, + prm, + elementComparer.SnapshotExpression), + prm); + } + + var newInnerPrm = Expression.Parameter(typeof(TElement), "source"); + + var newInnerBody = elementComparer.ExtractSnapshotBody( + Expression.Convert( + newInnerPrm, + elementComparer.Type)); + + // note we need to also convert the result of inner lambda back to TElement return Expression.Lambda>( Expression.Call( SnapshotMethod, prm, -#pragma warning disable EF9100 - elementComparer.ConstructorExpression), -#pragma warning restore EF9100 + Expression.Lambda( + Expression.Convert( + newInnerBody, + typeof(TElement)), + newInnerPrm)), prm); } + private static bool Compare(object? a, object? b, Func elementCompare) + { + if (ReferenceEquals(a, b)) + { + return true; + } + + if (a is null) + { + return b is null; + } + + if (b is null) + { + return false; + } + + if (a is IReadOnlyDictionary aDictionary && b is IReadOnlyDictionary bDictionary) + { + if (aDictionary.Count != bDictionary.Count) + { + return false; + } + + foreach (var pair in aDictionary) + { + if (!bDictionary.TryGetValue(pair.Key, out var bValue) + || !elementCompare(pair.Value, bValue)) + { + return false; + } + } + + return true; + } + + throw new InvalidOperationException( + CosmosStrings.BadDictionaryType( + (a is IDictionary ? b : a).GetType().ShortDisplayName(), + typeof(IDictionary<,>).MakeGenericType(typeof(string), typeof(TElement)).ShortDisplayName())); + } + private static bool Compare(object? a, object? b, ValueComparer elementComparer) { if (ReferenceEquals(a, b)) @@ -136,6 +302,27 @@ private static bool Compare(object? a, object? b, ValueComparer elementComparer) typeof(IDictionary<,>).MakeGenericType(typeof(string), elementComparer.Type).ShortDisplayName())); } + private static int GetHashCode(IEnumerable source, Func elementGetHashCode) + { + if (source is not IReadOnlyDictionary sourceDictionary) + { + throw new InvalidOperationException( + CosmosStrings.BadDictionaryType( + source.GetType().ShortDisplayName(), + typeof(IList<>).MakeGenericType(typeof(TElement)).ShortDisplayName())); + } + + var hash = new HashCode(); + + foreach (var pair in sourceDictionary) + { + hash.Add(pair.Key); + hash.Add(pair.Value == null ? 0 : elementGetHashCode(pair.Value)); + } + + return hash.ToHashCode(); + } + private static int GetHashCode(IEnumerable source, ValueComparer elementComparer) { if (source is not IReadOnlyDictionary sourceDictionary) @@ -157,6 +344,25 @@ private static int GetHashCode(IEnumerable source, ValueComparer elementComparer return hash.ToHashCode(); } + private static IReadOnlyDictionary Snapshot(object source, Func elementSnapshot) + { + if (source is not IReadOnlyDictionary sourceDictionary) + { + throw new InvalidOperationException( + CosmosStrings.BadDictionaryType( + source.GetType().ShortDisplayName(), + typeof(IDictionary<,>).MakeGenericType(typeof(string), typeof(TElement)).ShortDisplayName())); + } + + var snapshot = new Dictionary(); + foreach (var pair in sourceDictionary) + { + snapshot[pair.Key] = pair.Value == null ? default : (TElement?)elementSnapshot(pair.Value); + } + + return snapshot; + } + private static IReadOnlyDictionary Snapshot(object source, ValueComparer elementComparer) { if (source is not IReadOnlyDictionary sourceDictionary) diff --git a/src/EFCore/ChangeTracking/ListOfNullableValueTypesComparer.cs b/src/EFCore/ChangeTracking/ListOfNullableValueTypesComparer.cs index d95a40877db..a8c08979f18 100644 --- a/src/EFCore/ChangeTracking/ListOfNullableValueTypesComparer.cs +++ b/src/EFCore/ChangeTracking/ListOfNullableValueTypesComparer.cs @@ -25,6 +25,9 @@ public sealed class ListOfNullableValueTypesComparer : IInfrastructure where TElement : struct { + private static readonly bool UseOldBehavior35239 = + AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue35239", out var enabled35239) && enabled35239; + private static readonly bool IsArray = typeof(TConcreteList).IsArray; private static readonly bool IsReadOnly = IsArray @@ -32,14 +35,26 @@ public sealed class ListOfNullableValueTypesComparer : && typeof(TConcreteList).GetGenericTypeDefinition() == typeof(ReadOnlyCollection<>)); private static readonly MethodInfo CompareMethod = typeof(ListOfNullableValueTypesComparer).GetMethod( + nameof(Compare), BindingFlags.Static | BindingFlags.NonPublic, + [typeof(IEnumerable), typeof(IEnumerable), typeof(Func)])!; + + private static readonly MethodInfo LegacyCompareMethod = typeof(ListOfNullableValueTypesComparer).GetMethod( nameof(Compare), BindingFlags.Static | BindingFlags.NonPublic, [typeof(IEnumerable), typeof(IEnumerable), typeof(ValueComparer)])!; private static readonly MethodInfo GetHashCodeMethod = typeof(ListOfNullableValueTypesComparer).GetMethod( + nameof(GetHashCode), BindingFlags.Static | BindingFlags.NonPublic, + [typeof(IEnumerable), typeof(Func)])!; + + private static readonly MethodInfo LegacyGetHashCodeMethod = typeof(ListOfNullableValueTypesComparer).GetMethod( nameof(GetHashCode), BindingFlags.Static | BindingFlags.NonPublic, [typeof(IEnumerable), typeof(ValueComparer)])!; private static readonly MethodInfo SnapshotMethod = typeof(ListOfNullableValueTypesComparer).GetMethod( + nameof(Snapshot), BindingFlags.Static | BindingFlags.NonPublic, + [typeof(IEnumerable), typeof(Func)])!; + + private static readonly MethodInfo LegacySnapshotMethod = typeof(ListOfNullableValueTypesComparer).GetMethod( nameof(Snapshot), BindingFlags.Static | BindingFlags.NonPublic, [typeof(IEnumerable), typeof(ValueComparer)])!; @@ -67,10 +82,23 @@ ValueComparer IInfrastructure.Instance var prm1 = Expression.Parameter(typeof(IEnumerable), "a"); var prm2 = Expression.Parameter(typeof(IEnumerable), "b"); + if (elementComparer is ValueComparer && !UseOldBehavior35239) + { + //(a, b) => Compare(a, b, elementComparer.Equals) + return Expression.Lambda?, IEnumerable?, bool>>( + Expression.Call( + CompareMethod, + prm1, + prm2, + elementComparer.EqualsExpression), + prm1, + prm2); + } + //(a, b) => Compare(a, b, (ValueComparer)elementComparer) return Expression.Lambda?, IEnumerable?, bool>>( Expression.Call( - CompareMethod, + LegacyCompareMethod, prm1, prm2, Expression.Convert( @@ -84,10 +112,21 @@ ValueComparer IInfrastructure.Instance { var prm = Expression.Parameter(typeof(IEnumerable), "o"); + if (elementComparer is ValueComparer && !UseOldBehavior35239) + { + //o => GetHashCode(o, elementComparer.GetHashCode) + return Expression.Lambda, int>>( + Expression.Call( + GetHashCodeMethod, + prm, + elementComparer.HashCodeExpression), + prm); + } + //o => GetHashCode(o, (ValueComparer)elementComparer) return Expression.Lambda, int>>( Expression.Call( - GetHashCodeMethod, + LegacyGetHashCodeMethod, prm, Expression.Convert( elementComparer.ConstructorExpression, @@ -98,11 +137,21 @@ ValueComparer IInfrastructure.Instance private static Expression, IEnumerable>> SnapshotLambda(ValueComparer elementComparer) { var prm = Expression.Parameter(typeof(IEnumerable), "source"); + if (elementComparer is ValueComparer && !UseOldBehavior35239) + { + //source => Snapshot(source, elementComparer.Snapshot) + return Expression.Lambda, IEnumerable>>( + Expression.Call( + SnapshotMethod, + prm, + elementComparer.SnapshotExpression), + prm); + } //source => Snapshot(source, (ValueComparer)elementComparer) return Expression.Lambda, IEnumerable>>( Expression.Call( - SnapshotMethod, + LegacySnapshotMethod, prm, Expression.Convert( elementComparer.ConstructorExpression, @@ -110,6 +159,63 @@ ValueComparer IInfrastructure.Instance prm); } + private static bool Compare(IEnumerable? a, IEnumerable? b, Func elementCompare) + { + if (ReferenceEquals(a, b)) + { + return true; + } + + if (a is null) + { + return b is null; + } + + if (b is null) + { + return false; + } + + if (a is IList aList && b is IList bList) + { + if (aList.Count != bList.Count) + { + return false; + } + + for (var i = 0; i < aList.Count; i++) + { + var (el1, el2) = (aList[i], bList[i]); + if (el1 is null) + { + if (el2 is null) + { + continue; + } + + return false; + } + + if (el2 is null) + { + return false; + } + + if (!elementCompare(el1, el2)) + { + return false; + } + } + + return true; + } + + throw new InvalidOperationException( + CoreStrings.BadListType( + (a is IList ? b : a).GetType().ShortDisplayName(), + typeof(IList<>).MakeGenericType(typeof(TElement).MakeNullable()).ShortDisplayName())); + } + private static bool Compare(IEnumerable? a, IEnumerable? b, ValueComparer elementComparer) { if (ReferenceEquals(a, b)) @@ -167,6 +273,18 @@ private static bool Compare(IEnumerable? a, IEnumerable? b typeof(IList<>).MakeGenericType(elementComparer.Type.MakeNullable()).ShortDisplayName())); } + private static int GetHashCode(IEnumerable source, Func elementGetHashCode) + { + var hash = new HashCode(); + + foreach (var el in source) + { + hash.Add(el == null ? 0 : elementGetHashCode(el)); + } + + return hash.ToHashCode(); + } + private static int GetHashCode(IEnumerable source, ValueComparer elementComparer) { var hash = new HashCode(); @@ -179,6 +297,41 @@ private static int GetHashCode(IEnumerable source, ValueComparer Snapshot(IEnumerable source, Func elementSnapshot) + { + if (source is not IList sourceList) + { + throw new InvalidOperationException( + CoreStrings.BadListType( + source.GetType().ShortDisplayName(), + typeof(IList<>).MakeGenericType(typeof(TElement).MakeNullable()).ShortDisplayName())); + } + + if (IsArray) + { + var snapshot = new TElement?[sourceList.Count]; + for (var i = 0; i < sourceList.Count; i++) + { + var instance = sourceList[i]; + snapshot[i] = instance == null ? null : elementSnapshot(instance); + } + + return snapshot; + } + else + { + var snapshot = IsReadOnly ? new List() : (IList)Activator.CreateInstance()!; + foreach (var e in sourceList) + { + snapshot.Add(e == null ? null : elementSnapshot(e)); + } + + return IsReadOnly + ? (IList)Activator.CreateInstance(typeof(TConcreteList), snapshot)! + : snapshot; + } + } + private static IList Snapshot(IEnumerable source, ValueComparer elementComparer) { if (source is not IList sourceList) diff --git a/src/EFCore/ChangeTracking/ListOfReferenceTypesComparer.cs b/src/EFCore/ChangeTracking/ListOfReferenceTypesComparer.cs index e0e77528d87..f24a6f95aa2 100644 --- a/src/EFCore/ChangeTracking/ListOfReferenceTypesComparer.cs +++ b/src/EFCore/ChangeTracking/ListOfReferenceTypesComparer.cs @@ -23,6 +23,9 @@ namespace Microsoft.EntityFrameworkCore.ChangeTracking; public sealed class ListOfReferenceTypesComparer : ValueComparer, IInfrastructure where TElement : class { + private static readonly bool UseOldBehavior35239 = + AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue35239", out var enabled35239) && enabled35239; + private static readonly bool IsArray = typeof(TConcreteList).IsArray; private static readonly bool IsReadOnly = IsArray @@ -30,12 +33,21 @@ public sealed class ListOfReferenceTypesComparer : Valu && typeof(TConcreteList).GetGenericTypeDefinition() == typeof(ReadOnlyCollection<>)); private static readonly MethodInfo CompareMethod = typeof(ListOfReferenceTypesComparer).GetMethod( + nameof(Compare), BindingFlags.Static | BindingFlags.NonPublic, [typeof(object), typeof(object), typeof(Func)])!; + + private static readonly MethodInfo LegacyCompareMethod = typeof(ListOfReferenceTypesComparer).GetMethod( nameof(Compare), BindingFlags.Static | BindingFlags.NonPublic, [typeof(object), typeof(object), typeof(ValueComparer)])!; private static readonly MethodInfo GetHashCodeMethod = typeof(ListOfReferenceTypesComparer).GetMethod( + nameof(GetHashCode), BindingFlags.Static | BindingFlags.NonPublic, [typeof(IEnumerable), typeof(Func)])!; + + private static readonly MethodInfo LegacyGetHashCodeMethod = typeof(ListOfReferenceTypesComparer).GetMethod( nameof(GetHashCode), BindingFlags.Static | BindingFlags.NonPublic, [typeof(IEnumerable), typeof(ValueComparer)])!; private static readonly MethodInfo SnapshotMethod = typeof(ListOfReferenceTypesComparer).GetMethod( + nameof(Snapshot), BindingFlags.Static | BindingFlags.NonPublic, [typeof(object), typeof(Func)])!; + + private static readonly MethodInfo LegacySnapshotMethod = typeof(ListOfReferenceTypesComparer).GetMethod( nameof(Snapshot), BindingFlags.Static | BindingFlags.NonPublic, [typeof(object), typeof(ValueComparer)])!; /// @@ -62,10 +74,23 @@ ValueComparer IInfrastructure.Instance var prm1 = Expression.Parameter(typeof(object), "a"); var prm2 = Expression.Parameter(typeof(object), "b"); - // (a, b) => Compare(a, b, elementComparer) + if (elementComparer is ValueComparer && !UseOldBehavior35239) + { + // (a, b) => Compare(a, b, elementComparer.Equals, elementComparer.Type) + return Expression.Lambda>( + Expression.Call( + CompareMethod, + prm1, + prm2, + elementComparer.EqualsExpression), + prm1, + prm2); + } + + // (a, b) => Compare(a, b, new Comparer(...)) return Expression.Lambda>( Expression.Call( - CompareMethod, + LegacyCompareMethod, prm1, prm2, elementComparer.ConstructorExpression), @@ -77,10 +102,23 @@ private static Expression> GetHashCodeLambda(ValueComparer ele { var prm = Expression.Parameter(typeof(object), "o"); - //o => GetHashCode((IEnumerable)o, elementComparer) + if (elementComparer is ValueComparer && !UseOldBehavior35239) + { + // o => GetHashCode((IEnumerable)o, elementComparer.GetHashCode) + return Expression.Lambda>( + Expression.Call( + GetHashCodeMethod, + Expression.Convert( + prm, + typeof(IEnumerable)), + elementComparer.HashCodeExpression), + prm); + } + + // o => GetHashCode((IEnumerable)o, new Comparer(...)) return Expression.Lambda>( Expression.Call( - GetHashCodeMethod, + LegacyGetHashCodeMethod, Expression.Convert( prm, typeof(IEnumerable)), @@ -92,15 +130,83 @@ private static Expression> SnapshotLambda(ValueComparer ele { var prm = Expression.Parameter(typeof(object), "source"); - //source => Snapshot(source, elementComparer) + if (elementComparer is ValueComparer && !UseOldBehavior35239) + { + // source => Snapshot(source, elementComparer.Snapshot, elementComparer.Type) + return Expression.Lambda>( + Expression.Call( + SnapshotMethod, + prm, + elementComparer.SnapshotExpression), + prm); + } + + // source => Snapshot(source, new Comparer(..)) return Expression.Lambda>( Expression.Call( - SnapshotMethod, + LegacySnapshotMethod, prm, elementComparer.ConstructorExpression), prm); } + private static bool Compare(object? a, object? b, Func elementCompare) + { + if (ReferenceEquals(a, b)) + { + return true; + } + + if (a is null) + { + return b is null; + } + + if (b is null) + { + return false; + } + + if (a is IList aList && b is IList bList) + { + if (aList.Count != bList.Count) + { + return false; + } + + for (var i = 0; i < aList.Count; i++) + { + var (el1, el2) = (aList[i], bList[i]); + if (el1 is null) + { + if (el2 is null) + { + continue; + } + + return false; + } + + if (el2 is null) + { + return false; + } + + if (!elementCompare(el1, el2)) + { + return false; + } + } + + return true; + } + + throw new InvalidOperationException( + CoreStrings.BadListType( + (a is IList ? b : a).GetType().ShortDisplayName(), + typeof(IList<>).MakeGenericType(typeof(TElement)).ShortDisplayName())); + } + private static bool Compare(object? a, object? b, ValueComparer elementComparer) { if (ReferenceEquals(a, b)) @@ -158,6 +264,18 @@ private static bool Compare(object? a, object? b, ValueComparer elementComparer) typeof(IList<>).MakeGenericType(elementComparer.Type).ShortDisplayName())); } + private static int GetHashCode(IEnumerable source, Func elementGetHashCode) + { + var hash = new HashCode(); + + foreach (var el in source) + { + hash.Add(el == null ? 0 : elementGetHashCode((TElement)el)); + } + + return hash.ToHashCode(); + } + private static int GetHashCode(IEnumerable source, ValueComparer elementComparer) { var hash = new HashCode(); @@ -170,6 +288,41 @@ private static int GetHashCode(IEnumerable source, ValueComparer elementComparer return hash.ToHashCode(); } + private static IList Snapshot(object source, Func elementSnapshot) + { + if (source is not IList sourceList) + { + throw new InvalidOperationException( + CoreStrings.BadListType( + source.GetType().ShortDisplayName(), + typeof(IList<>).MakeGenericType(typeof(TElement)).ShortDisplayName())); + } + + if (IsArray) + { + var snapshot = new TElement?[sourceList.Count]; + for (var i = 0; i < sourceList.Count; i++) + { + var instance = sourceList[i]; + snapshot[i] = instance == null ? null : elementSnapshot(instance); + } + + return snapshot; + } + else + { + var snapshot = IsReadOnly ? new List() : (IList)Activator.CreateInstance()!; + foreach (var e in sourceList) + { + snapshot.Add(e == null ? null : elementSnapshot(e)); + } + + return IsReadOnly + ? (IList)Activator.CreateInstance(typeof(TConcreteList), snapshot)! + : snapshot; + } + } + private static IList Snapshot(object source, ValueComparer elementComparer) { if (source is not IList sourceList) diff --git a/src/EFCore/ChangeTracking/ListOfValueTypesComparer.cs b/src/EFCore/ChangeTracking/ListOfValueTypesComparer.cs index 19a3a8d4a2f..d0ae607c6b1 100644 --- a/src/EFCore/ChangeTracking/ListOfValueTypesComparer.cs +++ b/src/EFCore/ChangeTracking/ListOfValueTypesComparer.cs @@ -23,6 +23,9 @@ namespace Microsoft.EntityFrameworkCore.ChangeTracking; public sealed class ListOfValueTypesComparer : ValueComparer>, IInfrastructure where TElement : struct { + private static readonly bool UseOldBehavior35239 = + AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue35239", out var enabled35239) && enabled35239; + private static readonly bool IsArray = typeof(TConcreteList).IsArray; private static readonly bool IsReadOnly = IsArray @@ -30,14 +33,25 @@ public sealed class ListOfValueTypesComparer : ValueCom && typeof(TConcreteList).GetGenericTypeDefinition() == typeof(ReadOnlyCollection<>)); private static readonly MethodInfo CompareMethod = typeof(ListOfValueTypesComparer).GetMethod( + nameof(Compare), BindingFlags.Static | BindingFlags.NonPublic, + [typeof(IEnumerable), typeof(IEnumerable), typeof(Func)])!; + + private static readonly MethodInfo LegacyCompareMethod = typeof(ListOfValueTypesComparer).GetMethod( nameof(Compare), BindingFlags.Static | BindingFlags.NonPublic, [typeof(IEnumerable), typeof(IEnumerable), typeof(ValueComparer)])!; private static readonly MethodInfo GetHashCodeMethod = typeof(ListOfValueTypesComparer).GetMethod( + nameof(GetHashCode), BindingFlags.Static | BindingFlags.NonPublic, + [typeof(IEnumerable), typeof(Func)])!; + + private static readonly MethodInfo LegacyGetHashCodeMethod = typeof(ListOfValueTypesComparer).GetMethod( nameof(GetHashCode), BindingFlags.Static | BindingFlags.NonPublic, [typeof(IEnumerable), typeof(ValueComparer)])!; private static readonly MethodInfo SnapshotMethod = typeof(ListOfValueTypesComparer).GetMethod( + nameof(Snapshot), BindingFlags.Static | BindingFlags.NonPublic, [typeof(IEnumerable), typeof(Func)])!; + + private static readonly MethodInfo LegacySnapshotMethod = typeof(ListOfValueTypesComparer).GetMethod( nameof(Snapshot), BindingFlags.Static | BindingFlags.NonPublic, [typeof(IEnumerable), typeof(ValueComparer)])!; /// @@ -64,10 +78,23 @@ ValueComparer IInfrastructure.Instance var prm1 = Expression.Parameter(typeof(IEnumerable), "a"); var prm2 = Expression.Parameter(typeof(IEnumerable), "b"); + if (elementComparer is ValueComparer && !UseOldBehavior35239) + { + //(a, b) => Compare(a, b, elementComparer.Equals) + return Expression.Lambda?, IEnumerable?, bool>>( + Expression.Call( + CompareMethod, + prm1, + prm2, + elementComparer.EqualsExpression), + prm1, + prm2); + } + //(a, b) => Compare(a, b, (ValueComparer)elementComparer) return Expression.Lambda?, IEnumerable?, bool>>( Expression.Call( - CompareMethod, + LegacyCompareMethod, prm1, prm2, Expression.Convert( @@ -81,10 +108,21 @@ private static Expression, int>> GetHashCodeLambda(Va { var prm = Expression.Parameter(typeof(IEnumerable), "o"); + if (elementComparer is ValueComparer && !UseOldBehavior35239) + { + //o => GetHashCode(o, elementComparer.GetHashCode) + return Expression.Lambda, int>>( + Expression.Call( + GetHashCodeMethod, + prm, + elementComparer.HashCodeExpression), + prm); + } + //o => GetHashCode(o, (ValueComparer)elementComparer) return Expression.Lambda, int>>( Expression.Call( - GetHashCodeMethod, + LegacyGetHashCodeMethod, prm, Expression.Convert( elementComparer.ConstructorExpression, @@ -96,10 +134,21 @@ private static Expression, IEnumerable>> Sn { var prm = Expression.Parameter(typeof(IEnumerable), "source"); + if (elementComparer is ValueComparer && !UseOldBehavior35239) + { + //source => Snapshot(source, elementComparer.SnapShot) + return Expression.Lambda, IEnumerable>>( + Expression.Call( + SnapshotMethod, + prm, + elementComparer.SnapshotExpression), + prm); + } + //source => Snapshot(source, (ValueComparer)elementComparer) return Expression.Lambda, IEnumerable>>( Expression.Call( - SnapshotMethod, + LegacySnapshotMethod, prm, Expression.Convert( elementComparer.ConstructorExpression, @@ -107,6 +156,48 @@ private static Expression, IEnumerable>> Sn prm); } + private static bool Compare(IEnumerable? a, IEnumerable? b, Func elementCompare) + { + if (ReferenceEquals(a, b)) + { + return true; + } + + if (a is null) + { + return b is null; + } + + if (b is null) + { + return false; + } + + if (a is IList aList && b is IList bList) + { + if (aList.Count != bList.Count) + { + return false; + } + + for (var i = 0; i < aList.Count; i++) + { + var (el1, el2) = (aList[i], bList[i]); + if (!elementCompare(el1, el2)) + { + return false; + } + } + + return true; + } + + throw new InvalidOperationException( + CoreStrings.BadListType( + (a is IList ? b : a).GetType().ShortDisplayName(), + typeof(IList<>).MakeGenericType(typeof(TElement)).ShortDisplayName())); + } + private static bool Compare(IEnumerable? a, IEnumerable? b, ValueComparer elementComparer) { if (ReferenceEquals(a, b)) @@ -149,6 +240,18 @@ private static bool Compare(IEnumerable? a, IEnumerable? b, typeof(IList<>).MakeGenericType(elementComparer.Type).ShortDisplayName())); } + private static int GetHashCode(IEnumerable source, Func elementGetHashCode) + { + var hash = new HashCode(); + + foreach (var el in source) + { + hash.Add(elementGetHashCode(el)); + } + + return hash.ToHashCode(); + } + private static int GetHashCode(IEnumerable source, ValueComparer elementComparer) { var hash = new HashCode(); @@ -161,6 +264,41 @@ private static int GetHashCode(IEnumerable source, ValueComparer Snapshot(IEnumerable source, Func elementSnapshot) + { + if (source is not IList sourceList) + { + throw new InvalidOperationException( + CoreStrings.BadListType( + source.GetType().ShortDisplayName(), + typeof(IList<>).MakeGenericType(typeof(TElement).MakeNullable()).ShortDisplayName())); + } + + if (IsArray) + { + var snapshot = new TElement[sourceList.Count]; + for (var i = 0; i < sourceList.Count; i++) + { + var instance = sourceList[i]; + snapshot[i] = elementSnapshot(instance); + } + + return snapshot; + } + else + { + var snapshot = IsReadOnly ? new List() : (IList)Activator.CreateInstance()!; + foreach (var e in sourceList) + { + snapshot.Add(elementSnapshot(e)); + } + + return IsReadOnly + ? (IList)Activator.CreateInstance(typeof(TConcreteList), snapshot)! + : snapshot; + } + } + private static IList Snapshot(IEnumerable source, ValueComparer elementComparer) { if (source is not IList sourceList)