Skip to content

Commit

Permalink
Optimize ShardingSpherePreparedStatement for multi executionContext
Browse files Browse the repository at this point in the history
  • Loading branch information
tuichenchuxin committed Oct 19, 2023
1 parent 299fecc commit e52bb9b
Showing 1 changed file with 82 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.shardingsphere.driver.jdbc.core.statement;

import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import lombok.AccessLevel;
import lombok.Getter;
Expand Down Expand Up @@ -150,7 +151,7 @@ public final class ShardingSpherePreparedStatement extends AbstractPreparedState
@Getter
private final boolean selectContainsEnhancedTable;

private ExecutionContext executionContext;
private Collection<ExecutionContext> executionContexts;

private Map<String, Integer> columnLabelAndIndexMap;

Expand Down Expand Up @@ -244,14 +245,8 @@ public ResultSet executeQuery() throws SQLException {
if (useFederation) {
return executeFederationQuery(queryContext);
}
executionContext = createExecutionContext(queryContext);
List<QueryResult> queryResults = executeQuery0();
MergedResult mergedResult = mergeQuery(queryResults);
List<ResultSet> resultSets = getResultSets();
if (null == columnLabelAndIndexMap) {
columnLabelAndIndexMap = ShardingSphereResultSetUtils.createColumnLabelAndIndexMap(sqlStatementContext, selectContainsEnhancedTable, resultSets.get(0).getMetaData());
}
result = new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, executionContext, columnLabelAndIndexMap);
executionContexts = createExecutionContext(queryContext);
result = doExecuteQuery(executionContexts);
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
// CHECKSTYLE:ON
Expand All @@ -264,6 +259,22 @@ public ResultSet executeQuery() throws SQLException {
return result;
}

private ShardingSphereResultSet doExecuteQuery(final Collection<ExecutionContext> executionContexts) throws SQLException {
ShardingSphereResultSet result = null;
for (ExecutionContext each : executionContexts) {
List<QueryResult> queryResults = executeQuery0(each);
MergedResult mergedResult = mergeQuery(queryResults, each.getSqlStatementContext());
List<ResultSet> resultSets = getResultSets();
if (null == columnLabelAndIndexMap) {
columnLabelAndIndexMap = ShardingSphereResultSetUtils.createColumnLabelAndIndexMap(sqlStatementContext, selectContainsEnhancedTable, resultSets.get(0).getMetaData());
}
if (null == result) {
result = new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, each, columnLabelAndIndexMap);
}
}
return result;
}

private boolean decide(final QueryContext queryContext, final ShardingSphereDatabase database, final RuleMetaData globalRuleMetaData) {
return executor.getSqlFederationEngine().decide(queryContext.getSqlStatementContext(), queryContext.getParameters(), database, globalRuleMetaData);
}
Expand Down Expand Up @@ -309,12 +320,12 @@ private void resetParameters() throws SQLException {
replaySetParameter();
}

