Skip to content

Commit

Permalink
Refactor ParameterRewritersBuilder (#33617)
Browse files Browse the repository at this point in the history
  • Loading branch information
terrymanu authored Nov 11, 2024
1 parent adb7f79 commit 96443c5
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContextDecorator;
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.ParameterRewriter;
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.ParameterRewritersBuilder;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.builder.SQLTokenGeneratorBuilder;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
Expand All @@ -55,7 +56,7 @@ public void decorate(final EncryptRule rule, final ConfigurationProperties props
Collection<EncryptCondition> encryptConditions = createEncryptConditions(rule, sqlRewriteContext);
String databaseName = sqlRewriteContext.getDatabase().getName();
if (!sqlRewriteContext.getParameters().isEmpty()) {
Collection<ParameterRewriter> parameterRewriters = new EncryptParameterRewritersRegistry(rule, databaseName, sqlStatementContext, encryptConditions).getParameterRewriters();
Collection<ParameterRewriter> parameterRewriters = new ParameterRewritersBuilder(sqlStatementContext).build(new EncryptParameterRewritersRegistry(rule, databaseName, encryptConditions));
rewriteParameters(sqlRewriteContext, parameterRewriters);
}
SQLTokenGeneratorBuilder sqlTokenGeneratorBuilder = new EncryptTokenGenerateBuilder(rule, sqlStatementContext, encryptConditions, databaseName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@
import org.apache.shardingsphere.encrypt.rewrite.parameter.rewriter.EncryptInsertValueParameterRewriter;
import org.apache.shardingsphere.encrypt.rewrite.parameter.rewriter.EncryptPredicateParameterRewriter;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.ParameterRewriter;
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.ParameterRewritersRegistry;
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.ParameterRewritersBuilder;

import java.util.Arrays;
import java.util.Collection;

/**
Expand All @@ -42,13 +41,11 @@ public final class EncryptParameterRewritersRegistry implements ParameterRewrite

private final String databaseName;

private final SQLStatementContext sqlStatementContext;

private final Collection<EncryptCondition> encryptConditions;

@Override
public Collection<ParameterRewriter> getParameterRewriters() {
return new ParameterRewritersBuilder().build(sqlStatementContext,
return Arrays.asList(
new EncryptAssignmentParameterRewriter(rule, databaseName),
new EncryptPredicateParameterRewriter(rule, databaseName, encryptConditions),
new EncryptInsertPredicateParameterRewriter(rule, databaseName, encryptConditions),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,20 @@

package org.apache.shardingsphere.encrypt.rewrite.parameter;

import org.apache.shardingsphere.encrypt.rewrite.parameter.rewriter.EncryptAssignmentParameterRewriter;
import org.apache.shardingsphere.encrypt.rewrite.parameter.rewriter.EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter;
import org.apache.shardingsphere.encrypt.rewrite.parameter.rewriter.EncryptInsertPredicateParameterRewriter;
import org.apache.shardingsphere.encrypt.rewrite.parameter.rewriter.EncryptInsertValueParameterRewriter;
import org.apache.shardingsphere.encrypt.rewrite.parameter.rewriter.EncryptPredicateParameterRewriter;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.database.core.DefaultDatabase;
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.ParameterRewriter;
import org.junit.jupiter.api.Test;

import java.util.Collection;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
Expand All @@ -42,8 +47,12 @@ void assertGetParameterRewriters() {
when(rule.findEncryptTable("foo_tbl").isPresent()).thenReturn(true);
SelectStatementContext sqlStatementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(sqlStatementContext.getTablesContext().getTableNames()).thenReturn(Collections.singleton("foo_tbl"));
Collection<ParameterRewriter> actual = new EncryptParameterRewritersRegistry(rule, DefaultDatabase.LOGIC_NAME, sqlStatementContext, Collections.emptyList()).getParameterRewriters();
assertThat(actual.size(), is(1));
assertThat(actual.iterator().next(), instanceOf(EncryptPredicateParameterRewriter.class));
List<ParameterRewriter> actual = new ArrayList<>(new EncryptParameterRewritersRegistry(rule, DefaultDatabase.LOGIC_NAME, Collections.emptyList()).getParameterRewriters());
assertThat(actual.size(), is(5));
assertThat(actual.get(0), instanceOf(EncryptAssignmentParameterRewriter.class));
assertThat(actual.get(1), instanceOf(EncryptPredicateParameterRewriter.class));
assertThat(actual.get(2), instanceOf(EncryptInsertPredicateParameterRewriter.class));
assertThat(actual.get(3), instanceOf(EncryptInsertValueParameterRewriter.class));
assertThat(actual.get(4), instanceOf(EncryptInsertOnDuplicateKeyUpdateValueParameterRewriter.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContextDecorator;
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.ParameterRewriter;
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.ParameterRewritersBuilder;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.sharding.constant.ShardingOrder;
import org.apache.shardingsphere.sharding.rewrite.parameter.ShardingParameterRewritersRegistry;
Expand All @@ -51,7 +52,7 @@ public void decorate(final ShardingRule rule, final ConfigurationProperties prop
return;
}
if (!sqlRewriteContext.getParameters().isEmpty()) {
Collection<ParameterRewriter> parameterRewriters = new ShardingParameterRewritersRegistry(routeContext, sqlStatementContext).getParameterRewriters();
Collection<ParameterRewriter> parameterRewriters = new ParameterRewritersBuilder(sqlStatementContext).build(new ShardingParameterRewritersRegistry(routeContext));
rewriteParameters(sqlRewriteContext, parameterRewriters);
}
sqlRewriteContext.addSQLTokenGenerators(new ShardingTokenGenerateBuilder(rule, routeContext, sqlStatementContext).getSQLTokenGenerators());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
package org.apache.shardingsphere.sharding.rewrite.parameter;

import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.ParameterRewriter;
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.ParameterRewritersRegistry;
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.ParameterRewritersBuilder;
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.keygen.GeneratedKeyInsertValueParameterRewriter;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.sharding.rewrite.parameter.impl.ShardingPaginationParameterRewriter;

import java.util.Arrays;
import java.util.Collection;

/**
Expand All @@ -36,10 +35,8 @@ public final class ShardingParameterRewritersRegistry implements ParameterRewrit

private final RouteContext routeContext;

private final SQLStatementContext sqlStatementContext;

@Override
public Collection<ParameterRewriter> getParameterRewriters() {
return new ParameterRewritersBuilder().build(sqlStatementContext, new GeneratedKeyInsertValueParameterRewriter(), new ShardingPaginationParameterRewriter(routeContext));
return Arrays.asList(new GeneratedKeyInsertValueParameterRewriter(), new ShardingPaginationParameterRewriter(routeContext));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.ParameterRewriter;
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.keygen.GeneratedKeyInsertValueParameterRewriter;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.sharding.rewrite.parameter.impl.ShardingPaginationParameterRewriter;
import org.junit.jupiter.api.Test;

import java.util.Collection;
import java.util.ArrayList;
import java.util.List;

import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
Expand All @@ -38,8 +40,9 @@ class ShardingParameterRewritersRegistryTest {
void assertGetParameterRewriters() {
SelectStatementContext statementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(statementContext.getPaginationContext().isHasPagination()).thenReturn(true);
Collection<ParameterRewriter> actual = new ShardingParameterRewritersRegistry(mock(RouteContext.class), statementContext).getParameterRewriters();
assertThat(actual.size(), is(1));
assertThat(actual.iterator().next(), instanceOf(ShardingPaginationParameterRewriter.class));
List<ParameterRewriter> actual = new ArrayList<>(new ShardingParameterRewritersRegistry(mock(RouteContext.class)).getParameterRewriters());
assertThat(actual.size(), is(2));
assertThat(actual.get(0), instanceOf(GeneratedKeyInsertValueParameterRewriter.class));
assertThat(actual.get(1), instanceOf(ShardingPaginationParameterRewriter.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.shardingsphere.infra.rewrite.parameter.rewriter;

import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;

Expand All @@ -27,18 +28,20 @@
* Parameter rewriters builder.
*/
@HighFrequencyInvocation
@RequiredArgsConstructor
public final class ParameterRewritersBuilder {

private final SQLStatementContext sqlStatementContext;

/**
* Build parameter rewriters.
*
* @param sqlStatementContext SQL statement context
* @param rewriters parameter rewriters
* @param registry parameter rewriters registry
* @return built parameter rewriters
*/
public Collection<ParameterRewriter> build(final SQLStatementContext sqlStatementContext, final ParameterRewriter... rewriters) {
public Collection<ParameterRewriter> build(final ParameterRewritersRegistry registry) {
Collection<ParameterRewriter> result = new LinkedList<>();
for (ParameterRewriter each : rewriters) {
for (ParameterRewriter each : registry.getParameterRewriters()) {
if (each.isNeedRewrite(sqlStatementContext)) {
result.add(each);
}
Expand Down

0 comments on commit 96443c5

Please sign in to comment.