Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use sql bind info in EncryptInsertPredicateColumnTokenGenerator to avoid wrong column table mapping #34110

Merged
merged 2 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
1. Encrypt: Fixes merge exception without encrypt rule in database - [#33708](https://github.com/apache/shardingsphere/pull/33708)
1. SQL Binder: Fixes the expression segment cannot find the outer table when binding - [#34015](https://github.com/apache/shardingsphere/pull/34015)
1. Proxy: Fixes "ALL PRIVILEGES ON `DB`.*" is not recognized during SELECT privilege verification for MySQL - [#34037](https://github.com/apache/shardingsphere/pull/34037)
1. Encrypt: Use sql bind info in EncryptInsertPredicateColumnTokenGenerator to avoid wrong column table mapping - [#34110](https://github.com/apache/shardingsphere/pull/34110)

### Change Logs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,21 @@
import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.CollectionSQLTokenGenerator;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.aware.SchemaMetaDataAware;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken;

import java.util.Collection;
import java.util.Map;

/**
* Insert predicate column token generator for encrypt.
*/
@HighFrequencyInvocation
@RequiredArgsConstructor
@Setter
public final class EncryptInsertPredicateColumnTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext>, SchemaMetaDataAware {
public final class EncryptInsertPredicateColumnTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext> {

private final EncryptRule rule;

private Map<String, ShardingSphereSchema> schemas;

private ShardingSphereSchema defaultSchema;

@Override
public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()
Expand All @@ -54,8 +47,6 @@ public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext)
@Override
public Collection<SQLToken> generateSQLTokens(final SQLStatementContext sqlStatementContext) {
EncryptPredicateColumnTokenGenerator generator = new EncryptPredicateColumnTokenGenerator(rule);
generator.setSchemas(schemas);
generator.setDefaultSchema(defaultSchema);
return generator.generateSQLTokens(((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,23 @@
import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ColumnProjection;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.CollectionSQLTokenGenerator;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.aware.SchemaMetaDataAware;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.generic.SubstitutableColumnNameToken;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ExpressionExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BinaryOperationExpression;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.AndPredicate;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ExpressionExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;

import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.Map;
import java.util.Optional;

/**
Expand All @@ -59,14 +55,10 @@
@HighFrequencyInvocation
@RequiredArgsConstructor
@Setter
public final class EncryptPredicateColumnTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext>, SchemaMetaDataAware {
public final class EncryptPredicateColumnTokenGenerator implements CollectionSQLTokenGenerator<SQLStatementContext> {

private final EncryptRule rule;

private Map<String, ShardingSphereSchema> schemas;

private ShardingSphereSchema defaultSchema;

@Override
public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof WhereAvailable && !((WhereAvailable) sqlStatementContext).getWhereSegments().isEmpty();
Expand All @@ -77,19 +69,16 @@ public Collection<SQLToken> generateSQLTokens(final SQLStatementContext sqlState
Collection<SelectStatementContext> allSubqueryContexts = SQLStatementContextExtractor.getAllSubqueryContexts(sqlStatementContext);
Collection<WhereSegment> whereSegments = SQLStatementContextExtractor.getWhereSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
Collection<ColumnSegment> columnSegments = SQLStatementContextExtractor.getColumnSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
ShardingSphereSchema schema = ((TableAvailable) sqlStatementContext).getTablesContext().getSchemaName().map(schemas::get).orElseGet(() -> defaultSchema);
Map<String, String> columnExpressionTableNames = ((TableAvailable) sqlStatementContext).getTablesContext().findTableNames(columnSegments, schema);
return generateSQLTokens(columnSegments, columnExpressionTableNames, whereSegments, sqlStatementContext.getDatabaseType());
return generateSQLTokens(columnSegments, whereSegments, sqlStatementContext.getDatabaseType());
}

private Collection<SQLToken> generateSQLTokens(final Collection<ColumnSegment> columnSegments, final Map<String, String> columnExpressionTableNames,
final Collection<WhereSegment> whereSegments, final DatabaseType databaseType) {
private Collection<SQLToken> generateSQLTokens(final Collection<ColumnSegment> columnSegments, final Collection<WhereSegment> whereSegments, final DatabaseType databaseType) {
Collection<SQLToken> result = new LinkedList<>();
for (ColumnSegment each : columnSegments) {
String tableName = columnExpressionTableNames.getOrDefault(each.getExpression(), "");
Optional<EncryptTable> encryptTable = rule.findEncryptTable(tableName);
if (encryptTable.isPresent() && encryptTable.get().isEncryptColumn(each.getIdentifier().getValue())) {
result.add(buildSubstitutableColumnNameToken(encryptTable.get().getEncryptColumn(each.getIdentifier().getValue()), each, whereSegments, databaseType));
Optional<EncryptTable> encryptTable = rule.findEncryptTable(each.getColumnBoundInfo().getOriginalTable().getValue());
if (encryptTable.isPresent() && encryptTable.get().isEncryptColumn(each.getColumnBoundInfo().getOriginalColumn().getValue())) {
EncryptColumn encryptColumn = encryptTable.get().getEncryptColumn(each.getColumnBoundInfo().getOriginalColumn().getValue());
result.add(buildSubstitutableColumnNameToken(encryptColumn, each, whereSegments, databaseType));
}
}
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,14 @@ public static UpdateStatementContext createUpdateStatementContext() {
}

private static WhereSegment createWhereSegment() {
BinaryOperationExpression nameExpression = new BinaryOperationExpression(10, 24,
new ColumnSegment(10, 13, new IdentifierValue("name")), new LiteralExpressionSegment(18, 22, "LiLei"), "=", "name = 'LiLei'");
BinaryOperationExpression pwdExpression = new BinaryOperationExpression(30, 44,
new ColumnSegment(30, 32, new IdentifierValue("pwd")), new LiteralExpressionSegment(40, 45, "123456"), "=", "pwd = '123456'");
ColumnSegment nameColumnSegment = new ColumnSegment(10, 13, new IdentifierValue("name"));
nameColumnSegment.setColumnBoundInfo(
new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_db")), new IdentifierValue("t_user"), new IdentifierValue("name")));
BinaryOperationExpression nameExpression = new BinaryOperationExpression(10, 24, nameColumnSegment, new LiteralExpressionSegment(18, 22, "LiLei"), "=", "name = 'LiLei'");
ColumnSegment pwdColumnSegment = new ColumnSegment(30, 32, new IdentifierValue("pwd"));
pwdColumnSegment.setColumnBoundInfo(
new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_db")), new IdentifierValue("t_user"), new IdentifierValue("pwd")));
BinaryOperationExpression pwdExpression = new BinaryOperationExpression(30, 44, pwdColumnSegment, new LiteralExpressionSegment(40, 45, "123456"), "=", "pwd = '123456'");
return new WhereSegment(0, 0, new BinaryOperationExpression(0, 0, nameExpression, pwdExpression, "AND", "name = 'LiLei' AND pwd = '123456'"));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.junit.jupiter.api.Test;

import java.util.Collection;
import java.util.Collections;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
Expand All @@ -41,13 +40,11 @@ void setup() {

@Test
void assertIsGenerateSQLToken() {
generator.setSchemas(Collections.emptyMap());
assertTrue(generator.isGenerateSQLToken(EncryptGeneratorFixtureBuilder.createUpdateStatementContext()));
}

@Test
void assertGenerateSQLTokenFromGenerateNewSQLToken() {
generator.setSchemas(Collections.emptyMap());
Collection<SQLToken> substitutableColumnNameTokens = generator.generateSQLTokens(EncryptGeneratorFixtureBuilder.createUpdateStatementContext());
assertThat(substitutableColumnNameTokens.size(), is(1));
assertThat(((SubstitutableColumnNameToken) substitutableColumnNameTokens.iterator().next()).toString(null), is("pwd_assist"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,12 @@
scenario-comments="Test join contains some encrypt columns in multi tables when use encrypt feature.">
<assertion expected-data-source-name="read_dataset" />
</test-case>

<test-case sql="SELECT u.* FROM t_user u WHERE u.telephone = (SELECT DISTINCT telephone FROM t_merchant WHERE telephone = ?)" scenario-types="encrypt" db-types="MySQL,PostgreSQL,openGauss">
<assertion parameters="86100000001:String" expected-data-source-name="read_dataset" />
</test-case>

<test-case sql="SELECT * FROM t_user WHERE telephone = (SELECT DISTINCT telephone FROM t_merchant WHERE telephone = ?)" scenario-types="encrypt" db-types="MySQL,PostgreSQL,openGauss">
<assertion parameters="86100000001:String" expected-data-source-name="read_dataset" />
</test-case>
</e2e-test-cases>
Loading