Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/projectleo/gorm
Browse files Browse the repository at this point in the history
  • Loading branch information
asahasrabuddhe committed Jul 27, 2022
2 parents 2c91e85 + cbcbf63 commit 603f1c3
Show file tree
Hide file tree
Showing 21 changed files with 189 additions and 129 deletions.
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions .idea/gorm.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions .idea/markdown.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions callback_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func createCallback(scope *Scope) {

// execute create sql: no primaryField
if primaryField == nil {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
if result, err := scope.SQLDB().ExecContext(scope.Context(), scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()

Expand All @@ -146,7 +146,7 @@ func createCallback(scope *Scope) {

// execute create sql: lastInsertID implemention for majority of dialects
if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
if result, err := scope.SQLDB().ExecContext(scope.Context(), scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()

Expand All @@ -162,7 +162,7 @@ func createCallback(scope *Scope) {

// execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql)
if primaryField.Field.CanAddr() {
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
if err := scope.SQLDB().QueryRowContext(scope.Context(), scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
primaryField.IsBlank = false
scope.db.RowsAffected = 1
}
Expand Down
2 changes: 1 addition & 1 deletion callback_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func queryCallback(scope *Scope) {
scope.SQL = fmt.Sprint(str) + scope.SQL
}

if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
if rows, err := scope.SQLDB().QueryContext(scope.Context(), scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
defer rows.Close()

columns, _ := rows.Columns()
Expand Down
4 changes: 2 additions & 2 deletions callback_row_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ func rowQueryCallback(scope *Scope) {
}

if rowResult, ok := result.(*RowQueryResult); ok {
rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
rowResult.Row = scope.SQLDB().QueryRowContext(scope.Context(), scope.SQL, scope.SQLVars...)
} else if rowsResult, ok := result.(*RowsQueryResult); ok {
rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
rowsResult.Rows, rowsResult.Error = scope.SQLDB().QueryContext(scope.Context(), scope.SQL, scope.SQLVars...)
}
}
}
2 changes: 1 addition & 1 deletion customize_column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func TestCustomizeColumn(t *testing.T) {
DB.AutoMigrate(&CustomizeColumn{})

scope := DB.NewScope(&CustomizeColumn{})
if !scope.Dialect().HasColumn(scope.TableName(), col) {
if !scope.Dialect().HasColumn(scope.Context(), scope.TableName(), col) {
t.Errorf("CustomizeColumn should have column %s", col)
}

Expand Down
25 changes: 13 additions & 12 deletions dialect.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gorm

import (
"context"
"database/sql"
"fmt"
"reflect"
Expand All @@ -24,17 +25,17 @@ type Dialect interface {
DataTypeOf(field *StructField) string

// HasIndex check has index or not
HasIndex(tableName string, indexName string) bool
HasIndex(ctx context.Context, tableName string, indexName string) bool
// HasForeignKey check has foreign key or not
HasForeignKey(tableName string, foreignKeyName string) bool
HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool
// RemoveIndex remove index
RemoveIndex(tableName string, indexName string) error
RemoveIndex(ctx context.Context, tableName string, indexName string) error
// HasTable check has table or not
HasTable(tableName string) bool
HasTable(ctx context.Context, tableName string) bool
// HasColumn check has column or not
HasColumn(tableName string, columnName string) bool
HasColumn(ctx context.Context, tableName string, columnName string) bool
// ModifyColumn modify column's type
ModifyColumn(tableName string, columnName string, typ string) error
ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error

// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
LimitAndOffsetSQL(limit, offset interface{}) (string, error)
Expand All @@ -54,7 +55,7 @@ type Dialect interface {
NormalizeIndexAndColumn(indexName, columnName string) (string, string)

// CurrentDatabase return current database name
CurrentDatabase() string
CurrentDatabase(ctx context.Context) string
}

var dialectsMap = map[string]Dialect{}
Expand All @@ -67,9 +68,9 @@ func newDialect(name string, db SQLCommon) Dialect {
}

fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name)
commontDialect := &commonDialect{}
commontDialect.SetDB(db)
return commontDialect
cd := &commonDialect{}
cd.SetDB(db)
return cd
}

// RegisterDialect register new dialect
Expand Down Expand Up @@ -138,10 +139,10 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel
return fieldValue, dataType, size, strings.TrimSpace(additionalType)
}

func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) {
func currentDatabaseAndTable(ctx context.Context, dialect Dialect, tableName string) (string, string) {
if strings.Contains(tableName, ".") {
splitStrings := strings.SplitN(tableName, ".", 2)
return splitStrings[0], splitStrings[1]
}
return dialect.CurrentDatabase(), tableName
return dialect.CurrentDatabase(ctx), tableName
}
33 changes: 17 additions & 16 deletions dialect_common.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gorm

import (
"context"
"fmt"
"reflect"
"regexp"
Expand Down Expand Up @@ -99,43 +100,43 @@ func (s *commonDialect) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}

func (s commonDialect) HasIndex(tableName string, indexName string) bool {
func (s commonDialect) HasIndex(ctx context.Context, tableName string, indexName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
return count > 0
}

func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
func (s commonDialect) RemoveIndex(ctx context.Context, tableName string, indexName string) error {
_, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v", indexName))
return err
}

func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
func (s commonDialect) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool {
return false
}

func (s commonDialect) HasTable(tableName string) bool {
func (s commonDialect) HasTable(ctx context.Context, tableName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
return count > 0
}

func (s commonDialect) HasColumn(tableName string, columnName string) bool {
func (s commonDialect) HasColumn(ctx context.Context, tableName string, columnName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
return count > 0
}

func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
func (s commonDialect) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error {
_, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
return err
}

func (s commonDialect) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
func (s commonDialect) CurrentDatabase(ctx context.Context) (name string) {
s.db.QueryRowContext(ctx, "SELECT DATABASE()").Scan(&name)
return
}

Expand Down
37 changes: 19 additions & 18 deletions dialect_mysql.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gorm

import (
"context"
"crypto/sha1"
"database/sql"
"fmt"
Expand Down Expand Up @@ -129,13 +130,13 @@ func (s *mysql) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}

func (s mysql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
func (s mysql) RemoveIndex(ctx context.Context, tableName string, indexName string) error {
_, err := s.db.ExecContext(ctx, fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err
}

func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ))
func (s mysql) ModifyColumn(ctx context.Context, tableName string, columnName string, typ string) error {
_, err := s.db.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ))
return err
}

Expand All @@ -162,18 +163,18 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err err
return
}

func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
func (s mysql) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
return count > 0
}

func (s mysql) HasTable(tableName string) bool {
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
func (s mysql) HasTable(ctx context.Context, tableName string) bool {
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
var name string
// allow mysql database name with '-' character
if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil {
if err := s.db.QueryRowContext(ctx, fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil {
if err == sql.ErrNoRows {
return false
}
Expand All @@ -183,28 +184,28 @@ func (s mysql) HasTable(tableName string) bool {
}
}

func (s mysql) HasIndex(tableName string, indexName string) bool {
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
if rows, err := s.db.Query(fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil {
func (s mysql) HasIndex(ctx context.Context, tableName string, indexName string) bool {
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
if rows, err := s.db.QueryContext(ctx, fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil {
panic(err)
} else {
defer rows.Close()
return rows.Next()
}
}

func (s mysql) HasColumn(tableName string, columnName string) bool {
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
if rows, err := s.db.Query(fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil {
func (s mysql) HasColumn(ctx context.Context, tableName string, columnName string) bool {
currentDatabase, tableName := currentDatabaseAndTable(ctx, &s, tableName)
if rows, err := s.db.QueryContext(ctx, fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil {
panic(err)
} else {
defer rows.Close()
return rows.Next()
}
}

func (s mysql) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
func (s mysql) CurrentDatabase(ctx context.Context) (name string) {
s.db.QueryRowContext(ctx, "SELECT DATABASE()").Scan(&name)
return
}

Expand Down
21 changes: 11 additions & 10 deletions dialect_postgres.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gorm

import (
"context"
"encoding/json"
"fmt"
"reflect"
Expand Down Expand Up @@ -91,32 +92,32 @@ func (s *postgres) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}

func (s postgres) HasIndex(tableName string, indexName string) bool {
func (s postgres) HasIndex(ctx context.Context, tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
s.db.QueryRowContext(ctx, "SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
return count > 0
}

func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
func (s postgres) HasForeignKey(ctx context.Context, tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
s.db.QueryRowContext(ctx, "SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
return count > 0
}

func (s postgres) HasTable(tableName string) bool {
func (s postgres) HasTable(ctx context.Context, tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
return count > 0
}

func (s postgres) HasColumn(tableName string, columnName string) bool {
func (s postgres) HasColumn(ctx context.Context, tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
s.db.QueryRowContext(ctx, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
return count > 0
}

func (s postgres) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
func (s postgres) CurrentDatabase(ctx context.Context) (name string) {
s.db.QueryRowContext(ctx, "SELECT CURRENT_DATABASE()").Scan(&name)
return
}

Expand Down
Loading

0 comments on commit 603f1c3

Please sign in to comment.