diff --git a/pkg/sql2pgroll/create_table.go b/pkg/sql2pgroll/create_table.go index 01be619d..68a646be 100644 --- a/pkg/sql2pgroll/create_table.go +++ b/pkg/sql2pgroll/create_table.go @@ -4,7 +4,6 @@ package sql2pgroll import ( "fmt" - "slices" pgq "github.com/xataio/pg_query_go/v6" @@ -87,14 +86,6 @@ func convertColumnDef(tableName string, col *pgq.ColumnDef) (*migrations.Column, return nil, fmt.Errorf("error deparsing column type: %w", err) } - // Named inline constraints are not supported - anyNamed := slices.ContainsFunc(col.GetConstraints(), func(c *pgq.Node) bool { - return c.GetConstraint().GetConname() != "" - }) - if anyNamed { - return nil, nil - } - // Convert column constraints var notNull, pk, unique bool var check *migrations.CheckConstraint @@ -103,15 +94,31 @@ func convertColumnDef(tableName string, col *pgq.ColumnDef) (*migrations.Column, for _, c := range col.GetConstraints() { switch c.GetConstraint().GetContype() { case pgq.ConstrType_CONSTR_NULL: + // named NULL constraints are not supported + if isConstraintNamed(c.GetConstraint()) { + return nil, nil + } notNull = false case pgq.ConstrType_CONSTR_NOTNULL: + // named NOT NULL constraints are not supported + if isConstraintNamed(c.GetConstraint()) { + return nil, nil + } notNull = true case pgq.ConstrType_CONSTR_UNIQUE: + // named UNIQUE constraints are not supported + if isConstraintNamed(c.GetConstraint()) { + return nil, nil + } if !canConvertUniqueConstraint(c.GetConstraint()) { return nil, nil } unique = true case pgq.ConstrType_CONSTR_PRIMARY: + // named PRIMARY KEY constraints are not supported + if isConstraintNamed(c.GetConstraint()) { + return nil, nil + } if !canConvertPrimaryKeyConstraint(c.GetConstraint()) { return nil, nil } @@ -126,6 +133,10 @@ func convertColumnDef(tableName string, col *pgq.ColumnDef) (*migrations.Column, return nil, nil } case pgq.ConstrType_CONSTR_DEFAULT: + // named DEFAULT constraints are not supported + if isConstraintNamed(c.GetConstraint()) { + return nil, nil + } d, err := pgq.DeparseExpr(c.GetConstraint().GetRawExpr()) if err != nil { return nil, fmt.Errorf("error deparsing default value: %w", err) @@ -254,3 +265,15 @@ func convertInlineForeignKeyConstraint(tableName, columnName string, constraint Table: getQualifiedRelationName(constraint.GetPktable()), }, nil } + +// isConstraintNamed returns true iff `constraint` has a name. +// Column constraints defined inline in CREATE TABLE statements cab be either +// named or unnamed, for example: +// - CREATE TABLE t (a INT PRIMARY KEY); +// - CREATE TABLE t (a INT CONSTRAINT my_pk PRIMARY KEY); +// Likewise, table constraints can also be either named or unnamed, for example: +// - CREATE TABLE foo(a int, CONSTRAINT foo_check CHECK (a > 0)), +// - CREATE TABLE foo(a int, CHECK (a > 0)), +func isConstraintNamed(constraint *pgq.Constraint) bool { + return constraint.GetConname() != "" +} diff --git a/pkg/sql2pgroll/create_table_test.go b/pkg/sql2pgroll/create_table_test.go index f7235c3c..591fb50c 100644 --- a/pkg/sql2pgroll/create_table_test.go +++ b/pkg/sql2pgroll/create_table_test.go @@ -56,10 +56,22 @@ func TestConvertCreateTableStatements(t *testing.T) { sql: "CREATE TABLE foo(a int CHECK (a > 0))", expectedOp: expect.CreateTableOp10, }, + { + sql: "CREATE TABLE foo(a int CONSTRAINT my_check CHECK (a > 0))", + expectedOp: expect.CreateTableOp18, + }, { sql: "CREATE TABLE foo(a timestamptz DEFAULT now())", expectedOp: expect.CreateTableOp11, }, + { + sql: "CREATE TABLE foo(a int CONSTRAINT my_fk REFERENCES bar(b))", + expectedOp: expect.CreateTableOp19, + }, + { + sql: "CREATE TABLE foo(a int REFERENCES bar(b))", + expectedOp: expect.CreateTableOp12, + }, { sql: "CREATE TABLE foo(a int REFERENCES bar(b) NOT DEFERRABLE)", expectedOp: expect.CreateTableOp12, @@ -227,14 +239,13 @@ func TestUnconvertableCreateTableStatements(t *testing.T) { "CREATE TABLE foo(a int REFERENCES bar (b) ON UPDATE SET DEFAULT)", "CREATE TABLE foo(a int REFERENCES bar (b) MATCH FULL)", - // Named inline constraints are not supported - "CREATE TABLE foo(a int CONSTRAINT foo_check CHECK (a > 0))", - "CREATE TABLE foo(a int CONSTRAINT foo_unique UNIQUE)", - "CREATE TABLE foo(a int CONSTRAINT foo_pk PRIMARY KEY)", - "CREATE TABLE foo(a int CONSTRAINT foo_fk REFERENCES bar(b))", + // Named inline constraints are not supported for DEFAULT, NULL, NOT NULL, + // UNIQUE or PRIMARY KEY constraints "CREATE TABLE foo(a int CONSTRAINT foo_default DEFAULT 0)", "CREATE TABLE foo(a int CONSTRAINT foo_null NULL)", "CREATE TABLE foo(a int CONSTRAINT foo_notnull NOT NULL)", + "CREATE TABLE foo(a int CONSTRAINT foo_unique UNIQUE)", + "CREATE TABLE foo(a int CONSTRAINT foo_pk PRIMARY KEY)", // Generated columns are not supported "CREATE TABLE foo(a int GENERATED ALWAYS AS (1) STORED)", diff --git a/pkg/sql2pgroll/expect/create_table.go b/pkg/sql2pgroll/expect/create_table.go index 93c68ca6..b46adaa5 100644 --- a/pkg/sql2pgroll/expect/create_table.go +++ b/pkg/sql2pgroll/expect/create_table.go @@ -225,3 +225,35 @@ var CreateTableOp17 = &migrations.OpCreateTable{ }, }, } + +var CreateTableOp18 = &migrations.OpCreateTable{ + Name: "foo", + Columns: []migrations.Column{ + { + Name: "a", + Type: "int", + Nullable: true, + Check: &migrations.CheckConstraint{ + Name: "my_check", + Constraint: "a > 0", + }, + }, + }, +} + +var CreateTableOp19 = &migrations.OpCreateTable{ + Name: "foo", + Columns: []migrations.Column{ + { + Name: "a", + Type: "int", + Nullable: true, + References: &migrations.ForeignKeyReference{ + Name: "my_fk", + Table: "bar", + Column: "b", + OnDelete: migrations.ForeignKeyReferenceOnDeleteNOACTION, + }, + }, + }, +}