From 220a9a3e8d099e5ad3e0812acd88295ae84b3e73 Mon Sep 17 00:00:00 2001 From: Oleksandr Liakhevych Date: Sun, 2 Jun 2024 16:36:20 +0300 Subject: [PATCH] Move scanning types to separate step --- .../CombinedProviderComparer.cs | 2 +- .../DependencyInjectionGenerator.cs | 139 ++++++++++-------- DependencyInjection.SourceGenerator/Model.cs | 21 ++- 3 files changed, 90 insertions(+), 72 deletions(-) diff --git a/DependencyInjection.SourceGenerator/CombinedProviderComparer.cs b/DependencyInjection.SourceGenerator/CombinedProviderComparer.cs index 904a19a..6f729fb 100644 --- a/DependencyInjection.SourceGenerator/CombinedProviderComparer.cs +++ b/DependencyInjection.SourceGenerator/CombinedProviderComparer.cs @@ -3,7 +3,7 @@ namespace DependencyInjection.SourceGenerator; -using CombinedModel = (MethodModel Model, Compilation Compilation); +using CombinedModel = (MethodWithAttributesModel Model, Compilation Compilation); // We only compare Model here and ignore Compilation, as I don't want to run it on every input. internal class CombinedProviderComparer : IEqualityComparer diff --git a/DependencyInjection.SourceGenerator/DependencyInjectionGenerator.cs b/DependencyInjection.SourceGenerator/DependencyInjectionGenerator.cs index bff5b96..297b67e 100644 --- a/DependencyInjection.SourceGenerator/DependencyInjectionGenerator.cs +++ b/DependencyInjection.SourceGenerator/DependencyInjectionGenerator.cs @@ -46,111 +46,122 @@ public void Initialize(IncrementalGeneratorInitializationContext context) return null; var attributeData = context.Attributes.Select(AttributeModel.Create); - var model = MethodModel.Create(method, attributeData); + var model = MethodModel.Create(method); + return new MethodWithAttributesModel(model, new EquatableArray(attributeData.ToArray())); //if (Previous != null && !model.Equals(Previous)) // System.Diagnostics.Debugger.Launch(); //Previous = model; - - return model; }) .Where(method => method != null); var combinedProvider = methodProvider.Combine(context.CompilationProvider) .WithComparer(CombinedProviderComparer.Instance); - // We require all matching type symbols, and create the generated files. - context.RegisterImplementationSourceOutput(combinedProvider, - static (context, src) => + var methodImplementationsProvider = combinedProvider.Select(static (context, ct) => + { + var ((method, attributes), compilation) = context; + + var registrations = new List(); + + foreach (var attribute in attributes) { - var (model, compilation) = src; + var assembly = compilation.GetTypeByMetadataName(attribute.AssemblyOfTypeName ?? method.TypeMetadataName).ContainingAssembly; - //var sw = System.Diagnostics.Stopwatch.StartNew(); + var assignableToType = attribute.AssignableToTypeName is null + ? null + : compilation.GetTypeByMetadataName(attribute.AssignableToTypeName); - var sb = new StringBuilder(); - var attributes = model.Attributes; + var types = GetTypesFromAssembly(assembly) + .Where(t => !t.IsAbstract && !t.IsStatic && t.TypeKind == TypeKind.Class); + + if (attribute.TypeNameFilter != null) + { + var regex = $"^{Regex.Escape(attribute.TypeNameFilter).Replace(@"\*", ".*")}$"; + types = types.Where(t => Regex.IsMatch(t.ToDisplayString(), regex)); + } - foreach (var attribute in attributes) + foreach (var t in types) { - var assembly = compilation.GetTypeByMetadataName(attribute.AssemblyOfTypeName ?? model.TypeMetadataName).ContainingAssembly; + var implementationType = t; - var assignableToType = attribute.AssignableToTypeName is null - ? null - : compilation.GetTypeByMetadataName(attribute.AssignableToTypeName); + INamedTypeSymbol matchedType = null; + if (assignableToType != null && !IsAssignableTo(implementationType, assignableToType, out matchedType)) + continue; - var types = GetTypesFromAssembly(assembly) - .Where(t => !t.IsAbstract && !t.IsStatic && t.TypeKind == TypeKind.Class); + IEnumerable serviceTypes = null; - if (attribute.TypeNameFilter != null) + if (matchedType != null) { - var regex = $"^{Regex.Escape(attribute.TypeNameFilter).Replace(@"\*", ".*")}$"; - types = types.Where(t => Regex.IsMatch(t.ToDisplayString(), regex)); + serviceTypes = [matchedType]; } - - bool anyFound = false; - - foreach (var t in types) + else { - var implementationType = t; - - INamedTypeSymbol matchedType = null; - if (assignableToType != null && !IsAssignableTo(implementationType, assignableToType, out matchedType)) - continue; - - anyFound = true; - - IEnumerable serviceTypes = null; + serviceTypes = attribute.AsImplementedInterfaces + ? implementationType.AllInterfaces + : [implementationType]; + } - if (matchedType != null) + foreach (var serviceType in serviceTypes) + { + if (implementationType.IsGenericType) { - serviceTypes = [matchedType]; + var implementationTypeName = implementationType.ConstructUnboundGenericType().ToDisplayString(); + var serviceTypeName = serviceType.IsGenericType + ? serviceType.ConstructUnboundGenericType().ToDisplayString() + : serviceType.ToDisplayString(); + + var registration = new ServiceRegistrationModel(attribute.Lifetime, serviceTypeName, implementationTypeName, true); + registrations.Add(registration); } else { - serviceTypes = attribute.AsImplementedInterfaces - ? implementationType.AllInterfaces - : [implementationType]; - } - foreach (var serviceType in serviceTypes) - { - if (implementationType.IsGenericType) - { - var implementationTypeName = implementationType.ConstructUnboundGenericType().ToDisplayString(); - var serviceTypeName = serviceType.IsGenericType - ? serviceType.ConstructUnboundGenericType().ToDisplayString() - : serviceType.ToDisplayString(); - - sb.AppendLine($" .Add{attribute.Lifetime}(typeof({serviceTypeName}), typeof({implementationTypeName}))"); - } - else - { - sb.AppendLine($" .Add{attribute.Lifetime}<{serviceType.ToDisplayString()}, {implementationType.ToDisplayString()}>()"); - } + var registration = new ServiceRegistrationModel(attribute.Lifetime, serviceType.ToDisplayString(), implementationType.ToDisplayString(), false); + registrations.Add(registration); } } + } + } + + return new MethodImplementationModel(method, new EquatableArray([.. registrations])); + }); + + // We require all matching type symbols, and create the generated files. + context.RegisterImplementationSourceOutput(methodImplementationsProvider, + static (context, src) => + { + var (method, registrations) = src; + + //var sw = System.Diagnostics.Stopwatch.StartNew(); + + var sb = new StringBuilder(); - if (!anyFound) + foreach (var registration in registrations) + { + if (registration.IsOpenGeneric) + { + sb.AppendLine($" .Add{registration.Lifetime}(typeof({registration.ServiceTypeName}), typeof({registration.ImplementationTypeName}))"); + } + else { - //context.ReportDiagnostic(Diagnostic.Create(NoMatchingTypesFound, method.Locations[0])); - return; + sb.AppendLine($" .Add{registration.Lifetime}<{registration.ServiceTypeName}, {registration.ImplementationTypeName}>()"); } } - var returnType = model.ReturnsVoid ? "void" : "IServiceCollection"; - + var returnType = method.ReturnsVoid ? "void" : "IServiceCollection"; var source = $$""" using Microsoft.Extensions.DependencyInjection; - namespace {{model.Namespace}}; + namespace {{method.Namespace}}; - {{model.TypeAccessModifier}} {{model.TypeStatic}} partial class {{model.TypeName}} + {{method.TypeAccessModifier}} {{method.TypeStatic}} partial class {{method.TypeName}} { - {{model.MethodAccessModifier}} {{model.MethodStatic}} partial {{returnType}} {{model.MethodName}}({{(model.IsExtensionMethod ? "this" : "")}} IServiceCollection services) + {{method.MethodAccessModifier}} {{method.MethodStatic}} partial {{returnType}} {{method.MethodName}}({{(method.IsExtensionMethod ? "this" : "")}} IServiceCollection services) { - {{(model.ReturnsVoid ? "" : "return ")}}services + {{(method.ReturnsVoid ? "" : "return ")}}services {{sb.ToString().Trim()}}; } } @@ -163,7 +174,7 @@ namespace {{model.Namespace}}; //""" + source; //Iteration++; - context.AddSource($"{model.TypeName}_{model.MethodName}.Generated.cs", SourceText.From(source, Encoding.UTF8)); + context.AddSource($"{method.TypeName}_{method.MethodName}.Generated.cs", SourceText.From(source, Encoding.UTF8)); }); } diff --git a/DependencyInjection.SourceGenerator/Model.cs b/DependencyInjection.SourceGenerator/Model.cs index 6fe6841..86993b0 100644 --- a/DependencyInjection.SourceGenerator/Model.cs +++ b/DependencyInjection.SourceGenerator/Model.cs @@ -1,9 +1,18 @@ -using System.Collections.Generic; -using System.Linq; +using System.Linq; using Microsoft.CodeAnalysis; namespace DependencyInjection.SourceGenerator; +record MethodImplementationModel(MethodModel Method, EquatableArray Registrations); + +record ServiceRegistrationModel( + string Lifetime, + string ServiceTypeName, + string ImplementationTypeName, + bool IsOpenGeneric); + +record MethodWithAttributesModel(MethodModel Method, EquatableArray Attributes); + record MethodModel( string Namespace, string TypeName, @@ -14,10 +23,9 @@ record MethodModel( string MethodAccessModifier, string MethodStatic, bool IsExtensionMethod, - bool ReturnsVoid, - EquatableArray Attributes) + bool ReturnsVoid) { - public static MethodModel Create(IMethodSymbol method, IEnumerable attributes) + public static MethodModel Create(IMethodSymbol method) { return new MethodModel( Namespace: method.ContainingNamespace.ToDisplayString(), @@ -29,8 +37,7 @@ public static MethodModel Create(IMethodSymbol method, IEnumerable(attributes.ToArray())); + ReturnsVoid: method.ReturnsVoid); } private static string IsStatic(ISymbol symbol)