diff --git a/spring-boot-test-dbunit-integration-tests/src/test/kotlin/io/camassia/spring/dbunit/DatabaseTesterTest.kt b/spring-boot-test-dbunit-integration-tests/src/test/kotlin/io/camassia/spring/dbunit/DatabaseTesterTest.kt
index f8d8b46..8fcfbb2 100644
--- a/spring-boot-test-dbunit-integration-tests/src/test/kotlin/io/camassia/spring/dbunit/DatabaseTesterTest.kt
+++ b/spring-boot-test-dbunit-integration-tests/src/test/kotlin/io/camassia/spring/dbunit/DatabaseTesterTest.kt
@@ -39,6 +39,15 @@ class DatabaseTesterTest @Autowired constructor(
inner class GivenDataSet {
@Nested
inner class WithFilenames {
+ @Test
+ fun `should handle database clear`() {
+ dbunit.givenDataSet(DatabaseTesterTest::class, "/Empty123.xml")
+
+ assertThat(selectAllFrom("demo1")).isEmpty()
+ assertThat(selectAllFrom("demo2")).isEmpty()
+ assertThat(selectAllFrom("demo3")).isEmpty()
+ }
+
@Test
fun `should handle files with one row in multiple tables`() {
dbunit.givenDataSet(DatabaseTesterTest::class, "/Demo-WithOneRowInMultipleTables.xml")
@@ -107,6 +116,16 @@ class DatabaseTesterTest @Autowired constructor(
assertThat(result[0].component2()).isEqualTo("default")
}
+ @Test
+ fun `should use defaults for empty row`() {
+ dbunit.givenDataSet(DatabaseTesterTest::class, "/TemplatedDemo3.xml")
+
+ val result = selectAllFrom("demo3")
+ assertThat(result).hasSize(1)
+ assertThat(result[0].component1()).isEqualTo(3)
+ assertThat(result[0].component2()).isEqualTo("default")
+ }
+
@Test
fun `should replace special file params`() {
dbunit.givenDataSet(DatabaseTesterTest::class, "/Demo-WithFile.xml")
@@ -120,6 +139,18 @@ class DatabaseTesterTest @Autowired constructor(
@Nested
inner class WithTemplatedFiles {
+ @Test
+ fun `should handle database clear`() {
+ dbunit.givenDataSet(
+ DatabaseTesterTest::class,
+ File("/Empty123.xml")
+ )
+
+ assertThat(selectAllFrom("demo1")).isEmpty()
+ assertThat(selectAllFrom("demo2")).isEmpty()
+ assertThat(selectAllFrom("demo3")).isEmpty()
+ }
+
@Test
fun `should handle files with one row in multiple tables`() {
dbunit.givenDataSet(
@@ -230,6 +261,19 @@ class DatabaseTesterTest @Autowired constructor(
assertThat(result[0].component2()).isEqualTo("default")
}
+ @Test
+ fun `should use defaults for empty row`() {
+ dbunit.givenDataSet(
+ DatabaseTesterTest::class,
+ File("/TemplatedDemo3.xml")
+ )
+
+ val result = selectAllFrom("demo3")
+ assertThat(result).hasSize(1)
+ assertThat(result[0].component1()).isEqualTo(3)
+ assertThat(result[0].component2()).isEqualTo("default")
+ }
+
@Test
fun `should use defaults for missing overrides if available`() {
dbunit.givenDataSet(
@@ -280,6 +324,19 @@ class DatabaseTesterTest @Autowired constructor(
@Nested
inner class WithProgrammaticDataSet {
+ @Test
+ fun `should handle database clear`() {
+ dbunit.givenDataSet(
+ Table("demo1", emptyList()),
+ Table("demo2", emptyList()),
+ Table("demo3", emptyList())
+ )
+
+ assertThat(selectAllFrom("demo1")).isEmpty()
+ assertThat(selectAllFrom("demo2")).isEmpty()
+ assertThat(selectAllFrom("demo3")).isEmpty()
+ }
+
@Test
fun `should handle datasets with one row in multiple tables`() {
dbunit.givenDataSet(
@@ -337,6 +394,21 @@ class DatabaseTesterTest @Autowired constructor(
assertThat(result[0].component2()).isEqualTo("default")
}
+ @Test
+ fun `should use defaults for empty row`() {
+ dbunit.givenDataSet(
+ Table(
+ "demo3",
+ Row(emptyList())
+ )
+ )
+
+ val result = selectAllFrom("demo3")
+ assertThat(result).hasSize(1)
+ assertThat(result[0].component1()).isEqualTo(3)
+ assertThat(result[0].component2()).isEqualTo("default")
+ }
+
@Test
fun `should handle datasets with nulls`() {
dbunit.givenDataSet(
@@ -469,10 +541,11 @@ class DatabaseTesterTest @Autowired constructor(
fun `should build a dataset from the current db state`() {
insertInto("demo1", 1, "name1")
insertInto("demo2", 2, "name2")
+ insertInto("demo3", 3, "name3")
val dataSet = dbunit.createDataset()
- assertThat(dataSet.tableNames).containsExactlyInAnyOrder("DEMO1", "DEMO2")
+ assertThat(dataSet.tableNames).containsExactlyInAnyOrder("DEMO1", "DEMO2", "DEMO3")
val demo1 = dataSet.getTable("demo1")
assertThat(demo1.rowCount).isOne
assertThat(demo1.getValue(0, "id") as BigInteger).isEqualTo(1)
@@ -481,6 +554,10 @@ class DatabaseTesterTest @Autowired constructor(
assertThat(demo2.rowCount).isOne
assertThat(demo2.getValue(0, "id") as BigInteger).isEqualTo(2)
assertThat(demo2.getValue(0, "name")).isEqualTo("name2")
+ val demo3 = dataSet.getTable("demo3")
+ assertThat(demo3.rowCount).isOne
+ assertThat(demo3.getValue(0, "id") as BigInteger).isEqualTo(3)
+ assertThat(demo3.getValue(0, "name")).isEqualTo("name3")
}
@Test
@@ -544,6 +621,9 @@ class DatabaseTesterTest @Autowired constructor(
fun connectionSupplier(ds: DataSource) = DataSourceConnectionSupplier(ds)
@Bean
- fun tableDefault() = TableDefaults("demo1", Cell("name", "default"))
+ fun tableDefault1() = TableDefaults("demo1", Cell("name", "default"))
+
+ @Bean
+ fun tableDefault3() = TableDefaults("demo3", Cell("id", "3"), Cell("name", "default"))
}
-}
\ No newline at end of file
+}
diff --git a/spring-boot-test-dbunit-integration-tests/src/test/kotlin/io/camassia/spring/dbunit/RepositoryTest.kt b/spring-boot-test-dbunit-integration-tests/src/test/kotlin/io/camassia/spring/dbunit/RepositoryTest.kt
index dd5c8b1..c5b960f 100644
--- a/spring-boot-test-dbunit-integration-tests/src/test/kotlin/io/camassia/spring/dbunit/RepositoryTest.kt
+++ b/spring-boot-test-dbunit-integration-tests/src/test/kotlin/io/camassia/spring/dbunit/RepositoryTest.kt
@@ -11,6 +11,7 @@ import org.springframework.test.context.jdbc.SqlGroup
statements = [
"CREATE TABLE demo1 (id BIGINT NOT NULL, name VARCHAR(50), CONSTRAINT demo1_pk PRIMARY KEY (id))",
"CREATE TABLE demo2 (id BIGINT NOT NULL, name VARCHAR(50), CONSTRAINT demo2_pk PRIMARY KEY (id))",
+ "CREATE TABLE demo3 (id BIGINT NOT NULL, name VARCHAR(50), CONSTRAINT demo3_pk PRIMARY KEY (id))",
],
executionPhase = Sql.ExecutionPhase.BEFORE_TEST_METHOD
),
@@ -18,6 +19,7 @@ import org.springframework.test.context.jdbc.SqlGroup
statements = [
"DROP TABLE demo1",
"DROP TABLE demo2",
+ "DROP TABLE demo3",
],
executionPhase = Sql.ExecutionPhase.AFTER_TEST_METHOD
)
@@ -40,4 +42,4 @@ abstract class RepositoryTest {
}
}
-}
\ No newline at end of file
+}
diff --git a/spring-boot-test-dbunit-integration-tests/src/test/resources/Empty123.xml b/spring-boot-test-dbunit-integration-tests/src/test/resources/Empty123.xml
new file mode 100644
index 0000000..5c48908
--- /dev/null
+++ b/spring-boot-test-dbunit-integration-tests/src/test/resources/Empty123.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
diff --git a/spring-boot-test-dbunit-integration-tests/src/test/resources/TemplatedDemo3.xml b/spring-boot-test-dbunit-integration-tests/src/test/resources/TemplatedDemo3.xml
new file mode 100644
index 0000000..65a0af7
--- /dev/null
+++ b/spring-boot-test-dbunit-integration-tests/src/test/resources/TemplatedDemo3.xml
@@ -0,0 +1,4 @@
+
+
+
+
diff --git a/spring-boot-test-dbunit/src/main/kotlin/io/camassia/spring/dbunit/api/dataset/builder/TableBasedDataSetBuilder.kt b/spring-boot-test-dbunit/src/main/kotlin/io/camassia/spring/dbunit/api/dataset/builder/TableBasedDataSetBuilder.kt
index c6e2061..446c762 100644
--- a/spring-boot-test-dbunit/src/main/kotlin/io/camassia/spring/dbunit/api/dataset/builder/TableBasedDataSetBuilder.kt
+++ b/spring-boot-test-dbunit/src/main/kotlin/io/camassia/spring/dbunit/api/dataset/builder/TableBasedDataSetBuilder.kt
@@ -48,16 +48,18 @@ internal class TableBasedDataSetBuilder(
while (tables.next()) {
val table = tables.table
val tableName = table.tableMetaData.tableName
- val columnsFromDataSet = table.tableMetaData.columns.toSet()
- val columnsFromDefaults = this.extensions.defaults.forTable(tableName).map { Column(it.key, DataType.UNKNOWN) }
- val columns: Array = if(columnsFromDataSet.isNotEmpty()) (columnsFromDefaults + columnsFromDataSet).distinctBy { keyOf(it.columnName) }.toTypedArray() else emptyArray()
val primaryKeys = table.tableMetaData.primaryKeys
- consumer.startTable(DefaultTableMetaData(tableName, columns, primaryKeys))
-
if (table.rowCount == 0) {
- consumer.row(arrayOfNulls(columns.size))
+ consumer.startTable(DefaultTableMetaData(tableName, emptyArray(), primaryKeys))
+ consumer.row(emptyArray())
} else {
+ val columnsFromDataSet = table.tableMetaData.columns.toSet()
+ val columnsFromDefaults = this.extensions.defaults.forTable(tableName).map { Column(it.key, DataType.UNKNOWN) }
+ val columns: Array = (columnsFromDefaults + columnsFromDataSet).distinctBy { keyOf(it.columnName) }.toTypedArray()
+
+ consumer.startTable(DefaultTableMetaData(tableName, columns, primaryKeys))
+
(0 until table.rowCount).forEach { rowIndex ->
val cells: Array = columns.map { column ->
val rawCellValue: Any? = columnsFromDataSet.find { it.columnName.equals(column.columnName, ignoreCase) }?.let { table.getValue(rowIndex, it.columnName) }
@@ -91,29 +93,33 @@ internal class TableBasedDataSetBuilder(
tables.groupBy { keyOf(it.name) }.mapValues { (_, tables) ->
tables.flatMap { it.rows }
}.forEach { (tableName, rows) ->
- val defaults: Set = this.extensions.defaults.forTable(tableName)
- val columnsFromDataSet = rows.flatMap { it.cells }.map { Column(it.key, DataType.UNKNOWN) }
- val columnsFromDefaults = this.extensions.defaults.forTable(tableName).map { Column(it.key, DataType.UNKNOWN) }
- val columns: Array = if(columnsFromDataSet.isNotEmpty()) (columnsFromDefaults + columnsFromDataSet).distinctBy { keyOf(it.columnName) }.toTypedArray() else emptyArray()
-
- consumer.startTable(DefaultTableMetaData(tableName, columns))
-
- if (rows.isEmpty()) consumer.row(arrayOfNulls(columns.size))
- else rows.forEach { row ->
- val cells: Array = columns
- .map { column: Column ->
- val rawCellValue: Any? = row.cells.find { it.key.equals(column.columnName, ignoreCase) }?.value
- val fallback = if(rawCellValue == null) defaults.find { it.key.equals(column.columnName, ignoreCase) }?.value else null
- // Process all extensions and fallback to a default value if the cell is not set
- extensions
- .applyToCell(tableName, Cell(column.columnName, rawCellValue), overrides)
- .mapValue { value ->
- value ?: fallback
- }
- }
- .map { it.value }
- .toTypedArray()
- consumer.row(cells)
+ if (rows.isEmpty()) {
+ consumer.startTable(DefaultTableMetaData(tableName, emptyArray()))
+ consumer.row(emptyArray())
+ } else {
+ val defaults: Set = this.extensions.defaults.forTable(tableName)
+ val columnsFromDataSet = rows.flatMap { it.cells }.map { Column(it.key, DataType.UNKNOWN) }
+ val columnsFromDefaults = this.extensions.defaults.forTable(tableName).map { Column(it.key, DataType.UNKNOWN) }
+ val columns: Array = (columnsFromDefaults + columnsFromDataSet).distinctBy { keyOf(it.columnName) }.toTypedArray()
+
+ consumer.startTable(DefaultTableMetaData(tableName, columns))
+
+ rows.forEach { row ->
+ val cells: Array = columns
+ .map { column: Column ->
+ val rawCellValue: Any? = row.cells.find { it.key.equals(column.columnName, ignoreCase) }?.value
+ val fallback = if(rawCellValue == null) defaults.find { it.key.equals(column.columnName, ignoreCase) }?.value else null
+ // Process all extensions and fallback to a default value if the cell is not set
+ extensions
+ .applyToCell(tableName, Cell(column.columnName, rawCellValue), overrides)
+ .mapValue { value ->
+ value ?: fallback
+ }
+ }
+ .map { it.value }
+ .toTypedArray()
+ consumer.row(cells)
+ }
}
consumer.endTable()
@@ -122,4 +128,4 @@ internal class TableBasedDataSetBuilder(
consumer.endDataSet()
return DecoratedDataSet(dataSet)
}
-}
\ No newline at end of file
+}
| |