Skip to content

Commit

Permalink
Add a way to set postgres role when executing migrations (#226)
Browse files Browse the repository at this point in the history
In same cases we want to set a specific role when executing migrations,
so the ownerhsip of the created/updated objects is different from the
pgroll user (storing pgroll state). This change allows to set a role
that will be set in the connection executing migrations.
  • Loading branch information
exekias authored Jan 15, 2024
1 parent 025a38f commit 73c2016
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 17 deletions.
4 changes: 4 additions & 0 deletions cmd/flags/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ func StateSchema() string {
func LockTimeout() int {
return viper.GetInt("LOCK_TIMEOUT")
}

func Role() string {
return viper.GetString("ROLE")
}
8 changes: 7 additions & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ func init() {
rootCmd.PersistentFlags().String("schema", "public", "Postgres schema to use for the migration")
rootCmd.PersistentFlags().String("pgroll-schema", "pgroll", "Postgres schema to use for pgroll internal state")
rootCmd.PersistentFlags().Int("lock-timeout", 500, "Postgres lock timeout in milliseconds for pgroll DDL operations")
rootCmd.PersistentFlags().String("role", "", "Optional postgres role to set when executing migrations")

viper.BindPFlag("PG_URL", rootCmd.PersistentFlags().Lookup("postgres-url"))
viper.BindPFlag("SCHEMA", rootCmd.PersistentFlags().Lookup("schema"))
viper.BindPFlag("STATE_SCHEMA", rootCmd.PersistentFlags().Lookup("pgroll-schema"))
viper.BindPFlag("LOCK_TIMEOUT", rootCmd.PersistentFlags().Lookup("lock-timeout"))
viper.BindPFlag("ROLE", rootCmd.PersistentFlags().Lookup("role"))
}

var rootCmd = &cobra.Command{
Expand All @@ -41,13 +43,17 @@ func NewRoll(ctx context.Context) (*roll.Roll, error) {
schema := flags.Schema()
stateSchema := flags.StateSchema()
lockTimeout := flags.LockTimeout()
role := flags.Role()

state, err := state.New(ctx, pgURL, stateSchema)
if err != nil {
return nil, err
}

return roll.New(ctx, pgURL, schema, lockTimeout, state)
return roll.New(ctx, pgURL, schema, state,
roll.WithLockTimeoutMs(lockTimeout),
roll.WithRole(role),
)
}

// Execute executes the root command.
Expand Down
2 changes: 2 additions & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -537,12 +537,14 @@ The `pgroll` CLI has the following top-level flags:
* `--schema`: The Postgres schema in which migrations will be run (default `"public"`).
* `--pgroll-schema`: The Postgres schema in which `pgroll` will store its internal state (default: `"pgroll"`).
* `--lock-timeout`: The Postgres `lock_timeout` value to use for all `pgroll` DDL operations, specified in milliseconds (default `500`).
* `--role``: The Postgres role to use for all `pgroll` DDL operations (default: `""`, which doesn't set any role).

Each of these flags can also be set via an environment variable:
* `PGROLL_PG_URL`
* `PGROLL_SCHEMA`
* `PGROLL_STATE_SCHEMA`
* `PGROLL_LOCK_TIMEOUT`
* `PGROLL_ROLE`

The CLI flag takes precedence if a flag is set via both an environment variable and a CLI flag.

Expand Down
34 changes: 33 additions & 1 deletion pkg/roll/execute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ func TestSchemaOptionIsRespected(t *testing.T) {
func TestLockTimeoutIsEnforced(t *testing.T) {
t.Parallel()

testutils.WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 100, func(mig *roll.Roll, db *sql.DB) {
testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(100)}, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()

// Start a create table migration
Expand Down Expand Up @@ -458,6 +458,38 @@ func TestStatusMethodReturnsCorrectStatus(t *testing.T) {
})
}

func TestRoleIsRespected(t *testing.T) {
t.Parallel()

testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithRole("pgroll")}, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()

// Start a create table migration
err := mig.Start(ctx, &migrations.Migration{
Name: "01_create_table",
Operations: migrations.Operations{createTableOp("table1")},
})
assert.NoError(t, err)

// Complete the create table migration
err = mig.Complete(ctx)
assert.NoError(t, err)

// Ensure that the table exists in the correct schema and owned by the correct role
var exists bool
err = db.QueryRow(`
SELECT EXISTS(
SELECT 1
FROM pg_catalog.pg_tables
WHERE tablename = $1
AND schemaname = $2
AND tableowner = $3
)`, "table1", "public", "pgroll").Scan(&exists)
assert.NoError(t, err)
assert.True(t, exists)
})
}

func createTableOp(tableName string) *migrations.OpCreateTable {
return &migrations.OpCreateTable{
Name: tableName,
Expand Down
27 changes: 27 additions & 0 deletions pkg/roll/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// SPDX-License-Identifier: Apache-2.0

package roll

type options struct {
// lock timeout in milliseconds for pgroll DDL operations
lockTimeoutMs int

// optional role to set before executing migrations
role string
}

type Option func(*options)

// WithLockTimeoutMs sets the lock timeout in milliseconds for pgroll DDL operations
func WithLockTimeoutMs(lockTimeoutMs int) Option {
return func(o *options) {
o.lockTimeoutMs = lockTimeoutMs
}
}

// WithRole sets the role to set before executing migrations
func WithRole(role string) Option {
return func(o *options) {
o.role = role
}
}
22 changes: 18 additions & 4 deletions pkg/roll/roll.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ type Roll struct {
pgVersion PGVersion
}

func New(ctx context.Context, pgURL, schema string, lockTimeoutMs int, state *state.State) (*Roll, error) {
func New(ctx context.Context, pgURL, schema string, state *state.State, opts ...Option) (*Roll, error) {
options := &options{}
for _, o := range opts {
o(options)
}

dsn, err := pq.ParseURL(pgURL)
if err != nil {
dsn = pgURL
Expand All @@ -48,9 +53,18 @@ func New(ctx context.Context, pgURL, schema string, lockTimeoutMs int, state *st
return nil, fmt.Errorf("unable to set pgroll.internal to true: %w", err)
}

_, err = conn.ExecContext(ctx, fmt.Sprintf("SET lock_timeout to '%dms'", lockTimeoutMs))
if err != nil {
return nil, fmt.Errorf("unable to set lock_timeout: %w", err)
if options.lockTimeoutMs > 0 {
_, err = conn.ExecContext(ctx, fmt.Sprintf("SET lock_timeout to '%dms'", options.lockTimeoutMs))
if err != nil {
return nil, fmt.Errorf("unable to set lock_timeout: %w", err)
}
}

if options.role != "" {
_, err = conn.ExecContext(ctx, fmt.Sprintf("SET ROLE %s", options.role))
if err != nil {
return nil, fmt.Errorf("unable to set role to '%s': %w", options.role, err)
}
}

var pgMajorVersion PGVersion
Expand Down
43 changes: 32 additions & 11 deletions pkg/testutils/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ func SharedTestMain(m *testing.M) {
os.Exit(1)
}

db, err := sql.Open("postgres", tConnStr)
if err != nil {
os.Exit(1)
}

// create handy role for tests
_, err = db.ExecContext(ctx, "CREATE ROLE pgroll")
if err != nil {
os.Exit(1)
}

exitCode := m.Run()

if err := ctr.Terminate(ctx); err != nil {
Expand Down Expand Up @@ -113,7 +124,7 @@ func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.
fn(st, db)
}

func WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t *testing.T, schema string, lockTimeoutMs int, fn func(mig *roll.Roll, db *sql.DB)) {
func WithMigratorInSchemaAndConnectionToContainerWithOptions(t *testing.T, schema string, opts []roll.Option, fn func(mig *roll.Roll, db *sql.DB)) {
t.Helper()
ctx := context.Background()

Expand Down Expand Up @@ -143,6 +154,17 @@ func WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t *testing.T, s
u.Path = "/" + dbName
connStr := u.String()

db, err := sql.Open("postgres", connStr)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := db.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

st, err := state.New(ctx, connStr, "pgroll")
if err != nil {
t.Fatal(err)
Expand All @@ -153,7 +175,7 @@ func WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t *testing.T, s
t.Fatal(err)
}

mig, err := roll.New(ctx, connStr, schema, lockTimeoutMs, st)
mig, err := roll.New(ctx, connStr, schema, st, opts...)
if err != nil {
t.Fatal(err)
}
Expand All @@ -164,29 +186,28 @@ func WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t *testing.T, s
}
})

db, err := sql.Open("postgres", connStr)
_, err = db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema))
if err != nil {
t.Fatal(err)
}

_, err = db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema))
_, err = db.ExecContext(ctx, fmt.Sprintf("GRANT ALL PRIVILEGES ON SCHEMA %s TO pgroll", schema))
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := db.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})
_, err = db.ExecContext(ctx, fmt.Sprintf("GRANT ALL PRIVILEGES ON DATABASE %s TO pgroll", dbName))
if err != nil {
t.Fatal(err)
}

fn(mig, db)
}

func WithMigratorInSchemaAndConnectionToContainer(t *testing.T, schema string, fn func(mig *roll.Roll, db *sql.DB)) {
WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, schema, 500, fn)
WithMigratorInSchemaAndConnectionToContainerWithOptions(t, schema, []roll.Option{roll.WithLockTimeoutMs(500)}, fn)
}

func WithMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) {
WithMigratorInSchemaWithLockTimeoutAndConnectionToContainer(t, "public", 500, fn)
WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(500)}, fn)
}

0 comments on commit 73c2016

Please sign in to comment.