Skip to content

Commit

Permalink
refactor to use only one path for collecting types from assembly or s…
Browse files Browse the repository at this point in the history
…ource
  • Loading branch information
dferretti committed Jun 20, 2024
1 parent 2688fb5 commit 193254b
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 166 deletions.
64 changes: 5 additions & 59 deletions ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
using System;
using System.Collections.Generic;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
using System.Text.RegularExpressions;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
Expand Down Expand Up @@ -87,61 +85,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
});
}

private static List<ServiceRegistrationModel> GetRegistrationsFromSourceCode(AttributeModel attribute, ImmutableArray<TypeModel> sourceTypes)
{
var assignableToType = attribute.AssignableToType;

var types = sourceTypes
.GroupBy(t => t.DisplayString).Select(g => g.First()) // distinct-by fully qualified name to account for partial classes
.Where(t => !t.IsAbstract && !t.IsStatic && t.CanBeReferencedByName && t.TypeKind == TypeKind.Class);

if (attribute.TypeNameFilter != null)
{
var regex = $"^({Regex.Escape(attribute.TypeNameFilter).Replace(@"\*", ".*").Replace(",", "|")})$";
types = types.Where(t => Regex.IsMatch(t.DisplayString, regex));
}

var registrations = new List<ServiceRegistrationModel>();
foreach (var t in types)
{
var implementationType = t;

TypeModel? matchedType = null;
if (assignableToType != null && !SymbolExtensions.IsAssignableTo(implementationType, assignableToType, out matchedType))
continue;

IEnumerable<TypeModel> serviceTypes = (attribute.AsSelf, attribute.AsImplementedInterfaces) switch
{
(true, true) => new[] { implementationType }.Concat(implementationType.AllInterfaces),
(false, true) => implementationType.AllInterfaces,
(true, false) => [implementationType],
_ => [matchedType ?? implementationType]
};

foreach (var serviceType in serviceTypes)
{
if (implementationType.IsGenericType)
{
var implementationTypeName = implementationType.UnboundGenericDisplayString;
var serviceTypeName = serviceType.IsGenericType
? serviceType.UnboundGenericDisplayString
: serviceType.DisplayString;

var registration = new ServiceRegistrationModel(attribute.Lifetime, serviceTypeName, implementationTypeName, false, true);
registrations.Add(registration);
}
else
{
var shouldResolve = attribute.AsSelf && attribute.AsImplementedInterfaces && implementationType != serviceType;
var registration = new ServiceRegistrationModel(attribute.Lifetime, serviceType.DisplayString, implementationType.DisplayString, shouldResolve, false);
registrations.Add(registration);
}
}
}

return registrations;
}

