From c0ce9627ef8754d09a9cfe9fcac4488db2a576ad Mon Sep 17 00:00:00 2001 From: niu niu Date: Sun, 29 Oct 2023 17:07:35 +0800 Subject: [PATCH] Support sql federation cte (#28888) * Support mysql cte sql parse * Add mysql cte sql parse test * Refactor sql federation WithConverter * Support sql federation SelectStatement with convert * Add sql federation cte execution plan test * Format test sql * Format parse code * Change SelectStatementHandler mysql test --- .../converter/segment/with/WithConverter.java | 32 +- .../select/SelectStatementConverter.java | 8 +- .../cases/federation-query-sql-cases.xml | 12 + .../statement/MySQLStatementVisitor.java | 29 +- .../handler/dml/SelectStatementHandler.java | 6 + .../mysql/dml/MySQLSelectStatement.java | 12 + .../dml/SelectStatementHandlerTest.java | 10 + .../src/test/resources/converter/delete.xml | 4 +- .../test/resources/converter/select-with.xml | 24 ++ .../main/resources/case/dml/select-with.xml | 279 ++++++++++++++++++ .../sql/supported/dml/select-with.xml | 4 + 11 files changed, 400 insertions(+), 20 deletions(-) create mode 100644 test/it/optimizer/src/test/resources/converter/select-with.xml diff --git a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/segment/with/WithConverter.java b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/segment/with/WithConverter.java index c27b9d0949b94..138a5a80ce5a3 100644 --- a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/segment/with/WithConverter.java +++ b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/segment/with/WithConverter.java @@ -26,8 +26,9 @@ import org.apache.calcite.sql.SqlWithItem; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment; +import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.complex.CommonTableExpressionSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.WithSegment; -import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.expression.ExpressionConverter; +import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.expression.impl.ColumnConverter; import org.apache.shardingsphere.sqlfederation.optimizer.converter.statement.select.SelectStatementConverter; import java.util.Collection; @@ -48,21 +49,20 @@ public final class WithConverter { * @return sql node list */ public static Optional convert(final WithSegment withSegment, final SqlNode sqlNode) { - SqlIdentifier name = new SqlIdentifier(withSegment.getCommonTableExpressions().iterator().next().getIdentifier().getValue(), SqlParserPos.ZERO); - SqlNode selectSubquery = new SelectStatementConverter().convert(withSegment.getCommonTableExpressions().iterator().next().getSubquery().getSelect()); - Collection collectionColumns = withSegment.getCommonTableExpressions().iterator().next().getColumns(); - Collection convertedColumns; - SqlNodeList columns = null; - if (!collectionColumns.isEmpty()) { - convertedColumns = collectionColumns.stream().map(ExpressionConverter::convert).filter(Optional::isPresent).map(Optional::get).collect(Collectors.toList()); - columns = new SqlNodeList(convertedColumns, SqlParserPos.ZERO); - } - SqlWithItem sqlWithItem = new SqlWithItem(SqlParserPos.ZERO, name, columns, selectSubquery); - SqlNodeList sqlWithItems = new SqlNodeList(SqlParserPos.ZERO); - sqlWithItems.add(sqlWithItem); - SqlWith sqlWith = new SqlWith(SqlParserPos.ZERO, sqlWithItems, sqlNode); + return Optional.of(new SqlWith(SqlParserPos.ZERO, convertWithItem(withSegment.getCommonTableExpressions()), sqlNode)); + } + + private static SqlNodeList convertWithItem(final Collection commonTableExpressionSegments) { SqlNodeList result = new SqlNodeList(SqlParserPos.ZERO); - result.add(sqlWith); - return Optional.of(result); + for (CommonTableExpressionSegment each : commonTableExpressionSegments) { + SqlIdentifier name = new SqlIdentifier(each.getIdentifier().getValue(), SqlParserPos.ZERO); + SqlNodeList columns = each.getColumns().isEmpty() ? null : convertColumns(each.getColumns()); + result.add(new SqlWithItem(SqlParserPos.ZERO, name, columns, new SelectStatementConverter().convert(each.getSubquery().getSelect()))); + } + return result; + } + + private static SqlNodeList convertColumns(final Collection columnSegments) { + return new SqlNodeList(columnSegments.stream().map(each -> ColumnConverter.convert(each).orElseThrow(IllegalStateException::new)).collect(Collectors.toList()), SqlParserPos.ZERO); } } diff --git a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/select/SelectStatementConverter.java b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/select/SelectStatementConverter.java index a96dc7428e83f..a975d0a023271 100644 --- a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/select/SelectStatementConverter.java +++ b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/select/SelectStatementConverter.java @@ -36,6 +36,7 @@ import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.projection.ProjectionsConverter; import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.where.WhereConverter; import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.window.WindowConverter; +import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.with.WithConverter; import org.apache.shardingsphere.sqlfederation.optimizer.converter.statement.SQLStatementConverter; import org.apache.shardingsphere.sqlfederation.optimizer.converter.type.CombineOperatorConverter; @@ -50,7 +51,8 @@ public final class SelectStatementConverter implements SQLStatementConverter limit = SelectStatementHandler.getLimitSegment(selectStatement); if (limit.isPresent()) { @@ -61,6 +63,10 @@ public SqlNode convert(final SelectStatement selectStatement) { return orderBy.isEmpty() ? sqlCombine : new SqlOrderBy(SqlParserPos.ZERO, sqlCombine, orderBy, null, null); } + private SqlNode convertWith(final SqlNode sqlSelect, final SelectStatement selectStatement) { + return SelectStatementHandler.getWithSegment(selectStatement).flatMap(segment -> WithConverter.convert(segment, sqlSelect)).orElse(null); + } + private SqlSelect convertSelect(final SelectStatement selectStatement) { SqlNodeList distinct = DistinctConverter.convert(selectStatement.getProjections()).orElse(null); SqlNodeList projection = ProjectionsConverter.convert(selectStatement.getProjections()).orElseThrow(IllegalStateException::new); diff --git a/kernel/sql-federation/optimizer/src/test/resources/cases/federation-query-sql-cases.xml b/kernel/sql-federation/optimizer/src/test/resources/cases/federation-query-sql-cases.xml index 0584a210361c3..fbfc970006cba 100644 --- a/kernel/sql-federation/optimizer/src/test/resources/cases/federation-query-sql-cases.xml +++ b/kernel/sql-federation/optimizer/src/test/resources/cases/federation-query-sql-cases.xml @@ -436,4 +436,16 @@ + + + + + + + + + + + + diff --git a/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java b/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java index 9789712de3707..821b8c117dc84 100644 --- a/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java +++ b/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java @@ -49,6 +49,7 @@ import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.ConstraintNameContext; import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.ConvertFunctionContext; import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.CurrentUserFunctionContext; +import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.CteClauseContext; import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.DataTypeContext; import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.DeleteContext; import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.DuplicateSpecificationContext; @@ -147,6 +148,7 @@ import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.WindowFunctionContext; import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.WindowItemContext; import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.WindowSpecificationContext; +import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.WithClauseContext; import org.apache.shardingsphere.sql.parser.sql.common.enums.AggregationType; import org.apache.shardingsphere.sql.parser.sql.common.enums.CombineType; import org.apache.shardingsphere.sql.parser.sql.common.enums.JoinType; @@ -181,6 +183,7 @@ import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.UnaryOperationExpression; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ValuesExpression; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.complex.CommonExpressionSegment; +import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.complex.CommonTableExpressionSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.SimpleExpressionSegment; @@ -220,6 +223,7 @@ import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SubqueryTableSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableNameSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment; +import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.WithSegment; import org.apache.shardingsphere.sql.parser.sql.common.util.SQLUtils; import org.apache.shardingsphere.sql.parser.sql.common.value.collection.CollectionValue; import org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue; @@ -712,6 +716,9 @@ public ASTNode visitQueryExpression(final QueryExpressionContext ctx) { if (null != ctx.limitClause()) { result.setLimit((LimitSegment) visit(ctx.limitClause())); } + if (null != result && null != ctx.withClause()) { + result.setWithSegment((WithSegment) visit(ctx.withClause())); + } return result; } @@ -727,6 +734,27 @@ public ASTNode visitSelectWithInto(final SelectWithIntoContext ctx) { return result; } + @Override + public ASTNode visitWithClause(final WithClauseContext ctx) { + Collection commonTableExpressions = new LinkedList<>(); + for (CteClauseContext each : ctx.cteClause()) { + commonTableExpressions.add((CommonTableExpressionSegment) visit(each)); + } + return new WithSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), commonTableExpressions); + } + + @SuppressWarnings("unchecked") + @Override + public ASTNode visitCteClause(final CteClauseContext ctx) { + CommonTableExpressionSegment result = new CommonTableExpressionSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), (IdentifierValue) visit(ctx.identifier()), + new SubquerySegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), (MySQLSelectStatement) visit(ctx.subquery()), getOriginalText(ctx.subquery()))); + if (null != ctx.columnNames()) { + CollectionValue columns = (CollectionValue) visit(ctx.columnNames()); + result.getColumns().addAll(columns.getValue()); + } + return result; + } + @Override public ASTNode visitQueryExpressionBody(final QueryExpressionBodyContext ctx) { if (1 == ctx.getChildCount() && ctx.getChild(0) instanceof QueryPrimaryContext) { @@ -1592,7 +1620,6 @@ private List generateTablesFromTableAliasRefList(final Table @Override public ASTNode visitSelect(final SelectContext ctx) { - // TODO :Unsupported for withClause. MySQLSelectStatement result; if (null != ctx.queryExpression()) { result = (MySQLSelectStatement) visit(ctx.queryExpression()); diff --git a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/dialect/handler/dml/SelectStatementHandler.java b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/dialect/handler/dml/SelectStatementHandler.java index daeeaeffcde57..db1a92b0866cb 100644 --- a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/dialect/handler/dml/SelectStatementHandler.java +++ b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/dialect/handler/dml/SelectStatementHandler.java @@ -183,6 +183,9 @@ public static Optional getWithSegment(final SelectStatement selectS if (selectStatement instanceof SQLServerSelectStatement) { return ((SQLServerSelectStatement) selectStatement).getWithSegment(); } + if (selectStatement instanceof MySQLSelectStatement) { + return ((MySQLSelectStatement) selectStatement).getWithSegment(); + } return Optional.empty(); } @@ -199,6 +202,9 @@ public static void setWithSegment(final SelectStatement selectStatement, final W if (selectStatement instanceof SQLServerSelectStatement) { ((SQLServerSelectStatement) selectStatement).setWithSegment(withSegment); } + if (selectStatement instanceof MySQLSelectStatement) { + ((MySQLSelectStatement) selectStatement).setWithSegment(withSegment); + } } /** diff --git a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/dialect/statement/mysql/dml/MySQLSelectStatement.java b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/dialect/statement/mysql/dml/MySQLSelectStatement.java index 018bd27736bad..f6d35886c5b23 100644 --- a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/dialect/statement/mysql/dml/MySQLSelectStatement.java +++ b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/dialect/statement/mysql/dml/MySQLSelectStatement.java @@ -21,6 +21,7 @@ import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.pagination.limit.LimitSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.LockSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.WindowSegment; +import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.WithSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment; import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement; import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.MySQLStatement; @@ -41,6 +42,8 @@ public final class MySQLSelectStatement extends SelectStatement implements MySQL private WindowSegment window; + private WithSegment withSegment; + /** * Get order by segment. * @@ -76,4 +79,13 @@ public Optional getWindow() { public Optional getTable() { return Optional.ofNullable(table); } + + /** + * Get with segment. + * + * @return with segment. + */ + public Optional getWithSegment() { + return Optional.ofNullable(withSegment); + } } diff --git a/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/dialect/handler/dml/SelectStatementHandlerTest.java b/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/dialect/handler/dml/SelectStatementHandlerTest.java index 2b8abdc2f7444..f4b3ff013d002 100644 --- a/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/dialect/handler/dml/SelectStatementHandlerTest.java +++ b/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/dialect/handler/dml/SelectStatementHandlerTest.java @@ -187,6 +187,16 @@ void assertGetWithSegmentForOracle() { assertFalse(SelectStatementHandler.getWithSegment(new OracleSelectStatement()).isPresent()); } + @Test + void assertGetWithSegmentForMysql() { + MySQLSelectStatement selectStatement = new MySQLSelectStatement(); + selectStatement.setWithSegment(new WithSegment(0, 2, new LinkedList<>())); + Optional withSegment = SelectStatementHandler.getWithSegment(selectStatement); + assertTrue(withSegment.isPresent()); + assertThat(withSegment.get(), is(selectStatement.getWithSegment().get())); + assertFalse(SelectStatementHandler.getWithSegment(new MySQLSelectStatement()).isPresent()); + } + @Test void assertGetWithSegmentForSQLServer() { SQLServerSelectStatement selectStatement = new SQLServerSelectStatement(); diff --git a/test/it/optimizer/src/test/resources/converter/delete.xml b/test/it/optimizer/src/test/resources/converter/delete.xml index 1f320631d0bed..331174393e624 100644 --- a/test/it/optimizer/src/test/resources/converter/delete.xml +++ b/test/it/optimizer/src/test/resources/converter/delete.xml @@ -40,6 +40,6 @@ - - + + diff --git a/test/it/optimizer/src/test/resources/converter/select-with.xml b/test/it/optimizer/src/test/resources/converter/select-with.xml new file mode 100644 index 0000000000000..646c41e1793d6 --- /dev/null +++ b/test/it/optimizer/src/test/resources/converter/select-with.xml @@ -0,0 +1,24 @@ + + + + + + + + + diff --git a/test/it/parser/src/main/resources/case/dml/select-with.xml b/test/it/parser/src/main/resources/case/dml/select-with.xml index 1d579c105696f..fb1b2021e4acc 100644 --- a/test/it/parser/src/main/resources/case/dml/select-with.xml +++ b/test/it/parser/src/main/resources/case/dml/select-with.xml @@ -364,4 +364,283 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + = + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + = + + + + + + AND + + + + + + = + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + = + + + AND + + + + + + + + = + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/test/it/parser/src/main/resources/sql/supported/dml/select-with.xml b/test/it/parser/src/main/resources/sql/supported/dml/select-with.xml index 75a1f7cdd3e81..d6d7e917c3a70 100644 --- a/test/it/parser/src/main/resources/sql/supported/dml/select-with.xml +++ b/test/it/parser/src/main/resources/sql/supported/dml/select-with.xml @@ -23,4 +23,8 @@ + + + +