Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to generate specialized partition implementation #15

Merged
merged 19 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ internal sealed record DiscriminatedUnion(
string? Namespace,
string MethodVisibility,
string MatchResultTypeName,
IReadOnlyList<DiscriminatedUnionVariant> Variants);
IReadOnlyList<DiscriminatedUnionVariant> Variants,
bool GeneratePartitionExtension);

internal sealed record DiscriminatedUnionVariant(
TypeDeclarationSyntax Type,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Funcky.DiscriminatedUnion.SourceGeneration.Emitter;
Expand Down Expand Up @@ -32,8 +33,12 @@ private static void AddSource(SourceProductionContext context, ImmutableArray<st
{
if (code.Any())
{
var combinedCode = $"{GeneratedFileHeadersSource}{Environment.NewLine}{Environment.NewLine}" +
$"{string.Join(Environment.NewLine, code)}";
const string newline = """
bash marked this conversation as resolved.
Show resolved Hide resolved


""";

var combinedCode = $"{GeneratedFileHeadersSource}{newline}{newline}{string.Join(newline, code)}";
context.AddSource("DiscriminatedUnionGenerator.g.cs", combinedCode);
}
}
Expand Down
79 changes: 67 additions & 12 deletions Funcky.DiscriminatedUnion.SourceGeneration/Emitter.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System.CodeDom.Compiler;
Expand All @@ -18,25 +17,81 @@ public static string EmitDiscriminatedUnion(DiscriminatedUnion discriminatedUnio
{
WriteNamespace(writer, discriminatedUnion);

WriteParentTypes(writer, discriminatedUnion.ParentTypes);
WriteUnionType(discriminatedUnion, writer);

WriteJsonDerivedTypeAttributes(writer, discriminatedUnion);
writer.WriteLine(FormatPartialTypeDeclaration(discriminatedUnion.Type));
writer.OpenScope();

WriteGeneratedMethod(writer, $"{discriminatedUnion.MethodVisibility} abstract {FormatMatchMethodDeclaration(discriminatedUnion.MatchResultTypeName, discriminatedUnion.Variants)};");
writer.WriteLine();
WriteGeneratedMethod(writer, $"{discriminatedUnion.MethodVisibility} abstract {FormatSwitchMethodDeclaration(discriminatedUnion.Variants)};");

foreach (var variant in discriminatedUnion.Variants)
if (discriminatedUnion.GeneratePartitionExtension)
{
WriteVariant(writer, discriminatedUnion, variant);
writer.WriteLine();
WritePartitionExtension(discriminatedUnion, writer);
}
}

return stringBuilder.ToString();
}

private static void WriteUnionType(DiscriminatedUnion discriminatedUnion, IndentedTextWriter writer)
{
using var scope = writer.AutoCloseScopes();

WriteParentTypes(writer, discriminatedUnion.ParentTypes);

WriteJsonDerivedTypeAttributes(writer, discriminatedUnion);
writer.WriteLine(FormatPartialTypeDeclaration(discriminatedUnion.Type));
writer.OpenScope();

WriteGeneratedMethod(writer, $"{discriminatedUnion.MethodVisibility} abstract {FormatMatchMethodDeclaration(discriminatedUnion.MatchResultTypeName, discriminatedUnion.Variants)};");
writer.WriteLine();
WriteGeneratedMethod(writer, $"{discriminatedUnion.MethodVisibility} abstract {FormatSwitchMethodDeclaration(discriminatedUnion.Variants)};");

foreach (var variant in discriminatedUnion.Variants)
{
WriteVariant(writer, discriminatedUnion, variant);
}
}

private static void WritePartitionExtension(DiscriminatedUnion discriminatedUnion, IndentedTextWriter writer)
{
using var scope = writer.AutoCloseScopes();

writer.WriteLine($"{discriminatedUnion.MethodVisibility} static class {discriminatedUnion.Type.Identifier}EnumerableExtensions");
writer.OpenScope();

writer.Write("public record struct Partitions(");
var partitionVariants = discriminatedUnion
.Variants
.Select(v => $"System.Collections.Generic.IReadOnlyList<{discriminatedUnion.Type.Identifier}.{v.LocalTypeName}> {v.ParameterName}");
writer.Write(string.Join(", ", partitionVariants));
writer.WriteLine(");");

writer.WriteLine();

writer.WriteLine($"public static Partitions Partition(this System.Collections.Generic.IEnumerable<{discriminatedUnion.Type.Identifier}> source)");
writer.OpenScope();

foreach (var variant in discriminatedUnion.Variants)
{
writer.WriteLine($"var {variant.ParameterName}Items = System.Collections.Immutable.ImmutableList.CreateBuilder<{discriminatedUnion.Type.Identifier}.{variant.LocalTypeName}>();");
}
bash marked this conversation as resolved.
Show resolved Hide resolved

using (writer.AutoCloseScopes())
{
writer.WriteLine("foreach (var item in source)");
writer.OpenScope();
writer.Write("item.Switch(");

var assignmentVariants = discriminatedUnion
.Variants
.Select(v => $"{v.ParameterName}: {v.ParameterName}Items.Add");

writer.Write(string.Join(", ", assignmentVariants));

writer.WriteLine(");");
}

var items = discriminatedUnion.Variants.Select(v => $"{v.ParameterName}Items.ToImmutable()");
writer.WriteLine($"return new({string.Join(", ", items)});");
}

private static void WriteNamespace(IndentedTextWriter writer, DiscriminatedUnion discriminatedUnion)
{
if (!string.IsNullOrEmpty(discriminatedUnion.Namespace))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<LangVersion>11.0</LangVersion>
<EnforceExtendedAnalyzerRules>true</EnforceExtendedAnalyzerRules>
</PropertyGroup>
<PropertyGroup Label="NuGet Metadata">
<Version>1.1.0</Version>
<Version>1.2.0</Version>
<PackageId>Funcky.DiscriminatedUnion</PackageId>
<Authors>Polyadic</Authors>
<PackageLicenseExpression>MIT OR Apache-2.0</PackageLicenseExpression>
Expand Down Expand Up @@ -38,8 +39,8 @@
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="1.0.0" PrivateAssets="All" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="3.3.3" PrivateAssets="all" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.0.1" />
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="3.3.4" PrivateAssets="all" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.11.0" />
bash marked this conversation as resolved.
Show resolved Hide resolved
<PackageReference Include="IsExternalInit" Version="1.0.1" PrivateAssets="all" />
</ItemGroup>
</Project>
12 changes: 9 additions & 3 deletions Funcky.DiscriminatedUnion.SourceGeneration/Parser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public static bool IsSyntaxTargetForGeneration(SyntaxNode node)
return null;
}

var (nonExhaustive, flatten, matchResultType) = ParseAttribute(typeSymbol);
var (nonExhaustive, flatten, matchResultType, generatePartitionExtension) = ParseAttribute(typeSymbol);
var isVariant = flatten ? IsVariantOfDiscriminatedUnionFlattened(typeSymbol, semanticModel) : IsVariantOfDiscriminatedUnion(typeSymbol, semanticModel);

return new DiscriminatedUnion(
Expand All @@ -39,6 +39,7 @@ public static bool IsSyntaxTargetForGeneration(SyntaxNode node)
Namespace: FormatNamespace(typeSymbol),
MatchResultTypeName: matchResultType ?? "TResult",
MethodVisibility: nonExhaustive ? "internal" : "public",
GeneratePartitionExtension: generatePartitionExtension,
Variants: GetVariantTypeDeclarations(typeDeclaration, isVariant)
.Select(GetDiscriminatedUnionVariant(typeDeclaration, semanticModel, GenerateJsonDerivedTypeAttribute(typeSymbol)))
.ToList());
Expand All @@ -50,7 +51,8 @@ private static DiscriminatedUnionAttributeData ParseAttribute(ITypeSymbol type)
var nonExhaustive = attribute.GetNamedArgumentOrDefault<bool>(AttributeProperties.NonExhaustive);
var flatten = attribute.GetNamedArgumentOrDefault<bool>(AttributeProperties.Flatten);
var matchResultType = attribute.GetNamedArgumentOrDefault<string>(AttributeProperties.MatchResultTypeName);
return new DiscriminatedUnionAttributeData(nonExhaustive, flatten, matchResultType);
var generatePartitionExtension = attribute.GetNamedArgumentOrDefault<bool>(AttributeProperties.GeneratePartitionExtension);
return new DiscriminatedUnionAttributeData(nonExhaustive, flatten, matchResultType, generatePartitionExtension);
}

private static string? FormatNamespace(INamedTypeSymbol typeSymbol)
Expand Down Expand Up @@ -119,7 +121,11 @@ private static Func<AttributeSyntax, bool> IsDiscriminatedUnionAttribute(Generat
=> context.SemanticModel.GetSymbolInfo(attribute, cancellationToken).Symbol is IMethodSymbol attributeSymbol
&& attributeSymbol.ContainingType.ToDisplayString() == AttributeFullName;

private sealed record DiscriminatedUnionAttributeData(bool NonExhaustive, bool Flatten, string? MatchResultType);
private sealed record DiscriminatedUnionAttributeData(
bool NonExhaustive,
bool Flatten,
string? MatchResultType,
bool GeneratePartitionExtension);

private sealed class VariantCollectingVisitor : CSharpSyntaxWalker
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ internal sealed class DiscriminatedUnionAttribute : global::System.Attribute
/// <summary>Generates exhaustive <c>Match</c> and <c>Switch</c> methods for the entire type hierarchy.</summary>
public bool {{AttributeProperties.Flatten}} { get; set; }

/// <summary>If a specialized partition extension method for <c>IEnumerable<YourType></c> should be generated. Defaults to false.</summary>
public bool {{AttributeProperties.GeneratePartitionExtension}} { get; set; }

/// <summary>Customized the generic type name used for the result in the generated <c>Match</c> methods. Defaults to <c>TResult</c>.</summary>
public string? {{AttributeProperties.MatchResultTypeName}} { get; set; }
}
Expand All @@ -46,6 +49,7 @@ public static class AttributeProperties
{
public const string NonExhaustive = "NonExhaustive";
public const string Flatten = "Flatten";
public const string GeneratePartitionExtension = "GeneratePartitionExtension";
public const string MatchResultTypeName = "MatchResultTypeName";
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net7.0</TargetFramework>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<LangVersion>11.0</LangVersion>
<LangVersion>12.0</LangVersion>
<IsPackable>false</IsPackable>
</PropertyGroup>
<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ internal sealed class DiscriminatedUnionAttribute : global::System.Attribute
/// <summary>Generates exhaustive <c>Match</c> and <c>Switch</c> methods for the entire type hierarchy.</summary>
public bool Flatten { get; set; }

/// <summary>If a specialized partition extension method for <c>IEnumerable<YourType></c> should be generated. Defaults to false.</summary>
bash marked this conversation as resolved.
Show resolved Hide resolved
public bool GeneratePartitionExtension { get; set; }

/// <summary>Customized the generic type name used for the result in the generated <c>Match</c> methods. Defaults to <c>TResult</c>.</summary>
public string? MatchResultTypeName { get; set; }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@ partial class StaticClass
{
partial record NestedUnion
{
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.1.0.0")]
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.2.0.0")]
bash marked this conversation as resolved.
Show resolved Hide resolved
public abstract TResult Match<TResult>(global::System.Func<Variant, TResult> variant);

[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.1.0.0")]
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.2.0.0")]
public abstract void Switch(global::System.Action<Variant> variant);

partial record Variant
{
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.1.0.0")]
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.2.0.0")]
public override TResult Match<TResult>(global::System.Func<Variant, TResult> variant) => variant(this);

[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.1.0.0")]
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.2.0.0")]
public override void Switch(global::System.Action<Variant> variant) => variant(this);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ internal sealed class DiscriminatedUnionAttribute : global::System.Attribute
/// <summary>Generates exhaustive <c>Match</c> and <c>Switch</c> methods for the entire type hierarchy.</summary>
public bool Flatten { get; set; }

/// <summary>If a specialized partition extension method for <c>IEnumerable<YourType></c> should be generated. Defaults to false.</summary>
public bool GeneratePartitionExtension { get; set; }

/// <summary>Customized the generic type name used for the result in the generated <c>Match</c> methods. Defaults to <c>TResult</c>.</summary>
public string? MatchResultTypeName { get; set; }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

partial record EmptyUnion
{
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.1.0.0")]
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.2.0.0")]
public abstract TResult Match<TResult>();

[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.1.0.0")]
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.2.0.0")]
public abstract void Switch();
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ internal sealed class DiscriminatedUnionAttribute : global::System.Attribute
/// <summary>Generates exhaustive <c>Match</c> and <c>Switch</c> methods for the entire type hierarchy.</summary>
public bool Flatten { get; set; }

/// <summary>If a specialized partition extension method for <c>IEnumerable<YourType></c> should be generated. Defaults to false.</summary>
public bool GeneratePartitionExtension { get; set; }

/// <summary>Customized the generic type name used for the result in the generated <c>Match</c> methods. Defaults to <c>TResult</c>.</summary>
public string? MatchResultTypeName { get; set; }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,27 @@ namespace Funcky.DiscriminatedUnion.Test
{
partial record Result<T> where T : notnull
{
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.1.0.0")]
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.2.0.0")]
internal abstract TResult Match<TResult>(global::System.Func<Ok, TResult> ok, global::System.Func<Error, TResult> error);

[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.1.0.0")]
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.2.0.0")]
internal abstract void Switch(global::System.Action<Ok> ok, global::System.Action<Error> error);

partial record Ok
{
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.1.0.0")]
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.2.0.0")]
internal override TResult Match<TResult>(global::System.Func<Ok, TResult> ok, global::System.Func<Error, TResult> error) => ok(this);

[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.1.0.0")]
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.2.0.0")]
internal override void Switch(global::System.Action<Ok> ok, global::System.Action<Error> error) => ok(this);
}

partial record Error
{
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.1.0.0")]
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.2.0.0")]
internal override TResult Match<TResult>(global::System.Func<Ok, TResult> ok, global::System.Func<Error, TResult> error) => error(this);

[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.1.0.0")]
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.2.0.0")]
internal override void Switch(global::System.Action<Ok> ok, global::System.Action<Error> error) => error(this);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ internal sealed class DiscriminatedUnionAttribute : global::System.Attribute
/// <summary>Generates exhaustive <c>Match</c> and <c>Switch</c> methods for the entire type hierarchy.</summary>
public bool Flatten { get; set; }

/// <summary>If a specialized partition extension method for <c>IEnumerable<YourType></c> should be generated. Defaults to false.</summary>
public bool GeneratePartitionExtension { get; set; }

/// <summary>Customized the generic type name used for the result in the generated <c>Match</c> methods. Defaults to <c>TResult</c>.</summary>
public string? MatchResultTypeName { get; set; }
}
Expand Down
Loading
Loading