Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize ShardingSpherePreparedStatement for multi executionContext #28802

Merged
merged 2 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.shardingsphere.driver.jdbc.core.statement;

import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import lombok.AccessLevel;
import lombok.Getter;
Expand Down Expand Up @@ -150,7 +151,7 @@ public final class ShardingSpherePreparedStatement extends AbstractPreparedState
@Getter
private final boolean selectContainsEnhancedTable;

private ExecutionContext executionContext;
private Collection<ExecutionContext> executionContexts;

private Map<String, Integer> columnLabelAndIndexMap;

Expand Down Expand Up @@ -244,14 +245,8 @@ public ResultSet executeQuery() throws SQLException {
if (useFederation) {
return executeFederationQuery(queryContext);
}
executionContext = createExecutionContext(queryContext);
List<QueryResult> queryResults = executeQuery0();
MergedResult mergedResult = mergeQuery(queryResults);
List<ResultSet> resultSets = getResultSets();
if (null == columnLabelAndIndexMap) {
columnLabelAndIndexMap = ShardingSphereResultSetUtils.createColumnLabelAndIndexMap(sqlStatementContext, selectContainsEnhancedTable, resultSets.get(0).getMetaData());
}
result = new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, executionContext, columnLabelAndIndexMap);
executionContexts = createExecutionContext(queryContext);
result = doExecuteQuery(executionContexts);
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
// CHECKSTYLE:ON
Expand All @@ -264,6 +259,21 @@ public ResultSet executeQuery() throws SQLException {
return result;
}

private ShardingSphereResultSet doExecuteQuery(final Collection<ExecutionContext> executionContexts) throws SQLException {
ShardingSphereResultSet result = null;
// TODO support multi execution context, currently executionContexts.size() always equals 1
for (ExecutionContext each : executionContexts) {
List<QueryResult> queryResults = executeQuery0(each);
MergedResult mergedResult = mergeQuery(queryResults, each.getSqlStatementContext());
List<ResultSet> resultSets = getResultSets();
if (null == columnLabelAndIndexMap) {
columnLabelAndIndexMap = ShardingSphereResultSetUtils.createColumnLabelAndIndexMap(sqlStatementContext, selectContainsEnhancedTable, resultSets.get(0).getMetaData());
}
result = new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, each, columnLabelAndIndexMap);
}
return result;
}

private boolean decide(final QueryContext queryContext, final ShardingSphereDatabase database, final RuleMetaData globalRuleMetaData) {
return executor.getSqlFederationEngine().decide(queryContext.getSqlStatementContext(), queryContext.getParameters(), database, globalRuleMetaData);
}
Expand Down Expand Up @@ -309,12 +319,12 @@ private void resetParameters() throws SQLException {
replaySetParameter();
}

