Skip to content

Commit

Permalink
Fixing compilation errors and adding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lcavadas committed May 13, 2024
1 parent 71cd297 commit 95890d9
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,7 @@
import io.micronaut.data.runtime.operations.internal.OperationContext;
import io.micronaut.data.runtime.operations.internal.SyncCascadeOperations;
import io.micronaut.data.runtime.operations.internal.query.BindableParametersStoredQuery;
import io.micronaut.data.runtime.operations.internal.sql.AbstractSqlRepositoryOperations;
import io.micronaut.data.runtime.operations.internal.sql.SqlJsonColumnMapperProvider;
import io.micronaut.data.runtime.operations.internal.sql.SqlPreparedQuery;
import io.micronaut.data.runtime.operations.internal.sql.SqlStoredQuery;
import io.micronaut.data.runtime.operations.internal.sql.*;
import io.micronaut.data.runtime.support.AbstractConversionContext;
import io.micronaut.json.JsonMapper;
import io.micronaut.transaction.TransactionOperations;
Expand Down Expand Up @@ -198,7 +195,8 @@ public final class DefaultJdbcRepositoryOperations extends AbstractSqlRepository
JdbcSchemaHandler schemaHandler,
@Nullable JsonMapper jsonMapper,
SqlJsonColumnMapperProvider<ResultSet> sqlJsonColumnMapperProvider,
List<SqlExceptionMapper> sqlExceptionMapperList) {
List<SqlExceptionMapper> sqlExceptionMapperList,
List<SqlExecutionObserver> observers) {
super(
dataSourceName,
new ColumnNameResultSetReader(conversionService),
Expand All @@ -210,7 +208,8 @@ public final class DefaultJdbcRepositoryOperations extends AbstractSqlRepository
conversionService,
attributeConverterRegistry,
jsonMapper,
sqlJsonColumnMapperProvider);
sqlJsonColumnMapperProvider,
observers);
this.schemaTenantResolver = schemaTenantResolver;
this.schemaHandler = schemaHandler;
this.connectionOperations = connectionOperations;
Expand Down Expand Up @@ -538,8 +537,8 @@ public Optional<Number> executeUpdate(@NonNull PreparedQuery<?, Number> pq) {
try (PreparedStatement ps = prepareStatement(connection::prepareStatement, preparedQuery, true, false)) {
preparedQuery.bindParameters(new JdbcParameterBinder(connection, ps, preparedQuery));
int result = ps.executeUpdate();
if (QUERY_LOG.isTraceEnabled()) {
QUERY_LOG.trace("Update operation updated {} records", result);
for (SqlExecutionObserver observer : observers) {
observer.updatedRecords(result);
}
if (preparedQuery.isOptimisticLock()) {
checkOptimisticLocking(1, result);
Expand Down Expand Up @@ -847,8 +846,8 @@ public <R> R execute(@NonNull ConnectionCallback<R> callback) {
public <R> R prepareStatement(@NonNull String sql, @NonNull PreparedStatementCallback<R> callback) {
ArgumentUtils.requireNonNull("sql", sql);
ArgumentUtils.requireNonNull("callback", callback);
if (QUERY_LOG.isDebugEnabled()) {
QUERY_LOG.debug("Executing Query: {}", sql);
for (SqlExecutionObserver observer : observers) {
observer.query(sql);
}
ConnectionContext connectionCtx = getConnectionCtx();
try {
Expand Down Expand Up @@ -1169,8 +1168,8 @@ private PreparedStatement prepare(Connection connection, SqlStoredQuery<T, ?> st

@Override
protected void execute() throws SQLException {
if (QUERY_LOG.isDebugEnabled()) {
QUERY_LOG.debug("Executing SQL query: {}", storedQuery.getQuery());
for (SqlExecutionObserver observer : observers) {
observer.query(storedQuery.getQuery());
}
try {
if (storedQuery.getOperationType() == StoredQuery.OperationType.INSERT_RETURNING
Expand Down Expand Up @@ -1292,8 +1291,8 @@ private void setParameters(PreparedStatement stmt, SqlStoredQuery<T, ?> storedQu

@Override
protected void execute() {
if (QUERY_LOG.isDebugEnabled()) {
QUERY_LOG.debug("Executing SQL query: {}", storedQuery.getQuery());
for (SqlExecutionObserver observer : observers) {
observer.query(storedQuery.getQuery());
}
if (storedQuery.getOperationType() == StoredQuery.OperationType.INSERT_RETURNING
|| storedQuery.getOperationType() == StoredQuery.OperationType.UPDATE_RETURNING) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,7 @@
import io.micronaut.data.runtime.operations.internal.OperationContext;
import io.micronaut.data.runtime.operations.internal.ReactiveCascadeOperations;
import io.micronaut.data.runtime.operations.internal.query.BindableParametersStoredQuery;
import io.micronaut.data.runtime.operations.internal.sql.AbstractSqlRepositoryOperations;
import io.micronaut.data.runtime.operations.internal.sql.SqlJsonColumnMapperProvider;
import io.micronaut.data.runtime.operations.internal.sql.SqlPreparedQuery;
import io.micronaut.data.runtime.operations.internal.sql.SqlStoredQuery;
import io.micronaut.data.runtime.operations.internal.sql.*;
import io.micronaut.data.runtime.support.AbstractConversionContext;
import io.micronaut.json.JsonMapper;
import io.micronaut.transaction.exceptions.TransactionSystemException;
Expand Down Expand Up @@ -188,7 +185,8 @@ final class DefaultR2dbcRepositoryOperations extends AbstractSqlRepositoryOperat
SqlJsonColumnMapperProvider<Row> sqlJsonColumnMapperProvider,
List<R2dbcExceptionMapper> r2dbcExceptionMapperList,
@Parameter R2dbcReactorTransactionOperations transactionOperations,
@Parameter ReactorConnectionOperations<Connection> connectionOperations) {
@Parameter ReactorConnectionOperations<Connection> connectionOperations,
List<SqlExecutionObserver> observers) {
super(
dataSourceName,
new ColumnNameR2dbcResultReader(conversionService),
Expand All @@ -200,7 +198,8 @@ final class DefaultR2dbcRepositoryOperations extends AbstractSqlRepositoryOperat
conversionService,
attributeConverterRegistry,
jsonMapper,
sqlJsonColumnMapperProvider);
sqlJsonColumnMapperProvider,
observers);
this.connectionFactory = connectionFactory;
this.ioExecutorService = executorService;
this.schemaTenantResolver = schemaTenantResolver;
Expand Down Expand Up @@ -545,8 +544,8 @@ public Mono<Number> executeUpdate(@NonNull PreparedQuery<?, Number> pq) {
preparedQuery.bindParameters(new R2dbcParameterBinder(connection, statement, preparedQuery));
return executeAndGetRowsUpdatedSingle(statement, dialect)
.flatMap((Number rowsUpdated) -> {
if (QUERY_LOG.isTraceEnabled()) {
QUERY_LOG.trace("Update operation updated {} records", rowsUpdated);
for (SqlExecutionObserver observer : observers) {
observer.updatedRecords(rowsUpdated);
}
if (preparedQuery.isOptimisticLock()) {
checkOptimisticLocking(1, rowsUpdated);
Expand Down Expand Up @@ -950,8 +949,8 @@ private <T> Mono<T> executeAndMapEachRowSingle(Statement statement, Dialect dial

@Override
protected void execute() throws RuntimeException {
if (QUERY_LOG.isDebugEnabled()) {
QUERY_LOG.debug("Executing SQL query: {}", storedQuery.getQuery());
for (SqlExecutionObserver observer : observers) {
observer.query(storedQuery.getQuery());
}
Statement statement = prepare(ctx.connection);
setParameters(statement, storedQuery);
Expand Down Expand Up @@ -1045,8 +1044,8 @@ private void setParameters(Statement stmt, SqlStoredQuery<T, ?> storedQuery) {

@Override
protected void execute() throws RuntimeException {
if (QUERY_LOG.isDebugEnabled()) {
QUERY_LOG.debug("Executing SQL query: {}", storedQuery.getQuery());
for (SqlExecutionObserver observer : observers) {
observer.query(storedQuery.getQuery());
}
Statement statement;
if (hasGeneratedId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ public abstract class AbstractSqlRepositoryOperations<RS, PS, Exc extends Except
private final Map<QueryKey, SqlStoredQuery> entityInserts = new ConcurrentHashMap<>(10);
private final Map<QueryKey, SqlStoredQuery> entityUpdates = new ConcurrentHashMap<>(10);
private final Map<Association, String> associationInserts = new ConcurrentHashMap<>(10);
private final List<SqlExecutionObserver> listeners;
protected final List<SqlExecutionObserver> observers;

/**
* Default constructor.
Expand Down Expand Up @@ -145,15 +145,15 @@ protected AbstractSqlRepositoryOperations(
AttributeConverterRegistry attributeConverterRegistry,
JsonMapper jsonMapper,
SqlJsonColumnMapperProvider<RS> sqlJsonColumnMapperProvider,
List<SqlExecutionObserver> listeners) {
List<SqlExecutionObserver> observers) {
super(dateTimeProvider, runtimeEntityRegistry, conversionService, attributeConverterRegistry);
this.dataSourceName = dataSourceName;
this.columnNameResultSetReader = columnNameResultSetReader;
this.columnIndexResultSetReader = columnIndexResultSetReader;
this.preparedStatementWriter = preparedStatementWriter;
this.jsonMapper = jsonMapper;
this.sqlJsonColumnMapperProvider = sqlJsonColumnMapperProvider;
this.listeners = listeners;
this.observers = observers;
Collection<BeanDefinition<Object>> beanDefinitions = beanContext
.getBeanDefinitions(Object.class, Qualifiers.byStereotype(Repository.class));
for (BeanDefinition<Object> beanDefinition : beanDefinitions) {
Expand Down Expand Up @@ -203,7 +203,7 @@ protected <T, R> PS prepareStatement(StatementSupplier<PS> statementFunction,
}

String query = sqlPreparedQuery.getQuery();
listeners.forEach(listener -> listener.query(query));
observers.forEach(listener -> listener.query(query));
final PS ps;
try {
ps = statementFunction.create(query);
Expand Down Expand Up @@ -251,7 +251,7 @@ protected void setStatementParameter(PS preparedStatement, int index, DataType d

dataType = dialect.getDataType(dataType);

for (SqlExecutionObserver listener : listeners) {
for (SqlExecutionObserver listener : observers) {
listener.parameter(index, value, dataType);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,11 @@ public void parameter(int index, Object value, DataType dataType) {
QUERY_LOG.trace("Binding parameter at position {} to value {} with data type: {}", index, value, dataType);
}
}

@Override
public void updatedRecords(Number result) {
if (QUERY_LOG.isTraceEnabled()) {
QUERY_LOG.trace("Update operation updated {} records", result);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ public interface SqlExecutionObserver {
void query(String query);

void parameter(int index, Object value, DataType datatype);

void updatedRecords(Number result);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package io.micronaut.data.tck

import io.micronaut.data.model.DataType
import io.micronaut.data.runtime.operations.internal.sql.SqlExecutionObserver
import jakarta.inject.Singleton

@Singleton
class TestSqlExecutionObserver implements SqlExecutionObserver {
public List<Invocation> invocations = new ArrayList<>()

@Override
void query(String query) {
invocations.add(new Invocation(query))
}

@Override
void parameter(int index, Object value, DataType datatype) {
invocations.last().parameters[index] = value
}

@Override
void updatedRecords(Number result) {
invocations.last().affected = result
}

void clear() {
invocations.clear()
}

class Invocation {
String query
Map<Integer, Object> parameters = [:]
Number affected

Invocation(String query) {
this.query = query
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import io.micronaut.data.repository.jpa.criteria.DeleteSpecification
import io.micronaut.data.repository.jpa.criteria.PredicateSpecification
import io.micronaut.data.repository.jpa.criteria.QuerySpecification
import io.micronaut.data.repository.jpa.criteria.UpdateSpecification
import io.micronaut.data.tck.TestSqlExecutionObserver
import io.micronaut.data.tck.entities.Author
import io.micronaut.data.tck.entities.AuthorBooksDto
import io.micronaut.data.tck.entities.AuthorDtoWithBookDtos
Expand Down Expand Up @@ -60,27 +61,16 @@ import jakarta.persistence.criteria.CriteriaBuilder
import jakarta.persistence.criteria.CriteriaUpdate
import jakarta.persistence.criteria.Predicate
import jakarta.persistence.criteria.Root
import spock.lang.AutoCleanup
import spock.lang.IgnoreIf
import spock.lang.Shared
import spock.lang.Specification
import spock.lang.Unroll
import spock.lang.*

import java.sql.Connection
import java.time.LocalDate
import java.time.ZoneId
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import java.util.stream.Collectors

import static io.micronaut.data.tck.repositories.BookSpecifications.hasChapter
import static io.micronaut.data.tck.repositories.BookSpecifications.titleEquals
import static io.micronaut.data.tck.repositories.BookSpecifications.titleEqualsWithJoin
import static io.micronaut.data.tck.repositories.PersonRepository.Specifications.distinct
import static io.micronaut.data.tck.repositories.PersonRepository.Specifications.idsIn
import static io.micronaut.data.tck.repositories.PersonRepository.Specifications.nameEquals
import static io.micronaut.data.tck.repositories.PersonRepository.Specifications.setIncome
import static io.micronaut.data.tck.repositories.PersonRepository.Specifications.setName
import static io.micronaut.data.tck.repositories.BookSpecifications.*
import static io.micronaut.data.tck.repositories.PersonRepository.Specifications.*

abstract class AbstractRepositorySpec extends Specification {

Expand Down Expand Up @@ -118,6 +108,9 @@ abstract class AbstractRepositorySpec extends Specification {
@Shared
Optional<SynchronousTransactionManager<Connection>> transactionManager = context.findBean(SynchronousTransactionManager)

@Shared
TestSqlExecutionObserver observer = context.getBean(TestSqlExecutionObserver)

ApplicationContext getApplicationContext() {
return context
}
Expand Down Expand Up @@ -2802,6 +2795,99 @@ abstract class AbstractRepositorySpec extends Specification {
entityWithIdClass2Repository.deleteAll()
}

void "observer receives inserts"() {
given:
observer.clear()

when:
bookRepository.save(new Book(title: "Anonymous", totalPages: 400))

then:
observer.invocations.size() == 1
observer.invocations.get(0).query =~ /(?i)insert\s+into\s+.*/
observer.invocations.get(0).parameters[3] == "Anonymous"
observer.invocations.get(0).parameters[4] == 400

cleanup:
cleanupData()
}

void "observer receives query"() {
given:
observer.clear()

when:
bookRepository.findById(1)

then:
observer.invocations.size() == 1
observer.invocations.get(0).query =~ /(?i)select\s+.*\s+from\s+.book.\s+.*/
observer.invocations.get(0).parameters == [1: 1]


cleanup:
cleanupData()
}

void "observer receives update"() {
given:
setupBooks()
def book = bookRepository.findAllByTitleStartingWith("Along Came a Spider").first()
def author = authorRepository.searchByName("Stephen King")
observer.clear()

when:
bookRepository.updateAuthor(book.id, author)

then:
observer.invocations.size() == 1
observer.invocations[0].query =~ /(?i)update\s+.book.\s+.*/
observer.invocations[0].parameters[1] == author.id
observer.invocations[0].parameters[3] == book.id
observer.invocations[0].affected == 1

cleanup:
cleanupData()
}

void "observer receives delete"() {
given:
setupBooks()
def book = bookRepository.findAllByTitleStartingWith("Along Came a Spider").first()
observer.clear()

when:
bookRepository.delete(book)

then:
observer.invocations.size() == 1
observer.invocations[0].query =~ /(?i)delete\s+from\s+.book.\s+.*/
observer.invocations[0].parameters == [1: book.id]

cleanup:
cleanupData()
}

void "observer receives @Query"(){
given:
saveSampleBooks()
observer.clear()

when:
def book = bookDtoRepository.findByTitleWithQuery("The Stand")

then:
book.isPresent()
book.get().title == "The Stand"

observer.invocations.size() == 1
observer.invocations[0].query =~ /select \* from book b where b.title = .*/
observer.invocations[0].parameters == [1: "The Stand"]

cleanup:
cleanupData()
}

private GregorianCalendar getYearMonthDay(Date dateCreated) {
def cal = dateCreated.toCalendar()
def localDate = LocalDate.of(cal.get(Calendar.YEAR), cal.get(Calendar.MONTH) + 1, cal.get(Calendar.DAY_OF_MONTH))
Expand Down

0 comments on commit 95890d9

Please sign in to comment.