Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

repo-sync-2024-10-23T11:09:21+0800 #12

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@
import org.secretflow.dataproxy.common.exceptions.DataproxyException;
import org.secretflow.dataproxy.common.model.datasource.conn.OdpsConnConfig;
import org.secretflow.dataproxy.common.model.datasource.location.OdpsTableInfo;
import org.secretflow.dataproxy.common.utils.JsonUtils;
import org.secretflow.dataproxy.manager.DataWriter;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;

/**
Expand All @@ -72,6 +74,7 @@ public class OdpsDataWriter implements DataWriter {

private boolean isPartitioned = false;

private TableSchema odpsTableSchema = null;
private TableTunnel.UploadSession uploadSession = null;
private RecordWriter recordWriter = null;

Expand Down Expand Up @@ -152,15 +155,21 @@ private void initOdps() throws OdpsException, IOException {
// init odps client
Odps odps = initOdpsClient(this.connConfig);
// Pre-processing
preProcessing(odps, connConfig.getProjectName(), tableInfo.tableName(), this.convertToPartitionSpec(tableInfo.partitionSpec()));
PartitionSpec convertPartitionSpec = this.convertToPartitionSpec(tableInfo.partitionSpec());
preProcessing(odps, connConfig.getProjectName(), tableInfo.tableName(), convertPartitionSpec);
// init upload session
TableTunnel tunnel = new TableTunnel(odps);

if (isPartitioned) {
if (tableInfo.partitionSpec() == null || tableInfo.partitionSpec().isEmpty()) {
throw DataproxyException.of(DataproxyErrorCode.INVALID_PARTITION_SPEC, "partitionSpec is empty");
}
PartitionSpec partitionSpec = new PartitionSpec(tableInfo.partitionSpec());
assert this.odpsTableSchema != null;
List<Column> partitionColumns = this.odpsTableSchema.getPartitionColumns();
PartitionSpec partitionSpec = new PartitionSpec();
for (Column partitionColumn : partitionColumns) {
partitionSpec.set(partitionColumn.getName(), convertPartitionSpec.get(partitionColumn.getName()));
}
uploadSession = tunnel.createUploadSession(connConfig.getProjectName(), tableInfo.tableName(), partitionSpec, overwrite);
} else {
uploadSession = tunnel.createUploadSession(connConfig.getProjectName(), tableInfo.tableName(), overwrite);
Expand Down Expand Up @@ -271,7 +280,10 @@ private void preProcessing(Odps odps, String projectName, String tableName, Part
} else {
log.info("odps table is exists, project: {}, table name: {}", projectName, tableName);
}
isPartitioned = odps.tables().get(projectName, tableName).isPartitioned();

Table table = odps.tables().get(projectName, tableName);
isPartitioned = table.isPartitioned();
this.setOdpsTableSchemaIfAbsent(table.getSchema());

if (isPartitioned) {
if (partitionSpec == null || partitionSpec.isEmpty()) {
Expand Down Expand Up @@ -334,8 +346,29 @@ private boolean createOdpsTable(Odps odps, String projectName, String tableName,
TableSchema tableSchema = convertToTableSchema(schema);
if (partitionSpec != null) {
// Infer partitioning field type as string.
partitionSpec.keys().forEach(key -> tableSchema.addPartitionColumn(Column.newBuilder(key, TypeInfoFactory.STRING).build()));
List<Column> tableSchemaColumns = tableSchema.getColumns();
List<Integer> partitionColumnIndexes = new ArrayList<>();
ArrayList<Column> partitionColumns = new ArrayList<>();

for (String key : partitionSpec.keys()) {
if (tableSchema.containsColumn(key)) {
log.info("tableSchemaColumns contains partition column: {}", key);
partitionColumnIndexes.add(tableSchema.getColumnIndex(key));
partitionColumns.add(tableSchema.getColumn(key));
} else {
log.info("tableSchemaColumns not contains partition column: {}", key);
partitionColumns.add(Column.newBuilder(key, TypeInfoFactory.STRING).build());
}
}

for (int i = 0; i < partitionColumnIndexes.size(); i++) {
tableSchemaColumns.remove(partitionColumnIndexes.get(i) - i);
}
log.info("tableSchemaColumns: {}, partitionColumnIndexes: {}", JsonUtils.toString(tableSchemaColumns), JsonUtils.toString(partitionColumnIndexes));
tableSchema.setColumns(tableSchemaColumns);
tableSchema.setPartitionColumns(partitionColumns);
}
log.info("create odps table schema: {}", JsonUtils.toString(tableSchema));
odps.tables().create(projectName, tableName, tableSchema, "", true, null, OdpsUtil.getSqlFlag(), null);
return true;
} catch (Exception e) {
Expand All @@ -355,6 +388,12 @@ private boolean createOdpsPartition(Odps odps, String projectName, String tableN
return false;
}

private void setOdpsTableSchemaIfAbsent(TableSchema tableSchema) {
if (odpsTableSchema == null) {
this.odpsTableSchema = tableSchema;
}
}

private TableSchema convertToTableSchema(Schema schema) {
List<Column> columns = schema.getFields().stream().map(this::convertToColumn).toList();
return TableSchema.builder().withColumns(columns).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@
import org.secretflow.dataproxy.common.model.datasource.location.JdbcLocationConfig;
import org.secretflow.dataproxy.common.utils.JsonUtils;
import org.secretflow.dataproxy.manager.DataWriter;
import org.secretflow.dataproxy.manager.connector.rdbms.adaptor.JdbcParameterBinder;

import java.io.IOException;
import java.sql.*;
import java.util.Arrays;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.List;

/**
Expand Down Expand Up @@ -90,7 +89,9 @@ protected void initialize(Schema schema) {
log.info("[JdbcDataWriter] preSql execute start, sql: {}", JsonUtils.toJSONString(preSqlList));

try (Connection conn = this.jdbcAssistant.getDatabaseConn(dataSource)) {
executePreWorkSqls(conn, preSqlList);
// do nothing
// Avoid SQL injection issues
// About to Delete
} catch (SQLException e) {
throw DataproxyException.of(DataproxyErrorCode.JDBC_CREATE_TABLE_FAILED, e.getMessage(), e);
}
Expand All @@ -103,61 +104,7 @@ protected void initialize(Schema schema) {

@Override
public void write(VectorSchemaRoot root) throws IOException {
ensureInitialized(root.getSchema());

// 每次直接发送,不积攒
final int rowCount = root.getRowCount();
int recordCount = 0;

try (Connection conn = this.jdbcAssistant.getDatabaseConn(dataSource)) {
boolean finished = false;

if (this.jdbcAssistant.supportBatchInsert()) {
try (PreparedStatement preparedStatement = conn.prepareStatement(this.stmt)) {
if (rowCount != 0) {
final JdbcParameterBinder binder = JdbcParameterBinder.builder(preparedStatement, root).bindAll().build();
while (binder.next()) {
preparedStatement.addBatch();
}
int[] recordCounts = preparedStatement.executeBatch();
recordCount = Arrays.stream(recordCounts).sum();
}
finished = true;
} catch (Exception e) {
log.warn("[JdbcDataWriter] prepare batch write error, then dp will try to generate integral insert sql, stmt:{}", this.stmt, e);
}
}

// 不支持prepare模式,需要构造完整insert语句
//insert into `default`.`test_table`(`int32`,`float64`,`string`) values(?,?,?)
if (!finished) {
String insertSql = null;
List<JDBCType> jdbcTypes = root.getFieldVectors().stream()
.map(vector -> this.jdbcAssistant.arrowTypeToJdbcType(vector.getField()))
.toList();

try (Statement statement = conn.createStatement()) {
// 数据逐行写入
for (int row = 0; row < root.getRowCount(); row++) {
String[] values = new String[root.getFieldVectors().size()];
for (int col = 0; col < root.getFieldVectors().size(); col++) {
values[col] = this.jdbcAssistant.serialize(jdbcTypes.get(col), root.getVector(col).getObject(row));
}

insertSql = String.format(this.stmt.replace("?", "%s"), (Object[]) values);
statement.execute(insertSql);
}
} catch (Exception e) {
log.error("[JdbcDataWriter] integral insert sql error, sql:{}", insertSql, e);
throw e;
}
}

log.info("[JdbcDataWriter] jdbc batch write success, record count:{}, table:{}", recordCount, this.composeTableName);
} catch (Exception e) {
log.error("[JdbcDataWriter] jdbc batch write failed, table:{}", this.composeTableName);
throw DataproxyException.of(DataproxyErrorCode.JDBC_INSERT_INTO_TABLE_FAILED, e);
}
throw DataproxyException.of(DataproxyErrorCode.JDBC_INSERT_INTO_TABLE_FAILED, "jdbc not support write");
}

@Override
Expand All @@ -179,15 +126,4 @@ public void close() throws Exception {
} catch (Exception ignored) {
}
}

void executePreWorkSqls(Connection conn, List<String> preWorkSqls) throws SQLException {
for (String sql : preWorkSqls) {
try (Statement statement = conn.createStatement()) {
statement.execute(sql);
} catch (SQLException e) {
log.error("[SinkJdbcHandler] 数据转移前预先执行SQL失败:{}", sql);
throw e;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,17 @@ public void getStreamReadData(CallContext context, Ticket ticket, ServerStreamLi
log.info("[getStreamReadData] parse command from ticket success, command:{}", JsonUtils.toJSONString(command));
try (ArrowReader arrowReader = dataProxyService.generateArrowReader(rootAllocator, (DatasetReadCommand) command.getCommandInfo())) {
listener.start(arrowReader.getVectorSchemaRoot());
while (arrowReader.loadNextBatch()) {
listener.putNext();

while (true) {
if (context.isCancelled()) {
log.warn("[getStreamReadData] get stream cancelled");
break;
}
if (arrowReader.loadNextBatch()) {
listener.putNext();
} else {
break;
}
}
listener.completed();
log.info("[getStreamReadData] get stream completed");
Expand Down
2 changes: 1 addition & 1 deletion dataproxy_sdk/python/dataproxy/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.


__version__ = "0.2.0.dev$$DATE$$"
__version__ = "0.2.0b0"
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.11.0</version>
<version>2.14.0</version>
</dependency>
<dependency>
<groupId>com.opencsv</groupId>
Expand Down
Loading