Skip to content

Commit

Permalink
Fix GenericSQLRewriteEngineTest
Browse files Browse the repository at this point in the history
  • Loading branch information
terrymanu committed Sep 29, 2023
1 parent 3928c88 commit 629b0bb
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.NewStorageUnitMetaData;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.engine.result.GenericSQLRewriteResult;
Expand All @@ -32,9 +32,11 @@
import org.junit.jupiter.api.Test;

import java.util.Collections;
import java.util.Map;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand All @@ -44,7 +46,7 @@ class GenericSQLRewriteEngineTest {
void assertRewrite() {
DatabaseType databaseType = mock(DatabaseType.class);
SQLTranslatorRule rule = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration());
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, databaseType, Collections.singletonMap("ds_0", mockStorageUnit(databaseType)))
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, databaseType, mockStorageUnitMetaData(databaseType))
.rewrite(new SQLRewriteContext(mockDatabase(), mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class),
new HintValueContext()));
assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1"));
Expand All @@ -61,10 +63,10 @@ void assertRewriteStorageTypeIsEmpty() {
assertThat(actual.getSqlRewriteUnit().getParameters(), is(Collections.emptyList()));
}

private StorageUnit mockStorageUnit(final DatabaseType databaseType) {
StorageUnit result = mock(StorageUnit.class);
when(result.getStorageType()).thenReturn(databaseType);
return result;
private Map<String, NewStorageUnitMetaData> mockStorageUnitMetaData(final DatabaseType databaseType) {
NewStorageUnitMetaData result = mock(NewStorageUnitMetaData.class, RETURNS_DEEP_STUBS);
when(result.getStorageUnit().getStorageType()).thenReturn(databaseType);
return Collections.singletonMap("ds_0", result);
}

private ShardingSphereDatabase mockDatabase() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import org.apache.shardingsphere.infra.datanode.DataNode;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.NewStorageUnitMetaData;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.engine.result.RouteSQLRewriteResult;
Expand Down Expand Up @@ -60,7 +60,7 @@ void assertRewriteWithStandardParameterBuilder() {
routeContext.getRouteUnits().add(routeUnit);
DatabaseType databaseType = mock(DatabaseType.class);
RouteSQLRewriteResult actual = new RouteSQLRewriteEngine(
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType)).rewrite(sqlRewriteContext, routeContext);
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnitMetaData(databaseType)).rewrite(sqlRewriteContext, routeContext);
assertThat(actual.getSqlRewriteUnits().size(), is(1));
assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("SELECT ?"));
assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), is(Collections.singletonList(1)));
Expand All @@ -79,7 +79,7 @@ void assertRewriteWithStandardParameterBuilderWhenNeedAggregateRewrite() {
routeContext.getRouteUnits().add(secondRouteUnit);
DatabaseType databaseType = mock(DatabaseType.class);
RouteSQLRewriteResult actual = new RouteSQLRewriteEngine(
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType)).rewrite(sqlRewriteContext, routeContext);
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnitMetaData(databaseType)).rewrite(sqlRewriteContext, routeContext);
assertThat(actual.getSqlRewriteUnits().size(), is(1));
assertThat(actual.getSqlRewriteUnits().get(firstRouteUnit).getSql(), is("SELECT ? UNION ALL SELECT ?"));
assertThat(actual.getSqlRewriteUnits().get(firstRouteUnit).getParameters(), is(Arrays.asList(1, 1)));
Expand All @@ -99,7 +99,7 @@ void assertRewriteWithGroupedParameterBuilderForBroadcast() {
routeContext.getRouteUnits().add(routeUnit);
DatabaseType databaseType = mock(DatabaseType.class);
RouteSQLRewriteResult actual = new RouteSQLRewriteEngine(
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType)).rewrite(sqlRewriteContext, routeContext);
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnitMetaData(databaseType)).rewrite(sqlRewriteContext, routeContext);
assertThat(actual.getSqlRewriteUnits().size(), is(1));
assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("INSERT INTO tbl VALUES (?)"));
assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), is(Collections.singletonList(1)));
Expand All @@ -121,7 +121,7 @@ void assertRewriteWithGroupedParameterBuilderForRouteWithSameDataNode() {
routeContext.getOriginalDataNodes().add(Collections.singletonList(new DataNode("ds.tbl_0")));
DatabaseType databaseType = mock(DatabaseType.class);
RouteSQLRewriteResult actual = new RouteSQLRewriteEngine(
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType)).rewrite(sqlRewriteContext, routeContext);
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnitMetaData(databaseType)).rewrite(sqlRewriteContext, routeContext);
assertThat(actual.getSqlRewriteUnits().size(), is(1));
assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("INSERT INTO tbl VALUES (?)"));
assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), is(Collections.singletonList(1)));
Expand All @@ -142,7 +142,7 @@ void assertRewriteWithGroupedParameterBuilderForRouteWithEmptyDataNode() {
routeContext.getOriginalDataNodes().add(Collections.emptyList());
DatabaseType databaseType = mock(DatabaseType.class);
RouteSQLRewriteResult actual = new RouteSQLRewriteEngine(
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType)).rewrite(sqlRewriteContext, routeContext);
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnitMetaData(databaseType)).rewrite(sqlRewriteContext, routeContext);
assertThat(actual.getSqlRewriteUnits().size(), is(1));
assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("INSERT INTO tbl VALUES (?)"));
assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), is(Collections.singletonList(1)));
Expand All @@ -163,7 +163,7 @@ void assertRewriteWithGroupedParameterBuilderForRouteWithNotSameDataNode() {
routeContext.getOriginalDataNodes().add(Collections.singletonList(new DataNode("ds_1.tbl_1")));
DatabaseType databaseType = mock(DatabaseType.class);
RouteSQLRewriteResult actual = new RouteSQLRewriteEngine(
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType)).rewrite(sqlRewriteContext, routeContext);
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnitMetaData(databaseType)).rewrite(sqlRewriteContext, routeContext);
assertThat(actual.getSqlRewriteUnits().size(), is(1));
assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("INSERT INTO tbl VALUES (?)"));
assertTrue(actual.getSqlRewriteUnits().get(routeUnit).getParameters().isEmpty());
Expand All @@ -176,9 +176,9 @@ private ShardingSphereDatabase mockDatabase() {
return result;
}

private Map<String, StorageUnit> mockStorageUnits(final DatabaseType databaseType) {
StorageUnit result = mock(StorageUnit.class);
when(result.getStorageType()).thenReturn(databaseType);
private Map<String, NewStorageUnitMetaData> mockStorageUnitMetaData(final DatabaseType databaseType) {
NewStorageUnitMetaData result = mock(NewStorageUnitMetaData.class, RETURNS_DEEP_STUBS);
when(result.getStorageUnit().getStorageType()).thenReturn(databaseType);
return Collections.singletonMap("ds_0", result);
}
}

0 comments on commit 629b0bb

Please sign in to comment.