diff --git a/pkg/sql2pgroll/alter_table.go b/pkg/sql2pgroll/alter_table.go index 699757d9..752c4b32 100644 --- a/pkg/sql2pgroll/alter_table.go +++ b/pkg/sql2pgroll/alter_table.go @@ -4,7 +4,6 @@ package sql2pgroll import ( "fmt" - "strconv" "github.com/oapi-codegen/nullable" pgq "github.com/xataio/pg_query_go/v6" @@ -272,29 +271,16 @@ func convertAlterTableSetColumnDefault(stmt *pgq.AlterTableStmt, cmd *pgq.AlterT operation.Default = nullable.NewNullNullable[string]() return operation, nil } - - // We have a constant - switch v := c.GetVal().(type) { - case *pgq.A_Const_Sval: - operation.Default = nullable.NewNullableWithValue(fmt.Sprintf("'%s'", v.Sval.GetSval())) - case *pgq.A_Const_Ival: - operation.Default = nullable.NewNullableWithValue(strconv.FormatInt(int64(v.Ival.Ival), 10)) - case *pgq.A_Const_Fval: - operation.Default = nullable.NewNullableWithValue(v.Fval.Fval) - case *pgq.A_Const_Boolval: - operation.Default = nullable.NewNullableWithValue(strconv.FormatBool(v.Boolval.Boolval)) - case *pgq.A_Const_Bsval: - operation.Default = nullable.NewNullableWithValue(fmt.Sprintf("'%s'", v.Bsval.GetBsval())) - default: - return nil, fmt.Errorf("unknown constant type: %T", c.GetVal()) - } - - return operation, nil } + // We're setting it to an expression if cmd.GetDef() != nil { - // We're setting it to something other than a constant - return nil, nil + def, err := pgq.DeparseExpr(cmd.GetDef()) + if err != nil { + return nil, fmt.Errorf("failed to deparse expression: %w", err) + } + operation.Default = nullable.NewNullableWithValue(def) + return operation, nil } // We're not setting it to anything, which is the case when we are dropping it diff --git a/pkg/sql2pgroll/alter_table_test.go b/pkg/sql2pgroll/alter_table_test.go index f3665aa2..071a1410 100644 --- a/pkg/sql2pgroll/alter_table_test.go +++ b/pkg/sql2pgroll/alter_table_test.go @@ -64,6 +64,14 @@ func TestConvertAlterTableStatements(t *testing.T) { sql: "ALTER TABLE foo ALTER COLUMN bar SET DEFAULT null", expectedOp: expect.AlterColumnOp7, }, + { + sql: "ALTER TABLE foo ALTER COLUMN bar SET DEFAULT now()", + expectedOp: expect.AlterColumnOp11, + }, + { + sql: "ALTER TABLE foo ALTER COLUMN bar SET DEFAULT (first_name || ' ' || last_name)", + expectedOp: expect.AlterColumnOp12, + }, { sql: "ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a)", expectedOp: expect.CreateConstraintOp1, @@ -146,9 +154,6 @@ func TestUnconvertableAlterTableStatements(t *testing.T) { "ALTER TABLE foo DROP COLUMN bar CASCADE", "ALTER TABLE foo DROP COLUMN IF EXISTS bar", - // Non literal default values - "ALTER TABLE foo ALTER COLUMN bar SET DEFAULT now()", - // Unsupported foreign key statements "ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) ON UPDATE RESTRICT;", "ALTER TABLE foo ADD CONSTRAINT fk_bar_cd FOREIGN KEY (a, b) REFERENCES bar (c, d) ON UPDATE CASCADE;", diff --git a/pkg/sql2pgroll/expect/alter_column.go b/pkg/sql2pgroll/expect/alter_column.go index 231f6663..adfdbd57 100644 --- a/pkg/sql2pgroll/expect/alter_column.go +++ b/pkg/sql2pgroll/expect/alter_column.go @@ -82,7 +82,23 @@ var AlterColumnOp9 = &migrations.OpAlterColumn{ var AlterColumnOp10 = &migrations.OpAlterColumn{ Table: "foo", Column: "bar", - Default: nullable.NewNullableWithValue("'b0101'"), + Default: nullable.NewNullableWithValue("b'0101'"), + Up: sql2pgroll.PlaceHolderSQL, + Down: sql2pgroll.PlaceHolderSQL, +} + +var AlterColumnOp11 = &migrations.OpAlterColumn{ + Table: "foo", + Column: "bar", + Default: nullable.NewNullableWithValue("now()"), + Up: sql2pgroll.PlaceHolderSQL, + Down: sql2pgroll.PlaceHolderSQL, +} + +var AlterColumnOp12 = &migrations.OpAlterColumn{ + Table: "foo", + Column: "bar", + Default: nullable.NewNullableWithValue("(first_name || ' ') || last_name"), Up: sql2pgroll.PlaceHolderSQL, Down: sql2pgroll.PlaceHolderSQL, }