Skip to content

Fix Equals and GetHashCode for types containing Lists and Arrays in C# #2710

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
331 changes: 282 additions & 49 deletions crates/bindings-csharp/BSATN.Codegen/Type.cs

Large diffs are not rendered by default.

282 changes: 270 additions & 12 deletions crates/bindings-csharp/BSATN.Runtime.Tests/Tests.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
namespace SpacetimeDB;

using System.Diagnostics.CodeAnalysis;
using CsCheck;
using SpacetimeDB.BSATN;
using Xunit;

public static partial class BSATNRuntimeTests
Expand Down Expand Up @@ -296,9 +298,77 @@ public BasicDataRecord((int x, string y, int? z, string? w) data)
(int X, string Y, int? Z, string? W) c2
)> GenTwoBasic = Gen.Select(GenBasic, GenBasic, (c1, c2) => (c1, c2));

/// <summary>
/// Count collisions when comparing hashcodes of non-equal structures.
/// </summary>
struct CollisionCounter
{
private uint Comparisons;
private uint Collisions;

public void Add(bool collides)
{
Comparisons += 1;
if (collides)
{
Collisions += 1;
}
}

public double CollisionFraction
{
get => (double)Collisions / (double)Comparisons;
}

public void AssertCollisionsLessThan(double fraction)
{
Assert.True(
CollisionFraction < fraction,
$"Expected {fraction} portion of collisions, but got {CollisionFraction} = {Collisions} / {Comparisons}"
);
}
}

static void TestRoundTrip<T, BSATN>(Gen<T> gen, BSATN serializer)
where BSATN : IReadWrite<T>
{
gen.Sample(
(value) =>
{
var stream = new MemoryStream();
var writer = new BinaryWriter(stream);
serializer.Write(writer, value);
stream.Seek(0, SeekOrigin.Begin);
var reader = new BinaryReader(stream);
var result = serializer.Read(reader);
Assert.Equal(value, result);
},
iter: 10_000
);
}

[Fact]
public static void GeneratedProductRoundTrip()
{
TestRoundTrip(
GenBasic.Select(value => new BasicDataClass(value)),
new BasicDataClass.BSATN()
);
TestRoundTrip(
GenBasic.Select(value => new BasicDataRecord(value)),
new BasicDataRecord.BSATN()
);
TestRoundTrip(
GenBasic.Select(value => new BasicDataStruct(value)),
new BasicDataStruct.BSATN()
);
}

[Fact]
public static void TestGeneratedEquals()
public static void GeneratedProductEqualsWorks()
{
CollisionCounter collisionCounter = new();

GenTwoBasic.Sample(
example =>
{
Expand Down Expand Up @@ -355,10 +425,13 @@ public static void TestGeneratedEquals()
// hash code should not depend on the type of object.
Assert.Equal(class1.GetHashCode(), record1.GetHashCode());
Assert.Equal(record1.GetHashCode(), struct1.GetHashCode());

collisionCounter.Add(class1.GetHashCode() == class2.GetHashCode());
}
},
iter: 10_000
);
collisionCounter.AssertCollisionsLessThan(0.05);
}

[Type]
Expand Down Expand Up @@ -395,22 +468,17 @@ BasicDataRecord W
(e1, e2) => (e1, e2)
);

[Type]
public partial class ContainsList
[Fact]
public static void GeneratedSumRoundTrip()
{
public List<BasicEnum?> TheList = [];

public ContainsList() { }

public ContainsList(List<BasicEnum?> theList)
{
TheList = theList;
}
TestRoundTrip(GenBasicEnum, new BasicEnum.BSATN());
}

[Fact]
public static void GeneratedEnumsWork()
public static void GeneratedSumEqualsWorks()
{
CollisionCounter collisionCounter = new();

GenTwoBasicEnum.Sample(
example =>
{
Expand Down Expand Up @@ -442,10 +510,186 @@ public static void GeneratedEnumsWork()
Assert.False(example.e1 == example.e2);
Assert.True(example.e1 != example.e2);
Assert.NotEqual(example.e1.ToString(), example.e2.ToString());
collisionCounter.Add(example.e1.GetHashCode() == example.e2.GetHashCode());
}
},
iter: 10_000
);
collisionCounter.AssertCollisionsLessThan(0.05);
}

