diff --git a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java index 77eb054a6ce049..5eede96db73907 100644 --- a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java +++ b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/GenericSQLRewriteEngineTest.java @@ -22,7 +22,7 @@ import org.apache.shardingsphere.infra.database.core.type.DatabaseType; import org.apache.shardingsphere.infra.hint.HintValueContext; import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; -import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit; +import org.apache.shardingsphere.infra.metadata.database.resource.unit.NewStorageUnitMetaData; import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema; import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext; import org.apache.shardingsphere.infra.rewrite.engine.result.GenericSQLRewriteResult; @@ -32,9 +32,11 @@ import org.junit.jupiter.api.Test; import java.util.Collections; +import java.util.Map; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -44,7 +46,7 @@ class GenericSQLRewriteEngineTest { void assertRewrite() { DatabaseType databaseType = mock(DatabaseType.class); SQLTranslatorRule rule = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()); - GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, databaseType, Collections.singletonMap("ds_0", mockStorageUnit(databaseType))) + GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, databaseType, mockStorageUnitMetaData(databaseType)) .rewrite(new SQLRewriteContext(mockDatabase(), mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class), new HintValueContext())); assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1")); @@ -61,10 +63,10 @@ void assertRewriteStorageTypeIsEmpty() { assertThat(actual.getSqlRewriteUnit().getParameters(), is(Collections.emptyList())); } - private StorageUnit mockStorageUnit(final DatabaseType databaseType) { - StorageUnit result = mock(StorageUnit.class); - when(result.getStorageType()).thenReturn(databaseType); - return result; + private Map mockStorageUnitMetaData(final DatabaseType databaseType) { + NewStorageUnitMetaData result = mock(NewStorageUnitMetaData.class, RETURNS_DEEP_STUBS); + when(result.getStorageUnit().getStorageType()).thenReturn(databaseType); + return Collections.singletonMap("ds_0", result); } private ShardingSphereDatabase mockDatabase() { diff --git a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java index abdd7eb2e5b08b..d442e855a71367 100644 --- a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java +++ b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java @@ -26,7 +26,7 @@ import org.apache.shardingsphere.infra.datanode.DataNode; import org.apache.shardingsphere.infra.hint.HintValueContext; import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; -import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit; +import org.apache.shardingsphere.infra.metadata.database.resource.unit.NewStorageUnitMetaData; import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema; import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext; import org.apache.shardingsphere.infra.rewrite.engine.result.RouteSQLRewriteResult; @@ -60,7 +60,7 @@ void assertRewriteWithStandardParameterBuilder() { routeContext.getRouteUnits().add(routeUnit); DatabaseType databaseType = mock(DatabaseType.class); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( - new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType)).rewrite(sqlRewriteContext, routeContext); + new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnitMetaData(databaseType)).rewrite(sqlRewriteContext, routeContext); assertThat(actual.getSqlRewriteUnits().size(), is(1)); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("SELECT ?")); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), is(Collections.singletonList(1))); @@ -79,7 +79,7 @@ void assertRewriteWithStandardParameterBuilderWhenNeedAggregateRewrite() { routeContext.getRouteUnits().add(secondRouteUnit); DatabaseType databaseType = mock(DatabaseType.class); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( - new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType)).rewrite(sqlRewriteContext, routeContext); + new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnitMetaData(databaseType)).rewrite(sqlRewriteContext, routeContext); assertThat(actual.getSqlRewriteUnits().size(), is(1)); assertThat(actual.getSqlRewriteUnits().get(firstRouteUnit).getSql(), is("SELECT ? UNION ALL SELECT ?")); assertThat(actual.getSqlRewriteUnits().get(firstRouteUnit).getParameters(), is(Arrays.asList(1, 1))); @@ -99,7 +99,7 @@ void assertRewriteWithGroupedParameterBuilderForBroadcast() { routeContext.getRouteUnits().add(routeUnit); DatabaseType databaseType = mock(DatabaseType.class); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( - new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType)).rewrite(sqlRewriteContext, routeContext); + new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnitMetaData(databaseType)).rewrite(sqlRewriteContext, routeContext); assertThat(actual.getSqlRewriteUnits().size(), is(1)); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("INSERT INTO tbl VALUES (?)")); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), is(Collections.singletonList(1))); @@ -121,7 +121,7 @@ void assertRewriteWithGroupedParameterBuilderForRouteWithSameDataNode() { routeContext.getOriginalDataNodes().add(Collections.singletonList(new DataNode("ds.tbl_0"))); DatabaseType databaseType = mock(DatabaseType.class); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( - new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType)).rewrite(sqlRewriteContext, routeContext); + new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnitMetaData(databaseType)).rewrite(sqlRewriteContext, routeContext); assertThat(actual.getSqlRewriteUnits().size(), is(1)); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("INSERT INTO tbl VALUES (?)")); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), is(Collections.singletonList(1))); @@ -142,7 +142,7 @@ void assertRewriteWithGroupedParameterBuilderForRouteWithEmptyDataNode() { routeContext.getOriginalDataNodes().add(Collections.emptyList()); DatabaseType databaseType = mock(DatabaseType.class); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( - new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType)).rewrite(sqlRewriteContext, routeContext); + new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnitMetaData(databaseType)).rewrite(sqlRewriteContext, routeContext); assertThat(actual.getSqlRewriteUnits().size(), is(1)); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("INSERT INTO tbl VALUES (?)")); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), is(Collections.singletonList(1))); @@ -163,7 +163,7 @@ void assertRewriteWithGroupedParameterBuilderForRouteWithNotSameDataNode() { routeContext.getOriginalDataNodes().add(Collections.singletonList(new DataNode("ds_1.tbl_1"))); DatabaseType databaseType = mock(DatabaseType.class); RouteSQLRewriteResult actual = new RouteSQLRewriteEngine( - new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType)).rewrite(sqlRewriteContext, routeContext); + new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnitMetaData(databaseType)).rewrite(sqlRewriteContext, routeContext); assertThat(actual.getSqlRewriteUnits().size(), is(1)); assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("INSERT INTO tbl VALUES (?)")); assertTrue(actual.getSqlRewriteUnits().get(routeUnit).getParameters().isEmpty()); @@ -176,9 +176,9 @@ private ShardingSphereDatabase mockDatabase() { return result; } - private Map mockStorageUnits(final DatabaseType databaseType) { - StorageUnit result = mock(StorageUnit.class); - when(result.getStorageType()).thenReturn(databaseType); + private Map mockStorageUnitMetaData(final DatabaseType databaseType) { + NewStorageUnitMetaData result = mock(NewStorageUnitMetaData.class, RETURNS_DEEP_STUBS); + when(result.getStorageUnit().getStorageType()).thenReturn(databaseType); return Collections.singletonMap("ds_0", result); } }