Skip to content

Commit

Permalink
Move scanning types to separate step
Browse files Browse the repository at this point in the history
  • Loading branch information
Dreamescaper committed Jun 2, 2024
1 parent fcc2d75 commit 220a9a3
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace DependencyInjection.SourceGenerator;

using CombinedModel = (MethodModel Model, Compilation Compilation);
using CombinedModel = (MethodWithAttributesModel Model, Compilation Compilation);

// We only compare Model here and ignore Compilation, as I don't want to run it on every input.
internal class CombinedProviderComparer : IEqualityComparer<CombinedModel>
Expand Down
139 changes: 75 additions & 64 deletions DependencyInjection.SourceGenerator/DependencyInjectionGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,111 +46,122 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
return null;

var attributeData = context.Attributes.Select(AttributeModel.Create);
var model = MethodModel.Create(method, attributeData);
var model = MethodModel.Create(method);
return new MethodWithAttributesModel(model, new EquatableArray<AttributeModel>(attributeData.ToArray()));

//if (Previous != null && !model.Equals(Previous))
// System.Diagnostics.Debugger.Launch();

//Previous = model;

return model;
})
.Where(method => method != null);

var combinedProvider = methodProvider.Combine(context.CompilationProvider)
.WithComparer(CombinedProviderComparer.Instance);

// We require all matching type symbols, and create the generated files.
context.RegisterImplementationSourceOutput(combinedProvider,
static (context, src) =>
var methodImplementationsProvider = combinedProvider.Select(static (context, ct) =>
{
var ((method, attributes), compilation) = context;

var registrations = new List<ServiceRegistrationModel>();

foreach (var attribute in attributes)
{
var (model, compilation) = src;
var assembly = compilation.GetTypeByMetadataName(attribute.AssemblyOfTypeName ?? method.TypeMetadataName).ContainingAssembly;

//var sw = System.Diagnostics.Stopwatch.StartNew();
var assignableToType = attribute.AssignableToTypeName is null
? null
: compilation.GetTypeByMetadataName(attribute.AssignableToTypeName);

var sb = new StringBuilder();
var attributes = model.Attributes;
var types = GetTypesFromAssembly(assembly)
.Where(t => !t.IsAbstract && !t.IsStatic && t.TypeKind == TypeKind.Class);

if (attribute.TypeNameFilter != null)
{
var regex = $"^{Regex.Escape(attribute.TypeNameFilter).Replace(@"\*", ".*")}$";
types = types.Where(t => Regex.IsMatch(t.ToDisplayString(), regex));
}

foreach (var attribute in attributes)
foreach (var t in types)
{
var assembly = compilation.GetTypeByMetadataName(attribute.AssemblyOfTypeName ?? model.TypeMetadataName).ContainingAssembly;
var implementationType = t;

var assignableToType = attribute.AssignableToTypeName is null
? null
: compilation.GetTypeByMetadataName(attribute.AssignableToTypeName);
INamedTypeSymbol matchedType = null;
if (assignableToType != null && !IsAssignableTo(implementationType, assignableToType, out matchedType))
continue;

var types = GetTypesFromAssembly(assembly)
.Where(t => !t.IsAbstract && !t.IsStatic && t.TypeKind == TypeKind.Class);
IEnumerable<INamedTypeSymbol> serviceTypes = null;

if (attribute.TypeNameFilter != null)
if (matchedType != null)
{
var regex = $"^{Regex.Escape(attribute.TypeNameFilter).Replace(@"\*", ".*")}$";
types = types.Where(t => Regex.IsMatch(t.ToDisplayString(), regex));
serviceTypes = [matchedType];
}

bool anyFound = false;

foreach (var t in types)
else
{
var implementationType = t;

INamedTypeSymbol matchedType = null;
if (assignableToType != null && !IsAssignableTo(implementationType, assignableToType, out matchedType))
continue;

anyFound = true;

IEnumerable<INamedTypeSymbol> serviceTypes = null;
serviceTypes = attribute.AsImplementedInterfaces
? implementationType.AllInterfaces
: [implementationType];
}

if (matchedType != null)
foreach (var serviceType in serviceTypes)
{
if (implementationType.IsGenericType)
{
serviceTypes = [matchedType];
var implementationTypeName = implementationType.ConstructUnboundGenericType().ToDisplayString();
var serviceTypeName = serviceType.IsGenericType
? serviceType.ConstructUnboundGenericType().ToDisplayString()
: serviceType.ToDisplayString();

var registration = new ServiceRegistrationModel(attribute.Lifetime, serviceTypeName, implementationTypeName, true);
registrations.Add(registration);
}
else
{
serviceTypes = attribute.AsImplementedInterfaces
? implementationType.AllInterfaces
: [implementationType];
}

foreach (var serviceType in serviceTypes)
{
if (implementationType.IsGenericType)
{
var implementationTypeName = implementationType.ConstructUnboundGenericType().ToDisplayString();
var serviceTypeName = serviceType.IsGenericType
? serviceType.ConstructUnboundGenericType().ToDisplayString()
: serviceType.ToDisplayString();

sb.AppendLine($" .Add{attribute.Lifetime}(typeof({serviceTypeName}), typeof({implementationTypeName}))");
}
else
{
sb.AppendLine($" .Add{attribute.Lifetime}<{serviceType.ToDisplayString()}, {implementationType.ToDisplayString()}>()");
}
var registration = new ServiceRegistrationModel(attribute.Lifetime, serviceType.ToDisplayString(), implementationType.ToDisplayString(), false);
registrations.Add(registration);
}
}
}
}

return new MethodImplementationModel(method, new EquatableArray<ServiceRegistrationModel>([.. registrations]));
});

// We require all matching type symbols, and create the generated files.
context.RegisterImplementationSourceOutput(methodImplementationsProvider,
static (context, src) =>
{
var (method, registrations) = src;

//var sw = System.Diagnostics.Stopwatch.StartNew();

var sb = new StringBuilder();

if (!anyFound)
foreach (var registration in registrations)
{
if (registration.IsOpenGeneric)
{
sb.AppendLine($" .Add{registration.Lifetime}(typeof({registration.ServiceTypeName}), typeof({registration.ImplementationTypeName}))");
}
else
{
//context.ReportDiagnostic(Diagnostic.Create(NoMatchingTypesFound, method.Locations[0]));
return;
sb.AppendLine($" .Add{registration.Lifetime}<{registration.ServiceTypeName}, {registration.ImplementationTypeName}>()");
}
}

var returnType = model.ReturnsVoid ? "void" : "IServiceCollection";

var returnType = method.ReturnsVoid ? "void" : "IServiceCollection";

var source = $$"""
using Microsoft.Extensions.DependencyInjection;
namespace {{model.Namespace}};
namespace {{method.Namespace}};
{{model.TypeAccessModifier}} {{model.TypeStatic}} partial class {{model.TypeName}}
{{method.TypeAccessModifier}} {{method.TypeStatic}} partial class {{method.TypeName}}
{
{{model.MethodAccessModifier}} {{model.MethodStatic}} partial {{returnType}} {{model.MethodName}}({{(model.IsExtensionMethod ? "this" : "")}} IServiceCollection services)
{{method.MethodAccessModifier}} {{method.MethodStatic}} partial {{returnType}} {{method.MethodName}}({{(method.IsExtensionMethod ? "this" : "")}} IServiceCollection services)
{
{{(model.ReturnsVoid ? "" : "return ")}}services
{{(method.ReturnsVoid ? "" : "return ")}}services
{{sb.ToString().Trim()}};
}
}
Expand All @@ -163,7 +174,7 @@ namespace {{model.Namespace}};
//""" + source;
//Iteration++;

context.AddSource($"{model.TypeName}_{model.MethodName}.Generated.cs", SourceText.From(source, Encoding.UTF8));
context.AddSource($"{method.TypeName}_{method.MethodName}.Generated.cs", SourceText.From(source, Encoding.UTF8));
});
}

