Skip to content

Commit

Permalink
Implement Deserialize using TypeInfo (#165)
Browse files Browse the repository at this point in the history
Adjusts the source generator to generate the Deserialize implementation
using TypeInfo instead of Visitors.
  • Loading branch information
agocke authored Jun 15, 2024
1 parent 4de6d0e commit 015ea73
Show file tree
Hide file tree
Showing 102 changed files with 2,272 additions and 3,456 deletions.
2 changes: 1 addition & 1 deletion perf/bench/SampleTypes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public partial record LocationWrap : IDeserialize<Location>

var typeDeserialize = deserializer.DeserializeType(s_fieldMap);
int index;
while ((index = typeDeserialize.TryReadIndex(s_fieldMap)) != IDeserializeType.EndOfType)
while ((index = typeDeserialize.TryReadIndex(s_fieldMap, out _)) != IDeserializeType.EndOfType)
{
switch (index)
{
Expand Down
4 changes: 4 additions & 0 deletions src/generator/DataMemberSymbol.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ static bool IsNullable(ISymbol symbol)

public string Name => Symbol.Name;

public bool SkipDeserialize => _memberOptions.SkipDeserialize;

public bool SkipSerialize => _memberOptions.SkipSerialize;

public bool ThrowIfMissing => _memberOptions.ThrowIfMissing;

public bool ProvideAttributes => _memberOptions.ProvideAttributes;
Expand Down
174 changes: 162 additions & 12 deletions src/generator/Generator.Deserialize.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis;
Expand Down Expand Up @@ -31,15 +32,163 @@ internal static (MemberDeclarationSyntax[], BaseListSyntax) GenerateDeserializeI
));

// Generate members for ISerialize.Deserialize implementation
var method = GenerateDeserializeMethod(context, interfaceSyntax, receiverType);
var visitorType = GenerateVisitor(receiverType, typeSyntax, context, inProgress);
var members = new MemberDeclarationSyntax[] { method, visitorType };
MemberDeclarationSyntax[] members;
if (receiverType.TypeKind == TypeKind.Enum)
{
var method = GenerateOldDeserializeMethod(context, interfaceSyntax, receiverType);
var visitorType = GenerateVisitor(receiverType, typeSyntax, context, inProgress);
members = [ method, visitorType ];
}
else
{
var method = GenerateCustomDeserializeMethod(context, receiverType, typeSyntax, inProgress);
members = [ method ];
}
var baseList = BaseList(SeparatedList(new BaseTypeSyntax[] { SimpleBaseType(interfaceSyntax) }));
return (members, baseList);
}

