diff --git a/README.md b/README.md index fe9817e..9e0efb6 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,37 @@ It adds MediatR requests handlers, although you might need to add other types li private static partial IServiceCollection AddRepositories(this IServiceCollection services); ``` +### Add AspNetCore Minimal API endpoints +You can add custom type handler, if you need to do something non-trivial with that type. For example, you can automatically discover +and map Minimal API endpoints: +```csharp +public interface IEndpoint +{ + abstract static void MapEndpoint(IEndpointRouteBuilder endpoints); +} + +public class HelloWorldEndpoint : IEndpoint +{ + public static void MapEndpoint(IEndpointRouteBuilder endpoints) + { + endpoints.MapGet("/", () => "Hello World!"); + } +} + +public static partial class ServiceCollectionExtensions +{ + [GenerateServiceRegistrations(AssignableTo = typeof(IEndpoint), CustomHandler = nameof(MapEndpoint))] + public static partial IEndpointRouteBuilder MapEndpoints(this IEndpointRouteBuilder endpoints); + + private static void MapEndpoint(IEndpointRouteBuilder endpoints) where T : IEndpoint + { + T.MapEndpoint(endpoints); + } +} + +``` + + ## Parameters `GenerateServiceRegistrations` attribute has the following properties: @@ -82,6 +113,7 @@ private static partial IServiceCollection AddRepositories(this IServiceCollectio | **AssignableTo** | Set the type that the registered types must be assignable to. Types will be registered with this type as the service type, unless `AsImplementedInterfaces` or `AsSelf` is set. | | **Lifetime** | Set the lifetime of the registered services. `ServiceLifetime.Transient` is used if not specified. | | **AsImplementedInterfaces** | If true, the registered types will be registered as implemented interfaces instead of their actual type. | -| **AsSelf** | If true, types will be registered with their actual type. It can be combined with `AsImplementedInterfaces`. In that case implemeted interfaces will be "forwarded" to an actual implementation type | +| **AsSelf** | If true, types will be registered with their actual type. It can be combined with `AsImplementedInterfaces`. In that case implemented interfaces will be "forwarded" to an actual implementation type | | **TypeNameFilter** | Set this value to filter the types to register by their full name. You can use '*' wildcards. You can also use ',' to separate multiple filters. | | **KeySelector** | Set this value to a static method name returning string. Returned value will be used as a key for the registration. Method should either be generic, or have a single parameter of type `Type`. | +| **CustomHandler** | Set this property to a static generic method name in the current class. Set this property to a static generic method name in the current class. This property is incompatible with 'Lifetime', 'AsImplementedInterfaces', 'AsSelf', 'KeySelector' properties. | \ No newline at end of file diff --git a/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs b/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs new file mode 100644 index 0000000..2663e06 --- /dev/null +++ b/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs @@ -0,0 +1,179 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace ServiceScan.SourceGenerator.Tests; + +public class CustomHandlerTests +{ + private readonly DependencyInjectionGenerator _generator = new(); + + [Fact] + public void CustomHandlerWithNoParameters() + { + var source = $$""" + using ServiceScan.SourceGenerator; + + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + [GenerateServiceRegistrations(AssignableTo = typeof(IService), CustomHandler = nameof(HandleType))] + public static partial void ProcessServices(); + + private static void HandleType() => System.Console.WriteLine(typeof(T).Name); + } + """; + + var services = + """ + namespace GeneratorTests; + + public interface IService { } + public class MyService1 : IService { } + public class MyService2 : IService { } + """; + + var compilation = CreateCompilation(source, services); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var expected = $$""" + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + public static partial void ProcessServices() + { + HandleType(); + HandleType(); + } + } + """; + Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + } + + [Fact] + public void CustomHandlerWithParameters() + { + var source = $$""" + using ServiceScan.SourceGenerator; + + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + [GenerateServiceRegistrations(TypeNameFilter = "*Service", CustomHandler = nameof(HandleType))] + public static partial void ProcessServices(string value, decimal number); + + private static void HandleType(string value, decimal number) => System.Console.WriteLine(value + number.ToString() + typeof(T).Name); + } + """; + + var services = + """ + namespace GeneratorTests; + + public class MyFirstService {} + public class MySecondService {} + public class ServiceWithNonMatchingName {} + """; + + var compilation = CreateCompilation(source, services); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var expected = $$""" + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + public static partial void ProcessServices( string value, decimal number) + { + HandleType(value, number); + HandleType(value, number); + } + } + """; + Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + } + + [Fact] + public void CustomHandlerExtensionMethod() + { + var source = $$""" + using ServiceScan.SourceGenerator; + + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + [GenerateServiceRegistrations(AssignableTo = typeof(IService), CustomHandler = nameof(HandleType))] + public static partial IServices ProcessServices(this IServices services); + + private static void HandleType(IServices services) where T:IService, new() => services.Add(new T()); + } + """; + + var services = + """ + namespace GeneratorTests; + + public interface IServices + { + void Add(IService service); + } + + public interface IService { } + public class MyService1 : IService { } + public class MyService2 : IService { } + """; + + var compilation = CreateCompilation(source, services); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var expected = $$""" + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + public static partial GeneratorTests.IServices ProcessServices(this GeneratorTests.IServices services) + { + HandleType(services); + HandleType(services); + return services; + } + } + """; + Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + } + + private static Compilation CreateCompilation(params string[] source) + { + var path = Path.GetDirectoryName(typeof(object).Assembly.Location)!; + var runtimeAssemblyPath = Path.Combine(path, "System.Runtime.dll"); + + var runtimeReference = MetadataReference.CreateFromFile(typeof(object).Assembly.Location); + + return CSharpCompilation.Create("compilation", + source.Select(s => CSharpSyntaxTree.ParseText(s)), + [ + MetadataReference.CreateFromFile(typeof(object).Assembly.Location), + MetadataReference.CreateFromFile(runtimeAssemblyPath), + MetadataReference.CreateFromFile(typeof(IServiceCollection).Assembly.Location), + MetadataReference.CreateFromFile(typeof(External.IExternalService).Assembly.Location), + ], + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); + } +} diff --git a/ServiceScan.SourceGenerator.Tests/Sources.cs b/ServiceScan.SourceGenerator.Tests/Sources.cs index 23c73c7..3fb6c63 100644 --- a/ServiceScan.SourceGenerator.Tests/Sources.cs +++ b/ServiceScan.SourceGenerator.Tests/Sources.cs @@ -2,10 +2,6 @@ public static class Sources { - public const string Services = """ - - """; - public static string MethodWithAttribute(string attribute) { attribute = attribute.Replace("\n", "\n "); diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs new file mode 100644 index 0000000..60cc5f8 --- /dev/null +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs @@ -0,0 +1,127 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.RegularExpressions; +using Microsoft.CodeAnalysis; +using ServiceScan.SourceGenerator.Model; + +namespace ServiceScan.SourceGenerator; + +public partial class DependencyInjectionGenerator +{ + private static IEnumerable<(INamedTypeSymbol Type, INamedTypeSymbol? MatchedAssignableType)> FilterTypes + (Compilation compilation, AttributeModel attribute, INamedTypeSymbol containingType) + { + var assembly = (attribute.AssemblyOfTypeName is null + ? containingType + : compilation.GetTypeByMetadataName(attribute.AssemblyOfTypeName)).ContainingAssembly; + + var assignableToType = attribute.AssignableToTypeName is null + ? null + : compilation.GetTypeByMetadataName(attribute.AssignableToTypeName); + + if (assignableToType != null && attribute.AssignableToGenericArguments != null) + { + var typeArguments = attribute.AssignableToGenericArguments.Value.Select(t => compilation.GetTypeByMetadataName(t)).ToArray(); + assignableToType = assignableToType.Construct(typeArguments); + } + + foreach (var type in GetTypesFromAssembly(assembly)) + { + if (type.IsAbstract || type.IsStatic || !type.CanBeReferencedByName || type.TypeKind != TypeKind.Class) + continue; + + if (attribute.TypeNameFilter != null) + { + var regex = $"^({Regex.Escape(attribute.TypeNameFilter).Replace(@"\*", ".*").Replace(",", "|")})$"; + + if (!Regex.IsMatch(type.ToDisplayString(), regex)) + continue; + } + + INamedTypeSymbol matchedType = null; + if (assignableToType != null && !IsAssignableTo(type, assignableToType, out matchedType)) + continue; + + yield return (type, matchedType); + } + } + + private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assignableTo, out INamedTypeSymbol matchedType) + { + if (SymbolEqualityComparer.Default.Equals(type, assignableTo)) + { + matchedType = type; + return true; + } + + if (assignableTo.IsGenericType && assignableTo.IsDefinition) + { + if (assignableTo.TypeKind == TypeKind.Interface) + { + var matchingInterface = type.AllInterfaces.FirstOrDefault(i => i.IsGenericType && SymbolEqualityComparer.Default.Equals(i.OriginalDefinition, assignableTo)); + matchedType = matchingInterface; + return matchingInterface != null; + } + + var baseType = type.BaseType; + while (baseType != null) + { + if (baseType.IsGenericType && SymbolEqualityComparer.Default.Equals(baseType.OriginalDefinition, assignableTo)) + { + matchedType = baseType; + return true; + } + + baseType = baseType.BaseType; + } + } + else + { + if (assignableTo.TypeKind == TypeKind.Interface) + { + matchedType = assignableTo; + return type.AllInterfaces.Contains(assignableTo, SymbolEqualityComparer.Default); + } + + var baseType = type.BaseType; + while (baseType != null) + { + if (SymbolEqualityComparer.Default.Equals(baseType, assignableTo)) + { + matchedType = baseType; + return true; + } + + baseType = baseType.BaseType; + } + } + + matchedType = null; + return false; + } + + private static IEnumerable GetTypesFromAssembly(IAssemblySymbol assemblySymbol) + { + var @namespace = assemblySymbol.GlobalNamespace; + return GetTypesFromNamespace(@namespace); + + static IEnumerable GetTypesFromNamespace(INamespaceSymbol namespaceSymbol) + { + foreach (var member in namespaceSymbol.GetMembers()) + { + if (member is INamedTypeSymbol namedType) + { + yield return namedType; + } + else if (member is INamespaceSymbol nestedNamespace) + { + foreach (var type in GetTypesFromNamespace(nestedNamespace)) + { + yield return type; + } + } + } + } + } +} diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs index 6ab62de..29c14f7 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs @@ -1,7 +1,5 @@ -using System; -using System.Collections.Generic; +using System.Collections.Generic; using System.Linq; -using System.Text.RegularExpressions; using Microsoft.CodeAnalysis; using ServiceScan.SourceGenerator.Model; using static ServiceScan.SourceGenerator.DiagnosticDescriptors; @@ -21,6 +19,7 @@ private static DiagnosticModel FindServicesToRegister var (method, attributes) = diagnosticModel.Model; var registrations = new List(); + var customHandlers = new List(); foreach (var attribute in attributes) { @@ -28,100 +27,59 @@ private static DiagnosticModel FindServicesToRegister var containingType = compilation.GetTypeByMetadataName(method.TypeMetadataName); - var assembly = (attribute.AssemblyOfTypeName is null - ? containingType - : compilation.GetTypeByMetadataName(attribute.AssemblyOfTypeName)).ContainingAssembly; - - var assignableToType = attribute.AssignableToTypeName is null - ? null - : compilation.GetTypeByMetadataName(attribute.AssignableToTypeName); - - var keySelectorMethod = attribute.KeySelector is null - ? null - : containingType.GetMembers().OfType().FirstOrDefault(m => - m.IsStatic && m.Name == attribute.KeySelector); - - if (attribute.KeySelector != null) - { - if (keySelectorMethod is null) - return Diagnostic.Create(KeySelectorMethodNotFound, attribute.Location); - - if (keySelectorMethod.ReturnsVoid) - return Diagnostic.Create(KeySelectorMethodHasIncorrectSignature, attribute.Location); - - var validGenericKeySelector = keySelectorMethod.TypeArguments.Length == 1 && keySelectorMethod.Parameters.Length == 0; - var validNonGenericKeySelector = !keySelectorMethod.IsGenericMethod && keySelectorMethod.Parameters is [{ Type.Name: nameof(Type) }]; - - if (!validGenericKeySelector && !validNonGenericKeySelector) - return Diagnostic.Create(KeySelectorMethodHasIncorrectSignature, attribute.Location); - } - - if (assignableToType != null && attribute.AssignableToGenericArguments != null) - { - var typeArguments = attribute.AssignableToGenericArguments.Value.Select(t => compilation.GetTypeByMetadataName(t)).ToArray(); - assignableToType = assignableToType.Construct(typeArguments); - } - - var types = GetTypesFromAssembly(assembly) - .Where(t => !t.IsAbstract && !t.IsStatic && t.CanBeReferencedByName && t.TypeKind == TypeKind.Class); - - if (attribute.TypeNameFilter != null) - { - var regex = $"^({Regex.Escape(attribute.TypeNameFilter).Replace(@"\*", ".*").Replace(",", "|")})$"; - types = types.Where(t => Regex.IsMatch(t.ToDisplayString(), regex)); - } - - foreach (var t in types) + foreach (var (implementationType, matchedType) in FilterTypes(compilation, attribute, containingType)) { - var implementationType = t; + typesFound = true; - INamedTypeSymbol matchedType = null; - if (assignableToType != null && !IsAssignableTo(implementationType, assignableToType, out matchedType)) - continue; - - IEnumerable serviceTypes = (attribute.AsSelf, attribute.AsImplementedInterfaces) switch + if (attribute.CustomHandler != null) { - (true, true) => new[] { implementationType }.Concat(implementationType.AllInterfaces), - (false, true) => implementationType.AllInterfaces, - (true, false) => [implementationType], - _ => [matchedType ?? implementationType] - }; - - foreach (var serviceType in serviceTypes) + customHandlers.Add(new CustomHandlerModel(attribute.CustomHandler, implementationType.ToDisplayString())); + } + else { - if (implementationType.IsGenericType) + IEnumerable serviceTypes = (attribute.AsSelf, attribute.AsImplementedInterfaces) switch { - var implementationTypeName = implementationType.ConstructUnboundGenericType().ToDisplayString(); - var serviceTypeName = serviceType.IsGenericType - ? serviceType.ConstructUnboundGenericType().ToDisplayString() - : serviceType.ToDisplayString(); - - var registration = new ServiceRegistrationModel( - attribute.Lifetime, - serviceTypeName, - implementationTypeName, - false, - true, - keySelectorMethod?.Name, - keySelectorMethod?.IsGenericMethod); + (true, true) => new[] { implementationType }.Concat(implementationType.AllInterfaces), + (false, true) => implementationType.AllInterfaces, + (true, false) => [implementationType], + _ => [matchedType ?? implementationType] + }; - registrations.Add(registration); - } - else + foreach (var serviceType in serviceTypes) { - var shouldResolve = attribute.AsSelf && attribute.AsImplementedInterfaces && !SymbolEqualityComparer.Default.Equals(implementationType, serviceType); - var registration = new ServiceRegistrationModel( - attribute.Lifetime, - serviceType.ToDisplayString(), - implementationType.ToDisplayString(), - shouldResolve, - false, - keySelectorMethod?.Name, - keySelectorMethod?.IsGenericMethod); - registrations.Add(registration); + if (implementationType.IsGenericType) + { + var implementationTypeName = implementationType.ConstructUnboundGenericType().ToDisplayString(); + var serviceTypeName = serviceType.IsGenericType + ? serviceType.ConstructUnboundGenericType().ToDisplayString() + : serviceType.ToDisplayString(); + + var registration = new ServiceRegistrationModel( + attribute.Lifetime, + serviceTypeName, + implementationTypeName, + false, + true, + attribute.KeySelector, + attribute.KeySelectorGeneric); + + registrations.Add(registration); + } + else + { + var shouldResolve = attribute.AsSelf && attribute.AsImplementedInterfaces && !SymbolEqualityComparer.Default.Equals(implementationType, serviceType); + var registration = new ServiceRegistrationModel( + attribute.Lifetime, + serviceType.ToDisplayString(), + implementationType.ToDisplayString(), + shouldResolve, + false, + attribute.KeySelector, + attribute.KeySelectorGeneric); + + registrations.Add(registration); + } } - - typesFound = true; } } @@ -129,85 +87,7 @@ private static DiagnosticModel FindServicesToRegister diagnostic ??= Diagnostic.Create(NoMatchingTypesFound, attribute.Location); } - var implementationModel = new MethodImplementationModel(method, new EquatableArray([.. registrations])); + var implementationModel = new MethodImplementationModel(method, [.. registrations], [.. customHandlers]); return new(diagnostic, implementationModel); } - - private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assignableTo, out INamedTypeSymbol matchedType) - { - if (SymbolEqualityComparer.Default.Equals(type, assignableTo)) - { - matchedType = type; - return true; - } - - if (assignableTo.IsGenericType && assignableTo.IsDefinition) - { - if (assignableTo.TypeKind == TypeKind.Interface) - { - var matchingInterface = type.AllInterfaces.FirstOrDefault(i => i.IsGenericType && SymbolEqualityComparer.Default.Equals(i.OriginalDefinition, assignableTo)); - matchedType = matchingInterface; - return matchingInterface != null; - } - - var baseType = type.BaseType; - while (baseType != null) - { - if (baseType.IsGenericType && SymbolEqualityComparer.Default.Equals(baseType.OriginalDefinition, assignableTo)) - { - matchedType = baseType; - return true; - } - - baseType = baseType.BaseType; - } - } - else - { - if (assignableTo.TypeKind == TypeKind.Interface) - { - matchedType = assignableTo; - return type.AllInterfaces.Contains(assignableTo, SymbolEqualityComparer.Default); - } - - var baseType = type.BaseType; - while (baseType != null) - { - if (SymbolEqualityComparer.Default.Equals(baseType, assignableTo)) - { - matchedType = baseType; - return true; - } - - baseType = baseType.BaseType; - } - } - - matchedType = null; - return false; - } - - private static IEnumerable GetTypesFromAssembly(IAssemblySymbol assemblySymbol) - { - var @namespace = assemblySymbol.GlobalNamespace; - return GetTypesFromNamespace(@namespace); - - static IEnumerable GetTypesFromNamespace(INamespaceSymbol namespaceSymbol) - { - foreach (var member in namespaceSymbol.GetMembers()) - { - if (member is INamedTypeSymbol namedType) - { - yield return namedType; - } - else if (member is INamespaceSymbol nestedNamespace) - { - foreach (var type in GetTypesFromNamespace(nestedNamespace)) - { - yield return type; - } - } - } - } - } } diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs index acfdc14..6f58f0b 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs @@ -1,4 +1,6 @@ -using Microsoft.CodeAnalysis; +using System; +using System.Linq; +using Microsoft.CodeAnalysis; using ServiceScan.SourceGenerator.Model; using static ServiceScan.SourceGenerator.DiagnosticDescriptors; @@ -6,7 +8,7 @@ namespace ServiceScan.SourceGenerator; public partial class DependencyInjectionGenerator { - private static DiagnosticModel ParseMethodModel(GeneratorAttributeSyntaxContext context) + private static DiagnosticModel ParseRegisterMethodModel(GeneratorAttributeSyntaxContext context) { if (context.TargetSymbol is not IMethodSymbol method) return null; @@ -14,27 +16,82 @@ private static DiagnosticModel ParseMethodModel(Gener if (!method.IsPartialDefinition) return Diagnostic.Create(NotPartialDefinition, method.Locations[0]); - var serviceCollectionType = context.SemanticModel.Compilation.GetTypeByMetadataName("Microsoft.Extensions.DependencyInjection.IServiceCollection"); - - if (!method.ReturnsVoid && !SymbolEqualityComparer.Default.Equals(method.ReturnType, serviceCollectionType)) - return Diagnostic.Create(WrongReturnType, method.Locations[0]); - - if (method.Parameters.Length != 1 || !SymbolEqualityComparer.Default.Equals(method.Parameters[0].Type, serviceCollectionType)) - return Diagnostic.Create(WrongMethodParameters, method.Locations[0]); - + var hasCustomHandler = false; var attributeData = new AttributeModel[context.Attributes.Length]; for (var i = 0; i < context.Attributes.Length; i++) { - attributeData[i] = AttributeModel.Create(context.Attributes[i]); + var attribute = AttributeModel.Create(context.Attributes[i], method); + attributeData[i] = attribute; + + if (!attribute.HasSearchCriteria) + return Diagnostic.Create(MissingSearchCriteria, attribute.Location); + + hasCustomHandler |= attribute.CustomHandler != null; + if (hasCustomHandler && context.Attributes.Length != 1) + return Diagnostic.Create(OnlyOneCustomHandlerAllowed, attribute.Location); + + if (attribute.KeySelector != null) + { + var keySelectorMethod = method.ContainingType.GetMembers().OfType() + .FirstOrDefault(m => m.IsStatic && m.Name == attribute.KeySelector); + + if (keySelectorMethod is null) + return Diagnostic.Create(KeySelectorMethodNotFound, attribute.Location); + + if (keySelectorMethod.ReturnsVoid) + return Diagnostic.Create(KeySelectorMethodHasIncorrectSignature, attribute.Location); + + var validGenericKeySelector = keySelectorMethod.TypeArguments.Length == 1 && keySelectorMethod.Parameters.Length == 0; + var validNonGenericKeySelector = !keySelectorMethod.IsGenericMethod && keySelectorMethod.Parameters is [{ Type.Name: nameof(Type) }]; + + if (!validGenericKeySelector && !validNonGenericKeySelector) + return Diagnostic.Create(KeySelectorMethodHasIncorrectSignature, attribute.Location); + } + + if (attribute.CustomHandler != null) + { + var customHandlerMethod = method.ContainingType.GetMembers().OfType() + .FirstOrDefault(m => m.IsStatic && m.Name == attribute.CustomHandler); - if (!attributeData[i].HasSearchCriteria) - return Diagnostic.Create(MissingSearchCriteria, attributeData[i].Location); + if (customHandlerMethod is null) + return Diagnostic.Create(CustomHandlerMethodNotFound, attribute.Location); + + if (!customHandlerMethod.IsGenericMethod) + return Diagnostic.Create(CustomHandlerMethodHasIncorrectSignature, attribute.Location); + + var typesMatch = Enumerable.SequenceEqual( + method.Parameters.Select(p => p.Type), + customHandlerMethod.Parameters.Select(p => p.Type), + SymbolEqualityComparer.Default); + + if (!typesMatch) + return Diagnostic.Create(CustomHandlerMethodHasIncorrectSignature, attribute.Location); + } if (attributeData[i].HasErrors) return null; } + if (!hasCustomHandler) + { + var serviceCollectionType = context.SemanticModel.Compilation.GetTypeByMetadataName("Microsoft.Extensions.DependencyInjection.IServiceCollection"); + + if (!method.ReturnsVoid && !SymbolEqualityComparer.Default.Equals(method.ReturnType, serviceCollectionType)) + return Diagnostic.Create(WrongReturnType, method.Locations[0]); + + if (method.Parameters.Length != 1 || !SymbolEqualityComparer.Default.Equals(method.Parameters[0].Type, serviceCollectionType)) + return Diagnostic.Create(WrongMethodParameters, method.Locations[0]); + } + else + { + if (method.IsExtensionMethod && !method.ReturnsVoid && + (method.Parameters.Length == 0 || !SymbolEqualityComparer.Default.Equals(method.Parameters[0].Type, method.ReturnType))) + { + return Diagnostic.Create(WrongReturnTypeForCustomHandler, method.Locations[0]); + } + } + var model = MethodModel.Create(method, context.TargetNode); - return new MethodWithAttributesModel(model, new EquatableArray(attributeData)); + return new MethodWithAttributesModel(model, [.. attributeData]); } } diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs index c9f1fe8..654cfb2 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs @@ -3,6 +3,7 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Text; +using ServiceScan.SourceGenerator.Extensions; using ServiceScan.SourceGenerator.Model; namespace ServiceScan.SourceGenerator; @@ -12,12 +13,15 @@ public partial class DependencyInjectionGenerator : IIncrementalGenerator { public void Initialize(IncrementalGeneratorInitializationContext context) { - context.RegisterPostInitializationOutput(context => context.AddSource("GenerateServiceRegistrationsAttribute.Generated.cs", SourceText.From(GenerateAttributeSource.Source, Encoding.UTF8))); + context.RegisterPostInitializationOutput(context => + { + context.AddSource("ServiceScanAttributes.Generated.cs", SourceText.From(GenerateAttributeSource.Source, Encoding.UTF8)); + }); var methodProvider = context.SyntaxProvider.ForAttributeWithMetadataName( "ServiceScan.SourceGenerator.GenerateServiceRegistrationsAttribute", predicate: static (syntaxNode, ct) => syntaxNode is MethodDeclarationSyntax methodSyntax, - transform: static (context, ct) => ParseMethodModel(context)) + transform: static (context, ct) => ParseRegisterMethodModel(context)) .Where(method => method != null); var combinedProvider = methodProvider.Combine(context.CompilationProvider) @@ -35,28 +39,30 @@ public void Initialize(IncrementalGeneratorInitializationContext context) if (src.Model == null) return; - var (method, registrations) = src.Model; - string source = GenerateSource(method, registrations); + var (method, registrations, customHandling) = src.Model; + string source = customHandling.Count > 0 + ? GenerateCustomHandlingSource(method, customHandling) + : GenerateRegistrationsSource(method, registrations); + + source = source.ReplaceLineEndings(); context.AddSource($"{method.TypeName}_{method.MethodName}.Generated.cs", SourceText.From(source, Encoding.UTF8)); }); } - private static string GenerateSource(MethodModel method, EquatableArray registrations) + private static string GenerateRegistrationsSource(MethodModel method, EquatableArray registrations) { - var sb = new StringBuilder(); - - foreach (var registration in registrations) + var registrationsCode = string.Join("\n", registrations.Select(registration => { if (registration.IsOpenGeneric) { - sb.AppendLine($" .Add{registration.Lifetime}(typeof({registration.ServiceTypeName}), typeof({registration.ImplementationTypeName}))"); + return $" .Add{registration.Lifetime}(typeof({registration.ServiceTypeName}), typeof({registration.ImplementationTypeName}))"; } else { if (registration.ResolveImplementation) { - sb.AppendLine($" .Add{registration.Lifetime}<{registration.ServiceTypeName}>(s => s.GetRequiredService<{registration.ImplementationTypeName}>())"); + return $" .Add{registration.Lifetime}<{registration.ServiceTypeName}>(s => s.GetRequiredService<{registration.ImplementationTypeName}>())"; } else { @@ -70,10 +76,11 @@ private static string GenerateSource(MethodModel method, EquatableArray $"{registration.KeySelectorMethodName}(typeof({registration.ImplementationTypeName}))", null => null }; - sb.AppendLine($" .{addMethod}<{registration.ServiceTypeName}, {registration.ImplementationTypeName}>({keyMethodInvocation})"); + + return $" .{addMethod}<{registration.ServiceTypeName}, {registration.ImplementationTypeName}>({keyMethodInvocation})"; } } - } + })); var returnType = method.ReturnsVoid ? "void" : "IServiceCollection"; @@ -89,7 +96,36 @@ private static string GenerateSource(MethodModel method, EquatableArray customHandlers) + { + var invocations = string.Join("\n", customHandlers.Select(h => + $" {h.HandlerMethodName}<{h.TypeName}>({string.Join(", ", method.Parameters.Select(p => p.Name))});")); + + var namespaceDeclaration = method.Namespace is null ? "" : $"namespace {method.Namespace};"; + var parameters = string.Join(",", method.Parameters.Select((p, i) => + $"{(i == 0 && method.IsExtensionMethod ? "this" : "")} {p.Type} {p.Name}")); + + var methodBody = $$""" + {{invocations.Trim()}} + {{(method.ReturnsVoid ? "" : $"return {method.ParameterName};")}} + """; + + var source = $$""" + {{namespaceDeclaration}} + + {{method.TypeModifiers}} class {{method.TypeName}} + { + {{method.MethodModifiers}} {{method.ReturnType}} {{method.MethodName}}({{parameters}}) + { + {{methodBody.Trim()}} } } """; diff --git a/ServiceScan.SourceGenerator/DiagnosticDescriptors.cs b/ServiceScan.SourceGenerator/DiagnosticDescriptors.cs index ee67ebe..7f45897 100644 --- a/ServiceScan.SourceGenerator/DiagnosticDescriptors.cs +++ b/ServiceScan.SourceGenerator/DiagnosticDescriptors.cs @@ -52,4 +52,32 @@ public static class DiagnosticDescriptors "Usage", DiagnosticSeverity.Error, true); + + public static readonly DiagnosticDescriptor OnlyOneCustomHandlerAllowed = new("DI0008", + "Only one GenerateServiceRegistrations attribute is allowed when CustomHandler used", + "Only one GenerateServiceRegistrations attribute is allowed when CustomHandler used", + "Usage", + DiagnosticSeverity.Error, + true); + + public static readonly DiagnosticDescriptor WrongReturnTypeForCustomHandler = new("DI0009", + "Wrong return type", + "Method with CustomHandler must return void or 'this' parameter type", + "Usage", + DiagnosticSeverity.Error, + true); + + public static readonly DiagnosticDescriptor CustomHandlerMethodNotFound = new("DI0012", + "Provided CustomHandler method is not found", + "CustomHandler parameter should point to a static method in the class", + "Usage", + DiagnosticSeverity.Error, + true); + + public static readonly DiagnosticDescriptor CustomHandlerMethodHasIncorrectSignature = new("DI0011", + "Provided CustomHandler method has incorrect signature", + "CustomHandler method must be generic, and must have the same parameters as the method with an attribute", + "Usage", + DiagnosticSeverity.Error, + true); } diff --git a/ServiceScan.SourceGenerator/EquatableArray.cs b/ServiceScan.SourceGenerator/EquatableArray.cs index 77972d0..753a55b 100644 --- a/ServiceScan.SourceGenerator/EquatableArray.cs +++ b/ServiceScan.SourceGenerator/EquatableArray.cs @@ -2,6 +2,7 @@ using System.Collections; using System.Collections.Generic; using System.Collections.Immutable; +using System.Runtime.CompilerServices; namespace ServiceScan.SourceGenerator; @@ -9,6 +10,7 @@ namespace ServiceScan.SourceGenerator; /// Creates a new instance. /// /// The input to wrap. +[CollectionBuilder(typeof(EquatableArrayBuilder), nameof(EquatableArrayBuilder.Create))] internal readonly struct EquatableArray(T[] array) : IEquatable>, IEnumerable where T : IEquatable { public static readonly EquatableArray Empty = new([]); @@ -80,4 +82,9 @@ IEnumerator IEnumerable.GetEnumerator() { return !left.Equals(right); } +} + +file static class EquatableArrayBuilder +{ + public static EquatableArray Create(ReadOnlySpan values) where T : IEquatable => new(values.ToArray()); } \ No newline at end of file diff --git a/ServiceScan.SourceGenerator/Extensions/StringExtensions.cs b/ServiceScan.SourceGenerator/Extensions/StringExtensions.cs new file mode 100644 index 0000000..5602699 --- /dev/null +++ b/ServiceScan.SourceGenerator/Extensions/StringExtensions.cs @@ -0,0 +1,32 @@ +using System; + +namespace ServiceScan.SourceGenerator.Extensions; + +internal static class StringExtensions +{ + public static string ReplaceLineEndings(this string input) + { +#if NET6_0_OR_GREATER + return input.ReplaceLineEndings(); +#else +#pragma warning disable RS1035 // Do not use APIs banned for analyzers + return ReplaceLineEndings(input, Environment.NewLine); +#pragma warning restore RS1035 // Do not use APIs banned for analyzers +#endif + } + + public static string ReplaceLineEndings(this string input, string replacementText) + { +#if NET6_0_OR_GREATER + return input.ReplaceLineEndings(replacementText); +#else + // First normalize to LF + var lineFeedInput = input + .Replace("\r\n", "\n") + .Replace("\r", "\n"); + + // Then normalize to the replacement text + return lineFeedInput.Replace("\n", replacementText); +#endif + } +} diff --git a/ServiceScan.SourceGenerator/GenerateAttributeSource.cs b/ServiceScan.SourceGenerator/GenerateAttributeSource.cs index 62042d4..bc324aa 100644 --- a/ServiceScan.SourceGenerator/GenerateAttributeSource.cs +++ b/ServiceScan.SourceGenerator/GenerateAttributeSource.cs @@ -41,7 +41,7 @@ internal class GenerateServiceRegistrationsAttribute : Attribute /// /// If set to true, types will be registered with their actual type. - /// It can be combined with , in that case implemeted interfaces will be + /// It can be combined with , in that case implemented interfaces will be /// "forwarded" to "self" implementation. /// public bool AsSelf { get; set; } @@ -56,12 +56,20 @@ internal class GenerateServiceRegistrationsAttribute : Attribute public string? TypeNameFilter { get; set; } /// - /// Set this value to a static method name returning string. + /// Set this property to a static method name returning string. /// Returned value will be used as a key for the registration. /// Method should either be generic, or have a single parameter of type . /// /// nameof(GetKey) public string? KeySelector { get; set; } + + /// + /// Set this property to a static generic method name in the current class. + /// This method will be invoked for each type found by the filter instead of regular registration logic. + /// This property is incompatible with , , , + /// properties. + /// + public string? CustomHandler { get; set; } } """; } \ No newline at end of file diff --git a/ServiceScan.SourceGenerator/Model/AttributeModel.cs b/ServiceScan.SourceGenerator/Model/AttributeModel.cs index c437a74..40ba27a 100644 --- a/ServiceScan.SourceGenerator/Model/AttributeModel.cs +++ b/ServiceScan.SourceGenerator/Model/AttributeModel.cs @@ -10,6 +10,8 @@ record AttributeModel( string Lifetime, string? TypeNameFilter, string? KeySelector, + bool? KeySelectorGeneric, + string? CustomHandler, bool AsImplementedInterfaces, bool AsSelf, Location Location, @@ -17,7 +19,7 @@ record AttributeModel( { public bool HasSearchCriteria => TypeNameFilter != null || AssignableToTypeName != null; - public static AttributeModel Create(AttributeData attribute) + public static AttributeModel Create(AttributeData attribute, IMethodSymbol method) { var assemblyType = attribute.NamedArguments.FirstOrDefault(a => a.Key == "FromAssemblyOf").Value.Value as INamedTypeSymbol; var assignableTo = attribute.NamedArguments.FirstOrDefault(a => a.Key == "AssignableTo").Value.Value as INamedTypeSymbol; @@ -25,6 +27,20 @@ public static AttributeModel Create(AttributeData attribute) var asSelf = attribute.NamedArguments.FirstOrDefault(a => a.Key == "AsSelf").Value.Value is true; var typeNameFilter = attribute.NamedArguments.FirstOrDefault(a => a.Key == "TypeNameFilter").Value.Value as string; var keySelector = attribute.NamedArguments.FirstOrDefault(a => a.Key == "KeySelector").Value.Value as string; + var customHandler = attribute.NamedArguments.FirstOrDefault(a => a.Key == "CustomHandler").Value.Value as string; + + bool? keySelectorGeneric = null; + if (keySelector != null) + { + var keySelectorMethod = method.ContainingType.GetMembers() + .OfType() + .FirstOrDefault(m => m.IsStatic && m.Name == keySelector); + + if (keySelectorMethod != null) + { + keySelectorGeneric = keySelectorMethod.IsGenericMethod; + } + } if (string.IsNullOrWhiteSpace(typeNameFilter)) typeNameFilter = null; @@ -32,7 +48,7 @@ public static AttributeModel Create(AttributeData attribute) var assemblyOfTypeName = assemblyType?.ToFullMetadataName(); var assignableToTypeName = assignableTo?.ToFullMetadataName(); EquatableArray? assignableToGenericArguments = assignableTo != null && assignableTo.IsGenericType && !assignableTo.IsUnboundGenericType - ? new EquatableArray([.. assignableTo?.TypeArguments.Select(t => t.ToFullMetadataName())]) + ? [.. assignableTo?.TypeArguments.Select(t => t.ToFullMetadataName())] : null; var lifetime = attribute.NamedArguments.FirstOrDefault(a => a.Key == "Lifetime").Value.Value as int? switch @@ -48,12 +64,15 @@ public static AttributeModel Create(AttributeData attribute) var hasError = assemblyType is { TypeKind: TypeKind.Error } || assignableTo is { TypeKind: TypeKind.Error }; - return new(assignableToTypeName, + return new( + assignableToTypeName, assignableToGenericArguments, assemblyOfTypeName, lifetime, typeNameFilter, keySelector, + keySelectorGeneric, + customHandler, asImplementedInterfaces, asSelf, location, diff --git a/ServiceScan.SourceGenerator/Model/MethodImplementationModel.cs b/ServiceScan.SourceGenerator/Model/MethodImplementationModel.cs index 322c716..b372cc8 100644 --- a/ServiceScan.SourceGenerator/Model/MethodImplementationModel.cs +++ b/ServiceScan.SourceGenerator/Model/MethodImplementationModel.cs @@ -2,4 +2,5 @@ record MethodImplementationModel( MethodModel Method, - EquatableArray Registrations); + EquatableArray Registrations, + EquatableArray CustomHandlers); diff --git a/ServiceScan.SourceGenerator/Model/MethodModel.cs b/ServiceScan.SourceGenerator/Model/MethodModel.cs index 9c74501..ab17367 100644 --- a/ServiceScan.SourceGenerator/Model/MethodModel.cs +++ b/ServiceScan.SourceGenerator/Model/MethodModel.cs @@ -1,8 +1,11 @@ -using Microsoft.CodeAnalysis; +using System.Linq; +using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; namespace ServiceScan.SourceGenerator.Model; +record ParameterModel(string Type, string Name); + record MethodModel( string Namespace, string TypeName, @@ -10,36 +13,30 @@ record MethodModel( string TypeModifiers, string MethodName, string MethodModifiers, - string ParameterName, + EquatableArray Parameters, bool IsExtensionMethod, - bool ReturnsVoid) + bool ReturnsVoid, + string ReturnType) { + public string ParameterName => Parameters.Single().Name; + public static MethodModel Create(IMethodSymbol method, SyntaxNode syntax) { - var parameterName = method.Parameters[0].Name; + EquatableArray parameters = [.. method.Parameters.Select(p => new ParameterModel(p.Type.ToDisplayString(), p.Name))]; + + var typeSyntax = syntax.FirstAncestorOrSelf(); return new MethodModel( Namespace: method.ContainingNamespace.IsGlobalNamespace ? null : method.ContainingNamespace.ToDisplayString(), TypeName: method.ContainingType.Name, TypeMetadataName: method.ContainingType.ToFullMetadataName(), - TypeModifiers: GetModifiers(GetTypeSyntax(syntax)), + TypeModifiers: GetModifiers(typeSyntax), MethodName: method.Name, MethodModifiers: GetModifiers(syntax), - ParameterName: parameterName, + Parameters: parameters, IsExtensionMethod: method.IsExtensionMethod, - ReturnsVoid: method.ReturnsVoid); - } - - private static TypeDeclarationSyntax GetTypeSyntax(SyntaxNode node) - { - var parent = node.Parent; - while (parent != null) - { - if (parent is TypeDeclarationSyntax t) - return t; - parent = parent.Parent; - } - return null; + ReturnsVoid: method.ReturnsVoid, + ReturnType: method.ReturnType.ToDisplayString()); } private static string GetModifiers(SyntaxNode syntax) diff --git a/ServiceScan.SourceGenerator/Model/ServiceRegistrationModel.cs b/ServiceScan.SourceGenerator/Model/ServiceRegistrationModel.cs index 8a95a34..5b7b3e2 100644 --- a/ServiceScan.SourceGenerator/Model/ServiceRegistrationModel.cs +++ b/ServiceScan.SourceGenerator/Model/ServiceRegistrationModel.cs @@ -8,3 +8,7 @@ record ServiceRegistrationModel( bool IsOpenGeneric, string? KeySelectorMethodName, bool? KeySelectorMethodGeneric); + +record CustomHandlerModel( + string HandlerMethodName, + string TypeName); diff --git a/ServiceScan.SourceGenerator/ServiceScan.SourceGenerator.csproj b/ServiceScan.SourceGenerator/ServiceScan.SourceGenerator.csproj index 361147b..530db09 100644 --- a/ServiceScan.SourceGenerator/ServiceScan.SourceGenerator.csproj +++ b/ServiceScan.SourceGenerator/ServiceScan.SourceGenerator.csproj @@ -28,7 +28,7 @@ - + all runtime; build; native; contentfiles; analyzers; buildtransitive