Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor merge statement. #28486

Merged
merged 1 commit into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.oracle.dml.OracleMergeStatement;

/**
* Load xml statement context.
* Merge statement context.
*/
@Getter
public final class MergeStatementContext extends CommonSQLStatementContext {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.SetAssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionWithParamsSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
Expand Down Expand Up @@ -76,7 +77,12 @@ private MergeStatement bind(final MergeStatement sqlStatement, final ShardingSph
Map<String, TableSegmentBinderContext> tableBinderContexts = new LinkedHashMap<>();
tableBinderContexts.putAll(sourceTableBinderContexts);
tableBinderContexts.putAll(targetTableBinderContexts);
result.setExpr(ExpressionSegmentBinder.bind(sqlStatement.getExpr(), SegmentType.JOIN_ON, statementBinderContext, tableBinderContexts, Collections.emptyMap()));
if (sqlStatement.getExpression() != null) {
ExpressionWithParamsSegment expression = new ExpressionWithParamsSegment(sqlStatement.getExpression().getStartIndex(), sqlStatement.getExpression().getStopIndex(),
ExpressionSegmentBinder.bind(sqlStatement.getExpression().getExpr(), SegmentType.JOIN_ON, statementBinderContext, tableBinderContexts, Collections.emptyMap()));
expression.getParameterMarkerSegments().addAll(sqlStatement.getExpression().getParameterMarkerSegments());
result.setExpression(expression);
}
result.setInsert(Optional.ofNullable(sqlStatement.getInsert()).map(optional -> bindMergeInsert(optional,
(SimpleTableSegment) boundedTargetTableSegment, statementBinderContext, targetTableBinderContexts, sourceTableBinderContexts)).orElse(null));
result.setUpdate(Optional.ofNullable(sqlStatement.getUpdate()).map(optional -> bindMergeUpdate(optional,
Expand Down Expand Up @@ -142,6 +148,8 @@ private UpdateStatement bindMergeUpdate(final UpdateStatement sqlStatement, fina
SetAssignmentSegment setAssignmentSegment = new SetAssignmentSegment(sqlStatement.getSetAssignment().getStartIndex(), sqlStatement.getSetAssignment().getStopIndex(), assignments);
result.setSetAssignment(setAssignmentSegment);
sqlStatement.getWhere().ifPresent(optional -> result.setWhere(WhereSegmentBinder.bind(optional, updateStatementBinderContext, targetTableBinderContexts, Collections.emptyMap())));
UpdateStatementHandler.getDeleteWhereSegment(sqlStatement).ifPresent(optional -> UpdateStatementHandler.setDeleteWhereSegment(result,
WhereSegmentBinder.bind(optional, updateStatementBinderContext, targetTableBinderContexts, Collections.emptyMap())));
UpdateStatementHandler.getOrderBySegment(sqlStatement).ifPresent(optional -> UpdateStatementHandler.setOrderBySegment(result, optional));
UpdateStatementHandler.getLimitSegment(sqlStatement).ifPresent(optional -> UpdateStatementHandler.setLimitSegment(result, optional));
UpdateStatementHandler.getWithSegment(sqlStatement).ifPresent(optional -> UpdateStatementHandler.setWithSegment(result, optional));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.SetAssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionWithParamsSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ExpressionProjectionSegment;
Expand Down Expand Up @@ -67,8 +68,8 @@ void assertBind() {
SimpleTableSegment sourceTable = new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order_item")));
sourceTable.setAlias(new AliasSegment(0, 0, new IdentifierValue("b")));
mergeStatement.setSource(sourceTable);
mergeStatement.setExpr(new BinaryOperationExpression(0, 0, new ColumnSegment(0, 0, new IdentifierValue("id")),
new ColumnSegment(0, 0, new IdentifierValue("order_id")), "=", "id = order_id"));
mergeStatement.setExpression(new ExpressionWithParamsSegment(0, 0, new BinaryOperationExpression(0, 0, new ColumnSegment(0, 0, new IdentifierValue("id")),
new ColumnSegment(0, 0, new IdentifierValue("order_id")), "=", "id = order_id")));
UpdateStatement updateStatement = new OracleUpdateStatement();
updateStatement.setTable(targetTable);
ColumnSegment targetTableColumn = new ColumnSegment(0, 0, new IdentifierValue("status"));
Expand Down Expand Up @@ -139,4 +140,33 @@ void assertBindWithSubQuery() {
MergeStatement actual = new MergeStatementBinder().bind(mergeStatement, createMetaData(), DefaultDatabase.LOGIC_NAME);
assertThat(actual, not(mergeStatement));
}

@Test
void assertBindUpdateDeleteWhere() {
MergeStatement mergeStatement = new OracleMergeStatement();
SimpleTableSegment targetTable = new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order")));
targetTable.setAlias(new AliasSegment(0, 0, new IdentifierValue("a")));
mergeStatement.setTarget(targetTable);
SimpleTableSegment sourceTable = new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order_item")));
sourceTable.setAlias(new AliasSegment(0, 0, new IdentifierValue("b")));
mergeStatement.setSource(sourceTable);
OracleUpdateStatement updateStatement = new OracleUpdateStatement();
updateStatement.setTable(targetTable);
ColumnSegment targetTableColumn = new ColumnSegment(0, 0, new IdentifierValue("status"));
targetTableColumn.setOwner(new OwnerSegment(0, 0, new IdentifierValue("a")));
ColumnSegment sourceTableColumn = new ColumnSegment(0, 0, new IdentifierValue("status"));
sourceTableColumn.setOwner(new OwnerSegment(0, 0, new IdentifierValue("b")));
SetAssignmentSegment setAssignmentSegment = new SetAssignmentSegment(0, 0,
Collections.singletonList(new ColumnAssignmentSegment(0, 0, Collections.singletonList(targetTableColumn), sourceTableColumn)));
updateStatement.setSetAssignment(setAssignmentSegment);
updateStatement.setDeleteWhere(new WhereSegment(0, 0, new BinaryOperationExpression(0, 0, new ColumnSegment(0, 0, new IdentifierValue("item_id")),
new LiteralExpressionSegment(0, 0, 1), "=", "item_id = 1")));
mergeStatement.setUpdate(updateStatement);
MergeStatement actual = new MergeStatementBinder().bind(mergeStatement, createMetaData(), DefaultDatabase.LOGIC_NAME);
assertThat(actual.getUpdate(), instanceOf(OracleUpdateStatement.class));
assertThat(((OracleUpdateStatement) actual.getUpdate()).getDeleteWhere().getExpr(), instanceOf(BinaryOperationExpression.class));
assertThat(((BinaryOperationExpression) ((OracleUpdateStatement) actual.getUpdate()).getDeleteWhere().getExpr()).getLeft(), instanceOf(ColumnSegment.class));
assertThat(((ColumnSegment) ((BinaryOperationExpression) ((OracleUpdateStatement) actual.getUpdate()).getDeleteWhere().getExpr()).getLeft())
.getColumnBoundedInfo().getOriginalTable().getValue(), is("t_order_item"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.CollateExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.DatetimeExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionWithParamsSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.FunctionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.InExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.MultisetExpression;
Expand Down Expand Up @@ -1215,7 +1216,10 @@ public ASTNode visitMerge(final MergeContext ctx) {
OracleMergeStatement result = new OracleMergeStatement();
result.setTarget((TableSegment) visit(ctx.intoClause()));
result.setSource((TableSegment) visit(ctx.usingClause()));
result.setExpr((ExpressionSegment) visit(ctx.usingClause().expr()));
ExpressionWithParamsSegment onExpression = new ExpressionWithParamsSegment(ctx.usingClause().expr().start.getStartIndex(), ctx.usingClause().expr().stop.getStopIndex(),
(ExpressionSegment) visit(ctx.usingClause().expr()));
onExpression.getParameterMarkerSegments().addAll(popAllStatementParameterMarkerSegments());
result.setExpression(onExpression);
if (null != ctx.mergeUpdateClause()) {
result.setUpdate((UpdateStatement) visitMergeUpdateClause(ctx.mergeUpdateClause()));
}
Expand All @@ -1238,6 +1242,7 @@ public ASTNode visitMergeInsertClause(final MergeInsertClauseContext ctx) {
if (null != ctx.whereClause()) {
result.setWhere((WhereSegment) visit(ctx.whereClause()));
}
result.getParameterMarkerSegments().addAll(popAllStatementParameterMarkerSegments());
return result;
}

Expand All @@ -1259,7 +1264,7 @@ public ASTNode visitMergeColumnValue(final MergeColumnValueContext ctx) {
for (ExprContext each : ctx.expr()) {
segments.add(null == each ? new CommonExpressionSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), ctx.getText()) : (ExpressionSegment) visit(each));
}
result.getValue().add(new InsertValuesSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), segments));
result.getValue().add(new InsertValuesSegment(ctx.LP_().getSymbol().getStartIndex(), ctx.RP_().getSymbol().getStopIndex(), segments));
return result;
}

Expand Down Expand Up @@ -1306,6 +1311,7 @@ public ASTNode visitUsingClause(final UsingClauseContext ctx) {
}
OracleSelectStatement subquery = (OracleSelectStatement) visit(ctx.subquery());
SubquerySegment subquerySegment = new SubquerySegment(ctx.subquery().start.getStartIndex(), ctx.subquery().stop.getStopIndex(), subquery);
subquerySegment.getSelect().getParameterMarkerSegments().addAll(popAllStatementParameterMarkerSegments());
SubqueryTableSegment result = new SubqueryTableSegment(subquerySegment);
if (null != ctx.alias()) {
result.setAlias((AliasSegment) visit(ctx.alias()));
Expand All @@ -1323,6 +1329,7 @@ public ASTNode visitMergeUpdateClause(final MergeUpdateClauseContext ctx) {
if (null != ctx.deleteWhereClause()) {
result.setDeleteWhere((WhereSegment) visit(ctx.deleteWhereClause()));
}
result.getParameterMarkerSegments().addAll(popAllStatementParameterMarkerSegments());
return result;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr;

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.sql.parser.sql.common.segment.SQLSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.ParameterMarkerSegment;

import java.util.Collection;
import java.util.LinkedList;

/**
* Expression with parameters segment.
*/
@RequiredArgsConstructor
@Getter
public final class ExpressionWithParamsSegment implements SQLSegment {

private final int startIndex;

private final int stopIndex;

private final ExpressionSegment expr;

private final Collection<ParameterMarkerSegment> parameterMarkerSegments = new LinkedList<>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import lombok.Getter;
import lombok.Setter;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionWithParamsSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment;
import org.apache.shardingsphere.sql.parser.sql.common.statement.AbstractSQLStatement;

Expand All @@ -34,7 +34,7 @@ public abstract class MergeStatement extends AbstractSQLStatement implements DML

private TableSegment source;

private ExpressionSegment expr;
private ExpressionWithParamsSegment expression;

private UpdateStatement update;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.OrderBySegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.pagination.limit.LimitSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.WithSegment;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement;
import org.apache.shardingsphere.sql.parser.sql.dialect.handler.SQLStatementHandler;
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.MySQLStatement;
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLUpdateStatement;
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.oracle.dml.OracleUpdateStatement;
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.sqlserver.SQLServerStatement;
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.sqlserver.dml.SQLServerUpdateStatement;

Expand Down Expand Up @@ -76,6 +78,19 @@ public static Optional<WithSegment> getWithSegment(final UpdateStatement updateS
return Optional.empty();
}

/**
* Get delete where segment.
*
* @param updateStatement update statement
* @return delete where segment
*/
public static Optional<WhereSegment> getDeleteWhereSegment(final UpdateStatement updateStatement) {
if (updateStatement instanceof OracleUpdateStatement) {
return Optional.ofNullable(((OracleUpdateStatement) updateStatement).getDeleteWhere());
}
return Optional.empty();
}

/**
* Set order by segment.
*
Expand Down Expand Up @@ -111,4 +126,16 @@ public static void setWithSegment(final UpdateStatement updateStatement, final W
((SQLServerUpdateStatement) updateStatement).setWithSegment(withSegment);
}
}

/**
* Set delete where segment.
*
* @param updateStatement update statement
* @param deleteWhereSegment delete where segment
*/
public static void setDeleteWhereSegment(final UpdateStatement updateStatement, final WhereSegment deleteWhereSegment) {
if (updateStatement instanceof OracleUpdateStatement) {
((OracleUpdateStatement) updateStatement).setDeleteWhere(deleteWhereSegment);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ private static void assertTable(final SQLCaseAssertContext assertContext, final

private static void assertExpression(final SQLCaseAssertContext assertContext, final MergeStatement actual, final MergeStatementTestCase expected) {
if (null == expected.getExpr()) {
assertNull(actual.getExpr(), assertContext.getText("Actual expression should not exist."));
assertNull(actual.getExpression(), assertContext.getText("Actual expression should not exist."));
} else {
ExpressionAssert.assertExpression(assertContext, actual.getExpr(), expected.getExpr());
ExpressionAssert.assertExpression(assertContext, actual.getExpression().getExpr(), expected.getExpr());
}
}

Expand Down
20 changes: 20 additions & 0 deletions test/it/parser/src/main/resources/case/dml/merge.xml
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,26 @@
</binary-operation-expression>
</expr>
<insert>
<values>
<value>
<assignment-value>
<column name="employee_id" start-index="331" stop-index="343">
<owner name="S" start-index="331" stop-index="331" />
</column>
<binary-operation-expression start-index="67" stop-index="74">
<left>
<column name="salary" start-index="346" stop-index="353">
<owner name="S" start-index="346" stop-index="346" />
</column>
</left>
<operator>*</operator>
<right>
<literal-expression value=".01" start-index="355" stop-index="357" />
</right>
</binary-operation-expression>
</assignment-value>
</value>
</values>
<where start-index="365" stop-index="388" literal-start-index="365" literal-stop-index="388">
<expr>
<binary-operation-expression start-index="372" stop-index="387" literal-start-index="372" literal-stop-index="387">
Expand Down