/// <summary>
/// Generates
/// <code>
/// T IDeserialize&lt;T&gt;.Deserialize(IDeserializer deserializer)
/// {
/// var _local1 = default!;
/// ...
/// var _localN = default!;
///
/// var typeInfo = {typeName}SerdeTypeInfo.TypeInfo;
/// var typDeserializer = deserializer.DeserializeType(typeInfo);
/// int index;
/// while ((index = typeDeserialize.TryReadIndex(typeInfo)) != IDeserializeType.EndOfType)
/// {
/// switch (index)
/// {
/// }
/// }
/// }
/// </code>
/// </summary>
private static MethodDeclarationSyntax GenerateCustomDeserializeMethod(
GeneratorExecutionContext context,
ITypeSymbol type,
TypeSyntax typeSyntax,
ImmutableList<ITypeSymbol> inProgress)
{
Debug.Assert(type.TypeKind != TypeKind.Enum);

var members = SymbolUtilities.GetDataMembers(type, SerdeUsage.Both);
var typeFqn = typeSyntax.ToString();
var assignedVarType = members.Count switch {
<= 8 => "byte",
<= 16 => "ushort",
<= 32 => "uint",
<= 64 => "ulong",
_ => throw new InvalidOperationException("Too many members in type")
};
var (cases, locals, assignedMask) = InitCasesAndLocals();
string typeCreationExpr = GenerateTypeCreation(context, typeFqn, type, members);

const string typeInfoLocalName = "_l_typeInfo";
const string indexLocalName = "_l_index_";

var methodText = $$"""
static {{typeFqn}} Serde.IDeserialize<{{typeFqn}}>.Deserialize(IDeserializer deserializer)
{
{{locals}}
{{assignedVarType}} {{AssignedVarName}} = {{assignedMask}};

var {{typeInfoLocalName}} = {{type.Name}}SerdeTypeInfo.TypeInfo;
var typeDeserialize = deserializer.DeserializeType({{typeInfoLocalName}});
int {{indexLocalName}};
while (({{indexLocalName}} = typeDeserialize.TryReadIndex({{typeInfoLocalName}}, out var _l_errorName)) != IDeserializeType.EndOfType)
{
switch ({{indexLocalName}})
{
{{cases}}
}
}
{{typeCreationExpr}}
return newType;
}
""";
return (MethodDeclarationSyntax)ParseMemberDeclaration(methodText)!;

(string Cases, string Locals, string AssignedMask) InitCasesAndLocals()
{
var casesBuilder = new StringBuilder();
var localsBuilder = new StringBuilder();
long assignedMaskValue = 0;
for (int i = 0; i < members.Count; i++)
{
if (members[i].SkipDeserialize)
{
continue;
}

var m = members[i];
string wrapperName;
var memberType = m.Type.WithNullableAnnotation(m.NullableAnnotation).ToDisplayString();
if (TryGetExplicitWrapper(m, context, SerdeUsage.Deserialize, inProgress) is { } explicitWrap)
{
wrapperName = explicitWrap.ToString();
}
else if (ImplementsSerde(m.Type, context, SerdeUsage.Deserialize))
{
wrapperName = memberType;
}
else if (TryGetAnyWrapper(m.Type, context, SerdeUsage.Deserialize, inProgress) is { } wrap)
{
wrapperName = wrap.ToString();
}
else
{
// No built-in handling and doesn't implement IDeserialize, error
context.ReportDiagnostic(CreateDiagnostic(
DiagId.ERR_DoesntImplementInterface,
m.Locations[0],
m.Symbol,
memberType,
"Serde.IDeserialize"));
wrapperName = memberType;
}
var localName = GetLocalName(m);
localsBuilder.AppendLine($"{memberType} {localName} = default!;");
casesBuilder.AppendLine($"""
case {i}:
{localName} = typeDeserialize.ReadValue<{memberType}, {wrapperName}>({indexLocalName});
{AssignedVarName} |= (({assignedVarType})1) << {i};
break;
""");
if (m.IsNullable && !m.ThrowIfMissing)
{
assignedMaskValue |= 1L << i;
}
}
var unknownMemberBehavior = SymbolUtilities.GetTypeOptions(type).DenyUnknownMembers
? $"""
throw new InvalidDeserializeValueException("Unexpected field or property name in type {type.Name}: '" + _l_errorName + "'");
"""
: "break;";
casesBuilder.AppendLine($"""
case Serde.IDeserializeType.IndexNotFound:
{unknownMemberBehavior}
""");
casesBuilder.AppendLine($"""
default:
throw new InvalidOperationException("Unexpected index: " + {indexLocalName});
""");
return (casesBuilder.ToString(),
localsBuilder.ToString(),
"0b" + Convert.ToString(assignedMaskValue, 2));
}

}

// This is the old visitor-driven deserialization method. It is being replaced by the new
// TypeInfo-driven deserialization.
// Generate method `void ISerialize.Deserialize(IDeserializer deserializer) { ... }`
private static MethodDeclarationSyntax GenerateDeserializeMethod(
private static MethodDeclarationSyntax GenerateOldDeserializeMethod(
GeneratorExecutionContext context,
QualifiedNameSyntax interfaceSyntax,
ITypeSymbol typeSymbol)
Expand Down Expand Up @@ -302,10 +451,7 @@ private static MemberDeclarationSyntax GenerateCustomTypeVisitor(
<= 64 => "ulong",
_ => throw new InvalidOperationException("Too many members in type")
};
string cases;
string locals;
string assignedMask;
InitCasesAndLocals();
var (cases, locals, assignedMask) = InitCasesAndLocals();
string typeCreationExpr = GenerateTypeCreation(context, typeName, type, members);
var methodText = $$"""
{{typeName}} Serde.IDeserializeVisitor<{{typeName}}>.VisitDictionary<D>(ref D d)
Expand All @@ -325,7 +471,7 @@ private static MemberDeclarationSyntax GenerateCustomTypeVisitor(
""";
return ParseMemberDeclaration(methodText)!;

void InitCasesAndLocals()
(string Cases, string Locals, string AssignedMask) InitCasesAndLocals()
{
var casesBuilder = new StringBuilder();
var localsBuilder = new StringBuilder();
Expand Down Expand Up @@ -371,9 +517,9 @@ void InitCasesAndLocals()
assignedMaskValue |= 1L << i;
}
}
cases = casesBuilder.ToString();
locals = localsBuilder.ToString();
assignedMask = "0b" + Convert.ToString(assignedMaskValue, 2);
return (casesBuilder.ToString(),
localsBuilder.ToString(),
"0b" + Convert.ToString(assignedMaskValue, 2));
}
}

