Skip to content

Commit

Permalink
Refactor alter column suboperations (#337)
Browse files Browse the repository at this point in the history
Remove duplication between 'alter column' sub-operations. Pull:

* column duplication and trigger creation on migration start
* column rename and trigger removal on complete
* trigger and column drop on rollback

up to the parent 'alter column' operation. This removes a lot of
duplicated code from the sub-operations and will make it easier to
support multiple sub-operations in one 'alter column' operation.

Part of #336
  • Loading branch information
andrew-farries authored Apr 10, 2024
1 parent d5d4bea commit d4445bb
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 612 deletions.
175 changes: 172 additions & 3 deletions pkg/migrations/op_alter_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,155 @@ package migrations
import (
"context"
"database/sql"
"fmt"

"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/schema"
)

var _ Operation = (*OpAlterColumn)(nil)

func (o *OpAlterColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table)
column := table.GetColumn(o.Column)

op := o.innerOperation()

return op.Start(ctx, conn, stateSchema, tr, s, cbs...)
if _, ok := op.(*OpRenameColumn); !ok {
// Duplicate the column on the underlying table.
d := duplicatorForOperation(o.innerOperation(), conn, table, column)
if err := d.Duplicate(ctx); err != nil {
return nil, fmt.Errorf("failed to duplicate column: %w", err)
}
}

// perform any operation specific start steps
tbl, err := op.Start(ctx, conn, stateSchema, tr, s, cbs...)
if err != nil {
return nil, err
}

// Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL.
// Rename column operations do not require this trigger.
if _, ok := op.(*OpRenameColumn); !ok {
err = createTrigger(ctx, conn, tr, triggerConfig{
Name: TriggerName(o.Table, o.Column),
Direction: TriggerDirectionUp,
Columns: table.Columns,
SchemaName: s.Name,
TableName: o.Table,
PhysicalColumn: TemporaryName(o.Column),
StateSchema: stateSchema,
SQL: o.upSQLForOperation(op),
})
if err != nil {
return nil, fmt.Errorf("failed to create up trigger: %w", err)
}

// Add the new column to the internal schema representation. This is done
// here, before creation of the down trigger, so that the trigger can declare
// a variable for the new column.
table.AddColumn(o.Column, schema.Column{
Name: TemporaryName(o.Column),
})

// Add a trigger to copy values from the new column to the old.
err = createTrigger(ctx, conn, tr, triggerConfig{
Name: TriggerName(o.Table, TemporaryName(o.Column)),
Direction: TriggerDirectionDown,
Columns: table.Columns,
SchemaName: s.Name,
TableName: o.Table,
PhysicalColumn: o.Column,
StateSchema: stateSchema,
SQL: o.downSQLForOperation(op),
})
if err != nil {
return nil, fmt.Errorf("failed to create down trigger: %w", err)
}
}

return tbl, nil
}

func (o *OpAlterColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
op := o.innerOperation()

return op.Complete(ctx, conn, tr, s)
// Perform any operation specific completion steps
if err := op.Complete(ctx, conn, tr, s); err != nil {
return err
}

if _, ok := op.(*OpRenameColumn); !ok {
// Drop the old column
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s DROP COLUMN IF EXISTS %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(o.Column)))
if err != nil {
return err
}

// Remove the up function and trigger
_, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column))))
if err != nil {
return err
}

// Remove the down function and trigger
_, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, TemporaryName(o.Column)))))
if err != nil {
return err
}

// Rename the new column to the old column name
table := s.GetTable(o.Table)
column := table.GetColumn(o.Column)
if err := RenameDuplicatedColumn(ctx, conn, table, column); err != nil {
return err
}
}

return nil
}

func (o *OpAlterColumn) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
op := o.innerOperation()

return op.Rollback(ctx, conn, tr)
// Perform any operation specific rollback steps
if err := op.Rollback(ctx, conn, tr); err != nil {
return err
}

if _, ok := op.(*OpRenameColumn); !ok {
// Drop the new column
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(TemporaryName(o.Column)),
))
if err != nil {
return err
}

// Remove the up function and trigger
_, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column)),
))
if err != nil {
return err
}

// Remove the down function and trigger
_, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, TemporaryName(o.Column))),
))
if err != nil {
return err
}
}

return nil
}

