From c37de929d0a46ea5d5c99481039e523f447cb131 Mon Sep 17 00:00:00 2001
From: Oleksandr Liakhevych <liakh.oleksandr32@gmail.com>
Date: Mon, 23 Dec 2024 17:50:07 +0200
Subject: [PATCH] Add nested types (#17)

---
 .../AddServicesTests.cs                       | 36 +++++++++++++++++++
 ...ependencyInjectionGenerator.FilterTypes.cs | 19 +++++-----
 2 files changed, 46 insertions(+), 9 deletions(-)

diff --git a/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs b/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs
index 3f3c70e..9c05990 100644
--- a/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs
+++ b/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs
@@ -462,6 +462,42 @@ public class MyService: IServiceA, IServiceB {}
         Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString());
     }
 
+    [Fact]
+    public void AddNestedTypes()
+    {
+        var attribute = "[GenerateServiceRegistrations(AssignableTo = typeof(IService))]";
+        var compilation = CreateCompilation(Sources.MethodWithAttribute(attribute),
+            """
+            namespace GeneratorTests;
+
+            public interface IService { }
+            
+            public class ParentType1
+            {
+                public class MyService1 : IService { }
+                public class MyService2 : IService { }
+            }
+            
+            public class ParentType2
+            {
+                public class MyService1 : IService { }
+            }
+            """);
+
+        var results = CSharpGeneratorDriver
+            .Create(_generator)
+            .RunGenerators(compilation)
+            .GetRunResult();
+
+        var registrations = $"""
+            return services
+                .AddTransient<GeneratorTests.IService, GeneratorTests.ParentType1.MyService1>()
+                .AddTransient<GeneratorTests.IService, GeneratorTests.ParentType1.MyService2>()
+                .AddTransient<GeneratorTests.IService, GeneratorTests.ParentType2.MyService1>();
+            """;
+        Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString());
+    }
+
     [Fact]
     public void AddAsKeyedServices_GenericMethod()
     {
diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs
index 60cc5f8..64d9c84 100644
--- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs
+++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs
@@ -104,19 +104,20 @@ private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assig
     private static IEnumerable<INamedTypeSymbol> GetTypesFromAssembly(IAssemblySymbol assemblySymbol)
     {
         var @namespace = assemblySymbol.GlobalNamespace;
-        return GetTypesFromNamespace(@namespace);
+        return GetTypesFromNamespaceOrType(@namespace);
 
-        static IEnumerable<INamedTypeSymbol> GetTypesFromNamespace(INamespaceSymbol namespaceSymbol)
+        static IEnumerable<INamedTypeSymbol> GetTypesFromNamespaceOrType(INamespaceOrTypeSymbol symbol)
         {
-            foreach (var member in namespaceSymbol.GetMembers())
+            foreach (var member in symbol.GetMembers())
             {
-                if (member is INamedTypeSymbol namedType)
+                if (member is INamespaceOrTypeSymbol namespaceOrType)
                 {
-                    yield return namedType;
-                }
-                else if (member is INamespaceSymbol nestedNamespace)
-                {
-                    foreach (var type in GetTypesFromNamespace(nestedNamespace))
+                    if (member is INamedTypeSymbol namedType)
+                    {
+                        yield return namedType;
+                    }
+
+                    foreach (var type in GetTypesFromNamespaceOrType(namespaceOrType))
                     {
                         yield return type;
                     }