diff --git a/database/database-annotation-processor/src/main/java/ru/tinkoff/kora/database/annotation/processor/QueryMacrosParser.java b/database/database-annotation-processor/src/main/java/ru/tinkoff/kora/database/annotation/processor/QueryMacrosParser.java index 154bd3b6d..77e5bc807 100644 --- a/database/database-annotation-processor/src/main/java/ru/tinkoff/kora/database/annotation/processor/QueryMacrosParser.java +++ b/database/database-annotation-processor/src/main/java/ru/tinkoff/kora/database/annotation/processor/QueryMacrosParser.java @@ -30,7 +30,7 @@ final class QueryMacrosParser { this.types = types; } - record Field(Element field, String column, String path, boolean isId) {} + record Field(Element field, String column, String path, boolean isId, boolean isEmbedded) {} record Target(DeclaredType type, String name) {} @@ -40,7 +40,8 @@ public String parse(String sqlWithSyntax, DeclaredType repositoryType, Executabl while (true) { var cmdIndexStart = sqlWithSyntax.indexOf(MACROS_START, prevCmdIndex); if (cmdIndexStart == -1) { - return sqlBuilder.append(sqlWithSyntax.substring(prevCmdIndex)).toString(); + String sqlResult = sqlWithSyntax.substring(prevCmdIndex); + return sqlBuilder.append(sqlResult).toString(); } var cmdIndexEnd = sqlWithSyntax.indexOf(MACROS_END, cmdIndexStart); @@ -54,9 +55,24 @@ public String parse(String sqlWithSyntax, DeclaredType repositoryType, Executabl } private List getPathField(ExecutableElement method, DeclaredType target, String rootPath, String columnPrefix) { + var treatAsNativeParameterColumn = target.getAnnotationMirrors().stream() + .filter(a -> ClassName.get(a.getAnnotationType()).equals(DbUtils.COLUMN_ANNOTATION)) + .findFirst(); + if (treatAsNativeParameterColumn.isPresent()) { + Collection values = treatAsNativeParameterColumn.get().getElementValues().values(); + if (!values.isEmpty()) { + Object columnTypeUseName = values.iterator().next().getValue(); + if (columnTypeUseName != null) { + return List.of(new Field(target.asElement(), columnTypeUseName.toString(), rootPath, false, false)); + } + } + + throw new ProcessingErrorException("Can't treat argument '" + rootPath + "' as macros native cause failed to extract @Column value: " + target, method); + } + final JdbcNativeType nativeType = JdbcNativeTypes.findNativeType(TypeName.get(target)); if (nativeType != null) { - throw new ProcessingErrorException("Can't process argument '" + rootPath + "' as macros cause it is Native Type: " + target, method); + throw new ProcessingErrorException("Can't process argument '" + rootPath + "' as macros cause it is Native Type without @Column specified: " + target, method); } var result = new ArrayList(); @@ -70,15 +86,16 @@ private List getPathField(ExecutableElement method, DeclaredType target, if (embedded != null) { if (field.asType() instanceof DeclaredType dt) { var prefix = Objects.requireNonNullElse(AnnotationUtils.parseAnnotationValueWithoutDefault(embedded, "value"), ""); - for (var f : getPathField(method, dt, path, prefix)) { - result.add(new Field(f.field(), f.column(), f.path(), isId)); + List pathFields = getPathField(method, dt, path, prefix); + for (var f : pathFields) { + result.add(new Field(f.field(), f.column(), f.path(), isId, true)); } } else { throw new IllegalArgumentException("@Embedded annotation placed on field that can't be embedded: " + target); } } else { var columnName = getColumnName(target, field, columnPrefix); - result.add(new Field(field, columnName, path, isId)); + result.add(new Field(field, columnName, path, isId, false)); } } return result; diff --git a/database/database-annotation-processor/src/test/java/ru/tinkoff/kora/database/common/annotation/processor/jdbc/JdbcMacrosTest.java b/database/database-annotation-processor/src/test/java/ru/tinkoff/kora/database/common/annotation/processor/jdbc/JdbcMacrosTest.java index e047309fa..a27b25b30 100644 --- a/database/database-annotation-processor/src/test/java/ru/tinkoff/kora/database/common/annotation/processor/jdbc/JdbcMacrosTest.java +++ b/database/database-annotation-processor/src/test/java/ru/tinkoff/kora/database/common/annotation/processor/jdbc/JdbcMacrosTest.java @@ -4,6 +4,7 @@ import org.junit.jupiter.api.Test; import ru.tinkoff.kora.database.jdbc.mapper.parameter.JdbcParameterColumnMapper; import ru.tinkoff.kora.database.jdbc.mapper.result.JdbcResultColumnMapper; +import ru.tinkoff.kora.database.common.annotation.Column; import java.sql.PreparedStatement; import java.sql.ResultSet; @@ -14,24 +15,26 @@ import java.time.ZonedDateTime; import java.util.List; import java.util.concurrent.Executors; +import java.util.function.Consumer; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; final class JdbcMacrosTest extends AbstractJdbcRepositoryTest { + interface Some extends Consumer<@Column("id") String> { + + } + @Test void returnTable() throws SQLException { var repository = compileJdbc(List.of(newGeneratedObject("TestRowMapper")), """ @Repository public interface TestRepository extends JdbcRepository { - + @Table("entities") record Entity(@Id String id, @Column("value1") int field1, String value2, @Nullable String value3) {} - + @Query("SELECT * FROM %{return#table} WHERE id = :id") @Nullable Entity findById(String id); @@ -53,10 +56,10 @@ void returnSelectsAndTable() throws SQLException { var repository = compileJdbc(List.of(Executors.newCachedThreadPool(), newGeneratedObject("TestRowMapper")), """ @Repository public interface TestRepository extends JdbcRepository { - + @Table("entities") record Entity(@Id String id, @Column("value1") int field1, String value2, @Nullable String value3) {} - + @Query("SELECT %{return#selects} FROM %{return#table} WHERE id = :id") @Nullable CompletionStage findById(String id); @@ -78,7 +81,7 @@ void inserts() throws SQLException { var repository = compileJdbc(List.of(), """ @Repository public interface TestRepository extends JdbcRepository { - + @Query("INSERT INTO %{entity#inserts}") UpdateCount insert(Entity entity); } @@ -96,7 +99,7 @@ void insertBatch() throws SQLException { var repository = compileJdbc(List.of(), """ @Repository public interface TestRepository extends JdbcRepository { - + @Query("INSERT INTO %{entity#inserts}") UpdateCount insert(@Batch java.util.List entity); } @@ -105,7 +108,7 @@ public interface TestRepository extends JdbcRepository { record Entity(@Id String id, @Column("value1") int field1, String value2, @Nullable String value3) {} """); - when(executor.preparedStatement.executeLargeBatch()).thenReturn(new long[] {1L}); + when(executor.preparedStatement.executeLargeBatch()).thenReturn(new long[]{1L}); repository.invoke("insert", List.of(newGeneratedObject("Entity", "1", 1, "1", "1").get())); verify(executor.mockConnection).prepareStatement("INSERT INTO entities(id, value1, value2, value3) VALUES (?, ?, ?, ?)"); } @@ -115,7 +118,7 @@ void insertsWithoutId() throws SQLException { var repository = compileJdbc(List.of(), """ @Repository public interface TestRepository extends JdbcRepository { - + @Query("INSERT INTO %{entity#inserts -= @id}") UpdateCount insert(Entity entity); } @@ -133,7 +136,7 @@ void insertsExtended() throws SQLException { var repository = compileJdbc(List.of(), """ @Repository public interface TestRepository extends ParentRepository { - + } """, """ public interface ParentRepository extends JdbcRepository { @@ -155,7 +158,7 @@ void insertsWithoutField() throws SQLException { var repository = compileJdbc(List.of(), """ @Repository public interface TestRepository extends JdbcRepository { - + @Query("INSERT INTO %{entity#inserts -= field1}") UpdateCount insert(Entity entity); } @@ -173,7 +176,7 @@ void upsert() throws SQLException { var repository = compileJdbc(List.of(), """ @Repository public interface TestRepository extends JdbcRepository { - + @Query("INSERT INTO %{entity#inserts} ON CONFLICT (id) DO UPDATE SET %{entity#updates}") UpdateCount insert(Entity entity); } @@ -191,7 +194,7 @@ void upsertBatch() throws SQLException { var repository = compileJdbc(List.of(), """ @Repository public interface TestRepository extends JdbcRepository { - + @Query("INSERT INTO %{entity#inserts} ON CONFLICT (id) DO UPDATE SET %{entity#updates}") UpdateCount insert(@Batch java.util.List entity); } @@ -200,7 +203,7 @@ public interface TestRepository extends JdbcRepository { record Entity(@Id String id, @Column("value1") int field1, String value2, @Nullable String value3) {} """); - when(executor.preparedStatement.executeLargeBatch()).thenReturn(new long[] {1L}); + when(executor.preparedStatement.executeLargeBatch()).thenReturn(new long[]{1L}); repository.invoke("insert", List.of(newGeneratedObject("Entity", "1", 1, "1", "1").get())); verify(executor.mockConnection).prepareStatement("INSERT INTO entities(id, value1, value2, value3) VALUES (?, ?, ?, ?) ON CONFLICT (id) DO UPDATE SET value1 = ?, value2 = ?, value3 = ?"); } @@ -210,7 +213,7 @@ void entityTableAndUpdate() throws SQLException { var repository = compileJdbc(List.of(), """ @Repository public interface TestRepository extends JdbcRepository { - + @Query("UPDATE %{entity#table} SET %{entity#updates} WHERE %{entity#where = @id}") UpdateCount insert(Entity entity); } @@ -223,12 +226,70 @@ record Entity(@Id String id, @Column("value1") int field1, String value2, @Nulla verify(executor.mockConnection).prepareStatement("UPDATE entities SET value1 = ?, value2 = ?, value3 = ? WHERE id = ?"); } + @Test + void whereTypeUseParameter() throws SQLException { + var repository = compileJdbc(List.of(newGeneratedObject("TestRowMapper")), + """ + public interface AbstractJdbcRepository extends JdbcRepository { + + @Query("SELECT %{return#selects} FROM %{return#table} WHERE %{keyArg#where}") + @Nullable + V findById(K keyArg); + } + """, + """ + @Repository + public interface TestRepository extends AbstractJdbcRepository<@Column("id") String, TestRepository.Entity> { + + @Table("entities") + record Entity(@Id String id, @Column("value1") int field1, String value2, @Nullable String value3) {} + } + """, + """ + public class TestRowMapper implements JdbcResultSetMapper { + public TestRepository.Entity apply(ResultSet rs) { + return null; + } + } + """); + + repository.invoke("findById", "1"); + verify(executor.mockConnection).prepareStatement("SELECT id, value1, value2, value3 FROM entities WHERE id = ?"); + } + + @Test + void whereMethodArgumentParameter() throws SQLException { + var repository = compileJdbc(List.of(newGeneratedObject("TestRowMapper")), + """ + @Repository + public interface TestRepository extends JdbcRepository { + + @Table("entities") + record Entity(@Id String id, @Column("value1") int field1, String value2, @Nullable String value3) {} + + @Query("SELECT %{return#selects} FROM %{return#table} WHERE %{keyArg#where}") + @Nullable + Entity findById(@Column("id") String keyArg); + } + """, + """ + public class TestRowMapper implements JdbcResultSetMapper { + public TestRepository.Entity apply(ResultSet rs) { + return null; + } + } + """); + + repository.invoke("findById", "1"); + verify(executor.mockConnection).prepareStatement("SELECT id, value1, value2, value3 FROM entities WHERE id = ?"); + } + @Test void entityTableAndUpdateBatch() throws SQLException { var repository = compileJdbc(List.of(), """ @Repository public interface TestRepository extends JdbcRepository { - + @Query("UPDATE %{entity#table} SET %{entity#updates} WHERE %{entity#where = @id}") UpdateCount insert(@Batch java.util.List entity); } @@ -237,7 +298,7 @@ public interface TestRepository extends JdbcRepository { record Entity(@Id String id, @Column("value1") int field1, String value2, @Nullable String value3) {} """); - when(executor.preparedStatement.executeLargeBatch()).thenReturn(new long[] {1L}); + when(executor.preparedStatement.executeLargeBatch()).thenReturn(new long[]{1L}); repository.invoke("insert", List.of(newGeneratedObject("Entity", "1", 1, "1", "1").get())); verify(executor.mockConnection).prepareStatement("UPDATE entities SET value1 = ?, value2 = ?, value3 = ? WHERE id = ?"); } @@ -247,7 +308,7 @@ void entityTableAndUpdateWhereIdIsEmbedded() throws SQLException { var repository = compileJdbc(List.of(), """ @Repository public interface TestRepository extends JdbcRepository { - + @Query("UPDATE %{entity#table} SET %{entity#updates} WHERE %{entity#where = @id}") UpdateCount insert(Entity entity); } @@ -417,7 +478,7 @@ void entityTableAndUpdateExclude() throws SQLException { var repository = compileJdbc(List.of(), """ @Repository public interface TestRepository extends JdbcRepository { - + @Query("UPDATE %{entity#table} SET %{entity#updates -= field1} WHERE %{entity#where = @id}") UpdateCount insert(Entity entity); } @@ -435,7 +496,7 @@ void entityTableAndUpdateInclude() throws SQLException { var repository = compileJdbc(List.of(), """ @Repository public interface TestRepository extends JdbcRepository { - + @Query("UPDATE %{entity#table} SET %{entity#updates = field1} WHERE %{entity#where = @id}") UpdateCount insert(Entity entity); } diff --git a/database/database-common/src/main/java/ru/tinkoff/kora/database/common/annotation/Column.java b/database/database-common/src/main/java/ru/tinkoff/kora/database/common/annotation/Column.java index bb42ecbea..d2601aab2 100644 --- a/database/database-common/src/main/java/ru/tinkoff/kora/database/common/annotation/Column.java +++ b/database/database-common/src/main/java/ru/tinkoff/kora/database/common/annotation/Column.java @@ -21,7 +21,7 @@ * @see Table * @see Repository */ -@Target({ElementType.FIELD, ElementType.PARAMETER, ElementType.RECORD_COMPONENT}) +@Target({ElementType.FIELD, ElementType.PARAMETER, ElementType.RECORD_COMPONENT, ElementType.TYPE_USE}) @Retention(RetentionPolicy.CLASS) public @interface Column { diff --git a/database/database-symbol-processor/src/main/kotlin/ru/tinkoff/kora/database/symbol/processor/QueryMacrosParser.kt b/database/database-symbol-processor/src/main/kotlin/ru/tinkoff/kora/database/symbol/processor/QueryMacrosParser.kt index a152afb8f..f26423468 100644 --- a/database/database-symbol-processor/src/main/kotlin/ru/tinkoff/kora/database/symbol/processor/QueryMacrosParser.kt +++ b/database/database-symbol-processor/src/main/kotlin/ru/tinkoff/kora/database/symbol/processor/QueryMacrosParser.kt @@ -1,9 +1,6 @@ package ru.tinkoff.kora.database.symbol.processor -import com.google.devtools.ksp.symbol.KSClassDeclaration -import com.google.devtools.ksp.symbol.KSFunctionDeclaration -import com.google.devtools.ksp.symbol.KSPropertyDeclaration -import com.google.devtools.ksp.symbol.KSTypeReference +import com.google.devtools.ksp.symbol.* import com.squareup.kotlinpoet.ksp.toClassName import com.squareup.kotlinpoet.ksp.toTypeName import ru.tinkoff.kora.common.naming.SnakeCaseNameConverter @@ -28,8 +25,9 @@ class QueryMacrosParser { private const val SPECIAL_ID = "@id" } - data class Target(val type: KSClassDeclaration, val name: String) - data class Field(val field: KSPropertyDeclaration, val column: String, val path: String, val isId: Boolean) + data class Target(val typeRef: KSTypeReference, val type: KSClassDeclaration, val annotated: KSAnnotated, val name: String) + + data class Field(val column: String, val path: String, val isId: Boolean, val isEmbedded: Boolean) fun parse(sql: String, method: KSFunctionDeclaration): String { val sqlBuilder = StringBuilder() @@ -47,9 +45,39 @@ class QueryMacrosParser { } } - private fun getPathField(method: KSFunctionDeclaration, target: KSClassDeclaration, rootPath: String, columnPrefix: String): Sequence { + private fun getPathField( + method: KSFunctionDeclaration, + target: KSClassDeclaration, + targetRef: KSTypeReference, + targetAnnotated: KSAnnotated, + rootPath: String, + columnPrefix: String + ): Sequence { + val treatAsNativeParameterColumn = targetRef.annotations + .filter { a -> a.annotationType.resolve().toClassName() == DbUtils.columnAnnotation } + .firstOrNull() + ?: targetAnnotated.annotations + .filter { a -> a.annotationType.resolve().toClassName() == DbUtils.columnAnnotation } + .firstOrNull() + + if (treatAsNativeParameterColumn != null) { + val value = treatAsNativeParameterColumn.arguments[0].value + if (value != null) { + return sequenceOf( + Field( + value.toString(), + rootPath, + isId = false, + isEmbedded = false + ) + ) + } + + throw ProcessingErrorException("Can't treat argument '$rootPath' as macros native cause failed to extract @Column value: $target", method) + } + val nativeType = JdbcNativeTypes.findNativeType(target.toClassName()) - if(nativeType != null) { + if (nativeType != null) { throw ProcessingErrorException("Can't process argument '$rootPath' as macros cause it is Native Type: $target", method) } @@ -71,11 +99,12 @@ class QueryMacrosParser { } val prefix = isEmbedded.findValueNoDefault("value") ?: "" - return@flatMap getPathField(method, declaration, path, prefix) - .map { f -> Field(f.field, f.column, f.path, isId) } + val pathFields = getPathField(method, declaration, field.type, field, path, prefix) + return@flatMap pathFields + .map { f -> Field(f.column, f.path, isId, true) } } else { val columnName = getColumnName(target, field, columnPrefix) - return@flatMap sequenceOf(Field(field, columnName, path, isId)) + return@flatMap sequenceOf(Field(columnName, path, isId, false)) } } } @@ -159,9 +188,9 @@ class QueryMacrosParser { setOf() val fields = if (paths.isEmpty()) { - getPathField(method, target.type, target.name, "").toList() + getPathField(method, target.type, target.typeRef, target.annotated, target.name, "").toList() } else { - getPathField(method, target.type, target.name, "").filter { include == paths.contains(it.path) }.toList() + getPathField(method, target.type, target.typeRef, target.annotated, target.name, "").filter { include == paths.contains(it.path) }.toList() } val nameConverter = target.type.getNameConverter(SnakeCaseNameConverter.INSTANCE) @@ -192,7 +221,7 @@ class QueryMacrosParser { } private fun getTarget(targetName: String, method: KSFunctionDeclaration): Target { - val reference: KSTypeReference + val refPair: Pair if (TARGET_RETURN == targetName) { if (method.isVoid()) { @@ -200,7 +229,7 @@ class QueryMacrosParser { "Macros command specified 'return' target, but return value is type Void", method ) - } else if(method.returnType?.toTypeName() == DbUtils.updateCount) { + } else if (method.returnType?.toTypeName() == DbUtils.updateCount) { throw ProcessingErrorException( "Macros command specified 'return' target, but return value is type UpdateCount", method @@ -208,34 +237,32 @@ class QueryMacrosParser { } val resolved = method.returnType!!.resolve() - reference = if (method.isCompletionStage() || method.isMono() || method.isFlux() || resolved.isCollection()) { - resolved.arguments[0].type!! + refPair = if (method.isCompletionStage() || method.isMono() || method.isFlux() || resolved.isCollection()) { + Pair(resolved.arguments[0].type!!, method) } else { - method.returnType!! + Pair(method.returnType!!, method) } } else { - reference = method.parameters.stream() + refPair = method.parameters .filter { p -> p.name!!.asString().contentEquals(targetName) } - .findFirst() - .map { obj -> - val resolved = obj.type.resolve() + .map { param -> + val resolved = param.type.resolve() if (resolved.isCollection()) { - resolved.arguments[0].type!! + Pair(resolved.arguments[0].type!!, param as KSAnnotated) } else { - obj.type + Pair(param.type, param as KSAnnotated) } } - .orElseThrow { - ProcessingErrorException( - "Macros command unspecified target received: $targetName", - method - ) - } + .firstOrNull() + ?: throw ProcessingErrorException( + "Macros command unspecified target received: $targetName", + method + ) } - val resolved = reference.resolve().declaration + val resolved = refPair.first.resolve().declaration return if (resolved is KSClassDeclaration) { - Target(resolved, targetName) + Target(refPair.first, resolved, refPair.second, targetName) } else { throw ProcessingErrorException( "Macros command unprocessable target type: $targetName", diff --git a/database/database-symbol-processor/src/test/kotlin/ru/tinkoff/kora/database/symbol/processor/jdbc/JdbcMacrosTest.kt b/database/database-symbol-processor/src/test/kotlin/ru/tinkoff/kora/database/symbol/processor/jdbc/JdbcMacrosTest.kt index b474a331c..9d60151e1 100644 --- a/database/database-symbol-processor/src/test/kotlin/ru/tinkoff/kora/database/symbol/processor/jdbc/JdbcMacrosTest.kt +++ b/database/database-symbol-processor/src/test/kotlin/ru/tinkoff/kora/database/symbol/processor/jdbc/JdbcMacrosTest.kt @@ -6,7 +6,6 @@ import ru.tinkoff.kora.database.jdbc.mapper.parameter.JdbcParameterColumnMapper import ru.tinkoff.kora.database.jdbc.mapper.result.JdbcResultColumnMapper import java.sql.PreparedStatement import java.sql.ResultSet -import java.sql.SQLException import java.time.OffsetDateTime import java.util.concurrent.Executors @@ -15,7 +14,8 @@ class JdbcMacrosTest : AbstractJdbcRepositoryTest() { @Test fun returnTable() { val repository = compile( - listOf(newGenerated("TestRowMapper")), """ + listOf(newGenerated("TestRowMapper")), + """ @Repository interface TestRepository : JdbcRepository { @@ -29,14 +29,13 @@ class JdbcMacrosTest : AbstractJdbcRepositoryTest() { @Nullable fun findById(id: String): Entity? } - - """.trimIndent(), """ + """.trimIndent(), + """ class TestRowMapper : JdbcResultSetMapper { override fun apply(rs: ResultSet): TestRepository.Entity? { return null } } - """.trimIndent() ) repository.invoke("findById", "1") @@ -46,7 +45,8 @@ class JdbcMacrosTest : AbstractJdbcRepositoryTest() { @Test fun returnSelectsAndTable() { val repository = compile( - listOf(Executors.newSingleThreadExecutor(), newGenerated("TestRowMapper")), """ + listOf(Executors.newSingleThreadExecutor(), newGenerated("TestRowMapper")), + """ @Repository interface TestRepository : JdbcRepository { @@ -59,14 +59,13 @@ class JdbcMacrosTest : AbstractJdbcRepositoryTest() { @Query("SELECT %{return#selects} FROM %{return#table} WHERE id = :id") suspend fun findById(id: String): Entity? } - - """.trimIndent(), """ + """.trimIndent(), + """ class TestRowMapper : JdbcResultSetMapper { override fun apply(rs: ResultSet): TestRepository.Entity? { return null } } - """.trimIndent() ) repository.invoke("findById", "1") @@ -84,8 +83,8 @@ class JdbcMacrosTest : AbstractJdbcRepositoryTest() { @Query("INSERT INTO %{entity#inserts}") fun insert(entity: Entity): UpdateCount } - - """.trimIndent(), """ + """.trimIndent(), + """ @Table("entities") data class Entity(@field:Id val id: String, @field:Column("value1") val field1: Long, @@ -108,14 +107,13 @@ class JdbcMacrosTest : AbstractJdbcRepositoryTest() { @Query("INSERT INTO %{entity#inserts}") fun insert(@Batch entity: List): UpdateCount } - - """.trimIndent(), """ + """.trimIndent(), + """ @Table("entities") data class Entity(@field:Id val id: String, @field:Column("value1") val field1: Long, val value2: String, val value3: String?) - """.trimIndent() ) @@ -128,21 +126,21 @@ class JdbcMacrosTest : AbstractJdbcRepositoryTest() { @Test fun insertsWithoutId() { val repository = compile( - listOf(), """ + listOf(), + """ @Repository interface TestRepository : JdbcRepository { @Query("INSERT INTO %{entity#inserts -= @id}") fun insert(entity: Entity): UpdateCount } - - """.trimIndent(), """ + """.trimIndent(), + """ @Table("entities") data class Entity(@field:Id val id: String, @field:Column("value1") val field1: Long, val value2: String, val value3: String?) - """.trimIndent() ) repository.invoke("insert", newGenerated("Entity", "1", 1, "1", "1").invoke()) @@ -153,26 +151,24 @@ class JdbcMacrosTest : AbstractJdbcRepositoryTest() { @Test fun insertsExtended() { val repository = compile( - listOf(), """ + listOf(), + """ @Repository - interface TestRepository : ParentRepository { - - } - - """.trimIndent(), """ + interface TestRepository : ParentRepository + """.trimIndent(), + """ interface ParentRepository : JdbcRepository { @Query("INSERT INTO %{entity#inserts -= @id}") fun insert(entity: T): UpdateCount } - - """.trimIndent(), """ + """.trimIndent(), + """ @Table("entities") data class Entity(@field:Id val id: String, @field:Column("value1") val field1: Long, val value2: String, val value3: String?) - """.trimIndent() ) repository.invoke("insert", newGenerated("Entity", "1", 1, "1", "1").invoke()) @@ -183,21 +179,21 @@ class JdbcMacrosTest : AbstractJdbcRepositoryTest() { @Test fun insertsWithoutField() { val repository = compile( - listOf(), """ + listOf(), + """ @Repository interface TestRepository : JdbcRepository { @Query("INSERT INTO %{entity#inserts -= field1}") fun insert(entity: Entity): UpdateCount } - - """.trimIndent(), """ + """.trimIndent(), + """ @Table("entities") data class Entity(@field:Id val id: String, @field:Column("value1") val field1: Long, val value2: String, val value3: String?) - """.trimIndent() ) repository.invoke("insert", newGenerated("Entity", "1", 1, "1", "1").invoke()) @@ -208,21 +204,21 @@ class JdbcMacrosTest : AbstractJdbcRepositoryTest() { @Test fun upsert() { val repository = compile( - listOf(), """ + listOf(), + """ @Repository interface TestRepository : JdbcRepository { @Query("INSERT INTO %{entity#inserts} ON CONFLICT (id) DO UPDATE SET %{entity#updates}") fun upsert(entity: Entity): UpdateCount } - - """.trimIndent(), """ + """.trimIndent(), + """ @Table("entities") data class Entity(@field:Id val id: String, @field:Column("value1") val field1: Long, val value2: String, val value3: String?) - """.trimIndent() ) repository.invoke("upsert", newGenerated("Entity", "1", 1, "1", "1").invoke()) @@ -233,21 +229,21 @@ class JdbcMacrosTest : AbstractJdbcRepositoryTest() { @Test fun upsertBatch() { val repository = compile( - listOf(), """ + listOf(), + """ @Repository interface TestRepository : JdbcRepository { @Query("INSERT INTO %{entity#inserts} ON CONFLICT (id) DO UPDATE SET %{entity#updates}") fun upsert(@Batch entity: List): UpdateCount } - - """.trimIndent(), """ + """.trimIndent(), + """ @Table("entities") data class Entity(@field:Id val id: String, @field:Column("value1") val field1: Long, val value2: String, val value3: String?) - """.trimIndent() ) @@ -260,15 +256,16 @@ class JdbcMacrosTest : AbstractJdbcRepositoryTest() { @Test fun entityTableAndUpdate() { val repository = compile( - listOf(), """ + listOf(), + """ @Repository interface TestRepository : JdbcRepository { @Query("UPDATE %{entity#table} SET %{entity#updates} WHERE %{entity#where = @id}") fun insert(entity: Entity): UpdateCount } - - """.trimIndent(), """ + """.trimIndent(), + """ @Table("entities") data class Entity(@field:Id val id: String, @field:Column("value1") val field1: Long, @@ -282,6 +279,75 @@ class JdbcMacrosTest : AbstractJdbcRepositoryTest() { .prepareStatement("UPDATE entities SET value1 = ?, value2 = ?, value3 = ? WHERE id = ?") } + @Test + fun whereTypeUseParameter() { + val repository = compile( + listOf(newGenerated("TestRowMapper")), + """ + interface AbstractJdbcRepository : JdbcRepository { + + @Query("SELECT %{return#selects} FROM %{return#table} WHERE %{keyArg#where}") + fun findById(keyArg: K): V? + } + """.trimIndent(), + """ + @Repository + interface TestRepository : AbstractJdbcRepository<@Column("id") String, TestRepository.Entity> { + + @Table("entities") + data class Entity(@field:Id val id: String, + @field:Column("value1") val field1: Long, + val value2: String, + val value3: String?) + } + """.trimIndent(), + """ + class TestRowMapper : JdbcResultSetMapper { + override fun apply(rs: ResultSet): TestRepository.Entity? { + return null + } + } + """.trimIndent() + ) + + repository.invoke("findById", "1") + Mockito.verify(executor.mockConnection) + .prepareStatement("SELECT id, value1, value2, value3 FROM entities WHERE id = ?") + } + + @Test + fun whereMethodArgumentParameter() { + val repository = compile( + listOf(newGenerated("TestRowMapper")), + """ + @Repository + interface TestRepository : JdbcRepository { + + @Table("entities") + data class Entity(@field:Id val id: String, + @field:Column("value1") val field1: Long, + val value2: String, + val value3: String?) + + @Query("SELECT %{return#selects} FROM %{return#table} WHERE %{keyArg#where}") + @Nullable + fun findById(@Column("id") keyArg: String): Entity? + } + """.trimIndent(), + """ + class TestRowMapper : JdbcResultSetMapper { + override fun apply(rs: ResultSet): TestRepository.Entity? { + return null + } + } + """.trimIndent() + ) + + repository.invoke("findById", "1") + Mockito.verify(executor.mockConnection) + .prepareStatement("SELECT id, value1, value2, value3 FROM entities WHERE id = ?") + } + @Test fun entityTableAndUpdateBatch() { val repository = compile( @@ -292,14 +358,13 @@ class JdbcMacrosTest : AbstractJdbcRepositoryTest() { @Query("UPDATE %{entity#table} SET %{entity#updates} WHERE %{entity#where = @id}") fun insert(@Batch entity: List): UpdateCount } - - """.trimIndent(), """ + """.trimIndent(), + """ @Table("entities") data class Entity(@field:Id val id: String, @field:Column("value1") val field1: Long, val value2: String, val value3: String?) - """.trimIndent() )