diff --git a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/update/CustomSqlUpdate.java b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/update/CustomSqlUpdate.java new file mode 100644 index 0000000000000..1334cfb195ea7 --- /dev/null +++ b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/update/CustomSqlUpdate.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shardingsphere.sqlfederation.optimizer.converter.statement.update; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlSelect; +import org.apache.calcite.sql.SqlSpecialOperator; +import org.apache.calcite.sql.SqlWriter; +import org.apache.calcite.sql.SqlWriter.Frame; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.util.ImmutableNullableList; +import org.apache.calcite.util.Pair; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.dataflow.qual.Pure; + +import java.util.Iterator; +import java.util.List; + +public class CustomSqlUpdate extends SqlCall { + + public static final SqlSpecialOperator OPERATOR; + + private SqlNode targetTable; + + private SqlNodeList targetColumnList; + + private SqlNodeList sourceExpressionList; + + @Nullable + private SqlNode condition; + + @Nullable + private SqlSelect sourceSelect; + + @Nullable + private SqlIdentifier alias; + + @Nullable + private SqlNode from; + + public CustomSqlUpdate(final SqlParserPos pos, final SqlNode targetTable, final SqlNodeList targetColumnList, final SqlNodeList sourceExpressionList, final @Nullable SqlNode condition, + final @Nullable SqlSelect sourceSelect, final @Nullable SqlIdentifier alias, final @Nullable SqlNode from) { + super(pos); + this.targetTable = targetTable; + this.targetColumnList = targetColumnList; + this.sourceExpressionList = sourceExpressionList; + this.condition = condition; + this.sourceSelect = sourceSelect; + this.from = from; + assert sourceExpressionList.size() == targetColumnList.size(); + + this.alias = alias; + } + + public SqlKind getKind() { + return SqlKind.UPDATE; + } + + public SqlOperator getOperator() { + return OPERATOR; + } + + public List<@Nullable SqlNode> getOperandList() { + return ImmutableNullableList.of(this.targetTable, this.targetColumnList, this.sourceExpressionList, this.condition, this.alias, this.from); + } + + @Override + public void setOperand(final int i, final @Nullable SqlNode operand) { + switch (i) { + case 0: + assert operand instanceof SqlIdentifier; + + this.targetTable = operand; + break; + case 1: + this.targetColumnList = (SqlNodeList) operand; + break; + case 2: + this.sourceExpressionList = (SqlNodeList) operand; + break; + case 3: + this.condition = operand; + break; + case 4: + this.sourceExpressionList = (SqlNodeList) operand; + break; + case 5: + this.alias = (SqlIdentifier) operand; + break; + case 6: + this.from = operand; + break; + default: + throw new AssertionError(i); + } + + } + + public SqlNode getTargetTable() { + return this.targetTable; + } + + @Pure + public final @Nullable SqlNode getFrom() { + return this.from; + } + + public void setFrom(final @Nullable SqlNode from) { + this.from = from; + } + + @Pure + public @Nullable SqlIdentifier getAlias() { + return this.alias; + } + + public void setAlias(final SqlIdentifier alias) { + this.alias = alias; + } + + public SqlNodeList getTargetColumnList() { + return this.targetColumnList; + } + + public SqlNodeList getSourceExpressionList() { + return this.sourceExpressionList; + } + + public @Nullable SqlNode getCondition() { + return this.condition; + } + + public @Nullable SqlSelect getSourceSelect() { + return this.sourceSelect; + } + + public void setSourceSelect(final SqlSelect sourceSelect) { + this.sourceSelect = sourceSelect; + } + + @Override + public void unparse(final SqlWriter writer, final int leftPrec, final int rightPrec) { + final Frame frame = writer.startList(SqlWriter.FrameTypeEnum.SELECT, "UPDATE", ""); + int opLeft = this.getOperator().getLeftPrec(); + int opRight = this.getOperator().getRightPrec(); + this.targetTable.unparse(writer, opLeft, opRight); + SqlIdentifier alias = this.alias; + if (alias != null) { + writer.keyword("AS"); + alias.unparse(writer, opLeft, opRight); + } + SqlWriter.Frame setFrame = writer.startList(SqlWriter.FrameTypeEnum.UPDATE_SET_LIST, "SET", ""); + Iterator var9 = Pair.zip(this.getTargetColumnList(), this.getSourceExpressionList()).iterator(); + while (var9.hasNext()) { + Pair pair = (Pair) var9.next(); + writer.sep(","); + SqlIdentifier id = (SqlIdentifier) pair.left; + id.unparse(writer, opLeft, opRight); + writer.keyword("="); + SqlNode sourceExp = (SqlNode) pair.right; + sourceExp.unparse(writer, opLeft, opRight); + } + writer.endList(setFrame); + SqlNode from = this.from; + if (from != null) { + writer.sep("FROM"); + from.unparse(writer, opLeft, opRight); + } + SqlNode condition = this.condition; + if (condition != null) { + writer.sep("WHERE"); + condition.unparse(writer, opLeft, opRight); + } + writer.endList(frame); + } + + static { + OPERATOR = new SqlSpecialOperator("UPDATE", SqlKind.UPDATE); + } +} diff --git a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/update/UpdateStatementConverter.java b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/update/UpdateStatementConverter.java index 479beaf6bdfe3..55cd0495a0f66 100644 --- a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/update/UpdateStatementConverter.java +++ b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/update/UpdateStatementConverter.java @@ -20,12 +20,12 @@ import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlNodeList; import org.apache.calcite.sql.SqlOrderBy; -import org.apache.calcite.sql.SqlUpdate; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.AssignmentSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.pagination.limit.LimitSegment; +import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment; import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement; import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.UpdateStatementHandler; import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.expression.ExpressionConverter; @@ -34,6 +34,7 @@ import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.limit.PaginationValueSQLConverter; import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.orderby.OrderByConverter; import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.where.WhereConverter; +import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.with.WithConverter; import org.apache.shardingsphere.sqlfederation.optimizer.converter.statement.SQLStatementConverter; import java.util.List; @@ -47,7 +48,7 @@ public final class UpdateStatementConverter implements SQLStatementConverter limit = UpdateStatementHandler.getLimitSegment(updateStatement); if (limit.isPresent()) { @@ -58,8 +59,9 @@ public SqlNode convert(final UpdateStatement updateStatement) { return orderBy.isEmpty() ? sqlUpdate : new SqlOrderBy(SqlParserPos.ZERO, sqlUpdate, orderBy, null, null); } - private SqlUpdate convertUpdate(final UpdateStatement updateStatement) { + private SqlNode convertUpdate(final UpdateStatement updateStatement) { SqlNode table = TableConverter.convert(updateStatement.getTable()).orElseThrow(IllegalStateException::new); + SqlNode from = convertTable(updateStatement.getAssignmentSegment().orElse(null).getFrom()); SqlNode condition = updateStatement.getWhere().flatMap(WhereConverter::convert).orElse(null); SqlNodeList columns = new SqlNodeList(SqlParserPos.ZERO); SqlNodeList expressions = new SqlNodeList(SqlParserPos.ZERO); @@ -67,7 +69,8 @@ private SqlUpdate convertUpdate(final UpdateStatement updateStatement) { columns.addAll(convertColumn(each.getColumns())); expressions.add(convertExpression(each.getValue())); } - return new SqlUpdate(SqlParserPos.ZERO, table, columns, expressions, condition, null, null); + CustomSqlUpdate sqlUpdate = new CustomSqlUpdate(SqlParserPos.ZERO, table, columns, expressions, condition, null, null, from); + return UpdateStatementHandler.getWithSegment(updateStatement).flatMap(optional -> WithConverter.convert(optional, sqlUpdate)).orElse(sqlUpdate); } private List convertColumn(final List columnSegments) { @@ -77,4 +80,8 @@ private List convertColumn(final List columnSegments) { private SqlNode convertExpression(final ExpressionSegment expressionSegment) { return ExpressionConverter.convert(expressionSegment).orElseThrow(IllegalStateException::new); } + + private SqlNode convertTable(final TableSegment tableSegment) { + return TableConverter.convert(tableSegment).orElse(null); + } } diff --git a/parser/sql/dialect/sqlserver/src/main/java/org/apache/shardingsphere/sql/parser/sqlserver/visitor/statement/SQLServerStatementVisitor.java b/parser/sql/dialect/sqlserver/src/main/java/org/apache/shardingsphere/sql/parser/sqlserver/visitor/statement/SQLServerStatementVisitor.java index 4147471516b77..9a5a2a3afad55 100644 --- a/parser/sql/dialect/sqlserver/src/main/java/org/apache/shardingsphere/sql/parser/sqlserver/visitor/statement/SQLServerStatementVisitor.java +++ b/parser/sql/dialect/sqlserver/src/main/java/org/apache/shardingsphere/sql/parser/sqlserver/visitor/statement/SQLServerStatementVisitor.java @@ -995,9 +995,14 @@ public ASTNode visitUpdate(final UpdateContext ctx) { @Override public ASTNode visitSetAssignmentsClause(final SetAssignmentsClauseContext ctx) { Collection assignments = new LinkedList<>(); + TableSegment from; for (AssignmentContext each : ctx.assignment()) { assignments.add((AssignmentSegment) visit(each)); } + if (null != ctx.fromClause()) { + from = (TableSegment) visit(ctx.fromClause().tableReferences()); + return new SetAssignmentSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), assignments, from); + } return new SetAssignmentSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), assignments); } @@ -1017,7 +1022,6 @@ public ASTNode visitAssignment(final AssignmentContext ctx) { columnSegments.add(column); ExpressionSegment value = (ExpressionSegment) visit(ctx.assignmentValue()); AssignmentSegment result = new ColumnAssignmentSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), columnSegments, value); - result.getColumns().add(column); return result; } diff --git a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/assignment/SetAssignmentSegment.java b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/assignment/SetAssignmentSegment.java index c8b2cd08b45b6..cbc37cc87d497 100644 --- a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/assignment/SetAssignmentSegment.java +++ b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/assignment/SetAssignmentSegment.java @@ -20,6 +20,7 @@ import lombok.Getter; import lombok.RequiredArgsConstructor; import org.apache.shardingsphere.sql.parser.sql.common.segment.SQLSegment; +import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment; import java.util.Collection; @@ -35,4 +36,17 @@ public final class SetAssignmentSegment implements SQLSegment { private final int stopIndex; private final Collection assignments; + + private TableSegment from; + + public SetAssignmentSegment(final int startIndex, final int stopIndex, final Collection assignments, final TableSegment from) { + this.startIndex = startIndex; + this.stopIndex = stopIndex; + this.assignments = assignments; + this.from = from; + } + + public TableSegment getFrom() { + return from; + } } diff --git a/test/it/optimizer/src/test/resources/converter/update.xml b/test/it/optimizer/src/test/resources/converter/update.xml index 53569a080d66c..bdcb07da04cdd 100644 --- a/test/it/optimizer/src/test/resources/converter/update.xml +++ b/test/it/optimizer/src/test/resources/converter/update.xml @@ -62,4 +62,6 @@ + +