Skip to content

Commit

Permalink
Use sql bind info in EncryptInsertPredicateColumnTokenGenerator to av…
Browse files Browse the repository at this point in the history
…oid wrong column table mapping
  • Loading branch information
strongduanmu committed Dec 20, 2024
1 parent af96795 commit 111d22f
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,21 @@
import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.CollectionSQLTokenGenerator;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.aware.SchemaMetaDataAware;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken;

import java.util.Collection;
import java.util.Map;

/**
* Insert predicate column token generator for encrypt.
*/
@HighFrequencyInvocation
@RequiredArgsConstructor
@Setter
public final class EncryptInsertPredicateColumnTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext>, SchemaMetaDataAware {
public final class EncryptInsertPredicateColumnTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext> {

private final EncryptRule rule;

private Map<String, ShardingSphereSchema> schemas;

private ShardingSphereSchema defaultSchema;

@Override
public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()
Expand All @@ -54,8 +47,6 @@ public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext)
@Override
public Collection<SQLToken> generateSQLTokens(final SQLStatementContext sqlStatementContext) {
EncryptPredicateColumnTokenGenerator generator = new EncryptPredicateColumnTokenGenerator(rule);
generator.setSchemas(schemas);
generator.setDefaultSchema(defaultSchema);
return generator.generateSQLTokens(((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,23 @@
import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ColumnProjection;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.CollectionSQLTokenGenerator;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.aware.SchemaMetaDataAware;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.generic.SubstitutableColumnNameToken;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ExpressionExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BinaryOperationExpression;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.AndPredicate;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ExpressionExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;

import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.Map;
import java.util.Optional;

/**
Expand All @@ -59,14 +55,10 @@
@HighFrequencyInvocation
@RequiredArgsConstructor
@Setter
public final class EncryptPredicateColumnTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext>, SchemaMetaDataAware {
public final class EncryptPredicateColumnTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext> {

private final EncryptRule rule;

private Map<String, ShardingSphereSchema> schemas;

private ShardingSphereSchema defaultSchema;

@Override
public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof WhereAvailable && !((WhereAvailable) sqlStatementContext).getWhereSegments().isEmpty();
Expand All @@ -77,19 +69,16 @@ public Collection<SQLToken> generateSQLTokens(final SQLStatementContext sqlState
Collection<SelectStatementContext> allSubqueryContexts = SQLStatementContextExtractor.getAllSubqueryContexts(sqlStatementContext);
Collection<WhereSegment> whereSegments = SQLStatementContextExtractor.getWhereSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
Collection<ColumnSegment> columnSegments = SQLStatementContextExtractor.getColumnSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
ShardingSphereSchema schema = ((TableAvailable) sqlStatementContext).getTablesContext().getSchemaName().map(schemas::get).orElseGet(() -> defaultSchema);
Map<String, String> columnExpressionTableNames = ((TableAvailable) sqlStatementContext).getTablesContext().findTableNames(columnSegments, schema);
return generateSQLTokens(columnSegments, columnExpressionTableNames, whereSegments, sqlStatementContext.getDatabaseType());
return generateSQLTokens(columnSegments, whereSegments, sqlStatementContext.getDatabaseType());
}

private Collection<SQLToken> generateSQLTokens(final Collection<ColumnSegment> columnSegments, final Map<String, String> columnExpressionTableNames,
final Collection<WhereSegment> whereSegments, final DatabaseType databaseType) {
private Collection<SQLToken> generateSQLTokens(final Collection<ColumnSegment> columnSegments, final Collection<WhereSegment> whereSegments, final DatabaseType databaseType) {
Collection<SQLToken> result = new LinkedList<>();
for (ColumnSegment each : columnSegments) {
String tableName = columnExpressionTableNames.getOrDefault(each.getExpression(), "");
Optional<EncryptTable> encryptTable = rule.findEncryptTable(tableName);
if (encryptTable.isPresent() && encryptTable.get().isEncryptColumn(each.getIdentifier().getValue())) {
result.add(buildSubstitutableColumnNameToken(encryptTable.get().getEncryptColumn(each.getIdentifier().getValue()), each, whereSegments, databaseType));
Optional<EncryptTable> encryptTable = rule.findEncryptTable(each.getColumnBoundInfo().getOriginalTable().getValue());
if (encryptTable.isPresent() && encryptTable.get().isEncryptColumn(each.getColumnBoundInfo().getOriginalColumn().getValue())) {
EncryptColumn encryptColumn = encryptTable.get().getEncryptColumn(each.getColumnBoundInfo().getOriginalColumn().getValue());
result.add(buildSubstitutableColumnNameToken(encryptColumn, each, whereSegments, databaseType));
}
}
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,14 @@ public static UpdateStatementContext createUpdateStatementContext() {
}

private static WhereSegment createWhereSegment() {
BinaryOperationExpression nameExpression = new BinaryOperationExpression(10, 24,
new ColumnSegment(10, 13, new IdentifierValue("name")), new LiteralExpressionSegment(18, 22, "LiLei"), "=", "name = 'LiLei'");
BinaryOperationExpression pwdExpression = new BinaryOperationExpression(30, 44,
new ColumnSegment(30, 32, new IdentifierValue("pwd")), new LiteralExpressionSegment(40, 45, "123456"), "=", "pwd = '123456'");
ColumnSegment nameColumnSegment = new ColumnSegment(10, 13, new IdentifierValue("name"));
nameColumnSegment.setColumnBoundInfo(
new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_db")), new IdentifierValue("t_user"), new IdentifierValue("name")));
BinaryOperationExpression nameExpression = new BinaryOperationExpression(10, 24, nameColumnSegment, new LiteralExpressionSegment(18, 22, "LiLei"), "=", "name = 'LiLei'");
ColumnSegment pwdColumnSegment = new ColumnSegment(30, 32, new IdentifierValue("pwd"));
pwdColumnSegment.setColumnBoundInfo(
new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_db")), new IdentifierValue("t_user"), new IdentifierValue("pwd")));
BinaryOperationExpression pwdExpression = new BinaryOperationExpression(30, 44, pwdColumnSegment, new LiteralExpressionSegment(40, 45, "123456"), "=", "pwd = '123456'");
return new WhereSegment(0, 0, new BinaryOperationExpression(0, 0, nameExpression, pwdExpression, "AND", "name = 'LiLei' AND pwd = '123456'"));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.junit.jupiter.api.Test;

import java.util.Collection;
import java.util.Collections;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
Expand All @@ -41,13 +40,11 @@ void setup() {

@Test
void assertIsGenerateSQLToken() {
generator.setSchemas(Collections.emptyMap());
assertTrue(generator.isGenerateSQLToken(EncryptGeneratorFixtureBuilder.createUpdateStatementContext()));
}

@Test
void assertGenerateSQLTokenFromGenerateNewSQLToken() {
generator.setSchemas(Collections.emptyMap());
Collection<SQLToken> substitutableColumnNameTokens = generator.generateSQLTokens(EncryptGeneratorFixtureBuilder.createUpdateStatementContext());
assertThat(substitutableColumnNameTokens.size(), is(1));
assertThat(((SubstitutableColumnNameToken) substitutableColumnNameTokens.iterator().next()).toString(null), is("pwd_assist"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,12 @@
scenario-comments="Test join contains some encrypt columns in multi tables when use encrypt feature.">
<assertion expected-data-source-name="read_dataset" />
</test-case>

<test-case sql="SELECT u.* FROM t_user u WHERE u.telephone = (SELECT DISTINCT telephone FROM t_merchant WHERE telephone = ?)" scenario-types="encrypt" db-types="MySQL,PostgreSQL,openGauss">
<assertion parameters="86100000001:String" expected-data-source-name="read_dataset" />
</test-case>

<test-case sql="SELECT * FROM t_user WHERE telephone = (SELECT DISTINCT telephone FROM t_merchant WHERE telephone = ?)" scenario-types="encrypt" db-types="MySQL,PostgreSQL,openGauss">
<assertion parameters="86100000001:String" expected-data-source-name="read_dataset" />
</test-case>
</e2e-test-cases>

0 comments on commit 111d22f

Please sign in to comment.