diff --git a/src/xunit.analyzers.tests/Analyzers/X2000/SetsMustBeComparedWithEqualityComparerTests.cs b/src/xunit.analyzers.tests/Analyzers/X2000/SetsMustBeComparedWithEqualityComparerTests.cs new file mode 100644 index 00000000..6cc6a857 --- /dev/null +++ b/src/xunit.analyzers.tests/Analyzers/X2000/SetsMustBeComparedWithEqualityComparerTests.cs @@ -0,0 +1,222 @@ +using Microsoft.CodeAnalysis.CSharp; +using Xunit; +using Verify = CSharpVerifier; + +public class SetsMustBeComparedWithEqualityComparerTests +{ + const string customSet = @" +using System.Collections; +using System.Collections.Generic; + +public class MySet : ISet { + public int Count => throw new System.NotImplementedException(); + public bool IsReadOnly => throw new System.NotImplementedException(); + + public bool Add(int item) => throw new System.NotImplementedException(); + public void Clear() => throw new System.NotImplementedException(); + public bool Contains(int item) => throw new System.NotImplementedException(); + public void CopyTo(int[] array, int arrayIndex) => throw new System.NotImplementedException(); + public void ExceptWith(IEnumerable other) => throw new System.NotImplementedException(); + public IEnumerator GetEnumerator() => throw new System.NotImplementedException(); + public void IntersectWith(IEnumerable other) => throw new System.NotImplementedException(); + public bool IsProperSubsetOf(IEnumerable other) => throw new System.NotImplementedException(); + public bool IsProperSupersetOf(IEnumerable other) => throw new System.NotImplementedException(); + public bool IsSubsetOf(IEnumerable other) => throw new System.NotImplementedException(); + public bool IsSupersetOf(IEnumerable other) => throw new System.NotImplementedException(); + public bool Overlaps(IEnumerable other) => throw new System.NotImplementedException(); + public bool Remove(int item) => throw new System.NotImplementedException(); + public bool SetEquals(IEnumerable other) => throw new System.NotImplementedException(); + public void SymmetricExceptWith(IEnumerable other) => throw new System.NotImplementedException(); + public void UnionWith(IEnumerable other) => throw new System.NotImplementedException(); + void ICollection.Add(int item) => throw new System.NotImplementedException(); + IEnumerator IEnumerable.GetEnumerator() => throw new System.NotImplementedException(); +}"; + + [Theory] + [InlineData("Equal", "List", "List")] + [InlineData("Equal", "HashSet", "List")] + [InlineData("Equal", "List", "HashSet")] + [InlineData("NotEqual", "List", "List")] + [InlineData("NotEqual", "HashSet", "List")] + [InlineData("NotEqual", "List", "HashSet")] + public async void ForSetWithNonSet_DoesNotTrigger( + string method, + string collection1Type, + string collection2Type) + { + var code = @$" +using Xunit; +using System.Collections.Generic; + +public class TestClass {{ + [Fact] + public void TestMethod() {{ + var collection1 = new {collection1Type}(); + var collection2 = new {collection2Type}(); + + Assert.{method}(collection1, collection2, (int e1, int e2) => true); + }} +}}"; + + await Verify.VerifyAnalyzer(code); + } + + public static MatrixTheoryData MethodWithCollectionCreationData => + new( + new[] { "Equal", "NotEqual" }, + new[] { "new HashSet()", "new HashSet().ToImmutableHashSet()", "new MySet()" }, + new[] { "new HashSet()", "new HashSet().ToImmutableHashSet()", "new MySet()" } + ); + + [Theory] + [MemberData(nameof(MethodWithCollectionCreationData))] + public async void WithCollectionComparer_DoesNotTrigger( + string method, + string collection1, + string collection2) + { + var code = @$" +using Xunit; +using System.Collections.Generic; +using System.Collections.Immutable; + +public class TestClass {{ + [Fact] + public void TestMethod() {{ + var collection1 = {collection1}; + var collection2 = {collection2}; + + Assert.{method}(collection1, collection2, (IEnumerable e1, IEnumerable e2) => true); + }} +}}"; + + await Verify.VerifyAnalyzer(LanguageVersion.CSharp7, new[] { code, customSet }); + } + + [Theory] + [MemberData(nameof(MethodWithCollectionCreationData))] + public async void WithEqualityComparer_DoesNotTrigger( + string method, + string collection1, + string collection2) + { + var code = @$" +using Xunit; +using System.Collections.Generic; +using System.Collections.Immutable; + +public class TestEqualityComparer : IEqualityComparer +{{ + public bool Equals(int x, int y) + {{ + return true; + }} + + public int GetHashCode(int obj) + {{ + return 0; + }} +}} + +public class TestClass {{ + [Fact] + public void TestMethod() {{ + var collection1 = {collection1}; + var collection2 = {collection2}; + + Assert.{method}(collection1, collection2, new TestEqualityComparer()); + }} +}}"; + + await Verify.VerifyAnalyzer(LanguageVersion.CSharp7, new[] { code, customSet }); + } + + [Theory] + [MemberData(nameof(MethodWithCollectionCreationData))] + public async void WithComparerLambda_Triggers( + string method, + string collection1, + string collection2) + { + var code = @$" +using Xunit; +using System.Collections.Generic; +using System.Collections.Immutable; + +public class TestClass {{ + [Fact] + public void TestMethod() {{ + var collection1 = {collection1}; + var collection2 = {collection2}; + + Assert.{method}(collection1, collection2, (int e1, int e2) => true); + }} +}}"; + + var expected = + Verify + .Diagnostic() + .WithSpan(12, 9, 12, 68 + method.Length) + .WithArguments(method); + + await Verify.VerifyAnalyzer(LanguageVersion.CSharp7, new[] { code, customSet }, expected); + } + +#if ROSLYN_4_2_OR_GREATER // No C# 10 in Roslyn 3.11, so no local functions + + public static MatrixTheoryData ComparerFunctionData() => + new( + new[] { "Equal", "NotEqual" }, + new[] { "(int e1, int e2) => true", "FuncComparer", "LocalFunc", "funcDelegate" }, + new[] { "new HashSet()", "new HashSet().ToImmutableHashSet()", "new MySet()" }, + new[] { "new HashSet()", "new HashSet().ToImmutableHashSet()", "new MySet()" } + ); + + [Theory] + [MemberData(nameof(ComparerFunctionData))] + public async void WithComparerFunction_Triggers( + string method, + string comparerFuncSyntax, + string collection1, + string collection2) + { + var code = @$" +using Xunit; +using System.Collections.Generic; +using System.Collections.Immutable; + +public class TestClass {{ + private bool FuncComparer(int obj1, int obj2) + {{ + return true; + }} + + private delegate bool FuncDelegate(int obj1, int obj2); + + [Fact] + public void TestMethod() {{ + var collection1 = {collection1}; + var collection2 = {collection2}; + + bool LocalFunc(int obj1, int obj2) + {{ + return true; + }} + + var funcDelegate = FuncComparer; + + Assert.{method}(collection1, collection2, {comparerFuncSyntax}); + }} +}}"; + + var expected = + Verify + .Diagnostic() + .WithSpan(26, 9, 26, 44 + method.Length + comparerFuncSyntax.Length) + .WithArguments(method); + + await Verify.VerifyAnalyzer(LanguageVersion.CSharp10, new[] { code, customSet }, expected); + } + +#endif +} diff --git a/src/xunit.analyzers/Utility/Descriptors.cs b/src/xunit.analyzers/Utility/Descriptors.cs index 5377c0ab..a6109015 100644 --- a/src/xunit.analyzers/Utility/Descriptors.cs +++ b/src/xunit.analyzers/Utility/Descriptors.cs @@ -663,7 +663,14 @@ static DiagnosticDescriptor Rule( "The use of Assert.{0} can be simplified to avoid using a boolean literal value in an equality test." ); - // Placeholder for rule X2026 + public static DiagnosticDescriptor X2026_SetsMustBeComparedWithEqualityComparer { get; } = + Rule( + "xUnit2026", + "Comparison of sets must be done with IEqualityComparer", + Assertions, + Warning, + "Comparison of two sets may produce inconsistent results if GetHashCode() is not overriden. Consider using Assert.{0}(IEnumerable, IEnumerable, IEqualityComparer) instead." + ); // Placeholder for rule X2027 diff --git a/src/xunit.analyzers/Utility/TypeSymbolFactory.cs b/src/xunit.analyzers/Utility/TypeSymbolFactory.cs index 87e005a0..cca61dae 100644 --- a/src/xunit.analyzers/Utility/TypeSymbolFactory.cs +++ b/src/xunit.analyzers/Utility/TypeSymbolFactory.cs @@ -102,6 +102,12 @@ public static INamedTypeSymbol IEnumerableOfObjectArray(Compilation compilation) public static INamedTypeSymbol IReadOnlyCollectionOfT(Compilation compilation) => Guard.ArgumentNotNull(compilation).GetSpecialType(SpecialType.System_Collections_Generic_IReadOnlyCollection_T); + public static INamedTypeSymbol? IReadOnlySetOfT(Compilation compilation) => + Guard.ArgumentNotNull(compilation).GetTypeByMetadataName("System.Collections.Generic.IReadOnlySet`1"); + + public static INamedTypeSymbol? ISetOfT(Compilation compilation) => + Guard.ArgumentNotNull(compilation).GetTypeByMetadataName("System.Collections.Generic.ISet`1"); + public static INamedTypeSymbol? ISourceInformation_V2(Compilation compilation) => Guard.ArgumentNotNull(compilation).GetTypeByMetadataName("Xunit.Abstractions.ISourceInformation"); diff --git a/src/xunit.analyzers/X2000/SetsMustBeComparedWithEqualityComparer.cs b/src/xunit.analyzers/X2000/SetsMustBeComparedWithEqualityComparer.cs new file mode 100644 index 00000000..43290835 --- /dev/null +++ b/src/xunit.analyzers/X2000/SetsMustBeComparedWithEqualityComparer.cs @@ -0,0 +1,93 @@ +#if ROSLYN_3_11 +#pragma warning disable RS1024 // Incorrectly triggered by Roslyn 3.11 +#endif + +using System.Collections.Generic; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.CodeAnalysis.Operations; + +namespace Xunit.Analyzers; + +[DiagnosticAnalyzer(LanguageNames.CSharp)] +public class SetsMustBeComparedWithEqualityComparer : AssertUsageAnalyzerBase +{ + static readonly string[] targetMethods = + { + Constants.Asserts.Equal, + Constants.Asserts.NotEqual, + }; + + public SetsMustBeComparedWithEqualityComparer() + : base(Descriptors.X2026_SetsMustBeComparedWithEqualityComparer, targetMethods) + { } + + protected override void AnalyzeInvocation( + OperationAnalysisContext context, + XunitContext xunitContext, + IInvocationOperation invocationOperation, + IMethodSymbol method) + { + Guard.ArgumentNotNull(xunitContext); + Guard.ArgumentNotNull(invocationOperation); + Guard.ArgumentNotNull(method); + + var arguments = invocationOperation.Arguments; + if (arguments.Length != 3) + return; + + var semanticModel = context.Operation.SemanticModel; + if (semanticModel == null) + return; + + var setType = TypeSymbolFactory.ISetOfT(context.Compilation)?.ConstructUnboundGenericType(); + var readOnlySetType = TypeSymbolFactory.IReadOnlySetOfT(context.Compilation)?.ConstructUnboundGenericType(); + var setInterfaces = new HashSet(new[] { setType, readOnlySetType }.WhereNotNull(), SymbolEqualityComparer.Default); + + if (semanticModel.GetTypeInfo(arguments[0].Value.Syntax).Type is not INamedTypeSymbol collection0Type) + return; + var interface0Type = + collection0Type + .AllInterfaces + .Where(i => i.IsGenericType) + .FirstOrDefault(i => setInterfaces.Contains(i.ConstructUnboundGenericType())); + if (interface0Type is null) + return; + + if (semanticModel.GetTypeInfo(arguments[1].Value.Syntax).Type is not INamedTypeSymbol collection1Type) + return; + var interface1Type = + collection1Type + .AllInterfaces + .Where(i => i.IsGenericType) + .FirstOrDefault(i => setInterfaces.Contains(i.ConstructUnboundGenericType())); + if (interface1Type is null) + return; + + if (arguments[2].Value is not IDelegateCreationOperation && arguments[2].Value is not ILocalReferenceOperation) + return; + + if (arguments[2].Value.Type is not INamedTypeSymbol funcTypeSymbol || funcTypeSymbol.DelegateInvokeMethod == null) + return; + + var funcDelegate = funcTypeSymbol.DelegateInvokeMethod; + var isFuncOverload = + funcDelegate.ReturnType.SpecialType == SpecialType.System_Boolean && + funcDelegate.Parameters.Length == 2 && + funcDelegate.Parameters[0].Type.Equals(interface0Type.TypeArguments[0], SymbolEqualityComparer.Default) && + funcDelegate.Parameters[1].Type.Equals(interface1Type.TypeArguments[0], SymbolEqualityComparer.Default); + + // Wrong method overload + if (!isFuncOverload) + return; + + context.ReportDiagnostic( + Diagnostic.Create( + Descriptors.X2026_SetsMustBeComparedWithEqualityComparer, + invocationOperation.Syntax.GetLocation(), + method.Name + ) + ); + } +}