Skip to content

Commit

Permalink
Remove DatabaseTypeEngine.getStorageType with data sources (#28644)
Browse files Browse the repository at this point in the history
* Remove DatabaseTypeEngine.getStorageType with data sources

* Refactor StorageUnit.storageType

* Fix test cases

* Fix test cases
  • Loading branch information
terrymanu authored Oct 5, 2023
1 parent b5e7f6d commit 00e2c08
Show file tree
Hide file tree
Showing 14 changed files with 58 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ public final class DatabaseTypeEngine {
* @return protocol type
*/
public static DatabaseType getProtocolType(final String databaseName, final DatabaseConfiguration databaseConfig, final ConfigurationProperties props) {
return findConfiguredDatabaseType(props).orElseGet(() -> getStorageType(DataSourceStateManager.getInstance().getEnabledDataSources(databaseName, databaseConfig)));
Optional<DatabaseType> configuredDatabaseType = findConfiguredDatabaseType(props);
if (configuredDatabaseType.isPresent()) {
return configuredDatabaseType.get();
}
Collection<DataSource> enabledDataSources = DataSourceStateManager.getInstance().getEnabledDataSources(databaseName, databaseConfig);
return enabledDataSources.isEmpty() ? getDefaultStorageType() : getStorageType(enabledDataSources.iterator().next());
}

/**
Expand All @@ -66,7 +71,11 @@ public static DatabaseType getProtocolType(final String databaseName, final Data
*/
public static DatabaseType getProtocolType(final Map<String, ? extends DatabaseConfiguration> databaseConfigs, final ConfigurationProperties props) {
Optional<DatabaseType> configuredDatabaseType = findConfiguredDatabaseType(props);
return configuredDatabaseType.orElseGet(() -> getStorageType(getEnabledDataSources(databaseConfigs).values()));
if (configuredDatabaseType.isPresent()) {
return configuredDatabaseType.get();
}
Map<String, DataSource> enabledDataSources = getEnabledDataSources(databaseConfigs);
return enabledDataSources.isEmpty() ? getDefaultStorageType() : getStorageType(enabledDataSources.values().iterator().next());
}

private static Optional<DatabaseType> findConfiguredDatabaseType(final ConfigurationProperties props) {
Expand Down Expand Up @@ -101,18 +110,24 @@ public static Map<String, DatabaseType> getStorageTypes(final String databaseNam
/**
* Get storage type.
*
* @param dataSources data sources
* @param dataSource data source
* @return storage type
* @throws SQLWrapperException SQL wrapper exception
*/
public static DatabaseType getStorageType(final Collection<DataSource> dataSources) {
return dataSources.isEmpty() ? TypedSPILoader.getService(DatabaseType.class, DEFAULT_DATABASE_TYPE) : getStorageType(dataSources.iterator().next());
}

private static DatabaseType getStorageType(final DataSource dataSource) {
public static DatabaseType getStorageType(final DataSource dataSource) {
try (Connection connection = dataSource.getConnection()) {
return DatabaseTypeFactory.get(connection.getMetaData().getURL());
} catch (final SQLException ex) {
throw new SQLWrapperException(ex);
}
}

/**
* Get default storage type.
*
* @return default storage type
*/
public static DatabaseType getDefaultStorageType() {
return TypedSPILoader.getService(DatabaseType.class, DEFAULT_DATABASE_TYPE);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
package org.apache.shardingsphere.infra.metadata.database.resource.unit;

import lombok.Getter;
import org.apache.shardingsphere.infra.database.DatabaseTypeEngine;
import org.apache.shardingsphere.infra.database.core.connector.ConnectionProperties;
import org.apache.shardingsphere.infra.database.core.connector.ConnectionPropertiesParser;
import org.apache.shardingsphere.infra.database.core.spi.DatabaseTypedSPILoader;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeFactory;
import org.apache.shardingsphere.infra.datasource.pool.CatalogSwitchableDataSource;
import org.apache.shardingsphere.infra.datasource.pool.props.domain.DataSourcePoolProperties;
import org.apache.shardingsphere.infra.metadata.database.resource.node.StorageNode;
Expand Down Expand Up @@ -52,15 +52,11 @@ public StorageUnit(final String databaseName, final StorageNode storageNode, fin
this.storageNode = storageNode;
this.dataSource = new CatalogSwitchableDataSource(dataSource, storageNode.getCatalog(), storageNode.getUrl());
this.dataSourcePoolProperties = dataSourcePoolProperties;
storageType = DatabaseTypeFactory.get(storageNode.getUrl());
boolean isDataSourceEnabled = !DataSourceStateManager.getInstance().getEnabledDataSources(databaseName, Collections.singletonMap(storageNode.getName().getName(), dataSource)).isEmpty();
storageType = createStorageType(isDataSourceEnabled);
connectionProperties = createConnectionProperties(isDataSourceEnabled, storageNode);
}

private DatabaseType createStorageType(final boolean isDataSourceEnabled) {
return DatabaseTypeEngine.getStorageType(isDataSourceEnabled ? Collections.singleton(dataSource) : Collections.emptyList());
}

private ConnectionProperties createConnectionProperties(final boolean isDataSourceEnabled, final StorageNode storageNode) {
if (!isDataSourceEnabled) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.database.mysql.type.MySQLDatabaseType;
import org.apache.shardingsphere.infra.database.postgresql.type.PostgreSQLDatabaseType;
import org.apache.shardingsphere.infra.fixture.FixtureRuleConfiguration;
import org.apache.shardingsphere.infra.exception.core.external.sql.type.wrapper.SQLWrapperException;
import org.apache.shardingsphere.infra.fixture.FixtureRuleConfiguration;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.test.fixture.jdbc.MockedDataSource;
import org.apache.shardingsphere.test.util.PropertiesBuilder;
Expand All @@ -35,8 +35,6 @@
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Properties;

Expand Down Expand Up @@ -76,29 +74,20 @@ void assertGetStorageTypes() throws SQLException {
}

@Test
void assertGetStorageTypeWithEmptyDataSources() {
assertThat(DatabaseTypeEngine.getStorageType(Collections.emptyList()).getType(), is("MySQL"));
}

@Test
void assertGetStorageTypeWithDataSources() throws SQLException {
Collection<DataSource> dataSources = Arrays.asList(mockDataSource(TypedSPILoader.getService(DatabaseType.class, "H2")),
mockDataSource(TypedSPILoader.getService(DatabaseType.class, "H2")));
assertThat(DatabaseTypeEngine.getStorageType(dataSources).getType(), is("H2"));
}

@Test
void assertGetStorageTypeWithDifferentDataSourceTypes() throws SQLException {
Collection<DataSource> dataSources = Arrays.asList(mockDataSource(TypedSPILoader.getService(DatabaseType.class, "H2")),
mockDataSource(TypedSPILoader.getService(DatabaseType.class, "MySQL")));
assertThat(DatabaseTypeEngine.getStorageType(dataSources).getType(), is("H2"));
void assertGetStorageType() throws SQLException {
assertThat(DatabaseTypeEngine.getStorageType(mockDataSource(TypedSPILoader.getService(DatabaseType.class, "H2"))).getType(), is("H2"));
}

@Test
void assertGetStorageTypeWhenGetConnectionError() throws SQLException {
DataSource dataSource = mock(DataSource.class);
when(dataSource.getConnection()).thenThrow(SQLException.class);
assertThrows(SQLWrapperException.class, () -> DatabaseTypeEngine.getStorageType(Collections.singleton(dataSource)));
assertThrows(SQLWrapperException.class, () -> DatabaseTypeEngine.getStorageType(dataSource));
}

@Test
void assertGetDefaultStorageTypeWithEmptyDataSources() {
assertThat(DatabaseTypeEngine.getDefaultStorageType().getType(), is("MySQL"));
}

private DataSource mockDataSource(final DatabaseType databaseType) throws SQLException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.ResourceMetaData;
import org.apache.shardingsphere.infra.metadata.database.resource.node.StorageNode;
import org.apache.shardingsphere.infra.metadata.database.resource.node.StorageNodeName;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.rule.identifier.type.ResourceHeldRule;
Expand Down Expand Up @@ -91,7 +92,7 @@ private ShardingSphereDatabase mockDatabase(final ResourceMetaData resourceMetaD
when(result.getResourceMetaData()).thenReturn(resourceMetaData);
DataSourcePoolProperties dataSourcePoolProps = mock(DataSourcePoolProperties.class, RETURNS_DEEP_STUBS);
when(dataSourcePoolProps.getConnectionPropertySynonyms().getStandardProperties()).thenReturn(Collections.emptyMap());
StorageUnit storageUnit = new StorageUnit("foo_db", mock(StorageNode.class, RETURNS_DEEP_STUBS), dataSourcePoolProps, dataSource);
StorageUnit storageUnit = new StorageUnit("foo_db", new StorageNode(mock(StorageNodeName.class), "jdbc:mock://127.0.0.1/foo_ds"), dataSourcePoolProps, dataSource);
when(result.getResourceMetaData().getStorageUnits()).thenReturn(Collections.singletonMap("foo_db", storageUnit));
when(result.getRuleMetaData()).thenReturn(new RuleMetaData(Collections.singleton(databaseResourceHeldRule)));
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.apache.shardingsphere.transaction.spi.TransactionHook;

import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import java.util.Properties;
Expand All @@ -52,7 +51,7 @@ public GlobalClockRule(final GlobalClockRuleConfiguration ruleConfig, final Map<

private Properties createProperties(final Map<String, ShardingSphereDatabase> databases) {
Properties result = new Properties();
DatabaseType storageType = findStorageType(databases.values()).orElseGet(() -> DatabaseTypeEngine.getStorageType(Collections.emptyList()));
DatabaseType storageType = findStorageType(databases.values()).orElseGet(DatabaseTypeEngine::getDefaultStorageType);
result.setProperty("trunkType", storageType.getTrunkDatabaseType().orElse(storageType).getType());
result.setProperty("enabled", String.valueOf(configuration.isEnabled()));
result.setProperty("type", configuration.getType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ private Collection<String> decorateTables(final String databaseName, final Map<S
}
Map<String, DataSource> enabledDataSources = DataSourceStateManager.getInstance().getEnabledDataSources(databaseName, dataSources);
Map<String, DataSource> aggregatedDataSources = SingleTableLoadUtils.getAggregatedDataSourceMap(enabledDataSources, builtRules);
DatabaseType databaseType = DatabaseTypeEngine.getStorageType(enabledDataSources.values());
DatabaseType databaseType = enabledDataSources.isEmpty() ? DatabaseTypeEngine.getDefaultStorageType() : DatabaseTypeEngine.getStorageType(enabledDataSources.values().iterator().next());
Collection<String> excludedTables = SingleTableLoadUtils.getExcludedTables(builtRules);
Map<String, Collection<DataNode>> actualDataNodes = SingleTableDataNodeLoader.load(databaseName, databaseType, aggregatedDataSources, excludedTables);
Collection<DataNode> configuredDataNodes = SingleTableLoadUtils.convertToDataNodes(databaseName, databaseType, splitTables);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public SingleRule(final SingleRuleConfiguration ruleConfig, final String databas
Map<String, DataSource> enabledDataSources = DataSourceStateManager.getInstance().getEnabledDataSources(databaseName, dataSourceMap);
Map<String, DataSource> aggregateDataSourceMap = SingleTableLoadUtils.getAggregatedDataSourceMap(enabledDataSources, builtRules);
dataSourceNames = aggregateDataSourceMap.keySet();
databaseType = DatabaseTypeEngine.getStorageType(enabledDataSources.values());
databaseType = enabledDataSources.isEmpty() ? DatabaseTypeEngine.getDefaultStorageType() : DatabaseTypeEngine.getStorageType(enabledDataSources.values().iterator().next());
singleTableDataNodes = SingleTableDataNodeLoader.load(databaseName, databaseType, aggregateDataSourceMap, builtRules, configuration.getTables());
singleTableDataNodes.forEach((key, value) -> tableNamesMapper.put(value.iterator().next().getTableName()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.ResourceMetaData;
import org.apache.shardingsphere.infra.metadata.database.resource.node.StorageNode;
import org.apache.shardingsphere.infra.metadata.database.resource.node.StorageNodeName;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
Expand Down Expand Up @@ -150,8 +151,8 @@ private ShardingSphereDatabase mockDatabaseWithMultipleResources() {
Map<String, StorageUnit> storageUnits = new HashMap<>(2, 1F);
DataSourcePoolProperties dataSourcePoolProps = mock(DataSourcePoolProperties.class, RETURNS_DEEP_STUBS);
when(dataSourcePoolProps.getConnectionPropertySynonyms().getStandardProperties()).thenReturn(Collections.emptyMap());
storageUnits.put("ds_0", new StorageUnit("foo_db", mock(StorageNode.class, RETURNS_DEEP_STUBS), dataSourcePoolProps, new MockedDataSource()));
storageUnits.put("ds_1", new StorageUnit("foo_db", mock(StorageNode.class, RETURNS_DEEP_STUBS), dataSourcePoolProps, new MockedDataSource()));
storageUnits.put("ds_0", new StorageUnit("foo_db", new StorageNode(mock(StorageNodeName.class), "jdbc:mock://127.0.0.1/ds_0"), dataSourcePoolProps, new MockedDataSource()));
storageUnits.put("ds_1", new StorageUnit("foo_db", new StorageNode(mock(StorageNodeName.class), "jdbc:mock://127.0.0.1/ds_1"), dataSourcePoolProps, new MockedDataSource()));
ShardingSphereDatabase result = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS);
when(result.getResourceMetaData().getStorageUnits()).thenReturn(storageUnits);
when(result.getName()).thenReturn(DefaultDatabase.LOGIC_NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

package org.apache.shardingsphere.timeservice.type.database;

import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.database.DatabaseTypeEngine;
import org.apache.shardingsphere.infra.datasource.pool.creator.DataSourcePoolCreator;
import org.apache.shardingsphere.infra.database.core.spi.DatabaseTypedSPILoader;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.datasource.pool.creator.DataSourcePoolCreator;
import org.apache.shardingsphere.infra.yaml.config.swapper.resource.YamlDataSourceConfigurationSwapper;
import org.apache.shardingsphere.timeservice.spi.TimestampService;
import org.apache.shardingsphere.timeservice.type.database.exception.DatetimeLoadingException;
Expand All @@ -32,7 +32,6 @@
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.util.Collections;
import java.util.Map.Entry;
import java.util.Properties;
import java.util.stream.Collectors;
Expand All @@ -50,7 +49,7 @@ public final class DatabaseTimestampService implements TimestampService {
public void init(final Properties props) {
dataSource = DataSourcePoolCreator.create(new YamlDataSourceConfigurationSwapper().swapToDataSourcePoolProperties(
props.entrySet().stream().collect(Collectors.toMap(entry -> entry.getKey().toString(), Entry::getValue))));
storageType = DatabaseTypeEngine.getStorageType(Collections.singleton(dataSource));
storageType = DatabaseTypeEngine.getStorageType(dataSource);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.ResourceMetaData;
import org.apache.shardingsphere.infra.metadata.database.resource.node.StorageNode;
import org.apache.shardingsphere.infra.metadata.database.resource.node.StorageNodeName;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.test.fixture.jdbc.MockedDataSource;
import org.apache.shardingsphere.transaction.api.TransactionType;
Expand Down Expand Up @@ -101,8 +102,8 @@ private ResourceMetaData createResourceMetaData() {
Map<String, StorageUnit> storageUnits = new HashMap<>(2, 1F);
DataSourcePoolProperties dataSourcePoolProps = mock(DataSourcePoolProperties.class, RETURNS_DEEP_STUBS);
when(dataSourcePoolProps.getConnectionPropertySynonyms().getStandardProperties()).thenReturn(Collections.emptyMap());
storageUnits.put("ds_0", new StorageUnit("foo_db", mock(StorageNode.class, RETURNS_DEEP_STUBS), dataSourcePoolProps, new MockedDataSource()));
storageUnits.put("ds_1", new StorageUnit("foo_db", mock(StorageNode.class, RETURNS_DEEP_STUBS), dataSourcePoolProps, new MockedDataSource()));
storageUnits.put("ds_0", new StorageUnit("foo_db", new StorageNode(mock(StorageNodeName.class), "jdbc:mock://127.0.0.1/ds_0"), dataSourcePoolProps, new MockedDataSource()));
storageUnits.put("ds_1", new StorageUnit("foo_db", new StorageNode(mock(StorageNodeName.class), "jdbc:mock://127.0.0.1/ds_1"), dataSourcePoolProps, new MockedDataSource()));
ResourceMetaData result = mock(ResourceMetaData.class, RETURNS_DEEP_STUBS);
when(result.getStorageUnits()).thenReturn(storageUnits);
return result;
Expand All @@ -120,8 +121,8 @@ private ResourceMetaData createAddResourceMetaData() {
Map<String, StorageUnit> storageUnits = new HashMap<>(2, 1F);
DataSourcePoolProperties dataSourcePoolProps = mock(DataSourcePoolProperties.class, RETURNS_DEEP_STUBS);
when(dataSourcePoolProps.getConnectionPropertySynonyms().getStandardProperties()).thenReturn(Collections.emptyMap());
storageUnits.put("ds_0", new StorageUnit("foo_db", mock(StorageNode.class, RETURNS_DEEP_STUBS), dataSourcePoolProps, new MockedDataSource()));
storageUnits.put("ds_1", new StorageUnit("foo_db", mock(StorageNode.class, RETURNS_DEEP_STUBS), dataSourcePoolProps, new MockedDataSource()));
storageUnits.put("ds_0", new StorageUnit("foo_db", new StorageNode(mock(StorageNodeName.class), "jdbc:mock://127.0.0.1/ds_0"), dataSourcePoolProps, new MockedDataSource()));
storageUnits.put("ds_1", new StorageUnit("foo_db", new StorageNode(mock(StorageNodeName.class), "jdbc:mock://127.0.0.1/ds_1"), dataSourcePoolProps, new MockedDataSource()));
ResourceMetaData result = mock(ResourceMetaData.class, RETURNS_DEEP_STUBS);
when(result.getStorageUnits()).thenReturn(storageUnits);
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ private ShardingSphereDatabase mockDatabase() {
DataSourcePoolProperties dataSourcePoolProps = mock(DataSourcePoolProperties.class, RETURNS_DEEP_STUBS);
when(dataSourcePoolProps.getConnectionPropertySynonyms().getStandardProperties()).thenReturn(Collections.emptyMap());
Map<String, StorageUnit> storageUnits = Collections.singletonMap("foo_ds",
new StorageUnit("foo_db", mock(StorageNode.class, RETURNS_DEEP_STUBS), dataSourcePoolProps, new MockedDataSource()));
new StorageUnit("foo_db", new StorageNode(mock(StorageNodeName.class), "jdbc:mock://127.0.0.1/foo_db"), dataSourcePoolProps, new MockedDataSource()));
when(result.getResourceMetaData().getStorageUnits()).thenReturn(storageUnits);
return result;
}
Expand Down
Loading

0 comments on commit 00e2c08

Please sign in to comment.