diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 000000000..13566b81b --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/gorm.iml b/.idea/gorm.iml new file mode 100644 index 000000000..5e764c4f0 --- /dev/null +++ b/.idea/gorm.iml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/.idea/markdown.xml b/.idea/markdown.xml new file mode 100644 index 000000000..1e3409412 --- /dev/null +++ b/.idea/markdown.xml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 000000000..1ff239ec2 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 000000000..94a25f7f4 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/callback_create.go b/callback_create.go index 59840f863..4b42f5f63 100644 --- a/callback_create.go +++ b/callback_create.go @@ -130,7 +130,7 @@ func createCallback(scope *Scope) { // execute create sql: no primaryField if primaryField == nil { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + if result, err := scope.SQLDB().ExecContext(scope.Context(), scope.SQL, scope.SQLVars...); scope.Err(err) == nil { // set rows affected count scope.db.RowsAffected, _ = result.RowsAffected() @@ -146,7 +146,7 @@ func createCallback(scope *Scope) { // execute create sql: lastInsertID implemention for majority of dialects if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + if result, err := scope.SQLDB().ExecContext(scope.Context(), scope.SQL, scope.SQLVars...); scope.Err(err) == nil { // set rows affected count scope.db.RowsAffected, _ = result.RowsAffected() @@ -162,7 +162,7 @@ func createCallback(scope *Scope) { // execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql) if primaryField.Field.CanAddr() { - if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { + if err := scope.SQLDB().QueryRowContext(scope.Context(), scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { primaryField.IsBlank = false scope.db.RowsAffected = 1 } diff --git a/callback_query.go b/callback_query.go index f75627152..b159f2d8c 100644 --- a/callback_query.go +++ b/callback_query.go @@ -65,7 +65,7 @@ func queryCallback(scope *Scope) { scope.SQL = fmt.Sprint(str) + scope.SQL } - if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + if rows, err := scope.SQLDB().QueryContext(scope.Context(), scope.SQL, scope.SQLVars...); scope.Err(err) == nil { defer rows.Close() columns, _ := rows.Columns() diff --git a/callback_row_query.go b/callback_row_query.go index 43b21f83f..d383e1129 100644 --- a/callback_row_query.go +++ b/callback_row_query.go @@ -29,9 +29,9 @@ func rowQueryCallback(scope *Scope) { } if rowResult, ok := result.(*RowQueryResult); ok { - rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) + rowResult.Row = scope.SQLDB().QueryRowContext(scope.Context(), scope.SQL, scope.SQLVars...) } else if rowsResult, ok := result.(*RowsQueryResult); ok { - rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...) + rowsResult.Rows, rowsResult.Error = scope.SQLDB().QueryContext(scope.Context(), scope.SQL, scope.SQLVars...) } } } diff --git a/customize_column_test.go b/customize_column_test.go index c236ac24e..da65406e7 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -26,7 +26,7 @@ func TestCustomizeColumn(t *testing.T) { DB.AutoMigrate(&CustomizeColumn{}) scope := DB.NewScope(&CustomizeColumn{}) - if !scope.Dialect().HasColumn(scope.TableName(), col) { + if !scope.Dialect().HasColumn(scope.Context(), scope.TableName(), col) { t.Errorf("CustomizeColumn should have column %s", col) } diff --git a/dialect.go b/dialect.go index c742efcd2..0d5408102 100644 --- a/dialect.go +++ b/dialect.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "database/sql" "fmt" "reflect" @@ -24,17 +25,17 @@ type Dialect interface { DataTypeOf(field *StructField) string // HasIndex check has index or not - HasIndex(tableName string, indexName string) bool + HasIndex(ctx context.Context, tableName string, indexName string) bool // HasForeignKey check has foreign key or not - HasForeignKey(tableName string, foreignKeyName string) bool + HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool // RemoveIndex remove index - RemoveIndex(tableName string, indexName string) error + RemoveIndex(ctx context.Context, tableName string, indexName string) error // HasTable check has table or not - HasTable(tableName string) bool + HasTable(ctx context.Context, tableName string) bool // HasColumn check has column or not - HasColumn(tableName string, columnName string) bool + HasColumn(ctx context.Context, tableName string, columnName string) bool // ModifyColumn modify column's type - ModifyColumn(tableName string, columnName string, typ string) error + ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case LimitAndOffsetSQL(limit, offset interface{}) (string, error) @@ -54,7 +55,7 @@ type Dialect interface { NormalizeIndexAndColumn(indexName, columnName string) (string, string) // CurrentDatabase return current database name - CurrentDatabase() string + CurrentDatabase(ctx context.Context) string } var dialectsMap = map[string]Dialect{} @@ -67,9 +68,9 @@ func newDialect(name string, db SQLCommon) Dialect { } fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name) - commontDialect := &commonDialect{} - commontDialect.SetDB(db) - return commontDialect + cd := &commonDialect{} + cd.SetDB(db) + return cd } // RegisterDialect register new dialect @@ -138,10 +139,10 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel return fieldValue, dataType, size, strings.TrimSpace(additionalType) } -func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) { +func currentDatabaseAndTable(ctx context.Context, dialect Dialect, tableName string) (string, string) { if strings.Contains(tableName, ".") { splitStrings := strings.SplitN(tableName, ".", 2) return splitStrings[0], splitStrings[1] } - return dialect.CurrentDatabase(), tableName + return dialect.CurrentDatabase(ctx), tableName } diff --git a/dialect_common.go b/dialect_common.go index d549510cc..f331a0b27 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "fmt" "reflect" "regexp" @@ -99,43 +100,43 @@ func (s *commonDialect) DataTypeOf(field *StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } -func (s commonDialect) HasIndex(tableName string, indexName string) bool { +func (s commonDialect) HasIndex(ctx context.Context, tableName string, indexName string) bool { var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count) return count > 0 } -func (s commonDialect) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName)) +func (s commonDialect) RemoveIndex(ctx context.Context, tableName string, indexName string) error { + _, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v", indexName)) return err } -func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool { +func (s commonDialect) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool { return false } -func (s commonDialect) HasTable(tableName string) bool { +func (s commonDialect) HasTable(ctx context.Context, tableName string) bool { var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count) return count > 0 } -func (s commonDialect) HasColumn(tableName string, columnName string) bool { +func (s commonDialect) HasColumn(ctx context.Context, tableName string, columnName string) bool { var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) return count > 0 } -func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error { - _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ)) +func (s commonDialect) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error { + _, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ)) return err } -func (s commonDialect) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) +func (s commonDialect) CurrentDatabase(ctx context.Context) (name string) { + s.db.QueryRowContext(ctx, "SELECT DATABASE()").Scan(&name) return } diff --git a/dialect_mysql.go b/dialect_mysql.go index b4467ffa1..e8c27f995 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "crypto/sha1" "database/sql" "fmt" @@ -129,13 +130,13 @@ func (s *mysql) DataTypeOf(field *StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } -func (s mysql) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) +func (s mysql) RemoveIndex(ctx context.Context, tableName string, indexName string) error { + _, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) return err } -func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error { - _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ)) +func (s mysql) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error { + _, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ)) return err } @@ -162,18 +163,18 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err err return } -func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { +func (s mysql) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool { var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count) return count > 0 } -func (s mysql) HasTable(tableName string) bool { - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) +func (s mysql) HasTable(ctx context.Context, tableName string) bool { + currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName) var name string // allow mysql database name with '-' character - if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil { + if err := s.db.QueryRowContext(ctx, fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil { if err == sql.ErrNoRows { return false } @@ -183,9 +184,9 @@ func (s mysql) HasTable(tableName string) bool { } } -func (s mysql) HasIndex(tableName string, indexName string) bool { - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - if rows, err := s.db.Query(fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil { +func (s mysql) HasIndex(ctx context.Context, tableName string, indexName string) bool { + currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName) + if rows, err := s.db.QueryContext(ctx, fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil { panic(err) } else { defer rows.Close() @@ -193,9 +194,9 @@ func (s mysql) HasIndex(tableName string, indexName string) bool { } } -func (s mysql) HasColumn(tableName string, columnName string) bool { - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - if rows, err := s.db.Query(fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil { +func (s mysql) HasColumn(ctx context.Context, tableName string, columnName string) bool { + currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName) + if rows, err := s.db.QueryContext(ctx, fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil { panic(err) } else { defer rows.Close() @@ -203,8 +204,8 @@ func (s mysql) HasColumn(tableName string, columnName string) bool { } } -func (s mysql) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) +func (s mysql) CurrentDatabase(ctx context.Context) (name string) { + s.db.QueryRowContext(ctx, "SELECT DATABASE()").Scan(&name) return } diff --git a/dialect_postgres.go b/dialect_postgres.go index d2df31318..89571fb30 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "encoding/json" "fmt" "reflect" @@ -91,32 +92,32 @@ func (s *postgres) DataTypeOf(field *StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } -func (s postgres) HasIndex(tableName string, indexName string) bool { +func (s postgres) HasIndex(ctx context.Context, tableName string, indexName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count) return count > 0 } -func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool { +func (s postgres) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool { var count int - s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count) + s.db.QueryRowContext(ctx, "SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count) return count > 0 } -func (s postgres) HasTable(tableName string) bool { +func (s postgres) HasTable(ctx context.Context, tableName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count) return count > 0 } -func (s postgres) HasColumn(tableName string, columnName string) bool { +func (s postgres) HasColumn(ctx context.Context, tableName string, columnName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count) return count > 0 } -func (s postgres) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name) +func (s postgres) CurrentDatabase(ctx context.Context) (name string) { + s.db.QueryRowContext(ctx, "SELECT CURRENT_DATABASE()").Scan(&name) return } diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index 5f96c363a..6db7191f4 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "fmt" "reflect" "strings" @@ -70,25 +71,25 @@ func (s *sqlite3) DataTypeOf(field *StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } -func (s sqlite3) HasIndex(tableName string, indexName string) bool { +func (s sqlite3) HasIndex(ctx context.Context, tableName string, indexName string) bool { var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count) + s.db.QueryRowContext(ctx, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count) return count > 0 } -func (s sqlite3) HasTable(tableName string) bool { +func (s sqlite3) HasTable(ctx context.Context, tableName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count) return count > 0 } -func (s sqlite3) HasColumn(tableName string, columnName string) bool { +func (s sqlite3) HasColumn(ctx context.Context, tableName string, columnName string) bool { var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count) + s.db.QueryRowContext(ctx, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count) return count > 0 } -func (s sqlite3) CurrentDatabase() (name string) { +func (s sqlite3) CurrentDatabase(ctx context.Context) (name string) { var ( ifaces = make([]interface{}, 3) pointers = make([]*string, 3) @@ -97,7 +98,7 @@ func (s sqlite3) CurrentDatabase() (name string) { for i = 0; i < 3; i++ { ifaces[i] = &pointers[i] } - if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil { + if err := s.db.QueryRowContext(ctx, "PRAGMA database_list").Scan(ifaces...); err != nil { return } if pointers[1] != nil { diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index a516ed4af..98bb2f590 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -1,6 +1,7 @@ package mssql import ( + "context" "database/sql/driver" "encoding/json" "errors" @@ -12,6 +13,7 @@ import ( // Importing mssql driver package only in dialect file, otherwide not needed _ "github.com/denisenkom/go-mssqldb" + "github.com/jinzhu/gorm" ) @@ -122,21 +124,21 @@ func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool { return field.IsPrimaryKey } -func (s mssql) HasIndex(tableName string, indexName string) bool { +func (s mssql) HasIndex(ctx context.Context, tableName string, indexName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) return count > 0 } -func (s mssql) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) +func (s mssql) RemoveIndex(ctx context.Context, tableName string, indexName string) error { + _, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) return err } -func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { +func (s mssql) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool { var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow(`SELECT count(*) + currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName) + s.db.QueryRowContext(ctx, `SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join information_schema.tables as I on I.TABLE_NAME = T.name WHERE F.name = ? @@ -144,27 +146,27 @@ func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { return count > 0 } -func (s mssql) HasTable(tableName string) bool { +func (s mssql) HasTable(ctx context.Context, tableName string) bool { var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count) return count > 0 } -func (s mssql) HasColumn(tableName string, columnName string) bool { +func (s mssql) HasColumn(ctx context.Context, tableName string, columnName string) bool { var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName) + s.db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) return count > 0 } -func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error { - _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ)) +func (s mssql) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error { + _, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ)) return err } -func (s mssql) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) +func (s mssql) CurrentDatabase(ctx context.Context) (name string) { + s.db.QueryRowContext(ctx, "SELECT DB_NAME() AS [Current Database]").Scan(&name) return } @@ -220,12 +222,12 @@ func (mssql) NormalizeIndexAndColumn(indexName, columnName string) (string, stri return indexName, columnName } -func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) { +func currentDatabaseAndTable(ctx context.Context, dialect gorm.Dialect, tableName string) (string, string) { if strings.Contains(tableName, ".") { splitStrings := strings.SplitN(tableName, ".", 2) return splitStrings[0], splitStrings[1] } - return dialect.CurrentDatabase(), tableName + return dialect.CurrentDatabase(ctx), tableName } // JSON type to support easy handling of JSON data in character table fields diff --git a/embedded_struct_test.go b/embedded_struct_test.go index 5f8ece573..b0c8f7ef6 100644 --- a/embedded_struct_test.go +++ b/embedded_struct_test.go @@ -29,7 +29,7 @@ type EngadgetPost struct { func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) { dialect := DB.NewScope(&EngadgetPost{}).Dialect() engadgetPostScope := DB.NewScope(&EngadgetPost{}) - if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") { + if !dialect.HasColumn(engadgetPostScope.Context(), engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.Context(), engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.Context(), engadgetPostScope.TableName(), "author_email") { t.Errorf("should has prefix for embedded columns") } @@ -38,7 +38,7 @@ func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) { } hnScope := DB.NewScope(&HNPost{}) - if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") { + if !dialect.HasColumn(hnScope.Context(), hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.Context(), hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.Context(), hnScope.TableName(), "user_email") { t.Errorf("should has prefix for embedded columns") } } diff --git a/go.sum b/go.sum index d30630b95..b7956d8a6 100644 --- a/go.sum +++ b/go.sum @@ -16,10 +16,7 @@ github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4= github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-sqlite3 v1.14.0 h1:mLyGNKR8+Vv9CAU7PphKa2hkEqxxhn8i32J6FPj1/QA= github.com/mattn/go-sqlite3 v1.14.0/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus= -github.com/mattn/go-sqlite3 v2.0.1+incompatible h1:xQ15muvnzGBHpIpdrNi1DA5x0+TcBZzsIDwmw9uTHzw= -github.com/mattn/go-sqlite3 v2.0.1+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd h1:GGJVjV8waZKRHrgwvtH66z9ZGVurTD1MT0n1Bb+q4aM= golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= diff --git a/interface.go b/interface.go index fe6492314..8f119521d 100644 --- a/interface.go +++ b/interface.go @@ -7,10 +7,10 @@ import ( // SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. type SQLCommon interface { - Exec(query string, args ...interface{}) (sql.Result, error) - Prepare(query string) (*sql.Stmt, error) - Query(query string, args ...interface{}) (*sql.Rows, error) - QueryRow(query string, args ...interface{}) *sql.Row + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row } type sqlDb interface { diff --git a/main.go b/main.go index 0c247b6b9..2ef94e16b 100644 --- a/main.go +++ b/main.go @@ -20,6 +20,7 @@ type DB struct { // single db db SQLCommon + ctx context.Context blockGlobalUpdate bool logMode logModeValue logger logger @@ -83,9 +84,10 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { } db = &DB{ - db: dbSQL, - logger: defaultLogger, - + db: dbSQL, + ctx: context.Background(), + logger: defaultLogger, + // Create a clone of the default logger to avoid mutating a shared object when // multiple gorm connections are created simultaneously. callbacks: DefaultCallback.clone(defaultLogger), @@ -233,6 +235,14 @@ func (s *DB) SubQuery() *SqlExpr { return Expr(fmt.Sprintf("(%v)", scope.SQL), scope.SQLVars...) } +// WithContext returns the DB with the supplied context attached +func (s *DB) WithContext(ctx context.Context) *DB { + dbClone := s.clone() + dbClone.ctx = ctx + + return dbClone +} + // Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query func (s *DB) Where(query interface{}, args ...interface{}) *DB { return s.clone().search.Where(query, args...).db @@ -680,7 +690,7 @@ func (s *DB) HasTable(value interface{}) bool { tableName = scope.TableName() } - has := scope.Dialect().HasTable(tableName) + has := scope.Dialect().HasTable(s.ctx, tableName) s.AddError(scope.db.Error) return has } @@ -800,7 +810,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType handler.Setup(field.Relationship, many2many, source, destination) field.Relationship.JoinTableHandler = handler - if table := handler.Table(s); scope.Dialect().HasTable(table) { + if table := handler.Table(s); scope.Dialect().HasTable(s.ctx, table) { s.Table(table).AutoMigrate(handler) } } @@ -847,6 +857,7 @@ func (s *DB) GetErrors() []error { func (s *DB) clone() *DB { db := &DB{ db: s.db, + ctx: s.ctx, parent: s.parent, logger: s.logger, logMode: s.logMode, diff --git a/migration_test.go b/migration_test.go index 063c6f648..be459b852 100644 --- a/migration_test.go +++ b/migration_test.go @@ -284,8 +284,8 @@ func getPreparedUser(name string, role string) *User { } type Panda struct { - Number int64 `gorm:"unique_index:number"` - Name string `gorm:"column:name;type:varchar(255);default:null"` + Number int64 `gorm:"unique_index:number"` + Name string `gorm:"column:name;type:varchar(255);default:null"` } func runMigration() { @@ -312,7 +312,7 @@ func TestIndexes(t *testing.T) { } scope := DB.NewScope(&Email{}) - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { + if !scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "idx_email_email") { t.Errorf("Email should have index idx_email_email") } @@ -320,7 +320,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to remove index: %+v", err) } - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { + if scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "idx_email_email") { t.Errorf("Email's index idx_email_email should be deleted") } @@ -328,7 +328,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to create index: %+v", err) } - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { + if !scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "idx_email_email_and_user_id") { t.Errorf("Email should have index idx_email_email_and_user_id") } @@ -336,7 +336,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to remove index: %+v", err) } - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { + if scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "idx_email_email_and_user_id") { t.Errorf("Email's index idx_email_email_and_user_id should be deleted") } @@ -344,7 +344,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to create index: %+v", err) } - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { + if !scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "idx_email_email_and_user_id") { t.Errorf("Email should have index idx_email_email_and_user_id") } @@ -366,7 +366,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to remove index: %+v", err) } - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { + if scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "idx_email_email_and_user_id") { t.Errorf("Email's index idx_email_email_and_user_id should be deleted") } @@ -396,11 +396,11 @@ func TestAutoMigration(t *testing.T) { DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now}) scope := DB.NewScope(&EmailWithIdx{}) - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") { + if !scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "idx_email_agent") { t.Errorf("Failed to create index") } - if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") { + if !scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "uix_email_with_idxes_registered_at") { t.Errorf("Failed to create index") } @@ -479,23 +479,23 @@ func TestMultipleIndexes(t *testing.T) { DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"}) scope := DB.NewScope(&MultipleIndexes{}) - if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") { + if !scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "uix_multipleindexes_user_name") { t.Errorf("Failed to create index") } - if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") { + if !scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "uix_multipleindexes_user_email") { t.Errorf("Failed to create index") } - if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") { + if !scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "uix_multiple_indexes_email") { t.Errorf("Failed to create index") } - if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") { + if !scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "idx_multipleindexes_user_other") { t.Errorf("Failed to create index") } - if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") { + if !scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "idx_multiple_indexes_other") { t.Errorf("Failed to create index") } @@ -576,7 +576,7 @@ func TestIndexWithPrefixLength(t *testing.T) { if err := DB.CreateTable(table).Error; err != nil { t.Errorf("Failed to create %s table: %v", tableName, err) } - if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") { + if !scope.Dialect().HasIndex(scope.Context(), tableName, "idx_index_with_prefixes_length") { t.Errorf("Failed to create %s table index:", tableName) } }) diff --git a/scope.go b/scope.go index ea12ee2f1..e48dfd092 100644 --- a/scope.go +++ b/scope.go @@ -2,6 +2,7 @@ package gorm import ( "bytes" + "context" "database/sql" "database/sql/driver" "errors" @@ -66,6 +67,10 @@ func (scope *Scope) Dialect() Dialect { return scope.db.dialect } +func (scope *Scope) Context() context.Context { + return scope.db.ctx +} + // Quote used to quote string to escape them for database func (scope *Scope) Quote(str string) string { if strings.Contains(str, ".") { @@ -361,7 +366,7 @@ func (scope *Scope) Exec() *Scope { defer scope.trace(NowFunc()) if !scope.HasError() { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + if result, err := scope.SQLDB().ExecContext(scope.Context(), scope.SQL, scope.SQLVars...); scope.Err(err) == nil { if count, err := result.RowsAffected(); scope.Err(err) == nil { scope.db.RowsAffected = count } @@ -1136,7 +1141,7 @@ func (scope *Scope) createJoinTable(field *StructField) { if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { joinTableHandler := relationship.JoinTableHandler joinTable := joinTableHandler.Table(scope.db) - if !scope.Dialect().HasTable(joinTable) { + if !scope.Dialect().HasTable(scope.Context(), joinTable) { toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} var sqlTypes, primaryKeys []string @@ -1209,7 +1214,7 @@ func (scope *Scope) dropTable() *Scope { } func (scope *Scope) modifyColumn(column string, typ string) { - scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ)) + scope.db.AddError(scope.Dialect().ModifyColumn(scope.Context(), scope.QuotedTableName(), scope.Quote(column), typ)) } func (scope *Scope) dropColumn(column string) { @@ -1217,7 +1222,7 @@ func (scope *Scope) dropColumn(column string) { } func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { - if scope.Dialect().HasIndex(scope.TableName(), indexName) { + if scope.Dialect().HasIndex(scope.Context(), scope.TableName(), indexName) { return } @@ -1238,7 +1243,7 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on // Compatible with old generated key keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") - if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { + if scope.Dialect().HasForeignKey(scope.Context(), scope.TableName(), keyName) { return } var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` @@ -1247,7 +1252,7 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on func (scope *Scope) removeForeignKey(field string, dest string) { keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") - if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { + if !scope.Dialect().HasForeignKey(scope.Context(), scope.TableName(), keyName) { return } var mysql mysql @@ -1262,18 +1267,18 @@ func (scope *Scope) removeForeignKey(field string, dest string) { } func (scope *Scope) removeIndex(indexName string) { - scope.Dialect().RemoveIndex(scope.TableName(), indexName) + scope.Dialect().RemoveIndex(scope.Context(), scope.TableName(), indexName) } func (scope *Scope) autoMigrate() *Scope { tableName := scope.TableName() quotedTableName := scope.QuotedTableName() - if !scope.Dialect().HasTable(tableName) { + if !scope.Dialect().HasTable(scope.Context(), tableName) { scope.createTable() } else { for _, field := range scope.GetModelStruct().StructFields { - if !scope.Dialect().HasColumn(tableName, field.DBName) { + if !scope.Dialect().HasColumn(scope.Context(), tableName, field.DBName) { if field.IsNormal { sqlTag := scope.Dialect().DataTypeOf(field) scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()