Skip to content

Commit

Permalink
Add TableSegmentBoundInfo for TableNameSegment and modify TablesConte…
Browse files Browse the repository at this point in the history
…xt logic (#34026)

* Add TableSegmentBoundInfo for TableNameSegment and modify TablesContext logic

* fix unit test

* fix unit test

* fix unit test
  • Loading branch information
strongduanmu authored Dec 13, 2024
1 parent 9b9af8e commit f23d908
Show file tree
Hide file tree
Showing 108 changed files with 496 additions and 368 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.ddl.CloseStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableNameSegment;
import org.apache.shardingsphere.sql.parser.statement.core.statement.dal.DALStatement;
Expand Down Expand Up @@ -137,7 +135,6 @@ void assertNewInstanceWithUpdateStatementAndIsAllBroadcastTables() {
}

private TablesContext createTablesContext() {
DatabaseType databaseType = TypedSPILoader.getService(DatabaseType.class, "FIXTURE");
return new TablesContext(Collections.singleton(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("foo_tbl")))), databaseType, null);
return new TablesContext(Collections.singleton(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("foo_tbl")))), null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ private SelectStatementContext mockSelectStatementContext(final String tableName
when(result.getOrderByContext().getItems()).thenReturn(Collections.singleton(orderByItem));
when(result.getGroupByContext().getItems()).thenReturn(Collections.emptyList());
when(result.getSubqueryContexts().values()).thenReturn(Collections.emptyList());
when(result.getTablesContext()).thenReturn(new TablesContext(Collections.singleton(simpleTableSegment), databaseType, "foo_db"));
when(result.getTablesContext()).thenReturn(new TablesContext(Collections.singleton(simpleTableSegment)));
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ public static UpdateStatementContext createUpdateStatementContext() {
updateStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_user"))));
updateStatement.setWhere(createWhereSegment());
updateStatement.setSetAssignment(createSetAssignmentSegment());
return new UpdateStatementContext(updateStatement, "foo_db");
return new UpdateStatementContext(updateStatement);
}

private static WhereSegment createWhereSegment() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ void assertGenerateSQLTokensWhenOwnerMatchTableAlias() {
when(sqlStatementContext.getSqlStatement().getProjections()).thenReturn(projections);
when(sqlStatementContext.getSubqueryContexts().values()).thenReturn(Collections.emptyList());
SimpleTableSegment doctorOneTable = new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("doctor1")));
when(sqlStatementContext.getTablesContext()).thenReturn(new TablesContext(Arrays.asList(doctorTable, doctorOneTable), databaseType, "foo_db"));
when(sqlStatementContext.getTablesContext()).thenReturn(new TablesContext(Arrays.asList(doctorTable, doctorOneTable)));
when(sqlStatementContext.getProjectionsContext().getProjections()).thenReturn(Collections.singleton(new ColumnProjection("a", "mobile", null, databaseType)));
Collection<SQLToken> actual = generator.generateSQLTokens(sqlStatementContext);
assertThat(actual.size(), is(1));
Expand All @@ -114,7 +114,7 @@ void assertGenerateSQLTokensWhenOwnerMatchTableAliasForSameTable() {
when(sqlStatementContext.getSqlStatement().getProjections()).thenReturn(projections);
when(sqlStatementContext.getSubqueryContexts().values()).thenReturn(Collections.emptyList());
SimpleTableSegment sameDoctorTable = new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("doctor")));
when(sqlStatementContext.getTablesContext()).thenReturn(new TablesContext(Arrays.asList(doctorTable, sameDoctorTable), databaseType, "foo_db"));
when(sqlStatementContext.getTablesContext()).thenReturn(new TablesContext(Arrays.asList(doctorTable, sameDoctorTable)));
when(sqlStatementContext.getProjectionsContext().getProjections()).thenReturn(Collections.singleton(new ColumnProjection("a", "mobile", null, databaseType)));
Collection<SQLToken> actual = generator.generateSQLTokens(sqlStatementContext);
assertThat(actual.size(), is(1));
Expand All @@ -133,7 +133,7 @@ void assertGenerateSQLTokensWhenOwnerMatchTableName() {
when(sqlStatementContext.getSubqueryContexts().values()).thenReturn(Collections.emptyList());
SimpleTableSegment doctorTable = new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("doctor")));
SimpleTableSegment doctorOneTable = new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("doctor1")));
when(sqlStatementContext.getTablesContext()).thenReturn(new TablesContext(Arrays.asList(doctorTable, doctorOneTable), databaseType, "foo_db"));
when(sqlStatementContext.getTablesContext()).thenReturn(new TablesContext(Arrays.asList(doctorTable, doctorOneTable)));
when(sqlStatementContext.getProjectionsContext().getProjections()).thenReturn(Collections.singleton(new ColumnProjection("doctor", "mobile", null, databaseType)));
Collection<SQLToken> actual = generator.generateSQLTokens(sqlStatementContext);
assertThat(actual.size(), is(1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ private SelectStatementContext buildSelectStatementContext() {
OrderByItem orderByItem = new OrderByItem(columnOrderByItemSegment);
when(result.getGroupByContext().getItems()).thenReturn(Collections.singleton(orderByItem));
when(result.getSubqueryContexts().values()).thenReturn(Collections.emptyList());
when(result.getTablesContext()).thenReturn(new TablesContext(Collections.singleton(simpleTableSegment), databaseType, "foo_db"));
when(result.getTablesContext()).thenReturn(new TablesContext(Collections.singleton(simpleTableSegment)));
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@
@HighFrequencyInvocation
public final class ShadowDMLStatementDataSourceMappingsRetriever implements ShadowDataSourceMappingsRetriever {

private final Map<String, String> tableAliasNameMap;
private final Collection<String> tableNames;

private final ShadowTableHintDataSourceMappingsRetriever tableHintDataSourceMappingsRetriever;

private final ShadowColumnDataSourceMappingsRetriever shadowColumnDataSourceMappingsRetriever;

public ShadowDMLStatementDataSourceMappingsRetriever(final QueryContext queryContext, final ShadowOperationType operationType) {
tableAliasNameMap = ((TableAvailable) queryContext.getSqlStatementContext()).getTablesContext().getTableAliasNameMap();
tableNames = ((TableAvailable) queryContext.getSqlStatementContext()).getTablesContext().getTableNames();
tableHintDataSourceMappingsRetriever = new ShadowTableHintDataSourceMappingsRetriever(operationType, queryContext.getHintValueContext().isShadow());
shadowColumnDataSourceMappingsRetriever = createShadowDataSourceMappingsRetriever(queryContext);
}
Expand All @@ -73,7 +73,7 @@ private ShadowColumnDataSourceMappingsRetriever createShadowDataSourceMappingsRe

@Override
public Map<String, String> retrieve(final ShadowRule rule) {
Collection<String> shadowTables = rule.filterShadowTables(tableAliasNameMap.values());
Collection<String> shadowTables = rule.filterShadowTables(tableNames);
Map<String, String> result = tableHintDataSourceMappingsRetriever.retrieve(rule, shadowTables);
return result.isEmpty() && null != shadowColumnDataSourceMappingsRetriever ? shadowColumnDataSourceMappingsRetriever.retrieve(rule, shadowTables) : result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,16 @@
import org.apache.shardingsphere.shadow.route.retriever.dml.table.column.ShadowColumnDataSourceMappingsRetriever;
import org.apache.shardingsphere.shadow.route.util.ShadowExtractor;
import org.apache.shardingsphere.shadow.spi.ShadowOperationType;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ColumnExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ExpressionExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.AndPredicate;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.OwnerSegment;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ColumnExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ExpressionExtractor;

import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;

/**
* Shadow select statement data source mappings retriever.
Expand All @@ -60,7 +58,8 @@ protected Collection<ShadowColumnCondition> getShadowColumnConditions(final Stri
if (1 != columns.size()) {
continue;
}
ShadowExtractor.extractValues(each, parameters).map(values -> new ShadowColumnCondition(getOwnerTableName(columns.iterator().next()), shadowColumnName, values)).ifPresent(result::add);
ShadowExtractor.extractValues(each, parameters).map(values -> new ShadowColumnCondition(
columns.iterator().next().getColumnBoundInfo().getOriginalTable().getValue(), shadowColumnName, values)).ifPresent(result::add);
}
return result;
}
Expand All @@ -74,11 +73,4 @@ private Collection<ExpressionSegment> getWhereSegment() {
}
return result;
}

private String getOwnerTableName(final ColumnSegment columnSegment) {
Optional<OwnerSegment> owner = columnSegment.getOwner();
return owner.isPresent()
? sqlStatementContext.getTablesContext().getTableAliasNameMap().get(owner.get().getIdentifier().getValue())
: sqlStatementContext.getTablesContext().getTableNames().iterator().next();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.shadow.condition.ShadowColumnCondition;
import org.apache.shardingsphere.shadow.route.util.ShadowExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ColumnExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.OwnerSegment;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ColumnExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.TableSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
import org.apache.shardingsphere.test.mock.AutoMockExtension;
import org.apache.shardingsphere.test.mock.StaticMockSettings;
Expand All @@ -36,7 +37,7 @@
import java.util.Collections;
import java.util.Optional;

import static org.apache.shardingsphere.test.matcher.ShardingSphereAssertionMatchers.deepEqual;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
Expand All @@ -53,16 +54,22 @@ void assertRetrieveWithColumnOwner() {
ExpressionSegment expressionSegment = mock(ExpressionSegment.class);
when(whereSegment.getExpr()).thenReturn(expressionSegment);
when(sqlStatementContext.getWhereSegments()).thenReturn(Arrays.asList(whereSegment, mock(WhereSegment.class, RETURNS_DEEP_STUBS)));
ColumnSegment columnSegment = mock(ColumnSegment.class);
when(columnSegment.getOwner()).thenReturn(Optional.of(new OwnerSegment(0, 0, new IdentifierValue("foo"))));
ColumnSegment columnSegment = mock(ColumnSegment.class, RETURNS_DEEP_STUBS);
when(columnSegment.getColumnBoundInfo().getOriginalTable().getValue()).thenReturn("foo_tbl");
OwnerSegment ownerSegment = new OwnerSegment(0, 0, new IdentifierValue("foo"));
ownerSegment.setTableBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema")));
when(columnSegment.getOwner()).thenReturn(Optional.of(ownerSegment));
when(ColumnExtractor.extract(expressionSegment)).thenReturn(Collections.singleton(columnSegment));
when(ShadowExtractor.extractValues(expressionSegment, Collections.singletonList("foo"))).thenReturn(Optional.of(Collections.singleton("foo")));
when(sqlStatementContext.getTablesContext().getTableNames()).thenReturn(Collections.singleton("foo_tbl"));
when(sqlStatementContext.getTablesContext().getTableAliasNameMap().get("foo")).thenReturn("foo_tbl");
ShadowSelectStatementDataSourceMappingsRetriever retriever = new ShadowSelectStatementDataSourceMappingsRetriever(sqlStatementContext, Collections.singletonList("foo"));
Collection<ShadowColumnCondition> actual = retriever.getShadowColumnConditions("foo_col");
Collection<ShadowColumnCondition> expected = Collections.singletonList(new ShadowColumnCondition("foo_tbl", "foo_col", Collections.singleton("foo")));
assertThat(actual, deepEqual(expected));
assertThat(actual.size(), is(1));
ShadowColumnCondition actualCondition = actual.iterator().next();
assertThat(actualCondition.getTable(), is("foo_tbl"));
assertThat(actualCondition.getColumn(), is("foo_col"));
assertThat(actualCondition.getValues(), is(Collections.singleton("foo")));

}

@Test
Expand All @@ -72,14 +79,18 @@ void assertRetrieveWithoutColumnOwner() {
ExpressionSegment expressionSegment = mock(ExpressionSegment.class);
when(whereSegment.getExpr()).thenReturn(expressionSegment);
when(sqlStatementContext.getWhereSegments()).thenReturn(Arrays.asList(whereSegment, mock(WhereSegment.class, RETURNS_DEEP_STUBS)));
ColumnSegment columnSegment = mock(ColumnSegment.class);
ColumnSegment columnSegment = mock(ColumnSegment.class, RETURNS_DEEP_STUBS);
when(columnSegment.getColumnBoundInfo().getOriginalTable().getValue()).thenReturn("foo_tbl");
when(columnSegment.getOwner()).thenReturn(Optional.empty());
when(ColumnExtractor.extract(expressionSegment)).thenReturn(Collections.singleton(columnSegment));
when(ShadowExtractor.extractValues(expressionSegment, Collections.singletonList("foo"))).thenReturn(Optional.of(Collections.singleton("foo")));
when(sqlStatementContext.getTablesContext().getTableNames()).thenReturn(Collections.singleton("foo_tbl"));
ShadowSelectStatementDataSourceMappingsRetriever retriever = new ShadowSelectStatementDataSourceMappingsRetriever(sqlStatementContext, Collections.singletonList("foo"));
Collection<ShadowColumnCondition> actual = retriever.getShadowColumnConditions("foo_col");
Collection<ShadowColumnCondition> expected = Collections.singletonList(new ShadowColumnCondition("foo_tbl", "foo_col", Collections.singleton("foo")));
assertThat(actual, deepEqual(expected));
assertThat(actual.size(), is(1));
ShadowColumnCondition actualCondition = actual.iterator().next();
assertThat(actualCondition.getTable(), is("foo_tbl"));
assertThat(actualCondition.getColumn(), is("foo_col"));
assertThat(actualCondition.getValues(), is(Collections.singleton("foo")));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ void assertCheckWhenIndexExistRenameIndexNotExistForPostgreSQL() {
when(schema.getAllTables()).thenReturn(Collections.singleton(table));
when(table.containsIndex("t_order_index")).thenReturn(true);
when(table.containsIndex("t_order_index_new")).thenReturn(false);
assertDoesNotThrow(() -> new ShardingAlterIndexSupportedChecker().check(rule, database, schema, new AlterIndexStatementContext(sqlStatement, "foo_db")));
assertDoesNotThrow(() -> new ShardingAlterIndexSupportedChecker().check(rule, database, schema, new AlterIndexStatementContext(sqlStatement)));
}

@Test
Expand All @@ -70,7 +70,7 @@ void assertCheckWhenIndexNotExistRenameIndexNotExistForPostgreSQL() {
sqlStatement.setRenameIndex(new IndexSegment(0, 0, new IndexNameSegment(0, 0, new IdentifierValue("t_order_index_new"))));
ShardingSphereTable table = mock(ShardingSphereTable.class);
when(database.getSchema("public").getTable("t_order")).thenReturn(table);
assertThrows(IndexNotFoundException.class, () -> new ShardingAlterIndexSupportedChecker().check(rule, database, mock(), new AlterIndexStatementContext(sqlStatement, "foo_db")));
assertThrows(IndexNotFoundException.class, () -> new ShardingAlterIndexSupportedChecker().check(rule, database, mock(), new AlterIndexStatementContext(sqlStatement)));
}

@Test
Expand All @@ -84,6 +84,6 @@ void assertCheckAlterIndexWhenIndexExistRenameIndexExistForPostgreSQL() {
when(table.containsIndex("t_order_index")).thenReturn(true);
when(table.containsIndex("t_order_index_new")).thenReturn(true);
assertThrows(DuplicateIndexException.class,
() -> new ShardingAlterIndexSupportedChecker().check(rule, database, schema, new AlterIndexStatementContext(sqlStatement, "foo_db")));
() -> new ShardingAlterIndexSupportedChecker().check(rule, database, schema, new AlterIndexStatementContext(sqlStatement)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void assertCheckWithRenameTableWithShardingTable() {
PostgreSQLAlterTableStatement sqlStatement = new PostgreSQLAlterTableStatement();
sqlStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))));
sqlStatement.setRenameTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order_new"))));
AlterTableStatementContext sqlStatementContext = new AlterTableStatementContext(sqlStatement, "foo_db");
AlterTableStatementContext sqlStatementContext = new AlterTableStatementContext(sqlStatement);
when(rule.containsShardingTable(Arrays.asList("t_order", "t_order_new"))).thenReturn(true);
assertThrows(UnsupportedShardingOperationException.class, () -> new ShardingAlterTableSupportedChecker().check(rule, database, mock(), sqlStatementContext));
}
Expand Down
Loading

0 comments on commit f23d908

Please sign in to comment.