diff --git a/pkg/sql2pgroll/alter_table.go b/pkg/sql2pgroll/alter_table.go index ca693aac..db65d33f 100644 --- a/pkg/sql2pgroll/alter_table.go +++ b/pkg/sql2pgroll/alter_table.go @@ -11,7 +11,10 @@ import ( "github.com/xataio/pgroll/pkg/migrations" ) -const PlaceHolderSQL = "TODO: Implement SQL data migration" +const ( + PlaceHolderColumnName = "placeholder" + PlaceHolderSQL = "TODO: Implement SQL data migration" +) // convertAlterTableStmt converts an ALTER TABLE statement to pgroll operations. func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, error) { @@ -99,11 +102,12 @@ func convertAlterTableAlterColumnType(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTa }, nil } -// convertAlterTableAddConstraint converts SQL statements that add UNIQUE or FOREIGN KEY constraints, +// convertAlterTableAddConstraint converts SQL statements that add constraints, // for example: // // `ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a)` // `ALTER TABLE foo ADD CONSTRAINT fk_bar_c FOREIGN KEY (a) REFERENCES bar (c);` +// `ALTER TABLE foo ADD CONSTRAINT bar CHECK (age > 0)` // // An OpCreateConstraint operation is returned. func convertAlterTableAddConstraint(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) { @@ -119,6 +123,8 @@ func convertAlterTableAddConstraint(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTabl op, err = convertAlterTableAddUniqueConstraint(stmt, node.Constraint) case pgq.ConstrType_CONSTR_FOREIGN: op, err = convertAlterTableAddForeignKeyConstraint(stmt, node.Constraint) + case pgq.ConstrType_CONSTR_CHECK: + op, err = convertAlterTableAddCheckConstraint(stmt, node.Constraint) default: return nil, nil } @@ -229,6 +235,10 @@ func convertAlterTableAddForeignKeyConstraint(stmt *pgq.AlterTableStmt, constrai } func canConvertAlterTableAddForeignKeyConstraint(constraint *pgq.Constraint) bool { + if constraint.SkipValidation { + return false + } + switch constraint.GetFkUpdAction() { case "r", "c", "n", "d": // RESTRICT, CASCADE, SET NULL, SET DEFAULT @@ -248,6 +258,53 @@ func canConvertAlterTableAddForeignKeyConstraint(constraint *pgq.Constraint) boo return true } +// convertAlterTableAddCheckConstraint converts SQL statements like: +// +// `ALTER TABLE foo ADD CONSTRAINT bar CHECK (age > 0)` +// +// to an OpCreateConstraint operation. +func convertAlterTableAddCheckConstraint(stmt *pgq.AlterTableStmt, constraint *pgq.Constraint) (migrations.Operation, error) { + if !canConvertCheckConstraint(constraint) { + return nil, nil + } + + tableName := stmt.GetRelation().GetRelname() + if stmt.GetRelation().GetSchemaname() != "" { + tableName = stmt.GetRelation().GetSchemaname() + "." + tableName + } + + expr, err := pgq.DeparseExpr(constraint.GetRawExpr()) + if err != nil { + return nil, fmt.Errorf("failed to deparse CHECK expression: %w", err) + } + + return &migrations.OpCreateConstraint{ + Type: migrations.OpCreateConstraintTypeCheck, + Name: constraint.GetConname(), + Table: tableName, + Check: ptr(expr), + Columns: []string{PlaceHolderColumnName}, + Up: migrations.MultiColumnUpSQL{ + PlaceHolderColumnName: PlaceHolderSQL, + }, + Down: migrations.MultiColumnDownSQL{ + PlaceHolderColumnName: PlaceHolderSQL, + }, + }, nil +} + +// canConvertCheckConstraint checks if the CHECK constraint `constraint` can +// be faithfully converted to an OpCreateConstraint operation without losing +// information. +func canConvertCheckConstraint(constraint *pgq.Constraint) bool { + switch { + case constraint.IsNoInherit, constraint.SkipValidation: + return false + default: + return true + } +} + // convertAlterTableSetColumnDefault converts SQL statements like: // // `ALTER TABLE foo COLUMN bar SET DEFAULT 'foo'` @@ -317,10 +374,10 @@ func convertAlterTableDropConstraint(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTab return &migrations.OpDropMultiColumnConstraint{ Up: migrations.MultiColumnUpSQL{ - "placeholder": PlaceHolderSQL, + PlaceHolderColumnName: PlaceHolderSQL, }, Down: migrations.MultiColumnDownSQL{ - "placeholder": PlaceHolderSQL, + PlaceHolderColumnName: PlaceHolderSQL, }, Table: tableName, Name: cmd.GetName(), diff --git a/pkg/sql2pgroll/alter_table_test.go b/pkg/sql2pgroll/alter_table_test.go index d50e1c90..dd584db6 100644 --- a/pkg/sql2pgroll/alter_table_test.go +++ b/pkg/sql2pgroll/alter_table_test.go @@ -112,10 +112,6 @@ func TestConvertAlterTableStatements(t *testing.T) { sql: "ALTER TABLE foo ADD CONSTRAINT fk_bar_c FOREIGN KEY (a) REFERENCES bar (c);", expectedOp: expect.AddForeignKeyOp2, }, - { - sql: "ALTER TABLE foo ADD CONSTRAINT fk_bar_c FOREIGN KEY (a) REFERENCES bar (c) NOT VALID;", - expectedOp: expect.AddForeignKeyOp2, - }, { sql: "ALTER TABLE schema_a.foo ADD CONSTRAINT fk_bar_c FOREIGN KEY (a) REFERENCES schema_a.bar (c);", expectedOp: expect.AddForeignKeyOp3, @@ -136,6 +132,14 @@ func TestConvertAlterTableStatements(t *testing.T) { sql: "ALTER TABLE foo DROP CONSTRAINT IF EXISTS constraint_foo RESTRICT", expectedOp: expect.OpDropConstraintWithTable("foo"), }, + { + sql: "ALTER TABLE foo ADD CONSTRAINT bar CHECK (age > 0)", + expectedOp: expect.CreateConstraintOp3, + }, + { + sql: "ALTER TABLE schema.foo ADD CONSTRAINT bar CHECK (age > 0)", + expectedOp: expect.CreateConstraintOp4, + }, } for _, tc := range tests { @@ -176,11 +180,17 @@ func TestUnconvertableAlterTableStatements(t *testing.T) { "ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) ON UPDATE SET NULL;", "ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) ON UPDATE SET DEFAULT;", "ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) MATCH FULL;", + "ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) NOT VALID", // MATCH PARTIAL is not implemented in the actual parser yet //"ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) MATCH PARTIAL;", // Drop constraint with CASCADE "ALTER TABLE foo DROP CONSTRAINT bar CASCADE", + + // NO INHERIT and NOT VALID options on CHECK constraints are not + // representable by `OpCreateConstraint` + "ALTER TABLE foo ADD CONSTRAINT bar CHECK (age > 0) NO INHERIT", + "ALTER TABLE foo ADD CONSTRAINT bar CHECK (age > 0) NOT VALID", } for _, sql := range tests { diff --git a/pkg/sql2pgroll/expect/create_constraint.go b/pkg/sql2pgroll/expect/create_constraint.go index a408fcae..d98a1c21 100644 --- a/pkg/sql2pgroll/expect/create_constraint.go +++ b/pkg/sql2pgroll/expect/create_constraint.go @@ -30,3 +30,31 @@ var CreateConstraintOp2 = &migrations.OpCreateConstraint{ "b": sql2pgroll.PlaceHolderSQL, }, } + +var CreateConstraintOp3 = &migrations.OpCreateConstraint{ + Type: migrations.OpCreateConstraintTypeCheck, + Name: "bar", + Table: "foo", + Check: ptr("age > 0"), + Columns: []string{sql2pgroll.PlaceHolderColumnName}, + Up: map[string]string{ + sql2pgroll.PlaceHolderColumnName: sql2pgroll.PlaceHolderSQL, + }, + Down: map[string]string{ + sql2pgroll.PlaceHolderColumnName: sql2pgroll.PlaceHolderSQL, + }, +} + +var CreateConstraintOp4 = &migrations.OpCreateConstraint{ + Type: migrations.OpCreateConstraintTypeCheck, + Name: "bar", + Table: "schema.foo", + Check: ptr("age > 0"), + Columns: []string{sql2pgroll.PlaceHolderColumnName}, + Up: map[string]string{ + sql2pgroll.PlaceHolderColumnName: sql2pgroll.PlaceHolderSQL, + }, + Down: map[string]string{ + sql2pgroll.PlaceHolderColumnName: sql2pgroll.PlaceHolderSQL, + }, +} diff --git a/pkg/sql2pgroll/expect/drop_constraint.go b/pkg/sql2pgroll/expect/drop_constraint.go index 206789c2..b44d6c73 100644 --- a/pkg/sql2pgroll/expect/drop_constraint.go +++ b/pkg/sql2pgroll/expect/drop_constraint.go @@ -10,10 +10,10 @@ import ( func OpDropConstraintWithTable(table string) *migrations.OpDropMultiColumnConstraint { return &migrations.OpDropMultiColumnConstraint{ Up: migrations.MultiColumnUpSQL{ - "placeholder": sql2pgroll.PlaceHolderSQL, + sql2pgroll.PlaceHolderColumnName: sql2pgroll.PlaceHolderSQL, }, Down: migrations.MultiColumnDownSQL{ - "placeholder": sql2pgroll.PlaceHolderSQL, + sql2pgroll.PlaceHolderColumnName: sql2pgroll.PlaceHolderSQL, }, Table: table, Name: "constraint_foo",