From e52bb9b28987dff47fc297078336c5baec307896 Mon Sep 17 00:00:00 2001 From: tuichenchuxin Date: Thu, 19 Oct 2023 15:47:55 +0800 Subject: [PATCH 1/2] Optimize ShardingSpherePreparedStatement for multi executionContext --- .../ShardingSpherePreparedStatement.java | 129 +++++++++++------- 1 file changed, 82 insertions(+), 47 deletions(-) diff --git a/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java b/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java index 228e07a9bce72..7ec266085a627 100644 --- a/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java +++ b/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java @@ -17,6 +17,7 @@ package org.apache.shardingsphere.driver.jdbc.core.statement; +import com.google.common.base.Preconditions; import com.google.common.base.Strings; import lombok.AccessLevel; import lombok.Getter; @@ -150,7 +151,7 @@ public final class ShardingSpherePreparedStatement extends AbstractPreparedState @Getter private final boolean selectContainsEnhancedTable; - private ExecutionContext executionContext; + private Collection executionContexts; private Map columnLabelAndIndexMap; @@ -244,14 +245,8 @@ public ResultSet executeQuery() throws SQLException { if (useFederation) { return executeFederationQuery(queryContext); } - executionContext = createExecutionContext(queryContext); - List queryResults = executeQuery0(); - MergedResult mergedResult = mergeQuery(queryResults); - List resultSets = getResultSets(); - if (null == columnLabelAndIndexMap) { - columnLabelAndIndexMap = ShardingSphereResultSetUtils.createColumnLabelAndIndexMap(sqlStatementContext, selectContainsEnhancedTable, resultSets.get(0).getMetaData()); - } - result = new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, executionContext, columnLabelAndIndexMap); + executionContexts = createExecutionContext(queryContext); + result = doExecuteQuery(executionContexts); // CHECKSTYLE:OFF } catch (final RuntimeException ex) { // CHECKSTYLE:ON @@ -264,6 +259,22 @@ public ResultSet executeQuery() throws SQLException { return result; } + private ShardingSphereResultSet doExecuteQuery(final Collection executionContexts) throws SQLException { + ShardingSphereResultSet result = null; + for (ExecutionContext each : executionContexts) { + List queryResults = executeQuery0(each); + MergedResult mergedResult = mergeQuery(queryResults, each.getSqlStatementContext()); + List resultSets = getResultSets(); + if (null == columnLabelAndIndexMap) { + columnLabelAndIndexMap = ShardingSphereResultSetUtils.createColumnLabelAndIndexMap(sqlStatementContext, selectContainsEnhancedTable, resultSets.get(0).getMetaData()); + } + if (null == result) { + result = new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, each, columnLabelAndIndexMap); + } + } + return result; + } + private boolean decide(final QueryContext queryContext, final ShardingSphereDatabase database, final RuleMetaData globalRuleMetaData) { return executor.getSqlFederationEngine().decide(queryContext.getSqlStatementContext(), queryContext.getParameters(), database, globalRuleMetaData); } @@ -309,12 +320,12 @@ private void resetParameters() throws SQLException { replaySetParameter(); } - private List executeQuery0() throws SQLException { + private List executeQuery0(final ExecutionContext executionContext) throws SQLException { if (hasRawExecutionRule()) { - return executor.getRawExecutor().execute(createRawExecutionGroupContext(), + return executor.getRawExecutor().execute(createRawExecutionGroupContext(executionContext), executionContext.getQueryContext(), new RawSQLExecutorCallback()).stream().map(QueryResult.class::cast).collect(Collectors.toList()); } - ExecutionGroupContext executionGroupContext = createExecutionGroupContext(); + ExecutionGroupContext executionGroupContext = createExecutionGroupContext(executionContext); cacheStatements(executionGroupContext.getInputGroups()); return executor.getRegularExecutor().executeQuery(executionGroupContext, executionContext.getQueryContext(), new PreparedStatementExecuteQueryCallback(metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType(), @@ -351,12 +362,15 @@ public int executeUpdate() throws SQLException { JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, queryContext); return executor.getTrafficExecutor().execute(executionUnit, (statement, sql) -> ((PreparedStatement) statement).executeUpdate()); } - executionContext = createExecutionContext(queryContext); + executionContexts = createExecutionContext(queryContext); if (hasRawExecutionRule()) { - Collection executeResults = executor.getRawExecutor().execute(createRawExecutionGroupContext(), executionContext.getQueryContext(), new RawSQLExecutorCallback()); - return accumulate(executeResults); + Collection results = new LinkedList<>(); + for (ExecutionContext each : executionContexts) { + results.addAll(executor.getRawExecutor().execute(createRawExecutionGroupContext(each), each.getQueryContext(), new RawSQLExecutorCallback())); + } + return accumulate(results); } - return isNeedImplicitCommitTransaction(connection, Collections.singleton(executionContext)) ? executeUpdateWithImplicitCommitTransaction() : useDriverToExecuteUpdate(); + return isNeedImplicitCommitTransaction(connection, executionContexts) ? executeUpdateWithImplicitCommitTransaction() : useDriverToExecuteUpdate(); // CHECKSTYLE:OFF } catch (final RuntimeException ex) { // CHECKSTYLE:ON @@ -368,10 +382,18 @@ public int executeUpdate() throws SQLException { } private int useDriverToExecuteUpdate() throws SQLException { - ExecutionGroupContext executionGroupContext = createExecutionGroupContext(); - cacheStatements(executionGroupContext.getInputGroups()); - return executor.getRegularExecutor().executeUpdate(executionGroupContext, - executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), createExecuteUpdateCallback()); + Integer result = null; + Preconditions.checkArgument(!executionContexts.isEmpty()); + for (ExecutionContext each : executionContexts) { + ExecutionGroupContext executionGroupContext = createExecutionGroupContext(each); + cacheStatements(executionGroupContext.getInputGroups()); + int effectedCount = executor.getRegularExecutor().executeUpdate(executionGroupContext, + each.getQueryContext(), each.getRouteContext().getRouteUnits(), createExecuteUpdateCallback()); + if (null == result) { + result = effectedCount; + } + } + return result; } private int accumulate(final Collection results) { @@ -420,13 +442,16 @@ public boolean execute() throws SQLException { ResultSet resultSet = executeFederationQuery(queryContext); return null != resultSet; } - executionContext = createExecutionContext(queryContext); + executionContexts = createExecutionContext(queryContext); if (hasRawExecutionRule()) { - // TODO process getStatement - Collection executeResults = executor.getRawExecutor().execute(createRawExecutionGroupContext(), executionContext.getQueryContext(), new RawSQLExecutorCallback()); - return executeResults.iterator().next() instanceof QueryResult; + Collection results = new LinkedList<>(); + for (ExecutionContext each : executionContexts) { + // TODO process getStatement + results.addAll(executor.getRawExecutor().execute(createRawExecutionGroupContext(each), each.getQueryContext(), new RawSQLExecutorCallback())); + } + return results.iterator().next() instanceof QueryResult; } - return isNeedImplicitCommitTransaction(connection, Collections.singleton(executionContext)) ? executeWithImplicitCommitTransaction() : useDriverToExecute(); + return isNeedImplicitCommitTransaction(connection, executionContexts) ? executeWithImplicitCommitTransaction() : useDriverToExecute(); // CHECKSTYLE:OFF } catch (final RuntimeException ex) { // CHECKSTYLE:ON @@ -446,7 +471,7 @@ private boolean hasRawExecutionRule() { return false; } - private ExecutionGroupContext createRawExecutionGroupContext() throws SQLException { + private ExecutionGroupContext createRawExecutionGroupContext(final ExecutionContext executionContext) throws SQLException { int maxConnectionsSizePerQuery = metaDataContexts.getMetaData().getProps().getValue(ConfigurationPropertyKey.MAX_CONNECTIONS_SIZE_PER_QUERY); return new RawExecutionPrepareEngine(maxConnectionsSizePerQuery, metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().getRules()) .prepare(executionContext.getRouteContext(), executionContext.getExecutionUnits(), new ExecutionGroupReportContext(databaseName)); @@ -487,10 +512,18 @@ private int executeUpdateWithImplicitCommitTransaction() throws SQLException { } private boolean useDriverToExecute() throws SQLException { - ExecutionGroupContext executionGroupContext = createExecutionGroupContext(); - cacheStatements(executionGroupContext.getInputGroups()); - return executor.getRegularExecutor().execute(executionGroupContext, - executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), createExecuteCallback()); + Boolean result = null; + Preconditions.checkArgument(!executionContexts.isEmpty()); + for (ExecutionContext each : executionContexts) { + ExecutionGroupContext executionGroupContext = createExecutionGroupContext(each); + cacheStatements(executionGroupContext.getInputGroups()); + boolean isWrite = executor.getRegularExecutor().execute(executionGroupContext, + each.getQueryContext(), each.getRouteContext().getRouteUnits(), createExecuteCallback()); + if (null == result) { + result = isWrite; + } + } + return result; } private JDBCExecutorCallback createExecuteCallback() { @@ -510,7 +543,7 @@ protected Optional getSaneResult(final SQLStatement sqlStatement, final }; } - private ExecutionGroupContext createExecutionGroupContext() throws SQLException { + private ExecutionGroupContext createExecutionGroupContext(final ExecutionContext executionContext) throws SQLException { DriverExecutionPrepareEngine prepareEngine = createDriverExecutionPrepareEngine(); return prepareEngine.prepare(executionContext.getRouteContext(), executionContext.getExecutionUnits(), new ExecutionGroupReportContext(databaseName)); } @@ -526,16 +559,18 @@ public ResultSet getResultSet() throws SQLException { if (useFederation) { return executor.getSqlFederationEngine().getResultSet(); } - if (executionContext.getSqlStatementContext() instanceof SelectStatementContext || executionContext.getSqlStatementContext().getSqlStatement() instanceof DALStatement) { + if (executionContexts.iterator().next().getSqlStatementContext() instanceof SelectStatementContext + || executionContexts.iterator().next().getSqlStatementContext().getSqlStatement() instanceof DALStatement) { List resultSets = getResultSets(); if (resultSets.isEmpty()) { return currentResultSet; } - MergedResult mergedResult = mergeQuery(getQueryResults(resultSets)); + SQLStatementContext sqlStatementContext = executionContexts.iterator().next().getSqlStatementContext(); + MergedResult mergedResult = mergeQuery(getQueryResults(resultSets), sqlStatementContext); if (null == columnLabelAndIndexMap) { columnLabelAndIndexMap = ShardingSphereResultSetUtils.createColumnLabelAndIndexMap(sqlStatementContext, selectContainsEnhancedTable, resultSets.get(0).getMetaData()); } - currentResultSet = new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, executionContext, columnLabelAndIndexMap); + currentResultSet = new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, executionContexts.iterator().next(), columnLabelAndIndexMap); } return currentResultSet; } @@ -560,19 +595,19 @@ private List getQueryResults(final List resultSets) thro return result; } - private ExecutionContext createExecutionContext(final QueryContext queryContext) { + private Collection createExecutionContext(final QueryContext queryContext) { RuleMetaData globalRuleMetaData = metaDataContexts.getMetaData().getGlobalRuleMetaData(); ShardingSphereDatabase currentDatabase = metaDataContexts.getMetaData().getDatabase(databaseName); SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext()); ExecutionContext result = kernelProcessor.generateExecutionContext( queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), connection.getDatabaseConnectionManager().getConnectionContext()); findGeneratedKey(result).ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues())); - return result; + return Collections.singleton(result); } - private ExecutionContext createExecutionContext(final QueryContext queryContext, final String trafficInstanceId) { + private Collection createExecutionContext(final QueryContext queryContext, final String trafficInstanceId) { ExecutionUnit executionUnit = new ExecutionUnit(trafficInstanceId, new SQLUnit(queryContext.getSql(), queryContext.getParameters())); - return new ExecutionContext(queryContext, Collections.singletonList(executionUnit), new RouteContext()); + return Collections.singleton(new ExecutionContext(queryContext, Collections.singletonList(executionUnit), new RouteContext())); } private QueryContext createQueryContext() { @@ -583,10 +618,10 @@ private QueryContext createQueryContext() { return new QueryContext(sqlStatementContext, sql, params, hintValueContext, true); } - private MergedResult mergeQuery(final List queryResults) throws SQLException { + private MergedResult mergeQuery(final List queryResults, final SQLStatementContext sqlStatementContext) throws SQLException { MergeEngine mergeEngine = new MergeEngine(metaDataContexts.getMetaData().getDatabase(databaseName), metaDataContexts.getMetaData().getProps(), connection.getDatabaseConnectionManager().getConnectionContext()); - return mergeEngine.merge(queryResults, executionContext.getSqlStatementContext()); + return mergeEngine.merge(queryResults, sqlStatementContext); } private void cacheStatements(final Collection> executionGroups) throws SQLException { @@ -629,7 +664,7 @@ public ResultSet getGeneratedKeys() throws SQLException { if (null != currentBatchGeneratedKeysResultSet) { return currentBatchGeneratedKeysResultSet; } - Optional generatedKey = findGeneratedKey(executionContext); + Optional generatedKey = findGeneratedKey(executionContexts.iterator().next()); if (generatedKey.isPresent() && statementOption.isReturnGeneratedKeys() && !generatedValues.isEmpty()) { return new GeneratedKeysResultSet(getGeneratedKeysColumnName(generatedKey.get().getColumnName()), generatedValues.iterator(), this); } @@ -652,8 +687,8 @@ public void addBatch() { try { QueryContext queryContext = createQueryContext(); trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null); - executionContext = null == trafficInstanceId ? createExecutionContext(queryContext) : createExecutionContext(queryContext, trafficInstanceId); - batchPreparedStatementExecutor.addBatchForExecutionUnits(executionContext.getExecutionUnits()); + executionContexts = null == trafficInstanceId ? createExecutionContext(queryContext) : createExecutionContext(queryContext, trafficInstanceId); + batchPreparedStatementExecutor.addBatchForExecutionUnits(executionContexts.iterator().next().getExecutionUnits()); } finally { currentResultSet = null; clearParameters(); @@ -662,13 +697,13 @@ public void addBatch() { @Override public int[] executeBatch() throws SQLException { - if (null == executionContext) { + if (null == executionContexts || executionContexts.isEmpty()) { return new int[0]; } try { // TODO add raw SQL executor initBatchPreparedStatementExecutor(); - int[] results = batchPreparedStatementExecutor.executeBatch(executionContext.getSqlStatementContext()); + int[] results = batchPreparedStatementExecutor.executeBatch(executionContexts.iterator().next().getSqlStatementContext()); if (statementOption.isReturnGeneratedKeys() && generatedValues.isEmpty()) { List batchPreparedStatementExecutorStatements = batchPreparedStatementExecutor.getStatements(); for (Statement statement : batchPreparedStatementExecutorStatements) { @@ -698,7 +733,7 @@ private void initBatchPreparedStatementExecutor() throws SQLException { ExecutionUnit executionUnit = each.getExecutionUnit(); executionUnits.add(executionUnit); } - batchPreparedStatementExecutor.init(prepareEngine.prepare(executionContext.getRouteContext(), executionUnits, new ExecutionGroupReportContext(databaseName))); + batchPreparedStatementExecutor.init(prepareEngine.prepare(executionContexts.iterator().next().getRouteContext(), executionUnits, new ExecutionGroupReportContext(databaseName))); setBatchParametersForStatements(); } @@ -739,7 +774,7 @@ public int getResultSetHoldability() { @Override public boolean isAccumulate() { return metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().findRules(DataNodeContainedRule.class).stream() - .anyMatch(each -> each.isNeedAccumulate(executionContext.getSqlStatementContext().getTablesContext().getTableNames())); + .anyMatch(each -> each.isNeedAccumulate(executionContexts.iterator().next().getSqlStatementContext().getTablesContext().getTableNames())); } @Override From d5fb2a43d92e0950648732af350cf149da3a7253 Mon Sep 17 00:00:00 2001 From: tuichenchuxin Date: Thu, 19 Oct 2023 16:56:51 +0800 Subject: [PATCH 2/2] Refactor ShardingSpherePreparedStatement for support multi executionContext. --- .../ShardingSpherePreparedStatement.java | 17 ++++++----------- .../statement/ShardingSphereStatement.java | 18 +++++++----------- .../backend/connector/DatabaseConnector.java | 6 ++---- 3 files changed, 15 insertions(+), 26 deletions(-) diff --git a/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java b/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java index 7ec266085a627..1fefc7e0380e7 100644 --- a/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java +++ b/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java @@ -261,6 +261,7 @@ public ResultSet executeQuery() throws SQLException { private ShardingSphereResultSet doExecuteQuery(final Collection executionContexts) throws SQLException { ShardingSphereResultSet result = null; + // TODO support multi execution context, currently executionContexts.size() always equals 1 for (ExecutionContext each : executionContexts) { List queryResults = executeQuery0(each); MergedResult mergedResult = mergeQuery(queryResults, each.getSqlStatementContext()); @@ -268,9 +269,7 @@ private ShardingSphereResultSet doExecuteQuery(final Collection executionGroupContext = createExecutionGroupContext(each); cacheStatements(executionGroupContext.getInputGroups()); - int effectedCount = executor.getRegularExecutor().executeUpdate(executionGroupContext, + result = executor.getRegularExecutor().executeUpdate(executionGroupContext, each.getQueryContext(), each.getRouteContext().getRouteUnits(), createExecuteUpdateCallback()); - if (null == result) { - result = effectedCount; - } } return result; } @@ -514,14 +511,12 @@ private int executeUpdateWithImplicitCommitTransaction() throws SQLException { private boolean useDriverToExecute() throws SQLException { Boolean result = null; Preconditions.checkArgument(!executionContexts.isEmpty()); + // TODO support multi execution context, currently executionContexts.size() always equals 1 for (ExecutionContext each : executionContexts) { ExecutionGroupContext executionGroupContext = createExecutionGroupContext(each); cacheStatements(executionGroupContext.getInputGroups()); - boolean isWrite = executor.getRegularExecutor().execute(executionGroupContext, + result = executor.getRegularExecutor().execute(executionGroupContext, each.getQueryContext(), each.getRouteContext().getRouteUnits(), createExecuteCallback()); - if (null == result) { - result = isWrite; - } } return result; } diff --git a/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSphereStatement.java b/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSphereStatement.java index d787b19f5599d..d675040734b27 100644 --- a/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSphereStatement.java +++ b/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSphereStatement.java @@ -191,14 +191,14 @@ public ResultSet executeQuery(final String sql) throws SQLException { private ShardingSphereResultSet doExecuteQuery(final Collection executionContexts) throws SQLException { ShardingSphereResultSet result = null; + // TODO support multi execution context, currently executionContexts.size() always equals 1 for (ExecutionContext each : executionContexts) { List queryResults = executeQuery0(each); MergedResult mergedResult = mergeQuery(queryResults, each.getSqlStatementContext()); boolean selectContainsEnhancedTable = each.getSqlStatementContext() instanceof SelectStatementContext && ((SelectStatementContext) each.getSqlStatementContext()).isContainsEnhancedTable(); - if (null == result) { - result = new ShardingSphereResultSet(getResultSets(), mergedResult, this, selectContainsEnhancedTable, each); - } + result = new ShardingSphereResultSet(getResultSets(), mergedResult, this, selectContainsEnhancedTable, each); + } return result; } @@ -363,15 +363,13 @@ private int useDriverToExecuteUpdate(final ExecuteUpdateCallback updateCallback, final Collection executionContexts) throws SQLException { Integer result = null; Preconditions.checkArgument(!executionContexts.isEmpty()); + // TODO support multi execution context, currently executionContexts.size() always equals 1 for (ExecutionContext each : executionContexts) { ExecutionGroupContext executionGroupContext = createExecutionGroupContext(each); cacheStatements(executionGroupContext.getInputGroups()); JDBCExecutorCallback callback = createExecuteUpdateCallback(updateCallback, sqlStatementContext); - int effectedCount = executor.getRegularExecutor().executeUpdate(executionGroupContext, + result = executor.getRegularExecutor().executeUpdate(executionGroupContext, each.getQueryContext(), each.getRouteContext().getRouteUnits(), callback); - if (null == result) { - result = effectedCount; - } } return result; } @@ -576,15 +574,13 @@ private boolean executeWithImplicitCommitTransaction(final ExecuteCallback callb private boolean useDriverToExecute(final ExecuteCallback callback, final Collection executionContexts) throws SQLException { Boolean result = null; Preconditions.checkArgument(!executionContexts.isEmpty()); + // TODO support multi execution context, currently executionContexts.size() always equals 1 for (ExecutionContext each : executionContexts) { ExecutionGroupContext executionGroupContext = createExecutionGroupContext(each); cacheStatements(executionGroupContext.getInputGroups()); JDBCExecutorCallback jdbcExecutorCallback = createExecuteCallback(callback, each.getSqlStatementContext().getSqlStatement()); - boolean isWrite = executor.getRegularExecutor().execute(executionGroupContext, + result = executor.getRegularExecutor().execute(executionGroupContext, each.getQueryContext(), each.getRouteContext().getRouteUnits(), jdbcExecutorCallback); - if (null == result) { - result = isWrite; - } } return result; } diff --git a/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/connector/DatabaseConnector.java b/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/connector/DatabaseConnector.java index 871853cec2590..ee51ee8e6dca6 100644 --- a/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/connector/DatabaseConnector.java +++ b/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/connector/DatabaseConnector.java @@ -222,11 +222,9 @@ private ResponseHeader doExecuteWithImplicitCommitTransaction(final Collection executionContexts) throws SQLException { ResponseHeader result = null; + // TODO support multi execution context, currently executionContexts.size() always equals 1 for (ExecutionContext each : executionContexts) { - ResponseHeader responseHeader = doExecute(each); - if (null == result) { - result = responseHeader; - } + result = doExecute(each); } return result; }