Skip to content

Commit

Permalink
Merge branch 'main' into fix_tables_creation_with_foreign_key_reference
Browse files Browse the repository at this point in the history
  • Loading branch information
joc-a authored Sep 4, 2023
2 parents 7f653d7 + c3b998d commit f79b70e
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 49 deletions.
1 change: 1 addition & 0 deletions exposed-core/api/exposed-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -2240,6 +2240,7 @@ public class org/jetbrains/exposed/sql/Table : org/jetbrains/exposed/sql/ColumnS
public final fun getForeignKeys ()Ljava/util/List;
public final fun getIndices ()Ljava/util/List;
public fun getPrimaryKey ()Lorg/jetbrains/exposed/sql/Table$PrimaryKey;
public final fun getSchemaName ()Ljava/lang/String;
public fun getTableName ()Ljava/lang/String;
public fun hashCode ()I
public final fun index (Ljava/lang/String;Z[Lorg/jetbrains/exposed/sql/Column;Ljava/util/List;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)V
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,17 +257,17 @@ data class Index(
/** Name of the index. */
val indexName: String
get() = customName ?: buildString {
append(table.nameInDatabaseCase())
append(table.nameInDatabaseCaseUnquoted())
append('_')
append(columns.joinToString("_") { it.name }.inProperCase())
append(columns.joinToString("_") { it.name })
functions?.let { f ->
if (columns.isNotEmpty()) append('_')
append(f.joinToString("_") { it.toString().substringBefore("(").lowercase() }.inProperCase())
append(f.joinToString("_") { it.toString().substringBefore("(").lowercase() })
}
if (unique) {
append("_unique".inProperCase())
append("_unique")
}
}
}.inProperCase()

init {
require(columns.isNotEmpty() || functions?.isNotEmpty() == true) { "At least one column or function is required to create an index" }
Expand Down
13 changes: 8 additions & 5 deletions exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Table.kt
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,10 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {
else -> javaClass.name.removePrefix("${javaClass.`package`.name}.").substringAfter('$').removeSuffix("Table")
}

internal val tableNameWithoutScheme: String get() = tableName.substringAfter(".")
/** Returns the schema name, or null if one does not exist for this table. */
val schemaName: String? = if (name.contains(".")) name.substringBeforeLast(".") else null

internal val tableNameWithoutScheme: String get() = tableName.substringAfterLast(".")

// Table name may contain quotes, remove those before appending
internal val tableNameWithoutSchemeSanitized: String get() = tableNameWithoutScheme.replace("\"", "").replace("'", "")
Expand Down Expand Up @@ -369,15 +372,15 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {
fun nameInDatabaseCase(): String = tableName.inProperCase()

/**
* Returns the table name, in proper case, with wrapping single- and double-quotation characters removed.
* Returns the table name, without schema and in proper case, with wrapping single- and double-quotation characters removed.
*
* **Note** If used with MySQL or MariaDB, the column name is returned unchanged, since these databases use a
* **Note** If used with MySQL or MariaDB, the table name is returned unchanged, since these databases use a
* backtick character as the identifier quotation.
*/
fun nameInDatabaseCaseUnquoted(): String = if (currentDialect is MysqlDialect) {
nameInDatabaseCase()
tableNameWithoutScheme.inProperCase()
} else {
nameInDatabaseCase().trim('\"', '\'')
tableNameWithoutScheme.inProperCase().trim('\"', '\'')
}

override fun describe(s: Transaction, queryBuilder: QueryBuilder): Unit = queryBuilder { append(s.identity(this@Table)) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1139,7 +1139,7 @@ abstract class VendorDialect(
}

override fun tableExists(table: Table): Boolean {
val tableScheme = table.tableName.substringBefore('.', "").takeIf { it.isNotEmpty() }
val tableScheme = table.schemaName
val scheme = tableScheme?.inProperCase() ?: TransactionManager.current().connection.metadata { currentScheme }
val allTables = getAllTableNamesCache().getValue(scheme)
return allTables.any {
Expand Down Expand Up @@ -1171,11 +1171,11 @@ abstract class VendorDialect(
): Map<Pair<Table, LinkedHashSet<Column<*>>>, List<ForeignKeyConstraint>> {
val constraints = HashMap<Pair<Table, LinkedHashSet<Column<*>>>, MutableList<ForeignKeyConstraint>>()

val tablesToLoad = tables.filter { !columnConstraintsCache.containsKey(it.nameInDatabaseCase()) }
val tablesToLoad = tables.filter { !columnConstraintsCache.containsKey(it.nameInDatabaseCaseUnquoted()) }

fillConstraintCacheForTables(tablesToLoad)
tables.forEach { table ->
columnConstraintsCache[table.nameInDatabaseCase()].orEmpty().forEach {
columnConstraintsCache[table.nameInDatabaseCaseUnquoted()].orEmpty().forEach {
constraints.getOrPut(table to it.from) { arrayListOf() }.add(it)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,11 @@ open class MysqlDialect : VendorDialect(dialectName, MysqlDataTypeProvider, Mysq
}

override fun fillConstraintCacheForTables(tables: List<Table>) {
val allTables = SchemaUtils.sortTablesByReferences(tables).associateBy { it.nameInDatabaseCase() }
val allTables = SchemaUtils.sortTablesByReferences(tables).associateBy { it.nameInDatabaseCaseUnquoted() }
val allTableNames = allTables.keys
val inTableList = allTableNames.joinToString("','", prefix = " ku.TABLE_NAME IN ('", postfix = "')")
val tr = TransactionManager.current()
val schemaName = "'${getDatabase()}'"
val tableSchema = "'${tables.mapNotNull { it.schemaName }.toSet().singleOrNull() ?: getDatabase()}'"
val constraintsToLoad = HashMap<String, MutableMap<String, ForeignKeyConstraint>>()
tr.exec(
"""SELECT
Expand All @@ -312,9 +312,9 @@ open class MysqlDialect : VendorDialect(dialectName, MysqlDataTypeProvider, Mysq
FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc
INNER JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE ku
ON ku.TABLE_SCHEMA = rc.CONSTRAINT_SCHEMA AND rc.CONSTRAINT_NAME = ku.CONSTRAINT_NAME
WHERE ku.TABLE_SCHEMA = $schemaName
AND ku.CONSTRAINT_SCHEMA = $schemaName
AND rc.CONSTRAINT_SCHEMA = $schemaName
WHERE ku.TABLE_SCHEMA = $tableSchema
AND ku.CONSTRAINT_SCHEMA = $tableSchema
AND rc.CONSTRAINT_SCHEMA = $tableSchema
AND $inTableList
ORDER BY ku.ORDINAL_POSITION
""".trimIndent()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,25 +146,32 @@ class JdbcDatabaseMetadataImpl(database: String, val metadata: DatabaseMetaData)
}

override fun columns(vararg tables: Table): Map<Table, List<ColumnMetadata>> {
val rs = metadata.getColumns(databaseName, currentScheme, "%", "%")
val result = rs.extractColumns(tables) {
// @see java.sql.DatabaseMetaData.getColumns
// That read should go first as Oracle driver closes connection after that
val defaultDbValue = it.getString("COLUMN_DEF")?.let { sanitizedDefault(it) }
val autoIncrement = it.getString("IS_AUTOINCREMENT") == "YES"
val type = it.getInt("DATA_TYPE")
val columnMetadata = ColumnMetadata(
it.getString("COLUMN_NAME"),
type,
it.getBoolean("NULLABLE"),
it.getInt("COLUMN_SIZE").takeIf { it != 0 },
autoIncrement,
// Not sure this filters enough but I dont think we ever want to have sequences here
defaultDbValue?.takeIf { !autoIncrement },
)
it.getString("TABLE_NAME") to columnMetadata
val result = mutableMapOf<Table, List<ColumnMetadata>>()
val useSchemaInsteadOfDatabase = currentDialect is MysqlDialect

val tablesBySchema = tables.groupBy { identifierManager.inProperCase(it.schemaName ?: currentScheme) }
tablesBySchema.forEach { (schema, schemaTables) ->
val catalog = if (!useSchemaInsteadOfDatabase || schema == currentScheme) databaseName else schema
val rs = metadata.getColumns(catalog, schema, "%", "%")
result += rs.extractColumns(schemaTables.toTypedArray()) {
// @see java.sql.DatabaseMetaData.getColumns
// That read should go first as Oracle driver closes connection after that
val defaultDbValue = it.getString("COLUMN_DEF")?.let { sanitizedDefault(it) }
val autoIncrement = it.getString("IS_AUTOINCREMENT") == "YES"
val type = it.getInt("DATA_TYPE")
val columnMetadata = ColumnMetadata(
it.getString("COLUMN_NAME"),
type,
it.getBoolean("NULLABLE"),
it.getInt("COLUMN_SIZE").takeIf { it != 0 },
autoIncrement,
// Not sure this filters enough but I dont think we ever want to have sequences here
defaultDbValue?.takeIf { !autoIncrement },
)
it.getString("TABLE_NAME") to columnMetadata
}
rs.close()
}
rs.close()
return result
}

Expand All @@ -188,20 +195,23 @@ class JdbcDatabaseMetadataImpl(database: String, val metadata: DatabaseMetaData)

private val existingIndicesCache = HashMap<Table, List<Index>>()

@Suppress("CyclomaticComplexMethod")
override fun existingIndices(vararg tables: Table): Map<Table, List<Index>> {
for (table in tables) {
val transaction = TransactionManager.current()
val (catalog, tableSchema) = tableCatalogAndSchema(table)

existingIndicesCache.getOrPut(table) {
val pkNames = metadata.getPrimaryKeys(databaseName, currentScheme, table.nameInDatabaseCaseUnquoted()).let { rs ->
val pkNames = metadata.getPrimaryKeys(catalog, tableSchema, table.nameInDatabaseCaseUnquoted()).let { rs ->
val names = arrayListOf<String>()
while (rs.next()) {
rs.getString("PK_NAME")?.let { names += it }
}
rs.close()
names
}
val rs = metadata.getIndexInfo(databaseName, currentScheme, table.nameInDatabaseCase(), false, false)
val storedIndexTable = if (tableSchema == currentScheme) table.nameInDatabaseCase() else table.nameInDatabaseCaseUnquoted()
val rs = metadata.getIndexInfo(catalog, tableSchema, storedIndexTable, false, false)

val tmpIndices = hashMapOf<Triple<String, Boolean, Op.TRUE?>, MutableList<String>>()

Expand Down Expand Up @@ -243,7 +253,8 @@ class JdbcDatabaseMetadataImpl(database: String, val metadata: DatabaseMetaData)

override fun existingPrimaryKeys(vararg tables: Table): Map<Table, PrimaryKeyMetadata?> {
return tables.associateWith { table ->
metadata.getPrimaryKeys(databaseName, currentScheme, table.nameInDatabaseCaseUnquoted()).let { rs ->
val (catalog, tableSchema) = tableCatalogAndSchema(table)
metadata.getPrimaryKeys(catalog, tableSchema, table.nameInDatabaseCaseUnquoted()).let { rs ->
val columnNames = mutableListOf<String>()
var pkName = ""
while (rs.next()) {
Expand All @@ -258,9 +269,10 @@ class JdbcDatabaseMetadataImpl(database: String, val metadata: DatabaseMetaData)

@Synchronized
override fun tableConstraints(tables: List<Table>): Map<String, List<ForeignKeyConstraint>> {
val allTables = SchemaUtils.sortTablesByReferences(tables).associateBy { it.nameInDatabaseCase() }
val allTables = SchemaUtils.sortTablesByReferences(tables).associateBy { it.nameInDatabaseCaseUnquoted() }
return allTables.keys.associateWith { table ->
metadata.getImportedKeys(databaseName, currentScheme, table).iterate {
val (catalog, tableSchema) = tableCatalogAndSchema(allTables[table]!!)
metadata.getImportedKeys(catalog, identifierManager.inProperCase(tableSchema), table).iterate {
val fromTableName = getString("FKTABLE_NAME")!!
val fromColumnName = identifierManager.quoteIdentifierWhenWrongCaseOrNecessary(getString("FKCOLUMN_NAME")!!)
val fromColumn = allTables[fromTableName]?.columns?.firstOrNull {
Expand Down Expand Up @@ -291,6 +303,24 @@ class JdbcDatabaseMetadataImpl(database: String, val metadata: DatabaseMetaData)
}
}

/**
* Returns the name of the database in which a [table] is found, as well as it's schema name.
*
* If the table name does not include a schema prefix, the metadata value `currentScheme` is used instead.
*
* MySQL/MariaDB are special cases in that a schema definition is treated like a separate database. This means that
* a connection to 'testDb' with a table defined as 'my_schema.my_table' will only successfully find the table's
* metadata if 'my_schema' is used as the database name.
*/
private fun tableCatalogAndSchema(table: Table): Pair<String, String> {
val tableSchema = identifierManager.inProperCase(table.schemaName ?: currentScheme)
return if (currentDialect is MysqlDialect && tableSchema != currentScheme) {
tableSchema to tableSchema
} else {
databaseName to tableSchema
}
}

@Synchronized
override fun cleanCache() {
existingIndicesCache.clear()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.jetbrains.exposed.sql.tests.shared.ddl
import org.jetbrains.exposed.dao.id.EntityID
import org.jetbrains.exposed.dao.id.IdTable
import org.jetbrains.exposed.dao.id.IntIdTable
import org.jetbrains.exposed.dao.id.LongIdTable
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.SqlExpressionBuilder.isNull
import org.jetbrains.exposed.sql.tests.DatabaseTestsBase
Expand Down Expand Up @@ -294,7 +295,7 @@ class CreateMissingTablesAndColumnsTests : DatabaseTestsBase() {
withDb { testDb ->
try {
// MySQL doesn't support default values on text columns, hence excluded
table = if(testDb != TestDB.MYSQL) {
table = if (testDb != TestDB.MYSQL) {
object : Table("varchar_test") {
val varchar = varchar("varchar_column", 255).default(" ")
val text = text("text_column").default(" ")
Expand Down Expand Up @@ -328,7 +329,7 @@ class CreateMissingTablesAndColumnsTests : DatabaseTestsBase() {

@Test
fun `columns with default values that are whitespaces shouldn't be treated as empty strings`() {
val tableWhitespaceDefaultVarchar = StringFieldTable("varchar_whitespace_test", false," ")
val tableWhitespaceDefaultVarchar = StringFieldTable("varchar_whitespace_test", false, " ")

val tableWhitespaceDefaultText = StringFieldTable("text_whitespace_test", true, " ")

Expand Down Expand Up @@ -538,9 +539,10 @@ class CreateMissingTablesAndColumnsTests : DatabaseTestsBase() {
uniqueIndex("index2", value2, value1)
}
}

@Test fun testCreateTableWithReferenceMultipleTimes() {
withTables(SessionsTable, PlayerTable) {

@Test
fun testCreateTableWithReferenceMultipleTimes() {
withTables(PlayerTable, SessionsTable) {
SchemaUtils.createMissingTablesAndColumns(PlayerTable, SessionsTable)
SchemaUtils.createMissingTablesAndColumns(PlayerTable, SessionsTable)
}
Expand All @@ -554,7 +556,8 @@ class CreateMissingTablesAndColumnsTests : DatabaseTestsBase() {
val playerId = integer("player_id").references(PlayerTable.id)
}

@Test fun createTableWithReservedIdentifierInColumnName() {
@Test
fun createTableWithReservedIdentifierInColumnName() {
withDb(TestDB.MYSQL) {
SchemaUtils.createMissingTablesAndColumns(T1, T2)
SchemaUtils.createMissingTablesAndColumns(T1, T2)
Expand All @@ -567,11 +570,13 @@ class CreateMissingTablesAndColumnsTests : DatabaseTestsBase() {
object ExplicitTable : IntIdTable() {
val playerId = integer("player_id").references(PlayerTable.id, fkName = "Explicit_FK_NAME")
}

object NonExplicitTable : IntIdTable() {
val playerId = integer("player_id").references(PlayerTable.id)
}

@Test fun explicitFkNameIsExplicit() {
@Test
fun explicitFkNameIsExplicit() {
withTables(ExplicitTable, NonExplicitTable) {
assertEquals("Explicit_FK_NAME", ExplicitTable.playerId.foreignKey!!.customFkName)
assertEquals(null, NonExplicitTable.playerId.foreignKey!!.customFkName)
Expand All @@ -582,6 +587,7 @@ class CreateMissingTablesAndColumnsTests : DatabaseTestsBase() {
val name = integer("name").uniqueIndex()
val tmp = varchar("temp", 255)
}

object T2 : Table("CHAIN") {
val ref = integer("ref").references(T1.name)
}
Expand Down Expand Up @@ -620,7 +626,7 @@ class CreateMissingTablesAndColumnsTests : DatabaseTestsBase() {

@Test
fun testCreateCompositePrimaryKeyTableAndCompositeForeignKeyTableMultipleTimes() {
withTables(CompositeForeignKeyTable, CompositePrimaryKeyTable) {
withTables(CompositePrimaryKeyTable, CompositeForeignKeyTable) {
SchemaUtils.createMissingTablesAndColumns(CompositePrimaryKeyTable, CompositeForeignKeyTable)
SchemaUtils.createMissingTablesAndColumns(CompositePrimaryKeyTable, CompositeForeignKeyTable)
}
Expand Down Expand Up @@ -661,4 +667,36 @@ class CreateMissingTablesAndColumnsTests : DatabaseTestsBase() {
SchemaUtils.createMissingTablesAndColumns(CompositeForeignKeyTable, CompositePrimaryKeyTable)
}
}

@Test
fun testCreateTableWithSchemaPrefix() {
val schemaName = "my_schema"
val schema = Schema(schemaName)
// index and foreign key both use table name to auto-generate their own names & to compare metadata
val parentTable = object : IntIdTable("$schemaName.parent_table") {
val secondId = integer("second_id").uniqueIndex()
}
val childTable = object : LongIdTable("$schemaName.child_table") {
val parent = reference("my_parent", parentTable)
}

// SQLite does not recognize creation of schema other than the attached database
withDb(excludeSettings = listOf(TestDB.SQLITE)) { testDb ->
SchemaUtils.createSchema(schema)
SchemaUtils.create(parentTable, childTable)

try {
SchemaUtils.createMissingTablesAndColumns(parentTable, childTable)
assertTrue(parentTable.exists())
assertTrue(childTable.exists())
} finally {
if (testDb == TestDB.SQLSERVER) {
SchemaUtils.drop(childTable, parentTable)
SchemaUtils.dropSchema(schema)
} else {
SchemaUtils.dropSchema(schema, cascade = true)
}
}
}
}
}

0 comments on commit f79b70e

Please sign in to comment.