diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 000000000..13566b81b
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/gorm.iml b/.idea/gorm.iml
new file mode 100644
index 000000000..5e764c4f0
--- /dev/null
+++ b/.idea/gorm.iml
@@ -0,0 +1,9 @@
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/markdown.xml b/.idea/markdown.xml
new file mode 100644
index 000000000..1e3409412
--- /dev/null
+++ b/.idea/markdown.xml
@@ -0,0 +1,9 @@
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 000000000..1ff239ec2
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 000000000..94a25f7f4
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/callback_create.go b/callback_create.go
index 59840f863..4b42f5f63 100644
--- a/callback_create.go
+++ b/callback_create.go
@@ -130,7 +130,7 @@ func createCallback(scope *Scope) {
// execute create sql: no primaryField
if primaryField == nil {
- if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
+ if result, err := scope.SQLDB().ExecContext(scope.Context(), scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()
@@ -146,7 +146,7 @@ func createCallback(scope *Scope) {
// execute create sql: lastInsertID implemention for majority of dialects
if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" {
- if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
+ if result, err := scope.SQLDB().ExecContext(scope.Context(), scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()
@@ -162,7 +162,7 @@ func createCallback(scope *Scope) {
// execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql)
if primaryField.Field.CanAddr() {
- if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
+ if err := scope.SQLDB().QueryRowContext(scope.Context(), scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
primaryField.IsBlank = false
scope.db.RowsAffected = 1
}
diff --git a/callback_query.go b/callback_query.go
index f75627152..b159f2d8c 100644
--- a/callback_query.go
+++ b/callback_query.go
@@ -65,7 +65,7 @@ func queryCallback(scope *Scope) {
scope.SQL = fmt.Sprint(str) + scope.SQL
}
- if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
+ if rows, err := scope.SQLDB().QueryContext(scope.Context(), scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
defer rows.Close()
columns, _ := rows.Columns()
diff --git a/callback_row_query.go b/callback_row_query.go
index 43b21f83f..d383e1129 100644
--- a/callback_row_query.go
+++ b/callback_row_query.go
@@ -29,9 +29,9 @@ func rowQueryCallback(scope *Scope) {
}
if rowResult, ok := result.(*RowQueryResult); ok {
- rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
+ rowResult.Row = scope.SQLDB().QueryRowContext(scope.Context(), scope.SQL, scope.SQLVars...)
} else if rowsResult, ok := result.(*RowsQueryResult); ok {
- rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
+ rowsResult.Rows, rowsResult.Error = scope.SQLDB().QueryContext(scope.Context(), scope.SQL, scope.SQLVars...)
}
}
}
diff --git a/customize_column_test.go b/customize_column_test.go
index c236ac24e..da65406e7 100644
--- a/customize_column_test.go
+++ b/customize_column_test.go
@@ -26,7 +26,7 @@ func TestCustomizeColumn(t *testing.T) {
DB.AutoMigrate(&CustomizeColumn{})
scope := DB.NewScope(&CustomizeColumn{})
- if !scope.Dialect().HasColumn(scope.TableName(), col) {
+ if !scope.Dialect().HasColumn(scope.Context(), scope.TableName(), col) {
t.Errorf("CustomizeColumn should have column %s", col)
}
diff --git a/dialect.go b/dialect.go
index c742efcd2..0d5408102 100644
--- a/dialect.go
+++ b/dialect.go
@@ -1,6 +1,7 @@
package gorm
import (
+ "context"
"database/sql"
"fmt"
"reflect"
@@ -24,17 +25,17 @@ type Dialect interface {
DataTypeOf(field *StructField) string
// HasIndex check has index or not
- HasIndex(tableName string, indexName string) bool
+ HasIndex(ctx context.Context, tableName string, indexName string) bool
// HasForeignKey check has foreign key or not
- HasForeignKey(tableName string, foreignKeyName string) bool
+ HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool
// RemoveIndex remove index
- RemoveIndex(tableName string, indexName string) error
+ RemoveIndex(ctx context.Context, tableName string, indexName string) error
// HasTable check has table or not
- HasTable(tableName string) bool
+ HasTable(ctx context.Context, tableName string) bool
// HasColumn check has column or not
- HasColumn(tableName string, columnName string) bool
+ HasColumn(ctx context.Context, tableName string, columnName string) bool
// ModifyColumn modify column's type
- ModifyColumn(tableName string, columnName string, typ string) error
+ ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
LimitAndOffsetSQL(limit, offset interface{}) (string, error)
@@ -54,7 +55,7 @@ type Dialect interface {
NormalizeIndexAndColumn(indexName, columnName string) (string, string)
// CurrentDatabase return current database name
- CurrentDatabase() string
+ CurrentDatabase(ctx context.Context) string
}
var dialectsMap = map[string]Dialect{}
@@ -67,9 +68,9 @@ func newDialect(name string, db SQLCommon) Dialect {
}
fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name)
- commontDialect := &commonDialect{}
- commontDialect.SetDB(db)
- return commontDialect
+ cd := &commonDialect{}
+ cd.SetDB(db)
+ return cd
}
// RegisterDialect register new dialect
@@ -138,10 +139,10 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel
return fieldValue, dataType, size, strings.TrimSpace(additionalType)
}
-func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) {
+func currentDatabaseAndTable(ctx context.Context, dialect Dialect, tableName string) (string, string) {
if strings.Contains(tableName, ".") {
splitStrings := strings.SplitN(tableName, ".", 2)
return splitStrings[0], splitStrings[1]
}
- return dialect.CurrentDatabase(), tableName
+ return dialect.CurrentDatabase(ctx), tableName
}
diff --git a/dialect_common.go b/dialect_common.go
index d549510cc..f331a0b27 100644
--- a/dialect_common.go
+++ b/dialect_common.go
@@ -1,6 +1,7 @@
package gorm
import (
+ "context"
"fmt"
"reflect"
"regexp"
@@ -99,43 +100,43 @@ func (s *commonDialect) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
-func (s commonDialect) HasIndex(tableName string, indexName string) bool {
+func (s commonDialect) HasIndex(ctx context.Context, tableName string, indexName string) bool {
var count int
- currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
- s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
+ currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
+ s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
return count > 0
}
-func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
- _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
+func (s commonDialect) RemoveIndex(ctx context.Context, tableName string, indexName string) error {
+ _, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v", indexName))
return err
}
-func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
+func (s commonDialect) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool {
return false
}
-func (s commonDialect) HasTable(tableName string) bool {
+func (s commonDialect) HasTable(ctx context.Context, tableName string) bool {
var count int
- currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
- s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
+ currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
+ s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
return count > 0
}
-func (s commonDialect) HasColumn(tableName string, columnName string) bool {
+func (s commonDialect) HasColumn(ctx context.Context, tableName string, columnName string) bool {
var count int
- currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
- s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
+ currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
+ s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
return count > 0
}
-func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error {
- _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
+func (s commonDialect) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error {
+ _, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
return err
}
-func (s commonDialect) CurrentDatabase() (name string) {
- s.db.QueryRow("SELECT DATABASE()").Scan(&name)
+func (s commonDialect) CurrentDatabase(ctx context.Context) (name string) {
+ s.db.QueryRowContext(ctx, "SELECT DATABASE()").Scan(&name)
return
}
diff --git a/dialect_mysql.go b/dialect_mysql.go
index b4467ffa1..e8c27f995 100644
--- a/dialect_mysql.go
+++ b/dialect_mysql.go
@@ -1,6 +1,7 @@
package gorm
import (
+ "context"
"crypto/sha1"
"database/sql"
"fmt"
@@ -129,13 +130,13 @@ func (s *mysql) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
-func (s mysql) RemoveIndex(tableName string, indexName string) error {
- _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
+func (s mysql) RemoveIndex(ctx context.Context, tableName string, indexName string) error {
+ _, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err
}
-func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error {
- _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ))
+func (s mysql) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error {
+ _, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ))
return err
}
@@ -162,18 +163,18 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err err
return
}
-func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
+func (s mysql) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool {
var count int
- currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
- s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
+ currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
+ s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
return count > 0
}
-func (s mysql) HasTable(tableName string) bool {
- currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
+func (s mysql) HasTable(ctx context.Context, tableName string) bool {
+ currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
var name string
// allow mysql database name with '-' character
- if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil {
+ if err := s.db.QueryRowContext(ctx, fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil {
if err == sql.ErrNoRows {
return false
}
@@ -183,9 +184,9 @@ func (s mysql) HasTable(tableName string) bool {
}
}
-func (s mysql) HasIndex(tableName string, indexName string) bool {
- currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
- if rows, err := s.db.Query(fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil {
+func (s mysql) HasIndex(ctx context.Context, tableName string, indexName string) bool {
+ currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
+ if rows, err := s.db.QueryContext(ctx, fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil {
panic(err)
} else {
defer rows.Close()
@@ -193,9 +194,9 @@ func (s mysql) HasIndex(tableName string, indexName string) bool {
}
}
-func (s mysql) HasColumn(tableName string, columnName string) bool {
- currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
- if rows, err := s.db.Query(fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil {
+func (s mysql) HasColumn(ctx context.Context, tableName string, columnName string) bool {
+ currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
+ if rows, err := s.db.QueryContext(ctx, fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil {
panic(err)
} else {
defer rows.Close()
@@ -203,8 +204,8 @@ func (s mysql) HasColumn(tableName string, columnName string) bool {
}
}
-func (s mysql) CurrentDatabase() (name string) {
- s.db.QueryRow("SELECT DATABASE()").Scan(&name)
+func (s mysql) CurrentDatabase(ctx context.Context) (name string) {
+ s.db.QueryRowContext(ctx, "SELECT DATABASE()").Scan(&name)
return
}
diff --git a/dialect_postgres.go b/dialect_postgres.go
index d2df31318..89571fb30 100644
--- a/dialect_postgres.go
+++ b/dialect_postgres.go
@@ -1,6 +1,7 @@
package gorm
import (
+ "context"
"encoding/json"
"fmt"
"reflect"
@@ -91,32 +92,32 @@ func (s *postgres) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
-func (s postgres) HasIndex(tableName string, indexName string) bool {
+func (s postgres) HasIndex(ctx context.Context, tableName string, indexName string) bool {
var count int
- s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
+ s.db.QueryRowContext(ctx, "SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
return count > 0
}
-func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
+func (s postgres) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool {
var count int
- s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
+ s.db.QueryRowContext(ctx, "SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
return count > 0
}
-func (s postgres) HasTable(tableName string) bool {
+func (s postgres) HasTable(ctx context.Context, tableName string) bool {
var count int
- s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
+ s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
return count > 0
}
-func (s postgres) HasColumn(tableName string, columnName string) bool {
+func (s postgres) HasColumn(ctx context.Context, tableName string, columnName string) bool {
var count int
- s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
+ s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
return count > 0
}
-func (s postgres) CurrentDatabase() (name string) {
- s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
+func (s postgres) CurrentDatabase(ctx context.Context) (name string) {
+ s.db.QueryRowContext(ctx, "SELECT CURRENT_DATABASE()").Scan(&name)
return
}
diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go
index 5f96c363a..6db7191f4 100644
--- a/dialect_sqlite3.go
+++ b/dialect_sqlite3.go
@@ -1,6 +1,7 @@
package gorm
import (
+ "context"
"fmt"
"reflect"
"strings"
@@ -70,25 +71,25 @@ func (s *sqlite3) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
-func (s sqlite3) HasIndex(tableName string, indexName string) bool {
+func (s sqlite3) HasIndex(ctx context.Context, tableName string, indexName string) bool {
var count int
- s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
+ s.db.QueryRowContext(ctx, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
return count > 0
}
-func (s sqlite3) HasTable(tableName string) bool {
+func (s sqlite3) HasTable(ctx context.Context, tableName string) bool {
var count int
- s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
+ s.db.QueryRowContext(ctx, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
return count > 0
}
-func (s sqlite3) HasColumn(tableName string, columnName string) bool {
+func (s sqlite3) HasColumn(ctx context.Context, tableName string, columnName string) bool {
var count int
- s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
+ s.db.QueryRowContext(ctx, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
return count > 0
}
-func (s sqlite3) CurrentDatabase() (name string) {
+func (s sqlite3) CurrentDatabase(ctx context.Context) (name string) {
var (
ifaces = make([]interface{}, 3)
pointers = make([]*string, 3)
@@ -97,7 +98,7 @@ func (s sqlite3) CurrentDatabase() (name string) {
for i = 0; i < 3; i++ {
ifaces[i] = &pointers[i]
}
- if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
+ if err := s.db.QueryRowContext(ctx, "PRAGMA database_list").Scan(ifaces...); err != nil {
return
}
if pointers[1] != nil {
diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go
index a516ed4af..98bb2f590 100644
--- a/dialects/mssql/mssql.go
+++ b/dialects/mssql/mssql.go
@@ -1,6 +1,7 @@
package mssql
import (
+ "context"
"database/sql/driver"
"encoding/json"
"errors"
@@ -12,6 +13,7 @@ import (
// Importing mssql driver package only in dialect file, otherwide not needed
_ "github.com/denisenkom/go-mssqldb"
+
"github.com/jinzhu/gorm"
)
@@ -122,21 +124,21 @@ func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool {
return field.IsPrimaryKey
}
-func (s mssql) HasIndex(tableName string, indexName string) bool {
+func (s mssql) HasIndex(ctx context.Context, tableName string, indexName string) bool {
var count int
- s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
+ s.db.QueryRowContext(ctx, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
return count > 0
}
-func (s mssql) RemoveIndex(tableName string, indexName string) error {
- _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
+func (s mssql) RemoveIndex(ctx context.Context, tableName string, indexName string) error {
+ _, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err
}
-func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
+func (s mssql) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool {
var count int
- currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
- s.db.QueryRow(`SELECT count(*)
+ currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
+ s.db.QueryRowContext(ctx, `SELECT count(*)
FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id
inner join information_schema.tables as I on I.TABLE_NAME = T.name
WHERE F.name = ?
@@ -144,27 +146,27 @@ func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
return count > 0
}
-func (s mssql) HasTable(tableName string) bool {
+func (s mssql) HasTable(ctx context.Context, tableName string) bool {
var count int
- currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
- s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count)
+ currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
+ s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count)
return count > 0
}
-func (s mssql) HasColumn(tableName string, columnName string) bool {
+func (s mssql) HasColumn(ctx context.Context, tableName string, columnName string) bool {
var count int
- currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
- s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
+ currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
+ s.db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
return count > 0
}
-func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error {
- _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ))
+func (s mssql) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error {
+ _, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ))
return err
}
-func (s mssql) CurrentDatabase() (name string) {
- s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
+func (s mssql) CurrentDatabase(ctx context.Context) (name string) {
+ s.db.QueryRowContext(ctx, "SELECT DB_NAME() AS [Current Database]").Scan(&name)
return
}
@@ -220,12 +222,12 @@ func (mssql) NormalizeIndexAndColumn(indexName, columnName string) (string, stri
return indexName, columnName
}
-func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) {
+func currentDatabaseAndTable(ctx context.Context, dialect gorm.Dialect, tableName string) (string, string) {
if strings.Contains(tableName, ".") {
splitStrings := strings.SplitN(tableName, ".", 2)
return splitStrings[0], splitStrings[1]
}
- return dialect.CurrentDatabase(), tableName
+ return dialect.CurrentDatabase(ctx), tableName
}
// JSON type to support easy handling of JSON data in character table fields
diff --git a/embedded_struct_test.go b/embedded_struct_test.go
index 5f8ece573..b0c8f7ef6 100644
--- a/embedded_struct_test.go
+++ b/embedded_struct_test.go
@@ -29,7 +29,7 @@ type EngadgetPost struct {
func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) {
dialect := DB.NewScope(&EngadgetPost{}).Dialect()
engadgetPostScope := DB.NewScope(&EngadgetPost{})
- if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") {
+ if !dialect.HasColumn(engadgetPostScope.Context(), engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.Context(), engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.Context(), engadgetPostScope.TableName(), "author_email") {
t.Errorf("should has prefix for embedded columns")
}
@@ -38,7 +38,7 @@ func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) {
}
hnScope := DB.NewScope(&HNPost{})
- if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") {
+ if !dialect.HasColumn(hnScope.Context(), hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.Context(), hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.Context(), hnScope.TableName(), "user_email") {
t.Errorf("should has prefix for embedded columns")
}
}
diff --git a/go.sum b/go.sum
index d30630b95..b7956d8a6 100644
--- a/go.sum
+++ b/go.sum
@@ -16,10 +16,7 @@ github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4=
github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-sqlite3 v1.14.0 h1:mLyGNKR8+Vv9CAU7PphKa2hkEqxxhn8i32J6FPj1/QA=
github.com/mattn/go-sqlite3 v1.14.0/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus=
-github.com/mattn/go-sqlite3 v2.0.1+incompatible h1:xQ15muvnzGBHpIpdrNi1DA5x0+TcBZzsIDwmw9uTHzw=
-github.com/mattn/go-sqlite3 v2.0.1+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
-golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI=
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd h1:GGJVjV8waZKRHrgwvtH66z9ZGVurTD1MT0n1Bb+q4aM=
golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
diff --git a/interface.go b/interface.go
index fe6492314..8f119521d 100644
--- a/interface.go
+++ b/interface.go
@@ -7,10 +7,10 @@ import (
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
type SQLCommon interface {
- Exec(query string, args ...interface{}) (sql.Result, error)
- Prepare(query string) (*sql.Stmt, error)
- Query(query string, args ...interface{}) (*sql.Rows, error)
- QueryRow(query string, args ...interface{}) *sql.Row
+ ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
+ PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
+ QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
+ QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
type sqlDb interface {
diff --git a/main.go b/main.go
index 0c247b6b9..2ef94e16b 100644
--- a/main.go
+++ b/main.go
@@ -20,6 +20,7 @@ type DB struct {
// single db
db SQLCommon
+ ctx context.Context
blockGlobalUpdate bool
logMode logModeValue
logger logger
@@ -83,9 +84,10 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
}
db = &DB{
- db: dbSQL,
- logger: defaultLogger,
-
+ db: dbSQL,
+ ctx: context.Background(),
+ logger: defaultLogger,
+
// Create a clone of the default logger to avoid mutating a shared object when
// multiple gorm connections are created simultaneously.
callbacks: DefaultCallback.clone(defaultLogger),
@@ -233,6 +235,14 @@ func (s *DB) SubQuery() *SqlExpr {
return Expr(fmt.Sprintf("(%v)", scope.SQL), scope.SQLVars...)
}
+// WithContext returns the DB with the supplied context attached
+func (s *DB) WithContext(ctx context.Context) *DB {
+ dbClone := s.clone()
+ dbClone.ctx = ctx
+
+ return dbClone
+}
+
// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
return s.clone().search.Where(query, args...).db
@@ -680,7 +690,7 @@ func (s *DB) HasTable(value interface{}) bool {
tableName = scope.TableName()
}
- has := scope.Dialect().HasTable(tableName)
+ has := scope.Dialect().HasTable(s.ctx, tableName)
s.AddError(scope.db.Error)
return has
}
@@ -800,7 +810,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
handler.Setup(field.Relationship, many2many, source, destination)
field.Relationship.JoinTableHandler = handler
- if table := handler.Table(s); scope.Dialect().HasTable(table) {
+ if table := handler.Table(s); scope.Dialect().HasTable(s.ctx, table) {
s.Table(table).AutoMigrate(handler)
}
}
@@ -847,6 +857,7 @@ func (s *DB) GetErrors() []error {
func (s *DB) clone() *DB {
db := &DB{
db: s.db,
+ ctx: s.ctx,
parent: s.parent,
logger: s.logger,
logMode: s.logMode,
diff --git a/migration_test.go b/migration_test.go
index 063c6f648..be459b852 100644
--- a/migration_test.go
+++ b/migration_test.go
@@ -284,8 +284,8 @@ func getPreparedUser(name string, role string) *User {
}
type Panda struct {
- Number int64 `gorm:"unique_index:number"`
- Name string `gorm:"column:name;type:varchar(255);default:null"`
+ Number int64 `gorm:"unique_index:number"`
+ Name string `gorm:"column:name;type:varchar(255);default:null"`
}
func runMigration() {
@@ -312,7 +312,7 @@ func TestIndexes(t *testing.T) {
}
scope := DB.NewScope(&Email{})
- if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
+ if !scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "idx_email_email") {
t.Errorf("Email should have index idx_email_email")
}
@@ -320,7 +320,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to remove index: %+v", err)
}
- if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
+ if scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "idx_email_email") {
t.Errorf("Email's index idx_email_email should be deleted")
}
@@ -328,7 +328,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to create index: %+v", err)
}
- if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
+ if !scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email should have index idx_email_email_and_user_id")
}
@@ -336,7 +336,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to remove index: %+v", err)
}
- if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
+ if scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
}
@@ -344,7 +344,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to create index: %+v", err)
}
- if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
+ if !scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email should have index idx_email_email_and_user_id")
}
@@ -366,7 +366,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Got error when tried to remove index: %+v", err)
}
- if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
+ if scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
}
@@ -396,11 +396,11 @@ func TestAutoMigration(t *testing.T) {
DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now})
scope := DB.NewScope(&EmailWithIdx{})
- if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") {
+ if !scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "idx_email_agent") {
t.Errorf("Failed to create index")
}
- if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") {
+ if !scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "uix_email_with_idxes_registered_at") {
t.Errorf("Failed to create index")
}
@@ -479,23 +479,23 @@ func TestMultipleIndexes(t *testing.T) {
DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"})
scope := DB.NewScope(&MultipleIndexes{})
- if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") {
+ if !scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "uix_multipleindexes_user_name") {
t.Errorf("Failed to create index")
}
- if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") {
+ if !scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "uix_multipleindexes_user_email") {
t.Errorf("Failed to create index")
}
- if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") {
+ if !scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "uix_multiple_indexes_email") {
t.Errorf("Failed to create index")
}
- if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") {
+ if !scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "idx_multipleindexes_user_other") {
t.Errorf("Failed to create index")
}
- if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") {
+ if !scope.Dialect().HasIndex(scope.Context(), scope.TableName(), "idx_multiple_indexes_other") {
t.Errorf("Failed to create index")
}
@@ -576,7 +576,7 @@ func TestIndexWithPrefixLength(t *testing.T) {
if err := DB.CreateTable(table).Error; err != nil {
t.Errorf("Failed to create %s table: %v", tableName, err)
}
- if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") {
+ if !scope.Dialect().HasIndex(scope.Context(), tableName, "idx_index_with_prefixes_length") {
t.Errorf("Failed to create %s table index:", tableName)
}
})
diff --git a/scope.go b/scope.go
index ea12ee2f1..e48dfd092 100644
--- a/scope.go
+++ b/scope.go
@@ -2,6 +2,7 @@ package gorm
import (
"bytes"
+ "context"
"database/sql"
"database/sql/driver"
"errors"
@@ -66,6 +67,10 @@ func (scope *Scope) Dialect() Dialect {
return scope.db.dialect
}
+func (scope *Scope) Context() context.Context {
+ return scope.db.ctx
+}
+
// Quote used to quote string to escape them for database
func (scope *Scope) Quote(str string) string {
if strings.Contains(str, ".") {
@@ -361,7 +366,7 @@ func (scope *Scope) Exec() *Scope {
defer scope.trace(NowFunc())
if !scope.HasError() {
- if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
+ if result, err := scope.SQLDB().ExecContext(scope.Context(), scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
if count, err := result.RowsAffected(); scope.Err(err) == nil {
scope.db.RowsAffected = count
}
@@ -1136,7 +1141,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
joinTableHandler := relationship.JoinTableHandler
joinTable := joinTableHandler.Table(scope.db)
- if !scope.Dialect().HasTable(joinTable) {
+ if !scope.Dialect().HasTable(scope.Context(), joinTable) {
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
var sqlTypes, primaryKeys []string
@@ -1209,7 +1214,7 @@ func (scope *Scope) dropTable() *Scope {
}
func (scope *Scope) modifyColumn(column string, typ string) {
- scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ))
+ scope.db.AddError(scope.Dialect().ModifyColumn(scope.Context(), scope.QuotedTableName(), scope.Quote(column), typ))
}
func (scope *Scope) dropColumn(column string) {
@@ -1217,7 +1222,7 @@ func (scope *Scope) dropColumn(column string) {
}
func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
- if scope.Dialect().HasIndex(scope.TableName(), indexName) {
+ if scope.Dialect().HasIndex(scope.Context(), scope.TableName(), indexName) {
return
}
@@ -1238,7 +1243,7 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
// Compatible with old generated key
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
- if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
+ if scope.Dialect().HasForeignKey(scope.Context(), scope.TableName(), keyName) {
return
}
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
@@ -1247,7 +1252,7 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
func (scope *Scope) removeForeignKey(field string, dest string) {
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
- if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
+ if !scope.Dialect().HasForeignKey(scope.Context(), scope.TableName(), keyName) {
return
}
var mysql mysql
@@ -1262,18 +1267,18 @@ func (scope *Scope) removeForeignKey(field string, dest string) {
}
func (scope *Scope) removeIndex(indexName string) {
- scope.Dialect().RemoveIndex(scope.TableName(), indexName)
+ scope.Dialect().RemoveIndex(scope.Context(), scope.TableName(), indexName)
}
func (scope *Scope) autoMigrate() *Scope {
tableName := scope.TableName()
quotedTableName := scope.QuotedTableName()
- if !scope.Dialect().HasTable(tableName) {
+ if !scope.Dialect().HasTable(scope.Context(), tableName) {
scope.createTable()
} else {
for _, field := range scope.GetModelStruct().StructFields {
- if !scope.Dialect().HasColumn(tableName, field.DBName) {
+ if !scope.Dialect().HasColumn(scope.Context(), tableName, field.DBName) {
if field.IsNormal {
sqlTag := scope.Dialect().DataTypeOf(field)
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()