Skip to content

Commit

Permalink
Refactor merge statement. (#28486)
Browse files Browse the repository at this point in the history
  • Loading branch information
tuichenchuxin authored Sep 21, 2023
1 parent 1aa5887 commit 00e4ffb
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.oracle.dml.OracleMergeStatement;

/**
* Load xml statement context.
* Merge statement context.
*/
@Getter
public final class MergeStatementContext extends CommonSQLStatementContext {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.SetAssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionWithParamsSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
Expand Down Expand Up @@ -76,7 +77,12 @@ private MergeStatement bind(final MergeStatement sqlStatement, final ShardingSph
Map<String, TableSegmentBinderContext> tableBinderContexts = new LinkedHashMap<>();
tableBinderContexts.putAll(sourceTableBinderContexts);
tableBinderContexts.putAll(targetTableBinderContexts);
result.setExpr(ExpressionSegmentBinder.bind(sqlStatement.getExpr(), SegmentType.JOIN_ON, statementBinderContext, tableBinderContexts, Collections.emptyMap()));
if (sqlStatement.getExpression() != null) {
ExpressionWithParamsSegment expression = new ExpressionWithParamsSegment(sqlStatement.getExpression().getStartIndex(), sqlStatement.getExpression().getStopIndex(),
ExpressionSegmentBinder.bind(sqlStatement.getExpression().getExpr(), SegmentType.JOIN_ON, statementBinderContext, tableBinderContexts, Collections.emptyMap()));
expression.getParameterMarkerSegments().addAll(sqlStatement.getExpression().getParameterMarkerSegments());
result.setExpression(expression);
}
result.setInsert(Optional.ofNullable(sqlStatement.getInsert()).map(optional -> bindMergeInsert(optional,
(SimpleTableSegment) boundedTargetTableSegment, statementBinderContext, targetTableBinderContexts, sourceTableBinderContexts)).orElse(null));
result.setUpdate(Optional.ofNullable(sqlStatement.getUpdate()).map(optional -> bindMergeUpdate(optional,
Expand Down Expand Up @@ -142,6 +148,8 @@ private UpdateStatement bindMergeUpdate(final UpdateStatement sqlStatement, fina
SetAssignmentSegment setAssignmentSegment = new SetAssignmentSegment(sqlStatement.getSetAssignment().getStartIndex(), sqlStatement.getSetAssignment().getStopIndex(), assignments);
result.setSetAssignment(setAssignmentSegment);
sqlStatement.getWhere().ifPresent(optional -> result.setWhere(WhereSegmentBinder.bind(optional, updateStatementBinderContext, targetTableBinderContexts, Collections.emptyMap())));
UpdateStatementHandler.getDeleteWhereSegment(sqlStatement).ifPresent(optional -> UpdateStatementHandler.setDeleteWhereSegment(result,
WhereSegmentBinder.bind(optional, updateStatementBinderContext, targetTableBinderContexts, Collections.emptyMap())));
UpdateStatementHandler.getOrderBySegment(sqlStatement).ifPresent(optional -> UpdateStatementHandler.setOrderBySegment(result, optional));
UpdateStatementHandler.getLimitSegment(sqlStatement).ifPresent(optional -> UpdateStatementHandler.setLimitSegment(result, optional));
UpdateStatementHandler.getWithSegment(sqlStatement).ifPresent(optional -> UpdateStatementHandler.setWithSegment(result, optional));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.SetAssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionWithParamsSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ExpressionProjectionSegment;
Expand Down Expand Up @@ -67,8 +68,8 @@ void assertBind() {
SimpleTableSegment sourceTable = new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order_item")));
sourceTable.setAlias(new AliasSegment(0, 0, new IdentifierValue("b")));
mergeStatement.setSource(sourceTable);
mergeStatement.setExpr(new BinaryOperationExpression(0, 0, new ColumnSegment(0, 0, new IdentifierValue("id")),
new ColumnSegment(0, 0, new IdentifierValue("order_id")), "=", "id = order_id"));
mergeStatement.setExpression(new ExpressionWithParamsSegment(0, 0, new BinaryOperationExpression(0, 0, new ColumnSegment(0, 0, new IdentifierValue("id")),
new ColumnSegment(0, 0, new IdentifierValue("order_id")), "=", "id = order_id")));
UpdateStatement updateStatement = new OracleUpdateStatement();
updateStatement.setTable(targetTable);
ColumnSegment targetTableColumn = new ColumnSegment(0, 0, new IdentifierValue("status"));
Expand Down Expand Up @@ -139,4 +140,33 @@ void assertBindWithSubQuery() {
MergeStatement actual = new MergeStatementBinder().bind(mergeStatement, createMetaData(), DefaultDatabase.LOGIC_NAME);
assertThat(actual, not(mergeStatement));
}

@Test
void assertBindUpdateDeleteWhere() {
MergeStatement mergeStatement = new OracleMergeStatement();
SimpleTableSegment targetTable = new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order")));
targetTable.setAlias(new AliasSegment(0, 0, new IdentifierValue("a")));
mergeStatement.setTarget(targetTable);
SimpleTableSegment sourceTable = new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order_item")));
sourceTable.setAlias(new AliasSegment(0, 0, new IdentifierValue("b")));
mergeStatement.setSource(sourceTable);
OracleUpdateStatement updateStatement = new OracleUpdateStatement();
updateStatement.setTable(targetTable);
ColumnSegment targetTableColumn = new ColumnSegment(0, 0, new IdentifierValue("status"));
targetTableColumn.setOwner(new OwnerSegment(0, 0, new IdentifierValue("a")));
ColumnSegment sourceTableColumn = new ColumnSegment(0, 0, new IdentifierValue("status"));
sourceTableColumn.setOwner(new OwnerSegment(0, 0, new IdentifierValue("b")));
SetAssignmentSegment setAssignmentSegment = new SetAssignmentSegment(0, 0,
Collections.singletonList(new ColumnAssignmentSegment(0, 0, Collections.singletonList(targetTableColumn), sourceTableColumn)));
updateStatement.setSetAssignment(setAssignmentSegment);
updateStatement.setDeleteWhere(new WhereSegment(0, 0, new BinaryOperationExpression(0, 0, new ColumnSegment(0, 0, new IdentifierValue("item_id")),
new LiteralExpressionSegment(0, 0, 1), "=", "item_id = 1")));
mergeStatement.setUpdate(updateStatement);
MergeStatement actual = new MergeStatementBinder().bind(mergeStatement, createMetaData(), DefaultDatabase.LOGIC_NAME);
assertThat(actual.getUpdate(), instanceOf(OracleUpdateStatement.class));
assertThat(((OracleUpdateStatement) actual.getUpdate()).getDeleteWhere().getExpr(), instanceOf(BinaryOperationExpression.class));
assertThat(((BinaryOperationExpression) ((OracleUpdateStatement) actual.getUpdate()).getDeleteWhere().getExpr()).getLeft(), instanceOf(ColumnSegment.class));
assertThat(((ColumnSegment) ((BinaryOperationExpression) ((OracleUpdateStatement) actual.getUpdate()).getDeleteWhere().getExpr()).getLeft())
.getColumnBoundedInfo().getOriginalTable().getValue(), is("t_order_item"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.CollateExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.DatetimeExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionWithParamsSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.FunctionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.InExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.MultisetExpression;
Expand Down Expand Up @@ -1215,7 +1216,10 @@ public ASTNode visitMerge(final MergeContext ctx) {
OracleMergeStatement result = new OracleMergeStatement();
result.setTarget((TableSegment) visit(ctx.intoClause()));
result.setSource((TableSegment) visit(ctx.usingClause()));
result.setExpr((ExpressionSegment) visit(ctx.usingClause().expr()));
ExpressionWithParamsSegment onExpression = new ExpressionWithParamsSegment(ctx.usingClause().expr().start.getStartIndex(), ctx.usingClause().expr().stop.getStopIndex(),
(ExpressionSegment) visit(ctx.usingClause().expr()));
onExpression.getParameterMarkerSegments().addAll(popAllStatementParameterMarkerSegments());
result.setExpression(onExpression);
if (null != ctx.mergeUpdateClause()) {
result.setUpdate((UpdateStatement) visitMergeUpdateClause(ctx.mergeUpdateClause()));
}
Expand All @@ -1238,6 +1242,7 @@ public ASTNode visitMergeInsertClause(final MergeInsertClauseContext ctx) {
if (null != ctx.whereClause()) {
result.setWhere((WhereSegment) visit(ctx.whereClause()));
}
result.getParameterMarkerSegments().addAll(popAllStatementParameterMarkerSegments());
return result;
}

Expand All @@ -1259,7 +1264,7 @@ public ASTNode visitMergeColumnValue(final MergeColumnValueContext ctx) {
for (ExprContext each : ctx.expr()) {
segments.add(null == each ? new CommonExpressionSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), ctx.getText()) : (ExpressionSegment) visit(each));
}
result.getValue().add(new InsertValuesSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), segments));
result.getValue().add(new InsertValuesSegment(ctx.LP_().getSymbol().getStartIndex(), ctx.RP_().getSymbol().getStopIndex(), segments));
return result;
}

Expand Down Expand Up @@ -1306,6 +1311,7 @@ public ASTNode visitUsingClause(final UsingClauseContext ctx) {
}
OracleSelectStatement subquery = (OracleSelectStatement) visit(ctx.subquery());
SubquerySegment subquerySegment = new SubquerySegment(ctx.subquery().start.getStartIndex(), ctx.subquery().stop.getStopIndex(), subquery);
subquerySegment.getSelect().getParameterMarkerSegments().addAll(popAllStatementParameterMarkerSegments());
SubqueryTableSegment result = new SubqueryTableSegment(subquerySegment);
if (null != ctx.alias()) {
result.setAlias((AliasSegment) visit(ctx.alias()));
Expand All @@ -1323,6 +1329,7 @@ public ASTNode visitMergeUpdateClause(final MergeUpdateClauseContext ctx) {
if (null != ctx.deleteWhereClause()) {
result.setDeleteWhere((WhereSegment) visit(ctx.deleteWhereClause()));
}
result.getParameterMarkerSegments().addAll(popAllStatementParameterMarkerSegments());
return result;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr;

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.sql.parser.sql.common.segment.SQLSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.ParameterMarkerSegment;

import java.util.Collection;
import java.util.LinkedList;

/**
* Expression with parameters segment.
*/
@RequiredArgsConstructor
@Getter
public final class ExpressionWithParamsSegment implements SQLSegment {

private final int startIndex;

private final int stopIndex;

private final ExpressionSegment expr;

private final Collection<ParameterMarkerSegment> parameterMarkerSegments = new LinkedList<>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import lombok.Getter;
import lombok.Setter;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionWithParamsSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment;
import org.apache.shardingsphere.sql.parser.sql.common.statement.AbstractSQLStatement;

Expand All @@ -34,7 +34,7 @@ public abstract class MergeStatement extends AbstractSQLStatement implements DML

private TableSegment source;

private ExpressionSegment expr;
private ExpressionWithParamsSegment expression;

private UpdateStatement update;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.OrderBySegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.pagination.limit.LimitSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.WithSegment;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement;
import org.apache.shardingsphere.sql.parser.sql.dialect.handler.SQLStatementHandler;
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.MySQLStatement;
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLUpdateStatement;
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.oracle.dml.OracleUpdateStatement;
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.sqlserver.SQLServerStatement;
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.sqlserver.dml.SQLServerUpdateStatement;

Expand Down Expand Up @@ -76,6 +78,19 @@ public static Optional<WithSegment> getWithSegment(final UpdateStatement updateS
return Optional.empty();
}

/**
* Get delete where segment.
*
* @param updateStatement update statement
* @return delete where segment
*/
public static Optional<WhereSegment> getDeleteWhereSegment(final UpdateStatement updateStatement) {
if (updateStatement instanceof OracleUpdateStatement) {
return Optional.ofNullable(((OracleUpdateStatement) updateStatement).getDeleteWhere());
}
return Optional.empty();
}

/**
* Set order by segment.
*
Expand Down Expand Up @@ -111,4 +126,16 @@ public static void setWithSegment(final UpdateStatement updateStatement, final W
((SQLServerUpdateStatement) updateStatement).setWithSegment(withSegment);
}
}

/**
* Set delete where segment.
*
* @param updateStatement update statement
* @param deleteWhereSegment delete where segment
*/
public static void setDeleteWhereSegment(final UpdateStatement updateStatement, final WhereSegment deleteWhereSegment) {
if (updateStatement instanceof OracleUpdateStatement) {
((OracleUpdateStatement) updateStatement).setDeleteWhere(deleteWhereSegment);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ private static void assertTable(final SQLCaseAssertContext assertContext, final

private static void assertExpression(final SQLCaseAssertContext assertContext, final MergeStatement actual, final MergeStatementTestCase expected) {
if (null == expected.getExpr()) {
assertNull(actual.getExpr(), assertContext.getText("Actual expression should not exist."));
assertNull(actual.getExpression(), assertContext.getText("Actual expression should not exist."));
} else {
ExpressionAssert.assertExpression(assertContext, actual.getExpr(), expected.getExpr());
ExpressionAssert.assertExpression(assertContext, actual.getExpression().getExpr(), expected.getExpr());
}
}

Expand Down
20 changes: 20 additions & 0 deletions test/it/parser/src/main/resources/case/dml/merge.xml
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,26 @@
</binary-operation-expression>
</expr>
<insert>
<values>
<value>
<assignment-value>
<column name="employee_id" start-index="331" stop-index="343">
<owner name="S" start-index="331" stop-index="331" />
</column>
<binary-operation-expression start-index="67" stop-index="74">
<left>
<column name="salary" start-index="346" stop-index="353">
<owner name="S" start-index="346" stop-index="346" />
</column>
</left>
<operator>*</operator>
<right>
<literal-expression value=".01" start-index="355" stop-index="357" />
</right>
</binary-operation-expression>
</assignment-value>
</value>
</values>
<where start-index="365" stop-index="388" literal-start-index="365" literal-stop-index="388">
<expr>
<binary-operation-expression start-index="372" stop-index="387" literal-start-index="372" literal-stop-index="387">
Expand Down

0 comments on commit 00e4ffb

Please sign in to comment.