diff --git a/pkg/migrations/op_add_column.go b/pkg/migrations/op_add_column.go index c27a6d32..09d81699 100644 --- a/pkg/migrations/op_add_column.go +++ b/pkg/migrations/op_add_column.go @@ -18,7 +18,7 @@ var _ Operation = (*OpAddColumn)(nil) func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) - if err := addColumn(ctx, conn, *o, table); err != nil { + if err := addColumn(ctx, conn, *o, table, tr); err != nil { return nil, fmt.Errorf("failed to start add column operation: %w", err) } @@ -182,7 +182,7 @@ func (o *OpAddColumn) Validate(ctx context.Context, s *schema.Schema) error { return nil } -func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table) error { +func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table, tr SQLTransformer) error { // don't add non-nullable columns with no default directly // they are handled by: // - adding the column as nullable @@ -203,10 +203,16 @@ func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table o.Column.Check = nil o.Column.Name = TemporaryName(o.Column.Name) - _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s", + colSQL, err := ColumnToSQL(o.Column, tr) + if err != nil { + return err + } + + _, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s", pq.QuoteIdentifier(t.Name), - ColumnToSQL(o.Column), + colSQL, )) + return err } diff --git a/pkg/migrations/op_add_column_test.go b/pkg/migrations/op_add_column_test.go index 957161ee..0a248053 100644 --- a/pkg/migrations/op_add_column_test.go +++ b/pkg/migrations/op_add_column_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/xataio/pgroll/pkg/migrations" + "github.com/xataio/pgroll/pkg/roll" "github.com/xataio/pgroll/pkg/testutils" ) @@ -1289,3 +1290,103 @@ func TestAddColumnWithComment(t *testing.T) { }, }}) } + +func TestAddColumnDefaultTransformation(t *testing.T) { + t.Parallel() + + sqlTransformer := testutils.NewMockSQLTransformer(map[string]string{ + "'default value 1'": "'rewritten'", + "'default value 2'": testutils.MockSQLTransformerError, + }) + + ExecuteTests(t, TestCases{ + { + name: "column default is rewritten by the SQL transformer", + migrations: []migrations.Migration{ + { + Name: "01_create_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: ptr(true), + }, + }, + }, + }, + }, + { + Name: "02_add_column", + Operations: migrations.Operations{ + &migrations.OpAddColumn{ + Table: "users", + Column: migrations.Column{ + Name: "name", + Type: "text", + Default: ptr("'default value 1'"), + }, + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB, schema string) { + // Insert some data into the table + MustInsert(t, db, schema, "02_add_column", "users", map[string]string{ + "id": "1", + }) + + // Ensure the row has the rewritten default value. + rows := MustSelect(t, db, schema, "02_add_column", "users") + assert.Equal(t, []map[string]any{ + {"id": 1, "name": "rewritten"}, + }, rows) + }, + afterRollback: func(t *testing.T, db *sql.DB, schema string) { + }, + afterComplete: func(t *testing.T, db *sql.DB, schema string) { + // Ensure the row has the rewritten default value. + rows := MustSelect(t, db, schema, "02_add_column", "users") + assert.Equal(t, []map[string]any{ + {"id": 1, "name": "rewritten"}, + }, rows) + }, + }, + { + name: "operation fails when the SQL transformer returns an error", + migrations: []migrations.Migration{ + { + Name: "01_create_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: ptr(true), + }, + }, + }, + }, + }, + { + Name: "02_add_column", + Operations: migrations.Operations{ + &migrations.OpAddColumn{ + Table: "users", + Column: migrations.Column{ + Name: "name", + Type: "text", + Default: ptr("'default value 2'"), + }, + }, + }, + }, + }, + wantStartErr: testutils.ErrMockSQLTransformer, + }, + }, roll.WithSQLTransformer(sqlTransformer)) +} diff --git a/pkg/migrations/op_create_table.go b/pkg/migrations/op_create_table.go index 46a71ef4..93b8c7ec 100644 --- a/pkg/migrations/op_create_table.go +++ b/pkg/migrations/op_create_table.go @@ -15,10 +15,17 @@ import ( var _ Operation = (*OpCreateTable)(nil) func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { + // Generate SQL for the columns in the table + columnsSQL, err := columnsToSQL(o.Columns, tr) + if err != nil { + return nil, fmt.Errorf("failed to create columns SQL: %w", err) + } + + // Create the table under a temporary name tempName := TemporaryName(o.Name) - _, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (%s)", + _, err = conn.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (%s)", pq.QuoteIdentifier(tempName), - columnsToSQL(o.Columns))) + columnsSQL)) if err != nil { return nil, err } @@ -104,18 +111,22 @@ func (o *OpCreateTable) Validate(ctx context.Context, s *schema.Schema) error { return nil } -func columnsToSQL(cols []Column) string { +func columnsToSQL(cols []Column, tr SQLTransformer) (string, error) { var sql string for i, col := range cols { if i > 0 { sql += ", " } - sql += ColumnToSQL(col) + colSQL, err := ColumnToSQL(col, tr) + if err != nil { + return "", err + } + sql += colSQL } - return sql + return sql, nil } -func ColumnToSQL(col Column) string { +func ColumnToSQL(col Column, tr SQLTransformer) (string, error) { sql := fmt.Sprintf("%s %s", pq.QuoteIdentifier(col.Name), col.Type) if col.IsPrimaryKey() { @@ -128,7 +139,11 @@ func ColumnToSQL(col Column) string { sql += " NOT NULL" } if col.Default != nil { - sql += fmt.Sprintf(" DEFAULT %s", *col.Default) + d, err := tr.TransformSQL(*col.Default) + if err != nil { + return "", err + } + sql += fmt.Sprintf(" DEFAULT %s", d) } if col.References != nil { onDelete := "NO ACTION" @@ -147,5 +162,5 @@ func ColumnToSQL(col Column) string { pq.QuoteIdentifier(col.Check.Name), col.Check.Constraint) } - return sql + return sql, nil } diff --git a/pkg/migrations/op_create_table_test.go b/pkg/migrations/op_create_table_test.go index 597cf99b..85de4713 100644 --- a/pkg/migrations/op_create_table_test.go +++ b/pkg/migrations/op_create_table_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/xataio/pgroll/pkg/migrations" + "github.com/xataio/pgroll/pkg/roll" "github.com/xataio/pgroll/pkg/testutils" "github.com/stretchr/testify/assert" @@ -466,3 +467,92 @@ func TestCreateTableValidation(t *testing.T) { }, }) } + +func TestCreateTableColumnDefaultTransformation(t *testing.T) { + t.Parallel() + + sqlTransformer := testutils.NewMockSQLTransformer(map[string]string{ + "'default value 1'": "'rewritten'", + "'default value 2'": testutils.MockSQLTransformerError, + }) + + ExecuteTests(t, TestCases{ + { + name: "column default is rewritten by the SQL transformer", + migrations: []migrations.Migration{ + { + Name: "01_create_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: ptr(true), + }, + { + Name: "name", + Type: "text", + Default: ptr("'default value 1'"), + }, + }, + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB, schema string) { + // Insert some data into the table + MustInsert(t, db, schema, "01_create_table", "users", map[string]string{ + "id": "1", + }) + + // Ensure the row has the rewritten default value. + rows := MustSelect(t, db, schema, "01_create_table", "users") + assert.Equal(t, []map[string]any{ + {"id": 1, "name": "rewritten"}, + }, rows) + }, + afterRollback: func(t *testing.T, db *sql.DB, schema string) { + }, + afterComplete: func(t *testing.T, db *sql.DB, schema string) { + // Insert some data into the table + MustInsert(t, db, schema, "01_create_table", "users", map[string]string{ + "id": "1", + }) + + // Ensure the row has the rewritten default value. + rows := MustSelect(t, db, schema, "01_create_table", "users") + assert.Equal(t, []map[string]any{ + {"id": 1, "name": "rewritten"}, + }, rows) + }, + }, + { + name: "create table fails when the SQL transformer returns an error", + migrations: []migrations.Migration{ + { + Name: "01_create_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: ptr(true), + }, + { + Name: "name", + Type: "text", + Default: ptr("'default value 2'"), + }, + }, + }, + }, + }, + }, + wantStartErr: testutils.ErrMockSQLTransformer, + }, + }, roll.WithSQLTransformer(sqlTransformer)) +}