Skip to content

Commit

Permalink
Merge pull request #15 from Mafii/partition-specialization
Browse files Browse the repository at this point in the history
Add option to generate specialized partition implementation
  • Loading branch information
bash authored Oct 1, 2024
2 parents 12b8047 + 93c1c07 commit 8d8d375
Show file tree
Hide file tree
Showing 46 changed files with 710 additions and 118 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using System.Runtime.CompilerServices;

namespace Funcky.DiscriminatedUnion.SourceGeneration;

[InterpolatedStringHandler]
public readonly struct CollectingInterpolatedStringHandler
{
public CollectingInterpolatedStringHandler(int literalLength, int formattedCount)
{
_items = new List<object?>(formattedCount * 2);
}

public CollectingInterpolatedStringHandler()
{
_items = new List<object?>();
}

private readonly List<object?> _items;

public IEnumerable<object?> GetItems() => _items;

public void AppendLiteral(string s) => _items.Add(s);

public void AppendFormatted<T>(T t) => _items.Add(t);

public void AppendFormatted(CollectingInterpolatedStringHandler handler)
=> _items.AddRange(handler._items);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
namespace Funcky.DiscriminatedUnion.SourceGeneration;

public static class CollectingInterpolatedStringHandlerExtensions
{
public static CollectingInterpolatedStringHandler JoinToInterpolation<T>(
this IEnumerable<T> source,
string separator)
{
using var enumerator = source.GetEnumerator();

if (!enumerator.MoveNext())
{
return new CollectingInterpolatedStringHandler();
}

var result = new CollectingInterpolatedStringHandler();

result.AppendFormatted(enumerator.Current);

while (enumerator.MoveNext())
{
result.AppendLiteral(separator);
result.AppendFormatted(enumerator.Current);
}

return result;
}

public static CollectingInterpolatedStringHandler JoinToInterpolation<T>(
this IEnumerable<T> source,
Func<T, CollectingInterpolatedStringHandler> createPart,
string separator)
{
using var enumerator = source.GetEnumerator();

if (!enumerator.MoveNext())
{
return new CollectingInterpolatedStringHandler();
}

var result = new CollectingInterpolatedStringHandler();

result.AppendFormatted(createPart(enumerator.Current));

while (enumerator.MoveNext())
{
result.AppendLiteral(separator);
result.AppendFormatted(createPart(enumerator.Current));
}

return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ internal sealed record DiscriminatedUnion(
TypeDeclarationSyntax Type,
IReadOnlyList<TypeDeclarationSyntax> ParentTypes,
string? Namespace,
string MethodVisibility,
string GeneratedMethodOrClassVisibility,
string MatchResultTypeName,
IReadOnlyList<DiscriminatedUnionVariant> Variants);
IReadOnlyList<DiscriminatedUnionVariant> Variants,
bool GeneratePartitionExtension);

internal sealed record DiscriminatedUnionVariant(
TypeDeclarationSyntax Type,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Funcky.DiscriminatedUnion.SourceGeneration.Emitter;
Expand Down Expand Up @@ -32,8 +33,7 @@ private static void AddSource(SourceProductionContext context, ImmutableArray<st
{
if (code.Any())
{
var combinedCode = $"{GeneratedFileHeadersSource}{Environment.NewLine}{Environment.NewLine}" +
$"{string.Join(Environment.NewLine, code)}";
var combinedCode = $"{GeneratedFileHeadersSource}\n\n{string.Join("\n", code)}";
context.AddSource("DiscriminatedUnionGenerator.g.cs", combinedCode);
}
}
Expand Down
135 changes: 106 additions & 29 deletions Funcky.DiscriminatedUnion.SourceGeneration/Emitter.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System.CodeDom.Compiler;
Expand All @@ -18,30 +17,108 @@ public static string EmitDiscriminatedUnion(DiscriminatedUnion discriminatedUnio
{
WriteNamespace(writer, discriminatedUnion);

WriteParentTypes(writer, discriminatedUnion.ParentTypes);
WriteUnionType(discriminatedUnion, writer);

WriteJsonDerivedTypeAttributes(writer, discriminatedUnion);
writer.WriteLine(FormatPartialTypeDeclaration(discriminatedUnion.Type));
writer.OpenScope();

WriteGeneratedMethod(writer, $"{discriminatedUnion.MethodVisibility} abstract {FormatMatchMethodDeclaration(discriminatedUnion.MatchResultTypeName, discriminatedUnion.Variants)};");
writer.WriteLine();
WriteGeneratedMethod(writer, $"{discriminatedUnion.MethodVisibility} abstract {FormatSwitchMethodDeclaration(discriminatedUnion.Variants)};");

foreach (var variant in discriminatedUnion.Variants)
if (discriminatedUnion.GeneratePartitionExtension)
{
WriteVariant(writer, discriminatedUnion, variant);
writer.WriteLine();
WritePartitionExtensions(discriminatedUnion, writer);
}
}

return stringBuilder.ToString();
}

private static void WriteUnionType(DiscriminatedUnion discriminatedUnion, IndentedTextWriter writer)
{
using var scope = writer.AutoCloseScopes();

WriteParentTypes(writer, discriminatedUnion.ParentTypes);

WriteJsonDerivedTypeAttributes(writer, discriminatedUnion);
writer.WriteLineInterpolated(FormatPartialTypeDeclaration(discriminatedUnion.Type));
writer.OpenScope();

WriteGeneratedMethod(writer, $"{discriminatedUnion.GeneratedMethodOrClassVisibility} abstract {FormatMatchMethodDeclaration(discriminatedUnion.MatchResultTypeName, discriminatedUnion.Variants)};");
writer.WriteLine();
WriteGeneratedMethod(writer, $"{discriminatedUnion.GeneratedMethodOrClassVisibility} abstract {FormatSwitchMethodDeclaration(discriminatedUnion.Variants)};");

foreach (var variant in discriminatedUnion.Variants)
{
WriteVariant(writer, discriminatedUnion, variant);
}
}

private static void WritePartitionExtensions(DiscriminatedUnion discriminatedUnion, IndentedTextWriter writer)
{
using var scope = writer.AutoCloseScopes();

writer.WriteLine(GeneratedCodeAttributeSource);
writer.WriteLineInterpolated($"{discriminatedUnion.GeneratedMethodOrClassVisibility} static partial class {discriminatedUnion.Type.Identifier}EnumerableExtensions");
writer.OpenScope();

WriteTupleReturningPartitionExtension(discriminatedUnion, writer);
writer.WriteLine();
WritePartitionWithResultSelector(discriminatedUnion, writer);
}

private static void WriteTupleReturningPartitionExtension(DiscriminatedUnion discriminatedUnion, IndentedTextWriter writer)
{
using var methodScope = writer.AutoCloseScopes();

var namedResultPartitions = discriminatedUnion
.Variants
.JoinToInterpolation(
v => $"global::System.Collections.Generic.IReadOnlyList<{discriminatedUnion.Type.Identifier}.{v.LocalTypeName}> {v.ParameterName[..1].ToUpper()}{v.ParameterName[1..]}",
", ");

writer.WriteLineInterpolated($"public static ({namedResultPartitions}) Partition(this global::System.Collections.Generic.IEnumerable<{discriminatedUnion.Type.Identifier}> source)");
writer.OpenScope();

WritePartitioningIntoLists(discriminatedUnion, writer);

writer.WriteLineInterpolated($"return ({discriminatedUnion.Variants.JoinToInterpolation(v => $"{v.ParameterName}Items.AsReadOnly()", ", ")});");
}

private static void WritePartitionWithResultSelector(DiscriminatedUnion discriminatedUnion, IndentedTextWriter writer)
{
using var methodScope = writer.AutoCloseScopes();

writer.WriteInterpolated($"public static TResult Partition<{discriminatedUnion.MatchResultTypeName}>(this global::System.Collections.Generic.IEnumerable<{discriminatedUnion.Type.Identifier}> source, global::System.Func<");

foreach (var variant in discriminatedUnion.Variants)
{
writer.WriteInterpolated($"global::System.Collections.Generic.IReadOnlyList<{discriminatedUnion.Type.Identifier}.{variant.LocalTypeName}>, ");
}

writer.WriteLineInterpolated($"{discriminatedUnion.MatchResultTypeName}> resultSelector)");
writer.OpenScope();

WritePartitioningIntoLists(discriminatedUnion, writer);

writer.WriteLineInterpolated($"return resultSelector({discriminatedUnion.Variants.JoinToInterpolation(v => $"{v.ParameterName}Items.AsReadOnly()", ", ")});");
}

private static void WritePartitioningIntoLists(DiscriminatedUnion discriminatedUnion, IndentedTextWriter writer)
{
foreach (var variant in discriminatedUnion.Variants)
{
writer.WriteLineInterpolated($"var {variant.ParameterName}Items = new global::System.Collections.Generic.List<{discriminatedUnion.Type.Identifier}.{variant.LocalTypeName}>();");
}

using (writer.AutoCloseScopes())
{
writer.WriteLine("foreach (var item in source)");
writer.OpenScope();
writer.WriteLineInterpolated($"item.Switch({discriminatedUnion.Variants.JoinToInterpolation(v => $"{v.ParameterName}: {v.ParameterName}Items.Add", ", ")});");
}
}

private static void WriteNamespace(IndentedTextWriter writer, DiscriminatedUnion discriminatedUnion)
{
if (!string.IsNullOrEmpty(discriminatedUnion.Namespace))
{
writer.WriteLine($"namespace {discriminatedUnion.Namespace}");
writer.WriteLineInterpolated($"namespace {discriminatedUnion.Namespace}");
writer.OpenScope();
}
}
Expand All @@ -50,7 +127,7 @@ private static void WriteParentTypes(IndentedTextWriter writer, IEnumerable<Type
{
foreach (var parent in parents.Reverse())
{
writer.WriteLine(FormatPartialTypeDeclaration(parent));
writer.WriteLineInterpolated(FormatPartialTypeDeclaration(parent));
writer.OpenScope();
}
}
Expand All @@ -63,19 +140,19 @@ private static void WriteVariant(IndentedTextWriter writer, DiscriminatedUnion d

WriteParentTypes(writer, variant.ParentTypes);

writer.WriteLine(FormatPartialTypeDeclaration(variant.Type));
writer.WriteLineInterpolated(FormatPartialTypeDeclaration(variant.Type));
writer.OpenScope();

WriteGeneratedMethod(writer, $"{discriminatedUnion.MethodVisibility} override {FormatMatchMethodDeclaration(discriminatedUnion.MatchResultTypeName, discriminatedUnion.Variants)} => {FormatIdentifier(variant.ParameterName)}(this);");
WriteGeneratedMethod(writer, $"{discriminatedUnion.GeneratedMethodOrClassVisibility} override {FormatMatchMethodDeclaration(discriminatedUnion.MatchResultTypeName, discriminatedUnion.Variants)} => {FormatIdentifier(variant.ParameterName)}(this);");
writer.WriteLine();
WriteGeneratedMethod(writer, $"{discriminatedUnion.MethodVisibility} override {FormatSwitchMethodDeclaration(discriminatedUnion.Variants)} => {FormatIdentifier(variant.ParameterName)}(this);");
WriteGeneratedMethod(writer, $"{discriminatedUnion.GeneratedMethodOrClassVisibility} override {FormatSwitchMethodDeclaration(discriminatedUnion.Variants)} => {FormatIdentifier(variant.ParameterName)}(this);");
}
}

private static void WriteGeneratedMethod(IndentedTextWriter writer, string method)
private static void WriteGeneratedMethod(IndentedTextWriter writer, CollectingInterpolatedStringHandler method)
{
writer.WriteLine(GeneratedCodeAttributeSource);
writer.WriteLine(method);
writer.WriteLineInterpolated(method);
}

private static void WriteJsonDerivedTypeAttributes(IndentedTextWriter writer, DiscriminatedUnion discriminatedUnion)
Expand All @@ -90,26 +167,26 @@ private static void WriteJsonDerivedTypeAttribute(IndentedTextWriter writer, Dis
{
if (variant.GenerateJsonDerivedTypeAttribute)
{
writer.WriteLine($"[global::System.Text.Json.Serialization.JsonDerivedType(typeof({variant.TypeOfTypeName}), {SyntaxFactory.Literal(variant.JsonDerivedTypeDiscriminator)})]");
writer.WriteLineInterpolated($"[global::System.Text.Json.Serialization.JsonDerivedType(typeof({variant.TypeOfTypeName}), {SyntaxFactory.Literal(variant.JsonDerivedTypeDiscriminator)})]");
}
}

private static string FormatMatchMethodDeclaration(string genericTypeName, IEnumerable<DiscriminatedUnionVariant> variants)
=> $"{genericTypeName} Match<{genericTypeName}>({string.Join(", ", variants.Select(variant => $"global::System.Func<{variant.LocalTypeName}, {genericTypeName}> {FormatIdentifier(variant.ParameterName)}"))})";
private static CollectingInterpolatedStringHandler FormatMatchMethodDeclaration(string genericTypeName, IEnumerable<DiscriminatedUnionVariant> variants)
=> $"{genericTypeName} Match<{genericTypeName}>({variants.JoinToInterpolation(v => $"global::System.Func<{v.LocalTypeName}, {genericTypeName}> {FormatIdentifier(v.ParameterName)}", ", ")})";

private static string FormatSwitchMethodDeclaration(IEnumerable<DiscriminatedUnionVariant> variants)
=> $"void Switch({string.Join(", ", variants.Select(variant => $"global::System.Action<{variant.LocalTypeName}> {FormatIdentifier(variant.ParameterName)}"))})";
private static CollectingInterpolatedStringHandler FormatSwitchMethodDeclaration(IEnumerable<DiscriminatedUnionVariant> variants)
=> $"void Switch({variants.JoinToInterpolation(v => $"global::System.Action<{v.LocalTypeName}> {FormatIdentifier(v.ParameterName)}", ", ")})";

private static string FormatPartialTypeDeclaration(TypeDeclarationSyntax typeDeclaration)
private static CollectingInterpolatedStringHandler FormatPartialTypeDeclaration(TypeDeclarationSyntax typeDeclaration)
=> typeDeclaration is RecordDeclarationSyntax recordDeclaration
? CombineTokens("partial", typeDeclaration.Keyword, recordDeclaration.ClassOrStructKeyword, typeDeclaration.Identifier.ToString() + typeDeclaration.TypeParameterList, typeDeclaration.ConstraintClauses)
: CombineTokens("partial", typeDeclaration.Keyword, typeDeclaration.Identifier.ToString() + typeDeclaration.TypeParameterList, typeDeclaration.ConstraintClauses);

private static string CombineTokens(params object[] tokens)
=> string.Join(" ", tokens.Select(t => t.ToString()).Where(t => !string.IsNullOrEmpty(t)));
private static CollectingInterpolatedStringHandler CombineTokens(params object[] tokens)
=> tokens.Select(t => t.ToString()).Where(t => !string.IsNullOrEmpty(t)).JoinToInterpolation(" ");

private static string FormatIdentifier(string identifier)
=> IsIdentifier(identifier) ? '@' + identifier : identifier;
private static CollectingInterpolatedStringHandler FormatIdentifier(string identifier)
=> $"{(IsIdentifier(identifier) ? "@" : string.Empty)}{identifier}";

private static bool IsIdentifier(string identifier)
=> SyntaxFacts.GetKeywordKind(identifier) != SyntaxKind.None;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<LangVersion>11.0</LangVersion>
<EnforceExtendedAnalyzerRules>true</EnforceExtendedAnalyzerRules>
</PropertyGroup>
<PropertyGroup Label="NuGet Metadata">
<Version>1.1.0</Version>
<Version>1.2.0</Version>
<PackageId>Funcky.DiscriminatedUnion</PackageId>
<Authors>Polyadic</Authors>
<PackageLicenseExpression>MIT OR Apache-2.0</PackageLicenseExpression>
Expand Down Expand Up @@ -40,6 +41,6 @@
<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="3.3.3" PrivateAssets="all" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.0.1" />
<PackageReference Include="IsExternalInit" Version="1.0.1" PrivateAssets="all" />
<PackageReference Include="PolySharp" Version="1.14.1" PrivateAssets="all" />
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@ public static void OpenScope(this IndentedTextWriter writer)
writer.Indent++;
}

public static void WriteInterpolated(this IndentedTextWriter writer, CollectingInterpolatedStringHandler value)
{
foreach (var item in value.GetItems())
{
writer.Write(item?.ToString());
}
}

public static void WriteLineInterpolated(this IndentedTextWriter writer, CollectingInterpolatedStringHandler value)
{
writer.WriteInterpolated(value);
writer.WriteLine();
}

private static void CloseScope(this IndentedTextWriter writer)
{
writer.Indent--;
Expand Down
14 changes: 10 additions & 4 deletions Funcky.DiscriminatedUnion.SourceGeneration/Parser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,16 @@ public static bool IsSyntaxTargetForGeneration(SyntaxNode node)
return null;
}

var (nonExhaustive, flatten, matchResultType) = ParseAttribute(typeSymbol);
var (nonExhaustive, flatten, matchResultType, generatePartitionExtension) = ParseAttribute(typeSymbol);
var isVariant = flatten ? IsVariantOfDiscriminatedUnionFlattened(typeSymbol, semanticModel) : IsVariantOfDiscriminatedUnion(typeSymbol, semanticModel);

return new DiscriminatedUnion(
Type: typeDeclaration,
ParentTypes: typeDeclaration.Ancestors().OfType<TypeDeclarationSyntax>().ToList(),
Namespace: FormatNamespace(typeSymbol),
MatchResultTypeName: matchResultType ?? "TResult",
MethodVisibility: nonExhaustive ? "internal" : "public",
GeneratedMethodOrClassVisibility: nonExhaustive ? "internal" : "public",
GeneratePartitionExtension: generatePartitionExtension,
Variants: GetVariantTypeDeclarations(typeDeclaration, isVariant)
.Select(GetDiscriminatedUnionVariant(typeDeclaration, semanticModel, GenerateJsonDerivedTypeAttribute(typeSymbol)))
.ToList());
Expand All @@ -50,7 +51,8 @@ private static DiscriminatedUnionAttributeData ParseAttribute(ITypeSymbol type)
var nonExhaustive = attribute.GetNamedArgumentOrDefault<bool>(AttributeProperties.NonExhaustive);
var flatten = attribute.GetNamedArgumentOrDefault<bool>(AttributeProperties.Flatten);
var matchResultType = attribute.GetNamedArgumentOrDefault<string>(AttributeProperties.MatchResultTypeName);
return new DiscriminatedUnionAttributeData(nonExhaustive, flatten, matchResultType);
var generatePartitionExtension = attribute.GetNamedArgumentOrDefault<bool>(AttributeProperties.GeneratePartitionExtension);
return new DiscriminatedUnionAttributeData(nonExhaustive, flatten, matchResultType, generatePartitionExtension);
}

private static string? FormatNamespace(INamedTypeSymbol typeSymbol)
Expand Down Expand Up @@ -119,7 +121,11 @@ private static Func<AttributeSyntax, bool> IsDiscriminatedUnionAttribute(Generat
=> context.SemanticModel.GetSymbolInfo(attribute, cancellationToken).Symbol is IMethodSymbol attributeSymbol
&& attributeSymbol.ContainingType.ToDisplayString() == AttributeFullName;

private sealed record DiscriminatedUnionAttributeData(bool NonExhaustive, bool Flatten, string? MatchResultType);
private sealed record DiscriminatedUnionAttributeData(
bool NonExhaustive,
bool Flatten,
string? MatchResultType,
bool GeneratePartitionExtension);

private sealed class VariantCollectingVisitor : CSharpSyntaxWalker
{
Expand Down
Loading

0 comments on commit 8d8d375

Please sign in to comment.