Skip to content

Commit

Permalink
Refactor SQLRewriteEntry
Browse files Browse the repository at this point in the history
  • Loading branch information
terrymanu committed Sep 29, 2023
1 parent 6cfb07a commit 3928c88
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.NewStorageUnitMetaData;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContextDecorator;
Expand All @@ -35,11 +35,9 @@
import org.apache.shardingsphere.infra.spi.type.ordered.OrderedSPILoader;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.stream.Collectors;

/**
* SQL rewrite entry.
Expand Down Expand Up @@ -79,11 +77,10 @@ public SQLRewriteResult rewrite(final String sql, final List<Object> params, fin
SQLRewriteContext sqlRewriteContext = createSQLRewriteContext(sql, params, sqlStatementContext, routeContext, connectionContext, hintValueContext);
SQLTranslatorRule rule = globalRuleMetaData.getSingleRule(SQLTranslatorRule.class);
DatabaseType protocolType = database.getProtocolType();
Map<String, StorageUnit> storageUnits = database.getResourceMetaData().getStorageUnitMetaData().getMetaDataMap().entrySet().stream()
.collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().getStorageUnit(), (oldValue, currentValue) -> oldValue, LinkedHashMap::new));
Map<String, NewStorageUnitMetaData> storageUnitMetaDataMap = database.getResourceMetaData().getStorageUnitMetaData().getMetaDataMap();
return routeContext.getRouteUnits().isEmpty()
? new GenericSQLRewriteEngine(rule, protocolType, storageUnits).rewrite(sqlRewriteContext)
: new RouteSQLRewriteEngine(rule, protocolType, storageUnits).rewrite(sqlRewriteContext, routeContext);
? new GenericSQLRewriteEngine(rule, protocolType, storageUnitMetaDataMap).rewrite(sqlRewriteContext)
: new RouteSQLRewriteEngine(rule, protocolType, storageUnitMetaDataMap).rewrite(sqlRewriteContext, routeContext);
}

private SQLRewriteContext createSQLRewriteContext(final String sql, final List<Object> params, final SQLStatementContext sqlStatementContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.NewStorageUnitMetaData;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.engine.result.GenericSQLRewriteResult;
import org.apache.shardingsphere.infra.rewrite.engine.result.SQLRewriteUnit;
Expand All @@ -38,7 +38,7 @@ public final class GenericSQLRewriteEngine {

private final DatabaseType protocolType;

private final Map<String, StorageUnit> storageUnits;
private final Map<String, NewStorageUnitMetaData> storageUnitMetaDataMap;

/**
* Rewrite SQL and parameters.
Expand All @@ -49,7 +49,7 @@ public final class GenericSQLRewriteEngine {
public GenericSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext) {
String sql = translatorRule.translate(
new DefaultSQLBuilder(sqlRewriteContext).toSQL(), sqlRewriteContext.getSqlStatementContext().getSqlStatement(), protocolType,
storageUnits.isEmpty() ? protocolType : storageUnits.values().iterator().next().getStorageType());
storageUnitMetaDataMap.isEmpty() ? protocolType : storageUnitMetaDataMap.values().iterator().next().getStorageUnit().getStorageType());
return new GenericSQLRewriteResult(new SQLRewriteUnit(sql, sqlRewriteContext.getParameterBuilder().getParameters()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.datanode.DataNode;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.NewStorageUnitMetaData;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.engine.result.RouteSQLRewriteResult;
import org.apache.shardingsphere.infra.rewrite.engine.result.SQLRewriteUnit;
Expand Down Expand Up @@ -54,7 +54,7 @@ public final class RouteSQLRewriteEngine {

private final DatabaseType protocolType;

private final Map<String, StorageUnit> storageUnits;
private final Map<String, NewStorageUnitMetaData> storageUnitMetaDataMap;

/**
* Rewrite SQL and parameters.
Expand Down Expand Up @@ -157,7 +157,7 @@ private boolean isInSameDataNode(final Collection<DataNode> dataNodes, final Rou
private Map<RouteUnit, SQLRewriteUnit> translate(final SQLStatement sqlStatement, final Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits) {
Map<RouteUnit, SQLRewriteUnit> result = new LinkedHashMap<>(sqlRewriteUnits.size(), 1F);
for (Entry<RouteUnit, SQLRewriteUnit> entry : sqlRewriteUnits.entrySet()) {
DatabaseType storageType = storageUnits.get(entry.getKey().getDataSourceMapper().getActualName()).getStorageType();
DatabaseType storageType = storageUnitMetaDataMap.get(entry.getKey().getDataSourceMapper().getActualName()).getStorageUnit().getStorageType();
String sql = translatorRule.translate(entry.getValue().getSql(), sqlStatement, protocolType, storageType);
SQLRewriteUnit sqlRewriteUnit = new SQLRewriteUnit(sql, entry.getValue().getParameters());
result.put(entry.getKey(), sqlRewriteUnit);
Expand Down

0 comments on commit 3928c88

Please sign in to comment.