Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sql bind logic for create table statement and check simple table binder #34074

Merged
merged 10 commits into from
Dec 17, 2024
Merged
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
1. Proxy Native: Support Seata AT integration under Proxy Native in GraalVM Native Image - [#33889](https://github.com/apache/shardingsphere/pull/33889)
1. Agent: Simplify the use of Agent's Docker Image - [#33356](https://github.com/apache/shardingsphere/pull/33356)
1. Metadata: Add load-table-metadata-batch-size props to concurrent load table metadata - [#34009](https://github.com/apache/shardingsphere/pull/34009)
1. SQL Binder: Add sql bind logic for create table statement - [#34074](https://github.com/apache/shardingsphere/pull/34074)

### Bug Fixes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ public Stream<? extends Arguments> provideArguments(final ExtensionContext exten
Arguments.of("update t_warehouse set warehouse_name = ? where id = ?", Arrays.asList("foo", 1), true, Collections.singletonList(1)),
Arguments.of("delete from t_warehouse where id = ?", Collections.singletonList(1), true, Collections.singletonList(0)));
Collection<? extends Arguments> nonCacheableCases = Arrays.asList(
Arguments.of("create table t_warehouse (id int4 not null primary key)", Collections.emptyList(), false, Collections.emptyList()),
Arguments.of("create table t_warehouse_for_create (id int4 not null primary key)", Collections.emptyList(), false, Collections.emptyList()),
Arguments.of("insert into t_warehouse (id) select warehouse_id from t_order", Collections.emptyList(), false, Collections.emptyList()),
Arguments.of("insert into t_warehouse (id) values (?), (?)", Arrays.asList(1, 2), false, Collections.emptyList()),
Arguments.of("insert into t_non_sharding_table (id) values (?)", Collections.singletonList(1), false, Collections.emptyList()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ class ShardingCreateFunctionSupportedCheckerTest {
void assertCheckCreateFunctionForMySQL() {
MySQLSelectStatement selectStatement = new MySQLSelectStatement();
selectStatement.setFrom(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order_item"))));
MySQLCreateTableStatement createTableStatement = new MySQLCreateTableStatement(false);
MySQLCreateTableStatement createTableStatement = new MySQLCreateTableStatement();
createTableStatement.setIfNotExists(false);
createTableStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))));
ValidStatementSegment validStatementSegment = new ValidStatementSegment(0, 0);
validStatementSegment.setSqlStatement(createTableStatement);
Expand Down Expand Up @@ -104,7 +105,8 @@ void assertCheckCreateFunctionWithNoSuchTableForMySQL() {

@Test
void assertCheckCreateFunctionWithTableExistsForMySQL() {
MySQLCreateTableStatement createTableStatement = new MySQLCreateTableStatement(false);
MySQLCreateTableStatement createTableStatement = new MySQLCreateTableStatement();
createTableStatement.setIfNotExists(false);
createTableStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))));
ValidStatementSegment validStatementSegment = new ValidStatementSegment(0, 0);
validStatementSegment.setSqlStatement(createTableStatement);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ class ShardingCreateProcedureSupportedCheckerTest {
void assertCheckForMySQL() {
MySQLSelectStatement selectStatement = new MySQLSelectStatement();
selectStatement.setFrom(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order_item"))));
MySQLCreateTableStatement createTableStatement = new MySQLCreateTableStatement(false);
MySQLCreateTableStatement createTableStatement = new MySQLCreateTableStatement();
createTableStatement.setIfNotExists(false);
createTableStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))));
ValidStatementSegment validStatementSegment = new ValidStatementSegment(0, 0);
validStatementSegment.setSqlStatement(createTableStatement);
Expand Down Expand Up @@ -105,7 +106,8 @@ void assertCheckWithNoSuchTableForMySQL() {

@Test
void assertCheckWithTableExistsForMySQL() {
MySQLCreateTableStatement createTableStatement = new MySQLCreateTableStatement(false);
MySQLCreateTableStatement createTableStatement = new MySQLCreateTableStatement();
createTableStatement.setIfNotExists(false);
createTableStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))));
ValidStatementSegment validStatementSegment = new ValidStatementSegment(0, 0);
validStatementSegment.setSqlStatement(createTableStatement);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class ShardingCreateTableSupportedCheckerTest {

@Test
void assertCheckForMySQL() {
MySQLCreateTableStatement sqlStatement = new MySQLCreateTableStatement(false);
MySQLCreateTableStatement sqlStatement = new MySQLCreateTableStatement();
sqlStatement.setIfNotExists(false);
sqlStatement.setTable(new SimpleTableSegment(new TableNameSegment(1, 2, new IdentifierValue("t_order"))));
assertThrows(TableExistsException.class, () -> assertCheck(sqlStatement));
}
Expand All @@ -63,7 +64,8 @@ void assertCheckForOracle() {

@Test
void assertCheckForPostgreSQL() {
PostgreSQLCreateTableStatement sqlStatement = new PostgreSQLCreateTableStatement(false);
PostgreSQLCreateTableStatement sqlStatement = new PostgreSQLCreateTableStatement();
sqlStatement.setIfNotExists(false);
sqlStatement.setTable(new SimpleTableSegment(new TableNameSegment(1, 2, new IdentifierValue("t_order"))));
assertThrows(TableExistsException.class, () -> assertCheck(sqlStatement));
}
Expand Down Expand Up @@ -92,14 +94,16 @@ private void assertCheck(final CreateTableStatement sqlStatement) {

@Test
void assertCheckIfNotExistsForMySQL() {
MySQLCreateTableStatement sqlStatement = new MySQLCreateTableStatement(true);
MySQLCreateTableStatement sqlStatement = new MySQLCreateTableStatement();
sqlStatement.setIfNotExists(true);
sqlStatement.setTable(new SimpleTableSegment(new TableNameSegment(1, 2, new IdentifierValue("t_order"))));
assertCheckIfNotExists(sqlStatement);
}

@Test
void assertCheckIfNotExistsForPostgreSQL() {
PostgreSQLCreateTableStatement sqlStatement = new PostgreSQLCreateTableStatement(true);
PostgreSQLCreateTableStatement sqlStatement = new PostgreSQLCreateTableStatement();
sqlStatement.setIfNotExists(true);
sqlStatement.setTable(new SimpleTableSegment(new TableNameSegment(1, 2, new IdentifierValue("t_order"))));
assertCheckIfNotExists(sqlStatement);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ class ShardingCreateTableRouteContextCheckerTest {

@Test
void assertCheckWithSameRouteResultShardingTableForPostgreSQL() {
PostgreSQLCreateTableStatement sqlStatement = new PostgreSQLCreateTableStatement(false);
PostgreSQLCreateTableStatement sqlStatement = new PostgreSQLCreateTableStatement();
sqlStatement.setIfNotExists(false);
sqlStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))));
when(shardingRule.isShardingTable("t_order")).thenReturn(true);
when(shardingRule.getShardingTable("t_order")).thenReturn(new ShardingTable(Arrays.asList("ds_0", "ds_1"), "t_order"));
Expand All @@ -78,7 +79,8 @@ void assertCheckWithSameRouteResultShardingTableForPostgreSQL() {

@Test
void assertCheckWithDifferentRouteResultShardingTableForPostgreSQL() {
PostgreSQLCreateTableStatement sqlStatement = new PostgreSQLCreateTableStatement(false);
PostgreSQLCreateTableStatement sqlStatement = new PostgreSQLCreateTableStatement();
sqlStatement.setIfNotExists(false);
sqlStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))));
when(shardingRule.isShardingTable("t_order")).thenReturn(true);
when(shardingRule.getShardingTable("t_order")).thenReturn(new ShardingTable(Arrays.asList("ds_0", "ds_1"), "t_order"));
Expand All @@ -92,7 +94,8 @@ void assertCheckWithDifferentRouteResultShardingTableForPostgreSQL() {

@Test
void assertCheckWithSameRouteResultBroadcastTableForPostgreSQL() {
PostgreSQLCreateTableStatement sqlStatement = new PostgreSQLCreateTableStatement(false);
PostgreSQLCreateTableStatement sqlStatement = new PostgreSQLCreateTableStatement();
sqlStatement.setIfNotExists(false);
sqlStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_config"))));
when(queryContext.getSqlStatementContext()).thenReturn(new CreateTableStatementContext(sqlStatement));
assertDoesNotThrow(() -> new ShardingCreateTableRouteContextChecker().check(shardingRule, queryContext, database, mock(ConfigurationProperties.class), routeContext));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
*/
public enum SegmentType {

PROJECTION, PREDICATE, JOIN_ON, JOIN_USING, ORDER_BY, GROUP_BY, LOCK, SET_ASSIGNMENT, VALUES, INSERT_COLUMNS
PROJECTION, PREDICATE, JOIN_ON, JOIN_USING, ORDER_BY, GROUP_BY, LOCK, SET_ASSIGNMENT, VALUES, INSERT_COLUMNS, DEFINITION_COLUMNS
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.shardingsphere.infra.binder.engine.segment.column;

import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Multimap;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.infra.binder.engine.segment.SegmentType;
import org.apache.shardingsphere.infra.binder.engine.segment.expression.type.ColumnSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.from.context.TableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.engine.segment.from.type.SimpleTableSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinderContext;
import org.apache.shardingsphere.sql.parser.statement.core.segment.ddl.column.ColumnDefinitionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;

/**
* Column definition segment binder.
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class ColumnDefinitionSegmentBinder {

/**
* Bind column definition segment.
*
* @param segment column definition segment
* @param binderContext SQL statement binder context
* @param tableBinderContexts table binder contexts
* @return bound column definition segment
*/
public static ColumnDefinitionSegment bind(final ColumnDefinitionSegment segment, final SQLStatementBinderContext binderContext,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts) {
ColumnSegment boundColumnSegment = ColumnSegmentBinder.bind(segment.getColumnName(), SegmentType.DEFINITION_COLUMNS, binderContext, tableBinderContexts, LinkedHashMultimap.create());
ColumnDefinitionSegment result =
new ColumnDefinitionSegment(segment.getStartIndex(), segment.getStopIndex(), boundColumnSegment, segment.getDataType(), segment.isPrimaryKey(), segment.isNotNull(), segment.getText());
copy(segment, result);
segment.getReferencedTables().forEach(each -> result.getReferencedTables().add(SimpleTableSegmentBinder.bind(each, binderContext, tableBinderContexts)));
return result;
}

private static void copy(final ColumnDefinitionSegment result, final ColumnDefinitionSegment segment) {
result.setAutoIncrement(segment.isAutoIncrement());
result.setRef(segment.isRef());
result.setCharsetName(segment.getCharsetName());
result.setCollateName(segment.getCollateName());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public static CombineSegment bind(final CombineSegment segment, final SQLStateme
private static SubquerySegment bindSubquerySegment(final SubquerySegment segment, final SQLStatementBinderContext binderContext,
final Multimap<CaseInsensitiveMap.CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
SubquerySegment result = new SubquerySegment(segment.getStartIndex(), segment.getStopIndex(), segment.getText());
SQLStatementBinderContext subqueryBinderContext = new SQLStatementBinderContext(segment.getSelect(), binderContext.getMetaData(), binderContext.getCurrentDatabaseName());
SQLStatementBinderContext subqueryBinderContext = new SQLStatementBinderContext(binderContext.getMetaData(), binderContext.getCurrentDatabaseName(), segment.getSelect());
subqueryBinderContext.getExternalTableBinderContexts().putAll(binderContext.getExternalTableBinderContexts());
result.setSelect(new SelectStatementBinder(outerTableBinderContexts).bind(segment.getSelect(), subqueryBinderContext));
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ private static Optional<ColumnSegment> findInputColumnSegment(final ColumnSegmen
}
}
if (!isFindInputColumn) {
result = findInputColumnSegmentByVariables(segment, binderContext.getVariableNames()).orElse(null);
result = findInputColumnSegmentByVariables(segment, binderContext.getSqlStatement().getVariableNames()).orElse(null);
isFindInputColumn = null != result;
}
if (!isFindInputColumn) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public final class SubquerySegmentBinder {
*/
public static SubquerySegment bind(final SubquerySegment segment, final SQLStatementBinderContext binderContext,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
SQLStatementBinderContext selectBinderContext = new SQLStatementBinderContext(segment.getSelect(), binderContext.getMetaData(), binderContext.getCurrentDatabaseName());
SQLStatementBinderContext selectBinderContext = new SQLStatementBinderContext(binderContext.getMetaData(), binderContext.getCurrentDatabaseName(), segment.getSelect());
selectBinderContext.getExternalTableBinderContexts().putAll(binderContext.getExternalTableBinderContexts());
SelectStatement boundSelectStatement = new SelectStatementBinder(outerTableBinderContexts).bind(segment.getSelect(), selectBinderContext);
return new SubquerySegment(segment.getStartIndex(), segment.getStopIndex(), boundSelectStatement, segment.getText());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ public static JoinTableSegment bind(final JoinTableSegment segment, final SQLSta
result.setDerivedUsing(bindUsingColumns(derivedUsingColumns, tableBinderContexts));
result.getDerivedUsing().forEach(each -> binderContext.getUsingColumnNames().add(each.getIdentifier().getValue()));
}
result.getDerivedJoinTableProjectionSegments().addAll(getDerivedJoinTableProjectionSegments(result, binderContext.getDatabaseType(), usingColumnsByNaturalJoin, tableBinderContexts));
result.getDerivedJoinTableProjectionSegments()
.addAll(getDerivedJoinTableProjectionSegments(result, binderContext.getSqlStatement().getDatabaseType(), usingColumnsByNaturalJoin, tableBinderContexts));
binderContext.getJoinTableProjectionSegments().addAll(result.getDerivedJoinTableProjectionSegments());
return result;
}
Expand Down
Loading
Loading