Skip to content

Repository inheritance scan #185

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ru.tinkoff.kora.database.annotation.processor.extension;

import ru.tinkoff.kora.annotation.processor.common.CommonUtils;
import ru.tinkoff.kora.annotation.processor.common.ProcessingErrorException;
import ru.tinkoff.kora.database.annotation.processor.DbUtils;
import ru.tinkoff.kora.kora.app.annotation.processor.extension.ExtensionResult;
import ru.tinkoff.kora.kora.app.annotation.processor.extension.KoraExtension;
Expand All @@ -10,10 +11,12 @@
import javax.lang.model.element.ElementKind;
import javax.lang.model.element.Modifier;
import javax.lang.model.element.TypeElement;
import javax.lang.model.type.DeclaredType;
import javax.lang.model.type.TypeKind;
import javax.lang.model.type.TypeMirror;
import javax.lang.model.util.Elements;
import javax.lang.model.util.Types;
import java.util.Objects;

public class RepositoryKoraExtension implements KoraExtension {
private final Elements elements;
Expand All @@ -33,10 +36,45 @@ public KoraExtensionDependencyGenerator getDependencyGenerator(RoundEnvironment
if (element.getKind() != ElementKind.INTERFACE && (element.getKind() != ElementKind.CLASS || !element.getModifiers().contains(Modifier.ABSTRACT))) {
return null;
}
if (CommonUtils.findDirectAnnotation(element, DbUtils.REPOSITORY_ANNOTATION) == null) {
return null;

final TypeElement typeElement;
if (CommonUtils.findDirectAnnotation(element, DbUtils.REPOSITORY_ANNOTATION) != null) {
typeElement = (TypeElement) this.types.asElement(typeMirror);
} else {
var candidates = roundEnvironment.getRootElements().stream()
.filter(candidate -> candidate instanceof TypeElement)
.map(candidate -> {
if (types.isAssignable(candidate.asType(), typeMirror)
&& !types.isSameType(candidate.asType(), typeMirror)) {
if (CommonUtils.findDirectAnnotation(candidate, DbUtils.REPOSITORY_ANNOTATION) != null) {
return ((TypeElement) candidate);
} else {
return types.directSupertypes(candidate.asType()).stream()
.filter(parentType -> parentType instanceof DeclaredType)
.map(parentType -> ((DeclaredType) parentType))
.filter(parentType -> CommonUtils.findDirectAnnotation(parentType.asElement(), DbUtils.REPOSITORY_ANNOTATION) != null)
.findFirst()
.map(parentType -> ((TypeElement) parentType.asElement()))
.orElse(null);
}
} else {
return null;
}
})
.filter(Objects::nonNull)
.toList();

if (candidates.isEmpty()) {
return null;
}

if (candidates.size() > 1) {
throw new ProcessingErrorException("Found '%s' suitable candidates for: %s".formatted(candidates.size(), element), element);
} else {
typeElement = candidates.get(0);
}
}
var typeElement = (TypeElement) this.types.asElement(typeMirror);

var packageElement = this.elements.getPackageOf(typeElement);
var repositoryName = CommonUtils.getOuterClassesAsPrefix(typeElement) + typeElement.getSimpleName().toString() + "_Impl";
var packageName = packageElement.getQualifiedName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ protected TestRepository(Class<?> repositoryClass, Object repositoryObject) {

@SuppressWarnings("unchecked")
public <T> T invoke(String method, Object... args) {
for (var repositoryClassMethod : repositoryClass.getMethods()) {
for (var repositoryClassMethod : repositoryClass.getDeclaredMethods()) {
if (repositoryClassMethod.getName().equals(method) && repositoryClassMethod.getParameters().length == args.length) {
try {
repositoryClassMethod.setAccessible(true);
var result = repositoryClassMethod.invoke(this.repositoryObject, args);
if (result instanceof Mono<?> mono) {
return (T) mono.block();
Expand Down Expand Up @@ -81,4 +82,22 @@ protected TestRepository compile(Object connectionFactory, List<?> arguments, @L
throw new RuntimeException(e);
}
}

protected TestRepository compileForArgs(List<?> arguments, @Language("java") String... sources) {
var compileResult = compile(List.of(new RepositoryAnnotationProcessor()), sources);
if (compileResult.isFailed()) {
throw compileResult.compilationException();
}

Assertions.assertThat(compileResult.warnings()).hasSize(0);

try {
var repositoryClass = compileResult.loadClass("$TestRepository_Impl");
var realArgs = arguments.toArray();
var repository = repositoryClass.getConstructors()[0].newInstance(realArgs);
return new TestRepository(repositoryClass, repository);
} catch (InstantiationException | IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import ru.tinkoff.kora.database.common.annotation.processor.entity.TestEntityRecord;
import ru.tinkoff.kora.database.common.annotation.processor.jdbc.repository.AllowedParametersRepository;
import ru.tinkoff.kora.database.jdbc.JdbcConnectionFactory;
import ru.tinkoff.kora.database.jdbc.JdbcDatabaseConfig;
import ru.tinkoff.kora.database.jdbc.mapper.parameter.JdbcParameterColumnMapper;
import ru.tinkoff.kora.json.common.annotation.Json;

Expand Down Expand Up @@ -114,6 +115,32 @@ public interface TestRepository extends JdbcRepository {
verify(executor.preparedStatement).execute();
}

@Test
public void testAbstractClassRepository() throws SQLException {
var config = new JdbcDatabaseConfig("1", "2", "3", "4",
null, null, null, null, null, null, null, null, null);
var repository = compileForArgs(List.of(config, executor), """
import ru.tinkoff.kora.database.jdbc.JdbcDatabaseConfig;
@Repository
public abstract class TestRepository implements JdbcRepository {

private final JdbcDatabaseConfig config;

public TestRepository(JdbcDatabaseConfig config) {
this.config = config;
}

@Query("INSERT INTO test(test) VALUES ('test')")
abstract void testConnection(Connection connection);
}
""");

repository.invoke("testConnection", executor.mockConnection);

verify(executor.mockConnection).prepareStatement("INSERT INTO test(test) VALUES ('test')");
verify(executor.preparedStatement).execute();
}

@Test
public void testNativeParameter() throws SQLException {
var repository = compileJdbc(List.of(), """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ abstract class AbstractRepositoryTest : AbstractSymbolProcessorTest() {
throw compileResult.compilationException()
}

val repositoryClass = compileResult.loadClass("\$TestRepository_Impl")
val realArgs = arrayOfNulls<Any>(arguments.size + 1)
realArgs[0] = connectionFactory
System.arraycopy(arguments.toTypedArray(), 0, realArgs, 1, arguments.size)
Expand All @@ -65,7 +64,20 @@ abstract class AbstractRepositoryTest : AbstractSymbolProcessorTest() {
realArgs[i] = arg.invoke()
}
}

val repositoryClass = compileResult.loadClass("\$TestRepository_Impl")
val repository = repositoryClass.constructors[0].newInstance(*realArgs)
return TestRepository(repositoryClass.kotlin, repository)
}

protected fun compileForArgs(arguments: Array<Any?>, @Language("kotlin") vararg sources: String): TestRepository {
val compileResult = compile(listOf(RepositorySymbolProcessorProvider()), *sources)
if (compileResult.isFailed()) {
throw compileResult.compilationException()
}

val repositoryClass = compileResult.loadClass("\$TestRepository_Impl")
val repository = repositoryClass.constructors[0].newInstance(*arguments)
return TestRepository(repositoryClass.kotlin, repository)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ import org.mockito.ArgumentMatchers
import org.mockito.Mockito
import org.mockito.kotlin.verify
import ru.tinkoff.kora.common.Tag
import ru.tinkoff.kora.database.cassandra.CassandraConfig.Advanced.MetricsConfig.Config
import ru.tinkoff.kora.database.jdbc.JdbcDatabaseConfig
import ru.tinkoff.kora.database.jdbc.mapper.parameter.JdbcParameterColumnMapper
import java.time.Duration
import kotlin.reflect.full.findAnnotations
import kotlin.reflect.jvm.jvmErasure

Expand All @@ -29,6 +32,27 @@ class JdbcParametersTest : AbstractJdbcRepositoryTest() {
verify(executor.preparedStatement).updateCount
}

@Test
fun testAbstractClassRepository() {
val config = JdbcDatabaseConfig("1", "2", "3", "4",
null, null, null, null, null, null, null, null, null)
val repository = compileForArgs(
arrayOf(config, executor),
"""
@Repository
abstract class TestRepository(val config: JdbcDatabaseConfig) : JdbcRepository {
@Query("INSERT INTO test(test) VALUES ('test')")
abstract fun test(connection: Connection)
}
""".trimIndent()
)

repository.invoke<Any>("test", executor.mockConnection)

verify(executor.preparedStatement).execute()
verify(executor.preparedStatement).updateCount
}

@Test
fun testNativeParameter() {
val repository = compile(
Expand Down