Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanslade committed Dec 18, 2024
1 parent 0d88089 commit 73906e3
Show file tree
Hide file tree
Showing 3 changed files with 320 additions and 28 deletions.
157 changes: 129 additions & 28 deletions pkg/sql2pgroll/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, err
op, err = convertAlterTableSetColumnDefault(stmt, alterTableCmd)
case pgq.AlterTableType_AT_DropConstraint:
op, err = convertAlterTableDropConstraint(stmt, alterTableCmd)
case pgq.AlterTableType_AT_AddColumn:
op, err = convertAlterTableAddColumn(stmt, alterTableCmd)
}

if err != nil {
Expand Down Expand Up @@ -198,20 +200,9 @@ func convertAlterTableAddForeignKeyConstraint(stmt *pgq.AlterTableStmt, constrai
migs[column] = PlaceHolderSQL
}

var onDelete migrations.ForeignKeyReferenceOnDelete
switch constraint.GetFkDelAction() {
case "a":
onDelete = migrations.ForeignKeyReferenceOnDeleteNOACTION
case "c":
onDelete = migrations.ForeignKeyReferenceOnDeleteCASCADE
case "r":
onDelete = migrations.ForeignKeyReferenceOnDeleteRESTRICT
case "d":
onDelete = migrations.ForeignKeyReferenceOnDeleteSETDEFAULT
case "n":
onDelete = migrations.ForeignKeyReferenceOnDeleteSETNULL
default:
return nil, fmt.Errorf("unknown delete action: %q", constraint.GetFkDelAction())
onDelete, err := parseOnDeleteAction(constraint.GetFkDelAction())
if err != nil {
return nil, fmt.Errorf("failed to parse on delete action: %w", err)
}

tableName := getQualifiedRelationName(stmt.Relation)
Expand All @@ -232,6 +223,23 @@ func convertAlterTableAddForeignKeyConstraint(stmt *pgq.AlterTableStmt, constrai
}, nil
}

func parseOnDeleteAction(action string) (migrations.ForeignKeyReferenceOnDelete, error) {
switch action {
case "a":
return migrations.ForeignKeyReferenceOnDeleteNOACTION, nil
case "c":
return migrations.ForeignKeyReferenceOnDeleteCASCADE, nil
case "r":
return migrations.ForeignKeyReferenceOnDeleteRESTRICT, nil
case "d":
return migrations.ForeignKeyReferenceOnDeleteSETDEFAULT, nil
case "n":
return migrations.ForeignKeyReferenceOnDeleteSETNULL, nil
default:
return migrations.ForeignKeyReferenceOnDeleteNOACTION, fmt.Errorf("unknown delete action: %q", action)
}
}

func canConvertAlterTableAddForeignKeyConstraint(constraint *pgq.Constraint) bool {
if constraint.SkipValidation {
return false
Expand Down Expand Up @@ -319,21 +327,12 @@ func convertAlterTableSetColumnDefault(stmt *pgq.AlterTableStmt, cmd *pgq.AlterT
Up: PlaceHolderSQL,
}

if c := cmd.GetDef().GetAConst(); c != nil {
if c.GetIsnull() {
// The default can be set to null
operation.Default = nullable.NewNullNullable[string]()
return operation, nil
}
def, err := extractDefault(cmd.GetDef())
if err != nil {
return nil, err
}

// We're setting it to an expression
if cmd.GetDef() != nil {
def, err := pgq.DeparseExpr(cmd.GetDef())
if err != nil {
return nil, fmt.Errorf("failed to deparse expression: %w", err)
}
operation.Default = nullable.NewNullableWithValue(def)
if def.IsSpecified() {
operation.Default = def
return operation, nil
}

Expand All @@ -347,6 +346,24 @@ func convertAlterTableSetColumnDefault(stmt *pgq.AlterTableStmt, cmd *pgq.AlterT
return nil, nil
}

func extractDefault(node *pgq.Node) (nullable.Nullable[string], error) {
if c := node.GetAConst(); c != nil && c.GetIsnull() {
// The default can be set to null
return nullable.NewNullNullable[string](), nil
}

// It's an expression
if node != nil {
def, err := pgq.DeparseExpr(node)
if err != nil {
return nil, fmt.Errorf("failed to deparse expression: %w", err)
}
return nullable.NewNullableWithValue(def), nil
}

return nil, nil
}

// convertAlterTableDropConstraint convert DROP CONSTRAINT SQL into an OpDropMultiColumnConstraint.
// Because we are unable to infer the columns involved, placeholder migrations are used.
//
Expand Down Expand Up @@ -380,6 +397,90 @@ func canConvertDropConstraint(cmd *pgq.AlterTableCmd) bool {
return cmd.Behavior != pgq.DropBehavior_DROP_CASCADE
}

func convertAlterTableAddColumn(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) {
if !canConvertAddColumn(cmd) {
return nil, nil
}

columnDef := cmd.GetDef().GetColumnDef()

columnType, err := pgq.DeparseTypeName(columnDef.GetTypeName())
if err != nil {
return nil, fmt.Errorf("failed to deparse type name: %w", err)
}

operation := &migrations.OpAddColumn{
Column: migrations.Column{
Name: columnDef.GetColname(),
Type: columnType,
},
Table: getQualifiedRelationName(stmt.GetRelation()),
Up: PlaceHolderSQL,
}

// Currently only handles DEFAULT
if len(columnDef.GetConstraints()) > 0 {
for _, constraint := range columnDef.GetConstraints() {
switch constraint.GetConstraint().GetContype() {
case pgq.ConstrType_CONSTR_NULL:
operation.Column.Nullable = true
case pgq.ConstrType_CONSTR_PRIMARY:
operation.Column.Pk = true
case pgq.ConstrType_CONSTR_UNIQUE:
operation.Column.Unique = true
case pgq.ConstrType_CONSTR_CHECK:
raw, err := pgq.DeparseExpr(constraint.GetConstraint().GetRawExpr())
if err != nil {
return nil, fmt.Errorf("failed to deparse raw expression: %w", err)
}
operation.Column.Check = &migrations.CheckConstraint{
Constraint: raw,
Name: constraint.GetConstraint().GetConname(),
}
case pgq.ConstrType_CONSTR_DEFAULT:
defaultExpr := constraint.GetConstraint().GetRawExpr()
def, err := extractDefault(defaultExpr)
if err != nil {
return nil, err
}
if !def.IsNull() {
v := def.MustGet()
operation.Column.Default = &v
}
case pgq.ConstrType_CONSTR_FOREIGN:
onDelete, err := parseOnDeleteAction(constraint.GetConstraint().GetFkDelAction())
if err != nil {
return nil, err
}
fk := &migrations.ForeignKeyReference{
Name: constraint.GetConstraint().GetConname(),
OnDelete: onDelete,
Column: constraint.GetConstraint().GetPkAttrs()[0].GetString_().GetSval(),
Table: getQualifiedRelationName(constraint.GetConstraint().GetPktable()),
}
operation.Column.References = fk
}
}
}

return operation, nil
}

func canConvertAddColumn(cmd *pgq.AlterTableCmd) bool {
for _, constraint := range cmd.GetDef().GetColumnDef().GetConstraints() {
switch constraint.GetConstraint().GetFkUpdAction() {
case "r", "c", "n", "d":
// RESTRICT, CASCADE, SET NULL, SET DEFAULT
return false
case "a":
// NO ACTION, the default
break
}
}

return true
}

func convertAlterTableDropColumn(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) {
if !canConvertDropColumn(cmd) {
return nil, nil
Expand Down
80 changes: 80 additions & 0 deletions pkg/sql2pgroll/alter_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,76 @@ func TestConvertAlterTableStatements(t *testing.T) {
sql: "ALTER TABLE schema.foo ADD CONSTRAINT bar CHECK (age > 0)",
expectedOp: expect.CreateConstraintOp4,
},

// Add column
{
sql: "ALTER TABLE foo ADD COLUMN bar int",
expectedOp: expect.AddColumnOp1,
},
{
sql: "ALTER TABLE schema.foo ADD COLUMN bar int",
expectedOp: expect.AddColumnOp2,
},
{
sql: "ALTER TABLE foo ADD COLUMN bar int DEFAULT 123",
expectedOp: expect.AddColumnOp1WithDefault(ptr("123")),
},
{
sql: "ALTER TABLE foo ADD COLUMN bar int DEFAULT 'baz'",
expectedOp: expect.AddColumnOp1WithDefault(ptr("'baz'")),
},
{
sql: "ALTER TABLE foo ADD COLUMN bar int DEFAULT null",
expectedOp: expect.AddColumnOp1WithDefault(nil),
},
{
sql: "ALTER TABLE foo ADD COLUMN bar int NULL",
expectedOp: expect.AddColumnOp3,
},
{
sql: "ALTER TABLE foo ADD COLUMN bar int UNIQUE",
expectedOp: expect.AddColumnOp4,
},
{
sql: "ALTER TABLE foo ADD COLUMN bar int PRIMARY KEY",
expectedOp: expect.AddColumnOp5,
},
{
sql: "ALTER TABLE foo ADD COLUMN bar int CHECK (bar > 0)",
expectedOp: expect.AddColumnOp6,
},
{
sql: "ALTER TABLE foo ADD COLUMN bar int CONSTRAINT check_bar CHECK (bar > 0)",
expectedOp: expect.AddColumnOp7,
},
{
sql: "ALTER TABLE foo ADD COLUMN bar int CONSTRAINT fk_baz REFERENCES baz (bar)",
expectedOp: expect.AddColumnOp8WithOnDeleteAction(migrations.ForeignKeyReferenceOnDeleteNOACTION),
},
{
sql: "ALTER TABLE foo ADD COLUMN bar int CONSTRAINT fk_baz REFERENCES baz (bar) ON UPDATE NO ACTION",
expectedOp: expect.AddColumnOp8WithOnDeleteAction(migrations.ForeignKeyReferenceOnDeleteNOACTION),
},
{
sql: "ALTER TABLE foo ADD COLUMN bar int CONSTRAINT fk_baz REFERENCES baz (bar) ON DELETE NO ACTION",
expectedOp: expect.AddColumnOp8WithOnDeleteAction(migrations.ForeignKeyReferenceOnDeleteNOACTION),
},
{
sql: "ALTER TABLE foo ADD COLUMN bar int CONSTRAINT fk_baz REFERENCES baz (bar) ON DELETE RESTRICT",
expectedOp: expect.AddColumnOp8WithOnDeleteAction(migrations.ForeignKeyReferenceOnDeleteRESTRICT),
},
{
sql: "ALTER TABLE foo ADD COLUMN bar int CONSTRAINT fk_baz REFERENCES baz (bar) ON DELETE SET NULL ",
expectedOp: expect.AddColumnOp8WithOnDeleteAction(migrations.ForeignKeyReferenceOnDeleteSETNULL),
},
{
sql: "ALTER TABLE foo ADD COLUMN bar int CONSTRAINT fk_baz REFERENCES baz (bar) ON DELETE SET DEFAULT",
expectedOp: expect.AddColumnOp8WithOnDeleteAction(migrations.ForeignKeyReferenceOnDeleteSETDEFAULT),
},
{
sql: "ALTER TABLE foo ADD COLUMN bar int CONSTRAINT fk_baz REFERENCES baz (bar) ON DELETE CASCADE",
expectedOp: expect.AddColumnOp8WithOnDeleteAction(migrations.ForeignKeyReferenceOnDeleteCASCADE),
},
}

for _, tc := range tests {
Expand Down Expand Up @@ -191,6 +261,12 @@ func TestUnconvertableAlterTableStatements(t *testing.T) {
// representable by `OpCreateConstraint`
"ALTER TABLE foo ADD CONSTRAINT bar CHECK (age > 0) NO INHERIT",
"ALTER TABLE foo ADD CONSTRAINT bar CHECK (age > 0) NOT VALID",

// ADD COLUMN cases not yet covered
"ALTER TABLE foo ADD COLUMN bar int REFERENCES bar (c) ON UPDATE RESTRICT",
"ALTER TABLE foo ADD COLUMN bar int REFERENCES bar (c) ON UPDATE CASCADE",
"ALTER TABLE foo ADD COLUMN bar int REFERENCES bar (c) ON UPDATE SET NULL",
"ALTER TABLE foo ADD COLUMN bar int REFERENCES bar (c) ON UPDATE SET DEFAULT",
}

for _, sql := range tests {
Expand All @@ -204,3 +280,7 @@ func TestUnconvertableAlterTableStatements(t *testing.T) {
})
}
}

func ptr[T any](v T) *T {
return &v
}
Loading

0 comments on commit 73906e3

Please sign in to comment.