Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate unique savepoint names for nested transactions #7174

Merged
merged 2 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions finisher_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql"
"errors"
"fmt"
"hash/maphash"
"reflect"
"strings"

Expand Down Expand Up @@ -623,14 +624,15 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
// nested transaction
if !db.DisableNestedTransaction {
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
spID := new(maphash.Hash).Sum64()
err = db.SavePoint(fmt.Sprintf("sp%d", spID)).Error
if err != nil {
return
}
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
db.RollbackTo(fmt.Sprintf("sp%p", fc))
db.RollbackTo(fmt.Sprintf("sp%d", spID))
}
}()
}
Expand Down
4 changes: 2 additions & 2 deletions tests/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ require (
github.com/microsoft/go-mssqldb v1.7.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
golang.org/x/crypto v0.24.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/crypto v0.26.0 // indirect
golang.org/x/text v0.17.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

Expand Down
68 changes: 68 additions & 0 deletions tests/transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,74 @@ func TestNestedTransactionWithBlock(t *testing.T) {
}
}

func TestDeeplyNestedTransactionWithBlockAndWrappedCallback(t *testing.T) {
transaction := func(ctx context.Context, db *gorm.DB, callback func(ctx context.Context, db *gorm.DB) error) error {
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
return callback(ctx, tx)
})
}
var (
user = *GetUser("transaction-nested", Config{})
user1 = *GetUser("transaction-nested-1", Config{})
user2 = *GetUser("transaction-nested-2", Config{})
)

if err := transaction(context.Background(), DB, func(ctx context.Context, tx *gorm.DB) error {
tx.Create(&user)

if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}

if err := transaction(ctx, tx, func(ctx context.Context, tx1 *gorm.DB) error {
tx1.Create(&user1)

if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}

if err := transaction(ctx, tx1, func(ctx context.Context, tx2 *gorm.DB) error {
tx2.Create(&user2)

if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}

return errors.New("inner rollback")
}); err == nil {
t.Fatalf("nested transaction has no error")
}

return errors.New("rollback")
}); err == nil {
t.Fatalf("nested transaction should returns error")
}

if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil {
t.Fatalf("Should not find rollbacked record")
}

if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
return nil
}); err != nil {
t.Fatalf("no error should return, but got %v", err)
}

if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}

if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil {
t.Fatalf("Should not find rollbacked parent record")
}

if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil {
t.Fatalf("Should not find rollbacked nested record")
}
}

func TestDisabledNestedTransaction(t *testing.T) {
var (
user = *GetUser("transaction-nested", Config{})
Expand Down
Loading