Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 'UseMigrationLock' flag to acquire exclusive lock while performing migrations for Postgres #596

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ func NewApp() *cli.App {
Usage: "timeout for --wait flag",
Value: defaultDB.WaitTimeout,
},
&cli.BoolFlag{
Name: "migration-lock",
Usage: "use a lock during the migration so other dbmate instances can not run migrations at the same time",
},
}

app.Commands = []*cli.Command{
Expand Down Expand Up @@ -127,6 +131,7 @@ func NewApp() *cli.App {
Action: action(func(db *dbmate.DB, c *cli.Context) error {
db.Strict = c.Bool("strict")
db.Verbose = c.Bool("verbose")
db.UseMigrationLock = c.Bool("migration-lock")
return db.CreateAndMigrate()
}),
},
Expand Down Expand Up @@ -163,6 +168,7 @@ func NewApp() *cli.App {
Action: action(func(db *dbmate.DB, c *cli.Context) error {
db.Strict = c.Bool("strict")
db.Verbose = c.Bool("verbose")
db.UseMigrationLock = c.Bool("migration-lock")
return db.Migrate()
}),
},
Expand Down
56 changes: 54 additions & 2 deletions pkg/dbmate/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ type DB struct {
WaitInterval time.Duration
// WaitTimeout specifies maximum time for connection attempts
WaitTimeout time.Duration
// UseMigrationLock uses an exclusive lock while performing migrations
UseMigrationLock bool
}

// StatusResult represents an available migration status
Expand All @@ -83,6 +85,7 @@ func New(databaseURL *url.URL) *DB {
WaitBefore: false,
WaitInterval: time.Second,
WaitTimeout: 60 * time.Second,
UseMigrationLock: false,
}
}

Expand Down Expand Up @@ -153,12 +156,36 @@ func (db *DB) Wait() error {
}

// CreateAndMigrate creates the database (if necessary) and runs migrations
func (db *DB) CreateAndMigrate() error {
func (db *DB) CreateAndMigrate() (returnErr error) {
drv, err := db.Driver()
if err != nil {
return err
}

// try and acquire a lock for the duration of the migration.
// this is to prevent multiple instances performing the same migration in parallel.
if db.UseMigrationLock {
drvLock, ok := drv.(DriverMigrationLock)
if !ok {
return fmt.Errorf("driver does not support the use of a migration lock")
}

if err := drvLock.Lock(); err != nil {
return err
}

defer func() {
err := drvLock.Unlock()
if err != nil {
if returnErr != nil {
returnErr = fmt.Errorf("failed to unlock: %v: %w", err, returnErr)
return
}
returnErr = fmt.Errorf("failed to unlock: %w", err)
}
}()
}

// create database if it does not already exist
// skip this step if we cannot determine status
// (e.g. user does not have list database permission)
Expand Down Expand Up @@ -333,12 +360,37 @@ func (db *DB) openDatabaseForMigration(drv Driver) (*sql.DB, error) {
}

// Migrate migrates database to the latest version
func (db *DB) Migrate() error {
func (db *DB) Migrate() (returnErr error) {
drv, err := db.Driver()
if err != nil {
return err
}

if db.UseMigrationLock {
drvLock, ok := drv.(DriverMigrationLock)
if !ok {
return fmt.Errorf("driver does not support the use of a migration lock")
}

// only try and lock if we haven't already done so, e.g called CreateAndMigrate.
if !drvLock.IsLocked() {
if err := drvLock.Lock(); err != nil {
return err
}

defer func() {
err := drvLock.Unlock()
if err != nil {
if returnErr != nil {
returnErr = fmt.Errorf("failed to unlock: %v: %w", err, returnErr)
return
}
returnErr = fmt.Errorf("failed to unlock: %w", err)
}
}()
}
}

migrations, err := db.FindMigrations()
if err != nil {
return err
Expand Down
1 change: 1 addition & 0 deletions pkg/dbmate/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func TestNew(t *testing.T) {
require.False(t, db.WaitBefore)
require.Equal(t, time.Second, db.WaitInterval)
require.Equal(t, 60*time.Second, db.WaitTimeout)
require.Equal(t, false, db.UseMigrationLock)
}

func TestGetDriver(t *testing.T) {
Expand Down
6 changes: 6 additions & 0 deletions pkg/dbmate/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ type Driver interface {
QueryError(string, error) error
}

type DriverMigrationLock interface {
Lock() error
Unlock() error
IsLocked() bool
}

// DriverConfig holds configuration passed to driver constructors
type DriverConfig struct {
DatabaseURL *url.URL
Expand Down
48 changes: 48 additions & 0 deletions pkg/driver/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package postgres

import (
"bytes"
"context"
"database/sql"
"fmt"
"io"
Expand All @@ -28,6 +29,8 @@ type Driver struct {
migrationsTableName string
databaseURL *url.URL
log io.Writer

migrationLockTx *sql.Tx
}

// NewDriver initializes the driver
Expand Down Expand Up @@ -456,3 +459,48 @@ func (drv *Driver) quotedMigrationsTableNameParts(db dbutil.Transaction) (string
// if more than one part, we already have a schema
return quotedNameParts[0], strings.Join(quotedNameParts[1:], "."), nil
}

const lockKey = 48372615

func (drv *Driver) Lock() error {
if drv.migrationLockTx != nil {
return fmt.Errorf("already locked")
}

db, err := drv.Open()
if err != nil {
return err
}

tx, err := db.BeginTx(context.Background(), nil)
if err != nil {
return err
}

drv.migrationLockTx = tx

_, err = tx.Exec("SELECT pg_advisory_xact_lock($1)", lockKey)
if err != nil {
return fmt.Errorf("failed to acquire lock: %w", err)
}

return nil
}

func (drv *Driver) Unlock() error {
if drv.migrationLockTx == nil {
return fmt.Errorf("not locked")
}

if err := drv.migrationLockTx.Rollback(); err != nil {
return err
}

drv.migrationLockTx = nil

return nil
}

func (drv *Driver) IsLocked() bool {
return drv.migrationLockTx != nil
}
36 changes: 36 additions & 0 deletions pkg/driver/postgres/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/url"
"runtime"
"testing"
"time"

"github.com/amacneil/dbmate/v2/pkg/dbmate"
"github.com/amacneil/dbmate/v2/pkg/dbtest"
Expand Down Expand Up @@ -790,3 +791,38 @@ func TestPostgresMigrationsTableExists(t *testing.T) {
require.Equal(t, true, exists)
})
}

func TestPostgresMigrationLock(t *testing.T) {
t.Run("lock and unlock", func(t *testing.T) {
drv := testPostgresDriver(t)
err := drv.Lock()
require.NoError(t, err)

err = drv.Lock()
require.Error(t, err, "Should not be able to lock again without unlock")

err = drv.Unlock()
require.NoError(t, err, "Should be able to unlock")
})

t.Run("lock on one instance should block lock attempt on another", func(t *testing.T) {
drv1 := testPostgresDriver(t)
err1 := drv1.Lock()
require.NoError(t, err1)

var isUnlocked bool
go func() {
time.Sleep(10 * time.Millisecond)
err := drv1.Unlock()
require.NoError(t, err, "Should be able to unlock")
isUnlocked = true
}()

drv2 := testPostgresDriver(t)
err2 := drv2.Lock()
require.NoError(t, err2)
require.Equal(t, true, isUnlocked)
err2 = drv2.Unlock()
require.NoError(t, err2, "Should be able to unlock")
})
}
Loading