diff --git a/proxy/frontend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/prepare/MySQLComStmtPrepareParameterMarkerExtractor.java b/proxy/frontend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/prepare/MySQLComStmtPrepareParameterMarkerExtractor.java index 2633afdb893cc..48bf68426d5cc 100644 --- a/proxy/frontend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/prepare/MySQLComStmtPrepareParameterMarkerExtractor.java +++ b/proxy/frontend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/prepare/MySQLComStmtPrepareParameterMarkerExtractor.java @@ -19,6 +19,8 @@ import lombok.AccessLevel; import lombok.NoArgsConstructor; +import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions; +import org.apache.shardingsphere.infra.exception.postgresql.exception.metadata.ColumnNotFoundException; import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn; import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema; import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereTable; @@ -32,8 +34,8 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.LinkedList; import java.util.List; -import java.util.ListIterator; import java.util.stream.Collectors; /** @@ -51,42 +53,54 @@ public final class MySQLComStmtPrepareParameterMarkerExtractor { * @return corresponding columns of parameter markers */ public static List findColumnsOfParameterMarkers(final SQLStatement sqlStatement, final ShardingSphereSchema schema) { - return sqlStatement instanceof InsertStatement ? findColumnsOfParameterMarkersForInsert((InsertStatement) sqlStatement, schema) : Collections.emptyList(); + return sqlStatement instanceof InsertStatement && ((InsertStatement) sqlStatement).getTable().isPresent() + ? findColumnsOfParameterMarkersForInsert((InsertStatement) sqlStatement, schema) + : Collections.emptyList(); } private static List findColumnsOfParameterMarkersForInsert(final InsertStatement insertStatement, final ShardingSphereSchema schema) { ShardingSphereTable table = schema.getTable(insertStatement.getTable().map(optional -> optional.getTableName().getIdentifier().getValue()).orElse("")); List columnNamesOfInsert = getColumnNamesOfInsertStatement(insertStatement, table); + List result = getParameterMarkerColumns(insertStatement, table, columnNamesOfInsert); + insertStatement.getOnDuplicateKeyColumns().ifPresent(optional -> result.addAll(getOnDuplicateKeyParameterMarkerColumns(optional.getColumns(), table))); + return result; + } + + private static List getColumnNamesOfInsertStatement(final InsertStatement insertStatement, final ShardingSphereTable table) { + return insertStatement.getColumns().isEmpty() ? table.getColumnNames() : insertStatement.getColumns().stream().map(each -> each.getIdentifier().getValue()).collect(Collectors.toList()); + } + + private static List getParameterMarkerColumns(final InsertStatement insertStatement, final ShardingSphereTable table, final List columnNamesOfInsert) { List result = new ArrayList<>(insertStatement.getParameterMarkerSegments().size()); for (InsertValuesSegment each : insertStatement.getValues()) { - ListIterator listIterator = each.getValues().listIterator(); - for (int columnIndex = listIterator.nextIndex(); listIterator.hasNext(); columnIndex = listIterator.nextIndex()) { - ExpressionSegment value = listIterator.next(); - if (!(value instanceof ParameterMarkerExpressionSegment)) { - continue; - } - String columnName = columnNamesOfInsert.get(columnIndex); - ShardingSphereColumn column = table.getColumn(columnName); - result.add(column); - } + result.addAll(getParameterMarkerColumns(table, columnNamesOfInsert, each)); } - insertStatement.getOnDuplicateKeyColumns().ifPresent(optional -> appendOnDuplicateKeyParameterMarkers(optional.getColumns(), table, result)); return result; } - private static List getColumnNamesOfInsertStatement(final InsertStatement insertStatement, final ShardingSphereTable table) { - return insertStatement.getColumns().isEmpty() ? table.getColumnNames() : insertStatement.getColumns().stream().map(each -> each.getIdentifier().getValue()).collect(Collectors.toList()); + private static List getParameterMarkerColumns(final ShardingSphereTable table, final List columnNamesOfInsert, final InsertValuesSegment segment) { + List result = new LinkedList<>(); + int index = 0; + for (ExpressionSegment each : segment.getValues()) { + if (each instanceof ParameterMarkerExpressionSegment) { + String columnName = columnNamesOfInsert.get(index); + ShardingSpherePreconditions.checkState(table.containsColumn(columnName), () -> new ColumnNotFoundException(table.getName(), columnName)); + result.add(table.getColumn(columnName)); + } + index++; + } + return result; } - private static void appendOnDuplicateKeyParameterMarkers(final Collection onDuplicateKeyColumns, - final ShardingSphereTable table, final List result) { + private static List getOnDuplicateKeyParameterMarkerColumns(final Collection onDuplicateKeyColumns, final ShardingSphereTable table) { + List result = new LinkedList<>(); for (ColumnAssignmentSegment each : onDuplicateKeyColumns) { - if (!(each.getValue() instanceof ParameterMarkerExpressionSegment)) { - continue; + if (each.getValue() instanceof ParameterMarkerExpressionSegment) { + String columnName = each.getColumns().iterator().next().getIdentifier().getValue(); + ShardingSpherePreconditions.checkState(table.containsColumn(columnName), () -> new ColumnNotFoundException(table.getName(), columnName)); + result.add(table.getColumn(columnName)); } - String columnName = each.getColumns().iterator().next().getIdentifier().getValue(); - ShardingSphereColumn column = table.getColumn(columnName); - result.add(column); } + return result; } } diff --git a/proxy/frontend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/prepare/MySQLComStmtPrepareParameterMarkerExtractorTest.java b/proxy/frontend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/prepare/MySQLComStmtPrepareParameterMarkerExtractorTest.java index ad048e5078cc4..51a9664b9d10f 100644 --- a/proxy/frontend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/prepare/MySQLComStmtPrepareParameterMarkerExtractorTest.java +++ b/proxy/frontend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/binary/prepare/MySQLComStmtPrepareParameterMarkerExtractorTest.java @@ -37,11 +37,13 @@ class MySQLComStmtPrepareParameterMarkerExtractorTest { + private final DatabaseType databaseType = TypedSPILoader.getService(DatabaseType.class, "MySQL"); + @Test void assertFindColumnsOfParameterMarkersForInsertStatement() { - String sql = "insert into user (id, name, age) values (1, ?, ?), (?, 'bar', ?)"; - SQLStatement sqlStatement = new ShardingSphereSQLParserEngine(TypedSPILoader.getService(DatabaseType.class, "MySQL"), new CacheOption(0, 0L), new CacheOption(0, 0L)).parse(sql, false); - ShardingSphereSchema schema = prepareSchema(); + String sql = "INSERT INTO user (id, name, age) VALUES (1, ?, ?), (?, 'bar', ?)"; + SQLStatement sqlStatement = new ShardingSphereSQLParserEngine(databaseType, new CacheOption(0, 0L), new CacheOption(0, 0L)).parse(sql, false); + ShardingSphereSchema schema = createSchema(); List actual = MySQLComStmtPrepareParameterMarkerExtractor.findColumnsOfParameterMarkers(sqlStatement, schema); assertThat(actual.get(0), is(schema.getTable("user").getColumn("name"))); assertThat(actual.get(1), is(schema.getTable("user").getColumn("age"))); @@ -49,13 +51,11 @@ void assertFindColumnsOfParameterMarkersForInsertStatement() { assertThat(actual.get(3), is(schema.getTable("user").getColumn("age"))); } - private ShardingSphereSchema prepareSchema() { + private ShardingSphereSchema createSchema() { ShardingSphereTable table = new ShardingSphereTable("user", Arrays.asList( new ShardingSphereColumn("id", Types.BIGINT, true, false, false, false, true, false), new ShardingSphereColumn("name", Types.VARCHAR, false, false, false, false, false, false), new ShardingSphereColumn("age", Types.SMALLINT, false, false, false, false, true, false)), Collections.emptyList(), Collections.emptyList()); - ShardingSphereSchema result = new ShardingSphereSchema("foo_db"); - result.putTable(table); - return result; + return new ShardingSphereSchema("foo_db", Collections.singleton(table), Collections.emptyList()); } } diff --git a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutor.java b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutor.java index d6bdd6a1c3e39..0a5db3094daa6 100644 --- a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutor.java +++ b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutor.java @@ -135,9 +135,8 @@ private void describeInsertStatementByDatabaseMetaData(final PostgreSQLServerPre preparedStatement.setRowDescription(returningSegment.map(returning -> describeReturning(returning, table)).orElseGet(PostgreSQLNoDataPacket::getInstance)); int parameterMarkerIndex = 0; for (InsertValuesSegment each : insertStatement.getValues()) { - ListIterator listIterator = each.getValues().listIterator(); - for (int columnIndex = listIterator.nextIndex(); listIterator.hasNext(); columnIndex = listIterator.nextIndex()) { - ExpressionSegment value = listIterator.next(); + for (int i = 0; i < each.getValues().size(); i++) { + ExpressionSegment value = each.getValues().get(i); if (!(value instanceof ParameterMarkerExpressionSegment)) { continue; } @@ -145,10 +144,9 @@ private void describeInsertStatementByDatabaseMetaData(final PostgreSQLServerPre parameterMarkerIndex++; continue; } - String columnName = columnNamesOfInsert.get(columnIndex); - ShardingSphereColumn column = table.getColumn(columnName); - ShardingSpherePreconditions.checkNotNull(column, () -> new ColumnNotFoundException(logicTableName, columnName)); - preparedStatement.getParameterTypes().set(parameterMarkerIndex++, PostgreSQLColumnType.valueOfJDBCType(column.getDataType())); + String columnName = columnNamesOfInsert.get(i); + ShardingSpherePreconditions.checkState(table.containsColumn(columnName), () -> new ColumnNotFoundException(logicTableName, columnName)); + preparedStatement.getParameterTypes().set(parameterMarkerIndex++, PostgreSQLColumnType.valueOfJDBCType(table.getColumn(columnName).getDataType())); } } } @@ -179,8 +177,8 @@ private PostgreSQLRowDescriptionPacket describeReturning(final ReturningSegment Collection result = new LinkedList<>(); for (ProjectionSegment each : returningSegment.getProjections().getProjections()) { if (each instanceof ShorthandProjectionSegment) { - table.getAllColumns().stream().map(column -> new PostgreSQLColumnDescription(column.getName(), 0, column.getDataType(), estimateColumnLength(column.getDataType()), "")) - .forEach(result::add); + table.getAllColumns().stream() + .map(column -> new PostgreSQLColumnDescription(column.getName(), 0, column.getDataType(), estimateColumnLength(column.getDataType()), "")).forEach(result::add); } if (each instanceof ColumnProjectionSegment) { ColumnProjectionSegment segment = (ColumnProjectionSegment) each; diff --git a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java index b79a9843b5aaf..74f57509cdd48 100644 --- a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java +++ b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/describe/PostgreSQLComDescribeExecutorTest.java @@ -106,8 +106,9 @@ class PostgreSQLComDescribeExecutorTest { private static final String TABLE_NAME = "t_order"; - private static final SQLParserEngine SQL_PARSER_ENGINE = new ShardingSphereSQLParserEngine( - TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"), new CacheOption(2000, 65535L), new CacheOption(128, 1024L)); + private static final DatabaseType DATABASE_TYPE = TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"); + + private static final SQLParserEngine SQL_PARSER_ENGINE = new ShardingSphereSQLParserEngine(DATABASE_TYPE, new CacheOption(2000, 65535L), new CacheOption(128, 1024L)); @Mock private PortalContext portalContext; @@ -395,9 +396,9 @@ private ContextManager mockContextManager() { new ShardingSphereColumn("c", Types.CHAR, true, false, false, true, false, false), new ShardingSphereColumn("pad", Types.CHAR, true, false, false, true, false, false)); when(schema.getTable(TABLE_NAME)).thenReturn(new ShardingSphereTable(TABLE_NAME, columns, Collections.emptyList(), Collections.emptyList())); - when(result.getMetaDataContexts().getMetaData().getDatabase(DATABASE_NAME).getProtocolType()).thenReturn(TypedSPILoader.getService(DatabaseType.class, "PostgreSQL")); + when(result.getMetaDataContexts().getMetaData().getDatabase(DATABASE_NAME).getProtocolType()).thenReturn(DATABASE_TYPE); StorageUnit storageUnit = mock(StorageUnit.class, RETURNS_DEEP_STUBS); - when(storageUnit.getStorageType()).thenReturn(TypedSPILoader.getService(DatabaseType.class, "PostgreSQL")); + when(storageUnit.getStorageType()).thenReturn(DATABASE_TYPE); when(result.getMetaDataContexts().getMetaData().getDatabase(DATABASE_NAME).getResourceMetaData().getStorageUnits()).thenReturn(Collections.singletonMap("ds_0", storageUnit)); when(result.getMetaDataContexts().getMetaData().containsDatabase(DATABASE_NAME)).thenReturn(true); when(result.getMetaDataContexts().getMetaData().getDatabase(DATABASE_NAME).containsSchema("public")).thenReturn(true);