diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/MergeStatementContext.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/MergeStatementContext.java index 40ab28d71d29d..58f6d5ff6a03a 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/MergeStatementContext.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/MergeStatementContext.java @@ -22,7 +22,7 @@ import org.apache.shardingsphere.sql.parser.sql.dialect.statement.oracle.dml.OracleMergeStatement; /** - * Load xml statement context. + * Merge statement context. */ @Getter public final class MergeStatementContext extends CommonSQLStatementContext { diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/MergeStatementBinder.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/MergeStatementBinder.java index 7a38ca7e139fc..1cbbfc18dd01c 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/MergeStatementBinder.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/MergeStatementBinder.java @@ -34,6 +34,7 @@ import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.SetAssignmentSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment; +import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionWithParamsSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment; import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement; @@ -76,7 +77,12 @@ private MergeStatement bind(final MergeStatement sqlStatement, final ShardingSph Map tableBinderContexts = new LinkedHashMap<>(); tableBinderContexts.putAll(sourceTableBinderContexts); tableBinderContexts.putAll(targetTableBinderContexts); - result.setExpr(ExpressionSegmentBinder.bind(sqlStatement.getExpr(), SegmentType.JOIN_ON, statementBinderContext, tableBinderContexts, Collections.emptyMap())); + if (sqlStatement.getExpression() != null) { + ExpressionWithParamsSegment expression = new ExpressionWithParamsSegment(sqlStatement.getExpression().getStartIndex(), sqlStatement.getExpression().getStopIndex(), + ExpressionSegmentBinder.bind(sqlStatement.getExpression().getExpr(), SegmentType.JOIN_ON, statementBinderContext, tableBinderContexts, Collections.emptyMap())); + expression.getParameterMarkerSegments().addAll(sqlStatement.getExpression().getParameterMarkerSegments()); + result.setExpression(expression); + } result.setInsert(Optional.ofNullable(sqlStatement.getInsert()).map(optional -> bindMergeInsert(optional, (SimpleTableSegment) boundedTargetTableSegment, statementBinderContext, targetTableBinderContexts, sourceTableBinderContexts)).orElse(null)); result.setUpdate(Optional.ofNullable(sqlStatement.getUpdate()).map(optional -> bindMergeUpdate(optional, @@ -142,6 +148,8 @@ private UpdateStatement bindMergeUpdate(final UpdateStatement sqlStatement, fina SetAssignmentSegment setAssignmentSegment = new SetAssignmentSegment(sqlStatement.getSetAssignment().getStartIndex(), sqlStatement.getSetAssignment().getStopIndex(), assignments); result.setSetAssignment(setAssignmentSegment); sqlStatement.getWhere().ifPresent(optional -> result.setWhere(WhereSegmentBinder.bind(optional, updateStatementBinderContext, targetTableBinderContexts, Collections.emptyMap()))); + UpdateStatementHandler.getDeleteWhereSegment(sqlStatement).ifPresent(optional -> UpdateStatementHandler.setDeleteWhereSegment(result, + WhereSegmentBinder.bind(optional, updateStatementBinderContext, targetTableBinderContexts, Collections.emptyMap()))); UpdateStatementHandler.getOrderBySegment(sqlStatement).ifPresent(optional -> UpdateStatementHandler.setOrderBySegment(result, optional)); UpdateStatementHandler.getLimitSegment(sqlStatement).ifPresent(optional -> UpdateStatementHandler.setLimitSegment(result, optional)); UpdateStatementHandler.getWithSegment(sqlStatement).ifPresent(optional -> UpdateStatementHandler.setWithSegment(result, optional)); diff --git a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/MergeStatementBinderTest.java b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/MergeStatementBinderTest.java index e5671c97ad2d2..42216afd2f5db 100644 --- a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/MergeStatementBinderTest.java +++ b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/MergeStatementBinderTest.java @@ -26,6 +26,7 @@ import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.SetAssignmentSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression; +import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionWithParamsSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ExpressionProjectionSegment; @@ -67,8 +68,8 @@ void assertBind() { SimpleTableSegment sourceTable = new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order_item"))); sourceTable.setAlias(new AliasSegment(0, 0, new IdentifierValue("b"))); mergeStatement.setSource(sourceTable); - mergeStatement.setExpr(new BinaryOperationExpression(0, 0, new ColumnSegment(0, 0, new IdentifierValue("id")), - new ColumnSegment(0, 0, new IdentifierValue("order_id")), "=", "id = order_id")); + mergeStatement.setExpression(new ExpressionWithParamsSegment(0, 0, new BinaryOperationExpression(0, 0, new ColumnSegment(0, 0, new IdentifierValue("id")), + new ColumnSegment(0, 0, new IdentifierValue("order_id")), "=", "id = order_id"))); UpdateStatement updateStatement = new OracleUpdateStatement(); updateStatement.setTable(targetTable); ColumnSegment targetTableColumn = new ColumnSegment(0, 0, new IdentifierValue("status")); @@ -139,4 +140,33 @@ void assertBindWithSubQuery() { MergeStatement actual = new MergeStatementBinder().bind(mergeStatement, createMetaData(), DefaultDatabase.LOGIC_NAME); assertThat(actual, not(mergeStatement)); } + + @Test + void assertBindUpdateDeleteWhere() { + MergeStatement mergeStatement = new OracleMergeStatement(); + SimpleTableSegment targetTable = new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))); + targetTable.setAlias(new AliasSegment(0, 0, new IdentifierValue("a"))); + mergeStatement.setTarget(targetTable); + SimpleTableSegment sourceTable = new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order_item"))); + sourceTable.setAlias(new AliasSegment(0, 0, new IdentifierValue("b"))); + mergeStatement.setSource(sourceTable); + OracleUpdateStatement updateStatement = new OracleUpdateStatement(); + updateStatement.setTable(targetTable); + ColumnSegment targetTableColumn = new ColumnSegment(0, 0, new IdentifierValue("status")); + targetTableColumn.setOwner(new OwnerSegment(0, 0, new IdentifierValue("a"))); + ColumnSegment sourceTableColumn = new ColumnSegment(0, 0, new IdentifierValue("status")); + sourceTableColumn.setOwner(new OwnerSegment(0, 0, new IdentifierValue("b"))); + SetAssignmentSegment setAssignmentSegment = new SetAssignmentSegment(0, 0, + Collections.singletonList(new ColumnAssignmentSegment(0, 0, Collections.singletonList(targetTableColumn), sourceTableColumn))); + updateStatement.setSetAssignment(setAssignmentSegment); + updateStatement.setDeleteWhere(new WhereSegment(0, 0, new BinaryOperationExpression(0, 0, new ColumnSegment(0, 0, new IdentifierValue("item_id")), + new LiteralExpressionSegment(0, 0, 1), "=", "item_id = 1"))); + mergeStatement.setUpdate(updateStatement); + MergeStatement actual = new MergeStatementBinder().bind(mergeStatement, createMetaData(), DefaultDatabase.LOGIC_NAME); + assertThat(actual.getUpdate(), instanceOf(OracleUpdateStatement.class)); + assertThat(((OracleUpdateStatement) actual.getUpdate()).getDeleteWhere().getExpr(), instanceOf(BinaryOperationExpression.class)); + assertThat(((BinaryOperationExpression) ((OracleUpdateStatement) actual.getUpdate()).getDeleteWhere().getExpr()).getLeft(), instanceOf(ColumnSegment.class)); + assertThat(((ColumnSegment) ((BinaryOperationExpression) ((OracleUpdateStatement) actual.getUpdate()).getDeleteWhere().getExpr()).getLeft()) + .getColumnBoundedInfo().getOriginalTable().getValue(), is("t_order_item")); + } } diff --git a/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java b/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java index d97e9933e7626..5f4f95df0beef 100644 --- a/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java +++ b/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java @@ -123,6 +123,7 @@ import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.CollateExpression; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.DatetimeExpression; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment; +import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionWithParamsSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.FunctionSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.InExpression; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.MultisetExpression; @@ -1215,7 +1216,10 @@ public ASTNode visitMerge(final MergeContext ctx) { OracleMergeStatement result = new OracleMergeStatement(); result.setTarget((TableSegment) visit(ctx.intoClause())); result.setSource((TableSegment) visit(ctx.usingClause())); - result.setExpr((ExpressionSegment) visit(ctx.usingClause().expr())); + ExpressionWithParamsSegment onExpression = new ExpressionWithParamsSegment(ctx.usingClause().expr().start.getStartIndex(), ctx.usingClause().expr().stop.getStopIndex(), + (ExpressionSegment) visit(ctx.usingClause().expr())); + onExpression.getParameterMarkerSegments().addAll(popAllStatementParameterMarkerSegments()); + result.setExpression(onExpression); if (null != ctx.mergeUpdateClause()) { result.setUpdate((UpdateStatement) visitMergeUpdateClause(ctx.mergeUpdateClause())); } @@ -1238,6 +1242,7 @@ public ASTNode visitMergeInsertClause(final MergeInsertClauseContext ctx) { if (null != ctx.whereClause()) { result.setWhere((WhereSegment) visit(ctx.whereClause())); } + result.getParameterMarkerSegments().addAll(popAllStatementParameterMarkerSegments()); return result; } @@ -1259,7 +1264,7 @@ public ASTNode visitMergeColumnValue(final MergeColumnValueContext ctx) { for (ExprContext each : ctx.expr()) { segments.add(null == each ? new CommonExpressionSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), ctx.getText()) : (ExpressionSegment) visit(each)); } - result.getValue().add(new InsertValuesSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), segments)); + result.getValue().add(new InsertValuesSegment(ctx.LP_().getSymbol().getStartIndex(), ctx.RP_().getSymbol().getStopIndex(), segments)); return result; } @@ -1306,6 +1311,7 @@ public ASTNode visitUsingClause(final UsingClauseContext ctx) { } OracleSelectStatement subquery = (OracleSelectStatement) visit(ctx.subquery()); SubquerySegment subquerySegment = new SubquerySegment(ctx.subquery().start.getStartIndex(), ctx.subquery().stop.getStopIndex(), subquery); + subquerySegment.getSelect().getParameterMarkerSegments().addAll(popAllStatementParameterMarkerSegments()); SubqueryTableSegment result = new SubqueryTableSegment(subquerySegment); if (null != ctx.alias()) { result.setAlias((AliasSegment) visit(ctx.alias())); @@ -1323,6 +1329,7 @@ public ASTNode visitMergeUpdateClause(final MergeUpdateClauseContext ctx) { if (null != ctx.deleteWhereClause()) { result.setDeleteWhere((WhereSegment) visit(ctx.deleteWhereClause())); } + result.getParameterMarkerSegments().addAll(popAllStatementParameterMarkerSegments()); return result; } diff --git a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/expr/ExpressionWithParamsSegment.java b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/expr/ExpressionWithParamsSegment.java new file mode 100644 index 0000000000000..56f0b592f7dda --- /dev/null +++ b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/expr/ExpressionWithParamsSegment.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.apache.shardingsphere.sql.parser.sql.common.segment.SQLSegment; +import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.ParameterMarkerSegment; + +import java.util.Collection; +import java.util.LinkedList; + +/** + * Expression with parameters segment. + */ +@RequiredArgsConstructor +@Getter +public final class ExpressionWithParamsSegment implements SQLSegment { + + private final int startIndex; + + private final int stopIndex; + + private final ExpressionSegment expr; + + private final Collection parameterMarkerSegments = new LinkedList<>(); +} diff --git a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/MergeStatement.java b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/MergeStatement.java index c9dd114ae7bd0..e43a4fedc5aa6 100644 --- a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/MergeStatement.java +++ b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/MergeStatement.java @@ -19,7 +19,7 @@ import lombok.Getter; import lombok.Setter; -import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment; +import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionWithParamsSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment; import org.apache.shardingsphere.sql.parser.sql.common.statement.AbstractSQLStatement; @@ -34,7 +34,7 @@ public abstract class MergeStatement extends AbstractSQLStatement implements DML private TableSegment source; - private ExpressionSegment expr; + private ExpressionWithParamsSegment expression; private UpdateStatement update; diff --git a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/dialect/handler/dml/UpdateStatementHandler.java b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/dialect/handler/dml/UpdateStatementHandler.java index 3eeea17682a86..3ef7524e33560 100644 --- a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/dialect/handler/dml/UpdateStatementHandler.java +++ b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/dialect/handler/dml/UpdateStatementHandler.java @@ -21,11 +21,13 @@ import lombok.NoArgsConstructor; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.OrderBySegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.pagination.limit.LimitSegment; +import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.WithSegment; import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement; import org.apache.shardingsphere.sql.parser.sql.dialect.handler.SQLStatementHandler; import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.MySQLStatement; import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLUpdateStatement; +import org.apache.shardingsphere.sql.parser.sql.dialect.statement.oracle.dml.OracleUpdateStatement; import org.apache.shardingsphere.sql.parser.sql.dialect.statement.sqlserver.SQLServerStatement; import org.apache.shardingsphere.sql.parser.sql.dialect.statement.sqlserver.dml.SQLServerUpdateStatement; @@ -76,6 +78,19 @@ public static Optional getWithSegment(final UpdateStatement updateS return Optional.empty(); } + /** + * Get delete where segment. + * + * @param updateStatement update statement + * @return delete where segment + */ + public static Optional getDeleteWhereSegment(final UpdateStatement updateStatement) { + if (updateStatement instanceof OracleUpdateStatement) { + return Optional.ofNullable(((OracleUpdateStatement) updateStatement).getDeleteWhere()); + } + return Optional.empty(); + } + /** * Set order by segment. * @@ -111,4 +126,16 @@ public static void setWithSegment(final UpdateStatement updateStatement, final W ((SQLServerUpdateStatement) updateStatement).setWithSegment(withSegment); } } + + /** + * Set delete where segment. + * + * @param updateStatement update statement + * @param deleteWhereSegment delete where segment + */ + public static void setDeleteWhereSegment(final UpdateStatement updateStatement, final WhereSegment deleteWhereSegment) { + if (updateStatement instanceof OracleUpdateStatement) { + ((OracleUpdateStatement) updateStatement).setDeleteWhere(deleteWhereSegment); + } + } } diff --git a/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/MergeStatementAssert.java b/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/MergeStatementAssert.java index d65dacf166650..26a01d1d48cbe 100644 --- a/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/MergeStatementAssert.java +++ b/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/MergeStatementAssert.java @@ -67,9 +67,9 @@ private static void assertTable(final SQLCaseAssertContext assertContext, final private static void assertExpression(final SQLCaseAssertContext assertContext, final MergeStatement actual, final MergeStatementTestCase expected) { if (null == expected.getExpr()) { - assertNull(actual.getExpr(), assertContext.getText("Actual expression should not exist.")); + assertNull(actual.getExpression(), assertContext.getText("Actual expression should not exist.")); } else { - ExpressionAssert.assertExpression(assertContext, actual.getExpr(), expected.getExpr()); + ExpressionAssert.assertExpression(assertContext, actual.getExpression().getExpr(), expected.getExpr()); } } diff --git a/test/it/parser/src/main/resources/case/dml/merge.xml b/test/it/parser/src/main/resources/case/dml/merge.xml index 005ea1d2703f8..809c3366df61c 100644 --- a/test/it/parser/src/main/resources/case/dml/merge.xml +++ b/test/it/parser/src/main/resources/case/dml/merge.xml @@ -296,6 +296,26 @@ + + + + + + + + + + + + + * + + + + + + +