Skip to content

Commit

Permalink
Analyzer for equality checks with hash sets (#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
etherfield authored Dec 24, 2023
1 parent 4ee532e commit 362e7a0
Show file tree
Hide file tree
Showing 4 changed files with 329 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
using Microsoft.CodeAnalysis.CSharp;
using Xunit;
using Verify = CSharpVerifier<Xunit.Analyzers.SetsMustBeComparedWithEqualityComparer>;

public class SetsMustBeComparedWithEqualityComparerTests
{
const string customSet = @"
using System.Collections;
using System.Collections.Generic;
public class MySet : ISet<int> {
public int Count => throw new System.NotImplementedException();
public bool IsReadOnly => throw new System.NotImplementedException();
public bool Add(int item) => throw new System.NotImplementedException();
public void Clear() => throw new System.NotImplementedException();
public bool Contains(int item) => throw new System.NotImplementedException();
public void CopyTo(int[] array, int arrayIndex) => throw new System.NotImplementedException();
public void ExceptWith(IEnumerable<int> other) => throw new System.NotImplementedException();
public IEnumerator<int> GetEnumerator() => throw new System.NotImplementedException();
public void IntersectWith(IEnumerable<int> other) => throw new System.NotImplementedException();
public bool IsProperSubsetOf(IEnumerable<int> other) => throw new System.NotImplementedException();
public bool IsProperSupersetOf(IEnumerable<int> other) => throw new System.NotImplementedException();
public bool IsSubsetOf(IEnumerable<int> other) => throw new System.NotImplementedException();
public bool IsSupersetOf(IEnumerable<int> other) => throw new System.NotImplementedException();
public bool Overlaps(IEnumerable<int> other) => throw new System.NotImplementedException();
public bool Remove(int item) => throw new System.NotImplementedException();
public bool SetEquals(IEnumerable<int> other) => throw new System.NotImplementedException();
public void SymmetricExceptWith(IEnumerable<int> other) => throw new System.NotImplementedException();
public void UnionWith(IEnumerable<int> other) => throw new System.NotImplementedException();
void ICollection<int>.Add(int item) => throw new System.NotImplementedException();
IEnumerator IEnumerable.GetEnumerator() => throw new System.NotImplementedException();
}";

[Theory]
[InlineData("Equal", "List", "List")]
[InlineData("Equal", "HashSet", "List")]
[InlineData("Equal", "List", "HashSet")]
[InlineData("NotEqual", "List", "List")]
[InlineData("NotEqual", "HashSet", "List")]
[InlineData("NotEqual", "List", "HashSet")]
public async void ForSetWithNonSet_DoesNotTrigger(
string method,
string collection1Type,
string collection2Type)
{
var code = @$"
using Xunit;
using System.Collections.Generic;
public class TestClass {{
[Fact]
public void TestMethod() {{
var collection1 = new {collection1Type}<int>();
var collection2 = new {collection2Type}<int>();
Assert.{method}(collection1, collection2, (int e1, int e2) => true);
}}
}}";

await Verify.VerifyAnalyzer(code);
}

public static MatrixTheoryData<string, string, string> MethodWithCollectionCreationData =>
new(
new[] { "Equal", "NotEqual" },
new[] { "new HashSet<int>()", "new HashSet<int>().ToImmutableHashSet()", "new MySet()" },
new[] { "new HashSet<int>()", "new HashSet<int>().ToImmutableHashSet()", "new MySet()" }
);

[Theory]
[MemberData(nameof(MethodWithCollectionCreationData))]
public async void WithCollectionComparer_DoesNotTrigger(
string method,
string collection1,
string collection2)
{
var code = @$"
using Xunit;
using System.Collections.Generic;
using System.Collections.Immutable;
public class TestClass {{
[Fact]
public void TestMethod() {{
var collection1 = {collection1};
var collection2 = {collection2};
Assert.{method}(collection1, collection2, (IEnumerable<int> e1, IEnumerable<int> e2) => true);
}}
}}";

await Verify.VerifyAnalyzer(LanguageVersion.CSharp7, new[] { code, customSet });
}

[Theory]
[MemberData(nameof(MethodWithCollectionCreationData))]
public async void WithEqualityComparer_DoesNotTrigger(
string method,
string collection1,
string collection2)
{
var code = @$"
using Xunit;
using System.Collections.Generic;
using System.Collections.Immutable;
public class TestEqualityComparer : IEqualityComparer<int>
{{
public bool Equals(int x, int y)
{{
return true;
}}
public int GetHashCode(int obj)
{{
return 0;
}}
}}
public class TestClass {{
[Fact]
public void TestMethod() {{
var collection1 = {collection1};
var collection2 = {collection2};
Assert.{method}(collection1, collection2, new TestEqualityComparer());
}}
}}";

await Verify.VerifyAnalyzer(LanguageVersion.CSharp7, new[] { code, customSet });
}

[Theory]
[MemberData(nameof(MethodWithCollectionCreationData))]
public async void WithComparerLambda_Triggers(
string method,
string collection1,
string collection2)
{
var code = @$"
using Xunit;
using System.Collections.Generic;
using System.Collections.Immutable;
public class TestClass {{
[Fact]
public void TestMethod() {{
var collection1 = {collection1};
var collection2 = {collection2};
Assert.{method}(collection1, collection2, (int e1, int e2) => true);
}}
}}";

var expected =
Verify
.Diagnostic()
.WithSpan(12, 9, 12, 68 + method.Length)
.WithArguments(method);

await Verify.VerifyAnalyzer(LanguageVersion.CSharp7, new[] { code, customSet }, expected);
}

#if ROSLYN_4_2_OR_GREATER // No C# 10 in Roslyn 3.11, so no local functions

public static MatrixTheoryData<string, string, string, string> ComparerFunctionData() =>
new(
new[] { "Equal", "NotEqual" },
new[] { "(int e1, int e2) => true", "FuncComparer", "LocalFunc", "funcDelegate" },
new[] { "new HashSet<int>()", "new HashSet<int>().ToImmutableHashSet()", "new MySet()" },
new[] { "new HashSet<int>()", "new HashSet<int>().ToImmutableHashSet()", "new MySet()" }
);

[Theory]
[MemberData(nameof(ComparerFunctionData))]
public async void WithComparerFunction_Triggers(
string method,
string comparerFuncSyntax,
string collection1,
string collection2)
{
var code = @$"
using Xunit;
using System.Collections.Generic;
using System.Collections.Immutable;
public class TestClass {{
private bool FuncComparer(int obj1, int obj2)
{{
return true;
}}
private delegate bool FuncDelegate(int obj1, int obj2);
[Fact]
public void TestMethod() {{
var collection1 = {collection1};
var collection2 = {collection2};
bool LocalFunc(int obj1, int obj2)
{{
return true;
}}
var funcDelegate = FuncComparer;
Assert.{method}(collection1, collection2, {comparerFuncSyntax});
}}
}}";

var expected =
Verify
.Diagnostic()
.WithSpan(26, 9, 26, 44 + method.Length + comparerFuncSyntax.Length)
.WithArguments(method);

await Verify.VerifyAnalyzer(LanguageVersion.CSharp10, new[] { code, customSet }, expected);
}

#endif
}
9 changes: 8 additions & 1 deletion src/xunit.analyzers/Utility/Descriptors.cs
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,14 @@ static DiagnosticDescriptor Rule(
"The use of Assert.{0} can be simplified to avoid using a boolean literal value in an equality test."
);

// Placeholder for rule X2026
public static DiagnosticDescriptor X2026_SetsMustBeComparedWithEqualityComparer { get; } =
Rule(
"xUnit2026",
"Comparison of sets must be done with IEqualityComparer",
Assertions,
Warning,
"Comparison of two sets may produce inconsistent results if GetHashCode() is not overriden. Consider using Assert.{0}(IEnumerable<T>, IEnumerable<T>, IEqualityComparer<T>) instead."
);

// Placeholder for rule X2027

Expand Down
6 changes: 6 additions & 0 deletions src/xunit.analyzers/Utility/TypeSymbolFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ public static INamedTypeSymbol IEnumerableOfObjectArray(Compilation compilation)
public static INamedTypeSymbol IReadOnlyCollectionOfT(Compilation compilation) =>
Guard.ArgumentNotNull(compilation).GetSpecialType(SpecialType.System_Collections_Generic_IReadOnlyCollection_T);

public static INamedTypeSymbol? IReadOnlySetOfT(Compilation compilation) =>
Guard.ArgumentNotNull(compilation).GetTypeByMetadataName("System.Collections.Generic.IReadOnlySet`1");

public static INamedTypeSymbol? ISetOfT(Compilation compilation) =>
Guard.ArgumentNotNull(compilation).GetTypeByMetadataName("System.Collections.Generic.ISet`1");

public static INamedTypeSymbol? ISourceInformation_V2(Compilation compilation) =>
Guard.ArgumentNotNull(compilation).GetTypeByMetadataName("Xunit.Abstractions.ISourceInformation");

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#if ROSLYN_3_11
#pragma warning disable RS1024 // Incorrectly triggered by Roslyn 3.11
#endif

using System.Collections.Generic;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Operations;

namespace Xunit.Analyzers;

[DiagnosticAnalyzer(LanguageNames.CSharp)]
public class SetsMustBeComparedWithEqualityComparer : AssertUsageAnalyzerBase
{
static readonly string[] targetMethods =
{
Constants.Asserts.Equal,
Constants.Asserts.NotEqual,
};

public SetsMustBeComparedWithEqualityComparer()
: base(Descriptors.X2026_SetsMustBeComparedWithEqualityComparer, targetMethods)
{ }

protected override void AnalyzeInvocation(
OperationAnalysisContext context,
XunitContext xunitContext,
IInvocationOperation invocationOperation,
IMethodSymbol method)
{
Guard.ArgumentNotNull(xunitContext);
Guard.ArgumentNotNull(invocationOperation);
Guard.ArgumentNotNull(method);

var arguments = invocationOperation.Arguments;
if (arguments.Length != 3)
return;

var semanticModel = context.Operation.SemanticModel;
if (semanticModel == null)
return;

var setType = TypeSymbolFactory.ISetOfT(context.Compilation)?.ConstructUnboundGenericType();
var readOnlySetType = TypeSymbolFactory.IReadOnlySetOfT(context.Compilation)?.ConstructUnboundGenericType();
var setInterfaces = new HashSet<INamedTypeSymbol>(new[] { setType, readOnlySetType }.WhereNotNull(), SymbolEqualityComparer.Default);

if (semanticModel.GetTypeInfo(arguments[0].Value.Syntax).Type is not INamedTypeSymbol collection0Type)
return;
var interface0Type =
collection0Type
.AllInterfaces
.Where(i => i.IsGenericType)
.FirstOrDefault(i => setInterfaces.Contains(i.ConstructUnboundGenericType()));
if (interface0Type is null)
return;

if (semanticModel.GetTypeInfo(arguments[1].Value.Syntax).Type is not INamedTypeSymbol collection1Type)
return;
var interface1Type =
collection1Type
.AllInterfaces
.Where(i => i.IsGenericType)
.FirstOrDefault(i => setInterfaces.Contains(i.ConstructUnboundGenericType()));
if (interface1Type is null)
return;

if (arguments[2].Value is not IDelegateCreationOperation && arguments[2].Value is not ILocalReferenceOperation)
return;

if (arguments[2].Value.Type is not INamedTypeSymbol funcTypeSymbol || funcTypeSymbol.DelegateInvokeMethod == null)
return;

var funcDelegate = funcTypeSymbol.DelegateInvokeMethod;
var isFuncOverload =
funcDelegate.ReturnType.SpecialType == SpecialType.System_Boolean &&
funcDelegate.Parameters.Length == 2 &&
funcDelegate.Parameters[0].Type.Equals(interface0Type.TypeArguments[0], SymbolEqualityComparer.Default) &&
funcDelegate.Parameters[1].Type.Equals(interface1Type.TypeArguments[0], SymbolEqualityComparer.Default);

// Wrong method overload
if (!isFuncOverload)
return;

context.ReportDiagnostic(
Diagnostic.Create(
Descriptors.X2026_SetsMustBeComparedWithEqualityComparer,
invocationOperation.Syntax.GetLocation(),
method.Name
)
);
}
}

0 comments on commit 362e7a0

Please sign in to comment.