Skip to content

Commit

Permalink
Batch DML support (#107)
Browse files Browse the repository at this point in the history
* batch dml support; fixes #92
  • Loading branch information
dmitry-s authored Jun 11, 2019
1 parent b320ec4 commit 3aa0278
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 53 deletions.
42 changes: 23 additions & 19 deletions src/main/java/com/google/cloud/spanner/r2dbc/SpannerStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.google.cloud.spanner.r2dbc.statement.TypedNull;
import com.google.cloud.spanner.r2dbc.util.Assert;
import com.google.protobuf.Struct;
import com.google.spanner.v1.ExecuteBatchDmlResponse;
import com.google.spanner.v1.PartialResultSet;
import com.google.spanner.v1.Session;
import io.r2dbc.spi.Result;
Expand Down Expand Up @@ -52,6 +53,8 @@ public class SpannerStatement implements Statement {

private StatementBindings statementBindings;

private StatementType statementType;

/**
* Creates a Spanner statement for a given SQL statement.
*
Expand All @@ -73,6 +76,7 @@ public SpannerStatement(
this.transaction = transaction;
this.sql = Assert.requireNonNull(sql, "SQL string can not be null");
this.statementBindings = new StatementBindings();
this.statementType = StatementParser.getStatementType(this.sql);
}

@Override
Expand Down Expand Up @@ -108,34 +112,34 @@ public Statement bindNull(int i, Class<?> type) {

@Override
public Publisher<? extends Result> execute() {
Flux<Struct> structFlux = Flux.fromIterable(this.statementBindings.getBindings());
StatementType statementType = StatementParser.getStatementType(this.sql);

if (statementType == StatementType.SELECT) {
return structFlux.flatMap(struct -> runSingleStatement(struct, statementType));
switch (this.statementType) {
case DML:
return this.client
.executeBatchDml(this.session, this.transaction, this.sql,
this.statementBindings.getBindings(),
this.statementBindings.getTypes())
.flatMapIterable(ExecuteBatchDmlResponse::getResultSetsList)
.map(resultSet -> new SpannerResult(Flux.empty(),
Mono.just(Math.toIntExact(resultSet.getStats().getRowCountExact()))));
case SELECT:
Flux<Struct> structFlux = Flux.fromIterable(this.statementBindings.getBindings());
return structFlux.flatMap(this::runSelectStatement);
default:
throw new UnsupportedOperationException("Unsupported statement type " + this.statementType);
}
// DML statements have to be executed sequentially because they need seqNo to be in order
return structFlux.concatMapDelayError(struct -> runSingleStatement(struct, statementType));
}

private Mono<? extends Result> runSingleStatement(Struct params, StatementType statementType) {
private Mono<? extends Result> runSelectStatement(Struct params) {
PartialResultRowExtractor partialResultRowExtractor = new PartialResultRowExtractor();

Flux<PartialResultSet> resultSetFlux =
this.client.executeStreamingSql(
this.session, this.transaction, this.sql, params, this.statementBindings.getTypes());

if (statementType == StatementType.SELECT) {
return resultSetFlux
.flatMapIterable(partialResultRowExtractor, getPartialResultSetFetchSize())
.transform(result -> Mono.just(new SpannerResult(result, Mono.just(0))))
.next();
} else {
return resultSetFlux
.last()
.map(partialResultSet -> Math.toIntExact(partialResultSet.getStats().getRowCountExact()))
.map(rowCount -> new SpannerResult(Flux.empty(), Mono.just(rowCount)));
}
return resultSetFlux
.flatMapIterable(partialResultRowExtractor, getPartialResultSetFetchSize())
.transform(result -> Mono.just(new SpannerResult(result, Mono.just(0))))
.next();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
import com.google.cloud.spanner.r2dbc.SpannerTransactionContext;
import com.google.protobuf.Struct;
import com.google.spanner.v1.CommitResponse;
import com.google.spanner.v1.ExecuteBatchDmlResponse;
import com.google.spanner.v1.PartialResultSet;
import com.google.spanner.v1.Session;
import com.google.spanner.v1.Transaction;
import com.google.spanner.v1.Type;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import reactor.core.publisher.Flux;
Expand Down Expand Up @@ -83,6 +85,13 @@ default Flux<PartialResultSet> executeStreamingSql(
return executeStreamingSql(session, transaction, sql, null, null);
}

/**
* Execute DML batch.
*/
Mono<ExecuteBatchDmlResponse> executeBatchDml(Session session,
@Nullable SpannerTransactionContext transactionContext, String sql,
List<Struct> params, Map<String, Type> types);

/**
* Release any resources held by the {@link Client}.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import com.google.spanner.v1.CommitResponse;
import com.google.spanner.v1.CreateSessionRequest;
import com.google.spanner.v1.DeleteSessionRequest;
import com.google.spanner.v1.ExecuteBatchDmlRequest;
import com.google.spanner.v1.ExecuteBatchDmlResponse;
import com.google.spanner.v1.ExecuteSqlRequest;
import com.google.spanner.v1.PartialResultSet;
import com.google.spanner.v1.RollbackRequest;
Expand All @@ -42,6 +44,7 @@
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.auth.MoreCallCredentials;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import reactor.core.publisher.Flux;
Expand Down Expand Up @@ -159,6 +162,31 @@ public Mono<Void> deleteSession(Session session) {
});
}

@Override
public Mono<ExecuteBatchDmlResponse> executeBatchDml(Session session,
@Nullable SpannerTransactionContext transactionContext, String sql,
List<Struct> params, Map<String, Type> types) {

ExecuteBatchDmlRequest.Builder request = ExecuteBatchDmlRequest.newBuilder()
.setSession(session.getName());
if (transactionContext != null && transactionContext.getTransaction() != null) {
request.setTransaction(
TransactionSelector.newBuilder().setId(transactionContext.getTransaction().getId())
.build())
.setSeqno(transactionContext.nextSeqNum());

}
for (Struct paramsStruct : params) {
ExecuteBatchDmlRequest.Statement statement = ExecuteBatchDmlRequest.Statement.newBuilder()
.setSql(sql).setParams(paramsStruct).putAllParamTypes(types)
.build();
request.addStatements(statement);
}

return ObservableReactiveUtil
.unaryCall(obs -> this.spanner.executeBatchDml(request.build(), obs));
}

@Override
public Flux<PartialResultSet> executeStreamingSql(
Session session, @Nullable SpannerTransactionContext transactionContext, String sql,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
import com.google.cloud.spanner.r2dbc.client.Client;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import com.google.spanner.v1.ExecuteBatchDmlResponse;
import com.google.spanner.v1.PartialResultSet;
import com.google.spanner.v1.ResultSet;
import com.google.spanner.v1.ResultSetMetadata;
import com.google.spanner.v1.ResultSetStats;
import com.google.spanner.v1.Session;
Expand All @@ -37,6 +39,7 @@
import com.google.spanner.v1.Type;
import com.google.spanner.v1.TypeCode;
import io.r2dbc.spi.Result;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -195,23 +198,28 @@ public void readMultiResultSetQueryTest() {

when(this.mockClient.executeStreamingSql(any(), any(), any(), any(), any())).thenReturn(inputs);

StepVerifier.create(Flux.from(new SpannerStatement(this.mockClient, null, null, "").execute())
StepVerifier
.create(Flux.from(new SpannerStatement(this.mockClient, null, null, "SELECT").execute())
.flatMap(r -> Mono.from(r.getRowsUpdated())))
.expectNext(0)
.verifyComplete();
}

@Test
public void readDmlQueryTest() {
PartialResultSet p1 = PartialResultSet.newBuilder().setStats(
ResultSetStats.newBuilder().setRowCountExact(555).build()
).build();
ResultSet resultSet = ResultSet.newBuilder()
.setStats(ResultSetStats.newBuilder().setRowCountExact(555).build())
.build();

Flux<PartialResultSet> inputs = Flux.just(p1);
ExecuteBatchDmlResponse executeBatchDmlResponse = ExecuteBatchDmlResponse.newBuilder()
.addResultSets(resultSet)
.build();

when(this.mockClient.executeStreamingSql(any(), any(), any(), any(), any())).thenReturn(inputs);
when(this.mockClient.executeBatchDml(any(), any(), any(), any(), any()))
.thenReturn(Mono.just(executeBatchDmlResponse));

StepVerifier.create(Flux.from(new SpannerStatement(this.mockClient, null, null, "").execute())
StepVerifier.create(
Flux.from(new SpannerStatement(this.mockClient, null, null, "Insert into books").execute())
.flatMap(r -> Mono.from(r.getRowsUpdated())))
.expectNext(555)
.verifyComplete();
Expand All @@ -221,13 +229,19 @@ public void readDmlQueryTest() {
public void noopMapOnUpdateQueriesWhenNoRowsAffected() {
Client mockClient = mock(Client.class);
String sql = "delete from Books where true";
PartialResultSet partialResultSet = PartialResultSet.newBuilder()

ResultSet resultSet = ResultSet.newBuilder()
.setMetadata(ResultSetMetadata.getDefaultInstance())
.setStats(ResultSetStats.getDefaultInstance())
.build();
when(mockClient.executeStreamingSql(TEST_SESSION, null, sql,
Struct.newBuilder().build(), Collections.EMPTY_MAP))
.thenReturn(Flux.just(partialResultSet));

ExecuteBatchDmlResponse executeBatchDmlResponse = ExecuteBatchDmlResponse.newBuilder()
.addResultSets(resultSet)
.build();

when(mockClient.executeBatchDml(TEST_SESSION, null, sql,
Arrays.asList(Struct.newBuilder().build()), Collections.EMPTY_MAP))
.thenReturn(Mono.just(executeBatchDmlResponse));

SpannerStatement statement
= new SpannerStatement(mockClient, TEST_SESSION, null, sql);
Expand All @@ -244,7 +258,7 @@ public void noopMapOnUpdateQueriesWhenNoRowsAffected() {
.expectNext(0)
.verifyComplete();

verify(mockClient, times(2)).executeStreamingSql(TEST_SESSION, null, sql,
Struct.newBuilder().build(), Collections.EMPTY_MAP);
verify(mockClient, times(1)).executeBatchDml(TEST_SESSION, null, sql,
Arrays.asList(Struct.newBuilder().build()), Collections.EMPTY_MAP);
}
}
67 changes: 46 additions & 21 deletions src/test/java/com/google/cloud/spanner/r2dbc/it/SpannerIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

/**
* Integration test for connecting to a real Spanner instance.
Expand Down Expand Up @@ -252,25 +253,48 @@ public void testQuerying() {

Mono.from(this.connectionFactory.create())
.delayUntil(c -> c.beginTransaction())
.delayUntil(c -> Flux.from(c.createStatement(
"INSERT BOOKS (UUID, TITLE, AUTHOR, CATEGORY, FICTION, PUBLISHED, WORDS_PER_SENTENCE)"
+ " VALUES (@uuid, @title, @author, @category, @fiction, @published, @wps);")
.bind("uuid", "2b2cbd78-ecd8-430e-b685-fa7910f8a4c7")
.bind("author", "Douglas Crockford")
.bind("category", 100L)
.bind("title", "JavaScript: The Good Parts")
.bind("fiction", true)
.bind("published", LocalDate.of(2008, 5, 1))
.bind("wps", 20.8)
.add()
.bind("uuid", "df0e3d06-2743-4691-8e51-6d33d90c5cb9")
.bind("author", "Joshua Bloch")
.bind("category", 100L)
.bind("title", "Effective Java")
.bind("fiction", false)
.bind("published", LocalDate.of(2018, 1, 6))
.bind("wps", 15.1)
.execute()).flatMapSequential(r -> Mono.from(r.getRowsUpdated())))
.delayUntil(c ->
Mono.fromRunnable(() ->
StepVerifier.create(Flux.from(c.createStatement(
"INSERT BOOKS "
+ "(UUID, TITLE, AUTHOR, CATEGORY, FICTION, PUBLISHED, WORDS_PER_SENTENCE)"
+ " VALUES "
+ "(@uuid, @title, @author, @category, @fiction, @published, @wps);")
.bind("uuid", "2b2cbd78-ecd8-430e-b685-fa7910f8a4c7")
.bind("author", "Douglas Crockford")
.bind("category", 100L)
.bind("title", "JavaScript: The Good Parts")
.bind("fiction", true)
.bind("published", LocalDate.of(2008, 5, 1))
.bind("wps", 20.8)
.add()
.bind("uuid", "df0e3d06-2743-4691-8e51-6d33d90c5cb9")
.bind("author", "Joshua Bloch")
.bind("category", 100L)
.bind("title", "Effective Java")
.bind("fiction", false)
.bind("published", LocalDate.of(2018, 1, 6))
.bind("wps", 15.1)
.execute())
.flatMapSequential(r -> Mono.from(r.getRowsUpdated())))
.expectNext(1).expectNext(1).verifyComplete())
)
.delayUntil(c -> c.commitTransaction())
.block();

Mono.from(this.connectionFactory.create())
.delayUntil(c -> c.beginTransaction())
.delayUntil(c ->
Mono.fromRunnable(() ->
StepVerifier
.create(Flux.from(c.createStatement(
"UPDATE BOOKS SET CATEGORY = @new_cat WHERE CATEGORY = @old_cat")
.bind("new_cat", 101L)
.bind("old_cat", 100L)
.execute())
.flatMap(r -> Mono.from(r.getRowsUpdated())))
.expectNext(2).verifyComplete())
)
.delayUntil(c -> c.commitTransaction())
.block();

Expand Down Expand Up @@ -353,12 +377,13 @@ private int executeDmlQuery(String sql) {
Connection connection = Mono.from(connectionFactory.create()).block();

Mono.from(connection.beginTransaction()).block();
int rowsUpdated = Mono.from(connection.createStatement(sql).execute())
List<Integer> rowsUpdatedPerStatement = Flux.from(connection.createStatement(sql).execute())
.flatMap(result -> Mono.from(result.getRowsUpdated()))
.collectList()
.block();
Mono.from(connection.commitTransaction()).block();

return rowsUpdated;
return rowsUpdatedPerStatement.get(0);
}

/**
Expand Down

0 comments on commit 3aa0278

Please sign in to comment.