Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ShardingSphereStatement #31395

Merged
merged 3 commits into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ public abstract class AbstractStatementAdapter extends AbstractUnsupportedOperat

private boolean closed;

private boolean closeOnCompletion;

protected final boolean isNeedImplicitCommitTransaction(final ShardingSphereConnection connection, final SQLStatement sqlStatement, final boolean multiExecutionUnits) {
if (!connection.getAutoCommit()) {
return false;
Expand Down Expand Up @@ -229,6 +231,28 @@ public final SQLWarning getWarnings() {
public final void clearWarnings() {
}

@Override
public void closeOnCompletion() {
closeOnCompletion = true;
}

@Override
public boolean isCloseOnCompletion() {
return closeOnCompletion;
}

@Override
public void setCursorName(final String name) throws SQLException {
if (isTransparent()) {
getRoutedStatements().iterator().next().setCursorName(name);
}
super.setCursorName(name);
}

private boolean isTransparent() {
return 1 == getRoutedStatements().size();
}

@SuppressWarnings({"unchecked", "rawtypes"})
@Override
public final void cancel() throws SQLException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ public final class ShardingSphereStatement extends AbstractStatementAdapter {

private final MetaDataContexts metaDataContexts;

private String databaseName;

private final List<Statement> statements;

private final StatementOption statementOption;
Expand All @@ -130,12 +132,6 @@ public final class ShardingSphereStatement extends AbstractStatementAdapter {

private ResultSet currentResultSet;

private String trafficInstanceId;

private boolean useFederation;

private String databaseName;

public ShardingSphereStatement(final ShardingSphereConnection connection) {
this(connection, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY, ResultSet.HOLD_CURSORS_OVER_COMMIT);
}
Expand All @@ -147,14 +143,14 @@ public ShardingSphereStatement(final ShardingSphereConnection connection, final
public ShardingSphereStatement(final ShardingSphereConnection connection, final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability) {
this.connection = connection;
metaDataContexts = connection.getContextManager().getMetaDataContexts();
databaseName = connection.getDatabaseName();
statements = new LinkedList<>();
statementOption = new StatementOption(resultSetType, resultSetConcurrency, resultSetHoldability);
executor = new DriverExecutor(connection);
kernelProcessor = new KernelProcessor();
trafficRule = metaDataContexts.getMetaData().getGlobalRuleMetaData().getSingleRule(TrafficRule.class);
statementManager = new StatementManager();
batchStatementExecutor = new BatchStatementExecutor(this);
databaseName = connection.getDatabaseName();
}

@Override
Expand All @@ -166,13 +162,16 @@ public ResultSet executeQuery(final String sql) throws SQLException {
handleAutoCommit(queryContext);
databaseName = queryContext.getDatabaseNameFromSQLStatement().orElse(connection.getDatabaseName());
connection.getDatabaseConnectionManager().getConnectionContext().setCurrentDatabase(databaseName);
trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null);
String trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null);
if (null != trafficInstanceId) {
return executor.getTrafficExecutor().execute(createTrafficExecutionUnit(trafficInstanceId, queryContext), Statement::executeQuery);
result = executor.getTrafficExecutor().execute(createTrafficExecutionUnit(trafficInstanceId, queryContext), Statement::executeQuery);
currentResultSet = result;
return result;
}
useFederation = decide(queryContext, metaDataContexts.getMetaData().getDatabase(databaseName), metaDataContexts.getMetaData().getGlobalRuleMetaData());
if (useFederation) {
return executeFederationQuery(queryContext);
if (decide(queryContext, metaDataContexts.getMetaData().getDatabase(databaseName), metaDataContexts.getMetaData().getGlobalRuleMetaData())) {
result = executeFederationQuery(queryContext);
currentResultSet = result;
return result;
}
executionContext = createExecutionContext(queryContext);
result = doExecuteQuery(executionContext);
Expand All @@ -181,8 +180,6 @@ public ResultSet executeQuery(final String sql) throws SQLException {
// CHECKSTYLE:ON
handleExceptionInTransaction(connection, metaDataContexts);
throw SQLExceptionTransformEngine.toSQLException(ex, metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType());
} finally {
currentResultSet = null;
}
currentResultSet = result;
return result;
Expand Down Expand Up @@ -320,7 +317,7 @@ private int executeUpdate0(final String sql, final ExecuteUpdateCallback updateC
handleAutoCommit(queryContext);
databaseName = queryContext.getDatabaseNameFromSQLStatement().orElse(connection.getDatabaseName());
connection.getDatabaseConnectionManager().getConnectionContext().setCurrentDatabase(databaseName);
trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null);
String trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null);
if (null != trafficInstanceId) {
JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, queryContext);
return executor.getTrafficExecutor().execute(executionUnit, trafficCallback);
Expand Down Expand Up @@ -421,30 +418,29 @@ public boolean execute(final String sql, final String[] columnNames) throws SQLE
}

private boolean execute0(final String sql, final ExecuteCallback executeCallback, final TrafficExecutorCallback<Boolean> trafficCallback) throws SQLException {
try {
QueryContext queryContext = createQueryContext(sql);
handleAutoCommit(queryContext);
databaseName = queryContext.getDatabaseNameFromSQLStatement().orElse(connection.getDatabaseName());
connection.getDatabaseConnectionManager().getConnectionContext().setCurrentDatabase(databaseName);
trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null);
if (null != trafficInstanceId) {
JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, queryContext);
return executor.getTrafficExecutor().execute(executionUnit, trafficCallback);
}
useFederation = decide(queryContext, metaDataContexts.getMetaData().getDatabase(databaseName), metaDataContexts.getMetaData().getGlobalRuleMetaData());
if (useFederation) {
ResultSet resultSet = executeFederationQuery(queryContext);
return null != resultSet;
}
executionContext = createExecutionContext(queryContext);
if (!metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().getAttributes(RawExecutionRuleAttribute.class).isEmpty()) {
Collection<ExecuteResult> results = executor.getRawExecutor().execute(createRawExecutionContext(executionContext), executionContext.getQueryContext(), new RawSQLExecutorCallback());
return results.iterator().next() instanceof QueryResult;
}
return executeWithExecutionContext(executeCallback, executionContext);
} finally {
currentResultSet = null;
QueryContext queryContext = createQueryContext(sql);
handleAutoCommit(queryContext);
databaseName = queryContext.getDatabaseNameFromSQLStatement().orElse(connection.getDatabaseName());
connection.getDatabaseConnectionManager().getConnectionContext().setCurrentDatabase(databaseName);
String trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null);
if (null != trafficInstanceId) {
JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, queryContext);
boolean result = executor.getTrafficExecutor().execute(executionUnit, trafficCallback);
currentResultSet = executor.getTrafficExecutor().getResultSet();
return result;
}
if (decide(queryContext, metaDataContexts.getMetaData().getDatabase(databaseName), metaDataContexts.getMetaData().getGlobalRuleMetaData())) {
ResultSet resultSet = executeFederationQuery(queryContext);
currentResultSet = resultSet;
return null != resultSet;
}
currentResultSet = null;
executionContext = createExecutionContext(queryContext);
if (!metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().getAttributes(RawExecutionRuleAttribute.class).isEmpty()) {
Collection<ExecuteResult> results = executor.getRawExecutor().execute(createRawExecutionContext(executionContext), executionContext.getQueryContext(), new RawSQLExecutorCallback());
return results.iterator().next() instanceof QueryResult;
}
return executeWithExecutionContext(executeCallback, executionContext);
}

private void handleAutoCommit(final QueryContext queryContext) throws SQLException {
Expand Down Expand Up @@ -563,14 +559,7 @@ public ResultSet getResultSet() throws SQLException {
if (null != currentResultSet) {
return currentResultSet;
}
if (null != trafficInstanceId) {
return executor.getTrafficExecutor().getResultSet();
}
if (useFederation) {
return executor.getSqlFederationEngine().getResultSet();
}
if (executionContext.getSqlStatementContext() instanceof SelectStatementContext
|| executionContext.getSqlStatementContext().getSqlStatement() instanceof DALStatement) {
if (executionContext.getSqlStatementContext() instanceof SelectStatementContext || executionContext.getSqlStatementContext().getSqlStatement() instanceof DALStatement) {
List<ResultSet> resultSets = getResultSets();
if (resultSets.isEmpty()) {
return currentResultSet;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@
public abstract class AbstractUnsupportedOperationStatement extends WrapperAdapter implements Statement {

@Override
public final void closeOnCompletion() throws SQLException {
public void closeOnCompletion() throws SQLException {
throw new SQLFeatureNotSupportedException("closeOnCompletion");
}

@Override
public final boolean isCloseOnCompletion() throws SQLException {
public boolean isCloseOnCompletion() throws SQLException {
throw new SQLFeatureNotSupportedException("isCloseOnCompletion");
}

@Override
public final void setCursorName(final String name) throws SQLException {
public void setCursorName(final String name) throws SQLException {
throw new SQLFeatureNotSupportedException("setCursorName");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,6 @@ void setUp() {
shardingSphereStatement = new ShardingSphereStatement(connection);
}

@Test
void assertCloseOnCompletion() {
assertThrows(SQLFeatureNotSupportedException.class, () -> shardingSphereStatement.closeOnCompletion());
}

@Test
void assertIsCloseOnCompletion() {
assertThrows(SQLFeatureNotSupportedException.class, () -> shardingSphereStatement.isCloseOnCompletion());
}

@Test
void assertSetCursorName() {
assertThrows(SQLFeatureNotSupportedException.class, () -> shardingSphereStatement.setCursorName("cursorName"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.shardingsphere.traffic.executor;

import lombok.Getter;
import org.apache.shardingsphere.infra.executor.sql.context.SQLUnit;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutionUnit;

Expand All @@ -33,6 +34,9 @@ public final class TrafficExecutor implements AutoCloseable {

private Statement statement;

@Getter
private ResultSet resultSet;

/**
* Execute.
*
Expand All @@ -45,7 +49,9 @@ public final class TrafficExecutor implements AutoCloseable {
public <T> T execute(final JDBCExecutionUnit executionUnit, final TrafficExecutorCallback<T> callback) throws SQLException {
SQLUnit sqlUnit = executionUnit.getExecutionUnit().getSqlUnit();
cacheStatement(sqlUnit.getParameters(), executionUnit.getStorageResource());
return callback.execute(statement, sqlUnit.getSql());
T result = callback.execute(statement, sqlUnit.getSql());
resultSet = statement.getResultSet();
return result;
}

private void cacheStatement(final List<Object> params, final Statement statement) throws SQLException {
Expand All @@ -63,16 +69,6 @@ private void setParameters(final Statement statement, final List<Object> params)
}
}

/**
* Get result set.
*
* @return result set
* @throws SQLException SQL exception
*/
public ResultSet getResultSet() throws SQLException {
return statement.getResultSet();
}

@Override
public void close() throws SQLException {
if (null != statement) {
Expand Down