Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to generate specialized partition implementation #15

Merged
merged 19 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using System.Runtime.CompilerServices;

namespace Funcky.DiscriminatedUnion.SourceGeneration;

[InterpolatedStringHandler]
public readonly struct CollectingInterpolatedStringHandler
{
public CollectingInterpolatedStringHandler(int literalLength, int formattedCount)
{
_items = new List<object?>(formattedCount * 2);
}

public CollectingInterpolatedStringHandler()
{
_items = new List<object?>();
}

private readonly List<object?> _items;

public IEnumerable<object?> GetItems() => _items;

public void AppendLiteral(string s) => _items.Add(s);

public void AppendFormatted<T>(T t) => _items.Add(t);

public void AppendFormatted(CollectingInterpolatedStringHandler handler)
=> _items.AddRange(handler._items);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
namespace Funcky.DiscriminatedUnion.SourceGeneration;

public static class CollectingInterpolatedStringHandlerExtensions
{
public static CollectingInterpolatedStringHandler JoinToInterpolation<T>(
this IEnumerable<T> source,
string separator)
{
using var enumerator = source.GetEnumerator();

if (!enumerator.MoveNext())
{
return new CollectingInterpolatedStringHandler();
}

var result = new CollectingInterpolatedStringHandler();

result.AppendFormatted(enumerator.Current);

while (enumerator.MoveNext())
{
result.AppendLiteral(separator);
result.AppendFormatted(enumerator.Current);
}

return result;
}

public static CollectingInterpolatedStringHandler JoinToInterpolation<T>(
this IEnumerable<T> source,
Func<T, CollectingInterpolatedStringHandler> createPart,
string separator)
{
using var enumerator = source.GetEnumerator();

if (!enumerator.MoveNext())
{
return new CollectingInterpolatedStringHandler();
}

var result = new CollectingInterpolatedStringHandler();

result.AppendFormatted(createPart(enumerator.Current));

while (enumerator.MoveNext())
{
result.AppendLiteral(separator);
result.AppendFormatted(createPart(enumerator.Current));
}

return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ internal sealed record DiscriminatedUnion(
TypeDeclarationSyntax Type,
IReadOnlyList<TypeDeclarationSyntax> ParentTypes,
string? Namespace,
string MethodVisibility,
string GeneratedMethodOrClassVisibility,
string MatchResultTypeName,
IReadOnlyList<DiscriminatedUnionVariant> Variants);
IReadOnlyList<DiscriminatedUnionVariant> Variants,
bool GeneratePartitionExtension);

internal sealed record DiscriminatedUnionVariant(
TypeDeclarationSyntax Type,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Funcky.DiscriminatedUnion.SourceGeneration.Emitter;
Expand Down Expand Up @@ -32,8 +33,7 @@ private static void AddSource(SourceProductionContext context, ImmutableArray<st
{
if (code.Any())
{
var combinedCode = $"{GeneratedFileHeadersSource}{Environment.NewLine}{Environment.NewLine}" +
$"{string.Join(Environment.NewLine, code)}";
var combinedCode = $"{GeneratedFileHeadersSource}\n\n{string.Join("\n", code)}";
context.AddSource("DiscriminatedUnionGenerator.g.cs", combinedCode);
}
}
Expand Down
135 changes: 106 additions & 29 deletions Funcky.DiscriminatedUnion.SourceGeneration/Emitter.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System.CodeDom.Compiler;
Expand All @@ -18,30 +17,108 @@ public static string EmitDiscriminatedUnion(DiscriminatedUnion discriminatedUnio
{
WriteNamespace(writer, discriminatedUnion);

WriteParentTypes(writer, discriminatedUnion.ParentTypes);
WriteUnionType(discriminatedUnion, writer);

WriteJsonDerivedTypeAttributes(writer, discriminatedUnion);
writer.WriteLine(FormatPartialTypeDeclaration(discriminatedUnion.Type));
writer.OpenScope();

WriteGeneratedMethod(writer, $"{discriminatedUnion.MethodVisibility} abstract {FormatMatchMethodDeclaration(discriminatedUnion.MatchResultTypeName, discriminatedUnion.Variants)};");
writer.WriteLine();
WriteGeneratedMethod(writer, $"{discriminatedUnion.MethodVisibility} abstract {FormatSwitchMethodDeclaration(discriminatedUnion.Variants)};");

foreach (var variant in discriminatedUnion.Variants)
if (discriminatedUnion.GeneratePartitionExtension)
{
WriteVariant(writer, discriminatedUnion, variant);
writer.WriteLine();
WritePartitionExtensions(discriminatedUnion, writer);
}
}

return stringBuilder.ToString();
}

private static void WriteUnionType(DiscriminatedUnion discriminatedUnion, IndentedTextWriter writer)
{
using var scope = writer.AutoCloseScopes();

WriteParentTypes(writer, discriminatedUnion.ParentTypes);

WriteJsonDerivedTypeAttributes(writer, discriminatedUnion);
writer.WriteLineInterpolated(FormatPartialTypeDeclaration(discriminatedUnion.Type));
writer.OpenScope();

WriteGeneratedMethod(writer, $"{discriminatedUnion.GeneratedMethodOrClassVisibility} abstract {FormatMatchMethodDeclaration(discriminatedUnion.MatchResultTypeName, discriminatedUnion.Variants)};");
writer.WriteLine();
WriteGeneratedMethod(writer, $"{discriminatedUnion.GeneratedMethodOrClassVisibility} abstract {FormatSwitchMethodDeclaration(discriminatedUnion.Variants)};");

foreach (var variant in discriminatedUnion.Variants)
{
WriteVariant(writer, discriminatedUnion, variant);
}
}

private static void WritePartitionExtensions(DiscriminatedUnion discriminatedUnion, IndentedTextWriter writer)
{
using var scope = writer.AutoCloseScopes();

writer.WriteLine(GeneratedCodeAttributeSource);
writer.WriteLineInterpolated($"{discriminatedUnion.GeneratedMethodOrClassVisibility} static partial class {discriminatedUnion.Type.Identifier}EnumerableExtensions");
writer.OpenScope();

WriteTupleReturningPartitionExtension(discriminatedUnion, writer);
writer.WriteLine();
WritePartitionWithResultSelector(discriminatedUnion, writer);
}

private static void WriteTupleReturningPartitionExtension(DiscriminatedUnion discriminatedUnion, IndentedTextWriter writer)
{
using var methodScope = writer.AutoCloseScopes();

var namedResultPartitions = discriminatedUnion
.Variants
.JoinToInterpolation(
v => $"global::System.Collections.Generic.IReadOnlyList<{discriminatedUnion.Type.Identifier}.{v.LocalTypeName}> {v.ParameterName[..1].ToUpper()}{v.ParameterName[1..]}",
", ");

writer.WriteLineInterpolated($"public static ({namedResultPartitions}) Partition(this global::System.Collections.Generic.IEnumerable<{discriminatedUnion.Type.Identifier}> source)");
writer.OpenScope();

WritePartitioningIntoLists(discriminatedUnion, writer);

writer.WriteLineInterpolated($"return ({discriminatedUnion.Variants.JoinToInterpolation(v => $"{v.ParameterName}Items.AsReadOnly()", ", ")});");
}

private static void WritePartitionWithResultSelector(DiscriminatedUnion discriminatedUnion, IndentedTextWriter writer)
{
using var methodScope = writer.AutoCloseScopes();

writer.WriteInterpolated($"public static TResult Partition<{discriminatedUnion.MatchResultTypeName}>(this global::System.Collections.Generic.IEnumerable<{discriminatedUnion.Type.Identifier}> source, global::System.Func<");

foreach (var variant in discriminatedUnion.Variants)
{
writer.WriteInterpolated($"global::System.Collections.Generic.IReadOnlyList<{discriminatedUnion.Type.Identifier}.{variant.LocalTypeName}>, ");
}

writer.WriteLineInterpolated($"{discriminatedUnion.MatchResultTypeName}> resultSelector)");
writer.OpenScope();

WritePartitioningIntoLists(discriminatedUnion, writer);

writer.WriteLineInterpolated($"return resultSelector({discriminatedUnion.Variants.JoinToInterpolation(v => $"{v.ParameterName}Items.AsReadOnly()", ", ")});");
}

private static void WritePartitioningIntoLists(DiscriminatedUnion discriminatedUnion, IndentedTextWriter writer)
{
foreach (var variant in discriminatedUnion.Variants)
{
writer.WriteLineInterpolated($"var {variant.ParameterName}Items = new global::System.Collections.Generic.List<{discriminatedUnion.Type.Identifier}.{variant.LocalTypeName}>();");
}

using (writer.AutoCloseScopes())
{
writer.WriteLine("foreach (var item in source)");
writer.OpenScope();
writer.WriteLineInterpolated($"item.Switch({discriminatedUnion.Variants.JoinToInterpolation(v => $"{v.ParameterName}: {v.ParameterName}Items.Add", ", ")});");
}
}

private static void WriteNamespace(IndentedTextWriter writer, DiscriminatedUnion discriminatedUnion)
{
if (!string.IsNullOrEmpty(discriminatedUnion.Namespace))
{
writer.WriteLine($"namespace {discriminatedUnion.Namespace}");
writer.WriteLineInterpolated($"namespace {discriminatedUnion.Namespace}");
writer.OpenScope();
}
}
Expand All @@ -50,7 +127,7 @@ private static void WriteParentTypes(IndentedTextWriter writer, IEnumerable<Type
{
foreach (var parent in parents.Reverse())
{
writer.WriteLine(FormatPartialTypeDeclaration(parent));
writer.WriteLineInterpolated(FormatPartialTypeDeclaration(parent));
writer.OpenScope();
}
}
Expand All @@ -63,19 +140,19 @@ private static void WriteVariant(IndentedTextWriter writer, DiscriminatedUnion d

WriteParentTypes(writer, variant.ParentTypes);

writer.WriteLine(FormatPartialTypeDeclaration(variant.Type));
writer.WriteLineInterpolated(FormatPartialTypeDeclaration(variant.Type));
writer.OpenScope();

WriteGeneratedMethod(writer, $"{discriminatedUnion.MethodVisibility} override {FormatMatchMethodDeclaration(discriminatedUnion.MatchResultTypeName, discriminatedUnion.Variants)} => {FormatIdentifier(variant.ParameterName)}(this);");
WriteGeneratedMethod(writer, $"{discriminatedUnion.GeneratedMethodOrClassVisibility} override {FormatMatchMethodDeclaration(discriminatedUnion.MatchResultTypeName, discriminatedUnion.Variants)} => {FormatIdentifier(variant.ParameterName)}(this);");
writer.WriteLine();
WriteGeneratedMethod(writer, $"{discriminatedUnion.MethodVisibility} override {FormatSwitchMethodDeclaration(discriminatedUnion.Variants)} => {FormatIdentifier(variant.ParameterName)}(this);");
WriteGeneratedMethod(writer, $"{discriminatedUnion.GeneratedMethodOrClassVisibility} override {FormatSwitchMethodDeclaration(discriminatedUnion.Variants)} => {FormatIdentifier(variant.ParameterName)}(this);");
}
}

private static void WriteGeneratedMethod(IndentedTextWriter writer, string method)
private static void WriteGeneratedMethod(IndentedTextWriter writer, CollectingInterpolatedStringHandler method)
{
writer.WriteLine(GeneratedCodeAttributeSource);
writer.WriteLine(method);
writer.WriteLineInterpolated(method);
}

private static void WriteJsonDerivedTypeAttributes(IndentedTextWriter writer, DiscriminatedUnion discriminatedUnion)
Expand All @@ -90,26 +167,26 @@ private static void WriteJsonDerivedTypeAttribute(IndentedTextWriter writer, Dis
{
if (variant.GenerateJsonDerivedTypeAttribute)
{
writer.WriteLine($"[global::System.Text.Json.Serialization.JsonDerivedType(typeof({variant.TypeOfTypeName}), {SyntaxFactory.Literal(variant.JsonDerivedTypeDiscriminator)})]");
writer.WriteLineInterpolated($"[global::System.Text.Json.Serialization.JsonDerivedType(typeof({variant.TypeOfTypeName}), {SyntaxFactory.Literal(variant.JsonDerivedTypeDiscriminator)})]");
}
}

private static string FormatMatchMethodDeclaration(string genericTypeName, IEnumerable<DiscriminatedUnionVariant> variants)
=> $"{genericTypeName} Match<{genericTypeName}>({string.Join(", ", variants.Select(variant => $"global::System.Func<{variant.LocalTypeName}, {genericTypeName}> {FormatIdentifier(variant.ParameterName)}"))})";
private static CollectingInterpolatedStringHandler FormatMatchMethodDeclaration(string genericTypeName, IEnumerable<DiscriminatedUnionVariant> variants)
=> $"{genericTypeName} Match<{genericTypeName}>({variants.JoinToInterpolation(v => $"global::System.Func<{v.LocalTypeName}, {genericTypeName}> {FormatIdentifier(v.ParameterName)}", ", ")})";

private static string FormatSwitchMethodDeclaration(IEnumerable<DiscriminatedUnionVariant> variants)
=> $"void Switch({string.Join(", ", variants.Select(variant => $"global::System.Action<{variant.LocalTypeName}> {FormatIdentifier(variant.ParameterName)}"))})";
private static CollectingInterpolatedStringHandler FormatSwitchMethodDeclaration(IEnumerable<DiscriminatedUnionVariant> variants)
=> $"void Switch({variants.JoinToInterpolation(v => $"global::System.Action<{v.LocalTypeName}> {FormatIdentifier(v.ParameterName)}", ", ")})";

private static string FormatPartialTypeDeclaration(TypeDeclarationSyntax typeDeclaration)
private static CollectingInterpolatedStringHandler FormatPartialTypeDeclaration(TypeDeclarationSyntax typeDeclaration)
=> typeDeclaration is RecordDeclarationSyntax recordDeclaration
? CombineTokens("partial", typeDeclaration.Keyword, recordDeclaration.ClassOrStructKeyword, typeDeclaration.Identifier.ToString() + typeDeclaration.TypeParameterList, typeDeclaration.ConstraintClauses)
: CombineTokens("partial", typeDeclaration.Keyword, typeDeclaration.Identifier.ToString() + typeDeclaration.TypeParameterList, typeDeclaration.ConstraintClauses);

private static string CombineTokens(params object[] tokens)
=> string.Join(" ", tokens.Select(t => t.ToString()).Where(t => !string.IsNullOrEmpty(t)));
private static CollectingInterpolatedStringHandler CombineTokens(params object[] tokens)
=> tokens.Select(t => t.ToString()).Where(t => !string.IsNullOrEmpty(t)).JoinToInterpolation(" ");

private static string FormatIdentifier(string identifier)
=> IsIdentifier(identifier) ? '@' + identifier : identifier;
private static CollectingInterpolatedStringHandler FormatIdentifier(string identifier)
=> $"{(IsIdentifier(identifier) ? "@" : string.Empty)}{identifier}";

private static bool IsIdentifier(string identifier)
=> SyntaxFacts.GetKeywordKind(identifier) != SyntaxKind.None;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<LangVersion>11.0</LangVersion>
<EnforceExtendedAnalyzerRules>true</EnforceExtendedAnalyzerRules>
</PropertyGroup>
<PropertyGroup Label="NuGet Metadata">
<Version>1.1.0</Version>
<Version>1.2.0</Version>
<PackageId>Funcky.DiscriminatedUnion</PackageId>
<Authors>Polyadic</Authors>
<PackageLicenseExpression>MIT OR Apache-2.0</PackageLicenseExpression>
Expand Down Expand Up @@ -40,6 +41,6 @@
<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="3.3.3" PrivateAssets="all" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.0.1" />
<PackageReference Include="IsExternalInit" Version="1.0.1" PrivateAssets="all" />
<PackageReference Include="PolySharp" Version="1.14.1" PrivateAssets="all" />
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@ public static void OpenScope(this IndentedTextWriter writer)
writer.Indent++;
}

public static void WriteInterpolated(this IndentedTextWriter writer, CollectingInterpolatedStringHandler value)
{
foreach (var item in value.GetItems())
{
writer.Write(item?.ToString());
}
}

public static void WriteLineInterpolated(this IndentedTextWriter writer, CollectingInterpolatedStringHandler value)
{
writer.WriteInterpolated(value);
writer.WriteLine();
}

private static void CloseScope(this IndentedTextWriter writer)
{
writer.Indent--;
Expand Down
14 changes: 10 additions & 4 deletions Funcky.DiscriminatedUnion.SourceGeneration/Parser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,16 @@ public static bool IsSyntaxTargetForGeneration(SyntaxNode node)
return null;
}

var (nonExhaustive, flatten, matchResultType) = ParseAttribute(typeSymbol);
var (nonExhaustive, flatten, matchResultType, generatePartitionExtension) = ParseAttribute(typeSymbol);
var isVariant = flatten ? IsVariantOfDiscriminatedUnionFlattened(typeSymbol, semanticModel) : IsVariantOfDiscriminatedUnion(typeSymbol, semanticModel);

return new DiscriminatedUnion(
Type: typeDeclaration,
ParentTypes: typeDeclaration.Ancestors().OfType<TypeDeclarationSyntax>().ToList(),
Namespace: FormatNamespace(typeSymbol),
MatchResultTypeName: matchResultType ?? "TResult",
MethodVisibility: nonExhaustive ? "internal" : "public",
GeneratedMethodOrClassVisibility: nonExhaustive ? "internal" : "public",
GeneratePartitionExtension: generatePartitionExtension,
Variants: GetVariantTypeDeclarations(typeDeclaration, isVariant)
.Select(GetDiscriminatedUnionVariant(typeDeclaration, semanticModel, GenerateJsonDerivedTypeAttribute(typeSymbol)))
.ToList());
Expand All @@ -50,7 +51,8 @@ private static DiscriminatedUnionAttributeData ParseAttribute(ITypeSymbol type)
var nonExhaustive = attribute.GetNamedArgumentOrDefault<bool>(AttributeProperties.NonExhaustive);
var flatten = attribute.GetNamedArgumentOrDefault<bool>(AttributeProperties.Flatten);
var matchResultType = attribute.GetNamedArgumentOrDefault<string>(AttributeProperties.MatchResultTypeName);
return new DiscriminatedUnionAttributeData(nonExhaustive, flatten, matchResultType);
var generatePartitionExtension = attribute.GetNamedArgumentOrDefault<bool>(AttributeProperties.GeneratePartitionExtension);
return new DiscriminatedUnionAttributeData(nonExhaustive, flatten, matchResultType, generatePartitionExtension);
}

private static string? FormatNamespace(INamedTypeSymbol typeSymbol)
Expand Down Expand Up @@ -119,7 +121,11 @@ private static Func<AttributeSyntax, bool> IsDiscriminatedUnionAttribute(Generat
=> context.SemanticModel.GetSymbolInfo(attribute, cancellationToken).Symbol is IMethodSymbol attributeSymbol
&& attributeSymbol.ContainingType.ToDisplayString() == AttributeFullName;

private sealed record DiscriminatedUnionAttributeData(bool NonExhaustive, bool Flatten, string? MatchResultType);
private sealed record DiscriminatedUnionAttributeData(
bool NonExhaustive,
bool Flatten,
string? MatchResultType,
bool GeneratePartitionExtension);

private sealed class VariantCollectingVisitor : CSharpSyntaxWalker
{
Expand Down
Loading