From 5ac6ba3f855cfbf1e96fbf1ea28d4638996c4a7d Mon Sep 17 00:00:00 2001
From: Bjarne Koll <git@lynxplay.dev>
Date: Sat, 14 Dec 2024 18:29:05 +0100
Subject: [PATCH] Inheritance support

---
 .../restamp/recipe/MethodATMutator.java       | 73 ++++++++++++++++---
 .../restamp/RestampFunctionTestHelper.java    | 15 ++--
 .../restamp/at/InheritanceMethodATTest.java   | 65 +++++++++++++++++
 .../function/RestampClassFunctionTest.java    | 12 ++-
 .../function/RestampFieldFunctionTest.java    | 17 +++--
 .../function/RestampMethodFunctionTest.java   | 15 ++--
 6 files changed, 163 insertions(+), 34 deletions(-)
 create mode 100644 src/test/java/io/papermc/restamp/at/InheritanceMethodATTest.java

diff --git a/src/main/java/io/papermc/restamp/recipe/MethodATMutator.java b/src/main/java/io/papermc/restamp/recipe/MethodATMutator.java
index 84c3366..7c17469 100644
--- a/src/main/java/io/papermc/restamp/recipe/MethodATMutator.java
+++ b/src/main/java/io/papermc/restamp/recipe/MethodATMutator.java
@@ -12,12 +12,14 @@
 import org.cadixdev.bombe.type.VoidType;
 import org.cadixdev.bombe.type.signature.MethodSignature;
 import org.jspecify.annotations.NullMarked;
+import org.jspecify.annotations.Nullable;
 import org.openrewrite.ExecutionContext;
 import org.openrewrite.Recipe;
 import org.openrewrite.TreeVisitor;
 import org.openrewrite.java.JavaIsoVisitor;
 import org.openrewrite.java.tree.J;
 import org.openrewrite.java.tree.JavaType;
