Skip to content

Commit

Permalink
Add RuleMetaData for SQLTranslator interface
Browse files Browse the repository at this point in the history
  • Loading branch information
strongduanmu committed Oct 16, 2023
1 parent dd5257d commit bc7c13f
Show file tree
Hide file tree
Showing 12 changed files with 43 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,10 @@ public SQLRewriteResult rewrite(final String sql, final List<Object> params, fin
SQLTranslatorRule rule = globalRuleMetaData.getSingleRule(SQLTranslatorRule.class);
DatabaseType protocolType = database.getProtocolType();
Map<String, StorageUnit> storageUnits = database.getResourceMetaData().getStorageUnits();
RuleMetaData ruleMetaData = database.getRuleMetaData();
return routeContext.getRouteUnits().isEmpty()
? new GenericSQLRewriteEngine(rule, protocolType, storageUnits).rewrite(sqlRewriteContext)
: new RouteSQLRewriteEngine(rule, protocolType, storageUnits).rewrite(sqlRewriteContext, routeContext);
? new GenericSQLRewriteEngine(rule, protocolType, storageUnits, ruleMetaData).rewrite(sqlRewriteContext)
: new RouteSQLRewriteEngine(rule, protocolType, storageUnits, ruleMetaData).rewrite(sqlRewriteContext, routeContext);
}