private List<QueryResult> executeQuery0() throws SQLException {
private List<QueryResult> executeQuery0(final ExecutionContext executionContext) throws SQLException {
if (hasRawExecutionRule()) {
return executor.getRawExecutor().execute(createRawExecutionGroupContext(),
return executor.getRawExecutor().execute(createRawExecutionGroupContext(executionContext),
executionContext.getQueryContext(), new RawSQLExecutorCallback()).stream().map(QueryResult.class::cast).collect(Collectors.toList());
}
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext();
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext(executionContext);
cacheStatements(executionGroupContext.getInputGroups());
return executor.getRegularExecutor().executeQuery(executionGroupContext, executionContext.getQueryContext(),
new PreparedStatementExecuteQueryCallback(metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType(),
Expand Down Expand Up @@ -351,12 +361,15 @@ public int executeUpdate() throws SQLException {
JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, queryContext);
return executor.getTrafficExecutor().execute(executionUnit, (statement, sql) -> ((PreparedStatement) statement).executeUpdate());
}
executionContext = createExecutionContext(queryContext);
executionContexts = createExecutionContext(queryContext);
if (hasRawExecutionRule()) {
Collection<ExecuteResult> executeResults = executor.getRawExecutor().execute(createRawExecutionGroupContext(), executionContext.getQueryContext(), new RawSQLExecutorCallback());
return accumulate(executeResults);
Collection<ExecuteResult> results = new LinkedList<>();
for (ExecutionContext each : executionContexts) {
results.addAll(executor.getRawExecutor().execute(createRawExecutionGroupContext(each), each.getQueryContext(), new RawSQLExecutorCallback()));
}
return accumulate(results);
}
return isNeedImplicitCommitTransaction(connection, Collections.singleton(executionContext)) ? executeUpdateWithImplicitCommitTransaction() : useDriverToExecuteUpdate();
return isNeedImplicitCommitTransaction(connection, executionContexts) ? executeUpdateWithImplicitCommitTransaction() : useDriverToExecuteUpdate();
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
// CHECKSTYLE:ON
Expand All @@ -368,10 +381,16 @@ public int executeUpdate() throws SQLException {
}

private int useDriverToExecuteUpdate() throws SQLException {
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext();
cacheStatements(executionGroupContext.getInputGroups());
return executor.getRegularExecutor().executeUpdate(executionGroupContext,
executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), createExecuteUpdateCallback());
Integer result = null;
Preconditions.checkArgument(!executionContexts.isEmpty());
// TODO support multi execution context, currently executionContexts.size() always equals 1
for (ExecutionContext each : executionContexts) {
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext(each);
cacheStatements(executionGroupContext.getInputGroups());
result = executor.getRegularExecutor().executeUpdate(executionGroupContext,
each.getQueryContext(), each.getRouteContext().getRouteUnits(), createExecuteUpdateCallback());
}
return result;
}

private int accumulate(final Collection<ExecuteResult> results) {
Expand Down Expand Up @@ -420,13 +439,16 @@ public boolean execute() throws SQLException {
ResultSet resultSet = executeFederationQuery(queryContext);
return null != resultSet;
}
executionContext = createExecutionContext(queryContext);
executionContexts = createExecutionContext(queryContext);
if (hasRawExecutionRule()) {
// TODO process getStatement
Collection<ExecuteResult> executeResults = executor.getRawExecutor().execute(createRawExecutionGroupContext(), executionContext.getQueryContext(), new RawSQLExecutorCallback());
return executeResults.iterator().next() instanceof QueryResult;
Collection<ExecuteResult> results = new LinkedList<>();
for (ExecutionContext each : executionContexts) {
// TODO process getStatement
results.addAll(executor.getRawExecutor().execute(createRawExecutionGroupContext(each), each.getQueryContext(), new RawSQLExecutorCallback()));
}
return results.iterator().next() instanceof QueryResult;
}
return isNeedImplicitCommitTransaction(connection, Collections.singleton(executionContext)) ? executeWithImplicitCommitTransaction() : useDriverToExecute();
return isNeedImplicitCommitTransaction(connection, executionContexts) ? executeWithImplicitCommitTransaction() : useDriverToExecute();
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
// CHECKSTYLE:ON
Expand All @@ -446,7 +468,7 @@ private boolean hasRawExecutionRule() {
return false;
}

private ExecutionGroupContext<RawSQLExecutionUnit> createRawExecutionGroupContext() throws SQLException {
private ExecutionGroupContext<RawSQLExecutionUnit> createRawExecutionGroupContext(final ExecutionContext executionContext) throws SQLException {
int maxConnectionsSizePerQuery = metaDataContexts.getMetaData().getProps().<Integer>getValue(ConfigurationPropertyKey.MAX_CONNECTIONS_SIZE_PER_QUERY);
return new RawExecutionPrepareEngine(maxConnectionsSizePerQuery, metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().getRules())
.prepare(executionContext.getRouteContext(), executionContext.getExecutionUnits(), new ExecutionGroupReportContext(databaseName));
Expand Down Expand Up @@ -487,10 +509,16 @@ private int executeUpdateWithImplicitCommitTransaction() throws SQLException {
}

private boolean useDriverToExecute() throws SQLException {
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext();
cacheStatements(executionGroupContext.getInputGroups());
return executor.getRegularExecutor().execute(executionGroupContext,
executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), createExecuteCallback());
Boolean result = null;
Preconditions.checkArgument(!executionContexts.isEmpty());
// TODO support multi execution context, currently executionContexts.size() always equals 1
for (ExecutionContext each : executionContexts) {
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext(each);
cacheStatements(executionGroupContext.getInputGroups());
result = executor.getRegularExecutor().execute(executionGroupContext,
each.getQueryContext(), each.getRouteContext().getRouteUnits(), createExecuteCallback());
}
return result;
}

private JDBCExecutorCallback<Boolean> createExecuteCallback() {
Expand All @@ -510,7 +538,7 @@ protected Optional<Boolean> getSaneResult(final SQLStatement sqlStatement, final
};
}

private ExecutionGroupContext<JDBCExecutionUnit> createExecutionGroupContext() throws SQLException {
private ExecutionGroupContext<JDBCExecutionUnit> createExecutionGroupContext(final ExecutionContext executionContext) throws SQLException {
DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine = createDriverExecutionPrepareEngine();
return prepareEngine.prepare(executionContext.getRouteContext(), executionContext.getExecutionUnits(), new ExecutionGroupReportContext(databaseName));
}
Expand All @@ -526,16 +554,18 @@ public ResultSet getResultSet() throws SQLException {
if (useFederation) {
return executor.getSqlFederationEngine().getResultSet();
}
if (executionContext.getSqlStatementContext() instanceof SelectStatementContext || executionContext.getSqlStatementContext().getSqlStatement() instanceof DALStatement) {
if (executionContexts.iterator().next().getSqlStatementContext() instanceof SelectStatementContext
|| executionContexts.iterator().next().getSqlStatementContext().getSqlStatement() instanceof DALStatement) {
List<ResultSet> resultSets = getResultSets();
if (resultSets.isEmpty()) {
return currentResultSet;
}
MergedResult mergedResult = mergeQuery(getQueryResults(resultSets));
SQLStatementContext sqlStatementContext = executionContexts.iterator().next().getSqlStatementContext();
MergedResult mergedResult = mergeQuery(getQueryResults(resultSets), sqlStatementContext);
if (null == columnLabelAndIndexMap) {
columnLabelAndIndexMap = ShardingSphereResultSetUtils.createColumnLabelAndIndexMap(sqlStatementContext, selectContainsEnhancedTable, resultSets.get(0).getMetaData());
}
currentResultSet = new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, executionContext, columnLabelAndIndexMap);
currentResultSet = new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, executionContexts.iterator().next(), columnLabelAndIndexMap);
}
return currentResultSet;
}
Expand All @@ -560,19 +590,19 @@ private List<QueryResult> getQueryResults(final List<ResultSet> resultSets) thro
return result;
}

private ExecutionContext createExecutionContext(final QueryContext queryContext) {
private Collection<ExecutionContext> createExecutionContext(final QueryContext queryContext) {
RuleMetaData globalRuleMetaData = metaDataContexts.getMetaData().getGlobalRuleMetaData();
ShardingSphereDatabase currentDatabase = metaDataContexts.getMetaData().getDatabase(databaseName);
SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext());
ExecutionContext result = kernelProcessor.generateExecutionContext(
queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), connection.getDatabaseConnectionManager().getConnectionContext());
findGeneratedKey(result).ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues()));
return result;
return Collections.singleton(result);
}

private ExecutionContext createExecutionContext(final QueryContext queryContext, final String trafficInstanceId) {
private Collection<ExecutionContext> createExecutionContext(final QueryContext queryContext, final String trafficInstanceId) {
ExecutionUnit executionUnit = new ExecutionUnit(trafficInstanceId, new SQLUnit(queryContext.getSql(), queryContext.getParameters()));
return new ExecutionContext(queryContext, Collections.singletonList(executionUnit), new RouteContext());
return Collections.singleton(new ExecutionContext(queryContext, Collections.singletonList(executionUnit), new RouteContext()));
}

private QueryContext createQueryContext() {
Expand All @@ -583,10 +613,10 @@ private QueryContext createQueryContext() {
return new QueryContext(sqlStatementContext, sql, params, hintValueContext, true);
}

private MergedResult mergeQuery(final List<QueryResult> queryResults) throws SQLException {
private MergedResult mergeQuery(final List<QueryResult> queryResults, final SQLStatementContext sqlStatementContext) throws SQLException {
MergeEngine mergeEngine = new MergeEngine(metaDataContexts.getMetaData().getDatabase(databaseName),
metaDataContexts.getMetaData().getProps(), connection.getDatabaseConnectionManager().getConnectionContext());
return mergeEngine.merge(queryResults, executionContext.getSqlStatementContext());
return mergeEngine.merge(queryResults, sqlStatementContext);
}

private void cacheStatements(final Collection<ExecutionGroup<JDBCExecutionUnit>> executionGroups) throws SQLException {
Expand Down Expand Up @@ -629,7 +659,7 @@ public ResultSet getGeneratedKeys() throws SQLException {
if (null != currentBatchGeneratedKeysResultSet) {
return currentBatchGeneratedKeysResultSet;
}
Optional<GeneratedKeyContext> generatedKey = findGeneratedKey(executionContext);
Optional<GeneratedKeyContext> generatedKey = findGeneratedKey(executionContexts.iterator().next());
if (generatedKey.isPresent() && statementOption.isReturnGeneratedKeys() && !generatedValues.isEmpty()) {
return new GeneratedKeysResultSet(getGeneratedKeysColumnName(generatedKey.get().getColumnName()), generatedValues.iterator(), this);
}
Expand All @@ -652,8 +682,8 @@ public void addBatch() {
try {
QueryContext queryContext = createQueryContext();
trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null);
executionContext = null == trafficInstanceId ? createExecutionContext(queryContext) : createExecutionContext(queryContext, trafficInstanceId);
batchPreparedStatementExecutor.addBatchForExecutionUnits(executionContext.getExecutionUnits());
executionContexts = null == trafficInstanceId ? createExecutionContext(queryContext) : createExecutionContext(queryContext, trafficInstanceId);
batchPreparedStatementExecutor.addBatchForExecutionUnits(executionContexts.iterator().next().getExecutionUnits());
} finally {
currentResultSet = null;
clearParameters();
Expand All @@ -662,13 +692,13 @@ public void addBatch() {

@Override
public int[] executeBatch() throws SQLException {
if (null == executionContext) {
if (null == executionContexts || executionContexts.isEmpty()) {
return new int[0];
}
try {
// TODO add raw SQL executor
initBatchPreparedStatementExecutor();
int[] results = batchPreparedStatementExecutor.executeBatch(executionContext.getSqlStatementContext());
int[] results = batchPreparedStatementExecutor.executeBatch(executionContexts.iterator().next().getSqlStatementContext());
if (statementOption.isReturnGeneratedKeys() && generatedValues.isEmpty()) {
List<Statement> batchPreparedStatementExecutorStatements = batchPreparedStatementExecutor.getStatements();
for (Statement statement : batchPreparedStatementExecutorStatements) {
Expand Down Expand Up @@ -698,7 +728,7 @@ private void initBatchPreparedStatementExecutor() throws SQLException {
ExecutionUnit executionUnit = each.getExecutionUnit();
executionUnits.add(executionUnit);
}
batchPreparedStatementExecutor.init(prepareEngine.prepare(executionContext.getRouteContext(), executionUnits, new ExecutionGroupReportContext(databaseName)));
batchPreparedStatementExecutor.init(prepareEngine.prepare(executionContexts.iterator().next().getRouteContext(), executionUnits, new ExecutionGroupReportContext(databaseName)));
setBatchParametersForStatements();
}

Expand Down Expand Up @@ -739,7 +769,7 @@ public int getResultSetHoldability() {
@Override
public boolean isAccumulate() {
return metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().findRules(DataNodeContainedRule.class).stream()
.anyMatch(each -> each.isNeedAccumulate(executionContext.getSqlStatementContext().getTablesContext().getTableNames()));
.anyMatch(each -> each.isNeedAccumulate(executionContexts.iterator().next().getSqlStatementContext().getTablesContext().getTableNames()));
}

@Override
Expand Down
Loading