diff --git a/Schema Build Tests/readOnly/VarianceTypes.cs b/Schema Build Tests/readOnly/VarianceTypes.cs index 9282438..45553c4 100644 --- a/Schema Build Tests/readOnly/VarianceTypes.cs +++ b/Schema Build Tests/readOnly/VarianceTypes.cs @@ -15,4 +15,15 @@ public partial interface IOkayToAddVariance { [Const] public void PassGetSet(ISet list); } + + + [GenerateReadOnly] + public partial interface IFinCollection; + + [GenerateReadOnly] + public partial interface ISubTypeDictionary + : IFinCollection<(TKey Key, TValue Value)> { + [Const] + TValueSub Get(TKey key) where TValueSub : TValue; + } } \ No newline at end of file diff --git a/Schema Tests/binary/generator/VarianceGeneratorTests.cs b/Schema Tests/binary/generator/VarianceGeneratorTests.cs index 55a70b9..8b2db71 100644 --- a/Schema Tests/binary/generator/VarianceGeneratorTests.cs +++ b/Schema Tests/binary/generator/VarianceGeneratorTests.cs @@ -1,5 +1,7 @@ using NUnit.Framework; +using schema.readOnly; + namespace schema.binary.text { internal class VarianceGeneratorTests { diff --git a/Schema Tests/readOnly/BasicReadOnlyGeneratorTests.cs b/Schema Tests/readOnly/BasicReadOnlyGeneratorTests.cs index 36d267d..d2bdc80 100644 --- a/Schema Tests/readOnly/BasicReadOnlyGeneratorTests.cs +++ b/Schema Tests/readOnly/BasicReadOnlyGeneratorTests.cs @@ -94,7 +94,7 @@ public partial class SimpleGenerics : IReadOnlySimpleGenerics { T1 IReadOnlySimpleGenerics.Foo(T1 t1, T2 t2, T3 t3, T4 t4) => Foo(t1, t2, t3, t4); } - public interface IReadOnlySimpleGenerics { + public interface IReadOnlySimpleGenerics { public T1 Foo(T1 t1, T2 t2, T3 t3, T4 t4); } } diff --git a/Schema Tests/readOnly/ConstraintTests.cs b/Schema Tests/readOnly/ConstraintTests.cs index f6a0d50..d6fd1d1 100644 --- a/Schema Tests/readOnly/ConstraintTests.cs +++ b/Schema Tests/readOnly/ConstraintTests.cs @@ -66,7 +66,7 @@ public partial interface ICircular : IReadOnlyCircul TMutable IReadOnlyCircular.Foo(in TImpl other) => Foo(in other); } - public interface IReadOnlyCircular where TMutable : ICircular, TReadOnly where TReadOnly : IReadOnlyCircular { + public interface IReadOnlyCircular where TMutable : ICircular, TReadOnly where TReadOnly : IReadOnlyCircular { public TMutable Foo(TReadOnly other); public TMutable Foo(in TImpl other); } @@ -98,7 +98,7 @@ public partial class SubConstraint : IReadOnlySubConstraint { T2 IReadOnlySubConstraint.Bar => Bar; } - public interface IReadOnlySubConstraint where T2 : T1 { + public interface IReadOnlySubConstraint where T2 : T1 { public T1 Foo(S s) where S : T1; public T2 Bar { get; } } diff --git a/Schema Tests/readOnly/VarianceTests.cs b/Schema Tests/readOnly/VarianceTests.cs index f9014de..5d8f6bd 100644 --- a/Schema Tests/readOnly/VarianceTests.cs +++ b/Schema Tests/readOnly/VarianceTests.cs @@ -216,17 +216,17 @@ public interface IReadOnlyWrapper { public void TestDoesNotAddContravarianceForSet() { ReadOnlyGeneratorTestUtil.AssertGenerated( """ - using schema.readOnly; - using System.Collections.Generic; + using schema.readOnly; + using System.Collections.Generic; - namespace foo.bar { - [GenerateReadOnly] - public partial interface IWrapper { - [Const] - public void Method(ISet foo); - } + namespace foo.bar { + [GenerateReadOnly] + public partial interface IWrapper { + [Const] + public void Method(ISet foo); } - """, + } + """, """ namespace foo.bar { public partial interface IWrapper : IReadOnlyWrapper { @@ -240,5 +240,61 @@ public interface IReadOnlyWrapper { """); } + + [Test] + public void TestDoesNotAddVarianceWhenUsedAsTypeConstraint() { + ReadOnlyGeneratorTestUtil.AssertGenerated( + """ + using schema.readOnly; + + namespace foo.bar { + public partial interface IFinCollection; + + [GenerateReadOnly] + public partial interface ISubTypeDictionary + : IFinCollection + where T2 : T1; + } + """, + """ + namespace foo.bar { + public partial interface ISubTypeDictionary : IReadOnlySubTypeDictionary; + + public interface IReadOnlySubTypeDictionary : IFinCollection where T2 : T1; + } + + """); + } + + [Test] + public void TestDoesNotAddVarianceWhenUsedAsMethodConstraint() { + ReadOnlyGeneratorTestUtil.AssertGenerated( + """ + using schema.readOnly; + + namespace foo.bar { + public partial interface IFinCollection; + + [GenerateReadOnly] + public partial interface ISubTypeDictionary + : IFinCollection<(TKey Key, TValue Value)> { + [Const] + TValueSub Get(TKey key) where TValueSub : TValue; + } + } + """, + """ + namespace foo.bar { + public partial interface ISubTypeDictionary : IReadOnlySubTypeDictionary { + TValueSub IReadOnlySubTypeDictionary.Get(TKey key) => Get(key); + } + + public interface IReadOnlySubTypeDictionary : IFinCollection<(TKey Key, TValue Value)> { + public TValueSub Get(TKey key) where TValueSub : TValue; + } + } + + """); + } } } \ No newline at end of file diff --git a/Schema/src/readOnly/ReadOnlyTypeGenerator.cs b/Schema/src/readOnly/ReadOnlyTypeGenerator.cs index 0c88049..0adcf1d 100644 --- a/Schema/src/readOnly/ReadOnlyTypeGenerator.cs +++ b/Schema/src/readOnly/ReadOnlyTypeGenerator.cs @@ -615,6 +615,11 @@ public static string GetGenericParametersWithVarianceForReadOnlyVersion( var allParentTypes = symbol.GetBaseTypes().Concat(symbol.AllInterfaces).ToArray(); + var set = new TypeParameterSymbolVarianceSet( + typeParameters, + allParentTypes, + constMembers); + var sb = new StringBuilder(); sb.Append("<"); for (var i = 0; i < typeParameters.Length; ++i) { @@ -626,9 +631,7 @@ var allParentTypes var variance = typeParameter.Variance; if (variance == VarianceKind.None) { - variance = typeParameter.GetExpandedVarianceForReadonlyVersion( - allParentTypes, - constMembers); + variance = set.AllowedVariance(typeParameter); } sb.Append(variance switch { @@ -643,90 +646,6 @@ var allParentTypes return sb.ToString(); } - public static VarianceKind GetExpandedVarianceForReadonlyVersion( - this ITypeParameterSymbol typeParameterSymbol, - IReadOnlyList allParentTypes, - IReadOnlyList constMembers) { - var matchingTypeArguments - = allParentTypes - .SelectMany(NamedTypeSymbolUtil.GetTypeParamsAndArgs) - .Where(paramAndArg => paramAndArg.typeArgumentSymbol.Name == - typeParameterSymbol.Name) - .ToArray(); - - if (matchingTypeArguments.Length == 0 || - matchingTypeArguments.All( - paramAndArg => paramAndArg.typeParameterSymbol.Variance == - VarianceKind.Out)) { - return constMembers.Any( - constMember - => constMember.Parameters.Any( - p => p.Type.DependsOn(typeParameterSymbol)) || - constMember.ReturnType.DependsOnButHasWrongVariance( - typeParameterSymbol, - VarianceKind.Out)) - ? VarianceKind.None - : VarianceKind.Out; - } - - if (matchingTypeArguments.Length == 0 || - matchingTypeArguments.All( - paramAndArg => paramAndArg.typeParameterSymbol.Variance == - VarianceKind.In)) { - return constMembers.Any( - constMember - => constMember.ReturnType.DependsOn(typeParameterSymbol) || - constMember.Parameters.Any( - p => p.Type.DependsOnButHasWrongVariance( - typeParameterSymbol, - VarianceKind.In))) - ? VarianceKind.None - : VarianceKind.In; - } - - return VarianceKind.None; - } - - public static bool DependsOn(this ITypeSymbol typeSymbol, - ITypeParameterSymbol typeParameterSymbol) - => typeSymbol.DependsOnImpl_(null, typeParameterSymbol, out _); - - private static bool DependsOnImpl_( - this ITypeSymbol typeSymbol, - ITypeParameterSymbol? thisTypeParameterSymbol, - ITypeParameterSymbol otherTypeParameterSymbol, - out ITypeParameterSymbol? match) { - if (typeSymbol.IsSameAs(otherTypeParameterSymbol)) { - match = thisTypeParameterSymbol; - return true; - } - - if (typeSymbol.IsGenericZipped(out var typeParamsAndArgs)) { - foreach (var (typeParam, typeArg) in typeParamsAndArgs) { - if (typeArg.DependsOnImpl_(typeParam, - otherTypeParameterSymbol, - out match)) { - return true; - } - } - } - - match = null; - return false; - } - - public static bool DependsOnButHasWrongVariance( - this ITypeSymbol typeSymbol, - ITypeParameterSymbol otherTypeSymbol, - VarianceKind expectedVarianceKind) { - if (typeSymbol.DependsOnImpl_(null, otherTypeSymbol, out var match) && - match != null) { - return match.Variance != expectedVarianceKind; - } - - return false; - } - public static string GetCStyleCastToReadOnlyIfNeeded( this ITypeSymbol sourceSymbol, ITypeSymbol symbol, diff --git a/Schema/src/util/symbols/TypeParameterSymbolVarianceSet.cs b/Schema/src/util/symbols/TypeParameterSymbolVarianceSet.cs new file mode 100644 index 0000000..66bd2b6 --- /dev/null +++ b/Schema/src/util/symbols/TypeParameterSymbolVarianceSet.cs @@ -0,0 +1,206 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +using Microsoft.CodeAnalysis; + +namespace schema.util.symbols { + public interface ITypeParameterSymbolVarianceSet { + VarianceKind AllowedVariance(ITypeParameterSymbol typeParameterSymbol); + } + + [Flags] + public enum AllowedVarianceType { + NONE = 0, + OUT = 1, + IN = 2, + + BOTH = OUT | IN + } + + public static class AllowedVarianceTypeExtensions { + public static bool AllowsOut(this AllowedVarianceType allowedVariance) + => (allowedVariance & AllowedVarianceType.OUT) != 0; + + public static bool AllowsIn(this AllowedVarianceType allowedVariance) + => (allowedVariance & AllowedVarianceType.IN) != 0; + + public static VarianceKind ToVarianceKind( + this AllowedVarianceType allowedVariance) + => allowedVariance.AllowsOut() ? VarianceKind.Out : + allowedVariance.AllowsIn() ? VarianceKind.In : VarianceKind.None; + + public static AllowedVarianceType Intersection( + this AllowedVarianceType allowedVariance, + VarianceKind varianceKind) + => varianceKind == VarianceKind.In && allowedVariance.AllowsIn() + ? AllowedVarianceType.IN + : varianceKind == VarianceKind.Out && allowedVariance.AllowsOut() + ? AllowedVarianceType.OUT + : AllowedVarianceType.NONE; + } + + public class TypeParameterSymbolVarianceSet + : ITypeParameterSymbolVarianceSet { + private readonly IDictionary impl_ + = new Dictionary( + SymbolEqualityComparer.Default); + + public TypeParameterSymbolVarianceSet( + IEnumerable containerTypeParameterSymbols, + IEnumerable parentTypes, + IReadOnlyList constMembers) { + var knownContainerTypeParameterSymbols + = new HashSet(containerTypeParameterSymbols, + SymbolEqualityComparer.Default); + + foreach (var containerTypeParameter in + knownContainerTypeParameterSymbols) { + this.impl_[containerTypeParameter] = AllowedVarianceType.BOTH; + } + + { + var visitedParentTypeSymbols + = new HashSet(SymbolEqualityComparer.Default); + foreach (var parentType in parentTypes) { + this.VisitParentTypeSymbol_(parentType, + visitedParentTypeSymbols, + knownContainerTypeParameterSymbols); + } + } + + var visitedReturnTypeSymbols + = new HashSet(SymbolEqualityComparer.Default); + var visitedParameterTypeSymbols + = new HashSet(SymbolEqualityComparer.Default); + foreach (var constMember in constMembers) { + this.VisitReturnTypeSymbol_(constMember.ReturnType, + visitedReturnTypeSymbols, + knownContainerTypeParameterSymbols); + + foreach (var parameter in constMember.Parameters) { + this.VisitParameterTypeSymbol_(parameter.Type, + visitedParameterTypeSymbols, + knownContainerTypeParameterSymbols); + } + } + } + + public VarianceKind AllowedVariance( + ITypeParameterSymbol typeParameterSymbol) + => this.impl_[typeParameterSymbol].ToVarianceKind(); + + private void VisitParentTypeSymbol_( + ITypeSymbol parentTypeSymbol, + ISet visitedParentTypeSymbols, + ISet knownContainerTypeParameterSymbols) { + ForEachNewlyVisited_( + parentTypeSymbol, + visitedParentTypeSymbols, + knownContainerTypeParameterSymbols, + (typeSymbol, typeParam) => { + if (typeParam != null) { + this.impl_[typeSymbol] = + this.impl_[typeSymbol].Intersection(typeParam.Variance); + } + }); + } + + private void VisitReturnTypeSymbol_( + ITypeSymbol returnTypeSymbol, + ISet visitedReturnTypeSymbols, + ISet knownContainerTypeParameterSymbols) { + ForEachNewlyVisited_( + returnTypeSymbol, + visitedReturnTypeSymbols, + knownContainerTypeParameterSymbols, + (typeSymbol, typeParam) => { + this.impl_[typeSymbol] = + this.impl_[typeSymbol].Intersection(VarianceKind.Out); + + if (typeParam != null) { + this.impl_[typeSymbol] = + this.impl_[typeSymbol].Intersection(typeParam.Variance); + } + }); + } + + private void VisitParameterTypeSymbol_( + ITypeSymbol parameterTypeSymbol, + ISet visitedParameterTypeSymbols, + ISet knownContainerTypeParameterSymbols) { + ForEachNewlyVisited_( + parameterTypeSymbol, + visitedParameterTypeSymbols, + knownContainerTypeParameterSymbols, + (typeSymbol, typeParam) => { + this.impl_[typeSymbol] = + this.impl_[typeSymbol].Intersection(VarianceKind.In); + + if (typeParam != null) { + this.impl_[typeSymbol] = + this.impl_[typeSymbol].Intersection(typeParam.Variance); + } + }); + } + + private delegate void ForEachSymbolDelegate( + ITypeSymbol typeSymbol, + ITypeParameterSymbol? typeParam); + + private static void ForEachNewlyVisited_( + ITypeSymbol currentTypeSymbol, + ISet visitedTypeSymbols, + ISet symbolsToMatchAgainst, + ForEachSymbolDelegate matchHandler) + => ForEachNewlyVisitedImpl_(currentTypeSymbol, + null, + visitedTypeSymbols, + symbolsToMatchAgainst, + matchHandler); + + private static void ForEachNewlyVisitedImpl_( + ITypeSymbol currentTypeSymbol, + ITypeParameterSymbol? currentTypeParameterSymbol, + ISet visitedTypeSymbols, + ISet symbolsToMatchAgainst, + ForEachSymbolDelegate handler) { + if (!visitedTypeSymbols.Add(currentTypeSymbol)) { + return; + } + + if (symbolsToMatchAgainst.Contains(currentTypeSymbol)) { + handler(currentTypeSymbol, currentTypeParameterSymbol); + } + + if (currentTypeSymbol.IsGenericTypeParameter( + out var asTypeParameterSymbol)) { + foreach (var constraintType in asTypeParameterSymbol.ConstraintTypes) { + ForEachNewlyVisitedImpl_(constraintType, + asTypeParameterSymbol, + visitedTypeSymbols, + symbolsToMatchAgainst, + handler); + } + } + + if (currentTypeSymbol.IsGenericZipped(out var typeParamsAndArgs)) { + foreach (var (typeParam, typeArg) in typeParamsAndArgs) { + foreach (var constraintType in typeParam.ConstraintTypes) { + ForEachNewlyVisitedImpl_(constraintType, + typeParam, + visitedTypeSymbols, + symbolsToMatchAgainst, + handler); + } + + ForEachNewlyVisitedImpl_(typeArg, + typeParam, + visitedTypeSymbols, + symbolsToMatchAgainst, + handler); + } + } + } + } +} \ No newline at end of file