Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Also implement ISerializeWrap on enum wrappers #143

Merged
merged 1 commit into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions src/generator/Generator.Impl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,13 @@ internal static void GenerateImpl(

ITypeSymbol receiverType;
ExpressionSyntax receiverExpr;
bool wrapper;
string? wrapperName;
string? wrappedName;
// If the Through property is set, then we are implementing a wrapper type
if (attributeData.NamedArguments is [ (nameof(GenerateSerialize.Through), { Value: string memberName }) ])
{
wrapper = true;
var members = model.LookupSymbols(typeDecl.SpanStart, typeSymbol, memberName);
if (members.Length != 1)
{
Expand All @@ -83,26 +87,24 @@ internal static void GenerateImpl(
}
receiverType = SymbolUtilities.GetSymbolType(members[0]);
receiverExpr = IdentifierName(memberName);

if (usage.HasFlag(SerdeUsage.Serialize))
{
// If we're implementing ISerialize, also implement ISerializeWrap
GenerateISerializeWrapImpl(
typeDecl.Identifier.ValueText,
receiverType.ToDisplayString(),
typeDecl,
context);
}
wrapperName = typeDecl.Identifier.ValueText;
wrappedName = receiverType.ToDisplayString();
}
// Enums are also always wrapped, but the attribute is on the enum itself
else if (typeDecl.IsKind(SyntaxKind.EnumDeclaration))
{
wrapper = true;
receiverType = typeSymbol;
receiverExpr = IdentifierName("Value");
wrappedName = typeDecl.Identifier.ValueText;
wrapperName = GetWrapperName(wrappedName);
}
// Just a normal interface implementation
else
{
wrapper = false;
wrapperName = null;
wrappedName = null;
if (!typeDecl.Modifiers.Any(tok => tok.IsKind(SyntaxKind.PartialKeyword)))
{
// Type must be partial
Expand All @@ -116,6 +118,16 @@ internal static void GenerateImpl(
receiverExpr = ThisExpression();
}

if (wrapper && usage.HasFlag(SerdeUsage.Serialize))
{
// If we're implementing ISerialize, also implement ISerializeWrap
GenerateISerializeWrapImpl(
wrapperName!,
wrappedName!,
typeDecl,
context);
}

GenerateImpl(
usage,
new TypeDeclContext(typeDecl),
Expand All @@ -141,7 +153,7 @@ private static void GenerateEnumWrapper(
var typeName = typeDeclContext.Name;
var wrapperName = GetWrapperName(typeName);
var newType = SyntaxFactory.ParseMemberDeclaration($"""
internal readonly partial record struct {wrapperName}({typeName} Value);
readonly partial record struct {wrapperName}({typeName} Value);
""")!;
newType = typeDeclContext.WrapNewType(newType);
string fullWrapperName = string.Join(".", typeDeclContext.NamespaceNames
Expand All @@ -166,9 +178,9 @@ private static void GenerateISerializeWrapImpl(
{
var typeDeclContext = new TypeDeclContext(typeDecl);
var newType = SyntaxFactory.ParseMemberDeclaration($$"""
partial record struct {{wrapperName}}({{wrappedName}} Value) : Serde.ISerializeWrap<{{wrappedName}}, {{wrapperName}}>
partial record struct {{wrapperName}} : Serde.ISerializeWrap<{{wrappedName}}, {{wrapperName}}>
{
{{wrapperName}} Serde.ISerializeWrap<{{wrappedName}}, {{wrapperName}}>.Wrap({{wrappedName}} value) => new(value);
static {{wrapperName}} Serde.ISerializeWrap<{{wrappedName}}, {{wrapperName}}>.Create({{wrappedName}} value) => new(value);
}
""")!;
newType = typeDeclContext.WrapNewType(newType);
Expand All @@ -195,10 +207,6 @@ private static void GenerateImpl(
ImmutableList<ITypeSymbol> inProgress)
{
var typeName = typeDeclContext.Name;
string fullTypeName = string.Join(".", typeDeclContext.NamespaceNames
.Concat(typeDeclContext.ParentTypeInfo.Select(x => x.Name))
.Concat(new[] { typeName }));


// Generate statements for the implementation
var (implMembers, baseList) = usage switch
Expand Down Expand Up @@ -228,6 +236,7 @@ private static void GenerateImpl(
members: List(implMembers),
closeBraceToken: Token(SyntaxKind.CloseBraceToken),
semicolonToken: default);
typeName = wrapperName;
}
else
{
Expand All @@ -252,6 +261,10 @@ private static void GenerateImpl(
closeBraceToken: Token(SyntaxKind.CloseBraceToken),
semicolonToken: default);
}
string fullTypeName = string.Join(".", typeDeclContext.NamespaceNames
.Concat(typeDeclContext.ParentTypeInfo.Select(x => x.Name))
.Concat(new[] { typeName }));

var srcName = fullTypeName + "." + usage.GetInterfaceName();

newType = typeDeclContext.WrapNewType(newType);
Expand Down
13 changes: 7 additions & 6 deletions test/Serde.Generation.Test/DeserializeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,31 @@ public Task NestedExplicitDeserializeWrapper()

using Serde;
using System.Collections.Immutable;
using System.Collections.Specialized;
using System.Runtime.InteropServices.ComTypes;

[GenerateDeserialize(Through = nameof(Value))]
readonly partial record struct SectionWrap(BitVector32.Section Value);
readonly partial record struct OptsWrap(BIND_OPTS Value);

[GenerateDeserialize]
partial struct S
{
[SerdeMemberOptions(WrapperDeserialize = typeof(ImmutableArrayWrap.DeserializeImpl<BitVector32.Section, SectionWrap>))]
public ImmutableArray<BitVector32.Section> Sections;
[SerdeMemberOptions(WrapperDeserialize = typeof(ImmutableArrayWrap.DeserializeImpl<BIND_OPTS, OptsWrap>))]
public ImmutableArray<BIND_OPTS> Opts;
}

""";
return VerifyMultiFile(src);
}

[Fact]
public Task DeserializeOnlyWrap()
{
var src = """
using Serde;
using System.Collections.Specialized;
using System.Runtime.InteropServices.ComTypes;

[GenerateDeserialize(Through = nameof(Value))]
readonly partial record struct SectionWrap(BitVector32.Section Value);
readonly partial record struct Wrap(BIND_OPTS Value);

""";
return VerifyDeserialize(src);
Expand Down
67 changes: 60 additions & 7 deletions test/Serde.Generation.Test/GeneratorTestUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
using System.Collections.Immutable;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Encodings.Web;
using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
Expand Down Expand Up @@ -42,12 +45,12 @@ public static Task VerifyGeneratedCode(
return settings;
});

public static Task<VerifyResult> VerifyMultiFile(string src, MetadataReference[]? additionalRefs = null)
public static Task<VerifyResult[]> VerifyMultiFile(string src, MetadataReference[]? additionalRefs = null)
{
return VerifyGeneratedCode(src, s_cachedSettings.Value, additionalRefs);
}

public static Task<VerifyResult> VerifyGeneratedCode(
public static Task<VerifyResult[]> VerifyGeneratedCode(
string src,
string directoryName,
string testMethodName,
Expand All @@ -63,14 +66,64 @@ public static Task<VerifyResult> VerifyGeneratedCode(
return VerifyGeneratedCode(src, settings);
}

public static async Task<VerifyResult> VerifyGeneratedCode(string src, VerifySettings settings, MetadataReference[]? additionalRefs = null)
public static async Task<VerifyResult[]> VerifyGeneratedCode(
string src,
VerifySettings settings,
MetadataReference[]? additionalRefs = null)
{
var generatorInstance = new SerdeImplRoslynGenerator();
GeneratorDriver driver = CSharpGeneratorDriver.Create(generatorInstance);
var comp = await CreateCompilation(src, additionalRefs);
driver = driver.RunGenerators(comp);
var results = await Verifier.Verify(driver, settings);
return results;
Compilation comp = await CreateCompilation(src, additionalRefs);
driver = driver.RunGeneratorsAndUpdateCompilation(comp, out comp, out _);
var verify = Verifier.Verify(driver, settings);
var diags = comp.GetDiagnostics().Where(d => d.Severity == DiagnosticSeverity.Error).ToList();
if (diags.Any())
{
return new[] { await verify.AppendContentAsFile(SerializeDiagnostics(diags), name: "FinalDiagnostics") };
}
else
{
return new[] { await verify };
}
}

private static string SerializeDiagnostics(IEnumerable<Diagnostic> diags)
{
using var stream = new MemoryStream();
using var writer = new Utf8JsonWriter(stream, new JsonWriterOptions {
Indented = true,
Encoder = (JavaScriptEncoder?)JavaScriptEncoder.UnsafeRelaxedJsonEscaping });

writer.WriteStartArray();
foreach (var diag in diags)
{
writer.WriteStartObject();
writer.WriteString("Id", diag.Id);
var descriptor = diag.Descriptor;
writer.WriteString("Title", descriptor.Title.ToString());
writer.WriteString("Severity", diag.Severity.ToString());
writer.WriteString("WarningLevel", diag.WarningLevel.ToString());
writer.WriteString("Location", diag.Location.GetMappedLineSpan().ToString());
var description = descriptor.Description.ToString();
if (!string.IsNullOrWhiteSpace(description))
{
writer.WriteString("Description", description);
}

var help = descriptor.HelpLinkUri;
if (!string.IsNullOrWhiteSpace(help))
{
writer.WriteString("HelpLink", help);
}

writer.WriteString("MessageFormat", descriptor.MessageFormat.ToString());
writer.WriteString("Message", diag.GetMessage());
writer.WriteString("Category", descriptor.Category);
writer.WriteEndObject();
}
writer.WriteEndArray();
writer.Flush();
return Encoding.UTF8.GetString(stream.ToArray());
}

public static async Task<CSharpCompilation> CreateCompilation(string src, MetadataReference[]? additionalRefs = null)
Expand Down
32 changes: 26 additions & 6 deletions test/Serde.Generation.Test/SerializeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,18 @@ public Task NestedExplicitSerializeWrapper()

using Serde;
using System.Collections.Immutable;
using System.Collections.Specialized;
using System.Runtime.InteropServices.ComTypes;

[GenerateSerialize(Through = nameof(Value))]
readonly partial record struct SectionWrap(BitVector32.Section Value);
[GenerateSerde(Through = nameof(Value))]
readonly partial record struct OPTSWrap(BIND_OPTS Value);

[GenerateSerialize]
[GenerateSerde]
partial struct S
{
[SerdeMemberOptions(WrapperSerialize = typeof(ImmutableArrayWrap.SerializeImpl<BitVector32.Section, SectionWrap>))]
public ImmutableArray<BitVector32.Section> Sections;
[SerdeMemberOptions(
WrapperSerialize = typeof(ImmutableArrayWrap.SerializeImpl<BIND_OPTS, OPTSWrap>),
WrapperDeserialize = typeof(ImmutableArrayWrap.DeserializeImpl<BIND_OPTS, OPTSWrap>))]
public ImmutableArray<BIND_OPTS> Opts;
}

""";
Expand Down Expand Up @@ -520,6 +522,24 @@ public enum ColorULong : ulong { Red = 3, Green = 5, Blue = 7 }
return VerifyMultiFile(src);
}

[Fact]
public Task NestedEnumWrapper()
{
var src = """
using Serde;

[GenerateSerialize]
partial class C
{
public Rgb? ColorOpt;
}

[GenerateSerialize]
public enum Rgb { Red, Green, Blue }
""";
return VerifyMultiFile(src);
}

private static Task VerifySerialize(
string src,
[CallerMemberName] string callerName = "")
Expand Down
29 changes: 15 additions & 14 deletions test/Serde.Generation.Test/WrapperTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,15 @@ public Task NestedExplicitWrapper()

partial class Outer
{
[GenerateSerde(Through = nameof(Value))]
readonly partial record struct SectionWrap(BitVector32.Section Value);
[GenerateSerialize(Through = nameof(Value))]
public readonly partial record struct SectionWrap(BitVector32.Section Value);
}

[GenerateSerde]
[GenerateSerialize]
partial struct S
{
[SerdeMemberOptions(
WrapperSerialize = typeof(ImmutableArrayWrap.SerializeImpl<BitVector32.Section, Outer.SectionWrap>),
WrapperDeserialize = typeof(ImmutableArrayWrap.DeserializeImpl<BitVector32.Section, Outer.SectionWrap>))]
WrapperSerialize = typeof(ImmutableArrayWrap.SerializeImpl<BitVector32.Section, Outer.SectionWrap>))]
public ImmutableArray<BitVector32.Section> Sections;
}
""";
Expand All @@ -42,11 +41,11 @@ partial struct S
public Task GenerateSerdeWrap()
{
var src = """
using System.Collections.Specialized;
using System.Runtime.InteropServices.ComTypes;
using Serde;

[GenerateSerde(Through = nameof(Value))]
readonly partial record struct SectionWrap(BitVector32.Section Value);
readonly partial record struct OPTSWrap(BIND_OPTS Value);

""";
return VerifyMultiFile(src);
Expand Down Expand Up @@ -113,16 +112,16 @@ public PointWrap(Point point)
public Task NestedDeserializeWrap()
{
var src = @"
using System.Collections.Specialized;
using System.Runtime.InteropServices.ComTypes;

[Serde.GenerateWrapper(nameof(Value))]
internal readonly partial record struct SectionWrap(BitVector32.Section Value);
internal readonly partial record struct OPTSWrap(BIND_OPTS Value);

[Serde.GenerateDeserialize]
partial class C
{
[Serde.SerdeWrap(typeof(SectionWrap))]
public BitVector32.Section S = new BitVector32.Section();
[Serde.SerdeWrap(typeof(OPTSWrap))]
public BIND_OPTS S = new BIND_OPTS();
}";
return VerifyMultiFile(src);
}
Expand Down Expand Up @@ -255,9 +254,11 @@ namespace Test;
internal partial record struct RecursiveWrap(Recursive Value);

[GenerateSerde]
public partial record Parent(
[property: SerdeWrap(typeof(RecursiveWrap))]
Recursive R);
public partial record Parent
{
[SerdeWrap(typeof(RecursiveWrap))]
public Recursive R { get; init; }
}
""";
await VerifyMultiFile(src, new[] { comp.EmitToImageReference() });
}
Expand Down
Loading
Loading