Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change sql federation insert statement conversion #28631

Merged
merged 3 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,22 @@
import org.apache.calcite.sql.SqlInsert;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.SqlValuesOperator;
import org.apache.calcite.sql.fun.SqlRowOperator;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.SelectStatementHandler;
import org.apache.shardingsphere.sqlfederation.compiler.converter.segment.expression.ExpressionConverter;
import org.apache.shardingsphere.sqlfederation.compiler.converter.segment.expression.impl.ColumnConverter;
import org.apache.shardingsphere.sqlfederation.compiler.converter.segment.from.TableConverter;
import org.apache.shardingsphere.sqlfederation.compiler.converter.segment.groupby.GroupByConverter;
import org.apache.shardingsphere.sqlfederation.compiler.converter.segment.groupby.HavingConverter;
import org.apache.shardingsphere.sqlfederation.compiler.converter.segment.projection.DistinctConverter;
import org.apache.shardingsphere.sqlfederation.compiler.converter.segment.projection.ProjectionsConverter;
import org.apache.shardingsphere.sqlfederation.compiler.converter.segment.where.WhereConverter;
import org.apache.shardingsphere.sqlfederation.compiler.converter.segment.window.WindowConverter;
import org.apache.shardingsphere.sqlfederation.compiler.converter.statement.SQLStatementConverter;
import org.apache.shardingsphere.sqlfederation.compiler.converter.statement.select.SelectStatementConverter;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