func (o *OpAlterColumn) Validate(ctx context.Context, s *schema.Schema) error {
Expand Down Expand Up @@ -155,3 +282,45 @@ func (o *OpAlterColumn) numChanges() int {

return fieldsSet
}

// duplicatorForOperation returns a Duplicator for the given operation.
func duplicatorForOperation(op Operation, conn *sql.DB, table *schema.Table, column *schema.Column) *Duplicator {
d := NewColumnDuplicator(conn, table, column)

switch op := (op).(type) {
case *OpDropNotNull:
d = d.WithoutNotNull()
case *OpChangeType:
d = d.WithType(op.Type)
}
return d
}

// downSQLForOperation returns the down SQL for the given operation, applying
// an appropriate default if none is provided.
func (o *OpAlterColumn) downSQLForOperation(op Operation) string {
if o.Down != "" {
return o.Down
}

switch (op).(type) {
case *OpSetUnique, *OpSetNotNull:
return pq.QuoteIdentifier(o.Column)
}

return ""
}

// upSQLForOperation returns the up SQL for the given operation, applying
// an appropriate default if none is provided.
func (o *OpAlterColumn) upSQLForOperation(op Operation) string {
if o.Up != "" {
return o.Up
}

if _, ok := op.(*OpDropNotNull); ok {
return pq.QuoteIdentifier(o.Column)
}

return ""
}
99 changes: 1 addition & 98 deletions pkg/migrations/op_change_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ package migrations
import (
"context"
"database/sql"
"fmt"

"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/schema"
)

Expand All @@ -23,111 +21,16 @@ var _ Operation = (*OpChangeType)(nil)

func (o *OpChangeType) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table)
column := table.GetColumn(o.Column)

// Create a copy of the column on the underlying table.
d := NewColumnDuplicator(conn, table, column).WithType(o.Type)
if err := d.Duplicate(ctx); err != nil {
return nil, fmt.Errorf("failed to duplicate column: %w", err)
}

// Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL.
err := createTrigger(ctx, conn, tr, triggerConfig{
Name: TriggerName(o.Table, o.Column),
Direction: TriggerDirectionUp,
Columns: table.Columns,
SchemaName: s.Name,
TableName: o.Table,
PhysicalColumn: TemporaryName(o.Column),
StateSchema: stateSchema,
SQL: o.Up,
})
if err != nil {
return nil, fmt.Errorf("failed to create up trigger: %w", err)
}

// Add the new column to the internal schema representation. This is done
// here, before creation of the down trigger, so that the trigger can declare
// a variable for the new column.
table.AddColumn(o.Column, schema.Column{
Name: TemporaryName(o.Column),
})

// Add a trigger to copy values from the new column to the old, rewriting values using the `down` SQL.
err = createTrigger(ctx, conn, tr, triggerConfig{
Name: TriggerName(o.Table, TemporaryName(o.Column)),
Direction: TriggerDirectionDown,
Columns: table.Columns,
SchemaName: s.Name,
TableName: o.Table,
PhysicalColumn: o.Column,
StateSchema: stateSchema,
SQL: o.Down,
})
if err != nil {
return nil, fmt.Errorf("failed to create down trigger: %w", err)
}

return table, nil
}

func (o *OpChangeType) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
// Remove the up function and trigger
_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column))))
if err != nil {
return err
}

// Remove the down function and trigger
_, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, TemporaryName(o.Column)))))
if err != nil {
return err
}

// Drop the old column
_, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s DROP COLUMN IF EXISTS %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(o.Column)))
if err != nil {
return err
}

// Rename the new column to the old column name
table := s.GetTable(o.Table)
column := table.GetColumn(o.Column)
if err := RenameDuplicatedColumn(ctx, conn, table, column); err != nil {
return err
}

return nil
}

func (o *OpChangeType) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
// Drop the new column
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(TemporaryName(o.Column)),
))
if err != nil {
return err
}

// Remove the up function and trigger
_, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column)),
))
if err != nil {
return err
}

// Remove the down function and trigger
_, err = conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, TemporaryName(o.Column))),
))

return err
return nil
}

func (o *OpChangeType) Validate(ctx context.Context, s *schema.Schema) error {
Expand Down
Loading

0 comments on commit d4445bb

Please sign in to comment.