Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ final class QueryMacrosParser {
this.types = types;
}

record Field(Element field, String column, String path, boolean isId) {}
record Field(Element field, String column, String path, boolean isId, boolean isEmbedded) {}

record Target(DeclaredType type, String name) {}

Expand All @@ -40,7 +40,8 @@ public String parse(String sqlWithSyntax, DeclaredType repositoryType, Executabl
while (true) {
var cmdIndexStart = sqlWithSyntax.indexOf(MACROS_START, prevCmdIndex);
if (cmdIndexStart == -1) {
return sqlBuilder.append(sqlWithSyntax.substring(prevCmdIndex)).toString();
String sqlResult = sqlWithSyntax.substring(prevCmdIndex);
return sqlBuilder.append(sqlResult).toString();
}

var cmdIndexEnd = sqlWithSyntax.indexOf(MACROS_END, cmdIndexStart);
Expand All @@ -54,9 +55,24 @@ public String parse(String sqlWithSyntax, DeclaredType repositoryType, Executabl
}

private List<Field> getPathField(ExecutableElement method, DeclaredType target, String rootPath, String columnPrefix) {
var treatAsNativeParameterColumn = target.getAnnotationMirrors().stream()
.filter(a -> ClassName.get(a.getAnnotationType()).equals(DbUtils.COLUMN_ANNOTATION))
.findFirst();
if (treatAsNativeParameterColumn.isPresent()) {
Collection<? extends AnnotationValue> values = treatAsNativeParameterColumn.get().getElementValues().values();
if (!values.isEmpty()) {
Object columnTypeUseName = values.iterator().next().getValue();
if (columnTypeUseName != null) {
return List.of(new Field(target.asElement(), columnTypeUseName.toString(), rootPath, false, false));
}
}

throw new ProcessingErrorException("Can't treat argument '" + rootPath + "' as macros native cause failed to extract @Column value: " + target, method);
}

final JdbcNativeType nativeType = JdbcNativeTypes.findNativeType(TypeName.get(target));
if (nativeType != null) {
throw new ProcessingErrorException("Can't process argument '" + rootPath + "' as macros cause it is Native Type: " + target, method);
throw new ProcessingErrorException("Can't process argument '" + rootPath + "' as macros cause it is Native Type without @Column specified: " + target, method);
}

var result = new ArrayList<Field>();
Expand All @@ -70,15 +86,16 @@ private List<Field> getPathField(ExecutableElement method, DeclaredType target,
if (embedded != null) {
if (field.asType() instanceof DeclaredType dt) {
var prefix = Objects.requireNonNullElse(AnnotationUtils.parseAnnotationValueWithoutDefault(embedded, "value"), "");
for (var f : getPathField(method, dt, path, prefix)) {
result.add(new Field(f.field(), f.column(), f.path(), isId));
List<Field> pathFields = getPathField(method, dt, path, prefix);
for (var f : pathFields) {
result.add(new Field(f.field(), f.column(), f.path(), isId, true));
}
} else {
throw new IllegalArgumentException("@Embedded annotation placed on field that can't be embedded: " + target);
}
} else {
var columnName = getColumnName(target, field, columnPrefix);
result.add(new Field(field, columnName, path, isId));
result.add(new Field(field, columnName, path, isId, false));
}
}
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import org.junit.jupiter.api.Test;
import ru.tinkoff.kora.database.jdbc.mapper.parameter.JdbcParameterColumnMapper;
import ru.tinkoff.kora.database.jdbc.mapper.result.JdbcResultColumnMapper;
import ru.tinkoff.kora.database.common.annotation.Column;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
Expand All @@ -14,24 +15,26 @@
import java.time.ZonedDateTime;
import java.util.List;
import java.util.concurrent.Executors;
import java.util.function.Consumer;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

final class JdbcMacrosTest extends AbstractJdbcRepositoryTest {

interface Some extends Consumer<@Column("id") String> {

}

@Test
void returnTable() throws SQLException {
var repository = compileJdbc(List.of(newGeneratedObject("TestRowMapper")), """
@Repository
public interface TestRepository extends JdbcRepository {

@Table("entities")
record Entity(@Id String id, @Column("value1") int field1, String value2, @Nullable String value3) {}

@Query("SELECT * FROM %{return#table} WHERE id = :id")
@Nullable
Entity findById(String id);
Expand All @@ -53,10 +56,10 @@ void returnSelectsAndTable() throws SQLException {
var repository = compileJdbc(List.of(Executors.newCachedThreadPool(), newGeneratedObject("TestRowMapper")), """
@Repository
public interface TestRepository extends JdbcRepository {

@Table("entities")
record Entity(@Id String id, @Column("value1") int field1, String value2, @Nullable String value3) {}

@Query("SELECT %{return#selects} FROM %{return#table} WHERE id = :id")
@Nullable
CompletionStage<Entity> findById(String id);
Expand All @@ -78,7 +81,7 @@ void inserts() throws SQLException {
var repository = compileJdbc(List.of(), """
@Repository
public interface TestRepository extends JdbcRepository {

@Query("INSERT INTO %{entity#inserts}")
UpdateCount insert(Entity entity);
}
Expand All @@ -96,7 +99,7 @@ void insertBatch() throws SQLException {
var repository = compileJdbc(List.of(), """
@Repository
public interface TestRepository extends JdbcRepository {

@Query("INSERT INTO %{entity#inserts}")
UpdateCount insert(@Batch java.util.List<Entity> entity);
}
Expand All @@ -105,7 +108,7 @@ public interface TestRepository extends JdbcRepository {
record Entity(@Id String id, @Column("value1") int field1, String value2, @Nullable String value3) {}
""");

when(executor.preparedStatement.executeLargeBatch()).thenReturn(new long[] {1L});
when(executor.preparedStatement.executeLargeBatch()).thenReturn(new long[]{1L});
repository.invoke("insert", List.of(newGeneratedObject("Entity", "1", 1, "1", "1").get()));
verify(executor.mockConnection).prepareStatement("INSERT INTO entities(id, value1, value2, value3) VALUES (?, ?, ?, ?)");
}
Expand All @@ -115,7 +118,7 @@ void insertsWithoutId() throws SQLException {
var repository = compileJdbc(List.of(), """
@Repository
public interface TestRepository extends JdbcRepository {

@Query("INSERT INTO %{entity#inserts -= @id}")
UpdateCount insert(Entity entity);
}
Expand All @@ -133,7 +136,7 @@ void insertsExtended() throws SQLException {
var repository = compileJdbc(List.of(), """
@Repository
public interface TestRepository extends ParentRepository<Entity> {

}
""", """
public interface ParentRepository<T> extends JdbcRepository {
Expand All @@ -155,7 +158,7 @@ void insertsWithoutField() throws SQLException {
var repository = compileJdbc(List.of(), """
@Repository
public interface TestRepository extends JdbcRepository {

@Query("INSERT INTO %{entity#inserts -= field1}")
UpdateCount insert(Entity entity);
}
Expand All @@ -173,7 +176,7 @@ void upsert() throws SQLException {
var repository = compileJdbc(List.of(), """
@Repository
public interface TestRepository extends JdbcRepository {

@Query("INSERT INTO %{entity#inserts} ON CONFLICT (id) DO UPDATE SET %{entity#updates}")
UpdateCount insert(Entity entity);
}
Expand All @@ -191,7 +194,7 @@ void upsertBatch() throws SQLException {
var repository = compileJdbc(List.of(), """
@Repository
public interface TestRepository extends JdbcRepository {

@Query("INSERT INTO %{entity#inserts} ON CONFLICT (id) DO UPDATE SET %{entity#updates}")
UpdateCount insert(@Batch java.util.List<Entity> entity);
}
Expand All @@ -200,7 +203,7 @@ public interface TestRepository extends JdbcRepository {
record Entity(@Id String id, @Column("value1") int field1, String value2, @Nullable String value3) {}
""");

when(executor.preparedStatement.executeLargeBatch()).thenReturn(new long[] {1L});
when(executor.preparedStatement.executeLargeBatch()).thenReturn(new long[]{1L});
repository.invoke("insert", List.of(newGeneratedObject("Entity", "1", 1, "1", "1").get()));
verify(executor.mockConnection).prepareStatement("INSERT INTO entities(id, value1, value2, value3) VALUES (?, ?, ?, ?) ON CONFLICT (id) DO UPDATE SET value1 = ?, value2 = ?, value3 = ?");
}
Expand All @@ -210,7 +213,7 @@ void entityTableAndUpdate() throws SQLException {
var repository = compileJdbc(List.of(), """
@Repository
public interface TestRepository extends JdbcRepository {

@Query("UPDATE %{entity#table} SET %{entity#updates} WHERE %{entity#where = @id}")
UpdateCount insert(Entity entity);
}
Expand All @@ -223,12 +226,70 @@ record Entity(@Id String id, @Column("value1") int field1, String value2, @Nulla
verify(executor.mockConnection).prepareStatement("UPDATE entities SET value1 = ?, value2 = ?, value3 = ? WHERE id = ?");
}

@Test
void whereTypeUseParameter() throws SQLException {
var repository = compileJdbc(List.of(newGeneratedObject("TestRowMapper")),
"""
public interface AbstractJdbcRepository<K, V> extends JdbcRepository {

@Query("SELECT %{return#selects} FROM %{return#table} WHERE %{keyArg#where}")
@Nullable
V findById(K keyArg);
}
""",
"""
@Repository
public interface TestRepository extends AbstractJdbcRepository<@Column("id") String, TestRepository.Entity> {

@Table("entities")
record Entity(@Id String id, @Column("value1") int field1, String value2, @Nullable String value3) {}
}
""",
"""
public class TestRowMapper implements JdbcResultSetMapper<TestRepository.Entity> {
public TestRepository.Entity apply(ResultSet rs) {
return null;
}
}
""");

repository.invoke("findById", "1");
verify(executor.mockConnection).prepareStatement("SELECT id, value1, value2, value3 FROM entities WHERE id = ?");
}

@Test
void whereMethodArgumentParameter() throws SQLException {
var repository = compileJdbc(List.of(newGeneratedObject("TestRowMapper")),
"""
@Repository
public interface TestRepository extends JdbcRepository {

@Table("entities")
record Entity(@Id String id, @Column("value1") int field1, String value2, @Nullable String value3) {}

@Query("SELECT %{return#selects} FROM %{return#table} WHERE %{keyArg#where}")
@Nullable
Entity findById(@Column("id") String keyArg);
}
""",
"""
public class TestRowMapper implements JdbcResultSetMapper<TestRepository.Entity> {
public TestRepository.Entity apply(ResultSet rs) {
return null;
}
}
""");

repository.invoke("findById", "1");
verify(executor.mockConnection).prepareStatement("SELECT id, value1, value2, value3 FROM entities WHERE id = ?");
}

@Test
void entityTableAndUpdateBatch() throws SQLException {
var repository = compileJdbc(List.of(), """
@Repository
public interface TestRepository extends JdbcRepository {

@Query("UPDATE %{entity#table} SET %{entity#updates} WHERE %{entity#where = @id}")
UpdateCount insert(@Batch java.util.List<Entity> entity);
}
Expand All @@ -237,7 +298,7 @@ public interface TestRepository extends JdbcRepository {
record Entity(@Id String id, @Column("value1") int field1, String value2, @Nullable String value3) {}
""");

when(executor.preparedStatement.executeLargeBatch()).thenReturn(new long[] {1L});
when(executor.preparedStatement.executeLargeBatch()).thenReturn(new long[]{1L});
repository.invoke("insert", List.of(newGeneratedObject("Entity", "1", 1, "1", "1").get()));
verify(executor.mockConnection).prepareStatement("UPDATE entities SET value1 = ?, value2 = ?, value3 = ? WHERE id = ?");
}
Expand All @@ -247,7 +308,7 @@ void entityTableAndUpdateWhereIdIsEmbedded() throws SQLException {
var repository = compileJdbc(List.of(), """
@Repository
public interface TestRepository extends JdbcRepository {

@Query("UPDATE %{entity#table} SET %{entity#updates} WHERE %{entity#where = @id}")
UpdateCount insert(Entity entity);
}
Expand Down Expand Up @@ -417,7 +478,7 @@ void entityTableAndUpdateExclude() throws SQLException {
var repository = compileJdbc(List.of(), """
@Repository
public interface TestRepository extends JdbcRepository {

@Query("UPDATE %{entity#table} SET %{entity#updates -= field1} WHERE %{entity#where = @id}")
UpdateCount insert(Entity entity);
}
Expand All @@ -435,7 +496,7 @@ void entityTableAndUpdateInclude() throws SQLException {
var repository = compileJdbc(List.of(), """
@Repository
public interface TestRepository extends JdbcRepository {

@Query("UPDATE %{entity#table} SET %{entity#updates = field1} WHERE %{entity#where = @id}")
UpdateCount insert(Entity entity);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
* @see Table
* @see Repository
*/
@Target({ElementType.FIELD, ElementType.PARAMETER, ElementType.RECORD_COMPONENT})
@Target({ElementType.FIELD, ElementType.PARAMETER, ElementType.RECORD_COMPONENT, ElementType.TYPE_USE})
@Retention(RetentionPolicy.CLASS)
public @interface Column {

Expand Down
Loading