Expand Down Expand Up @@ -448,6 +594,10 @@ private static string GenerateTypeCreation(GeneratorExecutionContext context, st

foreach (var m in assignmentMembers)
{
if (m.SkipDeserialize)
{
continue;
}
assignments.AppendLine($"{m.Name} = {GetLocalName(m)},");
}
var mask = new string('1', members.Count);
Expand Down
68 changes: 63 additions & 5 deletions src/generator/Generator.SerdeTypeInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using static Serde.Diagnostics;

namespace Serde;

Expand Down Expand Up @@ -35,15 +36,70 @@ public static void GenerateTypeInfo(
{
return;
}
var receiverType = typeSymbol;

INamedTypeSymbol receiverType;
ExpressionSyntax receiverExpr;
string? wrapperName;
string? wrappedName;
// If the Through property is set, then we are implementing a wrapper type
if (attributeData.NamedArguments is [ (nameof(GenerateSerialize.Through), { Value: string memberName }) ])
{
var members = model.LookupSymbols(typeDecl.SpanStart, typeSymbol, memberName);
if (members.Length != 1)
{
// TODO: Error about bad lookup
return;
}
receiverType = (INamedTypeSymbol)SymbolUtilities.GetSymbolType(members[0]);
receiverExpr = IdentifierName(memberName);
wrapperName = typeDecl.Identifier.ValueText;
wrappedName = receiverType.ToDisplayString();
}
// Enums are also always wrapped, but the attribute is on the enum itself
else if (typeDecl.IsKind(SyntaxKind.EnumDeclaration))
{
receiverType = typeSymbol;
receiverExpr = IdentifierName("Value");
wrappedName = typeDecl.Identifier.ValueText;
wrapperName = GetWrapperName(wrappedName);
}
// Just a normal interface implementation
else
{
wrapperName = null;
wrappedName = null;
if (!typeDecl.Modifiers.Any(tok => tok.IsKind(SyntaxKind.PartialKeyword)))
{
// Type must be partial
context.ReportDiagnostic(CreateDiagnostic(
DiagId.ERR_TypeNotPartial,
typeDecl.Identifier.GetLocation(),
typeDecl.Identifier.ValueText));
return;
}
receiverType = typeSymbol;
receiverExpr = ThisExpression();
}

GenerateTypeInfo(typeDecl, receiverType, context);
}

public static void GenerateTypeInfo(
BaseTypeDeclarationSyntax typeDecl,
INamedTypeSymbol receiverType,
GeneratorExecutionContext context)
{

var statements = new List<StatementSyntax>();
var fieldsAndProps = SymbolUtilities.GetDataMembers(receiverType, SerdeUsage.Both);
var typeDeclContext = new TypeDeclContext(typeDecl);
var typeName = typeDeclContext.Name;
var typeString = receiverType.IsGenericType
? receiverType.Name + "<" + new string(',', receiverType.TypeParameters.Length - 1) + ">"
: receiverType.Name;
var typeName = receiverType.Name;
var typeString = receiverType.ToDisplayString();
if (typeString.IndexOf('<') is var index && index != -1)
{
typeString = typeString[..index];
typeString = typeString + "<" + new string(',', receiverType.TypeParameters.Length - 1) + ">";
}
var newType = $$"""
internal static class {{typeName}}SerdeTypeInfo
{
Expand All @@ -61,4 +117,6 @@ internal static class {{typeName}}SerdeTypeInfo

context.AddSource(fullTypeName, newType);
}

private static string GetWrapperName(string typeName) => typeName + "Wrap";
}
1 change: 1 addition & 0 deletions src/generator/Generator.Wrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ internal static void GenerateWrapper(

var inProgress = ImmutableList.Create(receiverType);

SerdeTypeInfoGenerator.GenerateTypeInfo(typeDecl, (INamedTypeSymbol)receiverType, context);
GenerateImpl(SerdeUsage.Serialize, new TypeDeclContext(typeDecl), receiverType, receiverExpr, context, inProgress);
SerializeImplRoslynGenerator.GenerateImpl(SerdeUsage.Serialize, new TypeDeclContext(typeDecl), receiverType, receiverExpr, context, inProgress);
GenerateImpl(SerdeUsage.Deserialize, new TypeDeclContext(typeDecl), receiverType, receiverExpr, context, inProgress);
Expand Down
23 changes: 17 additions & 6 deletions src/serde/IDeserialize.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,15 @@ public interface IDeserializeType
public const int EndOfType = -1;
public const int IndexNotFound = -2;

int TryReadIndex(TypeInfo map);
/// <summary>
/// Try to read the index of the next field in the type. If the index is found, the method
/// should return the index and set <paramref name="errorName" /> to null. If the end of the
/// type is reached, the method should return <see cref="EndOfType" /> and set <paramref
/// name="errorName" /> to null. If the field is not found, the method should return <see
/// cref="IndexNotFound" /> and set <paramref name="errorName" /> to the name of the missing
/// field, or the best-possible user-facing name.
/// </summary>
int TryReadIndex(TypeInfo map, out string? errorName);

V ReadValue<V, D>(int index) where D : IDeserialize<V>;
}
Expand All @@ -96,17 +104,20 @@ public sealed class TypeInfo
{
// The field names are sorted by the Utf8 representation of the field name.
private readonly ImmutableArray<(ReadOnlyMemory<byte> Utf8Name, int Index)> _nameToIndex;
private readonly ImmutableArray<FieldInfo> _indexToInfo;
private readonly ImmutableArray<PrivateFieldInfo> _indexToInfo;

private TypeInfo(
ImmutableArray<(ReadOnlyMemory<byte>, int)> nameToIndex,
ImmutableArray<FieldInfo> indexToInfo)
ImmutableArray<PrivateFieldInfo> indexToInfo)
{
_nameToIndex = nameToIndex;
_indexToInfo = indexToInfo;
}

private readonly record struct FieldInfo(IList<CustomAttributeData> CustomAttributesData);
/// <summary>
/// Holds information for a field or property in the given type.
/// </summary>
private readonly record struct PrivateFieldInfo(IList<CustomAttributeData> CustomAttributesData);


private static readonly UTF8Encoding s_utf8 = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false);
Expand All @@ -119,7 +130,7 @@ public static TypeInfo Create(
ReadOnlySpan<(string SerializeName, MemberInfo MemberInfo)> fields)
{
var nameToIndexBuilder = ImmutableArray.CreateBuilder<(ReadOnlyMemory<byte> Utf8Name, int Index)>(fields.Length);
var indexToInfoBuilder = ImmutableArray.CreateBuilder<FieldInfo>(fields.Length);
var indexToInfoBuilder = ImmutableArray.CreateBuilder<PrivateFieldInfo>(fields.Length);
for (int index = 0; index < fields.Length; index++)
{
var (serializeName, memberInfo) = fields[index];
Expand All @@ -129,7 +140,7 @@ public static TypeInfo Create(
}

nameToIndexBuilder.Add((s_utf8.GetBytes(serializeName), index));
var fieldInfo = new FieldInfo(memberInfo.GetCustomAttributesData());
var fieldInfo = new PrivateFieldInfo(memberInfo.GetCustomAttributesData());
indexToInfoBuilder.Add(fieldInfo);
}

Expand Down
Loading

0 comments on commit 015ea73

Please sign in to comment.