[Type]
public partial class ContainsList
{
public List<BasicEnum?>? TheList = [];

public ContainsList() { }

public ContainsList(List<BasicEnum?>? theList)
{
TheList = theList;
}
}

static readonly Gen<ContainsList> GenContainsList = GenBasicEnum
.Null()
.List[0, 2]
.Null()
.Select(list => new ContainsList(list));
static readonly Gen<(ContainsList e1, ContainsList e2)> GenTwoContainsList = Gen.Select(
GenContainsList,
GenContainsList,
(e1, e2) => (e1, e2)
);

[Fact]
public static void GeneratedListRoundTrip()
{
TestRoundTrip(GenContainsList, new ContainsList.BSATN());
}

[Fact]
public static void GeneratedListEqualsWorks()
{
CollisionCounter collisionCounter = new();
GenTwoContainsList.Sample(
example =>
{
var equal =
example.e1.TheList == null
? example.e2.TheList == null
: (
example.e2.TheList == null
? false
: example.e1.TheList.SequenceEqual(example.e2.TheList)
);

if (equal)
{
Assert.Equal(example.e1, example.e2);
Assert.True(example.e1 == example.e2);
Assert.False(example.e1 != example.e2);
Assert.Equal(example.e1.ToString(), example.e2.ToString());
Assert.Equal(example.e1.GetHashCode(), example.e2.GetHashCode());
}
else
{
Assert.NotEqual(example.e1, example.e2);
Assert.False(example.e1 == example.e2);
Assert.True(example.e1 != example.e2);
Assert.NotEqual(example.e1.ToString(), example.e2.ToString());
collisionCounter.Add(example.e1.GetHashCode() == example.e2.GetHashCode());
}
},
iter: 10_000
);
collisionCounter.AssertCollisionsLessThan(0.05);
}

[Type]
public partial class ContainsNestedList
{
public List<BasicEnum[][]> TheList = [];

public ContainsNestedList() { }

public ContainsNestedList(List<BasicEnum[][]> theList)
{
TheList = theList;
}
}

// For the serialization test, forbid nulls.
static readonly Gen<ContainsNestedList> GenContainsNestedListNoNulls = GenBasicEnum
.Array[0, 2]
.Array[0, 2]
.List[0, 2]
.Select(list => new ContainsNestedList(list));

[Fact]
public static void GeneratedNestedListRoundTrip()
{
TestRoundTrip(GenContainsNestedListNoNulls, new ContainsNestedList.BSATN());
}

// However, for the equals + hashcode test, throw in some nulls, just to be paranoid.
// The user might have constructed a bad one of these in-memory.

#pragma warning disable CS8620 // Argument cannot be used for parameter due to differences in the nullability of reference types.
static readonly Gen<ContainsNestedList> GenContainsNestedList = GenBasicEnum
.Null()
.Array[0, 2]
.Null()
.Array[0, 2]
.Null()
.List[0, 2]
.Select(list => new ContainsNestedList(list));
#pragma warning restore CS8620 // Argument cannot be used for parameter due to differences in the nullability of reference types.


static readonly Gen<(ContainsNestedList e1, ContainsNestedList e2)> GenTwoContainsNestedList =
Gen.Select(GenContainsNestedList, GenContainsNestedList, (e1, e2) => (e1, e2));

class EnumerableEqualityComparer<T> : EqualityComparer<IEnumerable<T>>
{
private readonly EqualityComparer<T> EqualityComparer;

public EnumerableEqualityComparer(EqualityComparer<T> equalityComparer)
{
EqualityComparer = equalityComparer;
}

public override bool Equals(IEnumerable<T>? x, IEnumerable<T>? y) =>
x == null ? y == null : (y == null ? false : x.SequenceEqual(y, EqualityComparer));

public override int GetHashCode([DisallowNull] IEnumerable<T> obj)
{
var hashCode = 0;
foreach (var item in obj)
{
if (item != null)
{
hashCode ^= EqualityComparer.GetHashCode(item);
}
}
return hashCode;
}
}