Expand Down
21 changes: 14 additions & 7 deletions DependencyInjection.SourceGenerator/Model.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
using System.Collections.Generic;
using System.Linq;
using System.Linq;
using Microsoft.CodeAnalysis;

namespace DependencyInjection.SourceGenerator;

record MethodImplementationModel(MethodModel Method, EquatableArray<ServiceRegistrationModel> Registrations);

record ServiceRegistrationModel(
string Lifetime,
string ServiceTypeName,
string ImplementationTypeName,
bool IsOpenGeneric);

record MethodWithAttributesModel(MethodModel Method, EquatableArray<AttributeModel> Attributes);

record MethodModel(
string Namespace,
string TypeName,
Expand All @@ -14,10 +23,9 @@ record MethodModel(
string MethodAccessModifier,
string MethodStatic,
bool IsExtensionMethod,
bool ReturnsVoid,
EquatableArray<AttributeModel> Attributes)
bool ReturnsVoid)
{
public static MethodModel Create(IMethodSymbol method, IEnumerable<AttributeModel> attributes)
public static MethodModel Create(IMethodSymbol method)
{
return new MethodModel(
Namespace: method.ContainingNamespace.ToDisplayString(),
Expand All @@ -29,8 +37,7 @@ public static MethodModel Create(IMethodSymbol method, IEnumerable<AttributeMode
MethodAccessModifier: GetAccessModifier(method),
MethodStatic: IsStatic(method),
IsExtensionMethod: method.IsExtensionMethod,
ReturnsVoid: method.ReturnsVoid,
Attributes: new EquatableArray<AttributeModel>(attributes.ToArray()));
ReturnsVoid: method.ReturnsVoid);
}

private static string IsStatic(ISymbol symbol)
Expand Down

0 comments on commit 220a9a3

Please sign in to comment.