Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor BaseDMLE2EIT and insert select statement parse logic #28457

Merged
merged 1 commit into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext)

@Override
public UseDefaultInsertColumnsToken generateSQLToken(final InsertStatementContext insertStatementContext) {
String tableName = insertStatementContext.getSqlStatement().getTable().getTableName().getIdentifier().getValue();
String tableName = Optional.ofNullable(insertStatementContext.getSqlStatement().getTable()).map(optional -> optional.getTableName().getIdentifier().getValue()).orElse("");
Optional<UseDefaultInsertColumnsToken> previousSQLToken = findInsertColumnsToken();
if (previousSQLToken.isPresent()) {
processPreviousSQLToken(previousSQLToken.get(), insertStatementContext, tableName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public final class GeneratedKeyContextEngine {
* @return generate key context
*/
public Optional<GeneratedKeyContext> createGenerateKeyContext(final List<String> insertColumnNames, final List<List<ExpressionSegment>> valueExpressions, final List<Object> params) {
String tableName = insertStatement.getTable().getTableName().getIdentifier().getValue();
String tableName = Optional.ofNullable(insertStatement.getTable()).map(optional -> optional.getTableName().getIdentifier().getValue()).orElse("");
return findGenerateKeyColumn(tableName).map(optional -> containsGenerateKey(insertColumnNames, optional)
? findGeneratedKey(insertColumnNames, valueExpressions, params, optional)
: new GeneratedKeyContext(optional, true));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ public InsertStatementContext(final ShardingSphereMetaData metaData, final List<
onDuplicateKeyUpdateValueContext = getOnDuplicateKeyUpdateValueContext(params, parametersOffset).orElse(null);
tablesContext = new TablesContext(getAllSimpleTableSegments(), getDatabaseType());
ShardingSphereSchema schema = getSchema(metaData, defaultDatabaseName);
columnNames = containsInsertColumns() ? insertColumnNames : schema.getVisibleColumnNames(sqlStatement.getTable().getTableName().getIdentifier().getValue().toLowerCase());
columnNames = containsInsertColumns() ? insertColumnNames
: Optional.ofNullable(sqlStatement.getTable()).map(optional -> schema.getVisibleColumnNames(optional.getTableName().getIdentifier().getValue())).orElseGet(Collections::emptyList);
generatedKeyContext = new GeneratedKeyContextEngine(sqlStatement, schema).createGenerateKeyContext(insertColumnNames, getAllValueExpressions(sqlStatement), params).orElse(null);
}

Expand Down Expand Up @@ -166,7 +167,7 @@ public List<List<Object>> getGroupedParameters() {
for (InsertValueContext each : insertValueContexts) {
result.add(each.getParameters());
}
if (null != insertSelectContext) {
if (null != insertSelectContext && !insertSelectContext.getParameters().isEmpty()) {
result.add(insertSelectContext.getParameters());
}
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.Optional;

/**
* Select statement binder.
Expand All @@ -54,7 +55,7 @@ private InsertStatement bind(final InsertStatement sqlStatement, final ShardingS
SQLStatementBinderContext statementBinderContext = new SQLStatementBinderContext(metaData, defaultDatabaseName, sqlStatement.getDatabaseType(), sqlStatement.getVariableNames());
statementBinderContext.getExternalTableBinderContexts().putAll(externalTableBinderContexts);
Map<String, TableSegmentBinderContext> tableBinderContexts = new LinkedHashMap<>();
result.setTable(SimpleTableSegmentBinder.bind(sqlStatement.getTable(), statementBinderContext, tableBinderContexts));
Optional.ofNullable(sqlStatement.getTable()).ifPresent(optional -> result.setTable(SimpleTableSegmentBinder.bind(optional, statementBinderContext, tableBinderContexts)));
if (sqlStatement.getInsertColumns().isPresent() && !sqlStatement.getInsertColumns().get().getColumns().isEmpty()) {
result.setInsertColumns(InsertColumnsSegmentBinder.bind(sqlStatement.getInsertColumns().get(), statementBinderContext, tableBinderContexts));
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.shardingsphere.infra.metadata.database.resource.ResourceMetaData;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.sql.parser.sql.common.enums.ParameterMarkerType;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.ColumnAssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
Expand Down Expand Up @@ -161,15 +162,16 @@ void assertGetGroupedParametersWithOnDuplicateParameters() {
void assertInsertSelect() {
InsertStatement insertStatement = new MySQLInsertStatement();
SelectStatement selectStatement = new MySQLSelectStatement();
selectStatement.addParameterMarkerSegments(Collections.singleton(new ParameterMarkerExpressionSegment(0, 0, 0, ParameterMarkerType.QUESTION)));
selectStatement.setProjections(new ProjectionsSegment(0, 0));
SubquerySegment insertSelect = new SubquerySegment(0, 0, selectStatement);
insertStatement.setInsertSelect(insertSelect);
insertStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("tbl"))));
InsertStatementContext actual = createInsertStatementContext(Collections.singletonList("param"), insertStatement);
actual.setUpParameters(Collections.singletonList("param"));
assertThat(actual.getInsertSelectContext().getParameterCount(), is(0));
assertThat(actual.getInsertSelectContext().getParameterCount(), is(1));
assertThat(actual.getGroupedParameters().size(), is(1));
assertThat(actual.getGroupedParameters().iterator().next(), is(Collections.emptyList()));
assertThat(actual.getGroupedParameters().iterator().next(), is(Collections.singletonList("param")));
}

private void setUpInsertValues(final InsertStatement insertStatement) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,18 @@ public SQLRewriteContext(final ShardingSphereDatabase database, final SQLStateme
if (!hintValueContext.isSkipSQLRewrite()) {
addSQLTokenGenerators(new DefaultTokenGeneratorBuilder(sqlStatementContext).getSQLTokenGenerators());
}
parameterBuilder = sqlStatementContext instanceof InsertStatementContext && null == ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()
? new GroupedParameterBuilder(
((InsertStatementContext) sqlStatementContext).getGroupedParameters(), ((InsertStatementContext) sqlStatementContext).getOnDuplicateKeyUpdateParameters())
parameterBuilder = containsInsertValues(sqlStatementContext)
? new GroupedParameterBuilder(((InsertStatementContext) sqlStatementContext).getGroupedParameters(), ((InsertStatementContext) sqlStatementContext).getOnDuplicateKeyUpdateParameters())
: new StandardParameterBuilder(params);
}

private boolean containsInsertValues(final SQLStatementContext sqlStatementContext) {
if (!(sqlStatementContext instanceof InsertStatementContext)) {
return false;
}
return null == ((InsertStatementContext) sqlStatementContext).getInsertSelectContext();
}

/**
* Add SQL token generators.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ void assertRewriteWithStandardParameterBuilderWhenNeedAggregateRewrite() {
void assertRewriteWithGroupedParameterBuilderForBroadcast() {
InsertStatementContext statementContext = mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
when(((TableAvailable) statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
when(statementContext.getInsertSelectContext()).thenReturn(null);
when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1)));
when(statementContext.getOnDuplicateKeyUpdateParameters()).thenReturn(Collections.emptyList());
SQLRewriteContext sqlRewriteContext =
new SQLRewriteContext(mockDatabase(), statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext());
RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
Expand All @@ -107,7 +109,9 @@ void assertRewriteWithGroupedParameterBuilderForBroadcast() {
void assertRewriteWithGroupedParameterBuilderForRouteWithSameDataNode() {
InsertStatementContext statementContext = mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
when(((TableAvailable) statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
when(statementContext.getInsertSelectContext()).thenReturn(null);
when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1)));
when(statementContext.getOnDuplicateKeyUpdateParameters()).thenReturn(Collections.emptyList());
SQLRewriteContext sqlRewriteContext =
new SQLRewriteContext(mockDatabase(), statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext());
RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
Expand All @@ -127,7 +131,9 @@ void assertRewriteWithGroupedParameterBuilderForRouteWithSameDataNode() {
void assertRewriteWithGroupedParameterBuilderForRouteWithEmptyDataNode() {
InsertStatementContext statementContext = mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
when(((TableAvailable) statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
when(statementContext.getInsertSelectContext()).thenReturn(null);
when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1)));
when(statementContext.getOnDuplicateKeyUpdateParameters()).thenReturn(Collections.emptyList());
SQLRewriteContext sqlRewriteContext =
new SQLRewriteContext(mockDatabase(), statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext());
RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ListExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.MatchAgainstExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.NotExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ValuesExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.RowExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.UnaryOperationExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ValuesExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.complex.CommonExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
Expand Down Expand Up @@ -1361,6 +1361,7 @@ public ASTNode visitInsert(final InsertContext ctx) {
@Override
public ASTNode visitInsertSelectClause(final InsertSelectClauseContext ctx) {
MySQLInsertStatement result = new MySQLInsertStatement();
result.setInsertSelect(createInsertSelectSegment(ctx));
if (null != ctx.LP_()) {
if (null != ctx.fields()) {
result.setInsertColumns(new InsertColumnsSegment(ctx.LP_().getSymbol().getStartIndex(), ctx.RP_().getSymbol().getStopIndex(), createInsertColumns(ctx.fields())));
Expand All @@ -1370,12 +1371,12 @@ public ASTNode visitInsertSelectClause(final InsertSelectClauseContext ctx) {
} else {
result.setInsertColumns(new InsertColumnsSegment(ctx.start.getStartIndex() - 1, ctx.start.getStartIndex() - 1, Collections.emptyList()));
}
result.setInsertSelect(createInsertSelectSegment(ctx));
return result;
}

private SubquerySegment createInsertSelectSegment(final InsertSelectClauseContext ctx) {
MySQLSelectStatement selectStatement = (MySQLSelectStatement) visit(ctx.select());
selectStatement.getParameterMarkerSegments().addAll(getParameterMarkerSegments());
return new SubquerySegment(ctx.select().start.getStartIndex(), ctx.select().stop.getStopIndex(), selectStatement);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,13 @@ public ASTNode visitQualifiedName(final QualifiedNameContext ctx) {
@Override
public ASTNode visitInsertRest(final InsertRestContext ctx) {
OpenGaussInsertStatement result = new OpenGaussInsertStatement();
ValuesClauseContext valuesClause = ctx.select().selectNoParens().selectClauseN().simpleSelect().valuesClause();
if (null == valuesClause) {
OpenGaussSelectStatement selectStatement = (OpenGaussSelectStatement) visit(ctx.select());
result.setInsertSelect(new SubquerySegment(ctx.select().start.getStartIndex(), ctx.select().stop.getStopIndex(), selectStatement));
} else {
result.getValues().addAll(createInsertValuesSegments(valuesClause));
}
if (null == ctx.insertColumnList()) {
result.setInsertColumns(new InsertColumnsSegment(ctx.start.getStartIndex() - 1, ctx.start.getStartIndex() - 1, Collections.emptyList()));
} else {
Expand All @@ -759,13 +766,6 @@ public ASTNode visitInsertRest(final InsertRestContext ctx) {
InsertColumnsSegment insertColumnsSegment = new InsertColumnsSegment(insertColumns.start.getStartIndex() - 1, insertColumns.stop.getStopIndex() + 1, columns.getValue());
result.setInsertColumns(insertColumnsSegment);
}
ValuesClauseContext valuesClause = ctx.select().selectNoParens().selectClauseN().simpleSelect().valuesClause();
if (null == valuesClause) {
OpenGaussSelectStatement selectStatement = (OpenGaussSelectStatement) visit(ctx.select());
result.setInsertSelect(new SubquerySegment(ctx.select().start.getStartIndex(), ctx.select().stop.getStopIndex(), selectStatement));
} else {
result.getValues().addAll(createInsertValuesSegments(valuesClause));
}
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ private Collection<InsertValuesSegment> createInsertValuesSegments(final Assignm
@Override
public ASTNode visitInsertMultiTable(final InsertMultiTableContext ctx) {
OracleInsertStatement result = new OracleInsertStatement();
result.setInsertSelect(new SubquerySegment(ctx.selectSubquery().start.getStartIndex(), ctx.selectSubquery().stop.getStopIndex(), (OracleSelectStatement) visit(ctx.selectSubquery())));
result.setMultiTableInsertType(null != ctx.conditionalInsertClause() && null != ctx.conditionalInsertClause().FIRST() ? MultiTableInsertType.FIRST : MultiTableInsertType.ALL);
List<MultiTableElementContext> multiTableElementContexts = ctx.multiTableElement();
if (null != multiTableElementContexts && !multiTableElementContexts.isEmpty()) {
Expand All @@ -336,9 +337,6 @@ public ASTNode visitInsertMultiTable(final InsertMultiTableContext ctx) {
} else {
result.setMultiTableConditionalIntoSegment((MultiTableConditionalIntoSegment) visit(ctx.conditionalInsertClause()));
}
OracleSelectStatement subquery = (OracleSelectStatement) visit(ctx.selectSubquery());
SubquerySegment subquerySegment = new SubquerySegment(ctx.selectSubquery().start.getStartIndex(), ctx.selectSubquery().stop.getStopIndex(), subquery);
result.setInsertSelect(subquerySegment);
result.addParameterMarkerSegments(getParameterMarkerSegments());
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,13 @@ public ASTNode visitQualifiedName(final QualifiedNameContext ctx) {
@Override
public ASTNode visitInsertRest(final InsertRestContext ctx) {
PostgreSQLInsertStatement result = new PostgreSQLInsertStatement();
ValuesClauseContext valuesClause = ctx.select().selectNoParens().selectClauseN().simpleSelect().valuesClause();
if (null == valuesClause) {
PostgreSQLSelectStatement selectStatement = (PostgreSQLSelectStatement) visit(ctx.select());
result.setInsertSelect(new SubquerySegment(ctx.select().start.getStartIndex(), ctx.select().stop.getStopIndex(), selectStatement));
} else {
result.getValues().addAll(createInsertValuesSegments(valuesClause));
}
if (null == ctx.insertColumnList()) {
result.setInsertColumns(new InsertColumnsSegment(ctx.start.getStartIndex() - 1, ctx.start.getStartIndex() - 1, Collections.emptyList()));
} else {
Expand All @@ -764,13 +771,6 @@ public ASTNode visitInsertRest(final InsertRestContext ctx) {
InsertColumnsSegment insertColumnsSegment = new InsertColumnsSegment(insertColumns.start.getStartIndex() - 1, insertColumns.stop.getStopIndex() + 1, columns.getValue());
result.setInsertColumns(insertColumnsSegment);
}
ValuesClauseContext valuesClause = ctx.select().selectNoParens().selectClauseN().simpleSelect().valuesClause();
if (null == valuesClause) {
PostgreSQLSelectStatement selectStatement = (PostgreSQLSelectStatement) visit(ctx.select());
result.setInsertSelect(new SubquerySegment(ctx.select().start.getStartIndex(), ctx.select().stop.getStopIndex(), selectStatement));
} else {
result.getValues().addAll(createInsertValuesSegments(valuesClause));
}
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,13 @@ void tearDown() {
}

protected final void assertDataSet(final AssertionTestParameter testParam, final SingleE2EContainerComposer containerComposer, final int actualUpdateCount) throws SQLException {
assertThat("Only support single table for DML.", containerComposer.getDataSet().getMetaDataList().size(), is(1));
assertThat(actualUpdateCount, is(containerComposer.getDataSet().getUpdateCount()));
DataSetMetaData expectedDataSetMetaData = containerComposer.getDataSet().getMetaDataList().get(0);
for (DataSetMetaData each : containerComposer.getDataSet().getMetaDataList()) {
assertDataSet(testParam, containerComposer, each);
}
}

private void assertDataSet(final AssertionTestParameter testParam, final SingleE2EContainerComposer containerComposer, final DataSetMetaData expectedDataSetMetaData) throws SQLException {
for (String each : InlineExpressionParserFactory.newInstance().splitAndEvaluate(expectedDataSetMetaData.getDataNodes())) {
DataNode dataNode = new DataNode(each);
DataSource dataSource = containerComposer.getActualDataSourceMap().get(dataNode.getDataSourceName());
Expand Down
Loading