From a0b37cf277de8bc145122c8309f312c465c03779 Mon Sep 17 00:00:00 2001 From: Ryan Slade Date: Tue, 17 Dec 2024 10:07:22 +0100 Subject: [PATCH] Add helper to get fully qualified relation name --- pkg/sql2pgroll/alter_table.go | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/pkg/sql2pgroll/alter_table.go b/pkg/sql2pgroll/alter_table.go index db65d33f..1cc0b35d 100644 --- a/pkg/sql2pgroll/alter_table.go +++ b/pkg/sql2pgroll/alter_table.go @@ -209,15 +209,8 @@ func convertAlterTableAddForeignKeyConstraint(stmt *pgq.AlterTableStmt, constrai return nil, fmt.Errorf("unknown delete action: %q", constraint.GetFkDelAction()) } - tableName := stmt.GetRelation().GetRelname() - if stmt.GetRelation().GetSchemaname() != "" { - tableName = stmt.GetRelation().GetSchemaname() + "." + tableName - } - - foreignTable := constraint.GetPktable().GetRelname() - if constraint.GetPktable().GetSchemaname() != "" { - foreignTable = constraint.GetPktable().GetSchemaname() + "." + foreignTable - } + tableName := getQualifiedRelationName(stmt.Relation) + foreignTable := getQualifiedRelationName(constraint.GetPktable()) return &migrations.OpCreateConstraint{ Columns: columns, @@ -268,10 +261,7 @@ func convertAlterTableAddCheckConstraint(stmt *pgq.AlterTableStmt, constraint *p return nil, nil } - tableName := stmt.GetRelation().GetRelname() - if stmt.GetRelation().GetSchemaname() != "" { - tableName = stmt.GetRelation().GetSchemaname() + "." + tableName - } + tableName := getQualifiedRelationName(stmt.GetRelation()) expr, err := pgq.DeparseExpr(constraint.GetRawExpr()) if err != nil { @@ -367,10 +357,7 @@ func convertAlterTableDropConstraint(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTab return nil, nil } - tableName := stmt.GetRelation().GetRelname() - if stmt.GetRelation().GetSchemaname() != "" { - tableName = stmt.GetRelation().GetSchemaname() + "." + tableName - } + tableName := getQualifiedRelationName(stmt.GetRelation()) return &migrations.OpDropMultiColumnConstraint{ Up: migrations.MultiColumnUpSQL{ @@ -443,6 +430,13 @@ func canConvertColumnForSetDataType(column *pgq.ColumnDef) bool { return true } +func getQualifiedRelationName(rel *pgq.RangeVar) string { + if rel.GetSchemaname() == "" { + return rel.GetRelname() + } + return fmt.Sprintf("%s.%s", rel.GetSchemaname(), rel.GetRelname()) +} + func ptr[T any](x T) *T { return &x }