private List<QueryResult> executeQuery0() throws SQLException {
private List<QueryResult> executeQuery0(final ExecutionContext executionContext) throws SQLException {
if (hasRawExecutionRule()) {
return executor.getRawExecutor().execute(createRawExecutionGroupContext(),
return executor.getRawExecutor().execute(createRawExecutionGroupContext(executionContext),
executionContext.getQueryContext(), new RawSQLExecutorCallback()).stream().map(QueryResult.class::cast).collect(Collectors.toList());
}
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext();
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext(executionContext);
cacheStatements(executionGroupContext.getInputGroups());
return executor.getRegularExecutor().executeQuery(executionGroupContext, executionContext.getQueryContext(),
new PreparedStatementExecuteQueryCallback(metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType(),
Expand Down Expand Up @@ -351,12 +362,15 @@ public int executeUpdate() throws SQLException {
JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, queryContext);
return executor.getTrafficExecutor().execute(executionUnit, (statement, sql) -> ((PreparedStatement) statement).executeUpdate());
}
executionContext = createExecutionContext(queryContext);
executionContexts = createExecutionContext(queryContext);
if (hasRawExecutionRule()) {
Collection<ExecuteResult> executeResults = executor.getRawExecutor().execute(createRawExecutionGroupContext(), executionContext.getQueryContext(), new RawSQLExecutorCallback());
return accumulate(executeResults);
Collection<ExecuteResult> results = new LinkedList<>();
for (ExecutionContext each : executionContexts) {
results.addAll(executor.getRawExecutor().execute(createRawExecutionGroupContext(each), each.getQueryContext(), new RawSQLExecutorCallback()));
}
return accumulate(results);
}
return isNeedImplicitCommitTransaction(connection, Collections.singleton(executionContext)) ? executeUpdateWithImplicitCommitTransaction() : useDriverToExecuteUpdate();
return isNeedImplicitCommitTransaction(connection, executionContexts) ? executeUpdateWithImplicitCommitTransaction() : useDriverToExecuteUpdate();
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
// CHECKSTYLE:ON
Expand All @@ -368,10 +382,18 @@ public int executeUpdate() throws SQLException {
}

private int useDriverToExecuteUpdate() throws SQLException {
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext();
cacheStatements(executionGroupContext.getInputGroups());
return executor.getRegularExecutor().executeUpdate(executionGroupContext,
executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), createExecuteUpdateCallback());
Integer result = null;
Preconditions.checkArgument(!executionContexts.isEmpty());
for (ExecutionContext each : executionContexts) {
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext(each);
cacheStatements(executionGroupContext.getInputGroups());
int effectedCount = executor.getRegularExecutor().executeUpdate(executionGroupContext,
each.getQueryContext(), each.getRouteContext().getRouteUnits(), createExecuteUpdateCallback());
if (null == result) {
result = effectedCount;
}
}
return result;
}

private int accumulate(final Collection<ExecuteResult> results) {
Expand Down Expand Up @@ -420,13 +442,16 @@ public boolean execute() throws SQLException {
ResultSet resultSet = executeFederationQuery(queryContext);
return null != resultSet;
}
executionContext = createExecutionContext(queryContext);
executionContexts = createExecutionContext(queryContext);
if (hasRawExecutionRule()) {
// TODO process getStatement
Collection<ExecuteResult> executeResults = executor.getRawExecutor().execute(createRawExecutionGroupContext(), executionContext.getQueryContext(), new RawSQLExecutorCallback());
return executeResults.iterator().next() instanceof QueryResult;
Collection<ExecuteResult> results = new LinkedList<>();
for (ExecutionContext each : executionContexts) {
// TODO process getStatement
results.addAll(executor.getRawExecutor().execute(createRawExecutionGroupContext(each), each.getQueryContext(), new RawSQLExecutorCallback()));
}
return results.iterator().next() instanceof QueryResult;
}
return isNeedImplicitCommitTransaction(connection, Collections.singleton(executionContext)) ? executeWithImplicitCommitTransaction() : useDriverToExecute();
return isNeedImplicitCommitTransaction(connection, executionContexts) ? executeWithImplicitCommitTransaction() : useDriverToExecute();
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
// CHECKSTYLE:ON
Expand All @@ -446,7 +471,7 @@ private boolean hasRawExecutionRule() {
return false;
}

private ExecutionGroupContext<RawSQLExecutionUnit> createRawExecutionGroupContext() throws SQLException {
private ExecutionGroupContext<RawSQLExecutionUnit> createRawExecutionGroupContext(final ExecutionContext executionContext) throws SQLException {
int maxConnectionsSizePerQuery = metaDataContexts.getMetaData().getProps().<Integer>getValue(ConfigurationPropertyKey.MAX_CONNECTIONS_SIZE_PER_QUERY);
return new RawExecutionPrepareEngine(maxConnectionsSizePerQuery, metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().getRules())
.prepare(executionContext.getRouteContext(), executionContext.getExecutionUnits(), new ExecutionGroupReportContext(databaseName));
Expand Down Expand Up @@ -487,10 +512,18 @@ private int executeUpdateWithImplicitCommitTransaction() throws SQLException {
}

private boolean useDriverToExecute() throws SQLException {
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext();
cacheStatements(executionGroupContext.getInputGroups());
return executor.getRegularExecutor().execute(executionGroupContext,
executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), createExecuteCallback());
Boolean result = null;
Preconditions.checkArgument(!executionContexts.isEmpty());
for (ExecutionContext each : executionContexts) {
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext(each);
cacheStatements(executionGroupContext.getInputGroups());
boolean isWrite = executor.getRegularExecutor().execute(executionGroupContext,
each.getQueryContext(), each.getRouteContext().getRouteUnits(), createExecuteCallback());
if (null == result) {
result = isWrite;
}
}
return result;
}

private JDBCExecutorCallback<Boolean> createExecuteCallback() {
Expand All @@ -510,7 +543,7 @@ protected Optional<Boolean> getSaneResult(final SQLStatement sqlStatement, final
};
}

private ExecutionGroupContext<JDBCExecutionUnit> createExecutionGroupContext() throws SQLException {
private ExecutionGroupContext<JDBCExecutionUnit> createExecutionGroupContext(final ExecutionContext executionContext) throws SQLException {
DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine = createDriverExecutionPrepareEngine();
return prepareEngine.prepare(executionContext.getRouteContext(), executionContext.getExecutionUnits(), new ExecutionGroupReportContext(databaseName));
}
Expand All @@ -526,16 +559,18 @@ public ResultSet getResultSet() throws SQLException {
if (useFederation) {
return executor.getSqlFederationEngine().getResultSet();
}
if (executionContext.getSqlStatementContext() instanceof SelectStatementContext || executionContext.getSqlStatementContext().getSqlStatement() instanceof DALStatement) {
if (executionContexts.iterator().next().getSqlStatementContext() instanceof SelectStatementContext
|| executionContexts.iterator().next().getSqlStatementContext().getSqlStatement() instanceof DALStatement) {
List<ResultSet> resultSets = getResultSets();
if (resultSets.isEmpty()) {
return currentResultSet;
}
MergedResult mergedResult = mergeQuery(getQueryResults(resultSets));
SQLStatementContext sqlStatementContext = executionContexts.iterator().next().getSqlStatementContext();
MergedResult mergedResult = mergeQuery(getQueryResults(resultSets), sqlStatementContext);
if (null == columnLabelAndIndexMap) {
columnLabelAndIndexMap = ShardingSphereResultSetUtils.createColumnLabelAndIndexMap(sqlStatementContext, selectContainsEnhancedTable, resultSets.get(0).getMetaData());
}
currentResultSet = new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, executionContext, columnLabelAndIndexMap);
currentResultSet = new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, executionContexts.iterator().next(), columnLabelAndIndexMap);
}
return currentResultSet;
}
Expand All @@ -560,19 +595,19 @@ private List<QueryResult> getQueryResults(final List<ResultSet> resultSets) thro
return result;
}

private ExecutionContext createExecutionContext(final QueryContext queryContext) {
private Collection<ExecutionContext> createExecutionContext(final QueryContext queryContext) {
RuleMetaData globalRuleMetaData = metaDataContexts.getMetaData().getGlobalRuleMetaData();
ShardingSphereDatabase currentDatabase = metaDataContexts.getMetaData().getDatabase(databaseName);
SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext());
ExecutionContext result = kernelProcessor.generateExecutionContext(
queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), connection.getDatabaseConnectionManager().getConnectionContext());
findGeneratedKey(result).ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues()));
return result;
return Collections.singleton(result);
}

private ExecutionContext createExecutionContext(final QueryContext queryContext, final String trafficInstanceId) {
private Collection<ExecutionContext> createExecutionContext(final QueryContext queryContext, final String trafficInstanceId) {
ExecutionUnit executionUnit = new ExecutionUnit(trafficInstanceId, new SQLUnit(queryContext.getSql(), queryContext.getParameters()));
return new ExecutionContext(queryContext, Collections.singletonList(executionUnit), new RouteContext());
return Collections.singleton(new ExecutionContext(queryContext, Collections.singletonList(executionUnit), new RouteContext()));
}

private QueryContext createQueryContext() {
Expand All @@ -583,10 +618,10 @@ private QueryContext createQueryContext() {
return new QueryContext(sqlStatementContext, sql, params, hintValueContext, true);
}

private MergedResult mergeQuery(final List<QueryResult> queryResults) throws SQLException {
private MergedResult mergeQuery(final List<QueryResult> queryResults, final SQLStatementContext sqlStatementContext) throws SQLException {
MergeEngine mergeEngine = new MergeEngine(metaDataContexts.getMetaData().getDatabase(databaseName),
metaDataContexts.getMetaData().getProps(), connection.getDatabaseConnectionManager().getConnectionContext());
return mergeEngine.merge(queryResults, executionContext.getSqlStatementContext());
return mergeEngine.merge(queryResults, sqlStatementContext);
}

private void cacheStatements(final Collection<ExecutionGroup<JDBCExecutionUnit>> executionGroups) throws SQLException {
Expand Down Expand Up @@ -629,7 +664,7 @@ public ResultSet getGeneratedKeys() throws SQLException {
if (null != currentBatchGeneratedKeysResultSet) {
return currentBatchGeneratedKeysResultSet;
}
Optional<GeneratedKeyContext> generatedKey = findGeneratedKey(executionContext);
Optional<GeneratedKeyContext> generatedKey = findGeneratedKey(executionContexts.iterator().next());
if (generatedKey.isPresent() && statementOption.isReturnGeneratedKeys() && !generatedValues.isEmpty()) {
return new GeneratedKeysResultSet(getGeneratedKeysColumnName(generatedKey.get().getColumnName()), generatedValues.iterator(), this);
}
Expand All @@ -652,8 +687,8 @@ public void addBatch() {
try {
QueryContext queryContext = createQueryContext();
trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null);
executionContext = null == trafficInstanceId ? createExecutionContext(queryContext) : createExecutionContext(queryContext, trafficInstanceId);
batchPreparedStatementExecutor.addBatchForExecutionUnits(executionContext.getExecutionUnits());
executionContexts = null == trafficInstanceId ? createExecutionContext(queryContext) : createExecutionContext(queryContext, trafficInstanceId);
batchPreparedStatementExecutor.addBatchForExecutionUnits(executionContexts.iterator().next().getExecutionUnits());
} finally {
currentResultSet = null;
clearParameters();
Expand All @@ -662,13 +697,13 @@ public void addBatch() {

@Override
public int[] executeBatch() throws SQLException {
if (null == executionContext) {
if (null == executionContexts || executionContexts.isEmpty()) {
return new int[0];
}
try {
// TODO add raw SQL executor
initBatchPreparedStatementExecutor();
int[] results = batchPreparedStatementExecutor.executeBatch(executionContext.getSqlStatementContext());
int[] results = batchPreparedStatementExecutor.executeBatch(executionContexts.iterator().next().getSqlStatementContext());
if (statementOption.isReturnGeneratedKeys() && generatedValues.isEmpty()) {
List<Statement> batchPreparedStatementExecutorStatements = batchPreparedStatementExecutor.getStatements();
for (Statement statement : batchPreparedStatementExecutorStatements) {
Expand Down Expand Up @@ -698,7 +733,7 @@ private void initBatchPreparedStatementExecutor() throws SQLException {
ExecutionUnit executionUnit = each.getExecutionUnit();
executionUnits.add(executionUnit);
}
batchPreparedStatementExecutor.init(prepareEngine.prepare(executionContext.getRouteContext(), executionUnits, new ExecutionGroupReportContext(databaseName)));
batchPreparedStatementExecutor.init(prepareEngine.prepare(executionContexts.iterator().next().getRouteContext(), executionUnits, new ExecutionGroupReportContext(databaseName)));
setBatchParametersForStatements();
}

Expand Down Expand Up @@ -739,7 +774,7 @@ public int getResultSetHoldability() {
@Override
public boolean isAccumulate() {
return metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().findRules(DataNodeContainedRule.class).stream()
.anyMatch(each -> each.isNeedAccumulate(executionContext.getSqlStatementContext().getTablesContext().getTableNames()));
.anyMatch(each -> each.isNeedAccumulate(executionContexts.iterator().next().getSqlStatementContext().getTablesContext().getTableNames()));
}

@Override
Expand Down

0 comments on commit e52bb9b

Please sign in to comment.