diff --git a/src/main/java/com/salesforce/dataloader/dao/database/DatabaseContext.java b/src/main/java/com/salesforce/dataloader/dao/database/DatabaseContext.java index c2685eb5..78ee56a6 100644 --- a/src/main/java/com/salesforce/dataloader/dao/database/DatabaseContext.java +++ b/src/main/java/com/salesforce/dataloader/dao/database/DatabaseContext.java @@ -39,6 +39,8 @@ import com.salesforce.dataloader.dyna.DateTimeConverter; import com.salesforce.dataloader.exception.DataAccessObjectInitializationException; import com.salesforce.dataloader.exception.ParameterLoadException; +import com.salesforce.dataloader.model.Row; +import com.salesforce.dataloader.model.RowInterface; /** * Describe your class here. @@ -148,18 +150,18 @@ public void replaceSqlParams(String sqlString) { * Values for the parameter replacement * @throws ParameterLoadException */ - public void setSqlParamValues(SqlConfig sqlConfig, AppConfig appConfig, Map paramValues) + public void setSqlParamValues(SqlConfig sqlConfig, AppConfig appConfig, RowInterface paramValues) throws ParameterLoadException { // detect if there're no parameters to set if (sqlConfig.getSqlParams() == null) { return; } if (paramValues == null) { - paramValues = new HashMap(); + paramValues = new Row(); } for (String paramName : sqlConfig.getSqlParams().keySet()) { String type = sqlConfig.getSqlParams().get(paramName); - if (paramValues.containsKey(paramName)) { + if (paramValues.getColumnNames().contains(paramName)) { Object sqlValue = mapParamToDbType(appConfig, paramValues.get(paramName), type); paramValues.put(paramName, sqlValue); } else { diff --git a/src/main/java/com/salesforce/dataloader/dao/database/DatabaseReader.java b/src/main/java/com/salesforce/dataloader/dao/database/DatabaseReader.java index 0a0e7e15..9808d6e3 100644 --- a/src/main/java/com/salesforce/dataloader/dao/database/DatabaseReader.java +++ b/src/main/java/com/salesforce/dataloader/dao/database/DatabaseReader.java @@ -29,6 +29,7 @@ import java.sql.*; import java.util.*; +import com.salesforce.dataloader.model.TableHeader; import com.salesforce.dataloader.model.TableRow; import org.apache.commons.dbcp2.BasicDataSource; @@ -124,8 +125,15 @@ private void setupQuery(Map params) throws DataAccessObjectInitia PreparedStatement statement = dbContext.prepareStatement(); // right now, query doesn't support data input -- all the parameters are static vs. update which takes data // for every put call + TableRow row = null; + if (params != null) { + Set colHeaderNames = params.keySet(); + @SuppressWarnings("unchecked") + TableHeader header = new TableHeader((List)(Object)Arrays.asList(colHeaderNames.toArray())); + row = new TableRow(header); + } dbContext.setSqlParamValues(sqlConfig, - this.getAppConfig(), params); + this.getAppConfig(), row); // set the query fetch size int fetchSize; diff --git a/src/main/java/com/salesforce/dataloader/dao/database/DatabaseWriter.java b/src/main/java/com/salesforce/dataloader/dao/database/DatabaseWriter.java index 3ab67c86..aab1f1ea 100644 --- a/src/main/java/com/salesforce/dataloader/dao/database/DatabaseWriter.java +++ b/src/main/java/com/salesforce/dataloader/dao/database/DatabaseWriter.java @@ -29,9 +29,7 @@ import java.sql.*; import java.util.*; -import com.salesforce.dataloader.model.Row; import com.salesforce.dataloader.model.RowInterface; -import com.salesforce.dataloader.model.TableRow; import org.apache.commons.dbcp2.BasicDataSource; import org.apache.logging.log4j.Logger; import com.salesforce.dataloader.util.DLLogManager; @@ -139,26 +137,13 @@ public boolean writeRowList(List inputRowList) throws Da try { //for batchsize = 1, don't do batching, this provides much better error output if(inputRowList.size() == 1) { - if (inputRowList.get(0) instanceof Row) { - dbContext.setSqlParamValues(sqlConfig, appConfig, (Row)inputRowList.get(0)); - } else if (inputRowList.get(0) == null) { - dbContext.setSqlParamValues(sqlConfig, appConfig, null); - } else { - dbContext.setSqlParamValues(sqlConfig, appConfig, - ((TableRow)inputRowList.get(0)).convertToRow()); - } + dbContext.setSqlParamValues(sqlConfig, appConfig, inputRowList.get(0)); currentRowNumber++; } else { // for each row set the Sql params in the prepared statement dbContext.getDataStatement().clearBatch(); for (RowInterface inputRow : inputRowList) { - if (inputRow instanceof Row) { - dbContext.setSqlParamValues(sqlConfig, appConfig, (Row)inputRow); - } else if (inputRow == null) { - dbContext.setSqlParamValues(sqlConfig, appConfig, null); - } else { - dbContext.setSqlParamValues(sqlConfig, appConfig, ((TableRow)inputRow).convertToRow()); - } + dbContext.setSqlParamValues(sqlConfig, appConfig, inputRow); dbContext.getDataStatement().addBatch(); currentRowNumber++; } @@ -296,9 +281,4 @@ public void setColumnNames(List columnNames) { // TODO: Ordered column names can possibly used for ordered output from the write. Currently, this is not used // since writeRow will contain column information anyway and order doesn't matter in database } - - public List getColumnNamesFromRow(Row row) throws DataAccessObjectInitializationException { - return null; - } - } diff --git a/src/test/java/com/salesforce/dataloader/process/DatabaseProcessTest.java b/src/test/java/com/salesforce/dataloader/process/DatabaseProcessTest.java index 3973f57b..013fdea3 100644 --- a/src/test/java/com/salesforce/dataloader/process/DatabaseProcessTest.java +++ b/src/test/java/com/salesforce/dataloader/process/DatabaseProcessTest.java @@ -73,7 +73,8 @@ public static Collection getTestParameters() { TestVariant.forSettings(TestSetting.BULK_API_ENABLED), TestVariant.forSettings(TestSetting.BULK_API_ENABLED, TestSetting.BULK_API_CACHE_DAO_UPLOAD_ENABLED), TestVariant.forSettings(TestSetting.BULK_API_ENABLED, TestSetting.BULK_API_ZIP_CONTENT_ENABLED), - TestVariant.forSettings(TestSetting.BULK_API_ENABLED, TestSetting.BULK_API_SERIAL_MODE_ENABLED) + TestVariant.forSettings(TestSetting.BULK_API_ENABLED, TestSetting.BULK_API_SERIAL_MODE_ENABLED), + TestVariant.forSettings(TestSetting.BULK_API_ENABLED, TestSetting.BULK_V2_API_ENABLED) ); }