diff --git a/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java b/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java index 228e07a9bce72..7ec266085a627 100644 --- a/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java +++ b/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java @@ -17,6 +17,7 @@ package org.apache.shardingsphere.driver.jdbc.core.statement; +import com.google.common.base.Preconditions; import com.google.common.base.Strings; import lombok.AccessLevel; import lombok.Getter; @@ -150,7 +151,7 @@ public final class ShardingSpherePreparedStatement extends AbstractPreparedState @Getter private final boolean selectContainsEnhancedTable; - private ExecutionContext executionContext; + private Collection executionContexts; private Map columnLabelAndIndexMap; @@ -244,14 +245,8 @@ public ResultSet executeQuery() throws SQLException { if (useFederation) { return executeFederationQuery(queryContext); } - executionContext = createExecutionContext(queryContext); - List queryResults = executeQuery0(); - MergedResult mergedResult = mergeQuery(queryResults); - List resultSets = getResultSets(); - if (null == columnLabelAndIndexMap) { - columnLabelAndIndexMap = ShardingSphereResultSetUtils.createColumnLabelAndIndexMap(sqlStatementContext, selectContainsEnhancedTable, resultSets.get(0).getMetaData()); - } - result = new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, executionContext, columnLabelAndIndexMap); + executionContexts = createExecutionContext(queryContext); + result = doExecuteQuery(executionContexts); // CHECKSTYLE:OFF } catch (final RuntimeException ex) { // CHECKSTYLE:ON @@ -264,6 +259,22 @@ public ResultSet executeQuery() throws SQLException { return result; } + private ShardingSphereResultSet doExecuteQuery(final Collection executionContexts) throws SQLException { + ShardingSphereResultSet result = null; + for (ExecutionContext each : executionContexts) { + List queryResults = executeQuery0(each); + MergedResult mergedResult = mergeQuery(queryResults, each.getSqlStatementContext()); + List resultSets = getResultSets(); + if (null == columnLabelAndIndexMap) { + columnLabelAndIndexMap = ShardingSphereResultSetUtils.createColumnLabelAndIndexMap(sqlStatementContext, selectContainsEnhancedTable, resultSets.get(0).getMetaData()); + } + if (null == result) { + result = new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, each, columnLabelAndIndexMap); + } + } + return result; + } + private boolean decide(final QueryContext queryContext, final ShardingSphereDatabase database, final RuleMetaData globalRuleMetaData) { return executor.getSqlFederationEngine().decide(queryContext.getSqlStatementContext(), queryContext.getParameters(), database, globalRuleMetaData); } @@ -309,12 +320,12 @@ private void resetParameters() throws SQLException { replaySetParameter(); } - private List executeQuery0() throws SQLException { + private List executeQuery0(final ExecutionContext executionContext) throws SQLException { if (hasRawExecutionRule()) { - return executor.getRawExecutor().execute(createRawExecutionGroupContext(), + return executor.getRawExecutor().execute(createRawExecutionGroupContext(executionContext), executionContext.getQueryContext(), new RawSQLExecutorCallback()).stream().map(QueryResult.class::cast).collect(Collectors.toList()); } - ExecutionGroupContext executionGroupContext = createExecutionGroupContext(); + ExecutionGroupContext executionGroupContext = createExecutionGroupContext(executionContext); cacheStatements(executionGroupContext.getInputGroups()); return executor.getRegularExecutor().executeQuery(executionGroupContext, executionContext.getQueryContext(), new PreparedStatementExecuteQueryCallback(metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType(), @@ -351,12 +362,15 @@ public int executeUpdate() throws SQLException { JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, queryContext); return executor.getTrafficExecutor().execute(executionUnit, (statement, sql) -> ((PreparedStatement) statement).executeUpdate()); } - executionContext = createExecutionContext(queryContext); + executionContexts = createExecutionContext(queryContext); if (hasRawExecutionRule()) { - Collection executeResults = executor.getRawExecutor().execute(createRawExecutionGroupContext(), executionContext.getQueryContext(), new RawSQLExecutorCallback()); - return accumulate(executeResults); + Collection results = new LinkedList<>(); + for (ExecutionContext each : executionContexts) { + results.addAll(executor.getRawExecutor().execute(createRawExecutionGroupContext(each), each.getQueryContext(), new RawSQLExecutorCallback())); + } + return accumulate(results); } - return isNeedImplicitCommitTransaction(connection, Collections.singleton(executionContext)) ? executeUpdateWithImplicitCommitTransaction() : useDriverToExecuteUpdate(); + return isNeedImplicitCommitTransaction(connection, executionContexts) ? executeUpdateWithImplicitCommitTransaction() : useDriverToExecuteUpdate(); // CHECKSTYLE:OFF } catch (final RuntimeException ex) { // CHECKSTYLE:ON @@ -368,10 +382,18 @@ public int executeUpdate() throws SQLException { } private int useDriverToExecuteUpdate() throws SQLException { - ExecutionGroupContext executionGroupContext = createExecutionGroupContext(); - cacheStatements(executionGroupContext.getInputGroups()); - return executor.getRegularExecutor().executeUpdate(executionGroupContext, - executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), createExecuteUpdateCallback()); + Integer result = null; + Preconditions.checkArgument(!executionContexts.isEmpty()); + for (ExecutionContext each : executionContexts) { + ExecutionGroupContext executionGroupContext = createExecutionGroupContext(each); + cacheStatements(executionGroupContext.getInputGroups()); + int effectedCount = executor.getRegularExecutor().executeUpdate(executionGroupContext, + each.getQueryContext(), each.getRouteContext().getRouteUnits(), createExecuteUpdateCallback()); + if (null == result) { + result = effectedCount; + } + } + return result; } private int accumulate(final Collection results) { @@ -420,13 +442,16 @@ public boolean execute() throws SQLException { ResultSet resultSet = executeFederationQuery(queryContext); return null != resultSet; } - executionContext = createExecutionContext(queryContext); + executionContexts = createExecutionContext(queryContext); if (hasRawExecutionRule()) { - // TODO process getStatement - Collection executeResults = executor.getRawExecutor().execute(createRawExecutionGroupContext(), executionContext.getQueryContext(), new RawSQLExecutorCallback()); - return executeResults.iterator().next() instanceof QueryResult; + Collection results = new LinkedList<>(); + for (ExecutionContext each : executionContexts) { + // TODO process getStatement + results.addAll(executor.getRawExecutor().execute(createRawExecutionGroupContext(each), each.getQueryContext(), new RawSQLExecutorCallback())); + } + return results.iterator().next() instanceof QueryResult; } - return isNeedImplicitCommitTransaction(connection, Collections.singleton(executionContext)) ? executeWithImplicitCommitTransaction() : useDriverToExecute(); + return isNeedImplicitCommitTransaction(connection, executionContexts) ? executeWithImplicitCommitTransaction() : useDriverToExecute(); // CHECKSTYLE:OFF } catch (final RuntimeException ex) { // CHECKSTYLE:ON @@ -446,7 +471,7 @@ private boolean hasRawExecutionRule() { return false; } - private ExecutionGroupContext createRawExecutionGroupContext() throws SQLException { + private ExecutionGroupContext createRawExecutionGroupContext(final ExecutionContext executionContext) throws SQLException { int maxConnectionsSizePerQuery = metaDataContexts.getMetaData().getProps().getValue(ConfigurationPropertyKey.MAX_CONNECTIONS_SIZE_PER_QUERY); return new RawExecutionPrepareEngine(maxConnectionsSizePerQuery, metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().getRules()) .prepare(executionContext.getRouteContext(), executionContext.getExecutionUnits(), new ExecutionGroupReportContext(databaseName)); @@ -487,10 +512,18 @@ private int executeUpdateWithImplicitCommitTransaction() throws SQLException { } private boolean useDriverToExecute() throws SQLException { - ExecutionGroupContext executionGroupContext = createExecutionGroupContext(); - cacheStatements(executionGroupContext.getInputGroups()); - return executor.getRegularExecutor().execute(executionGroupContext, - executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), createExecuteCallback()); + Boolean result = null; + Preconditions.checkArgument(!executionContexts.isEmpty()); + for (ExecutionContext each : executionContexts) { + ExecutionGroupContext executionGroupContext = createExecutionGroupContext(each); + cacheStatements(executionGroupContext.getInputGroups()); + boolean isWrite = executor.getRegularExecutor().execute(executionGroupContext, + each.getQueryContext(), each.getRouteContext().getRouteUnits(), createExecuteCallback()); + if (null == result) { + result = isWrite; + } + } + return result; } private JDBCExecutorCallback createExecuteCallback() { @@ -510,7 +543,7 @@ protected Optional getSaneResult(final SQLStatement sqlStatement, final }; } - private ExecutionGroupContext createExecutionGroupContext() throws SQLException { + private ExecutionGroupContext createExecutionGroupContext(final ExecutionContext executionContext) throws SQLException { DriverExecutionPrepareEngine prepareEngine = createDriverExecutionPrepareEngine(); return prepareEngine.prepare(executionContext.getRouteContext(), executionContext.getExecutionUnits(), new ExecutionGroupReportContext(databaseName)); } @@ -526,16 +559,18 @@ public ResultSet getResultSet() throws SQLException { if (useFederation) { return executor.getSqlFederationEngine().getResultSet(); } - if (executionContext.getSqlStatementContext() instanceof SelectStatementContext || executionContext.getSqlStatementContext().getSqlStatement() instanceof DALStatement) { + if (executionContexts.iterator().next().getSqlStatementContext() instanceof SelectStatementContext + || executionContexts.iterator().next().getSqlStatementContext().getSqlStatement() instanceof DALStatement) { List resultSets = getResultSets(); if (resultSets.isEmpty()) { return currentResultSet; } - MergedResult mergedResult = mergeQuery(getQueryResults(resultSets)); + SQLStatementContext sqlStatementContext = executionContexts.iterator().next().getSqlStatementContext(); + MergedResult mergedResult = mergeQuery(getQueryResults(resultSets), sqlStatementContext); if (null == columnLabelAndIndexMap) { columnLabelAndIndexMap = ShardingSphereResultSetUtils.createColumnLabelAndIndexMap(sqlStatementContext, selectContainsEnhancedTable, resultSets.get(0).getMetaData()); } - currentResultSet = new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, executionContext, columnLabelAndIndexMap); + currentResultSet = new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, executionContexts.iterator().next(), columnLabelAndIndexMap); } return currentResultSet; } @@ -560,19 +595,19 @@ private List getQueryResults(final List resultSets) thro return result; } - private ExecutionContext createExecutionContext(final QueryContext queryContext) { + private Collection createExecutionContext(final QueryContext queryContext) { RuleMetaData globalRuleMetaData = metaDataContexts.getMetaData().getGlobalRuleMetaData(); ShardingSphereDatabase currentDatabase = metaDataContexts.getMetaData().getDatabase(databaseName); SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext()); ExecutionContext result = kernelProcessor.generateExecutionContext( queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), connection.getDatabaseConnectionManager().getConnectionContext()); findGeneratedKey(result).ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues())); - return result; + return Collections.singleton(result); } - private ExecutionContext createExecutionContext(final QueryContext queryContext, final String trafficInstanceId) { + private Collection createExecutionContext(final QueryContext queryContext, final String trafficInstanceId) { ExecutionUnit executionUnit = new ExecutionUnit(trafficInstanceId, new SQLUnit(queryContext.getSql(), queryContext.getParameters())); - return new ExecutionContext(queryContext, Collections.singletonList(executionUnit), new RouteContext()); + return Collections.singleton(new ExecutionContext(queryContext, Collections.singletonList(executionUnit), new RouteContext())); } private QueryContext createQueryContext() { @@ -583,10 +618,10 @@ private QueryContext createQueryContext() { return new QueryContext(sqlStatementContext, sql, params, hintValueContext, true); } - private MergedResult mergeQuery(final List queryResults) throws SQLException { + private MergedResult mergeQuery(final List queryResults, final SQLStatementContext sqlStatementContext) throws SQLException { MergeEngine mergeEngine = new MergeEngine(metaDataContexts.getMetaData().getDatabase(databaseName), metaDataContexts.getMetaData().getProps(), connection.getDatabaseConnectionManager().getConnectionContext()); - return mergeEngine.merge(queryResults, executionContext.getSqlStatementContext()); + return mergeEngine.merge(queryResults, sqlStatementContext); } private void cacheStatements(final Collection> executionGroups) throws SQLException { @@ -629,7 +664,7 @@ public ResultSet getGeneratedKeys() throws SQLException { if (null != currentBatchGeneratedKeysResultSet) { return currentBatchGeneratedKeysResultSet; } - Optional generatedKey = findGeneratedKey(executionContext); + Optional generatedKey = findGeneratedKey(executionContexts.iterator().next()); if (generatedKey.isPresent() && statementOption.isReturnGeneratedKeys() && !generatedValues.isEmpty()) { return new GeneratedKeysResultSet(getGeneratedKeysColumnName(generatedKey.get().getColumnName()), generatedValues.iterator(), this); } @@ -652,8 +687,8 @@ public void addBatch() { try { QueryContext queryContext = createQueryContext(); trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null); - executionContext = null == trafficInstanceId ? createExecutionContext(queryContext) : createExecutionContext(queryContext, trafficInstanceId); - batchPreparedStatementExecutor.addBatchForExecutionUnits(executionContext.getExecutionUnits()); + executionContexts = null == trafficInstanceId ? createExecutionContext(queryContext) : createExecutionContext(queryContext, trafficInstanceId); + batchPreparedStatementExecutor.addBatchForExecutionUnits(executionContexts.iterator().next().getExecutionUnits()); } finally { currentResultSet = null; clearParameters(); @@ -662,13 +697,13 @@ public void addBatch() { @Override public int[] executeBatch() throws SQLException { - if (null == executionContext) { + if (null == executionContexts || executionContexts.isEmpty()) { return new int[0]; } try { // TODO add raw SQL executor initBatchPreparedStatementExecutor(); - int[] results = batchPreparedStatementExecutor.executeBatch(executionContext.getSqlStatementContext()); + int[] results = batchPreparedStatementExecutor.executeBatch(executionContexts.iterator().next().getSqlStatementContext()); if (statementOption.isReturnGeneratedKeys() && generatedValues.isEmpty()) { List batchPreparedStatementExecutorStatements = batchPreparedStatementExecutor.getStatements(); for (Statement statement : batchPreparedStatementExecutorStatements) { @@ -698,7 +733,7 @@ private void initBatchPreparedStatementExecutor() throws SQLException { ExecutionUnit executionUnit = each.getExecutionUnit(); executionUnits.add(executionUnit); } - batchPreparedStatementExecutor.init(prepareEngine.prepare(executionContext.getRouteContext(), executionUnits, new ExecutionGroupReportContext(databaseName))); + batchPreparedStatementExecutor.init(prepareEngine.prepare(executionContexts.iterator().next().getRouteContext(), executionUnits, new ExecutionGroupReportContext(databaseName))); setBatchParametersForStatements(); } @@ -739,7 +774,7 @@ public int getResultSetHoldability() { @Override public boolean isAccumulate() { return metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().findRules(DataNodeContainedRule.class).stream() - .anyMatch(each -> each.isNeedAccumulate(executionContext.getSqlStatementContext().getTablesContext().getTableNames())); + .anyMatch(each -> each.isNeedAccumulate(executionContexts.iterator().next().getSqlStatementContext().getTablesContext().getTableNames())); } @Override