Expand All @@ -60,28 +52,18 @@ public SqlNode convert(final InsertStatement insertStatement) {

private SqlInsert convertInsert(final InsertStatement insertStatement) {
SqlNode table = new TableConverter().convert(insertStatement.getTable()).orElseThrow(IllegalStateException::new);
SqlParserPos position = SqlParserPos.ZERO;
SqlNodeList keywords = new SqlNodeList(position);
SqlNode source;
if (insertStatement.getInsertSelect().isPresent()) {
source = convertSelect(insertStatement.getInsertSelect().get());
} else {
source = convertValues(insertStatement.getValues());
}
SqlNodeList keywords = new SqlNodeList(SqlParserPos.ZERO);
SqlNode source = convertSource(insertStatement);
SqlNodeList columnList = convertColumn(insertStatement.getColumns());
return new SqlInsert(SqlParserPos.ZERO, keywords, table, source, columnList);
}

private SqlNode convertSelect(final SubquerySegment subquerySegment) {
SelectStatement selectStatement = subquerySegment.getSelect();
SqlNodeList distinct = new DistinctConverter().convert(selectStatement.getProjections()).orElse(null);
SqlNodeList projection = new ProjectionsConverter().convert(selectStatement.getProjections()).orElseThrow(IllegalStateException::new);
SqlNode from = new TableConverter().convert(selectStatement.getFrom()).orElse(null);
SqlNode where = selectStatement.getWhere().flatMap(optional -> new WhereConverter().convert(optional)).orElse(null);
SqlNodeList groupBy = selectStatement.getGroupBy().flatMap(optional -> new GroupByConverter().convert(optional)).orElse(null);
SqlNode having = selectStatement.getHaving().flatMap(optional -> new HavingConverter().convert(optional)).orElse(null);
SqlNodeList window = SelectStatementHandler.getWindowSegment(selectStatement).flatMap(new WindowConverter()::convert).orElse(SqlNodeList.EMPTY);
return new SqlSelect(SqlParserPos.ZERO, distinct, projection, from, where, groupBy, having, window, null, null, null, null, SqlNodeList.EMPTY);
private SqlNode convertSource(final InsertStatement insertStatement) {
if (insertStatement.getInsertSelect().isPresent()) {
return new SelectStatementConverter().convert(insertStatement.getInsertSelect().get().getSelect());
} else {
return convertValues(insertStatement.getValues());
}
}

private SqlNode convertValues(final Collection<InsertValuesSegment> insertValuesSegments) {
Expand All @@ -91,17 +73,13 @@ private SqlNode convertValues(final Collection<InsertValuesSegment> insertValues
values.add(convertExpression(value));
}
}
List<SqlNode> operands = new ArrayList<>();
operands.add(new SqlBasicCall(new SqlRowOperator("ROW"), values, SqlParserPos.ZERO));
List<SqlNode> operands = Collections.singletonList(new SqlBasicCall(new SqlRowOperator("ROW"), values, SqlParserPos.ZERO));
return new SqlBasicCall(new SqlValuesOperator(), operands, SqlParserPos.ZERO);
}

private SqlNodeList convertColumn(final Collection<ColumnSegment> columnSegments) {
List<SqlNode> columns = columnSegments.stream().map(each -> new ColumnConverter().convert(each).orElseThrow(IllegalStateException::new)).collect(Collectors.toList());
if (columns.isEmpty()) {
return SqlNodeList.EMPTY;
}
return new SqlNodeList(columns, SqlParserPos.ZERO);
return columns.isEmpty() ? null : new SqlNodeList(columns, SqlParserPos.ZERO);
}

private SqlNode convertExpression(final ExpressionSegment expressionSegment) {
Expand Down
9 changes: 6 additions & 3 deletions test/it/optimizer/src/test/resources/converter/insert.xml
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@
<test-cases sql-case-id="insert_with_batch_and_irregular_parameters" expected-sql="INSERT INTO `t_order` (`order_id`, `user_id`, `status`) VALUES (?, 1, 'insert', ?, ?, ?)" db-types="MySQL" sql-case-types="PLACEHOLDER" />
<test-cases sql-case-id="insert_with_batch_and_composite_expression" expected-sql="INSERT INTO `t_order` (`order_id`, `user_id`, `status`) VALUES (?, ?, `SUBSTR`(?, 1), ?, ?, `SUBSTR`(?, 1))" db-types="H2,MySQL" sql-case-types="PLACEHOLDER" />
<test-cases sql-case-id="insert_with_batch_and_with_generate_key_column" expected-sql="INSERT INTO `t_order_item` (`item_id`, `order_id`, `user_id`, `status`, `creation_date`) VALUES (?, ?, ?, 'insert', '2017-08-08', ?, ?, ?, 'insert', '2017-08-08')" db-types="MySQL" sql-case-types="PLACEHOLDER" />
<test-cases sql-case-id="insert_with_batch_and_with_generate_key_column" expected-sql="INSERT INTO &quot;t_order_item&quot; (&quot;item_id&quot;, &quot;order_id&quot;, &quot;user_id&quot;, &quot;status&quot;, &quot;creation_date&quot;) VALUES (?, ?, ?, 'insert', '2017-08-08', ?, ?, ?, 'insert', '2017-08-08')" db-types="SQLServer, PostgreSQL,openGauss" sql-case-types="PLACEHOLDER" />
<test-cases sql-case-id="insert_with_batch_and_with_generate_key_column" expected-sql="INSERT INTO [t_order_item] ([item_id], [order_id], [user_id], [status], [creation_date]) VALUES (?, ?, ?, 'insert', '2017-08-08', ?, ?, ?, 'insert', '2017-08-08')" db-types="SQLServer" sql-case-types="PLACEHOLDER" />
<test-cases sql-case-id="insert_with_batch_and_with_generate_key_column" expected-sql="INSERT INTO &quot;t_order_item&quot; (&quot;item_id&quot;, &quot;order_id&quot;, &quot;user_id&quot;, &quot;status&quot;, &quot;creation_date&quot;) VALUES (?, ?, ?, 'insert', '2017-08-08', ?, ?, ?, 'insert', '2017-08-08')" db-types="PostgreSQL,openGauss" sql-case-types="PLACEHOLDER" />
<test-cases sql-case-id="insert_with_batch_and_without_generate_key_column" expected-sql="INSERT INTO `t_order_item` (`order_id`, `user_id`, `status`, `creation_date`) VALUES (?, ?, 'insert', '2017-08-08', ?, ?, 'insert', '2017-08-08')" db-types="MySQL" sql-case-types="PLACEHOLDER" />
<test-cases sql-case-id="insert_with_batch_and_without_generate_key_column" expected-sql="INSERT INTO &quot;t_order_item&quot; (&quot;order_id&quot;, &quot;user_id&quot;, &quot;status&quot;, &quot;creation_date&quot;) VALUES (?, ?, 'insert', '2017-08-08', ?, ?, 'insert', '2017-08-08')" db-types="SQLServer, PostgreSQL,openGauss" sql-case-types="PLACEHOLDER" />
<test-cases sql-case-id="insert_with_batch_and_without_generate_key_column" expected-sql="INSERT INTO &quot;t_order_item&quot; (&quot;order_id&quot;, &quot;user_id&quot;, &quot;status&quot;, &quot;creation_date&quot;) VALUES (?, ?, 'insert', '2017-08-08', ?, ?, 'insert', '2017-08-08')" db-types="PostgreSQL,openGauss" sql-case-types="PLACEHOLDER" />
<test-cases sql-case-id="insert_with_batch_and_without_generate_key_column" expected-sql="INSERT INTO [t_order_item] ([order_id], [user_id], [status], [creation_date]) VALUES (?, ?, 'insert', '2017-08-08', ?, ?, 'insert', '2017-08-08')" db-types="SQLServer" sql-case-types="PLACEHOLDER" />
<test-cases sql-case-id="insert_with_multiple_values" expected-sql="INSERT INTO `t_order` (`order_id`, `user_id`, `status`) VALUES (1, 1, 'insert', 2, 2, 'insert2')" db-types="MySQL" sql-case-types="LITERAL" />
<test-cases sql-case-id="insert_with_one_auto_increment_column" expected-sql="INSERT INTO `t_auto_increment_table` VALUES ()" db-types="MySQL" sql-case-types="LITERAL" />
<test-cases sql-case-id="insert_with_double_value" expected-sql="INSERT INTO `t_double_test` (`col1`) VALUES (1.22)" db-types="MySQL" sql-case-types="LITERAL" />
Expand All @@ -69,7 +71,8 @@
<test-cases sql-case-id="insert_with_schema" expected-sql="INSERT INTO &quot;db1&quot;.&quot;t_order&quot; VALUES (1, 2, 3)" db-types="PostgreSQL,openGauss,Oracle" sql-case-types="LITERAL" />
<test-cases sql-case-id="insert_with_negative_value" expected-sql="INSERT INTO `t_order` (`order_id`, `user_id`, `status`) VALUES (?, ?, ?)" db-types="MySQL" sql-case-types="PLACEHOLDER" />
<test-cases sql-case-id="insert_with_negative_value" expected-sql="INSERT INTO &quot;t_order&quot; (&quot;order_id&quot;, &quot;user_id&quot;, &quot;status&quot;) VALUES (?, ?, ?)" db-types="PostgreSQL,openGauss,Oracle" sql-case-types="PLACEHOLDER" />
<test-cases sql-case-id="insert_datetime_literals" expected-sql="INSERT INTO &quot;date_tab&quot; VALUES ('1999-12-01 10:00:00', '1999-12-01 10:00:00', '1999-12-01 10:00:00')" db-types="Oracle" sql-case-types="LITERAL" />
<!-- FIXME -->
<!--<test-cases sql-case-id="insert_datetime_literals" expected-sql="INSERT INTO &quot;date_tab&quot; VALUES ('1999-12-01 10:00:00', '1999-12-01 10:00:00', '1999-12-01 10:00:00')" db-types="Oracle" sql-case-types="LITERAL" />-->
<test-cases sql-case-id="insert_with_content_keyword" expected-sql="INSERT INTO &quot;SYS_MQ_MSG&quot; (&quot;ID&quot;, &quot;CONTENT&quot;) VALUES (1, 'test')" db-types="Oracle" sql-case-types="LITERAL" />
<test-cases sql-case-id="insert_with_connect_by_and_prior" expected-sql="INSERT INTO &quot;t&quot; (&quot;c1&quot;, &quot;c2&quot;, &quot;c3&quot;, &quot;c4&quot;, &quot;c5&quot;) SELECT &quot;c1&quot;, &quot;c2&quot;, &quot;regexp_substr&quot;(&quot;c3&quot;, '[^,]+', 1, &quot;l&quot;) &quot;c3&quot;, &quot;c4&quot;, &quot;c5&quot; FROM &quot;t&quot; WHERE &quot;id&quot; = 1" db-types="Oracle" sql-case-types="LITERAL" />
<test-cases sql-case-id="insert_with_national_character_set" expected-sql="INSERT INTO &quot;customers&quot; VALUES (1000, &quot;TO_NCHAR&quot;('John Smith'), '''500 Oracle Parkway', &quot;sysdate&quot;)" db-types="Oracle" sql-case-types="LITERAL" />
Expand Down