Skip to content

Commit

Permalink
Check privilege when register or alter storage unit (apache#32172)
Browse files Browse the repository at this point in the history
  • Loading branch information
RaigorJiang authored and AbnerHuang2 committed Jul 24, 2024
1 parent c2c72f9 commit f6d5e66
Show file tree
Hide file tree
Showing 16 changed files with 100 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ void assertParseRQL() {

@Test
void assertParseRDL() {
assertParse(new RegisterStorageUnitStatement(false, Collections.emptyList()), "RDL=1");
assertParse(new RegisterStorageUnitStatement(false, Collections.emptyList(), Collections.emptySet()), "RDL=1");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
import java.util.stream.Collectors;

/**
* Storage units connect exception.
* Storage units validate exception.
*/
public final class StorageUnitsConnectException extends ResourceDefinitionException {
public final class StorageUnitsValidateException extends ResourceDefinitionException {

private static final long serialVersionUID = 1824912697040264268L;

public StorageUnitsConnectException(final Map<String, Exception> causes) {
super(XOpenSQLState.CONNECTION_EXCEPTION, 10, "Storage units can not connect, error messages are: %s.", causes.entrySet().stream().map(entry -> String.format(
public StorageUnitsValidateException(final Map<String, Exception> causes) {
super(XOpenSQLState.CONNECTION_EXCEPTION, 10, "Storage units validate error, messages are: %s.", causes.entrySet().stream().map(entry -> String.format(
"Storage unit name: '%s', error message is: %s", entry.getKey(), entry.getValue().getMessage())).collect(Collectors.joining(System.lineSeparator())));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.infra.database.core.checker.DialectDatabaseEnvironmentChecker;
import org.apache.shardingsphere.infra.database.core.checker.PrivilegeCheckType;
import org.apache.shardingsphere.infra.database.core.spi.DatabaseTypedSPILoader;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeFactory;
import org.apache.shardingsphere.infra.datasource.pool.creator.DataSourcePoolCreator;
import org.apache.shardingsphere.infra.datasource.pool.destroyer.DataSourcePoolDestroyer;
import org.apache.shardingsphere.infra.datasource.pool.props.domain.DataSourcePoolProperties;
Expand All @@ -27,9 +32,11 @@
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;

/**
* Data source pool properties validator.
Expand All @@ -41,14 +48,15 @@ public final class DataSourcePoolPropertiesValidator {
* Validate data source pool properties map.
*
* @param propsMap data source pool properties map
* @param expectedPrivileges excepted privileges
* @return data source name and exception map
*/
public static Map<String, Exception> validate(final Map<String, DataSourcePoolProperties> propsMap) {
public static Map<String, Exception> validate(final Map<String, DataSourcePoolProperties> propsMap, final Collection<PrivilegeCheckType> expectedPrivileges) {
Map<String, Exception> result = new LinkedHashMap<>(propsMap.size(), 1F);
for (Entry<String, DataSourcePoolProperties> entry : propsMap.entrySet()) {
try {
validateProperties(entry.getKey(), entry.getValue());
validateConnection(entry.getKey(), entry.getValue());
validateConnection(entry.getKey(), entry.getValue(), expectedPrivileges);
} catch (final InvalidDataSourcePoolPropertiesException ex) {
result.put(entry.getKey(), ex);
}
Expand All @@ -64,11 +72,16 @@ private static void validateProperties(final String dataSourceName, final DataSo
}
}

private static void validateConnection(final String dataSourceName, final DataSourcePoolProperties props) throws InvalidDataSourcePoolPropertiesException {
private static void validateConnection(final String dataSourceName, final DataSourcePoolProperties props,
final Collection<PrivilegeCheckType> expectedPrivileges) throws InvalidDataSourcePoolPropertiesException {
DataSource dataSource = null;
try {
dataSource = DataSourcePoolCreator.create(props);
checkFailFast(dataSource);
if (expectedPrivileges.isEmpty() || expectedPrivileges.contains(PrivilegeCheckType.NONE)) {
checkFailFast(dataSource);
return;
}
checkPrivileges(dataSource, props, expectedPrivileges);
// CHECKSTYLE:OFF
} catch (final SQLException | RuntimeException ex) {
// CHECKSTYLE:ON
Expand All @@ -87,4 +100,14 @@ private static void checkFailFast(final DataSource dataSource) throws SQLExcepti
// CHECKSTYLE:ON
}
}

private static void checkPrivileges(final DataSource dataSource, final DataSourcePoolProperties props, final Collection<PrivilegeCheckType> expectedPrivileges) {
DatabaseType databaseType = DatabaseTypeFactory.get((String) props.getConnectionPropertySynonyms().getStandardProperties().get("url"));
Optional<DialectDatabaseEnvironmentChecker> checker = DatabaseTypedSPILoader.findService(DialectDatabaseEnvironmentChecker.class, databaseType);
if (checker.isPresent()) {
for (PrivilegeCheckType each : expectedPrivileges) {
checker.get().checkPrivilege(dataSource, each);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ static void setUp() throws ClassNotFoundException {
@Test
void assertValidate() {
assertTrue(DataSourcePoolPropertiesValidator.validate(
Collections.singletonMap("name", new DataSourcePoolProperties(HikariDataSource.class.getName(), Collections.singletonMap("jdbcUrl", "jdbc:mock")))).isEmpty());
Collections.singletonMap("name", new DataSourcePoolProperties(HikariDataSource.class.getName(), Collections.singletonMap("jdbcUrl", "jdbc:mock"))), Collections.emptySet()).isEmpty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
*/
public enum PrivilegeCheckType {

PIPELINE, SELECT, XA
NONE, PIPELINE, SELECT, XA
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.shardingsphere.distsql.segment.URLBasedDataSourceSegment;
import org.apache.shardingsphere.distsql.segment.converter.DataSourceSegmentsConverter;
import org.apache.shardingsphere.distsql.statement.rdl.resource.unit.type.AlterStorageUnitStatement;
import org.apache.shardingsphere.infra.database.core.checker.PrivilegeCheckType;
import org.apache.shardingsphere.infra.database.core.connector.ConnectionProperties;
import org.apache.shardingsphere.infra.database.core.connector.url.JdbcUrl;
import org.apache.shardingsphere.infra.database.core.connector.url.StandardJdbcUrlParser;
Expand All @@ -35,8 +36,8 @@
import org.apache.shardingsphere.infra.exception.core.external.ShardingSphereExternalException;
import org.apache.shardingsphere.infra.exception.kernel.metadata.resource.storageunit.AlterStorageUnitConnectionInfoException;
import org.apache.shardingsphere.infra.exception.kernel.metadata.resource.storageunit.DuplicateStorageUnitException;
import org.apache.shardingsphere.infra.exception.kernel.metadata.resource.storageunit.StorageUnitsOperateException;
import org.apache.shardingsphere.infra.exception.kernel.metadata.resource.storageunit.MissingRequiredStorageUnitsException;
import org.apache.shardingsphere.infra.exception.kernel.metadata.resource.storageunit.StorageUnitsOperateException;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.mode.manager.ContextManager;
Expand Down Expand Up @@ -64,7 +65,7 @@ public final class AlterStorageUnitExecutor implements DistSQLUpdateExecutor<Alt
public void executeUpdate(final AlterStorageUnitStatement sqlStatement, final ContextManager contextManager) {
checkBefore(sqlStatement);
Map<String, DataSourcePoolProperties> propsMap = DataSourceSegmentsConverter.convert(database.getProtocolType(), sqlStatement.getStorageUnits());
validateHandler.validate(propsMap);
validateHandler.validate(propsMap, sqlStatement.getExpectedPrivileges().stream().map(each -> PrivilegeCheckType.valueOf(each.toUpperCase())).collect(Collectors.toSet()));
try {
MetaDataContexts originalMetaDataContexts = contextManager.getMetaDataContexts();
contextManager.getPersistServiceFacade().getMetaDataManagerPersistService().alterStorageUnits(database.getName(), propsMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.shardingsphere.distsql.segment.DataSourceSegment;
import org.apache.shardingsphere.distsql.segment.converter.DataSourceSegmentsConverter;
import org.apache.shardingsphere.distsql.statement.rdl.resource.unit.type.RegisterStorageUnitStatement;
import org.apache.shardingsphere.infra.database.core.checker.PrivilegeCheckType;
import org.apache.shardingsphere.infra.datasource.pool.props.domain.DataSourcePoolProperties;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.exception.core.external.ShardingSphereExternalException;
Expand Down Expand Up @@ -66,7 +67,7 @@ public void executeUpdate(final RegisterStorageUnitStatement sqlStatement, final
if (propsMap.isEmpty()) {
return;
}
validateHandler.validate(propsMap);
validateHandler.validate(propsMap, sqlStatement.getExpectedPrivileges().stream().map(each -> PrivilegeCheckType.valueOf(each.toUpperCase())).collect(Collectors.toSet()));
try {
MetaDataContexts originalMetaDataContexts = contextManager.getMetaDataContexts();
contextManager.getPersistServiceFacade().getMetaDataManagerPersistService().registerStorageUnits(database.getName(), propsMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@

package org.apache.shardingsphere.distsql.handler.validate;

import org.apache.shardingsphere.infra.database.core.checker.PrivilegeCheckType;
import org.apache.shardingsphere.infra.datasource.pool.props.domain.DataSourcePoolProperties;
import org.apache.shardingsphere.infra.datasource.pool.props.validator.DataSourcePoolPropertiesValidator;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.exception.kernel.metadata.resource.storageunit.StorageUnitsConnectException;
import org.apache.shardingsphere.infra.exception.kernel.metadata.resource.storageunit.StorageUnitsValidateException;

import java.util.Collection;
import java.util.Collections;
import java.util.Map;

/**
Expand All @@ -35,7 +38,17 @@ public final class DistSQLDataSourcePoolPropertiesValidator {
* @param propsMap data source pool properties map
*/
public void validate(final Map<String, DataSourcePoolProperties> propsMap) {
Map<String, Exception> exceptions = DataSourcePoolPropertiesValidator.validate(propsMap);
ShardingSpherePreconditions.checkMustEmpty(exceptions, () -> new StorageUnitsConnectException(exceptions));
validate(propsMap, Collections.emptySet());
}

/**
* Validate data source properties map.
*
* @param propsMap data source pool properties map
* @param expectedPrivileges expected privileges
*/
public void validate(final Map<String, DataSourcePoolProperties> propsMap, final Collection<PrivilegeCheckType> expectedPrivileges) {
Map<String, Exception> exceptions = DataSourcePoolPropertiesValidator.validate(propsMap, expectedPrivileges);
ShardingSpherePreconditions.checkMustEmpty(exceptions, () -> new StorageUnitsValidateException(exceptions));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,14 @@ private ContextManager mockContextManager(final MetaDataContexts metaDataContext
}

private AlterStorageUnitStatement createAlterStorageUnitStatement(final String resourceName) {
return new AlterStorageUnitStatement(Collections.singleton(new URLBasedDataSourceSegment(resourceName, "jdbc:mysql://127.0.0.1:3306/ds_0", "root", "", new Properties())));
return new AlterStorageUnitStatement(Collections.singleton(new URLBasedDataSourceSegment(resourceName, "jdbc:mysql://127.0.0.1:3306/ds_0", "root", "", new Properties())),
Collections.emptySet());
}

private AlterStorageUnitStatement createAlterStorageUnitStatementWithDuplicateStorageUnitNames() {
return new AlterStorageUnitStatement(Arrays.asList(
new HostnameAndPortBasedDataSourceSegment("ds_0", "127.0.0.1", "3306", "ds_0", "root", "", new Properties()),
new URLBasedDataSourceSegment("ds_0", "jdbc:mysql://127.0.0.1:3306/ds_1", "root", "", new Properties())));
new URLBasedDataSourceSegment("ds_0", "jdbc:mysql://127.0.0.1:3306/ds_1", "root", "", new Properties())), Collections.emptySet());
}

private ConnectionProperties mockConnectionProperties(final String catalog) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,13 @@ void assertExecuteUpdateWithDuplicateStorageUnitNamesWithDataSourceContainedRule
}

private RegisterStorageUnitStatement createRegisterStorageUnitStatement() {
return new RegisterStorageUnitStatement(false, Collections.singleton(new URLBasedDataSourceSegment("ds_0", "jdbc:mysql://127.0.0.1:3306/test0", "root", "", new Properties())));
return new RegisterStorageUnitStatement(false, Collections.singleton(new URLBasedDataSourceSegment("ds_0", "jdbc:mysql://127.0.0.1:3306/test0", "root", "", new Properties())),
Collections.emptySet());
}

private RegisterStorageUnitStatement createRegisterStorageUnitStatementWithDuplicateStorageUnitNames() {
return new RegisterStorageUnitStatement(false, Arrays.asList(
new HostnameAndPortBasedDataSourceSegment("ds_0", "127.0.0.1", "3306", "ds_0", "root", "", new Properties()),
new URLBasedDataSourceSegment("ds_0", "jdbc:mysql://127.0.0.1:3306/ds_1", "root", "", new Properties())));
new URLBasedDataSourceSegment("ds_0", "jdbc:mysql://127.0.0.1:3306/ds_1", "root", "", new Properties())), Collections.emptySet());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,23 @@

import org.apache.shardingsphere.distsql.statement.rdl.resource.unit.type.RegisterStorageUnitStatement;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableNameSegment;
import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.SelectStatement;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
import org.apache.shardingsphere.sql.parser.statement.mysql.ddl.MySQLCreateTableStatement;
import org.apache.shardingsphere.sql.parser.statement.mysql.dml.MySQLInsertStatement;
import org.apache.shardingsphere.sql.parser.statement.mysql.dml.MySQLSelectStatement;
import org.junit.jupiter.api.Test;

import java.util.LinkedList;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;

class AutoCommitUtilsTest {

@Test
void assertNeedOpenTransactionForSelectStatement() {
SelectStatement selectStatement = new MySQLSelectStatement();
assertFalse(AutoCommitUtils.needOpenTransaction(selectStatement));
selectStatement.setFrom(new SimpleTableSegment(new TableNameSegment(0, 5, new IdentifierValue("foo"))));
selectStatement.setFrom(mock(SimpleTableSegment.class));
assertTrue(AutoCommitUtils.needOpenTransaction(selectStatement));
}

Expand All @@ -50,6 +47,6 @@ void assertNeedOpenTransactionForDDLOrDMLStatement() {

@Test
void assertNeedOpenTransactionForOtherStatement() {
assertFalse(AutoCommitUtils.needOpenTransaction(new RegisterStorageUnitStatement(false, new LinkedList<>())));
assertFalse(AutoCommitUtils.needOpenTransaction(mock(RegisterStorageUnitStatement.class)));
}
}
4 changes: 4 additions & 0 deletions parser/distsql/engine/src/main/antlr4/imports/Keyword.g4
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,7 @@ ALGORITHM
FORCE
: F O R C E
;

CHECK_PRIVILEGES
: C H E C K UL_ P R I V I L E G E S
;
16 changes: 14 additions & 2 deletions parser/distsql/engine/src/main/antlr4/imports/RDLStatement.g4
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,21 @@ grammar RDLStatement;
import BaseRule;

registerStorageUnit
: REGISTER STORAGE UNIT ifNotExists? storageUnitDefinition (COMMA_ storageUnitDefinition)*
: REGISTER STORAGE UNIT ifNotExists? storageUnitsDefinition (COMMA_ checkPrivileges)?
;

alterStorageUnit
: ALTER STORAGE UNIT storageUnitDefinition (COMMA_ storageUnitDefinition)*
: ALTER STORAGE UNIT storageUnitsDefinition (COMMA_ checkPrivileges)?
;

unregisterStorageUnit
: UNREGISTER STORAGE UNIT ifExists? storageUnitName (COMMA_ storageUnitName)* ignoreTables?
;

storageUnitsDefinition
: storageUnitDefinition (COMMA_ storageUnitDefinition)*
;

storageUnitDefinition
: storageUnitName LP_ (simpleSource | urlSource) COMMA_ USER EQ_ user (COMMA_ PASSWORD EQ_ password)? (COMMA_ propertiesDefinition)? RP_
;
Expand Down Expand Up @@ -80,3 +84,11 @@ ifExists
ifNotExists
: IF NOT EXISTS
;

checkPrivileges
: CHECK_PRIVILEGES EQ_ privilegeType (COMMA_ privilegeType)*
;

privilegeType
: IDENTIFIER_
;
Loading

0 comments on commit f6d5e66

Please sign in to comment.