private static DiagnosticModel<MethodImplementationModel> FindServicesToRegister((DiagnosticModel<MethodWithAttributesModel>, ImmutableArray<TypeModel>) context)
{
var (diagnosticModel, sourceTypes) = context;
Expand All @@ -152,13 +95,16 @@ private static DiagnosticModel<MethodImplementationModel> FindServicesToRegister

var (method, attributes) = diagnosticModel.Model;

var types = sourceTypes
.GroupBy(t => t.DisplayString).Select(g => g.First()); // distinct-by fully qualified name to account for partial classes

var registrations = new List<ServiceRegistrationModel>();

foreach (var attribute in attributes)
{
// get registrations from the assembly specified in the attribute or from source code
var regs = attribute.RegistrationsFromAssembly?.ToList()
?? GetRegistrationsFromSourceCode(attribute, sourceTypes);
?? SymbolExtensions.GetRegistrations(types, attribute.AssignableToType, attribute.TypeNameFilter, attribute.AsSelf, attribute.AsImplementedInterfaces, attribute.Lifetime);

if (!regs.Any())
diagnostic ??= Diagnostic.Create(NoMatchingTypesFound, attribute.Location);
Expand Down
74 changes: 11 additions & 63 deletions ServiceScan.SourceGenerator/Model/AttributeModel.cs
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;
using System.Linq;
using Microsoft.CodeAnalysis;

namespace ServiceScan.SourceGenerator.Model;

record AttributeModel(
EquatableArray<string>? AssignableToGenericArguments,
TypeModel? AssignableToType,
EquatableArray<ServiceRegistrationModel>? RegistrationsFromAssembly, // if null, use types found from source code
string Lifetime,
string? TypeNameFilter,
bool AsImplementedInterfaces,
bool AsSelf,
Location Location,
bool HasErrors,
TypeModel? AssignableToType)
bool HasErrors)
{
public bool HasSearchCriteria => TypeNameFilter != null || AssignableToType != null;

Expand All @@ -26,10 +22,6 @@ public static AttributeModel Create(AttributeData attribute, Compilation compila
var asImplementedInterfaces = attribute.NamedArguments.FirstOrDefault(a => a.Key == "AsImplementedInterfaces").Value.Value is true;
var asSelf = attribute.NamedArguments.FirstOrDefault(a => a.Key == "AsSelf").Value.Value is true;
var typeNameFilter = attribute.NamedArguments.FirstOrDefault(a => a.Key == "TypeNameFilter").Value.Value as string;

EquatableArray<string>? assignableToGenericArguments = assignableTo != null && assignableTo.IsGenericType && !assignableTo.IsUnboundGenericType
? new EquatableArray<string>([.. assignableTo?.TypeArguments.Select(t => t.ToFullMetadataName())])
: null;

if (string.IsNullOrWhiteSpace(typeNameFilter))
typeNameFilter = null;
Expand All @@ -41,17 +33,16 @@ public static AttributeModel Create(AttributeData attribute, Compilation compila
_ => "Transient"
};

var registrations = GetRegistrationsFromAssembly(assemblyType, compilation, typeNameFilter, asSelf, asImplementedInterfaces, assignableTo, assignableToGenericArguments, lifetime);

var typeModel = assignableTo is not null ? TypeModel.Create(assignableTo) : null;
var assignableToTypeModel = assignableTo is not null ? TypeModel.Create(assignableTo) : null;
var registrations = GetRegistrationsFromAssembly(assemblyType, compilation, typeNameFilter, asSelf, asImplementedInterfaces, assignableToTypeModel, lifetime);

var syntax = attribute.ApplicationSyntaxReference.SyntaxTree;
var textSpan = attribute.ApplicationSyntaxReference.Span;
var location = Location.Create(syntax, textSpan);

var hasError = assemblyType is { TypeKind: TypeKind.Error } || assignableTo is { TypeKind: TypeKind.Error };

return new(assignableToGenericArguments, registrations, lifetime, typeNameFilter, asImplementedInterfaces, asSelf, location, hasError, typeModel);
return new(assignableToTypeModel, registrations, lifetime, typeNameFilter, asImplementedInterfaces, asSelf, location, hasError);
}

private static EquatableArray<ServiceRegistrationModel>? GetRegistrationsFromAssembly(
Expand All @@ -60,8 +51,7 @@ public static AttributeModel Create(AttributeData attribute, Compilation compila
string? typeNameFilter,
bool asSelf,
bool asImplementedInterfaces,
INamedTypeSymbol? assignableToType,
EquatableArray<string>? assignableToGenericArguments,
TypeModel? assignableToType,
string lifetime)
{
if (fromAssemblyOf is null)
Expand All @@ -71,53 +61,11 @@ public static AttributeModel Create(AttributeData attribute, Compilation compila
if (SymbolEqualityComparer.Default.Equals(fromAssemblyOf.ContainingAssembly, compilation.Assembly))
return null;

var registrations = new List<ServiceRegistrationModel>();

var types = fromAssemblyOf.ContainingAssembly.GetTypesFromAssembly()
.Where(t => !t.IsAbstract && !t.IsStatic && t.CanBeReferencedByName && t.TypeKind == TypeKind.Class);

if (typeNameFilter != null)
{
var regex = $"^({Regex.Escape(typeNameFilter).Replace(@"\*", ".*").Replace(",", "|")})$";
types = types.Where(t => Regex.IsMatch(t.ToDisplayString(), regex));
}

foreach (var t in types)
{
var implementationType = t;

INamedTypeSymbol matchedType = null;
if (assignableToType != null && !SymbolExtensions.IsAssignableTo(implementationType, assignableToType, out matchedType))
continue;

IEnumerable<INamedTypeSymbol> serviceTypes = (asSelf, asImplementedInterfaces) switch
{
(true, true) => new[] { implementationType }.Concat(implementationType.AllInterfaces),
(false, true) => implementationType.AllInterfaces,
(true, false) => [implementationType],
_ => [matchedType ?? implementationType]
};

foreach (var serviceType in serviceTypes)
{
if (implementationType.IsGenericType)
{
var implementationTypeName = implementationType.ConstructUnboundGenericType().ToDisplayString();
var serviceTypeName = serviceType.IsGenericType
? serviceType.ConstructUnboundGenericType().ToDisplayString()
: serviceType.ToDisplayString();
var types = fromAssemblyOf.ContainingAssembly
.GetTypesFromAssembly()
.Select(TypeModel.Create);

var registration = new ServiceRegistrationModel(lifetime, serviceTypeName, implementationTypeName, false, true);
registrations.Add(registration);
}
else
{
var shouldResolve = asSelf && asImplementedInterfaces && !SymbolEqualityComparer.Default.Equals(implementationType, serviceType);
var registration = new ServiceRegistrationModel(lifetime, serviceType.ToDisplayString(), implementationType.ToDisplayString(), shouldResolve, false);
registrations.Add(registration);
}
}
}
var registrations = SymbolExtensions.GetRegistrations(types, assignableToType, typeNameFilter, asSelf, asImplementedInterfaces, lifetime);

return new(registrations.ToArray());
}
Expand Down
11 changes: 4 additions & 7 deletions ServiceScan.SourceGenerator/Model/TypeModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ sealed record TypeModel

public required bool IsUnboundGenericType { get; init; }

//public required bool IsDefinition { get; init; }

public required EquatableArray<TypeModel> AllInterfaces { get; init; }

// OriginalDefinition is null for unbound generic types. For bound generic types, OriginalDefinition is the TypeModel representing the unbound generic type.
// OriginalDefinition is null for unbound generic types (to avoid recursion). For bound generic types, OriginalDefinition is the TypeModel representing the unbound generic type.
// For non-generic types, OriginalDefinition is null.
public required TypeModel? OriginalDefinition { get; init; }

public required TypeModel? BaseType { get; init; }
Expand All @@ -37,22 +36,20 @@ sealed record TypeModel
if (node is not TypeDeclarationSyntax typeDeclaration)
return null;

var symbol = semanticModel.GetDeclaredSymbol(typeDeclaration, cancel) as INamedTypeSymbol;
if (symbol is null)
if (semanticModel.GetDeclaredSymbol(typeDeclaration, cancel) is not INamedTypeSymbol symbol)
return null;

return Create(symbol);
}

public static TypeModel? Create(INamedTypeSymbol symbol)
public static TypeModel Create(INamedTypeSymbol symbol)
{
return new TypeModel
{
IsAbstract = symbol.IsAbstract,
IsStatic = symbol.IsStatic,
IsGenericType = symbol.IsGenericType,
IsUnboundGenericType = symbol.IsUnboundGenericType,
//IsDefinition = symbol.IsDefinition,
CanBeReferencedByName = symbol.CanBeReferencedByName,
TypeKind = symbol.TypeKind,
DisplayString = symbol.ToDisplayString(),
Expand Down
79 changes: 42 additions & 37 deletions ServiceScan.SourceGenerator/SymbolExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;
using Microsoft.CodeAnalysis;
using ServiceScan.SourceGenerator.Model;

Expand Down Expand Up @@ -38,58 +39,62 @@ static IEnumerable<INamedTypeSymbol> GetTypesFromNamespace(INamespaceSymbol name
}
}

public static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assignableTo, out INamedTypeSymbol matchedType)
public static List<ServiceRegistrationModel> GetRegistrations(
IEnumerable<TypeModel> types,
TypeModel? assignableToType,
string? typeNameFilter,
bool asSelf,
bool asImplementedInterfaces,
string lifetime)
{
if (SymbolEqualityComparer.Default.Equals(type, assignableTo))
var registrations = new List<ServiceRegistrationModel>();

types = types.Where(t => !t.IsAbstract && !t.IsStatic && t.CanBeReferencedByName && t.TypeKind == TypeKind.Class);

if (typeNameFilter != null)
{
matchedType = type;
return true;
var regex = $"^({Regex.Escape(typeNameFilter).Replace(@"\*", ".*").Replace(",", "|")})$";
types = types.Where(t => Regex.IsMatch(t.DisplayString, regex));
}

if (assignableTo.IsGenericType && assignableTo.IsDefinition)
foreach (var t in types)
{
if (assignableTo.TypeKind == TypeKind.Interface)
{
var matchingInterface = type.AllInterfaces.FirstOrDefault(i => i.IsGenericType && SymbolEqualityComparer.Default.Equals(i.OriginalDefinition, assignableTo));
matchedType = matchingInterface;
return matchingInterface != null;
}
var implementationType = t;

var baseType = type.BaseType;
while (baseType != null)
{
if (baseType.IsGenericType && SymbolEqualityComparer.Default.Equals(baseType.OriginalDefinition, assignableTo))
{
matchedType = baseType;
return true;
}
TypeModel? matchedType = null;
if (assignableToType != null && !IsAssignableTo(implementationType, assignableToType, out matchedType))
continue;

baseType = baseType.BaseType;
}
}
else
{
if (assignableTo.TypeKind == TypeKind.Interface)
IEnumerable<TypeModel> serviceTypes = (asSelf, asImplementedInterfaces) switch
{
matchedType = assignableTo;
return type.AllInterfaces.Contains(assignableTo, SymbolEqualityComparer.Default);
}
(true, true) => new[] { implementationType }.Concat(implementationType.AllInterfaces),
(false, true) => implementationType.AllInterfaces,
(true, false) => [implementationType],
_ => [matchedType ?? implementationType]
};

var baseType = type.BaseType;
while (baseType != null)
foreach (var serviceType in serviceTypes)
{
if (SymbolEqualityComparer.Default.Equals(baseType, assignableTo))
if (implementationType.IsGenericType)
{
matchedType = baseType;
return true;
}
var implementationTypeName = implementationType.UnboundGenericDisplayString;
var serviceTypeName = serviceType.IsGenericType
? serviceType.UnboundGenericDisplayString
: serviceType.DisplayString;

baseType = baseType.BaseType;
var registration = new ServiceRegistrationModel(lifetime, serviceTypeName, implementationTypeName, false, true);
registrations.Add(registration);
}
else
{
var shouldResolve = asSelf && asImplementedInterfaces && implementationType != serviceType;
var registration = new ServiceRegistrationModel(lifetime, serviceType.DisplayString, implementationType.DisplayString, shouldResolve, false);
registrations.Add(registration);
}
}
}

matchedType = null;
return false;
return registrations;
}

public static bool IsAssignableTo(TypeModel type, TypeModel assignableTo, out TypeModel matchedType)
Expand Down

0 comments on commit 193254b

Please sign in to comment.