private SQLRewriteContext createSQLRewriteContext(final String sql, final List<Object> params, final SQLStatementContext sqlStatementContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.engine.result.GenericSQLRewriteResult;
import org.apache.shardingsphere.infra.rewrite.engine.result.SQLRewriteUnit;
Expand All @@ -40,6 +41,8 @@ public final class GenericSQLRewriteEngine {

private final Map<String, StorageUnit> storageUnits;

private final RuleMetaData globalRuleMetaData;

/**
* Rewrite SQL and parameters.
*
Expand All @@ -49,7 +52,7 @@ public final class GenericSQLRewriteEngine {
public GenericSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext) {
String sql = translatorRule.translate(
new DefaultSQLBuilder(sqlRewriteContext).toSQL(), sqlRewriteContext.getSqlStatementContext().getSqlStatement(), protocolType,
storageUnits.isEmpty() ? protocolType : storageUnits.values().iterator().next().getStorageType());
storageUnits.isEmpty() ? protocolType : storageUnits.values().iterator().next().getStorageType(), globalRuleMetaData);
return new GenericSQLRewriteResult(new SQLRewriteUnit(sql, sqlRewriteContext.getParameterBuilder().getParameters()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.datanode.DataNode;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.engine.result.RouteSQLRewriteResult;
import org.apache.shardingsphere.infra.rewrite.engine.result.SQLRewriteUnit;
Expand Down Expand Up @@ -56,6 +57,8 @@ public final class RouteSQLRewriteEngine {

private final Map<String, StorageUnit> storageUnits;

private final RuleMetaData globalRuleMetaData;

/**
* Rewrite SQL and parameters.
*
Expand Down Expand Up @@ -158,7 +161,7 @@ private Map<RouteUnit, SQLRewriteUnit> translate(final SQLStatement sqlStatement
Map<RouteUnit, SQLRewriteUnit> result = new LinkedHashMap<>(sqlRewriteUnits.size(), 1F);
for (Entry<RouteUnit, SQLRewriteUnit> entry : sqlRewriteUnits.entrySet()) {
DatabaseType storageType = storageUnits.get(entry.getKey().getDataSourceMapper().getActualName()).getStorageType();
String sql = translatorRule.translate(entry.getValue().getSql(), sqlStatement, protocolType, storageType);
String sql = translatorRule.translate(entry.getValue().getSql(), sqlStatement, protocolType, storageType, globalRuleMetaData);
SQLRewriteUnit sqlRewriteUnit = new SQLRewriteUnit(sql, entry.getValue().getParameters());
result.put(entry.getKey(), sqlRewriteUnit);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.engine.result.GenericSQLRewriteResult;
Expand All @@ -46,7 +47,7 @@ class GenericSQLRewriteEngineTest {
void assertRewrite() {
DatabaseType databaseType = mock(DatabaseType.class);
SQLTranslatorRule rule = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration());
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, databaseType, mockStorageUnits(databaseType))
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, databaseType, mockStorageUnits(databaseType), mock(RuleMetaData.class))
.rewrite(new SQLRewriteContext(mockDatabase(), mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class),
new HintValueContext()));
assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1"));
Expand All @@ -56,7 +57,7 @@ void assertRewrite() {
@Test
void assertRewriteStorageTypeIsEmpty() {
SQLTranslatorRule rule = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration());
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, mock(DatabaseType.class), Collections.emptyMap())
GenericSQLRewriteResult actual = new GenericSQLRewriteEngine(rule, mock(DatabaseType.class), Collections.emptyMap(), mock(RuleMetaData.class))
.rewrite(new SQLRewriteContext(mockDatabase(), mock(CommonSQLStatementContext.class), "SELECT 1", Collections.emptyList(), mock(ConnectionContext.class),
new HintValueContext()));
assertThat(actual.getSqlRewriteUnit().getSql(), is("SELECT 1"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.engine.result.RouteSQLRewriteResult;
Expand Down Expand Up @@ -60,7 +61,7 @@ void assertRewriteWithStandardParameterBuilder() {
routeContext.getRouteUnits().add(routeUnit);
DatabaseType databaseType = mock(DatabaseType.class);
RouteSQLRewriteResult actual = new RouteSQLRewriteEngine(
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType)).rewrite(sqlRewriteContext, routeContext);
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType), mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext);
assertThat(actual.getSqlRewriteUnits().size(), is(1));
assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("SELECT ?"));
assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), is(Collections.singletonList(1)));
Expand All @@ -79,7 +80,7 @@ void assertRewriteWithStandardParameterBuilderWhenNeedAggregateRewrite() {
routeContext.getRouteUnits().add(secondRouteUnit);
DatabaseType databaseType = mock(DatabaseType.class);
RouteSQLRewriteResult actual = new RouteSQLRewriteEngine(
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType)).rewrite(sqlRewriteContext, routeContext);
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType), mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext);
assertThat(actual.getSqlRewriteUnits().size(), is(1));
assertThat(actual.getSqlRewriteUnits().get(firstRouteUnit).getSql(), is("SELECT ? UNION ALL SELECT ?"));
assertThat(actual.getSqlRewriteUnits().get(firstRouteUnit).getParameters(), is(Arrays.asList(1, 1)));
Expand All @@ -99,7 +100,7 @@ void assertRewriteWithGroupedParameterBuilderForBroadcast() {
routeContext.getRouteUnits().add(routeUnit);
DatabaseType databaseType = mock(DatabaseType.class);
RouteSQLRewriteResult actual = new RouteSQLRewriteEngine(
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType)).rewrite(sqlRewriteContext, routeContext);
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType), mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext);
assertThat(actual.getSqlRewriteUnits().size(), is(1));
assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("INSERT INTO tbl VALUES (?)"));
assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), is(Collections.singletonList(1)));
Expand All @@ -121,7 +122,7 @@ void assertRewriteWithGroupedParameterBuilderForRouteWithSameDataNode() {
routeContext.getOriginalDataNodes().add(Collections.singletonList(new DataNode("ds.tbl_0")));
DatabaseType databaseType = mock(DatabaseType.class);
RouteSQLRewriteResult actual = new RouteSQLRewriteEngine(
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType)).rewrite(sqlRewriteContext, routeContext);
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType), mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext);
assertThat(actual.getSqlRewriteUnits().size(), is(1));
assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("INSERT INTO tbl VALUES (?)"));
assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), is(Collections.singletonList(1)));
Expand All @@ -142,7 +143,7 @@ void assertRewriteWithGroupedParameterBuilderForRouteWithEmptyDataNode() {
routeContext.getOriginalDataNodes().add(Collections.emptyList());
DatabaseType databaseType = mock(DatabaseType.class);
RouteSQLRewriteResult actual = new RouteSQLRewriteEngine(
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType)).rewrite(sqlRewriteContext, routeContext);
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType), mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext);
assertThat(actual.getSqlRewriteUnits().size(), is(1));
assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("INSERT INTO tbl VALUES (?)"));
assertThat(actual.getSqlRewriteUnits().get(routeUnit).getParameters(), is(Collections.singletonList(1)));
Expand All @@ -163,7 +164,7 @@ void assertRewriteWithGroupedParameterBuilderForRouteWithNotSameDataNode() {
routeContext.getOriginalDataNodes().add(Collections.singletonList(new DataNode("ds_1.tbl_1")));
DatabaseType databaseType = mock(DatabaseType.class);
RouteSQLRewriteResult actual = new RouteSQLRewriteEngine(
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType)).rewrite(sqlRewriteContext, routeContext);
new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()), databaseType, mockStorageUnits(databaseType), mock(RuleMetaData.class)).rewrite(sqlRewriteContext, routeContext);
assertThat(actual.getSqlRewriteUnits().size(), is(1));
assertThat(actual.getSqlRewriteUnits().get(routeUnit).getSql(), is("INSERT INTO tbl VALUES (?)"));
assertTrue(actual.getSqlRewriteUnits().get(routeUnit).getParameters().isEmpty());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.shardingsphere.sqltranslator.spi;

import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.spi.annotation.SingletonSPI;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPI;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
Expand All @@ -35,7 +36,8 @@ public interface SQLTranslator extends TypedSPI {
* @param sqlStatement to be translated SQL statement
* @param protocolType protocol type
* @param storageType storage type
* @param globalRuleMetaData global rule meta data
* @return translated SQL
*/
String translate(String sql, SQLStatement sqlStatement, DatabaseType protocolType, DatabaseType storageType);
String translate(String sql, SQLStatement sqlStatement, DatabaseType protocolType, DatabaseType storageType, RuleMetaData globalRuleMetaData);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import lombok.Getter;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.rule.identifier.scope.GlobalRule;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
Expand Down Expand Up @@ -51,14 +52,15 @@ public SQLTranslatorRule(final SQLTranslatorRuleConfiguration ruleConfig) {
* @param sqlStatement to be translated SQL statement
* @param protocolType protocol type
* @param storageType storage type
* @param globalRuleMetaData global rule meta data
* @return translated SQL
*/
public String translate(final String sql, final SQLStatement sqlStatement, final DatabaseType protocolType, final DatabaseType storageType) {
public String translate(final String sql, final SQLStatement sqlStatement, final DatabaseType protocolType, final DatabaseType storageType, final RuleMetaData globalRuleMetaData) {
if (protocolType.equals(storageType) || null == storageType) {
return sql;
}
try {
return translator.translate(sql, sqlStatement, protocolType, storageType);
return translator.translate(sql, sqlStatement, protocolType, storageType, globalRuleMetaData);
} catch (final SQLTranslationException ex) {
if (useOriginalSQLWhenTranslatingFailed) {
return sql;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.shardingsphere.sqltranslator.rule;

import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.sqltranslator.api.config.SQLTranslatorRuleConfiguration;
import org.apache.shardingsphere.sqltranslator.exception.syntax.UnsupportedTranslatedDatabaseException;
Expand All @@ -28,46 +29,47 @@
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;

class SQLTranslatorRuleTest {

@Test
void assertTranslateWhenProtocolSameAsStorage() {
String expected = "select 1";
DatabaseType databaseType = TypedSPILoader.getService(DatabaseType.class, "PostgreSQL");
String actual = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("CONVERT_TO_UPPER_CASE", false)).translate(expected, null, databaseType, databaseType);
String actual = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("CONVERT_TO_UPPER_CASE", false)).translate(expected, null, databaseType, databaseType, mock(RuleMetaData.class));
assertThat(actual, is(expected));
}

@Test
void assertTranslateWhenNoStorage() {
String expected = "select 1";
String actual = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("CONVERT_TO_UPPER_CASE", false)).translate(
expected, null, TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"), null);
expected, null, TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"), null, mock(RuleMetaData.class));
assertThat(actual, is(expected));
}

@Test
void assertTranslateWithProtocolDifferentWithStorage() {
String input = "select 1";
String actual = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("CONVERT_TO_UPPER_CASE", false)).translate(
input, null, TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"), TypedSPILoader.getService(DatabaseType.class, "MySQL"));
input, null, TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"), TypedSPILoader.getService(DatabaseType.class, "MySQL"), mock(RuleMetaData.class));
assertThat(actual, is(input.toUpperCase(Locale.ROOT)));
}

@Test
void assertUseOriginalSQLWhenTranslatingFailed() {
String expected = "select 1";
String actual = new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("ALWAYS_FAILED", true)).translate(expected, null,
TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"), TypedSPILoader.getService(DatabaseType.class, "MySQL"));
TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"), TypedSPILoader.getService(DatabaseType.class, "MySQL"), mock(RuleMetaData.class));
assertThat(actual, is(expected));
}

@Test
void assertNotUseOriginalSQLWhenTranslatingFailed() {
assertThrows(UnsupportedTranslatedDatabaseException.class,
() -> new SQLTranslatorRule(new SQLTranslatorRuleConfiguration("ALWAYS_FAILED", false)).translate("", null,
TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"), TypedSPILoader.getService(DatabaseType.class, "MySQL")));
TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"), TypedSPILoader.getService(DatabaseType.class, "MySQL"), mock(RuleMetaData.class)));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
package org.apache.shardingsphere.sqltranslator.rule.fixture;

import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
import org.apache.shardingsphere.sqltranslator.exception.syntax.UnsupportedTranslatedDatabaseException;
import org.apache.shardingsphere.sqltranslator.spi.SQLTranslator;

public final class AlwaysFailedSQLTranslator implements SQLTranslator {

@Override
public String translate(final String sql, final SQLStatement sqlStatement, final DatabaseType protocolType, final DatabaseType storageType) {
public String translate(final String sql, final SQLStatement sqlStatement, final DatabaseType protocolType, final DatabaseType storageType, final RuleMetaData globalRuleMetaData) {
throw new UnsupportedTranslatedDatabaseException(storageType);
}

Expand Down
Loading

0 comments on commit bc7c13f

Please sign in to comment.