From 5ac6ba3f855cfbf1e96fbf1ea28d4638996c4a7d Mon Sep 17 00:00:00 2001 From: Bjarne Koll <git@lynxplay.dev> Date: Sat, 14 Dec 2024 18:29:05 +0100 Subject: [PATCH] Inheritance support --- .../restamp/recipe/MethodATMutator.java | 73 ++++++++++++++++--- .../restamp/RestampFunctionTestHelper.java | 15 ++-- .../restamp/at/InheritanceMethodATTest.java | 65 +++++++++++++++++ .../function/RestampClassFunctionTest.java | 12 ++- .../function/RestampFieldFunctionTest.java | 17 +++-- .../function/RestampMethodFunctionTest.java | 15 ++-- 6 files changed, 163 insertions(+), 34 deletions(-) create mode 100644 src/test/java/io/papermc/restamp/at/InheritanceMethodATTest.java diff --git a/src/main/java/io/papermc/restamp/recipe/MethodATMutator.java b/src/main/java/io/papermc/restamp/recipe/MethodATMutator.java index 84c3366..7c17469 100644 --- a/src/main/java/io/papermc/restamp/recipe/MethodATMutator.java +++ b/src/main/java/io/papermc/restamp/recipe/MethodATMutator.java @@ -12,12 +12,14 @@ import org.cadixdev.bombe.type.VoidType; import org.cadixdev.bombe.type.signature.MethodSignature; import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; import org.openrewrite.ExecutionContext; import org.openrewrite.Recipe; import org.openrewrite.TreeVisitor; import org.openrewrite.java.JavaIsoVisitor; import org.openrewrite.java.tree.J; import org.openrewrite.java.tree.JavaType; +import org.openrewrite.java.tree.JavaType.FullyQualified; import org.openrewrite.java.tree.TypeTree; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -34,6 +36,7 @@ public class MethodATMutator extends Recipe { private static final Logger LOGGER = LoggerFactory.getLogger(MethodATMutator.class); private final AccessTransformSet atDictionary; + private final AccessTransformSet inheritanceAccessTransformAtDirectory; private final ModifierTransformer modifierTransformer; private final AccessTransformerTypeConverter atTypeConverter; @@ -43,6 +46,12 @@ public MethodATMutator(final AccessTransformSet atDictionary, this.atDictionary = atDictionary; this.modifierTransformer = modifierTransformer; this.atTypeConverter = atTypeConverter; + + // Create a copy of the atDirectory for inherited at lookups. + // Needed as the parent type may be processed first, removing its access transformer for tracking purposes. + // Child types hence lookup using this. + this.inheritanceAccessTransformAtDirectory = AccessTransformSet.create(); + this.inheritanceAccessTransformAtDirectory.merge(this.atDictionary); } @Override @@ -67,12 +76,6 @@ public J.MethodDeclaration visitMethodDeclaration(final J.MethodDeclaration unre if (parentClassDeclaration == null || parentClassDeclaration.getType() == null) return methodDeclaration; - // Find access transformers for class - final AccessTransformSet.Class transformerClass = atDictionary.getClass( - parentClassDeclaration.getType().getFullyQualifiedName() - ).orElse(null); - if (transformerClass == null) return methodDeclaration; - final String methodIdentifier = parentClassDeclaration.getType().getFullyQualifiedName() + "#" + methodDeclaration.getName(); if (methodDeclaration.getMethodType() == null) { @@ -80,7 +83,7 @@ public J.MethodDeclaration visitMethodDeclaration(final J.MethodDeclaration unre return methodDeclaration; } - // Fetch access transformer to apply to specific field. + // Fetch access transformer to apply to specific method. String atMethodName = methodDeclaration.getMethodType().getName(); Type returnType = atTypeConverter.convert(methodDeclaration.getMethodType().getReturnType(), () -> "Parsing return type " + methodDeclaration.getReturnTypeExpression().toString() + " of method " + methodIdentifier); @@ -101,10 +104,14 @@ public J.MethodDeclaration visitMethodDeclaration(final J.MethodDeclaration unre returnType = VoidType.INSTANCE; } - final AccessTransform accessTransform = transformerClass.replaceMethod(new MethodSignature( - atMethodName, new MethodDescriptor(parameterTypes, returnType) - ), AccessTransform.EMPTY); - if (accessTransform == null || accessTransform.isEmpty()) return methodDeclaration; + // Find access transformers for method + final AccessTransform accessTransform = findApplicableAccessTransformer( + parentClassDeclaration.getType(), + atMethodName, + returnType, + parameterTypes + ); + if (accessTransform == null) return methodDeclaration; final TypeTree returnTypeExpression = methodDeclaration.getReturnTypeExpression(); final ModifierTransformationResult transformationResult = modifierTransformer.transformModifiers( @@ -125,4 +132,48 @@ atMethodName, new MethodDescriptor(parameterTypes, returnType) }; } + /** + * Finds the applicable access transformer for a method and *optionally* removes it from the atDirectory. + * + * @param owningType the owning type of the method, e.g. the type it is defined in. + * @param atMethodName the method name. + * @param returnType the return type. + * @param parameterTypes the method parameters. + * + * @return the access transformer or null. + */ + @Nullable + private AccessTransform findApplicableAccessTransformer( + final FullyQualified owningType, + final String atMethodName, + final Type returnType, + final List<FieldType> parameterTypes + ) { + final MethodSignature methodSignature = new MethodSignature( + atMethodName, + new MethodDescriptor(parameterTypes, returnType) + ); + + for (FullyQualified currentCheckedType = owningType; currentCheckedType != null; currentCheckedType = currentCheckedType.getSupertype()) { + // The class at data from the copy of the at dir. + // Removal of these happens later but we need the original state to ensure overrides are updated. + final AccessTransformSet.Class transformerClass = inheritanceAccessTransformAtDirectory + .getClass(currentCheckedType.getFullyQualifiedName()) + .orElse(null); + if (transformerClass == null) continue; + + // Only get the method here. + final AccessTransform accessTransform = transformerClass.getMethod(methodSignature); + if (accessTransform == null || accessTransform.isEmpty()) continue; + + // If we *did* find an AT here and this *is* the direct owning type, remove it from the original atDirectory. + if (currentCheckedType == owningType) { + atDictionary.getClass(transformerClass.getName()).ifPresent(c -> c.replaceMethod(methodSignature, AccessTransform.EMPTY)); + } + return accessTransform; + } + + return null; // We did not find anything applicable. + } + } diff --git a/src/test/java/io/papermc/restamp/RestampFunctionTestHelper.java b/src/test/java/io/papermc/restamp/RestampFunctionTestHelper.java index b0f67ca..dc2d69a 100644 --- a/src/test/java/io/papermc/restamp/RestampFunctionTestHelper.java +++ b/src/test/java/io/papermc/restamp/RestampFunctionTestHelper.java @@ -19,6 +19,7 @@ import org.openrewrite.java.tree.Space; import org.openrewrite.marker.Markers; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.stream.Stream; @@ -33,16 +34,16 @@ public class RestampFunctionTestHelper { * Constructs a new restamp input object from a single java class' source in a string. * * @param accessTransformSet the access transformers to apply. - * @param javaClassSource the source code of a java class. + * @param javaClassesSource the source code of a java class. * * @return the constructed restamp input. */ public static RestampInput inputFromSourceString(final AccessTransformSet accessTransformSet, - final String javaClassSource) { + final String... javaClassesSource) { final Java21Parser javaParser = Java21Parser.builder().build(); final InMemoryExecutionContext executionContext = new InMemoryExecutionContext(t -> Assertions.fail("Failed to parse inputs", t)); final List<SourceFile> sourceFiles = javaParser.parseInputs( - List.of(Parser.Input.fromString(javaClassSource)), + Arrays.stream(javaClassesSource).map(Parser.Input::fromString).toList(), null, executionContext ).toList(); @@ -92,7 +93,7 @@ public static String accessChangeToModifierString(final AccessChange accessChang return stringBuilder.toString(); } - public record TestCodeStyle(boolean includesLeadingAnnotation) { + public record TestCodeStyle(boolean includesLeadingAnnotation, boolean leadingSpace) { } @@ -147,8 +148,10 @@ private static Object[] concat(final Object[] first, final Object... other) { @Override public Stream<? extends Arguments> provideArguments(final ExtensionContext context) { return CartesianVisibilityArgumentProvider.provideArguments().flatMap(arguments -> - Stream.of(true, false).map(includeAnnotation -> - Arguments.arguments(concat(arguments.get(), new TestCodeStyle(includeAnnotation))) + Stream.of(true, false).flatMap(includeAnnotation -> + Stream.of(true, false).map(leadingSpace -> + Arguments.arguments(concat(arguments.get(), new TestCodeStyle(includeAnnotation, leadingSpace))) + ) ) ); } diff --git a/src/test/java/io/papermc/restamp/at/InheritanceMethodATTest.java b/src/test/java/io/papermc/restamp/at/InheritanceMethodATTest.java new file mode 100644 index 0000000..0d1a0c9 --- /dev/null +++ b/src/test/java/io/papermc/restamp/at/InheritanceMethodATTest.java @@ -0,0 +1,65 @@ +package io.papermc.restamp.at; + +import io.papermc.restamp.Restamp; +import io.papermc.restamp.RestampFunctionTestHelper; +import io.papermc.restamp.RestampInput; +import org.cadixdev.at.AccessTransform; +import org.cadixdev.at.AccessTransformSet; +import org.cadixdev.bombe.type.signature.MethodSignature; +import org.jspecify.annotations.NullMarked; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.openrewrite.Result; + +import java.util.List; + +@NullMarked +public class InheritanceMethodATTest { + + @Test + public void testInheritedATs() { + final AccessTransformSet accessTransformSet = AccessTransformSet.create(); + accessTransformSet.getOrCreateClass("io.papermc.test.Test").replaceMethod( + MethodSignature.of("test", "(Ljava.lang.Object;)Ljava.lang.String;"), AccessTransform.PUBLIC + ); + + final RestampInput input = RestampFunctionTestHelper.inputFromSourceString( + accessTransformSet, + """ + package io.papermc.test; + + public class Test { + protected String test(final Object parameter) { + return "hi there"; + } + } + """, + """ + package io.papermc.test; + + public class SuperTest extends Test { + @Override + protected String test(final Object parameter) { + return "hi there but better"; + } + } + """ + ); + + final List<Result> results = Restamp.run(input).getAllResults(); + Assertions.assertEquals( + """ + package io.papermc.test; + + public class SuperTest extends Test { + @Override + public String test(final Object parameter) { + return "hi there but better"; + } + } + """, + results.get(1).getAfter().printAll() + ); + } + +} diff --git a/src/test/java/io/papermc/restamp/function/RestampClassFunctionTest.java b/src/test/java/io/papermc/restamp/function/RestampClassFunctionTest.java index e76c650..7b2cd0b 100644 --- a/src/test/java/io/papermc/restamp/function/RestampClassFunctionTest.java +++ b/src/test/java/io/papermc/restamp/function/RestampClassFunctionTest.java @@ -53,18 +53,22 @@ public void testAccessTransformerOnClass(final AccessTransform given, private String constructClassTest(String modifier, final RestampFunctionTestHelper.TestCodeStyle testCodeStyle) { if (!modifier.isEmpty()) modifier = modifier + " "; - return """ + final StringBuilder builder = new StringBuilder(); + builder.append(""" package io.papermc.test; /** * With javadocs! */ - """ + - (testCodeStyle.includesLeadingAnnotation() ? "@Experimental\n" : "") + + """); + if (testCodeStyle.includesLeadingAnnotation()) builder.append("@Experimental\n"); + if (testCodeStyle.leadingSpace()) builder.append(" /* leading space */ "); + builder.append( """ %sclass Test { - }""".formatted(modifier); + }""".formatted(modifier)); + return builder.toString(); } } diff --git a/src/test/java/io/papermc/restamp/function/RestampFieldFunctionTest.java b/src/test/java/io/papermc/restamp/function/RestampFieldFunctionTest.java index 74bfc3b..af69eef 100644 --- a/src/test/java/io/papermc/restamp/function/RestampFieldFunctionTest.java +++ b/src/test/java/io/papermc/restamp/function/RestampFieldFunctionTest.java @@ -53,17 +53,20 @@ public void testAccessTransformerOnField(final AccessTransform given, private String constructFieldTest(String modifier, final RestampFunctionTestHelper.TestCodeStyle testCodeStyle) { if (!modifier.isEmpty()) modifier = modifier + " "; - return """ + final StringBuilder builder = new StringBuilder(); + builder.append(""" package io.papermc.test; public class Test { /* Comment above */ - """ + - (testCodeStyle.includesLeadingAnnotation() ? " @Experimental\n" : "") + - """ - %sString passphrase = "Hello World"; - } - """.formatted(modifier); + """); + if (testCodeStyle.includesLeadingAnnotation()) builder.append(" @Experimental\n"); + if (testCodeStyle.leadingSpace()) builder.append(" /* leading space */ "); + builder.append(""" + %sString passphrase = "Hello World"; + } + """.formatted(modifier)); + return builder.toString(); } } diff --git a/src/test/java/io/papermc/restamp/function/RestampMethodFunctionTest.java b/src/test/java/io/papermc/restamp/function/RestampMethodFunctionTest.java index 41c7153..0549541 100644 --- a/src/test/java/io/papermc/restamp/function/RestampMethodFunctionTest.java +++ b/src/test/java/io/papermc/restamp/function/RestampMethodFunctionTest.java @@ -56,7 +56,8 @@ public void testAccessTransformerOnMethod(final AccessTransform given, private String constructMethodTest(String modifier, final RestampFunctionTestHelper.TestCodeStyle testCodeStyle) { if (!modifier.isEmpty()) modifier = modifier + " "; - return """ + final StringBuilder builder = new StringBuilder(); + builder.append(""" package io.papermc.test; public class Test { @@ -64,14 +65,16 @@ public class Test { /** * Javadocs */ - """ + - (testCodeStyle.includesLeadingAnnotation() ? " @Experimental\n" : "") + - """ - %sString /* Comment insiede */ test(Object input) { + """); + if (testCodeStyle.includesLeadingAnnotation()) builder.append(" @Experimental\n"); + if (testCodeStyle.leadingSpace()) builder.append(" /* leading space */ "); + builder.append(""" + %sString /* Comment insiede */ test(Object input) { return String.valueOf(input); } } - """.formatted(modifier); + """.formatted(modifier)); + return builder.toString(); } }