diff --git a/spring-boot-test-dbunit-integration-tests/src/test/kotlin/io/camassia/spring/dbunit/DatabaseTesterTest.kt b/spring-boot-test-dbunit-integration-tests/src/test/kotlin/io/camassia/spring/dbunit/DatabaseTesterTest.kt index f8d8b46..8fcfbb2 100644 --- a/spring-boot-test-dbunit-integration-tests/src/test/kotlin/io/camassia/spring/dbunit/DatabaseTesterTest.kt +++ b/spring-boot-test-dbunit-integration-tests/src/test/kotlin/io/camassia/spring/dbunit/DatabaseTesterTest.kt @@ -39,6 +39,15 @@ class DatabaseTesterTest @Autowired constructor( inner class GivenDataSet { @Nested inner class WithFilenames { + @Test + fun `should handle database clear`() { + dbunit.givenDataSet(DatabaseTesterTest::class, "/Empty123.xml") + + assertThat(selectAllFrom("demo1")).isEmpty() + assertThat(selectAllFrom("demo2")).isEmpty() + assertThat(selectAllFrom("demo3")).isEmpty() + } + @Test fun `should handle files with one row in multiple tables`() { dbunit.givenDataSet(DatabaseTesterTest::class, "/Demo-WithOneRowInMultipleTables.xml") @@ -107,6 +116,16 @@ class DatabaseTesterTest @Autowired constructor( assertThat(result[0].component2()).isEqualTo("default") } + @Test + fun `should use defaults for empty row`() { + dbunit.givenDataSet(DatabaseTesterTest::class, "/TemplatedDemo3.xml") + + val result = selectAllFrom("demo3") + assertThat(result).hasSize(1) + assertThat(result[0].component1()).isEqualTo(3) + assertThat(result[0].component2()).isEqualTo("default") + } + @Test fun `should replace special file params`() { dbunit.givenDataSet(DatabaseTesterTest::class, "/Demo-WithFile.xml") @@ -120,6 +139,18 @@ class DatabaseTesterTest @Autowired constructor( @Nested inner class WithTemplatedFiles { + @Test + fun `should handle database clear`() { + dbunit.givenDataSet( + DatabaseTesterTest::class, + File("/Empty123.xml") + ) + + assertThat(selectAllFrom("demo1")).isEmpty() + assertThat(selectAllFrom("demo2")).isEmpty() + assertThat(selectAllFrom("demo3")).isEmpty() + } + @Test fun `should handle files with one row in multiple tables`() { dbunit.givenDataSet( @@ -230,6 +261,19 @@ class DatabaseTesterTest @Autowired constructor( assertThat(result[0].component2()).isEqualTo("default") } + @Test + fun `should use defaults for empty row`() { + dbunit.givenDataSet( + DatabaseTesterTest::class, + File("/TemplatedDemo3.xml") + ) + + val result = selectAllFrom("demo3") + assertThat(result).hasSize(1) + assertThat(result[0].component1()).isEqualTo(3) + assertThat(result[0].component2()).isEqualTo("default") + } + @Test fun `should use defaults for missing overrides if available`() { dbunit.givenDataSet( @@ -280,6 +324,19 @@ class DatabaseTesterTest @Autowired constructor( @Nested inner class WithProgrammaticDataSet { + @Test + fun `should handle database clear`() { + dbunit.givenDataSet( + Table("demo1", emptyList()), + Table("demo2", emptyList()), + Table("demo3", emptyList()) + ) + + assertThat(selectAllFrom("demo1")).isEmpty() + assertThat(selectAllFrom("demo2")).isEmpty() + assertThat(selectAllFrom("demo3")).isEmpty() + } + @Test fun `should handle datasets with one row in multiple tables`() { dbunit.givenDataSet( @@ -337,6 +394,21 @@ class DatabaseTesterTest @Autowired constructor( assertThat(result[0].component2()).isEqualTo("default") } + @Test + fun `should use defaults for empty row`() { + dbunit.givenDataSet( + Table( + "demo3", + Row(emptyList()) + ) + ) + + val result = selectAllFrom("demo3") + assertThat(result).hasSize(1) + assertThat(result[0].component1()).isEqualTo(3) + assertThat(result[0].component2()).isEqualTo("default") + } + @Test fun `should handle datasets with nulls`() { dbunit.givenDataSet( @@ -469,10 +541,11 @@ class DatabaseTesterTest @Autowired constructor( fun `should build a dataset from the current db state`() { insertInto("demo1", 1, "name1") insertInto("demo2", 2, "name2") + insertInto("demo3", 3, "name3") val dataSet = dbunit.createDataset() - assertThat(dataSet.tableNames).containsExactlyInAnyOrder("DEMO1", "DEMO2") + assertThat(dataSet.tableNames).containsExactlyInAnyOrder("DEMO1", "DEMO2", "DEMO3") val demo1 = dataSet.getTable("demo1") assertThat(demo1.rowCount).isOne assertThat(demo1.getValue(0, "id") as BigInteger).isEqualTo(1) @@ -481,6 +554,10 @@ class DatabaseTesterTest @Autowired constructor( assertThat(demo2.rowCount).isOne assertThat(demo2.getValue(0, "id") as BigInteger).isEqualTo(2) assertThat(demo2.getValue(0, "name")).isEqualTo("name2") + val demo3 = dataSet.getTable("demo3") + assertThat(demo3.rowCount).isOne + assertThat(demo3.getValue(0, "id") as BigInteger).isEqualTo(3) + assertThat(demo3.getValue(0, "name")).isEqualTo("name3") } @Test @@ -544,6 +621,9 @@ class DatabaseTesterTest @Autowired constructor( fun connectionSupplier(ds: DataSource) = DataSourceConnectionSupplier(ds) @Bean - fun tableDefault() = TableDefaults("demo1", Cell("name", "default")) + fun tableDefault1() = TableDefaults("demo1", Cell("name", "default")) + + @Bean + fun tableDefault3() = TableDefaults("demo3", Cell("id", "3"), Cell("name", "default")) } -} \ No newline at end of file +} diff --git a/spring-boot-test-dbunit-integration-tests/src/test/kotlin/io/camassia/spring/dbunit/RepositoryTest.kt b/spring-boot-test-dbunit-integration-tests/src/test/kotlin/io/camassia/spring/dbunit/RepositoryTest.kt index dd5c8b1..c5b960f 100644 --- a/spring-boot-test-dbunit-integration-tests/src/test/kotlin/io/camassia/spring/dbunit/RepositoryTest.kt +++ b/spring-boot-test-dbunit-integration-tests/src/test/kotlin/io/camassia/spring/dbunit/RepositoryTest.kt @@ -11,6 +11,7 @@ import org.springframework.test.context.jdbc.SqlGroup statements = [ "CREATE TABLE demo1 (id BIGINT NOT NULL, name VARCHAR(50), CONSTRAINT demo1_pk PRIMARY KEY (id))", "CREATE TABLE demo2 (id BIGINT NOT NULL, name VARCHAR(50), CONSTRAINT demo2_pk PRIMARY KEY (id))", + "CREATE TABLE demo3 (id BIGINT NOT NULL, name VARCHAR(50), CONSTRAINT demo3_pk PRIMARY KEY (id))", ], executionPhase = Sql.ExecutionPhase.BEFORE_TEST_METHOD ), @@ -18,6 +19,7 @@ import org.springframework.test.context.jdbc.SqlGroup statements = [ "DROP TABLE demo1", "DROP TABLE demo2", + "DROP TABLE demo3", ], executionPhase = Sql.ExecutionPhase.AFTER_TEST_METHOD ) @@ -40,4 +42,4 @@ abstract class RepositoryTest { } } -} \ No newline at end of file +} diff --git a/spring-boot-test-dbunit-integration-tests/src/test/resources/Empty123.xml b/spring-boot-test-dbunit-integration-tests/src/test/resources/Empty123.xml new file mode 100644 index 0000000..5c48908 --- /dev/null +++ b/spring-boot-test-dbunit-integration-tests/src/test/resources/Empty123.xml @@ -0,0 +1,6 @@ + + + + + + diff --git a/spring-boot-test-dbunit-integration-tests/src/test/resources/TemplatedDemo3.xml b/spring-boot-test-dbunit-integration-tests/src/test/resources/TemplatedDemo3.xml new file mode 100644 index 0000000..65a0af7 --- /dev/null +++ b/spring-boot-test-dbunit-integration-tests/src/test/resources/TemplatedDemo3.xml @@ -0,0 +1,4 @@ + + + + diff --git a/spring-boot-test-dbunit/src/main/kotlin/io/camassia/spring/dbunit/api/dataset/builder/TableBasedDataSetBuilder.kt b/spring-boot-test-dbunit/src/main/kotlin/io/camassia/spring/dbunit/api/dataset/builder/TableBasedDataSetBuilder.kt index c6e2061..446c762 100644 --- a/spring-boot-test-dbunit/src/main/kotlin/io/camassia/spring/dbunit/api/dataset/builder/TableBasedDataSetBuilder.kt +++ b/spring-boot-test-dbunit/src/main/kotlin/io/camassia/spring/dbunit/api/dataset/builder/TableBasedDataSetBuilder.kt @@ -48,16 +48,18 @@ internal class TableBasedDataSetBuilder( while (tables.next()) { val table = tables.table val tableName = table.tableMetaData.tableName - val columnsFromDataSet = table.tableMetaData.columns.toSet() - val columnsFromDefaults = this.extensions.defaults.forTable(tableName).map { Column(it.key, DataType.UNKNOWN) } - val columns: Array = if(columnsFromDataSet.isNotEmpty()) (columnsFromDefaults + columnsFromDataSet).distinctBy { keyOf(it.columnName) }.toTypedArray() else emptyArray() val primaryKeys = table.tableMetaData.primaryKeys - consumer.startTable(DefaultTableMetaData(tableName, columns, primaryKeys)) - if (table.rowCount == 0) { - consumer.row(arrayOfNulls(columns.size)) + consumer.startTable(DefaultTableMetaData(tableName, emptyArray(), primaryKeys)) + consumer.row(emptyArray()) } else { + val columnsFromDataSet = table.tableMetaData.columns.toSet() + val columnsFromDefaults = this.extensions.defaults.forTable(tableName).map { Column(it.key, DataType.UNKNOWN) } + val columns: Array = (columnsFromDefaults + columnsFromDataSet).distinctBy { keyOf(it.columnName) }.toTypedArray() + + consumer.startTable(DefaultTableMetaData(tableName, columns, primaryKeys)) + (0 until table.rowCount).forEach { rowIndex -> val cells: Array = columns.map { column -> val rawCellValue: Any? = columnsFromDataSet.find { it.columnName.equals(column.columnName, ignoreCase) }?.let { table.getValue(rowIndex, it.columnName) } @@ -91,29 +93,33 @@ internal class TableBasedDataSetBuilder( tables.groupBy { keyOf(it.name) }.mapValues { (_, tables) -> tables.flatMap { it.rows } }.forEach { (tableName, rows) -> - val defaults: Set = this.extensions.defaults.forTable(tableName) - val columnsFromDataSet = rows.flatMap { it.cells }.map { Column(it.key, DataType.UNKNOWN) } - val columnsFromDefaults = this.extensions.defaults.forTable(tableName).map { Column(it.key, DataType.UNKNOWN) } - val columns: Array = if(columnsFromDataSet.isNotEmpty()) (columnsFromDefaults + columnsFromDataSet).distinctBy { keyOf(it.columnName) }.toTypedArray() else emptyArray() - - consumer.startTable(DefaultTableMetaData(tableName, columns)) - - if (rows.isEmpty()) consumer.row(arrayOfNulls(columns.size)) - else rows.forEach { row -> - val cells: Array = columns - .map { column: Column -> - val rawCellValue: Any? = row.cells.find { it.key.equals(column.columnName, ignoreCase) }?.value - val fallback = if(rawCellValue == null) defaults.find { it.key.equals(column.columnName, ignoreCase) }?.value else null - // Process all extensions and fallback to a default value if the cell is not set - extensions - .applyToCell(tableName, Cell(column.columnName, rawCellValue), overrides) - .mapValue { value -> - value ?: fallback - } - } - .map { it.value } - .toTypedArray() - consumer.row(cells) + if (rows.isEmpty()) { + consumer.startTable(DefaultTableMetaData(tableName, emptyArray())) + consumer.row(emptyArray()) + } else { + val defaults: Set = this.extensions.defaults.forTable(tableName) + val columnsFromDataSet = rows.flatMap { it.cells }.map { Column(it.key, DataType.UNKNOWN) } + val columnsFromDefaults = this.extensions.defaults.forTable(tableName).map { Column(it.key, DataType.UNKNOWN) } + val columns: Array = (columnsFromDefaults + columnsFromDataSet).distinctBy { keyOf(it.columnName) }.toTypedArray() + + consumer.startTable(DefaultTableMetaData(tableName, columns)) + + rows.forEach { row -> + val cells: Array = columns + .map { column: Column -> + val rawCellValue: Any? = row.cells.find { it.key.equals(column.columnName, ignoreCase) }?.value + val fallback = if(rawCellValue == null) defaults.find { it.key.equals(column.columnName, ignoreCase) }?.value else null + // Process all extensions and fallback to a default value if the cell is not set + extensions + .applyToCell(tableName, Cell(column.columnName, rawCellValue), overrides) + .mapValue { value -> + value ?: fallback + } + } + .map { it.value } + .toTypedArray() + consumer.row(cells) + } } consumer.endTable() @@ -122,4 +128,4 @@ internal class TableBasedDataSetBuilder( consumer.endDataSet() return DecoratedDataSet(dataSet) } -} \ No newline at end of file +}