Skip to content

Commit

Permalink
Refactor ShardingParameterRewriterBuilder (#33614)
Browse files Browse the repository at this point in the history
* Refactor ShardingParameterRewriterBuilder

* Refactor ShardingParameterRewriterBuilder
  • Loading branch information
terrymanu authored Nov 11, 2024
1 parent 9d9af4d commit afd2198
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ public void decorate(final ShardingRule rule, final ConfigurationProperties prop
return;
}
if (!sqlRewriteContext.getParameters().isEmpty()) {
Collection<ParameterRewriter> parameterRewriters =
new ShardingParameterRewriterBuilder(routeContext, sqlRewriteContext.getDatabase().getSchemas(), sqlStatementContext).getParameterRewriters();
Collection<ParameterRewriter> parameterRewriters = new ShardingParameterRewriterBuilder(routeContext, sqlStatementContext).getParameterRewriters();
rewriteParameters(sqlRewriteContext, parameterRewriters);
}
sqlRewriteContext.addSQLTokenGenerators(new ShardingTokenGenerateBuilder(rule, routeContext, sqlStatementContext).getSQLTokenGenerators());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,14 @@

import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.ParameterRewriter;
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.ParameterRewriterBuilder;
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.keygen.GeneratedKeyInsertValueParameterRewriter;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.aware.RouteContextAware;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.aware.SchemaMetaDataAware;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.sharding.rewrite.parameter.impl.ShardingPaginationParameterRewriter;

import java.util.Collection;
import java.util.LinkedList;
import java.util.Map;

/**
* Parameter rewriter builder for sharding.
Expand All @@ -40,25 +36,17 @@ public final class ShardingParameterRewriterBuilder implements ParameterRewriter

private final RouteContext routeContext;

private final Map<String, ShardingSphereSchema> schemas;

private final SQLStatementContext sqlStatementContext;

@Override
public Collection<ParameterRewriter> getParameterRewriters() {
Collection<ParameterRewriter> result = new LinkedList<>();
addParameterRewriter(result, new GeneratedKeyInsertValueParameterRewriter());
addParameterRewriter(result, new ShardingPaginationParameterRewriter());
addParameterRewriter(result, new ShardingPaginationParameterRewriter(routeContext));
return result;
}

private void addParameterRewriter(final Collection<ParameterRewriter> paramRewriters, final ParameterRewriter toBeAddedParamRewriter) {
if (toBeAddedParamRewriter instanceof SchemaMetaDataAware) {
((SchemaMetaDataAware) toBeAddedParamRewriter).setSchemas(schemas);
}
if (toBeAddedParamRewriter instanceof RouteContextAware) {
((RouteContextAware) toBeAddedParamRewriter).setRouteContext(routeContext);
}
if (toBeAddedParamRewriter.isNeedRewrite(sqlStatementContext)) {
paramRewriters.add(toBeAddedParamRewriter);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

package org.apache.shardingsphere.sharding.rewrite.parameter.impl;

import lombok.Setter;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.aware.RouteContextAware;
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.binder.context.segment.select.pagination.PaginationContext;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
Expand All @@ -32,10 +31,10 @@
/**
* Sharding pagination parameter rewriter.
*/
@Setter
public final class ShardingPaginationParameterRewriter implements ParameterRewriter, RouteContextAware {
@RequiredArgsConstructor
public final class ShardingPaginationParameterRewriter implements ParameterRewriter {

private RouteContext routeContext;
private final RouteContext routeContext;

@Override
public boolean isNeedRewrite(final SQLStatementContext sqlStatementContext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@
package org.apache.shardingsphere.sharding.rewrite.parameter;

import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.ParameterRewriter;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.sharding.rewrite.parameter.impl.ShardingPaginationParameterRewriter;
import org.junit.jupiter.api.Test;

import java.util.Collection;
import java.util.Collections;

import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
Expand All @@ -41,8 +39,7 @@ class ShardingParameterRewriterBuilderTest {
void assertGetParameterRewritersWhenPaginationIsNeedRewrite() {
SelectStatementContext statementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(statementContext.getPaginationContext().isHasPagination()).thenReturn(true);
Collection<ParameterRewriter> actual = new ShardingParameterRewriterBuilder(
mock(RouteContext.class), Collections.singletonMap("test", mock(ShardingSphereSchema.class)), statementContext).getParameterRewriters();
Collection<ParameterRewriter> actual = new ShardingParameterRewriterBuilder(mock(RouteContext.class), statementContext).getParameterRewriters();
assertThat(actual.size(), is(1));
assertThat(actual.iterator().next(), instanceOf(ShardingPaginationParameterRewriter.class));
}
Expand All @@ -53,6 +50,6 @@ void assertGetParameterRewritersWhenPaginationIsNotNeedRewrite() {
when(routeContext.isSingleRouting()).thenReturn(true);
SelectStatementContext statementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(statementContext.getPaginationContext().isHasPagination()).thenReturn(true);
assertTrue(new ShardingParameterRewriterBuilder(routeContext, Collections.singletonMap("test", mock(ShardingSphereSchema.class)), statementContext).getParameterRewriters().isEmpty());
assertTrue(new ShardingParameterRewriterBuilder(routeContext, statementContext).getParameterRewriters().isEmpty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,8 @@ class ShardingPaginationParameterRewriterTest {

@Test
void assertIsNeedRewrite() {
ShardingPaginationParameterRewriter paramRewriter = new ShardingPaginationParameterRewriter();
RouteContext routeContext = mock(RouteContext.class);
paramRewriter.setRouteContext(routeContext);
ShardingPaginationParameterRewriter paramRewriter = new ShardingPaginationParameterRewriter(routeContext);
InsertStatementContext insertStatementContext = mock(InsertStatementContext.class);
assertFalse(paramRewriter.isNeedRewrite(insertStatementContext));
SelectStatementContext selectStatementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
Expand All @@ -81,7 +80,7 @@ void assertRewrite() {
when(pagination.getRevisedOffset()).thenReturn(TEST_REVISED_OFFSET);
when(pagination.getRevisedRowCount(selectStatementContext)).thenReturn(TEST_REVISED_ROW_COUNT);
when(selectStatementContext.getPaginationContext()).thenReturn(pagination);
new ShardingPaginationParameterRewriter().rewrite(standardParamBuilder, selectStatementContext, null);
new ShardingPaginationParameterRewriter(null).rewrite(standardParamBuilder, selectStatementContext, null);
assertTrue(addOffsetParametersFlag);
assertTrue(addRowCountParameterFlag);
}
Expand Down

0 comments on commit afd2198

Please sign in to comment.