+import org.openrewrite.java.tree.JavaType.FullyQualified;
 import org.openrewrite.java.tree.TypeTree;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -34,6 +36,7 @@ public class MethodATMutator extends Recipe {
     private static final Logger LOGGER = LoggerFactory.getLogger(MethodATMutator.class);
 
     private final AccessTransformSet atDictionary;
+    private final AccessTransformSet inheritanceAccessTransformAtDirectory;
     private final ModifierTransformer modifierTransformer;
     private final AccessTransformerTypeConverter atTypeConverter;
 
@@ -43,6 +46,12 @@ public MethodATMutator(final AccessTransformSet atDictionary,
         this.atDictionary = atDictionary;
         this.modifierTransformer = modifierTransformer;
         this.atTypeConverter = atTypeConverter;
+
+        // Create a copy of the atDirectory for inherited at lookups.
+        // Needed as the parent type may be processed first, removing its access transformer for tracking purposes.
+        // Child types hence lookup using this.
+        this.inheritanceAccessTransformAtDirectory = AccessTransformSet.create();
+        this.inheritanceAccessTransformAtDirectory.merge(this.atDictionary);
     }
 
     @Override
@@ -67,12 +76,6 @@ public J.MethodDeclaration visitMethodDeclaration(final J.MethodDeclaration unre
                 if (parentClassDeclaration == null || parentClassDeclaration.getType() == null)
                     return methodDeclaration;
 
-                // Find access transformers for class
-                final AccessTransformSet.Class transformerClass = atDictionary.getClass(
-                    parentClassDeclaration.getType().getFullyQualifiedName()
-                ).orElse(null);
-                if (transformerClass == null) return methodDeclaration;
-
                 final String methodIdentifier = parentClassDeclaration.getType().getFullyQualifiedName() + "#" + methodDeclaration.getName();
 
                 if (methodDeclaration.getMethodType() == null) {
@@ -80,7 +83,7 @@ public J.MethodDeclaration visitMethodDeclaration(final J.MethodDeclaration unre
                     return methodDeclaration;
                 }
 
-                // Fetch access transformer to apply to specific field.
+                // Fetch access transformer to apply to specific method.
                 String atMethodName = methodDeclaration.getMethodType().getName();
                 Type returnType = atTypeConverter.convert(methodDeclaration.getMethodType().getReturnType(),
                     () -> "Parsing return type " + methodDeclaration.getReturnTypeExpression().toString() + " of method " + methodIdentifier);
@@ -101,10 +104,14 @@ public J.MethodDeclaration visitMethodDeclaration(final J.MethodDeclaration unre
                     returnType = VoidType.INSTANCE;
                 }
 
-                final AccessTransform accessTransform = transformerClass.replaceMethod(new MethodSignature(
-                    atMethodName, new MethodDescriptor(parameterTypes, returnType)
-                ), AccessTransform.EMPTY);
-                if (accessTransform == null || accessTransform.isEmpty()) return methodDeclaration;
+                // Find access transformers for method
+                final AccessTransform accessTransform = findApplicableAccessTransformer(
+                    parentClassDeclaration.getType(),
+                    atMethodName,
+                    returnType,
+                    parameterTypes
+                );
+                if (accessTransform == null) return methodDeclaration;
 
                 final TypeTree returnTypeExpression = methodDeclaration.getReturnTypeExpression();
                 final ModifierTransformationResult transformationResult = modifierTransformer.transformModifiers(
@@ -125,4 +132,48 @@ atMethodName, new MethodDescriptor(parameterTypes, returnType)
         };
     }
 
+    /**
+     * Finds the applicable access transformer for a method and *optionally* removes it from the atDirectory.
+     *
+     * @param owningType     the owning type of the method, e.g. the type it is defined in.
+     * @param atMethodName   the method name.
+     * @param returnType     the return type.
+     * @param parameterTypes the method parameters.
+     *
+     * @return the access transformer or null.
+     */
+    @Nullable
+    private AccessTransform findApplicableAccessTransformer(
+        final FullyQualified owningType,
+        final String atMethodName,
+        final Type returnType,
+        final List<FieldType> parameterTypes
+    ) {
+        final MethodSignature methodSignature = new MethodSignature(
+            atMethodName,
+            new MethodDescriptor(parameterTypes, returnType)
+        );
+
+        for (FullyQualified currentCheckedType = owningType; currentCheckedType != null; currentCheckedType = currentCheckedType.getSupertype()) {
+            // The class at data from the copy of the at dir.
+            // Removal of these happens later but we need the original state to ensure overrides are updated.
+            final AccessTransformSet.Class transformerClass = inheritanceAccessTransformAtDirectory
+                .getClass(currentCheckedType.getFullyQualifiedName())
+                .orElse(null);
+            if (transformerClass == null) continue;
+
+            // Only get the method here.
+            final AccessTransform accessTransform = transformerClass.getMethod(methodSignature);
+            if (accessTransform == null || accessTransform.isEmpty()) continue;
+
+            // If we *did* find an AT here and this *is* the direct owning type, remove it from the original atDirectory.
+            if (currentCheckedType == owningType) {
+                atDictionary.getClass(transformerClass.getName()).ifPresent(c -> c.replaceMethod(methodSignature, AccessTransform.EMPTY));
+            }
+            return accessTransform;
+        }
+
+        return null; // We did not find anything applicable.
+    }
+
 }
diff --git a/src/test/java/io/papermc/restamp/RestampFunctionTestHelper.java b/src/test/java/io/papermc/restamp/RestampFunctionTestHelper.java
index b0f67ca..dc2d69a 100644
--- a/src/test/java/io/papermc/restamp/RestampFunctionTestHelper.java
+++ b/src/test/java/io/papermc/restamp/RestampFunctionTestHelper.java
@@ -19,6 +19,7 @@
 import org.openrewrite.java.tree.Space;
 import org.openrewrite.marker.Markers;
 
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 import java.util.stream.Stream;
@@ -33,16 +34,16 @@ public class RestampFunctionTestHelper {
      * Constructs a new restamp input object from a single java class' source in a string.
      *
      * @param accessTransformSet the access transformers to apply.
-     * @param javaClassSource    the source code of a java class.
+     * @param javaClassesSource    the source code of a java class.
      *
      * @return the constructed restamp input.
      */
     public static RestampInput inputFromSourceString(final AccessTransformSet accessTransformSet,
-                                                     final String javaClassSource) {
+                                                     final String... javaClassesSource) {
         final Java21Parser javaParser = Java21Parser.builder().build();
         final InMemoryExecutionContext executionContext = new InMemoryExecutionContext(t -> Assertions.fail("Failed to parse inputs", t));
         final List<SourceFile> sourceFiles = javaParser.parseInputs(
-            List.of(Parser.Input.fromString(javaClassSource)),
+            Arrays.stream(javaClassesSource).map(Parser.Input::fromString).toList(),
             null,
             executionContext
         ).toList();
@@ -92,7 +93,7 @@ public static String accessChangeToModifierString(final AccessChange accessChang
         return stringBuilder.toString();
     }
 
-    public record TestCodeStyle(boolean includesLeadingAnnotation) {
+    public record TestCodeStyle(boolean includesLeadingAnnotation, boolean leadingSpace) {
 
     }
 
@@ -147,8 +148,10 @@ private static Object[] concat(final Object[] first, final Object... other) {
         @Override
         public Stream<? extends Arguments> provideArguments(final ExtensionContext context) {
             return CartesianVisibilityArgumentProvider.provideArguments().flatMap(arguments ->
-                Stream.of(true, false).map(includeAnnotation ->
-                    Arguments.arguments(concat(arguments.get(), new TestCodeStyle(includeAnnotation)))
+                Stream.of(true, false).flatMap(includeAnnotation ->
+                    Stream.of(true, false).map(leadingSpace ->
+                        Arguments.arguments(concat(arguments.get(), new TestCodeStyle(includeAnnotation, leadingSpace)))
+                    )
                 )
             );
         }
diff --git a/src/test/java/io/papermc/restamp/at/InheritanceMethodATTest.java b/src/test/java/io/papermc/restamp/at/InheritanceMethodATTest.java
new file mode 100644
index 0000000..0d1a0c9
--- /dev/null
+++ b/src/test/java/io/papermc/restamp/at/InheritanceMethodATTest.java
@@ -0,0 +1,65 @@
+package io.papermc.restamp.at;
+
+import io.papermc.restamp.Restamp;
+import io.papermc.restamp.RestampFunctionTestHelper;
+import io.papermc.restamp.RestampInput;
+import org.cadixdev.at.AccessTransform;
+import org.cadixdev.at.AccessTransformSet;
+import org.cadixdev.bombe.type.signature.MethodSignature;
+import org.jspecify.annotations.NullMarked;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+import org.openrewrite.Result;
+
+import java.util.List;
+
+@NullMarked
+public class InheritanceMethodATTest {
+
+    @Test
+    public void testInheritedATs() {
+        final AccessTransformSet accessTransformSet = AccessTransformSet.create();
+        accessTransformSet.getOrCreateClass("io.papermc.test.Test").replaceMethod(
+            MethodSignature.of("test", "(Ljava.lang.Object;)Ljava.lang.String;"), AccessTransform.PUBLIC
+        );
+
+        final RestampInput input = RestampFunctionTestHelper.inputFromSourceString(
+            accessTransformSet,
+            """
+                package io.papermc.test;
+                
+                public class Test {
+                    protected String test(final Object parameter) {
+                        return "hi there";
+                    }
+                }
+                """,
+            """
+                package io.papermc.test;
+                
+                public class SuperTest extends Test {
+                    @Override
+                    protected String test(final Object parameter) {
+                        return "hi there but better";
+                    }
+                }
+                """
+        );
+
+        final List<Result> results = Restamp.run(input).getAllResults();
+        Assertions.assertEquals(
+            """
+                package io.papermc.test;
+                
+                public class SuperTest extends Test {
+                    @Override
+                    public String test(final Object parameter) {
+                        return "hi there but better";
+                    }
+                }
+                """,
+            results.get(1).getAfter().printAll()
+        );
+    }
+
+}
diff --git a/src/test/java/io/papermc/restamp/function/RestampClassFunctionTest.java b/src/test/java/io/papermc/restamp/function/RestampClassFunctionTest.java
index e76c650..7b2cd0b 100644
--- a/src/test/java/io/papermc/restamp/function/RestampClassFunctionTest.java
+++ b/src/test/java/io/papermc/restamp/function/RestampClassFunctionTest.java
@@ -53,18 +53,22 @@ public void testAccessTransformerOnClass(final AccessTransform given,
 
     private String constructClassTest(String modifier, final RestampFunctionTestHelper.TestCodeStyle testCodeStyle) {
         if (!modifier.isEmpty()) modifier = modifier + " ";
-        return """
+        final StringBuilder builder = new StringBuilder();
+        builder.append("""
             package io.papermc.test;
             
             /**
             * With javadocs!
             */
-            """ +
-            (testCodeStyle.includesLeadingAnnotation() ? "@Experimental\n" : "") +
+            """);
+        if (testCodeStyle.includesLeadingAnnotation()) builder.append("@Experimental\n");
+        if (testCodeStyle.leadingSpace()) builder.append(" /* leading space */ ");
+        builder.append(
             """
             %sclass Test {
             
-            }""".formatted(modifier);
+            }""".formatted(modifier));
+        return builder.toString();
     }
 
 }
diff --git a/src/test/java/io/papermc/restamp/function/RestampFieldFunctionTest.java b/src/test/java/io/papermc/restamp/function/RestampFieldFunctionTest.java
index 74bfc3b..af69eef 100644
--- a/src/test/java/io/papermc/restamp/function/RestampFieldFunctionTest.java
+++ b/src/test/java/io/papermc/restamp/function/RestampFieldFunctionTest.java
@@ -53,17 +53,20 @@ public void testAccessTransformerOnField(final AccessTransform given,
 
     private String constructFieldTest(String modifier, final RestampFunctionTestHelper.TestCodeStyle testCodeStyle) {
         if (!modifier.isEmpty()) modifier = modifier + " ";
-        return """
+        final StringBuilder builder = new StringBuilder();
+        builder.append("""
             package io.papermc.test;
             
             public class Test {
                 /* Comment above */
-            """ +
-            (testCodeStyle.includesLeadingAnnotation() ? "    @Experimental\n" : "") +
-            """
-                    %sString passphrase = "Hello World";
-                }
-                """.formatted(modifier);
+            """);
+        if (testCodeStyle.includesLeadingAnnotation()) builder.append("    @Experimental\n");
+        if (testCodeStyle.leadingSpace()) builder.append(" /* leading space */ ");
+        builder.append("""
+                %sString passphrase = "Hello World";
+            }
+            """.formatted(modifier));
+        return builder.toString();
     }
 
 }
diff --git a/src/test/java/io/papermc/restamp/function/RestampMethodFunctionTest.java b/src/test/java/io/papermc/restamp/function/RestampMethodFunctionTest.java
index 41c7153..0549541 100644
--- a/src/test/java/io/papermc/restamp/function/RestampMethodFunctionTest.java
+++ b/src/test/java/io/papermc/restamp/function/RestampMethodFunctionTest.java
@@ -56,7 +56,8 @@ public void testAccessTransformerOnMethod(final AccessTransform given,
 
     private String constructMethodTest(String modifier, final RestampFunctionTestHelper.TestCodeStyle testCodeStyle) {
         if (!modifier.isEmpty()) modifier = modifier + " ";
-        return """
+        final StringBuilder builder = new StringBuilder();
+        builder.append("""
             package io.papermc.test;
             
             public class Test {
@@ -64,14 +65,16 @@ public class Test {
                 /**
                 * Javadocs
                 */
-            """ +
-            (testCodeStyle.includesLeadingAnnotation() ? "    @Experimental\n" : "") +
-            """
-                                %sString /* Comment insiede */ test(Object input) {
+            """);
+        if (testCodeStyle.includesLeadingAnnotation()) builder.append("    @Experimental\n");
+        if (testCodeStyle.leadingSpace()) builder.append(" /* leading space */ ");
+        builder.append("""
+                %sString /* Comment insiede */ test(Object input) {
                     return String.valueOf(input);
                 }
             }
-            """.formatted(modifier);
+            """.formatted(modifier));
+        return builder.toString();
     }
 
 }