Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support of customSqlUpdate to support FROM #28866

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.shardingsphere.sqlfederation.optimizer.converter.statement.update;

import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.SqlSpecialOperator;
import org.apache.calcite.sql.SqlWriter;
import org.apache.calcite.sql.SqlWriter.Frame;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.util.ImmutableNullableList;
import org.apache.calcite.util.Pair;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.checkerframework.dataflow.qual.Pure;

import java.util.Iterator;
import java.util.List;

public class CustomSqlUpdate extends SqlCall {
kanha-gupta marked this conversation as resolved.
Show resolved Hide resolved

public static final SqlSpecialOperator OPERATOR;

private SqlNode targetTable;

private SqlNodeList targetColumnList;

private SqlNodeList sourceExpressionList;

@Nullable
private SqlNode condition;

@Nullable
private SqlSelect sourceSelect;

@Nullable
private SqlIdentifier alias;

@Nullable
private SqlNode from;

public CustomSqlUpdate(final SqlParserPos pos, final SqlNode targetTable, final SqlNodeList targetColumnList, final SqlNodeList sourceExpressionList, final @Nullable SqlNode condition,
final @Nullable SqlSelect sourceSelect, final @Nullable SqlIdentifier alias, final @Nullable SqlNode from) {
super(pos);
this.targetTable = targetTable;
this.targetColumnList = targetColumnList;
this.sourceExpressionList = sourceExpressionList;
this.condition = condition;
this.sourceSelect = sourceSelect;
this.from = from;
assert sourceExpressionList.size() == targetColumnList.size();

this.alias = alias;
}

public SqlKind getKind() {
return SqlKind.UPDATE;
}

public SqlOperator getOperator() {
return OPERATOR;
}

public List<@Nullable SqlNode> getOperandList() {
return ImmutableNullableList.of(this.targetTable, this.targetColumnList, this.sourceExpressionList, this.condition, this.alias, this.from);
}

@Override
public void setOperand(final int i, final @Nullable SqlNode operand) {
switch (i) {
case 0:
assert operand instanceof SqlIdentifier;

this.targetTable = operand;
break;
case 1:
this.targetColumnList = (SqlNodeList) operand;
break;
case 2:
this.sourceExpressionList = (SqlNodeList) operand;
break;
case 3:
this.condition = operand;
break;
case 4:
this.sourceExpressionList = (SqlNodeList) operand;
break;
case 5:
this.alias = (SqlIdentifier) operand;
break;
case 6:
this.from = operand;
break;
default:
throw new AssertionError(i);
}

}

public SqlNode getTargetTable() {
return this.targetTable;
}

@Pure
public final @Nullable SqlNode getFrom() {
return this.from;
}

public void setFrom(final @Nullable SqlNode from) {
this.from = from;
}

@Pure
public @Nullable SqlIdentifier getAlias() {
return this.alias;
}

public void setAlias(final SqlIdentifier alias) {
this.alias = alias;
}

public SqlNodeList getTargetColumnList() {
return this.targetColumnList;
}

public SqlNodeList getSourceExpressionList() {
return this.sourceExpressionList;
}

public @Nullable SqlNode getCondition() {
return this.condition;
}

public @Nullable SqlSelect getSourceSelect() {
return this.sourceSelect;
}

public void setSourceSelect(final SqlSelect sourceSelect) {
this.sourceSelect = sourceSelect;
}

@Override
public void unparse(final SqlWriter writer, final int leftPrec, final int rightPrec) {
final Frame frame = writer.startList(SqlWriter.FrameTypeEnum.SELECT, "UPDATE", "");
int opLeft = this.getOperator().getLeftPrec();
int opRight = this.getOperator().getRightPrec();
this.targetTable.unparse(writer, opLeft, opRight);
SqlIdentifier alias = this.alias;
if (alias != null) {
writer.keyword("AS");
alias.unparse(writer, opLeft, opRight);
}
SqlWriter.Frame setFrame = writer.startList(SqlWriter.FrameTypeEnum.UPDATE_SET_LIST, "SET", "");
Iterator var9 = Pair.zip(this.getTargetColumnList(), this.getSourceExpressionList()).iterator();
while (var9.hasNext()) {
Pair<SqlNode, SqlNode> pair = (Pair) var9.next();
writer.sep(",");
SqlIdentifier id = (SqlIdentifier) pair.left;
id.unparse(writer, opLeft, opRight);
writer.keyword("=");
SqlNode sourceExp = (SqlNode) pair.right;
sourceExp.unparse(writer, opLeft, opRight);
}
writer.endList(setFrame);
SqlNode from = this.from;
if (from != null) {
writer.sep("FROM");
from.unparse(writer, opLeft, opRight);
}
SqlNode condition = this.condition;
if (condition != null) {
writer.sep("WHERE");
condition.unparse(writer, opLeft, opRight);
}
writer.endList(frame);
}

static {
OPERATOR = new SqlSpecialOperator("UPDATE", SqlKind.UPDATE);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlOrderBy;
import org.apache.calcite.sql.SqlUpdate;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.pagination.limit.LimitSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement;
import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.UpdateStatementHandler;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.expression.ExpressionConverter;
Expand All @@ -34,6 +34,7 @@
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.limit.PaginationValueSQLConverter;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.orderby.OrderByConverter;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.where.WhereConverter;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.with.WithConverter;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.statement.SQLStatementConverter;

import java.util.List;
Expand All @@ -47,7 +48,7 @@ public final class UpdateStatementConverter implements SQLStatementConverter<Upd

@Override
public SqlNode convert(final UpdateStatement updateStatement) {
SqlUpdate sqlUpdate = convertUpdate(updateStatement);
SqlNode sqlUpdate = convertUpdate(updateStatement);
SqlNodeList orderBy = UpdateStatementHandler.getOrderBySegment(updateStatement).flatMap(OrderByConverter::convert).orElse(SqlNodeList.EMPTY);
Optional<LimitSegment> limit = UpdateStatementHandler.getLimitSegment(updateStatement);
if (limit.isPresent()) {
Expand All @@ -58,16 +59,18 @@ public SqlNode convert(final UpdateStatement updateStatement) {
return orderBy.isEmpty() ? sqlUpdate : new SqlOrderBy(SqlParserPos.ZERO, sqlUpdate, orderBy, null, null);
}

private SqlUpdate convertUpdate(final UpdateStatement updateStatement) {
private SqlNode convertUpdate(final UpdateStatement updateStatement) {
SqlNode table = TableConverter.convert(updateStatement.getTable()).orElseThrow(IllegalStateException::new);
SqlNode from = convertTable(updateStatement.getAssignmentSegment().orElse(null).getFrom());
SqlNode condition = updateStatement.getWhere().flatMap(WhereConverter::convert).orElse(null);
SqlNodeList columns = new SqlNodeList(SqlParserPos.ZERO);
SqlNodeList expressions = new SqlNodeList(SqlParserPos.ZERO);
for (AssignmentSegment each : updateStatement.getAssignmentSegment().orElseThrow(IllegalStateException::new).getAssignments()) {
columns.addAll(convertColumn(each.getColumns()));
expressions.add(convertExpression(each.getValue()));
}
return new SqlUpdate(SqlParserPos.ZERO, table, columns, expressions, condition, null, null);
CustomSqlUpdate sqlUpdate = new CustomSqlUpdate(SqlParserPos.ZERO, table, columns, expressions, condition, null, null, from);
return UpdateStatementHandler.getWithSegment(updateStatement).flatMap(optional -> WithConverter.convert(optional, sqlUpdate)).orElse(sqlUpdate);
}

private List<SqlNode> convertColumn(final List<ColumnSegment> columnSegments) {
Expand All @@ -77,4 +80,8 @@ private List<SqlNode> convertColumn(final List<ColumnSegment> columnSegments) {
private SqlNode convertExpression(final ExpressionSegment expressionSegment) {
return ExpressionConverter.convert(expressionSegment).orElseThrow(IllegalStateException::new);
}

private SqlNode convertTable(final TableSegment tableSegment) {
return TableConverter.convert(tableSegment).orElse(null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -995,9 +995,14 @@ public ASTNode visitUpdate(final UpdateContext ctx) {
@Override
public ASTNode visitSetAssignmentsClause(final SetAssignmentsClauseContext ctx) {
Collection<AssignmentSegment> assignments = new LinkedList<>();
TableSegment from;
for (AssignmentContext each : ctx.assignment()) {
assignments.add((AssignmentSegment) visit(each));
}
if (null != ctx.fromClause()) {
from = (TableSegment) visit(ctx.fromClause().tableReferences());
return new SetAssignmentSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), assignments, from);
}
return new SetAssignmentSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), assignments);
}

Expand All @@ -1017,7 +1022,6 @@ public ASTNode visitAssignment(final AssignmentContext ctx) {
columnSegments.add(column);
ExpressionSegment value = (ExpressionSegment) visit(ctx.assignmentValue());
AssignmentSegment result = new ColumnAssignmentSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), columnSegments, value);
result.getColumns().add(column);
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.sql.parser.sql.common.segment.SQLSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment;

import java.util.Collection;

Expand All @@ -35,4 +36,17 @@ public final class SetAssignmentSegment implements SQLSegment {
private final int stopIndex;

private final Collection<AssignmentSegment> assignments;

private TableSegment from;

public SetAssignmentSegment(final int startIndex, final int stopIndex, final Collection<AssignmentSegment> assignments, final TableSegment from) {
this.startIndex = startIndex;
this.stopIndex = stopIndex;
this.assignments = assignments;
this.from = from;
}

public TableSegment getFrom() {
return from;
}
}
2 changes: 2 additions & 0 deletions test/it/optimizer/src/test/resources/converter/update.xml
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,6 @@
<!--<test-cases sql-case-id="update_with_translate_function" expected-sql="UPDATE &quot;translate_tab&quot; SET &quot;char_col&quot; = TRANSLATE(&quot;nchar_col&quot; USING 'CHAR_CS')" db-types="Oracle" />-->
<test-cases sql-case-id="update_with_dot_column_name" expected-sql="UPDATE &quot;employees&quot; SET &quot;salary&quot; = &quot;salary&quot; + 10 WHERE &quot;employee_id&quot; BETWEEN ASYMMETRIC 1 AND 10" db-types="Oracle" sql-case-types="LITERAL" />
<test-cases sql-case-id="update_with_dot_column_name" expected-sql="UPDATE &quot;employees&quot; SET &quot;salary&quot; = &quot;salary&quot; + ? WHERE &quot;employee_id&quot; BETWEEN ASYMMETRIC ? AND ?" db-types="Oracle" sql-case-types="PLACEHOLDER" />
<test-cases sql-case-id="update_with_with_clause" expected-sql="(WITH [cte] ([order_id], [user_id], [status]) AS (SELECT [order_id], [user_id], [status] FROM [t_order]) UPDATE [t_order] SET [status] = 1 FROM [t_order] AS [t] INNER JOIN [cte] AS [c] ON [t].[order_id] = [c].[order_id] WHERE [c].[order_id] = 1)" db-types="SQLServer" sql-case-types="LITERAL" />
<test-cases sql-case-id="update_with_with_clause" expected-sql="(WITH [cte] ([order_id], [user_id], [status]) AS (SELECT [order_id], [user_id], [status] FROM [t_order]) UPDATE [t_order] SET [status] = ? FROM [t_order] AS [t] INNER JOIN [cte] AS [c] ON [t].[order_id] = [c].[order_id] WHERE [c].[order_id] = ?)" db-types="SQLServer" sql-case-types="PLACEHOLDER" />
</sql-node-converter-test-cases>