diff --git a/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsDataWriter.java b/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsDataWriter.java index aeab788..2e7d312 100644 --- a/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsDataWriter.java +++ b/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsDataWriter.java @@ -46,10 +46,12 @@ import org.secretflow.dataproxy.common.exceptions.DataproxyException; import org.secretflow.dataproxy.common.model.datasource.conn.OdpsConnConfig; import org.secretflow.dataproxy.common.model.datasource.location.OdpsTableInfo; +import org.secretflow.dataproxy.common.utils.JsonUtils; import org.secretflow.dataproxy.manager.DataWriter; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.List; /** @@ -72,6 +74,7 @@ public class OdpsDataWriter implements DataWriter { private boolean isPartitioned = false; + private TableSchema odpsTableSchema = null; private TableTunnel.UploadSession uploadSession = null; private RecordWriter recordWriter = null; @@ -152,7 +155,8 @@ private void initOdps() throws OdpsException, IOException { // init odps client Odps odps = initOdpsClient(this.connConfig); // Pre-processing - preProcessing(odps, connConfig.getProjectName(), tableInfo.tableName(), this.convertToPartitionSpec(tableInfo.partitionSpec())); + PartitionSpec convertPartitionSpec = this.convertToPartitionSpec(tableInfo.partitionSpec()); + preProcessing(odps, connConfig.getProjectName(), tableInfo.tableName(), convertPartitionSpec); // init upload session TableTunnel tunnel = new TableTunnel(odps); @@ -160,7 +164,12 @@ private void initOdps() throws OdpsException, IOException { if (tableInfo.partitionSpec() == null || tableInfo.partitionSpec().isEmpty()) { throw DataproxyException.of(DataproxyErrorCode.INVALID_PARTITION_SPEC, "partitionSpec is empty"); } - PartitionSpec partitionSpec = new PartitionSpec(tableInfo.partitionSpec()); + assert this.odpsTableSchema != null; + List partitionColumns = this.odpsTableSchema.getPartitionColumns(); + PartitionSpec partitionSpec = new PartitionSpec(); + for (Column partitionColumn : partitionColumns) { + partitionSpec.set(partitionColumn.getName(), convertPartitionSpec.get(partitionColumn.getName())); + } uploadSession = tunnel.createUploadSession(connConfig.getProjectName(), tableInfo.tableName(), partitionSpec, overwrite); } else { uploadSession = tunnel.createUploadSession(connConfig.getProjectName(), tableInfo.tableName(), overwrite); @@ -271,7 +280,10 @@ private void preProcessing(Odps odps, String projectName, String tableName, Part } else { log.info("odps table is exists, project: {}, table name: {}", projectName, tableName); } - isPartitioned = odps.tables().get(projectName, tableName).isPartitioned(); + + Table table = odps.tables().get(projectName, tableName); + isPartitioned = table.isPartitioned(); + this.setOdpsTableSchemaIfAbsent(table.getSchema()); if (isPartitioned) { if (partitionSpec == null || partitionSpec.isEmpty()) { @@ -334,8 +346,29 @@ private boolean createOdpsTable(Odps odps, String projectName, String tableName, TableSchema tableSchema = convertToTableSchema(schema); if (partitionSpec != null) { // Infer partitioning field type as string. - partitionSpec.keys().forEach(key -> tableSchema.addPartitionColumn(Column.newBuilder(key, TypeInfoFactory.STRING).build())); + List tableSchemaColumns = tableSchema.getColumns(); + List partitionColumnIndexes = new ArrayList<>(); + ArrayList partitionColumns = new ArrayList<>(); + + for (String key : partitionSpec.keys()) { + if (tableSchema.containsColumn(key)) { + log.info("tableSchemaColumns contains partition column: {}", key); + partitionColumnIndexes.add(tableSchema.getColumnIndex(key)); + partitionColumns.add(tableSchema.getColumn(key)); + } else { + log.info("tableSchemaColumns not contains partition column: {}", key); + partitionColumns.add(Column.newBuilder(key, TypeInfoFactory.STRING).build()); + } + } + + for (int i = 0; i < partitionColumnIndexes.size(); i++) { + tableSchemaColumns.remove(partitionColumnIndexes.get(i) - i); + } + log.info("tableSchemaColumns: {}, partitionColumnIndexes: {}", JsonUtils.toString(tableSchemaColumns), JsonUtils.toString(partitionColumnIndexes)); + tableSchema.setColumns(tableSchemaColumns); + tableSchema.setPartitionColumns(partitionColumns); } + log.info("create odps table schema: {}", JsonUtils.toString(tableSchema)); odps.tables().create(projectName, tableName, tableSchema, "", true, null, OdpsUtil.getSqlFlag(), null); return true; } catch (Exception e) { @@ -355,6 +388,12 @@ private boolean createOdpsPartition(Odps odps, String projectName, String tableN return false; } + private void setOdpsTableSchemaIfAbsent(TableSchema tableSchema) { + if (odpsTableSchema == null) { + this.odpsTableSchema = tableSchema; + } + } + private TableSchema convertToTableSchema(Schema schema) { List columns = schema.getFields().stream().map(this::convertToColumn).toList(); return TableSchema.builder().withColumns(columns).build(); diff --git a/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/rdbms/JdbcDataWriter.java b/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/rdbms/JdbcDataWriter.java index cde59d9..5b28999 100644 --- a/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/rdbms/JdbcDataWriter.java +++ b/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/rdbms/JdbcDataWriter.java @@ -26,11 +26,10 @@ import org.secretflow.dataproxy.common.model.datasource.location.JdbcLocationConfig; import org.secretflow.dataproxy.common.utils.JsonUtils; import org.secretflow.dataproxy.manager.DataWriter; -import org.secretflow.dataproxy.manager.connector.rdbms.adaptor.JdbcParameterBinder; import java.io.IOException; -import java.sql.*; -import java.util.Arrays; +import java.sql.Connection; +import java.sql.SQLException; import java.util.List; /** @@ -90,7 +89,9 @@ protected void initialize(Schema schema) { log.info("[JdbcDataWriter] preSql execute start, sql: {}", JsonUtils.toJSONString(preSqlList)); try (Connection conn = this.jdbcAssistant.getDatabaseConn(dataSource)) { - executePreWorkSqls(conn, preSqlList); + // do nothing + // Avoid SQL injection issues + // About to Delete } catch (SQLException e) { throw DataproxyException.of(DataproxyErrorCode.JDBC_CREATE_TABLE_FAILED, e.getMessage(), e); } @@ -103,61 +104,7 @@ protected void initialize(Schema schema) { @Override public void write(VectorSchemaRoot root) throws IOException { - ensureInitialized(root.getSchema()); - - // 每次直接发送,不积攒 - final int rowCount = root.getRowCount(); - int recordCount = 0; - - try (Connection conn = this.jdbcAssistant.getDatabaseConn(dataSource)) { - boolean finished = false; - - if (this.jdbcAssistant.supportBatchInsert()) { - try (PreparedStatement preparedStatement = conn.prepareStatement(this.stmt)) { - if (rowCount != 0) { - final JdbcParameterBinder binder = JdbcParameterBinder.builder(preparedStatement, root).bindAll().build(); - while (binder.next()) { - preparedStatement.addBatch(); - } - int[] recordCounts = preparedStatement.executeBatch(); - recordCount = Arrays.stream(recordCounts).sum(); - } - finished = true; - } catch (Exception e) { - log.warn("[JdbcDataWriter] prepare batch write error, then dp will try to generate integral insert sql, stmt:{}", this.stmt, e); - } - } - - // 不支持prepare模式,需要构造完整insert语句 - //insert into `default`.`test_table`(`int32`,`float64`,`string`) values(?,?,?) - if (!finished) { - String insertSql = null; - List jdbcTypes = root.getFieldVectors().stream() - .map(vector -> this.jdbcAssistant.arrowTypeToJdbcType(vector.getField())) - .toList(); - - try (Statement statement = conn.createStatement()) { - // 数据逐行写入 - for (int row = 0; row < root.getRowCount(); row++) { - String[] values = new String[root.getFieldVectors().size()]; - for (int col = 0; col < root.getFieldVectors().size(); col++) { - values[col] = this.jdbcAssistant.serialize(jdbcTypes.get(col), root.getVector(col).getObject(row)); - } - - insertSql = String.format(this.stmt.replace("?", "%s"), (Object[]) values); - statement.execute(insertSql); - } - } catch (Exception e) { - log.error("[JdbcDataWriter] integral insert sql error, sql:{}", insertSql, e); - throw e; - } - } - - log.info("[JdbcDataWriter] jdbc batch write success, record count:{}, table:{}", recordCount, this.composeTableName); - } catch (Exception e) { - log.error("[JdbcDataWriter] jdbc batch write failed, table:{}", this.composeTableName); - throw DataproxyException.of(DataproxyErrorCode.JDBC_INSERT_INTO_TABLE_FAILED, e); - } + throw DataproxyException.of(DataproxyErrorCode.JDBC_INSERT_INTO_TABLE_FAILED, "jdbc not support write"); } @Override @@ -179,15 +126,4 @@ public void close() throws Exception { } catch (Exception ignored) { } } - - void executePreWorkSqls(Connection conn, List preWorkSqls) throws SQLException { - for (String sql : preWorkSqls) { - try (Statement statement = conn.createStatement()) { - statement.execute(sql); - } catch (SQLException e) { - log.error("[SinkJdbcHandler] 数据转移前预先执行SQL失败:{}", sql); - throw e; - } - } - } } \ No newline at end of file diff --git a/dataproxy-server/src/main/java/org/secretflow/dataproxy/server/flight/DataproxyProducerImpl.java b/dataproxy-server/src/main/java/org/secretflow/dataproxy/server/flight/DataproxyProducerImpl.java index 621e2d4..adac904 100644 --- a/dataproxy-server/src/main/java/org/secretflow/dataproxy/server/flight/DataproxyProducerImpl.java +++ b/dataproxy-server/src/main/java/org/secretflow/dataproxy/server/flight/DataproxyProducerImpl.java @@ -296,8 +296,17 @@ public void getStreamReadData(CallContext context, Ticket ticket, ServerStreamLi log.info("[getStreamReadData] parse command from ticket success, command:{}", JsonUtils.toJSONString(command)); try (ArrowReader arrowReader = dataProxyService.generateArrowReader(rootAllocator, (DatasetReadCommand) command.getCommandInfo())) { listener.start(arrowReader.getVectorSchemaRoot()); - while (arrowReader.loadNextBatch()) { - listener.putNext(); + + while (true) { + if (context.isCancelled()) { + log.warn("[getStreamReadData] get stream cancelled"); + break; + } + if (arrowReader.loadNextBatch()) { + listener.putNext(); + } else { + break; + } } listener.completed(); log.info("[getStreamReadData] get stream completed"); diff --git a/dataproxy_sdk/python/dataproxy/version.py b/dataproxy_sdk/python/dataproxy/version.py index b0cf93d..11cf104 100644 --- a/dataproxy_sdk/python/dataproxy/version.py +++ b/dataproxy_sdk/python/dataproxy/version.py @@ -13,4 +13,4 @@ # limitations under the License. -__version__ = "0.2.0.dev$$DATE$$" +__version__ = "0.2.0b0" diff --git a/pom.xml b/pom.xml index 2f49697..e8166ae 100644 --- a/pom.xml +++ b/pom.xml @@ -271,7 +271,7 @@ commons-io commons-io - 2.11.0 + 2.14.0 com.opencsv