[Fact]
public static void GeneratedNestedListEqualsWorks()
{
var equalityComparer = new EnumerableEqualityComparer<IEnumerable<IEnumerable<BasicEnum>>>(
new EnumerableEqualityComparer<IEnumerable<BasicEnum>>(
new EnumerableEqualityComparer<BasicEnum>(EqualityComparer<BasicEnum>.Default)
)
);
CollisionCounter collisionCounter = new();
GenTwoContainsNestedList.Sample(
example =>
{
var equal = equalityComparer.Equals(example.e1.TheList, example.e2.TheList);

if (equal)
{
Assert.Equal(example.e1, example.e2);
Assert.True(example.e1 == example.e2);
Assert.False(example.e1 != example.e2);
Assert.Equal(example.e1.ToString(), example.e2.ToString());
Assert.Equal(example.e1.GetHashCode(), example.e2.GetHashCode());
}
else
{
Assert.NotEqual(example.e1, example.e2);
Assert.False(example.e1 == example.e2);
Assert.True(example.e1 != example.e2);
Assert.NotEqual(example.e1.ToString(), example.e2.ToString());
collisionCounter.Add(example.e1.GetHashCode() == example.e2.GetHashCode());
}
},
iter: 10_000
);
collisionCounter.AssertCollisionsLessThan(0.05);
}

[Fact]
Expand Down Expand Up @@ -516,5 +760,19 @@ public static void GeneratedToString()
]
).ToString()
);
#pragma warning disable CS8625 // Cannot convert null literal to non-nullable reference type.
Assert.Equal(
"ContainsNestedList { TheList = [ [ [ X(1), null ], null ], null ] }",
new ContainsNestedList(
[
[
[new BasicEnum.X(1), null],
null,
],
null,
]
).ToString()
);
#pragma warning restore CS8625 // Cannot convert null literal to non-nullable reference type.
}
}
2 changes: 1 addition & 1 deletion crates/bindings-csharp/Codegen.Tests/Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ static IEnumerable<Diagnostic> GetCompilationErrors(Compilation compilation)
.Emit(Stream.Null)
.Diagnostics.Where(diag => diag.Severity != DiagnosticSeverity.Hidden)
// The order of diagnostics is not predictable, sort them by location to make the test deterministic.
.OrderBy(diag => diag.Location.ToString());
.OrderBy(diag => diag.GetMessage() + diag.Location.ToString());
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,20 @@ SpacetimeDB.BSATN.ITypeRegistrar registrar

public override int GetHashCode()
{
return IntField.GetHashCode() ^ StringField.GetHashCode();
var ___hashIntField = IntField.GetHashCode();
var ___hashStringField = StringField == null ? 0 : StringField.GetHashCode();
return ___hashIntField ^ ___hashStringField;
}

#nullable enable
public bool Equals(CustomClass that)
{
return IntField.Equals(that.IntField) && StringField.Equals(that.StringField);
var ___eqIntField = this.IntField.Equals(that.IntField);
var ___eqStringField =
this.StringField == null
? that.StringField == null
: this.StringField.Equals(that.StringField);
return ___eqIntField && ___eqStringField;
}

public override bool Equals(object? that)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,20 @@ SpacetimeDB.BSATN.ITypeRegistrar registrar

public override int GetHashCode()
{
return IntField.GetHashCode() ^ StringField.GetHashCode();
var ___hashIntField = IntField.GetHashCode();
var ___hashStringField = StringField == null ? 0 : StringField.GetHashCode();
return ___hashIntField ^ ___hashStringField;
}

#nullable enable
public bool Equals(CustomStruct that)
{
return IntField.Equals(that.IntField) && StringField.Equals(that.StringField);
var ___eqIntField = this.IntField.Equals(that.IntField);
var ___eqStringField =
this.StringField == null
? that.StringField == null
: this.StringField.Equals(that.StringField);
return ___eqIntField && ___eqStringField;
}

public override bool Equals(object? that)
Expand Down
Loading
Loading