Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom handling for types #15

Merged
merged 5 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,37 @@ It adds MediatR requests handlers, although you might need to add other types li
private static partial IServiceCollection AddRepositories(this IServiceCollection services);
```

### Add AspNetCore Minimal API endpoints
You can add custom type handler, if you need to do something non-trivial with that type. For example, you can automatically discover
and map Minimal API endpoints:
```csharp
public interface IEndpoint
{
abstract static void MapEndpoint(IEndpointRouteBuilder endpoints);
}

public class HelloWorldEndpoint : IEndpoint
{
public static void MapEndpoint(IEndpointRouteBuilder endpoints)
{
endpoints.MapGet("/", () => "Hello World!");
}
}

public static partial class ServiceCollectionExtensions
{
[GenerateServiceRegistrations(AssignableTo = typeof(IEndpoint), CustomHandler = nameof(MapEndpoint))]
public static partial IEndpointRouteBuilder MapEndpoints(this IEndpointRouteBuilder endpoints);

private static void MapEndpoint<T>(IEndpointRouteBuilder endpoints) where T : IEndpoint
{
T.MapEndpoint(endpoints);
}
}

```


## Parameters

`GenerateServiceRegistrations` attribute has the following properties:
Expand All @@ -82,6 +113,7 @@ private static partial IServiceCollection AddRepositories(this IServiceCollectio
| **AssignableTo** | Set the type that the registered types must be assignable to. Types will be registered with this type as the service type, unless `AsImplementedInterfaces` or `AsSelf` is set. |
| **Lifetime** | Set the lifetime of the registered services. `ServiceLifetime.Transient` is used if not specified. |
| **AsImplementedInterfaces** | If true, the registered types will be registered as implemented interfaces instead of their actual type. |
| **AsSelf** | If true, types will be registered with their actual type. It can be combined with `AsImplementedInterfaces`. In that case implemeted interfaces will be "forwarded" to an actual implementation type |
| **AsSelf** | If true, types will be registered with their actual type. It can be combined with `AsImplementedInterfaces`. In that case implemented interfaces will be "forwarded" to an actual implementation type |
| **TypeNameFilter** | Set this value to filter the types to register by their full name. You can use '*' wildcards. You can also use ',' to separate multiple filters. |
| **KeySelector** | Set this value to a static method name returning string. Returned value will be used as a key for the registration. Method should either be generic, or have a single parameter of type `Type`. |
| **CustomHandler** | Set this property to a static generic method name in the current class. Set this property to a static generic method name in the current class. This property is incompatible with 'Lifetime', 'AsImplementedInterfaces', 'AsSelf', 'KeySelector' properties. |
179 changes: 179 additions & 0 deletions ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.Extensions.DependencyInjection;
using Xunit;

namespace ServiceScan.SourceGenerator.Tests;

public class CustomHandlerTests
{
private readonly DependencyInjectionGenerator _generator = new();

[Fact]
public void CustomHandlerWithNoParameters()
{
var source = $$"""
using ServiceScan.SourceGenerator;

namespace GeneratorTests;

public static partial class ServicesExtensions
{
[GenerateServiceRegistrations(AssignableTo = typeof(IService), CustomHandler = nameof(HandleType))]
public static partial void ProcessServices();

private static void HandleType<T>() => System.Console.WriteLine(typeof(T).Name);
}
""";

var services =
"""
namespace GeneratorTests;

public interface IService { }
public class MyService1 : IService { }
public class MyService2 : IService { }
""";

var compilation = CreateCompilation(source, services);

var results = CSharpGeneratorDriver
.Create(_generator)
.RunGenerators(compilation)
.GetRunResult();

var expected = $$"""
namespace GeneratorTests;

public static partial class ServicesExtensions
{
public static partial void ProcessServices()
{
HandleType<GeneratorTests.MyService1>();
HandleType<GeneratorTests.MyService2>();
}
}
""";
Assert.Equal(expected, results.GeneratedTrees[1].ToString());
}

[Fact]
public void CustomHandlerWithParameters()
{
var source = $$"""
using ServiceScan.SourceGenerator;

namespace GeneratorTests;

public static partial class ServicesExtensions
{
[GenerateServiceRegistrations(TypeNameFilter = "*Service", CustomHandler = nameof(HandleType))]
public static partial void ProcessServices(string value, decimal number);

private static void HandleType<T>(string value, decimal number) => System.Console.WriteLine(value + number.ToString() + typeof(T).Name);
}
""";

var services =
"""
namespace GeneratorTests;

public class MyFirstService {}
public class MySecondService {}
public class ServiceWithNonMatchingName {}
""";

var compilation = CreateCompilation(source, services);

var results = CSharpGeneratorDriver
.Create(_generator)
.RunGenerators(compilation)
.GetRunResult();

var expected = $$"""
namespace GeneratorTests;

public static partial class ServicesExtensions
{
public static partial void ProcessServices( string value, decimal number)
{
HandleType<GeneratorTests.MyFirstService>(value, number);
HandleType<GeneratorTests.MySecondService>(value, number);
}
}
""";
Assert.Equal(expected, results.GeneratedTrees[1].ToString());
}

[Fact]
public void CustomHandlerExtensionMethod()
{
var source = $$"""
using ServiceScan.SourceGenerator;

namespace GeneratorTests;

public static partial class ServicesExtensions
{
[GenerateServiceRegistrations(AssignableTo = typeof(IService), CustomHandler = nameof(HandleType))]
public static partial IServices ProcessServices(this IServices services);

private static void HandleType<T>(IServices services) where T:IService, new() => services.Add(new T());
}
""";

var services =
"""
namespace GeneratorTests;

public interface IServices
{
void Add(IService service);
}

public interface IService { }
public class MyService1 : IService { }
public class MyService2 : IService { }
""";

var compilation = CreateCompilation(source, services);

var results = CSharpGeneratorDriver
.Create(_generator)
.RunGenerators(compilation)
.GetRunResult();

var expected = $$"""
namespace GeneratorTests;

public static partial class ServicesExtensions
{
public static partial GeneratorTests.IServices ProcessServices(this GeneratorTests.IServices services)
{
HandleType<GeneratorTests.MyService1>(services);
HandleType<GeneratorTests.MyService2>(services);
return services;
}
}
""";
Assert.Equal(expected, results.GeneratedTrees[1].ToString());
}

private static Compilation CreateCompilation(params string[] source)
{
var path = Path.GetDirectoryName(typeof(object).Assembly.Location)!;
var runtimeAssemblyPath = Path.Combine(path, "System.Runtime.dll");

var runtimeReference = MetadataReference.CreateFromFile(typeof(object).Assembly.Location);

return CSharpCompilation.Create("compilation",
source.Select(s => CSharpSyntaxTree.ParseText(s)),
[
MetadataReference.CreateFromFile(typeof(object).Assembly.Location),
MetadataReference.CreateFromFile(runtimeAssemblyPath),
MetadataReference.CreateFromFile(typeof(IServiceCollection).Assembly.Location),
MetadataReference.CreateFromFile(typeof(External.IExternalService).Assembly.Location),
],
new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
}
}
4 changes: 0 additions & 4 deletions ServiceScan.SourceGenerator.Tests/Sources.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@

public static class Sources
{
public const string Services = """

""";

public static string MethodWithAttribute(string attribute)
{
attribute = attribute.Replace("\n", "\n ");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;
using Microsoft.CodeAnalysis;
using ServiceScan.SourceGenerator.Model;

namespace ServiceScan.SourceGenerator;

public partial class DependencyInjectionGenerator
{
private static IEnumerable<(INamedTypeSymbol Type, INamedTypeSymbol? MatchedAssignableType)> FilterTypes
(Compilation compilation, AttributeModel attribute, INamedTypeSymbol containingType)
{
var assembly = (attribute.AssemblyOfTypeName is null
? containingType
: compilation.GetTypeByMetadataName(attribute.AssemblyOfTypeName)).ContainingAssembly;

var assignableToType = attribute.AssignableToTypeName is null
? null
: compilation.GetTypeByMetadataName(attribute.AssignableToTypeName);

if (assignableToType != null && attribute.AssignableToGenericArguments != null)
{
var typeArguments = attribute.AssignableToGenericArguments.Value.Select(t => compilation.GetTypeByMetadataName(t)).ToArray();
assignableToType = assignableToType.Construct(typeArguments);
}

foreach (var type in GetTypesFromAssembly(assembly))
{
if (type.IsAbstract || type.IsStatic || !type.CanBeReferencedByName || type.TypeKind != TypeKind.Class)
continue;

if (attribute.TypeNameFilter != null)
{
var regex = $"^({Regex.Escape(attribute.TypeNameFilter).Replace(@"\*", ".*").Replace(",", "|")})$";

if (!Regex.IsMatch(type.ToDisplayString(), regex))
continue;
}

INamedTypeSymbol matchedType = null;
if (assignableToType != null && !IsAssignableTo(type, assignableToType, out matchedType))
continue;

yield return (type, matchedType);
}
}

private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assignableTo, out INamedTypeSymbol matchedType)
{
if (SymbolEqualityComparer.Default.Equals(type, assignableTo))
{
matchedType = type;
return true;
}

if (assignableTo.IsGenericType && assignableTo.IsDefinition)
{
if (assignableTo.TypeKind == TypeKind.Interface)
{
var matchingInterface = type.AllInterfaces.FirstOrDefault(i => i.IsGenericType && SymbolEqualityComparer.Default.Equals(i.OriginalDefinition, assignableTo));
matchedType = matchingInterface;
return matchingInterface != null;
}

var baseType = type.BaseType;
while (baseType != null)
{
if (baseType.IsGenericType && SymbolEqualityComparer.Default.Equals(baseType.OriginalDefinition, assignableTo))
{
matchedType = baseType;
return true;
}

baseType = baseType.BaseType;
}
}
else
{
if (assignableTo.TypeKind == TypeKind.Interface)
{
matchedType = assignableTo;
return type.AllInterfaces.Contains(assignableTo, SymbolEqualityComparer.Default);
}

var baseType = type.BaseType;
while (baseType != null)
{
if (SymbolEqualityComparer.Default.Equals(baseType, assignableTo))
{
matchedType = baseType;
return true;
}

baseType = baseType.BaseType;
}
}

matchedType = null;
return false;
}

private static IEnumerable<INamedTypeSymbol> GetTypesFromAssembly(IAssemblySymbol assemblySymbol)
{
var @namespace = assemblySymbol.GlobalNamespace;
return GetTypesFromNamespace(@namespace);

static IEnumerable<INamedTypeSymbol> GetTypesFromNamespace(INamespaceSymbol namespaceSymbol)
{
foreach (var member in namespaceSymbol.GetMembers())
{
if (member is INamedTypeSymbol namedType)
{
yield return namedType;
}
else if (member is INamespaceSymbol nestedNamespace)
{
foreach (var type in GetTypesFromNamespace(nestedNamespace))
{
yield return type;
}
}
}
}
}
}
Loading
Loading