diff --git a/pkg/sql2pgroll/convert.go b/pkg/sql2pgroll/convert.go index 40ca02641..69d812a57 100644 --- a/pkg/sql2pgroll/convert.go +++ b/pkg/sql2pgroll/convert.go @@ -45,6 +45,8 @@ func convert(sql string) (migrations.Operations, error) { return convertAlterTableStmt(node.AlterTableStmt) case *pgq.Node_RenameStmt: return convertRenameStmt(node.RenameStmt) + case *pgq.Node_DropStmt: + return convertDropStatement(node.DropStmt) default: return makeRawSQLOperation(sql), nil } diff --git a/pkg/sql2pgroll/drop.go b/pkg/sql2pgroll/drop.go new file mode 100644 index 000000000..ed6596b21 --- /dev/null +++ b/pkg/sql2pgroll/drop.go @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: Apache-2.0 + +package sql2pgroll + +import ( + "strings" + + pgq "github.com/pganalyze/pg_query_go/v6" + + "github.com/xataio/pgroll/pkg/migrations" +) + +// convertDropStatement converts supported drop statements to pgroll operations +func convertDropStatement(stmt *pgq.DropStmt) (migrations.Operations, error) { + if stmt.RemoveType == pgq.ObjectType_OBJECT_INDEX { + return convertDropIndexStatement(stmt) + } + return nil, nil +} + +// convertDropIndexStatement converts simple DROP INDEX statements to pgroll operations +func convertDropIndexStatement(stmt *pgq.DropStmt) (migrations.Operations, error) { + if !canConvertDropIndex(stmt) { + return nil, nil + } + items := stmt.GetObjects()[0].GetList().GetItems() + parts := make([]string, len(items)) + for i, item := range items { + parts[i] = item.GetString_().GetSval() + } + + return migrations.Operations{ + &migrations.OpDropIndex{ + Name: strings.Join(parts, "."), + }, + }, nil +} + +// canConvertDropIndex checks whether we can convert the statement without losing any information. +func canConvertDropIndex(stmt *pgq.DropStmt) bool { + if len(stmt.Objects) > 1 { + return false + } + if stmt.Behavior == pgq.DropBehavior_DROP_CASCADE { + return false + } + return true +} diff --git a/pkg/sql2pgroll/drop_test.go b/pkg/sql2pgroll/drop_test.go new file mode 100644 index 000000000..52c415eb1 --- /dev/null +++ b/pkg/sql2pgroll/drop_test.go @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: Apache-2.0 + +package sql2pgroll_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/xataio/pgroll/pkg/migrations" + "github.com/xataio/pgroll/pkg/sql2pgroll" + "github.com/xataio/pgroll/pkg/sql2pgroll/expect" +) + +func TestDropIndexStatements(t *testing.T) { + t.Parallel() + + tests := []struct { + sql string + expectedOp migrations.Operation + }{ + { + sql: "DROP INDEX foo", + expectedOp: expect.DropIndexOp1, + }, + { + sql: "DROP INDEX myschema.foo", + expectedOp: expect.DropIndexOp2, + }, + { + sql: "DROP INDEX foo RESTRICT", + expectedOp: expect.DropIndexOp1, + }, + { + sql: "DROP INDEX IF EXISTS foo", + expectedOp: expect.DropIndexOp1, + }, + { + sql: "DROP INDEX CONCURRENTLY foo", + expectedOp: expect.DropIndexOp1, + }, + } + + for _, tc := range tests { + t.Run(tc.sql, func(t *testing.T) { + ops, err := sql2pgroll.Convert(tc.sql) + require.NoError(t, err) + + require.Len(t, ops, 1) + + assert.Equal(t, tc.expectedOp, ops[0]) + }) + } +} + +func TestUnconvertableDropIndexStatements(t *testing.T) { + t.Parallel() + + tests := []string{ + "DROP INDEX foo CASCADE", + } + + for _, sql := range tests { + t.Run(sql, func(t *testing.T) { + ops, err := sql2pgroll.Convert(sql) + require.NoError(t, err) + + require.Len(t, ops, 1) + + assert.Equal(t, expect.RawSQLOp(sql), ops[0]) + }) + } +} diff --git a/pkg/sql2pgroll/expect/drop_index.go b/pkg/sql2pgroll/expect/drop_index.go new file mode 100644 index 000000000..81d60e7b8 --- /dev/null +++ b/pkg/sql2pgroll/expect/drop_index.go @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: Apache-2.0 + +package expect + +import ( + "github.com/xataio/pgroll/pkg/migrations" +) + +var DropIndexOp1 = &migrations.OpDropIndex{ + Name: "foo", +} + +var DropIndexOp2 = &migrations.OpDropIndex{ + Name: "myschema.foo", +}