From 4c22d004e43625aeabedb0bbffa0be43c38c18de Mon Sep 17 00:00:00 2001 From: Vivek Yadav Date: Fri, 22 Nov 2024 15:26:50 +0530 Subject: [PATCH] Added the backend changes for check constraints --- internal/convert.go | 1 + internal/helpers.go | 7 +- internal/mapping.go | 13 ++ internal/reports/report_helpers.go | 7 + schema/schema.go | 38 ++-- sources/common/infoschema.go | 23 +-- sources/common/toddl.go | 32 +++- sources/common/toddl_test.go | 30 ++++ sources/mysql/infoschema.go | 100 +++++++++-- sources/mysql/infoschema_test.go | 115 ++++++++++++ spanner/ddl/ast.go | 55 ++++-- spanner/ddl/ast_test.go | 61 +++++-- webv2/api/schema.go | 103 +++++++++++ webv2/api/schema_test.go | 272 +++++++++++++++++++++++++++++ webv2/routes.go | 4 +- webv2/table/review_table_schema.go | 9 +- webv2/table/update_table_schema.go | 10 ++ 17 files changed, 809 insertions(+), 71 deletions(-) diff --git a/internal/convert.go b/internal/convert.go index 6b089053d..81339bda9 100644 --- a/internal/convert.go +++ b/internal/convert.go @@ -128,6 +128,7 @@ const ( SequenceCreated ForeignKeyActionNotSupported NumericPKNotSupported + TypeMismatch ) const ( diff --git a/internal/helpers.go b/internal/helpers.go index 8690dd10c..d4f1de43d 100644 --- a/internal/helpers.go +++ b/internal/helpers.go @@ -26,7 +26,7 @@ import ( type Counter struct { counterMutex sync.Mutex - ObjectId string + ObjectId string } var Cntr Counter @@ -65,6 +65,11 @@ func GenerateForeignkeyId() string { func GenerateIndexesId() string { return GenerateId("i") } + +func GenerateCheckConstrainstId() string { + return GenerateId("ck") +} + func GenerateRuleId() string { return GenerateId("r") } diff --git a/internal/mapping.go b/internal/mapping.go index 0eb9bd055..1b955bdba 100644 --- a/internal/mapping.go +++ b/internal/mapping.go @@ -243,6 +243,19 @@ func ToSpannerIndexName(conv *Conv, srcIndexName string) string { return getSpannerValidName(conv, srcIndexName) } +// Note that the check constraints names in spanner have to be globally unique +// (across the database). But in some source databases, such as MySQL, +// they only have to be unique for a table. Hence we must map each source +// constraint name to a unique spanner constraint name. +func ToSpannerCheckConstraintName(conv *Conv, srcCheckConstraintName string) string { + return getSpannerValidName(conv, srcCheckConstraintName) +} + +func GetSpannerValidExpression(cks []ddl.Checkconstraint) []ddl.Checkconstraint { + // TODO validate the check constraints data with batch verification then send back + return cks +} + // conv.UsedNames tracks Spanner names that have been used for table names, foreign key constraints // and indexes. We use this to ensure we generate unique names when // we map from source dbs to Spanner since Spanner requires all these names to be diff --git a/internal/reports/report_helpers.go b/internal/reports/report_helpers.go index a066f0cde..34cb58511 100644 --- a/internal/reports/report_helpers.go +++ b/internal/reports/report_helpers.go @@ -403,6 +403,13 @@ func buildTableReportBody(conv *internal.Conv, tableId string, issues map[string Description: fmt.Sprintf("UNIQUE constraint on column(s) '%s' replaced with primary key since table '%s' didn't have one. Spanner requires a primary key for every table", strings.Join(uniquePK, ", "), conv.SpSchema[tableId].Name), } l = append(l, toAppend) + case internal.TypeMismatch: + toAppend := Issue{ + Category: IssueDB[i].Category, + Description: fmt.Sprintf("Table '%s': Type mismatch in '%s'column affecting check constraints. Verify data type compatibility with constraint logic", conv.SpSchema[tableId].Name, conv.SpSchema[tableId].ColDefs[colId].Name), + } + l = append(l, toAppend) + default: toAppend := Issue{ Category: IssueDB[i].Category, diff --git a/schema/schema.go b/schema/schema.go index 48b125bba..6bade2a4e 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -35,26 +35,27 @@ import ( // Table represents a database table. type Table struct { - Name string - Schema string - ColIds []string // List of column Ids (for predictable iteration order e.g. printing). - ColDefs map[string]Column // Details of columns. - ColNameIdMap map[string]string `json:"-"` // Computed every time just after conv is generated or after any column renaming - PrimaryKeys []Key - ForeignKeys []ForeignKey - Indexes []Index - Id string + Name string + Schema string + ColIds []string // List of column Ids (for predictable iteration order e.g. printing). + ColDefs map[string]Column // Details of columns. + ColNameIdMap map[string]string `json:"-"` // Computed every time just after conv is generated or after any column renaming + PrimaryKeys []Key + ForeignKeys []ForeignKey + CheckConstraints []CheckConstraints + Indexes []Index + Id string } // Column represents a database column. // TODO: add support for foreign keys. type Column struct { - Name string - Type Type - NotNull bool - Ignored Ignored - Id string - AutoGen ddl.AutoGenCol + Name string + Type Type + NotNull bool + Ignored Ignored + Id string + AutoGen ddl.AutoGenCol } // ForeignKey represents a foreign key. @@ -76,6 +77,13 @@ type ForeignKey struct { Id string } +// CheckConstraints represents a Check Constrainst. +type CheckConstraints struct { + Name string + Expr string + Id string +} + // Key respresents a primary key or index key. type Key struct { ColId string diff --git a/sources/common/infoschema.go b/sources/common/infoschema.go index 5144d2613..da2612895 100644 --- a/sources/common/infoschema.go +++ b/sources/common/infoschema.go @@ -36,7 +36,7 @@ type InfoSchema interface { GetColumns(conv *internal.Conv, table SchemaAndName, constraints map[string][]string, primaryKeys []string) (map[string]schema.Column, []string, error) GetRowsFromTable(conv *internal.Conv, srcTable string) (interface{}, error) GetRowCount(table SchemaAndName) (int64, error) - GetConstraints(conv *internal.Conv, table SchemaAndName) ([]string, map[string][]string, error) + GetConstraints(conv *internal.Conv, table SchemaAndName) ([]string, []schema.CheckConstraints, map[string][]string, error) GetForeignKeys(conv *internal.Conv, table SchemaAndName) (foreignKeys []schema.ForeignKey, err error) GetIndexes(conv *internal.Conv, table SchemaAndName, colNameIdMp map[string]string) ([]schema.Index, error) ProcessData(conv *internal.Conv, tableId string, srcSchema schema.Table, spCols []string, spSchema ddl.CreateTable, additionalAttributes internal.AdditionalDataAttributes) error @@ -185,7 +185,7 @@ func (is *InfoSchemaImpl) processTable(conv *internal.Conv, table SchemaAndName, var t schema.Table fmt.Println("processing schema for table", table) tblId := internal.GenerateTableId() - primaryKeys, constraints, err := infoSchema.GetConstraints(conv, table) + primaryKeys, checkConstraints, constraints, err := infoSchema.GetConstraints(conv, table) if err != nil { return t, fmt.Errorf("couldn't get constraints for table %s.%s: %s", table.Schema, table.Name, err) } @@ -215,15 +215,16 @@ func (is *InfoSchemaImpl) processTable(conv *internal.Conv, table SchemaAndName, schemaPKeys = append(schemaPKeys, schema.Key{ColId: colNameIdMap[k]}) } t = schema.Table{ - Id: tblId, - Name: name, - Schema: table.Schema, - ColIds: colIds, - ColNameIdMap: colNameIdMap, - ColDefs: colDefs, - PrimaryKeys: schemaPKeys, - Indexes: indexes, - ForeignKeys: foreignKeys} + Id: tblId, + Name: name, + Schema: table.Schema, + ColIds: colIds, + ColNameIdMap: colNameIdMap, + ColDefs: colDefs, + PrimaryKeys: schemaPKeys, + CheckConstraints: checkConstraints, + Indexes: indexes, + ForeignKeys: foreignKeys} return t, nil } diff --git a/sources/common/toddl.go b/sources/common/toddl.go index 16f4cf3ce..bf4379288 100644 --- a/sources/common/toddl.go +++ b/sources/common/toddl.go @@ -167,14 +167,15 @@ func (ss *SchemaToSpannerImpl) SchemaToSpannerDDLHelper(conv *internal.Conv, tod } comment := "Spanner schema for source table " + quoteIfNeeded(srcTable.Name) conv.SpSchema[srcTable.Id] = ddl.CreateTable{ - Name: spTableName, - ColIds: spColIds, - ColDefs: spColDef, - PrimaryKeys: cvtPrimaryKeys(srcTable.PrimaryKeys), - ForeignKeys: cvtForeignKeys(conv, spTableName, srcTable.Id, srcTable.ForeignKeys, isRestore), - Indexes: cvtIndexes(conv, srcTable.Id, srcTable.Indexes, spColIds, spColDef), - Comment: comment, - Id: srcTable.Id} + Name: spTableName, + ColIds: spColIds, + ColDefs: spColDef, + PrimaryKeys: cvtPrimaryKeys(srcTable.PrimaryKeys), + ForeignKeys: cvtForeignKeys(conv, spTableName, srcTable.Id, srcTable.ForeignKeys, isRestore), + CheckConstraint: cvtCheckConstraint(conv, srcTable.CheckConstraints), + Indexes: cvtIndexes(conv, srcTable.Id, srcTable.Indexes, spColIds, spColDef), + Comment: comment, + Id: srcTable.Id} return nil } @@ -234,6 +235,21 @@ func cvtForeignKeys(conv *internal.Conv, spTableName string, srcTableId string, return spKeys } +func cvtCheckConstraint(conv *internal.Conv, srcKeys []schema.CheckConstraints) []ddl.Checkconstraint { + var spcks []ddl.Checkconstraint + + for _, cks := range srcKeys { + spcks = append(spcks, ddl.Checkconstraint{ + Id: cks.Id, + Name: internal.ToSpannerCheckConstraintName(conv, cks.Name), + Expr: cks.Expr, + }) + + } + + return internal.GetSpannerValidExpression(spcks) +} + func CvtForeignKeysHelper(conv *internal.Conv, spTableName string, srcTableId string, srcKey schema.ForeignKey, isRestore bool) (ddl.Foreignkey, error) { if len(srcKey.ColIds) != len(srcKey.ReferColumnIds) { conv.Unexpected(fmt.Sprintf("ConvertForeignKeys: ColIds and referColumns don't have the same lengths: len(columns)=%d, len(referColumns)=%d for source tableId: %s, referenced table: %s", len(srcKey.ColIds), len(srcKey.ReferColumnIds), srcTableId, srcKey.ReferTableId)) diff --git a/sources/common/toddl_test.go b/sources/common/toddl_test.go index 63ecad71b..fd1504a84 100644 --- a/sources/common/toddl_test.go +++ b/sources/common/toddl_test.go @@ -428,3 +428,33 @@ func Test_SchemaToSpannerSequenceHelper(t *testing.T) { assert.Equal(t, expectedConv, conv) } } +func Test_cvtCheckContraint(t *testing.T) { + + conv := internal.MakeConv() + srcSchema := []schema.CheckConstraints{ + { + Id: "ck1", + Name: "check_1", + Expr: "age > 0", + }, + { + Id: "ck1", + Name: "check_2", + Expr: "age < 99", + }, + } + spSchema := []ddl.Checkconstraint{ + { + Id: "ck1", + Name: "check_1", + Expr: "age > 0", + }, + { + Id: "ck1", + Name: "check_2", + Expr: "age < 99", + }, + } + result := cvtCheckConstraint(conv, srcSchema) + assert.Equal(t, spSchema, result) +} diff --git a/sources/mysql/infoschema.go b/sources/mysql/infoschema.go index 87458188f..7f7a69028 100644 --- a/sources/mysql/infoschema.go +++ b/sources/mysql/infoschema.go @@ -199,7 +199,7 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd // We've already filtered out PRIMARY KEY. switch c { case "CHECK": - ignored.Check = true + case "FOREIGN KEY", "PRIMARY KEY", "UNIQUE": // Nothing to do here -- these are all handled elsewhere. } @@ -237,29 +237,107 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd // other constraints. Note: we need to preserve ordinal order of // columns in primary key constraints. // Note that foreign key constraints are handled in getForeignKeys. -func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, map[string][]string, error) { +func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, []schema.CheckConstraints, map[string][]string, error) { q := `SELECT k.COLUMN_NAME, t.CONSTRAINT_TYPE FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t INNER JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k ON t.CONSTRAINT_NAME = k.CONSTRAINT_NAME AND t.CONSTRAINT_SCHEMA = k.CONSTRAINT_SCHEMA AND t.TABLE_NAME=k.TABLE_NAME WHERE k.TABLE_SCHEMA = ? AND k.TABLE_NAME = ? ORDER BY k.ordinal_position;` - rows, err := isi.Db.Query(q, table.Schema, table.Name) + + q1 := `SELECT + COALESCE(k.COLUMN_NAME, '') AS COLUMN_NAME, + t.CONSTRAINT_NAME, + t.CONSTRAINT_TYPE, + COALESCE(c.CHECK_CLAUSE, '') AS CHECK_CLAUSE + FROM + INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t + LEFT JOIN + INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k + ON t.CONSTRAINT_NAME = k.CONSTRAINT_NAME + AND t.CONSTRAINT_SCHEMA = k.CONSTRAINT_SCHEMA + AND t.TABLE_NAME = k.TABLE_NAME + LEFT JOIN + INFORMATION_SCHEMA.CHECK_CONSTRAINTS AS c + ON t.CONSTRAINT_NAME = c.CONSTRAINT_NAME + WHERE + t.TABLE_SCHEMA = ? + AND t.TABLE_NAME = ? + ORDER BY k.ORDINAL_POSITION; + ` + checkQuery := `SELECT COUNT(*) + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' + AND TABLE_NAME = 'CHECK_CONSTRAINTS';` + var tableExistsCount int + rows1, err := isi.Db.Query(checkQuery) if err != nil { - return nil, nil, err + return nil, nil, nil, err + } + for rows1.Next() { + err1 := rows1.Scan(&tableExistsCount) + if err1 != nil { + conv.Unexpected(fmt.Sprintf("Can't scan: %v", err)) + return nil, nil, nil, err + } + } + + defer rows1.Close() + + tableExists := tableExistsCount > 0 + + var finalQuery string + if tableExists { + finalQuery = q1 + } else { + finalQuery = q + } + + rows, err := isi.Db.Query(finalQuery, table.Schema, table.Name) + + if err != nil { + return nil, nil, nil, err } defer rows.Close() var primaryKeys []string - var col, constraint string + var checkKeys []schema.CheckConstraints + var col, constraintName, constraint, checkClause string m := make(map[string][]string) for rows.Next() { - err := rows.Scan(&col, &constraint) - if err != nil { - conv.Unexpected(fmt.Sprintf("Can't scan: %v", err)) - continue + if tableExists { + err := rows.Scan(&col, &constraintName, &constraint, &checkClause) + if err != nil { + conv.Unexpected(fmt.Sprintf("Can't scan: %v", err)) + continue + } + } else { + err := rows.Scan(&col, &constraintName, &constraint, &checkClause) + if err != nil { + conv.Unexpected(fmt.Sprintf("Can't scan: %v", err)) + continue + } } if col == "" || constraint == "" { - conv.Unexpected(fmt.Sprintf("Got empty col or constraint")) + + if tableExists { + if constraintName == "" || checkClause == "" { + conv.Unexpected(fmt.Sprintf("Got empty constraintName or checkClause")) + continue + } + switch constraint { + case "CHECK": + checkClause = strings.ReplaceAll(checkClause, "_utf8mb4\\", "") + checkClause = strings.ReplaceAll(checkClause, "\\", "") + + checkKeys = append(checkKeys, schema.CheckConstraints{Name: constraintName, Expr: string(checkClause), Id: internal.GenerateCheckConstrainstId()}) + default: + m[col] = append(m[col], constraint) + } + } else { + conv.Unexpected(fmt.Sprintf("Got empty col or constraint")) + } + continue + } switch constraint { case "PRIMARY KEY": @@ -268,7 +346,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem m[col] = append(m[col], constraint) } } - return primaryKeys, m, nil + return primaryKeys, checkKeys, m, nil } // GetForeignKeys return list all the foreign keys constraints. diff --git a/sources/mysql/infoschema_test.go b/sources/mysql/infoschema_test.go index a61c04580..ffc7e460b 100644 --- a/sources/mysql/infoschema_test.go +++ b/sources/mysql/infoschema_test.go @@ -17,6 +17,7 @@ package mysql import ( "database/sql" "database/sql/driver" + "fmt" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -530,3 +531,117 @@ func mkMockDB(t *testing.T, ms []mockSpec) *sql.DB { } return db } +func TestGetConstraints(t *testing.T) { + + case1 := []mockSpec{ + { + query: `SELECT COUNT\(\*\) + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' + AND TABLE_NAME = 'CHECK_CONSTRAINTS'; + `, + cols: []string{"COUNT"}, + rows: [][]driver.Value{{1}}, + }, + { + query: `(?i)SELECT\s+COALESCE\(k.COLUMN_NAME,\s*''\)\s+AS\s+COLUMN_NAME,\s+t\.CONSTRAINT_NAME,\s+t\.CONSTRAINT_TYPE,\s+COALESCE\(c.CHECK_CLAUSE,\s*''\)\s+AS\s+CHECK_CLAUSE\s+FROM\s+INFORMATION_SCHEMA\.TABLE_CONSTRAINTS\s+AS\s+t\s+LEFT\s+JOIN\s+INFORMATION_SCHEMA\.KEY_COLUMN_USAGE\s+AS\s+k\s+ON\s+t\.CONSTRAINT_NAME\s*=\s*k\.CONSTRAINT_NAME\s+AND\s+t\.CONSTRAINT_SCHEMA\s*=\s*k\.CONSTRAINT_SCHEMA\s+AND\s+t\.TABLE_NAME\s*=\s*k\.TABLE_NAME\s+LEFT\s+JOIN\s+INFORMATION_SCHEMA\.CHECK_CONSTRAINTS\s+AS\s+c\s+ON\s+t\.CONSTRAINT_NAME\s*=\s*c\.CONSTRAINT_NAME\s+WHERE\s+t\.TABLE_SCHEMA\s*=\s*\?\s+AND\s+t\.TABLE_NAME\s*=\s*\?\s*ORDER\s+BY\s+k\.ORDINAL_POSITION`, + args: []driver.Value{"test_schema", "test_table"}, + cols: []string{"COLUMN_NAME", "CONSTRAINT_NAME", "CONSTRAINT_TYPE", "CHECK_CLAUSE"}, + rows: [][]driver.Value{{"id", "PRIMARY", "PRIMARY KEY", ""}, {"", "chk_test", "CHECK", "amount > 0"}}, + }, + } + + case2 := []mockSpec{ + { + query: `SELECT COUNT\(\*\) + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' + AND TABLE_NAME = 'CHECK_CONSTRAINTS'; + `, + cols: []string{"COUNT"}, + rows: [][]driver.Value{{0}}, + }, + { + query: `(?i)SELECT\s+k\.COLUMN_NAME,\s+t\.CONSTRAINT_TYPE\s+FROM\s+INFORMATION_SCHEMA\.TABLE_CONSTRAINTS\s+AS\s+t\s+INNER\s+JOIN\s+INFORMATION_SCHEMA\.KEY_COLUMN_USAGE\s+AS\s+k\s+ON\s+t\.CONSTRAINT_NAME\s*=\s*k\.CONSTRAINT_NAME\s+AND\s+t\.CONSTRAINT_SCHEMA\s*=\s*k\.CONSTRAINT_SCHEMA\s+AND\s+t\.TABLE_NAME\s*=\s*k\.TABLE_NAME\s+WHERE\s+k\.TABLE_SCHEMA\s*=\s*\?\s+AND\s+k\.TABLE_NAME\s*=\s*\?\s*ORDER\s+BY\s+k\.ORDINAL_POSITION;`, + args: []driver.Value{"test_schema", "test_table"}, + cols: []string{"COLUMN_NAME", "CONSTRAINT_NAME", "CONSTRAINT_TYPE", "CHECK_CLAUSE"}, + rows: [][]driver.Value{{"id", "PRIMARY", "PRIMARY KEY", ""}}, + }, + } + + cases := []struct { + db []mockSpec + tableExists bool + }{ + { + db: case1, + tableExists: true, + }, + { + db: case2, + tableExists: false, + }, + } + + for _, tc := range cases { + if tc.tableExists { + db := mkMockDB(t, tc.db) + + defer db.Close() + + isi := InfoSchemaImpl{Db: db} + + table := common.SchemaAndName{ + Schema: "test_schema", + Name: "test_table", + } + + conv := new(internal.Conv) + + primaryKeys, checkKeys, constraints, err := isi.GetConstraints(conv, table) + if err != nil { + t.Fatalf("expected no error, but got %v", err) + } + + expectedPrimaryKeys := []string{"id"} + if fmt.Sprintf("%v", primaryKeys) != fmt.Sprintf("%v", expectedPrimaryKeys) { + t.Errorf("expected %v, got %v for primary keys", expectedPrimaryKeys, primaryKeys) + } + + expectedCheckKeys := []schema.CheckConstraints{ + {Name: "chk_test", Expr: "amount > 0", Id: "ck1"}, + } + + assert.Equal(t, expectedCheckKeys, checkKeys) + assert.Equal(t, expectedPrimaryKeys, primaryKeys) + assert.Empty(t, constraints) + } else { + db := mkMockDB(t, tc.db) + + defer db.Close() + + isi := InfoSchemaImpl{Db: db} + + table := common.SchemaAndName{ + Schema: "test_schema", + Name: "test_table", + } + + conv := new(internal.Conv) + + primaryKeys, checkKeys, constraints, err := isi.GetConstraints(conv, table) + if err != nil { + t.Fatalf("expected no error, but got %v", err) + } + + expectedPrimaryKeys := []string{"id"} + if fmt.Sprintf("%v", primaryKeys) != fmt.Sprintf("%v", expectedPrimaryKeys) { + t.Errorf("expected %v, got %v for primary keys", expectedPrimaryKeys, primaryKeys) + } + + assert.Equal(t, expectedPrimaryKeys, primaryKeys) + assert.Empty(t, checkKeys) + assert.Empty(t, constraints) + } + } +} diff --git a/spanner/ddl/ast.go b/spanner/ddl/ast.go index 9490a89e3..d38db2105 100644 --- a/spanner/ddl/ast.go +++ b/spanner/ddl/ast.go @@ -264,6 +264,12 @@ type IndexKey struct { Order int } +type Checkconstraint struct { + Id string + Name string + Expr string +} + // PrintPkOrIndexKey unparses the primary or index keys. func (idx IndexKey) PrintPkOrIndexKey(ct CreateTable, c Config) string { col := c.quote(ct.ColDefs[idx.ColId].Name) @@ -318,16 +324,17 @@ func (k Foreignkey) PrintForeignKey(c Config) string { // // create_table: CREATE TABLE table_name ([column_def, ...] ) primary_key [, cluster] type CreateTable struct { - Name string - ColIds []string // Provides names and order of columns - ShardIdColumn string - ColDefs map[string]ColumnDef // Provides definition of columns (a map for simpler/faster lookup during type processing) - PrimaryKeys []IndexKey - ForeignKeys []Foreignkey - Indexes []CreateIndex - ParentTable InterleavedParent //if not empty, this table will be interleaved - Comment string - Id string + Name string + ColIds []string // Provides names and order of columns + ShardIdColumn string + ColDefs map[string]ColumnDef // Provides definition of columns (a map for simpler/faster lookup during type processing) + PrimaryKeys []IndexKey + ForeignKeys []Foreignkey + Indexes []CreateIndex + ParentTable InterleavedParent //if not empty, this table will be interleaved + CheckConstraint []Checkconstraint + Comment string + Id string } // PrintCreateTable unparses a CREATE TABLE statement. @@ -381,13 +388,20 @@ func (ct CreateTable) PrintCreateTable(spSchema Schema, config Config) string { } } + var checkString string + if len(ct.CheckConstraint) != 0 { + checkString = PrintCheckConstraintTable(ct.CheckConstraint) + } else { + checkString = "" + } + if len(keys) == 0 { - return fmt.Sprintf("%sCREATE TABLE %s (\n%s) %s", tableComment, config.quote(ct.Name), cols, interleave) + return fmt.Sprintf("%sCREATE TABLE %s (\n%s %s) %s", tableComment, config.quote(ct.Name), cols, checkString, interleave) } if config.SpDialect == constants.DIALECT_POSTGRESQL { return fmt.Sprintf("%sCREATE TABLE %s (\n%s\tPRIMARY KEY (%s)\n)%s", tableComment, config.quote(ct.Name), cols, strings.Join(keys, ", "), interleave) } - return fmt.Sprintf("%sCREATE TABLE %s (\n%s) PRIMARY KEY (%s)%s", tableComment, config.quote(ct.Name), cols, strings.Join(keys, ", "), interleave) + return fmt.Sprintf("%sCREATE TABLE %s (\n%s %s) PRIMARY KEY (%s)%s", tableComment, config.quote(ct.Name), cols, checkString, strings.Join(keys, ", "), interleave) } // CreateIndex encodes the following DDL definition: @@ -494,6 +508,23 @@ func (k Foreignkey) PrintForeignKeyAlterTable(spannerSchema Schema, c Config, ta return s } +// PrintCheckConstraintTable unparses the check constraints using CHECK CONSTRAINTS. +func PrintCheckConstraintTable(cks []Checkconstraint) string { + + var s string + s = "" + for index, col := range cks { + if index == len(cks)-1 { + s = s + fmt.Sprintf("\tCONSTRAINT %s CHECK %s\n", col.Name, col.Expr) + } else { + s = s + fmt.Sprintf("\tCONSTRAINT %s CHECK %s,\n", col.Name, col.Expr) + } + + } + + return s +} + // Schema stores a map of table names and Tables. type Schema map[string]CreateTable diff --git a/spanner/ddl/ast_test.go b/spanner/ddl/ast_test.go index d6e53f96b..f55c22ed7 100644 --- a/spanner/ddl/ast_test.go +++ b/spanner/ddl/ast_test.go @@ -143,6 +143,10 @@ func TestPrintCreateTable(t *testing.T) { }, PrimaryKeys: []IndexKey{{ColId: "col1", Desc: true}}, ForeignKeys: nil, + CheckConstraint: []Checkconstraint{ + {Id: "ck1", Name: "check_1", Expr: "(age > 18)"}, + {Id: "ck2", Name: "check_2", Expr: "(age < 99)"}, + }, Indexes: nil, ParentTable: InterleavedParent{}, Comment: "", @@ -156,12 +160,13 @@ func TestPrintCreateTable(t *testing.T) { "col4": {Name: "col4", T: Type{Name: Int64}, NotNull: true}, "col5": {Name: "col5", T: Type{Name: String, Len: MaxLength}, NotNull: false}, }, - PrimaryKeys: []IndexKey{{ColId: "col4", Desc: true}}, - ForeignKeys: nil, - Indexes: nil, - ParentTable: InterleavedParent{Id: "t1", OnDelete: constants.FK_CASCADE}, - Comment: "", - Id: "t2", + PrimaryKeys: []IndexKey{{ColId: "col4", Desc: true}}, + ForeignKeys: nil, + Indexes: nil, + CheckConstraint: nil, + ParentTable: InterleavedParent{Id: "t1", OnDelete: constants.FK_CASCADE}, + Comment: "", + Id: "t2", }, "t3": CreateTable{ Name: "table3", @@ -170,12 +175,33 @@ func TestPrintCreateTable(t *testing.T) { ColDefs: map[string]ColumnDef{ "col6": {Name: "col6", T: Type{Name: Int64}, NotNull: true}, }, - PrimaryKeys: []IndexKey{{ColId: "col6", Desc: true}}, + PrimaryKeys: []IndexKey{{ColId: "col6", Desc: true}}, + ForeignKeys: nil, + Indexes: nil, + CheckConstraint: nil, + ParentTable: InterleavedParent{Id: "t1", OnDelete: ""}, + Comment: "", + Id: "t3", + }, + "t4": CreateTable{ + Name: "table1", + ColIds: []string{"col1", "col2", "col3"}, + ShardIdColumn: "", + ColDefs: map[string]ColumnDef{ + "col1": {Name: "col1", T: Type{Name: Int64}, NotNull: true}, + "col2": {Name: "col2", T: Type{Name: String, Len: MaxLength}, NotNull: false}, + "col3": {Name: "col3", T: Type{Name: Bytes, Len: int64(42)}, NotNull: false}, + }, + PrimaryKeys: nil, ForeignKeys: nil, + CheckConstraint: []Checkconstraint{ + {Id: "ck1", Name: "check_1", Expr: "(age > 18)"}, + {Id: "ck2", Name: "check_2", Expr: "(age < 99)"}, + }, Indexes: nil, - ParentTable: InterleavedParent{Id: "t1", OnDelete: ""}, + ParentTable: InterleavedParent{}, Comment: "", - Id: "t3", + Id: "t1", }, } tests := []struct { @@ -191,7 +217,8 @@ func TestPrintCreateTable(t *testing.T) { "CREATE TABLE table1 (\n" + " col1 INT64 NOT NULL ,\n" + " col2 STRING(MAX),\n" + - " col3 BYTES(42),\n" + + " col3 BYTES(42),\n " + + "CONSTRAINT check_1 CHECK (age > 18),\nCONSTRAINT check_2 CHECK (age < 99)\n" + ") PRIMARY KEY (col1 DESC)", }, { @@ -201,7 +228,8 @@ func TestPrintCreateTable(t *testing.T) { "CREATE TABLE `table1` (\n" + " `col1` INT64 NOT NULL ,\n" + " `col2` STRING(MAX),\n" + - " `col3` BYTES(42),\n" + + " `col3` BYTES(42),\n " + + "CONSTRAINT check_1 CHECK (age > 18),\nCONSTRAINT check_2 CHECK (age < 99)\n" + ") PRIMARY KEY (`col1` DESC)", }, { @@ -223,6 +251,17 @@ func TestPrintCreateTable(t *testing.T) { ") PRIMARY KEY (col6 DESC),\n" + "INTERLEAVE IN PARENT table1", }, + { + "no quote", + false, + s["t4"], + "CREATE TABLE table1 (\n" + + " col1 INT64 NOT NULL ,\n" + + " col2 STRING(MAX),\n" + + " col3 BYTES(42),\n " + + "CONSTRAINT check_1 CHECK (age > 18),\nCONSTRAINT check_2 CHECK (age < 99)\n" + + ") ", + }, } for _, tc := range tests { assert.Equal(t, tc.expected, tc.ct.PrintCreateTable(s, Config{ProtectIds: tc.protectIds})) diff --git a/webv2/api/schema.go b/webv2/api/schema.go index f8b8ca3d3..c1ea6d2be 100644 --- a/webv2/api/schema.go +++ b/webv2/api/schema.go @@ -488,6 +488,109 @@ func RestoreSecondaryIndex(w http.ResponseWriter, r *http.Request) { } +// UpdateCheckConstraint processes the request to update spanner table check constraints, ensuring session and schema validity, and responds with the updated conversion metadata. +func UpdateCheckConstraint(w http.ResponseWriter, r *http.Request) { + tableId := r.FormValue("table") + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) + } + sessionState := session.GetSessionState() + if sessionState.Conv == nil || sessionState.Driver == "" { + http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) + return + } + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + + newCKs := []ddl.Checkconstraint{} + if err = json.Unmarshal(reqBody, &newCKs); err != nil { + http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) + return + } + + sp := sessionState.Conv.SpSchema[tableId] + + sp.CheckConstraint = newCKs + + sessionState.Conv.SpSchema[tableId] = sp + + session.UpdateSessionFile() + + convm := session.ConvWithMetadata{ + SessionMetadata: sessionState.SessionMetadata, + Conv: *sessionState.Conv, + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(convm) + +} + +func doesNameExist(spcks []ddl.Checkconstraint, targetName string) bool { + for _, spck := range spcks { + if strings.Contains(spck.Expr, targetName) { + return true + } + } + return false +} + +// ValidateCheckConstraint verifies if the type of a database column has been altered and add an error if a change is detected. +func ValidateCheckConstraint(w http.ResponseWriter, r *http.Request) { + sessionState := session.GetSessionState() + if sessionState.Conv == nil || sessionState.Driver == "" { + http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) + return + } + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + + sp := sessionState.Conv.SpSchema + srcschema := sessionState.Conv.SrcSchema + + flag := true + + schemaissue := []internal.SchemaIssue{} + + for _, src := range srcschema { + + for _, col := range sp[src.Id].ColDefs { + + if len(sp[src.Id].CheckConstraint) != 0 { + spType := col.T.Name + srcType := srcschema[src.Id].ColDefs[col.Id].Type + + actualType := mysqlDefaultTypeMap[srcType.Name] + + if actualType.Name != spType { + + columnName := sp[src.Id].ColDefs[col.Id].Name + spcks := sp[src.Id].CheckConstraint + if doesNameExist(spcks, columnName) { + flag = false + + schemaissue = sessionState.Conv.SchemaIssues[src.Id].ColumnLevelIssues[col.Id] + + if !utilities.IsSchemaIssuePresent(schemaissue, internal.TypeMismatch) { + schemaissue = append(schemaissue, internal.TypeMismatch) + } + + sessionState.Conv.SchemaIssues[src.Id].ColumnLevelIssues[col.Id] = schemaissue + + break + + } + } + } + + } + + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(flag) +} + // renameForeignKeys checks the new names for spanner name validity, ensures the new names are already not used by existing tables // secondary indexes or foreign key constraints. If above checks passed then foreignKey renaming reflected in the schema else appropriate // error thrown. diff --git a/webv2/api/schema_test.go b/webv2/api/schema_test.go index 9e7894ea8..c1ca74042 100644 --- a/webv2/api/schema_test.go +++ b/webv2/api/schema_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "strings" @@ -2541,3 +2542,274 @@ func TestGetAutoGenMapMySQL(t *testing.T) { } } +func TestUpdateCheckConstraint(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Driver = constants.MYSQL + sessionState.Conv = internal.MakeConv() + + tableID := "table1" + + expectedCheckConstraint := []ddl.Checkconstraint{ + {Id: "ck1", Name: "check_1", Expr: "(age > 18)"}, + {Id: "ck2", Name: "check_2", Expr: "(age < 99)"}, + } + + checkConstraints := []schema.CheckConstraints{ + {Id: "ck1", Name: "check_1", Expr: "(age > 18)"}, + {Id: "ck2", Name: "check_2", Expr: "(age < 99)"}, + } + + body, err := json.Marshal(checkConstraints) + assert.NoError(t, err) + + req, err := http.NewRequest("POST", "update/cks", bytes.NewBuffer(body)) + assert.NoError(t, err) + + q := req.URL.Query() + q.Add("table", tableID) + req.URL.RawQuery = q.Encode() + + rr := httptest.NewRecorder() + + handler := http.HandlerFunc(api.UpdateCheckConstraint) + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + + updatedSp := sessionState.Conv.SpSchema[tableID] + + assert.Equal(t, expectedCheckConstraint, updatedSp.CheckConstraint) +} + +func TestUpdateCheckConstraint_ParseError(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Driver = constants.MYSQL + sessionState.Conv = internal.MakeConv() + + invalidJSON := "invalid json body" + + rr := httptest.NewRecorder() + req, err := http.NewRequest("POST", "update/cks", io.NopCloser(strings.NewReader(invalidJSON))) + assert.NoError(t, err) + + handler := http.HandlerFunc(api.UpdateCheckConstraint) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusBadRequest, rr.Code) + + expectedErrorMessage := "Request Body parse error" + assert.Contains(t, rr.Body.String(), expectedErrorMessage) +} + +func (errReader) Read(p []byte) (n int, err error) { + return 0, fmt.Errorf("simulated read error") +} + +func TestUpdateCheckConstraint_ImproperSession(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Conv = nil // Simulate no conversion + + rr := httptest.NewRecorder() + req, err := http.NewRequest("POST", "update/cks", io.NopCloser(errReader{})) + assert.NoError(t, err) + + handler := http.HandlerFunc(api.UpdateCheckConstraint) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusInternalServerError, rr.Code) + assert.Contains(t, rr.Body.String(), "Schema is not converted or Driver is not configured properly") + +} + +func TestValidateCheckConstraint_ImproperSession(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Conv = nil // Simulate no conversion + + rr := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/validateCheckConstraint", nil) + + handler := http.HandlerFunc(api.ValidateCheckConstraint) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusNotFound, rr.Code) + assert.Contains(t, rr.Body.String(), "Schema is not converted or Driver is not configured properly") + +} + +func TestValidateCheckConstraint_NoTypeMismatch(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Driver = constants.MYSQL + sessionState.Conv = internal.MakeConv() + + buildConvMySQL_NoTypeMatch(sessionState.Conv) + rr1 := httptest.NewRecorder() + req1 := httptest.NewRequest("GET", "/spannerDefaultTypeMap", nil) + + handler1 := http.HandlerFunc(api.SpannerDefaultTypeMap) + handler1.ServeHTTP(rr1, req1) + + rr := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/validateCheckConstraint", nil) + + handler := http.HandlerFunc(api.ValidateCheckConstraint) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + var responseFlag bool + json.NewDecoder(rr.Body).Decode(&responseFlag) + assert.True(t, responseFlag) +} + +func TestValidateCheckConstraint_TypeMismatch(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Driver = constants.MYSQL + sessionState.Conv = internal.MakeConv() + + buildConvMySQL_TypeMatch(sessionState.Conv) + + rr1 := httptest.NewRecorder() + req1 := httptest.NewRequest("GET", "/spannerDefaultTypeMap", nil) + + handler1 := http.HandlerFunc(api.SpannerDefaultTypeMap) + handler1.ServeHTTP(rr1, req1) + + rr := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/validateCheckConstraint", nil) + + handler := http.HandlerFunc(api.ValidateCheckConstraint) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + var responseFlag bool + json.NewDecoder(rr.Body).Decode(&responseFlag) + assert.False(t, responseFlag) + issues := sessionState.Conv.SchemaIssues["t1"].ColumnLevelIssues["c2"] + assert.Contains(t, issues, internal.TypeMismatch) +} + +func buildConvMySQL_NoTypeMatch(conv *internal.Conv) { + conv.SrcSchema = map[string]schema.Table{ + "t1": { + Name: "table1", + Id: "t1", + ColIds: []string{"c1", "c2", "c3"}, + CheckConstraints: []schema.CheckConstraints{ + { + Id: "ck1", + Name: "check_1", + Expr: "age > 0", + }, + { + Id: "ck1", + Name: "check_2", + Expr: "age < 99", + }, + }, + ColDefs: map[string]schema.Column{ + "c1": {Name: "a", Id: "c1", Type: schema.Type{Name: "json"}}, + "c2": {Name: "b", Id: "c2", Type: schema.Type{Name: "decimal"}}, + "c3": {Name: "c", Id: "c3", Type: schema.Type{Name: "datetime"}}, + }, + PrimaryKeys: []schema.Key{{ColId: "c1"}}}, + } + conv.SpSchema = map[string]ddl.CreateTable{ + "t1": { + Name: "table1", + Id: "t1", + ColIds: []string{"c1", "c2", "c3"}, + CheckConstraint: []ddl.Checkconstraint{ + { + Id: "ck1", + Name: "check_1", + Expr: "age > 0", + }, + { + Id: "ck1", + Name: "check_2", + Expr: "age < 99", + }, + }, + ColDefs: map[string]ddl.ColumnDef{ + "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.JSON}}, + "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.Numeric}}, + "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.Timestamp}}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, + }, + } + + conv.SchemaIssues = map[string]internal.TableIssues{ + "t1": { + ColumnLevelIssues: map[string][]internal.SchemaIssue{ + "c1": {internal.Widened}, + "c2": {internal.Time}, + }, + }, + } + conv.SyntheticPKeys["t2"] = internal.SyntheticPKey{"c20", 0} + conv.Audit.MigrationType = migration.MigrationData_SCHEMA_AND_DATA.Enum() +} + +func buildConvMySQL_TypeMatch(conv *internal.Conv) { + conv.SrcSchema = map[string]schema.Table{ + "t1": { + Name: "table1", + Id: "t1", + ColIds: []string{"c1", "c2", "c3"}, + CheckConstraints: []schema.CheckConstraints{ + { + Id: "ck1", + Name: "check_1", + Expr: "age > 0", + }, + { + Id: "ck1", + Name: "check_2", + Expr: "age < 99", + }, + }, + ColDefs: map[string]schema.Column{ + "c1": {Name: "a", Id: "c1", Type: schema.Type{Name: "json"}}, + "c2": {Name: "age", Id: "c2", Type: schema.Type{Name: "decimal"}}, + "c3": {Name: "c", Id: "c3", Type: schema.Type{Name: "datetime"}}, + }, + PrimaryKeys: []schema.Key{{ColId: "c1"}}}, + } + conv.SpSchema = map[string]ddl.CreateTable{ + "t1": { + Name: "table1", + Id: "t1", + ColIds: []string{"c1", "c2", "c3"}, + CheckConstraint: []ddl.Checkconstraint{ + { + Id: "ck1", + Name: "check_1", + Expr: "age > 0", + }, + { + Id: "ck1", + Name: "check_2", + Expr: "age < 99", + }, + }, + ColDefs: map[string]ddl.ColumnDef{ + "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.JSON}}, + "c2": {Name: "age", Id: "c2", T: ddl.Type{Name: ddl.String}}, + "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.Timestamp}}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, + }, + } + + conv.SchemaIssues = map[string]internal.TableIssues{ + "t1": { + ColumnLevelIssues: map[string][]internal.SchemaIssue{ + "c1": {internal.Widened}, + "c2": {internal.Time}, + }, + }, + } + conv.SyntheticPKeys["t2"] = internal.SyntheticPKey{"c20", 0} + conv.Audit.MigrationType = migration.MigrationData_SCHEMA_AND_DATA.Enum() +} diff --git a/webv2/routes.go b/webv2/routes.go index bdfb3419c..bcb02d980 100644 --- a/webv2/routes.go +++ b/webv2/routes.go @@ -45,7 +45,7 @@ func getRoutes() *mux.Router { } ctx := context.Background() - spClient, _:= spinstanceadmin.NewInstanceAdminClientImpl(ctx) + spClient, _ := spinstanceadmin.NewInstanceAdminClientImpl(ctx) dsClient, _ := ds.NewDatastreamClientImpl(ctx) storageclient, _ := storageclient.NewStorageClientImpl(ctx) validateResourceImpl := conversion.NewValidateResourcesImpl(&spanneraccessor.SpannerAccessorImpl{}, spClient, &datastream_accessor.DatastreamAccessorImpl{}, @@ -76,6 +76,7 @@ func getRoutes() *mux.Router { router.HandleFunc("/spannerDefaultTypeMap", api.SpannerDefaultTypeMap).Methods("GET") router.HandleFunc("/autoGenMap", api.GetAutoGenMap).Methods("GET") router.HandleFunc("/getSequenceKind", api.GetSequenceKind).Methods("GET") + router.HandleFunc("/validateCheckConstraint", api.ValidateCheckConstraint).Methods("GET") router.HandleFunc("/setparent", api.SetParentTable).Methods("GET") router.HandleFunc("/removeParent", api.RemoveParentTable).Methods("POST") @@ -93,6 +94,7 @@ func getRoutes() *mux.Router { router.HandleFunc("/UpdateSequence", api.UpdateSequence).Methods("POST") router.HandleFunc("/update/fks", api.UpdateForeignKeys).Methods("POST") + router.HandleFunc("/update/cks", api.UpdateCheckConstraint).Methods("POST") router.HandleFunc("/update/indexes", api.UpdateIndexes).Methods("POST") // Session Management diff --git a/webv2/table/review_table_schema.go b/webv2/table/review_table_schema.go index 0ad6fb0ec..5a5d7c31b 100644 --- a/webv2/table/review_table_schema.go +++ b/webv2/table/review_table_schema.go @@ -108,6 +108,13 @@ func ReviewTableSchema(w http.ResponseWriter, r *http.Request) { return } } + oldName := conv.SrcSchema[tableId].ColDefs[colId].Name + + for i := range conv.SpSchema[tableId].CheckConstraint { + originalString := conv.SpSchema[tableId].CheckConstraint[i].Expr + updatedValue := strings.ReplaceAll(originalString, oldName, v.Rename) + conv.SpSchema[tableId].CheckConstraint[i].Expr = updatedValue + } interleaveTableSchema = reviewRenameColumn(v.Rename, tableId, colId, conv, interleaveTableSchema) @@ -148,7 +155,7 @@ func ReviewTableSchema(w http.ResponseWriter, r *http.Request) { } } - if !v.Removed && !v.Add && v.Rename== ""{ + if !v.Removed && !v.Add && v.Rename == "" { sequences := UpdateAutoGenCol(v.AutoGen, tableId, colId, conv) conv.SpSequences = sequences } diff --git a/webv2/table/update_table_schema.go b/webv2/table/update_table_schema.go index 6e03c709c..76eaeca3a 100644 --- a/webv2/table/update_table_schema.go +++ b/webv2/table/update_table_schema.go @@ -19,6 +19,7 @@ import ( "fmt" "io/ioutil" "net/http" + "strings" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" @@ -55,6 +56,7 @@ type updateTable struct { // (3) Rename column. // (4) Add or Remove NotNull constraint. // (5) Update Spanner type. +// (6) Update Check constraints Name. func UpdateTableSchema(w http.ResponseWriter, r *http.Request) { reqBody, err := ioutil.ReadAll(r.Body) @@ -96,6 +98,14 @@ func UpdateTableSchema(w http.ResponseWriter, r *http.Request) { if v.Rename != "" && v.Rename != conv.SpSchema[tableId].ColDefs[colId].Name { + oldName := conv.SrcSchema[tableId].ColDefs[colId].Name + + for i := range conv.SpSchema[tableId].CheckConstraint { + originalString := conv.SpSchema[tableId].CheckConstraint[i].Expr + updatedValue := strings.ReplaceAll(originalString, oldName, v.Rename) + conv.SpSchema[tableId].CheckConstraint[i].Expr = updatedValue + } + renameColumn(v.Rename, tableId, colId, conv) }