From 49ecd9b71a8214e1b799f778ff2b4d007308421e Mon Sep 17 00:00:00 2001
From: Daniel Cazzulino <daniel@cazzulino.com>
Date: Sat, 28 Dec 2024 02:34:06 -0300
Subject: [PATCH] Improve record analyzer performance

Switch to symbol action rather than syntax, to speed up analysis.

Partially fixes #60
---
 src/StructId.Analyzer/RecordAnalyzer.cs | 42 +++++++++++++++----------
 1 file changed, 26 insertions(+), 16 deletions(-)

diff --git a/src/StructId.Analyzer/RecordAnalyzer.cs b/src/StructId.Analyzer/RecordAnalyzer.cs
index c6678f6..86bcd51 100644
--- a/src/StructId.Analyzer/RecordAnalyzer.cs
+++ b/src/StructId.Analyzer/RecordAnalyzer.cs
@@ -22,31 +22,42 @@ public override void Initialize(AnalysisContext context)
         if (!Debugger.IsAttached)
             context.EnableConcurrentExecution();
 
-        context.RegisterSyntaxNodeAction(Analyze, SyntaxKind.ClassDeclaration);
-        context.RegisterSyntaxNodeAction(Analyze, SyntaxKind.StructDeclaration);
-        context.RegisterSyntaxNodeAction(Analyze, SyntaxKind.RecordDeclaration);
-        context.RegisterSyntaxNodeAction(Analyze, SyntaxKind.RecordStructDeclaration);
+        context.RegisterCompilationStartAction(start =>
+        {
+            var known = new KnownTypes(start.Compilation);
+            if (known.IStructId is null || known.IStructIdT is null)
+                return;
+
+            start.RegisterSymbolAction(AnalyzeSymbol, SymbolKind.NamedType);
+        });
     }
 
-    static void Analyze(SyntaxNodeAnalysisContext context)
+    static void AnalyzeSymbol(SymbolAnalysisContext context)
     {
+        if (context.Symbol is not INamedTypeSymbol symbol)
+            return;
+
         var known = new KnownTypes(context.Compilation);
 
-        if (context.Node is not TypeDeclarationSyntax typeDeclaration ||
-            known.IStructIdT is not { } structIdTypeOfT ||
-            known.IStructId is not { } structIdType)
+        // We only care about IStructId and IStructId<T>
+        if (!symbol.Is(known.IStructId) && !symbol.Is(known.IStructIdT))
             return;
 
-        var symbol = context.SemanticModel.GetDeclaredSymbol(typeDeclaration);
-        if (symbol is null)
+        // We can only analyze if there's a declaration in source.
+        if (symbol.DeclaringSyntaxReferences.Length == 0 ||
+            symbol.DeclaringSyntaxReferences
+            .Select(x => x.GetSyntax())
+            .OfType<TypeDeclarationSyntax>()
+            .FirstOrDefault() is not { } typeDeclaration)
             return;
 
-        if (!symbol.Is(structIdType) && !symbol.Is(structIdTypeOfT))
-            return;
+        // TODO: report or ignore if more than one declaration?
 
         // If there's only one declaration and it's not partial
-        var report = symbol.DeclaringSyntaxReferences.Length == 1 && !typeDeclaration.Modifiers.Any(SyntaxKind.PartialKeyword);
-        report |= !typeDeclaration.IsKind(SyntaxKind.RecordStructDeclaration) || !symbol.IsReadOnly;
+        var report = symbol.DeclaringSyntaxReferences.Length == 1 &&
+            !typeDeclaration.Modifiers.Any(SyntaxKind.PartialKeyword);
+
+        report |= !symbol.IsRecord || symbol.TypeKind != TypeKind.Struct || !symbol.IsReadOnly;
 
         if (report)
         {
@@ -55,7 +66,7 @@ known.IStructIdT is not { } structIdTypeOfT ||
             else if (typeDeclaration.BaseList?.Types.FirstOrDefault(t => t.Type is IdentifierNameSyntax { Identifier.Text: "IStructId" }) is { } implementation)
                 context.ReportDiagnostic(Diagnostic.Create(MustBeRecordStruct, implementation.GetLocation(), symbol.Name));
             else
-                context.ReportDiagnostic(Diagnostic.Create(MustBeRecordStruct, typeDeclaration.Identifier.GetLocation(), symbol.Name));
+                context.ReportDiagnostic(Diagnostic.Create(MustBeRecordStruct, symbol.Locations.FirstOrDefault(), symbol.Name));
         }
 
         if (typeDeclaration.ParameterList is null)
@@ -72,6 +83,5 @@ known.IStructIdT is not { } structIdTypeOfT ||
         var parameter = typeDeclaration.ParameterList.Parameters[0];
         if (parameter.Identifier.Text != "Value")
             context.ReportDiagnostic(Diagnostic.Create(MustHaveValueConstructor, parameter.Identifier.GetLocation(), symbol.Name));
-
     }
 }