From 534145a25f4c095d0e9bee7ba646275d98cc642f Mon Sep 17 00:00:00 2001 From: Taher Lakdawala <78196491+taherkl@users.noreply.github.com> Date: Thu, 2 Jan 2025 17:39:33 +0530 Subject: [PATCH] verification API and dump flow changes to support MySQL CHECK CONSTRAINTS (#978) * verification ap and dump flow changes * fixed IT issue * Check constraints verificartion api v2 (#24) * handled function not found * added unhandled error * updated the error msg --------- Co-authored-by: Vivek Yadav * fix IT issue * comment addressed (#27) * comment addressed 1. rename the functionNotFound 2. added condition to call verification api * spell checked --------- Co-authored-by: Vivek Yadav * refactor the DbDumpImpl struct (#28) * refactor the DbDumpImpl struct * remove the GenerateCheckConstrainstExprId method --------- Co-authored-by: Vivek Yadav * fixed if condition --------- Co-authored-by: taherkl Co-authored-by: Vivek Yadav Co-authored-by: Vivek Yadav <105432992+VivekY1098@users.noreply.github.com> --- common/utils/utils.go | 4 +- conversion/conversion.go | 4 +- conversion/conversion_from_source.go | 7 +- conversion/conversion_helper.go | 7 +- expressions_api/expression_verify.go | 3 + internal/convert.go | 4 + internal/reports/report_helpers.go | 59 +++++- mocks/expressions_api_mock.go | 24 +++ schema/schema.go | 7 +- sources/common/dbdump.go | 5 +- sources/common/toddl.go | 157 ++++++++++++++- sources/common/toddl_test.go | 181 ++++++++++++++++-- sources/dynamodb/schema_test.go | 24 ++- sources/dynamodb/toddl_test.go | 25 ++- sources/mysql/infoschema.go | 7 +- sources/mysql/infoschema_test.go | 42 ++-- sources/mysql/mysqldump.go | 56 +++++- sources/mysql/mysqldump_test.go | 14 +- sources/mysql/report_test.go | 12 +- sources/mysql/toddl_test.go | 23 ++- sources/oracle/infoschema_test.go | 15 +- sources/oracle/toddl_test.go | 23 ++- sources/postgres/infoschema_test.go | 23 ++- sources/postgres/pgdump.go | 3 +- sources/postgres/pgdump_test.go | 34 +++- sources/postgres/report_test.go | 12 +- sources/postgres/toddl_test.go | 23 ++- sources/sqlserver/infoschema_test.go | 14 +- sources/sqlserver/toddl_test.go | 23 ++- spanner/ddl/ast.go | 7 +- testing/csv/integration_test.go | 2 +- testing/dynamodb/snapshot/integration_test.go | 2 +- .../dynamodb/streaming/integration_test.go | 2 +- testing/mysql/integration_test.go | 10 +- testing/oracle/integration_test.go | 5 +- testing/postgres/golden_test.go | 13 +- testing/postgres/integration_test.go | 12 +- testing/sqlserver/integration_test.go | 5 +- webv2/api/schema.go | 90 +++++---- webv2/api/schema_test.go | 151 +++++++++++++++ webv2/routes.go | 11 +- 41 files changed, 987 insertions(+), 158 deletions(-) create mode 100644 mocks/expressions_api_mock.go diff --git a/common/utils/utils.go b/common/utils/utils.go index a5b56859e..1c4fbd1da 100644 --- a/common/utils/utils.go +++ b/common/utils/utils.go @@ -446,12 +446,14 @@ func GetLegacyModeSupportedDrivers() []string { func ReadSpannerSchema(ctx context.Context, conv *internal.Conv, client *sp.Client) error { infoSchema := spanner.InfoSchemaImpl{Client: client, Ctx: ctx, SpDialect: conv.SpDialect} processSchema := common.ProcessSchemaImpl{} + expressionVerificationAccessor, _ := expressions_api.NewExpressionVerificationAccessorImpl(ctx, conv.SpProjectId, conv.SpInstanceId) ddlVerifier, err := expressions_api.NewDDLVerifierImpl(ctx, conv.SpProjectId, conv.SpInstanceId) if err != nil { return fmt.Errorf("error trying create ddl verifier: %v", err) } schemaToSpanner := common.SchemaToSpannerImpl{ - DdlV: ddlVerifier, + DdlV: ddlVerifier, + ExpressionVerificationAccessor: expressionVerificationAccessor, } err = processSchema.ProcessSchema(conv, infoSchema, common.DefaultWorkers, internal.AdditionalSchemaAttributes{IsSharded: false}, &schemaToSpanner, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) if err != nil { diff --git a/conversion/conversion.go b/conversion/conversion.go index f22098380..b200ec873 100644 --- a/conversion/conversion.go +++ b/conversion/conversion.go @@ -37,6 +37,7 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/task" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/expressions_api" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal/reports" "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" @@ -79,7 +80,8 @@ func (ci *ConvImpl) SchemaConv(migrationProjectId string, sourceProfile profiles case constants.POSTGRES, constants.MYSQL, constants.DYNAMODB, constants.SQLSERVER, constants.ORACLE: return schemaFromSource.schemaFromDatabase(migrationProjectId, sourceProfile, targetProfile, &GetInfoImpl{}, &common.ProcessSchemaImpl{}) case constants.PGDUMP, constants.MYSQLDUMP: - return schemaFromSource.SchemaFromDump(targetProfile.Conn.Sp.Project, targetProfile.Conn.Sp.Instance, sourceProfile.Driver, targetProfile.Conn.Sp.Dialect, ioHelper, &ProcessDumpByDialectImpl{}) + expressionVerificationAccessor, _ := expressions_api.NewExpressionVerificationAccessorImpl(context.Background(), targetProfile.Conn.Sp.Project, targetProfile.Conn.Sp.Instance) + return schemaFromSource.SchemaFromDump(targetProfile.Conn.Sp.Project, targetProfile.Conn.Sp.Instance, sourceProfile.Driver, targetProfile.Conn.Sp.Dialect, ioHelper, &ProcessDumpByDialectImpl{ExpressionVerificationAccessor: expressionVerificationAccessor}) default: return nil, fmt.Errorf("schema conversion for driver %s not supported", sourceProfile.Driver) } diff --git a/conversion/conversion_from_source.go b/conversion/conversion_from_source.go index e92b8608e..c6a8bb35e 100644 --- a/conversion/conversion_from_source.go +++ b/conversion/conversion_from_source.go @@ -102,8 +102,12 @@ func (sads *SchemaFromSourceImpl) schemaFromDatabase(migrationProjectId string, additionalSchemaAttributes := internal.AdditionalSchemaAttributes{ IsSharded: isSharded, } + + ctx := context.Background() + expressionVerificationAccessor, _ := expressions_api.NewExpressionVerificationAccessorImpl(ctx, conv.SpProjectId, conv.SpInstanceId) schemaToSpanner := common.SchemaToSpannerImpl{ - DdlV: sads.DdlVerifier, + DdlV: sads.DdlVerifier, + ExpressionVerificationAccessor: expressionVerificationAccessor, } return conv, processSchema.ProcessSchema(conv, infoSchema, common.DefaultWorkers, additionalSchemaAttributes, &schemaToSpanner, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) } @@ -118,6 +122,7 @@ func (sads *SchemaFromSourceImpl) SchemaFromDump(SpProjectId string, SpInstanceI ioHelper.BytesRead = n conv := internal.MakeConv() conv.SpDialect = spDialect + conv.Source = driver p := internal.NewProgress(n, "Generating schema", internal.Verbose(), false, int(internal.SchemaCreationInProgress)) r := internal.NewReader(bufio.NewReader(f), p) conv.SetSchemaMode() // Build schema and ignore data in dump. diff --git a/conversion/conversion_helper.go b/conversion/conversion_helper.go index f56cde213..c1b7429e4 100644 --- a/conversion/conversion_helper.go +++ b/conversion/conversion_helper.go @@ -47,7 +47,8 @@ type ProcessDumpByDialectInterface interface { } type ProcessDumpByDialectImpl struct { - DdlVerifier expressions_api.DDLVerifier + ExpressionVerificationAccessor expressions_api.ExpressionVerificationAccessor + DdlVerifier expressions_api.DDLVerifier } type PopulateDataConvInterface interface { @@ -92,9 +93,9 @@ func getSeekable(f *os.File) (*os.File, int64, error) { func (pdd *ProcessDumpByDialectImpl) ProcessDump(driver string, conv *internal.Conv, r *internal.Reader) error { switch driver { case constants.MYSQLDUMP: - return common.ProcessDbDump(conv, r, mysql.DbDumpImpl{}, pdd.DdlVerifier) + return common.ProcessDbDump(conv, r, mysql.DbDumpImpl{}, pdd.DdlVerifier, pdd.ExpressionVerificationAccessor) case constants.PGDUMP: - return common.ProcessDbDump(conv, r, postgres.DbDumpImpl{}, pdd.DdlVerifier) + return common.ProcessDbDump(conv, r, postgres.DbDumpImpl{}, pdd.DdlVerifier, pdd.ExpressionVerificationAccessor) default: return fmt.Errorf("process dump for driver %s not supported", driver) } diff --git a/expressions_api/expression_verify.go b/expressions_api/expression_verify.go index f33c030d9..dd3ebb3b1 100644 --- a/expressions_api/expression_verify.go +++ b/expressions_api/expression_verify.go @@ -160,6 +160,9 @@ func (ev *ExpressionVerificationAccessorImpl) removeExpressions(inputConv *inter //TODO: Implement similar checks for DEFAULT and CHECK constraints as well convCopy.SpSequences = nil for _, table := range convCopy.SpSchema { + table.CheckConstraints = []ddl.CheckConstraint{} + convCopy.SpSchema[table.Id] = table + for colName, colDef := range table.ColDefs { colDef.AutoGen = ddl.AutoGenCol{} colDef.DefaultValue = ddl.DefaultValue{} diff --git a/internal/convert.go b/internal/convert.go index d2e892131..068dff661 100644 --- a/internal/convert.go +++ b/internal/convert.go @@ -133,6 +133,10 @@ const ( NumericPKNotSupported TypeMismatch DefaultValueError + InvalidCondition + ColumnNotFound + CheckConstraintFunctionNotFound + GenericError ) const ( diff --git a/internal/reports/report_helpers.go b/internal/reports/report_helpers.go index 6abbe88b5..d95922615 100644 --- a/internal/reports/report_helpers.go +++ b/internal/reports/report_helpers.go @@ -111,6 +111,45 @@ func buildTableReportBody(conv *internal.Conv, tableId string, issues map[string } + // added if to add table level issue + if p.severity == Errors && len(tableLevelIssues) != 0 { + for _, issue := range tableLevelIssues { + switch issue { + case internal.TypeMismatch: + toAppend := Issue{ + Category: IssueDB[issue].Category, + Description: fmt.Sprintf("Table '%s': Type mismatch in check constraint. Verify that the column type matches the constraint logic.", conv.SpSchema[tableId].Name), + } + l = append(l, toAppend) + case internal.InvalidCondition: + toAppend := Issue{ + Category: IssueDB[issue].Category, + Description: fmt.Sprintf("Table '%s': Invalid condition in check constraint. Ensure the condition is compatible with the constraint logic.", conv.SpSchema[tableId].Name), + } + l = append(l, toAppend) + case internal.ColumnNotFound: + toAppend := Issue{ + Category: IssueDB[issue].Category, + Description: fmt.Sprintf("Table '%s': Column not found in check constraint. Verify that all referenced columns exist.", conv.SpSchema[tableId].Name), + } + l = append(l, toAppend) + + case internal.CheckConstraintFunctionNotFound: + toAppend := Issue{ + Category: IssueDB[issue].Category, + Description: fmt.Sprintf("Table '%s': Function not found in check constraint. Ensure all functions used in the condition are valid.", conv.SpSchema[tableId].Name), + } + l = append(l, toAppend) + case internal.GenericError: + toAppend := Issue{ + Category: IssueDB[issue].Category, + Description: fmt.Sprintf("Table '%s': Something went wrong in check constraint. Verify the conditions and constraint logic.", conv.SpSchema[tableId].Name), + } + l = append(l, toAppend) + } + } + } + if p.severity == warning { flag := false for _, spFk := range conv.SpSchema[tableId].ForeignKeys { @@ -118,7 +157,7 @@ func buildTableReportBody(conv *internal.Conv, tableId string, issues map[string if err != nil { continue } - if srcFk.OnDelete == "" && srcFk.OnUpdate == "" && flag == false { + if srcFk.OnDelete == "" && srcFk.OnUpdate == "" && !flag { flag = true issue := internal.ForeignKeyActionNotSupported toAppend := Issue{ @@ -403,18 +442,13 @@ func buildTableReportBody(conv *internal.Conv, tableId string, issues map[string Description: fmt.Sprintf("UNIQUE constraint on column(s) '%s' replaced with primary key since table '%s' didn't have one. Spanner requires a primary key for every table", strings.Join(uniquePK, ", "), conv.SpSchema[tableId].Name), } l = append(l, toAppend) + case internal.DefaultValueError: toAppend := Issue{ Category: IssueDB[i].Category, Description: fmt.Sprintf("%s for table '%s' column '%s'", IssueDB[i].Brief, conv.SpSchema[tableId].Name, spColName), } l = append(l, toAppend) - case internal.TypeMismatch: - toAppend := Issue{ - Category: IssueDB[i].Category, - Description: fmt.Sprintf("Table '%s': Type mismatch in '%s'column affecting check constraints. Verify data type compatibility with constraint logic", conv.SpSchema[tableId].Name, conv.SpSchema[tableId].ColDefs[colId].Name), - } - l = append(l, toAppend) default: toAppend := Issue{ Category: IssueDB[i].Category, @@ -526,9 +560,14 @@ var IssueDB = map[internal.SchemaIssue]struct { Category string // Standarized issue type CategoryDescription string }{ - internal.DefaultValue: {Brief: "Some columns have default values which Spanner migration tool does not migrate. Please add the default constraints manually after the migration is complete", Severity: note, batch: true, Category: "MISSING_DEFAULT_VALUE_CONSTRAINTS"}, - internal.ForeignKey: {Brief: "Spanner does not support foreign keys", Severity: warning, Category: "FOREIGN_KEY_USES"}, - internal.MultiDimensionalArray: {Brief: "Spanner doesn't support multi-dimensional arrays", Severity: warning, Category: "MULTI_DIMENSIONAL_ARRAY_USES"}, + internal.DefaultValue: {Brief: "Some columns have default values which Spanner migration tool does not migrate. Please add the default constraints manually after the migration is complete", Severity: note, batch: true, Category: "MISSING_DEFAULT_VALUE_CONSTRAINTS"}, + internal.TypeMismatch: {Brief: "Type mismatch in check constraint mention in table", Severity: warning, Category: "TYPE_MISMATCH"}, + internal.InvalidCondition: {Brief: "Invalid condition in check constraint mention in table", Severity: warning, Category: "INVALID_CONDITION"}, + internal.ColumnNotFound: {Brief: "Column not found in check constraint mention in the table", Severity: warning, Category: "COLUMN_NOT_FOUND"}, + internal.CheckConstraintFunctionNotFound: {Brief: "Function not found in check constraint mention in the table", Severity: warning, Category: "FUNCTION_NOT_FOUND"}, + internal.GenericError: {Brief: "Something went wrong", Severity: warning, Category: "UNHANDLE_ERROR"}, + internal.ForeignKey: {Brief: "Spanner does not support foreign keys", Severity: warning, Category: "FOREIGN_KEY_USES"}, + internal.MultiDimensionalArray: {Brief: "Spanner doesn't support multi-dimensional arrays", Severity: warning, Category: "MULTI_DIMENSIONAL_ARRAY_USES"}, internal.NoGoodType: {Brief: "No appropriate Spanner type. The column will be made nullable in Spanner", Severity: warning, Category: "INAPPROPRIATE_TYPE", CategoryDescription: "No appropriate Spanner type"}, internal.Numeric: {Brief: "Spanner does not support numeric. This type mapping could lose precision and is not recommended for production use", Severity: warning, Category: "NUMERIC_USES"}, diff --git a/mocks/expressions_api_mock.go b/mocks/expressions_api_mock.go new file mode 100644 index 000000000..b53dd93be --- /dev/null +++ b/mocks/expressions_api_mock.go @@ -0,0 +1,24 @@ +package mocks + +import ( + "context" + + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/stretchr/testify/mock" +) + +// MockExpressionVerificationAccessor is a mock of ExpressionVerificationAccessor +type MockExpressionVerificationAccessor struct { + mock.Mock +} + +// VerifyExpressions is a mocked method for expression verification +func (m *MockExpressionVerificationAccessor) VerifyExpressions(ctx context.Context, input internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput { + args := m.Called(ctx, input) + return args.Get(0).(internal.VerifyExpressionsOutput) +} + +func (m *MockExpressionVerificationAccessor) RefreshSpannerClient(ctx context.Context, project, instance string ) error { + args := m.Called(ctx, project, instance) + return args.Get(0).(error) +} diff --git a/schema/schema.go b/schema/schema.go index eab021cf0..9119501b3 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -80,9 +80,10 @@ type ForeignKey struct { // CheckConstraints represents a check constraint defined in the schema. type CheckConstraint struct { - Name string - Expr string - Id string + Name string + Expr string + ExprId string + Id string } // Key respresents a primary key or index key. diff --git a/sources/common/dbdump.go b/sources/common/dbdump.go index 02bf75ab6..71740fbe8 100644 --- a/sources/common/dbdump.go +++ b/sources/common/dbdump.go @@ -30,7 +30,7 @@ type DbDump interface { // In schema mode, this method incrementally builds a schema (updating conv). // In data mode, this method uses this schema to convert data and writes it // to Spanner, using the data sink specified in conv. -func ProcessDbDump(conv *internal.Conv, r *internal.Reader, dbDump DbDump, ddlVerifier expressions_api.DDLVerifier) error { +func ProcessDbDump(conv *internal.Conv, r *internal.Reader, dbDump DbDump, ddlVerifier expressions_api.DDLVerifier, exprVerifier expressions_api.ExpressionVerificationAccessor) error { if err := dbDump.ProcessDump(conv, r); err != nil { return err } @@ -39,7 +39,8 @@ func ProcessDbDump(conv *internal.Conv, r *internal.Reader, dbDump DbDump, ddlVe utilsOrder.initPrimaryKeyOrder(conv) utilsOrder.initIndexOrder(conv) schemaToSpanner := SchemaToSpannerImpl{ - DdlV: ddlVerifier, + ExpressionVerificationAccessor: exprVerifier, + DdlV: ddlVerifier, } schemaToSpanner.SchemaToSpannerDDL(conv, dbDump.GetToDdl()) conv.AddPrimaryKeys() diff --git a/sources/common/toddl.go b/sources/common/toddl.go index b23d2d6f7..e73e681cd 100644 --- a/sources/common/toddl.go +++ b/sources/common/toddl.go @@ -30,6 +30,7 @@ While adding new methods or code here package common import ( + "context" "fmt" "reflect" "strconv" @@ -59,7 +60,16 @@ type SchemaToSpannerInterface interface { } type SchemaToSpannerImpl struct { - DdlV expressions_api.DDLVerifier + ExpressionVerificationAccessor expressions_api.ExpressionVerificationAccessor + DdlV expressions_api.DDLVerifier +} + +var ErrorTypeMapping = map[string]internal.SchemaIssue{ + "No matching signature for operator": internal.TypeMismatch, + "Syntax error": internal.InvalidCondition, + "Unrecognized name": internal.ColumnNotFound, + "Function not found": internal.CheckConstraintFunctionNotFound, + "unhandled error": internal.GenericError, } // SchemaToSpannerDDL performs schema conversion from the source DB schema to @@ -75,7 +85,17 @@ func (ss *SchemaToSpannerImpl) SchemaToSpannerDDL(conv *internal.Conv, toddl ToD srcTable := conv.SrcSchema[tableId] ss.SchemaToSpannerDDLHelper(conv, toddl, srcTable, false) } - if conv.Source == constants.MYSQL && conv.SpProjectId!="" && conv.SpInstanceId!=""{ + + if (conv.Source == constants.MYSQL || conv.Source == constants.MYSQLDUMP) && conv.SpProjectId != "" && conv.SpInstanceId != "" { + // Process and verify Check constraints for MySQL and MySQLDump flow only + err := ss.VerifyExpressions(conv) + if err != nil { + return err + } + } + + if conv.Source == constants.MYSQL && conv.SpProjectId != "" && conv.SpInstanceId != "" { + // Process and verify Spanner DDL expressions for MYSQL expressionDetails := ss.DdlV.GetSourceExpressionDetails(conv, tableIds) expressions, err := ss.DdlV.VerifySpannerDDL(conv, expressionDetails) if err != nil && !strings.Contains(err.Error(), "expressions either failed verification") { @@ -87,6 +107,131 @@ func (ss *SchemaToSpannerImpl) SchemaToSpannerDDL(conv *internal.Conv, toddl ToD return nil } +// GenerateExpressionDetailList it will generate the expression detail list which is used in verify expression method as a input +func GenerateExpressionDetailList(spschema ddl.Schema) []internal.ExpressionDetail { + expressionDetailList := []internal.ExpressionDetail{} + for _, sp := range spschema { + for _, cc := range sp.CheckConstraints { + + expressionDetail := internal.ExpressionDetail{ + Expression: cc.Expr, + Type: "CHECK", + ReferenceElement: internal.ReferenceElement{Name: sp.Name}, + ExpressionId: cc.ExprId, + Metadata: map[string]string{"tableId": sp.Id}, + } + expressionDetailList = append(expressionDetailList, expressionDetail) + } + } + + return expressionDetailList +} + +// RemoveError it will reset the table issue before re-populating +func RemoveError(tableIssues map[string]internal.TableIssues) map[string]internal.TableIssues { + + for tableId, TableIssues := range tableIssues { + for _, issue := range ErrorTypeMapping { + removedIssue := removeSchemaIssue(TableIssues.TableLevelIssues, issue) + TableIssues.TableLevelIssues = removedIssue + tableIssues[tableId] = TableIssues + } + + } + return tableIssues + +} + +// GetIssue it will collect all the error and return it +func GetIssue(result internal.VerifyExpressionsOutput) map[string][]internal.SchemaIssue { + exprOutputsByTable := make(map[string][]internal.ExpressionVerificationOutput) + issues := make(map[string][]internal.SchemaIssue) + for _, ev := range result.ExpressionVerificationOutputList { + if !ev.Result { + tableId := ev.ExpressionDetail.Metadata["tableId"] + exprOutputsByTable[tableId] = append(exprOutputsByTable[tableId], ev) + } + } + + for tableId, exprOutputs := range exprOutputsByTable { + + for _, ev := range exprOutputs { + var issue internal.SchemaIssue + + switch { + case strings.Contains(ev.Err.Error(), "No matching signature for operator"): + issue = internal.TypeMismatch + case strings.Contains(ev.Err.Error(), "Syntax error"): + issue = internal.InvalidCondition + case strings.Contains(ev.Err.Error(), "Unrecognized name"): + issue = internal.ColumnNotFound + case strings.Contains(ev.Err.Error(), "Function not found"): + issue = internal.CheckConstraintFunctionNotFound + default: + issue = internal.GenericError + } + issues[tableId] = append(issues[tableId], issue) + + } + + } + + return issues + +} + +// VerifyExpression this function will use expression_api to validate check constraint expressions and add the relevant error +// to suggestion tab and remove the check constraint which has error +func (ss *SchemaToSpannerImpl) VerifyExpressions(conv *internal.Conv) error { + ctx := context.Background() + + spschema := conv.SpSchema + + verifyExpressionsInput := internal.VerifyExpressionsInput{ + Conv: conv, + Source: conv.Source, + ExpressionDetailList: GenerateExpressionDetailList(spschema), + } + + result := ss.ExpressionVerificationAccessor.VerifyExpressions(ctx, verifyExpressionsInput) + if result.ExpressionVerificationOutputList == nil { + return result.Err + } + issueTypes := GetIssue(result) + if len(issueTypes) > 0 { + for tableId, issues := range issueTypes { + + for _, issue := range issues { + if _, exists := conv.SchemaIssues[tableId]; !exists { + conv.SchemaIssues[tableId] = internal.TableIssues{ + TableLevelIssues: []internal.SchemaIssue{}, + } + } + + tableIssue := conv.SchemaIssues[tableId] + + if !IsSchemaIssuePresent(tableIssue.TableLevelIssues, issue) { + tableIssue.TableLevelIssues = append(tableIssue.TableLevelIssues, issue) + } + conv.SchemaIssues[tableId] = tableIssue + } + } + } + + return nil +} + +// IsSchemaIssuePresent checks if issue is present in the given schemaissue list. +func IsSchemaIssuePresent(schemaissue []internal.SchemaIssue, issue internal.SchemaIssue) bool { + + for _, s := range schemaissue { + if s == issue { + return true + } + } + return false +} + func (ss *SchemaToSpannerImpl) SchemaToSpannerDDLHelper(conv *internal.Conv, toddl ToDdl, srcTable schema.Table, isRestore bool) error { spTableName, err := internal.GetSpannerTable(conv, srcTable.Id) if err != nil { @@ -254,9 +399,10 @@ func cvtCheckConstraint(conv *internal.Conv, srcKeys []schema.CheckConstraint) [ for _, cc := range srcKeys { spcc = append(spcc, ddl.CheckConstraint{ - Id: cc.Id, - Name: internal.ToSpannerCheckConstraintName(conv, cc.Name), - Expr: cc.Expr, + Id: cc.Id, + Name: internal.ToSpannerCheckConstraintName(conv, cc.Name), + Expr: cc.Expr, + ExprId: cc.ExprId, }) } return spcc @@ -330,6 +476,7 @@ func SrcTableToSpannerDDL(conv *internal.Conv, toddl ToDdl, srcTable schema.Tabl conv.SpSchema[tableId] = spTable } } + internal.ResolveRefs(conv) return nil } diff --git a/sources/common/toddl_test.go b/sources/common/toddl_test.go index 4a85db225..127aaa35c 100644 --- a/sources/common/toddl_test.go +++ b/sources/common/toddl_test.go @@ -15,15 +15,19 @@ package common import ( + "context" + "errors" "reflect" "testing" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/expressions_api" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/mocks" "github.com/GoogleCloudPlatform/spanner-migration-tool/schema" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) func Test_quoteIfNeeded(t *testing.T) { @@ -437,36 +441,42 @@ func Test_cvtCheckContraint(t *testing.T) { conv := internal.MakeConv() srcSchema := []schema.CheckConstraint{ { - Id: "cc1", - Name: "check_1", - Expr: "age > 0", + Id: "cc1", + Name: "check_1", + Expr: "age > 0", + ExprId: "expr1", }, { - Id: "cc2", - Name: "check_2", - Expr: "age < 99", + Id: "cc2", + Name: "check_2", + Expr: "age < 99", + ExprId: "expr2", }, { - Id: "cc3", - Name: "@invalid_name", // incompatabile name - Expr: "age != 0", + Id: "cc3", + Name: "@invalid_name", // incompatabile name + Expr: "age != 0", + ExprId: "expr3", }, } spSchema := []ddl.CheckConstraint{ { - Id: "cc1", - Name: "check_1", - Expr: "age > 0", + Id: "cc1", + Name: "check_1", + Expr: "age > 0", + ExprId: "expr1", }, { - Id: "cc2", - Name: "check_2", - Expr: "age < 99", + Id: "cc2", + Name: "check_2", + Expr: "age < 99", + ExprId: "expr2", }, { - Id: "cc3", - Name: "Ainvalid_name", - Expr: "age != 0", + Id: "cc3", + Name: "Ainvalid_name", + Expr: "age != 0", + ExprId: "expr3", }, } result := cvtCheckConstraint(conv, srcSchema) @@ -581,3 +591,138 @@ func TestSpannerSchemaApplyExpressions(t *testing.T) { }) } } + +func TestVerifyCheckConstraintExpressions(t *testing.T) { + tests := []struct { + name string + expressions []ddl.CheckConstraint + expectedResults []internal.ExpressionVerificationOutput + expectedCheckConstraint []ddl.CheckConstraint + expectedResponse bool + }{ + { + name: "AllValidExpressions", + expressions: []ddl.CheckConstraint{ + {Expr: "(col1 > 0)", ExprId: "expr1", Name: "check1"}, + }, + expectedResults: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1"}, ExpressionId: "expr1"}}, + }, + expectedCheckConstraint: []ddl.CheckConstraint{ + {Expr: "(col1 > 0)", ExprId: "expr1", Name: "check1"}, + }, + expectedResponse: false, + }, + { + name: "InvalidSyntaxError", + expressions: []ddl.CheckConstraint{ + {Expr: "(col1 > 0)", ExprId: "expr1", Name: "check1"}, + {Expr: "(col1 > 18", ExprId: "expr2", Name: "check2"}, + }, + expectedResults: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1"}, ExpressionId: "expr1"}}, + {Result: false, Err: errors.New("Syntax error ..."), ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 18", Type: "CHECK", Metadata: map[string]string{"tableId": "t1"}, ExpressionId: "expr2"}}, + }, + expectedCheckConstraint: []ddl.CheckConstraint{ + {Expr: "(col1 > 0)", ExprId: "expr1", Name: "check1"}, + {Expr: "(col1 > 18", ExprId: "expr2", Name: "check2"}, + }, + expectedResponse: true, + }, + { + name: "NameError", + expressions: []ddl.CheckConstraint{ + {Expr: "(col1 > 0)", ExprId: "expr1", Name: "check1"}, + {Expr: "(col1 > 18)", ExprId: "expr2", Name: "check2"}, + }, + expectedResults: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1"}, ExpressionId: "expr1"}}, + {Result: false, Err: errors.New("Unrecognized name ..."), ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 18)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1"}, ExpressionId: "expr2"}}, + }, + expectedCheckConstraint: []ddl.CheckConstraint{ + {Expr: "(col1 > 0)", ExprId: "expr1", Name: "check1"}, + {Expr: "(col1 > 18)", ExprId: "expr2", Name: "check2"}, + }, + expectedResponse: true, + }, + { + name: "TypeError", + expressions: []ddl.CheckConstraint{ + {Expr: "(col1 > 0)", ExprId: "expr1", Name: "check1"}, + {Expr: "(col1 > 18)", ExprId: "expr2", Name: "check2"}, + }, + expectedResults: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1"}, ExpressionId: "expr1"}}, + {Result: false, Err: errors.New("No matching signature for operator"), ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 18)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1"}, ExpressionId: "expr2"}}, + }, + expectedCheckConstraint: []ddl.CheckConstraint{ + {Expr: "(col1 > 0)", ExprId: "expr1", Name: "check1"}, + {Expr: "(col1 > 18)", ExprId: "expr2", Name: "check2"}, + }, + expectedResponse: true, + }, + { + name: "FunctionError", + expressions: []ddl.CheckConstraint{ + {Expr: "(col1 > 0)", ExprId: "expr1", Name: "check1"}, + {Expr: "(col1 > 18)", ExprId: "expr2", Name: "check2"}, + }, + expectedResults: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1"}, ExpressionId: "expr1"}}, + {Result: false, Err: errors.New("Function not found"), ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 18)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1"}, ExpressionId: "expr2"}}, + }, + expectedCheckConstraint: []ddl.CheckConstraint{ + {Expr: "(col1 > 0)", ExprId: "expr1", Name: "check1"}, + {Expr: "(col1 > 18)", ExprId: "expr2", Name: "check2"}, + }, + expectedResponse: true, + }, + { + name: "GenericError", + expressions: []ddl.CheckConstraint{ + {Expr: "(col1 > 0)", ExprId: "expr1", Name: "check1"}, + {Expr: "(col1 > 18)", ExprId: "expr2", Name: "check2"}, + }, + expectedResults: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1"}, ExpressionId: "expr1"}}, + {Result: false, Err: errors.New("Unhandle error"), ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 18)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1"}, ExpressionId: "expr2"}}, + }, + expectedCheckConstraint: []ddl.CheckConstraint{ + {Expr: "(col1 > 0)", ExprId: "expr1", Name: "check1"}, + {Expr: "(col1 > 18)", ExprId: "expr2", Name: "check2"}, + }, + expectedResponse: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + handler := &SchemaToSpannerImpl{ExpressionVerificationAccessor: mockAccessor} + + conv := internal.MakeConv() + + ctx := context.Background() + + conv.SpSchema = map[string]ddl.CreateTable{ + "t1": { + Name: "table1", + Id: "t1", + PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, + ColIds: []string{"c1"}, + ColDefs: map[string]ddl.ColumnDef{ + "c1": {Name: "col1", Id: "c1", T: ddl.Type{Name: ddl.Int64}}, + }, + CheckConstraints: tc.expressions, + }, + } + + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: tc.expectedResults, + }) + handler.VerifyExpressions(conv) + assert.Equal(t, conv.SpSchema["t1"].CheckConstraints, tc.expectedCheckConstraint) + + }) + } +} diff --git a/sources/dynamodb/schema_test.go b/sources/dynamodb/schema_test.go index ef02fcf4f..eb792b345 100644 --- a/sources/dynamodb/schema_test.go +++ b/sources/dynamodb/schema_test.go @@ -26,11 +26,13 @@ import ( "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" "github.com/aws/aws-sdk-go/service/dynamodbstreams" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "go.uber.org/zap" "github.com/GoogleCloudPlatform/spanner-migration-tool/expressions_api" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/GoogleCloudPlatform/spanner-migration-tool/mocks" "github.com/GoogleCloudPlatform/spanner-migration-tool/schema" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" @@ -187,11 +189,19 @@ func TestProcessSchema(t *testing.T) { scanOutputs: scanOutputs, } sampleSize := int64(10000) + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) conv := internal.MakeConv() processSchema := common.ProcessSchemaImpl{} schemaToSpanner := &common.SchemaToSpannerImpl{ - DdlV: &expressions_api.MockDDLVerifier{}, + ExpressionVerificationAccessor: mockAccessor, + DdlV: &expressions_api.MockDDLVerifier{}, } err := processSchema.ProcessSchema(conv, InfoSchemaImpl{client, nil, sampleSize}, 1, internal.AdditionalSchemaAttributes{}, schemaToSpanner, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) @@ -293,9 +303,17 @@ func TestProcessSchema_FullDataTypes(t *testing.T) { sampleSize := int64(10000) conv := internal.MakeConv() + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) processSchema := common.ProcessSchemaImpl{} schemaToSpanner := &common.SchemaToSpannerImpl{ - DdlV: &expressions_api.MockDDLVerifier{}, + ExpressionVerificationAccessor: mockAccessor, + DdlV: &expressions_api.MockDDLVerifier{}, } err := processSchema.ProcessSchema(conv, InfoSchemaImpl{client, nil, sampleSize}, 1, internal.AdditionalSchemaAttributes{}, schemaToSpanner, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) @@ -663,7 +681,7 @@ func TestInfoSchemaImpl_GetTables(t *testing.T) { isi := InfoSchemaImpl{client, nil, 10} tables, err := isi.GetTables() assert.Nil(t, err) - assert.Equal(t, []common.SchemaAndName{{Schema: "", Name: "table-a", Id: ""}, {"", "table-b", ""}}, tables) + assert.Equal(t, []common.SchemaAndName{{Schema: "", Name: "table-a", Id: ""}, {Schema: "", Name: "table-b", Id: ""}}, tables) } func TestInfoSchemaImpl_GetTableName(t *testing.T) { diff --git a/sources/dynamodb/toddl_test.go b/sources/dynamodb/toddl_test.go index c73eaf25a..b316108e1 100644 --- a/sources/dynamodb/toddl_test.go +++ b/sources/dynamodb/toddl_test.go @@ -15,16 +15,19 @@ package dynamodb import ( + "context" "testing" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/expressions_api" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/mocks" "github.com/GoogleCloudPlatform/spanner-migration-tool/proto/migration" "github.com/GoogleCloudPlatform/spanner-migration-tool/schema" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) func TestToSpannerType(t *testing.T) { @@ -59,8 +62,17 @@ func TestToSpannerType(t *testing.T) { } conv.SrcSchema[name] = srcSchema conv.Audit = audit + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) + schemaToSpanner := &common.SchemaToSpannerImpl{ - DdlV: &expressions_api.MockDDLVerifier{}, + ExpressionVerificationAccessor: mockAccessor, + DdlV: &expressions_api.MockDDLVerifier{}, } assert.Nil(t, schemaToSpanner.SchemaToSpannerDDL(conv, ToDdlImpl{})) actual := conv.SpSchema[name] @@ -122,8 +134,17 @@ func TestToSpannerPostgreSQLDialectType(t *testing.T) { } conv.SrcSchema["t1"] = srcSchema conv.Audit = audit + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) + schemaToSpanner := &common.SchemaToSpannerImpl{ - DdlV: &expressions_api.MockDDLVerifier{}, + ExpressionVerificationAccessor: mockAccessor, + DdlV: &expressions_api.MockDDLVerifier{}, } assert.Nil(t, schemaToSpanner.SchemaToSpannerDDL(conv, ToDdlImpl{})) actual := conv.SpSchema["t1"] diff --git a/sources/mysql/infoschema.go b/sources/mysql/infoschema.go index d118b16e2..bfc0ee917 100644 --- a/sources/mysql/infoschema.go +++ b/sources/mysql/infoschema.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "fmt" + "regexp" "sort" "strings" @@ -34,6 +35,8 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/streaming" ) +var collationRegex = regexp.MustCompile(constants.DB_COLLATION_REGEX) + // InfoSchemaImpl is MySQL specific implementation for InfoSchema. type InfoSchemaImpl struct { DbName string @@ -273,7 +276,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem func (isi InfoSchemaImpl) getConstraintsDQL() (string, error) { var tableExistsCount int // check if CHECK_CONSTRAINTS table exists. - checkQuery := `SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';` + checkQuery := `SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE (TABLE_SCHEMA = 'information_schema' OR TABLE_SCHEMA = 'INFORMATION_SCHEMA') AND TABLE_NAME = 'CHECK_CONSTRAINTS';` err := isi.Db.QueryRow(checkQuery).Scan(&tableExistsCount) if err != nil { return "", err @@ -341,7 +344,7 @@ func (isi InfoSchemaImpl) processRow( // Case added to handle check constraints case "CHECK": checkClause = collationRegex.ReplaceAllString(checkClause, "") - *checkKeys = append(*checkKeys, schema.CheckConstraint{Name: constraintName, Expr: checkClause, Id: internal.GenerateCheckConstrainstId()}) + *checkKeys = append(*checkKeys, schema.CheckConstraint{Name: constraintName, Expr: checkClause, ExprId: internal.GenerateExpressionId(), Id: internal.GenerateCheckConstrainstId()}) default: m[col] = append(m[col], constraintType) } diff --git a/sources/mysql/infoschema_test.go b/sources/mysql/infoschema_test.go index 08742fdf2..d6d74f476 100644 --- a/sources/mysql/infoschema_test.go +++ b/sources/mysql/infoschema_test.go @@ -15,6 +15,7 @@ package mysql import ( + "context" "database/sql" "database/sql/driver" "regexp" @@ -22,10 +23,12 @@ import ( "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/expressions_api" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/mocks" "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" "github.com/GoogleCloudPlatform/spanner-migration-tool/schema" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" @@ -54,7 +57,7 @@ func TestProcessSchemaMYSQL(t *testing.T) { }, }, { - query: `SELECT COUNT\(\*\) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';`, + query: regexp.QuoteMeta(`SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE (TABLE_SCHEMA = 'information_schema' OR TABLE_SCHEMA = 'INFORMATION_SCHEMA') AND TABLE_NAME = 'CHECK_CONSTRAINTS';`), args: nil, cols: []string{"count"}, rows: [][]driver.Value{ @@ -94,7 +97,7 @@ func TestProcessSchemaMYSQL(t *testing.T) { cols: []string{"INDEX_NAME", "COLUMN_NAME", "SEQ_IN_INDEX", "COLLATION", "NON_UNIQUE"}, }, { - query: `SELECT COUNT\(\*\) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';`, + query: regexp.QuoteMeta(`SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE (TABLE_SCHEMA = 'information_schema' OR TABLE_SCHEMA = 'INFORMATION_SCHEMA') AND TABLE_NAME = 'CHECK_CONSTRAINTS';`), args: nil, cols: []string{"count"}, rows: [][]driver.Value{ @@ -143,7 +146,7 @@ func TestProcessSchemaMYSQL(t *testing.T) { }, }, { - query: `SELECT COUNT\(\*\) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';`, + query: regexp.QuoteMeta(`SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE (TABLE_SCHEMA = 'information_schema' OR TABLE_SCHEMA = 'INFORMATION_SCHEMA') AND TABLE_NAME = 'CHECK_CONSTRAINTS';`), args: nil, cols: []string{"count"}, rows: [][]driver.Value{ @@ -179,7 +182,7 @@ func TestProcessSchemaMYSQL(t *testing.T) { cols: []string{"INDEX_NAME", "COLUMN_NAME", "SEQ_IN_INDEX", "COLLATION", "NON_UNIQUE"}, }, { - query: `SELECT COUNT\(\*\) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';`, + query: regexp.QuoteMeta(`SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE (TABLE_SCHEMA = 'information_schema' OR TABLE_SCHEMA = 'INFORMATION_SCHEMA') AND TABLE_NAME = 'CHECK_CONSTRAINTS';`), args: nil, cols: []string{"count"}, rows: [][]driver.Value{ @@ -235,7 +238,7 @@ func TestProcessSchemaMYSQL(t *testing.T) { cols: []string{"INDEX_NAME", "COLUMN_NAME", "SEQ_IN_INDEX", "COLLATION", "NON_UNIQUE"}, }, { - query: `SELECT COUNT\(\*\) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';`, + query: regexp.QuoteMeta(`SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE (TABLE_SCHEMA = 'information_schema' OR TABLE_SCHEMA = 'INFORMATION_SCHEMA') AND TABLE_NAME = 'CHECK_CONSTRAINTS';`), args: nil, cols: []string{"count"}, rows: [][]driver.Value{ @@ -416,7 +419,7 @@ func TestProcessData_MultiCol(t *testing.T) { rows: [][]driver.Value{{"test"}}, }, { - query: `SELECT COUNT\(\*\) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';`, + query: regexp.QuoteMeta(`SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE (TABLE_SCHEMA = 'information_schema' OR TABLE_SCHEMA = 'INFORMATION_SCHEMA') AND TABLE_NAME = 'CHECK_CONSTRAINTS';`), args: nil, cols: []string{"count"}, rows: [][]driver.Value{ @@ -462,8 +465,17 @@ func TestProcessData_MultiCol(t *testing.T) { conv := internal.MakeConv() isi := InfoSchemaImpl{"test", db, "migration-project-id", profiles.SourceProfile{}, profiles.TargetProfile{}} processSchema := common.ProcessSchemaImpl{} + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) + schemaToSpanner := common.SchemaToSpannerImpl{ - DdlV: &expressions_api.MockDDLVerifier{}, + ExpressionVerificationAccessor: mockAccessor, + DdlV: &expressions_api.MockDDLVerifier{}, } err := processSchema.ProcessSchema(conv, isi, 1, internal.AdditionalSchemaAttributes{}, &schemaToSpanner, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) assert.Nil(t, err) @@ -523,7 +535,7 @@ func TestProcessSchema_Sharded(t *testing.T) { rows: [][]driver.Value{{"test"}}, }, { - query: `SELECT COUNT\(\*\) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';`, + query: regexp.QuoteMeta(`SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE (TABLE_SCHEMA = 'information_schema' OR TABLE_SCHEMA = 'INFORMATION_SCHEMA') AND TABLE_NAME = 'CHECK_CONSTRAINTS';`), args: nil, cols: []string{"count"}, rows: [][]driver.Value{ @@ -568,9 +580,17 @@ func TestProcessSchema_Sharded(t *testing.T) { db := mkMockDB(t, ms) conv := internal.MakeConv() isi := InfoSchemaImpl{"test", db, "migration-project-id", profiles.SourceProfile{}, profiles.TargetProfile{}} + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) processSchema := common.ProcessSchemaImpl{} schemaToSpanner := common.SchemaToSpannerImpl{ - DdlV: &expressions_api.MockDDLVerifier{}, + ExpressionVerificationAccessor: mockAccessor, + DdlV: &expressions_api.MockDDLVerifier{}, } err := processSchema.ProcessSchema(conv, isi, 1, internal.AdditionalSchemaAttributes{IsSharded: true}, &schemaToSpanner, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) assert.Nil(t, err) @@ -639,7 +659,7 @@ func mkMockDB(t *testing.T, ms []mockSpec) *sql.DB { func TestGetConstraints_CheckConstraintsTableExists(t *testing.T) { ms := []mockSpec{ { - query: `SELECT COUNT\(\*\) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';`, + query: regexp.QuoteMeta(`SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE (TABLE_SCHEMA = 'information_schema' OR TABLE_SCHEMA = 'INFORMATION_SCHEMA') AND TABLE_NAME = 'CHECK_CONSTRAINTS';`), cols: []string{"COUNT(*)"}, rows: [][]driver.Value{{1}}, }, @@ -675,7 +695,7 @@ func TestGetConstraints_CheckConstraintsTableExists(t *testing.T) { func TestGetConstraints_CheckConstraintsTableAbsent(t *testing.T) { ms := []mockSpec{ { - query: `SELECT COUNT\(\*\) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';`, + query: regexp.QuoteMeta(`SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE (TABLE_SCHEMA = 'information_schema' OR TABLE_SCHEMA = 'INFORMATION_SCHEMA' ) AND TABLE_NAME = 'CHECK_CONSTRAINTS';`), cols: []string{"COUNT(*)"}, rows: [][]driver.Value{{0}}, }, diff --git a/sources/mysql/mysqldump.go b/sources/mysql/mysqldump.go index 05e473f1d..0fef98337 100644 --- a/sources/mysql/mysqldump.go +++ b/sources/mysql/mysqldump.go @@ -28,6 +28,7 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/format" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/types" driver "github.com/pingcap/tidb/types/parser_driver" @@ -36,7 +37,7 @@ import ( var valuesRegexp = regexp.MustCompile("\\((.*?)\\)") var insertRegexp = regexp.MustCompile("INSERT\\sINTO\\s(.*?)\\sVALUES\\s") var unsupportedRegexp = regexp.MustCompile("function|procedure|trigger") -var collationRegex = regexp.MustCompile(constants.DB_COLLATION_REGEX) +var dbcollationRegex = regexp.MustCompile("_[_A-Za-z0-9]+('([^']*)')") // MysqlSpatialDataTypes is an array of all MySQL spatial data types. var MysqlSpatialDataTypes = []string{"geometrycollection", "multipoint", "multilinestring", "multipolygon", "point", "linestring", "polygon", "geometry"} @@ -51,7 +52,8 @@ var spatialIndexRegex = regexp.MustCompile("(?i)\\sSPATIAL\\s") var spatialSridRegex = regexp.MustCompile("(?i)\\sSRID\\s\\d*") // DbDumpImpl MySQL specific implementation for DdlDumpImpl. -type DbDumpImpl struct{} +type DbDumpImpl struct { +} // GetToDdl function below implement the common.DbDump interface. func (ddi DbDumpImpl) GetToDdl() common.ToDdl { @@ -247,6 +249,9 @@ func processCreateTable(conv *internal.Conv, stmt *ast.CreateTableStmt) { var keys []schema.Key var fkeys []schema.ForeignKey var index []schema.Index + + checkConstraints := getCheckConstraints(stmt.Constraints) + for _, element := range stmt.Cols { _, col, constraint, err := processColumn(conv, tableName, element) if err != nil { @@ -285,14 +290,16 @@ func processCreateTable(conv *internal.Conv, stmt *ast.CreateTableStmt) { } conv.SchemaStatement(NodeType(stmt)) conv.SrcSchema[tableId] = schema.Table{ - Id: tableId, - Name: tableName, - ColIds: colIds, - ColNameIdMap: colNameIdMap, - ColDefs: colDef, - PrimaryKeys: keys, - ForeignKeys: fkeys, - Indexes: index} + Id: tableId, + Name: tableName, + ColIds: colIds, + ColNameIdMap: colNameIdMap, + ColDefs: colDef, + PrimaryKeys: keys, + ForeignKeys: fkeys, + Indexes: index, + CheckConstraints: checkConstraints, + } for _, constraint := range stmt.Constraints { processConstraint(conv, tableId, constraint, "CREATE TABLE", conv.SrcSchema[tableId].ColNameIdMap) } @@ -326,6 +333,35 @@ func processConstraint(conv *internal.Conv, tableId string, constraint *ast.Cons conv.SrcSchema[tableId] = st } +// method to get check constraints using tiDB parser +func getCheckConstraints(constraints []*ast.Constraint) (checkConstraints []schema.CheckConstraint) { + for _, constraint := range constraints { + if constraint.Tp == ast.ConstraintCheck { + exp := expressionToString(constraint.Expr) + exp = dbcollationRegex.ReplaceAllString(exp, "$1") + checkConstraint := schema.CheckConstraint{ + Name: constraint.Name, + Expr: exp, + ExprId: internal.GenerateExpressionId(), + Id: internal.GenerateCheckConstrainstId(), + } + checkConstraints = append(checkConstraints, checkConstraint) + } + } + return checkConstraints +} + +// converts an AST expression node to its string representation. +func expressionToString(expr ast.Node) string { + var sb strings.Builder + restoreCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, &sb) + if err := expr.Restore(restoreCtx); err != nil { + fmt.Errorf("Error restoring expression: %v\n", err) + return "" + } + return sb.String() +} + // toSchemaKeys converts a string list of MySQL keys to schema keys. // Note that we map all MySQL keys to ascending ordered schema keys. // For primary keys: this is fine because MySQL primary keys are always ascending. diff --git a/sources/mysql/mysqldump_test.go b/sources/mysql/mysqldump_test.go index 0a2fffa53..7230e54a4 100644 --- a/sources/mysql/mysqldump_test.go +++ b/sources/mysql/mysqldump_test.go @@ -16,6 +16,7 @@ package mysql import ( "bufio" + "context" "fmt" "math/big" "math/bits" @@ -26,9 +27,11 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/expressions_api" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/mocks" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) func TestProcessMySQLDump_Scalar(t *testing.T) { @@ -929,14 +932,21 @@ func runProcessMySQLDump(s string) (*internal.Conv, []spannerData) { conv := internal.MakeConv() conv.SetLocation(time.UTC) conv.SetSchemaMode() + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) mysqlDbDump := DbDumpImpl{} - common.ProcessDbDump(conv, internal.NewReader(bufio.NewReader(strings.NewReader(s)), nil), mysqlDbDump, &expressions_api.MockDDLVerifier{}) + common.ProcessDbDump(conv, internal.NewReader(bufio.NewReader(strings.NewReader(s)), nil), mysqlDbDump, &expressions_api.MockDDLVerifier{}, mockAccessor) conv.SetDataMode() var rows []spannerData conv.SetDataSink(func(table string, cols []string, vals []interface{}) { rows = append(rows, spannerData{table: table, cols: cols, vals: vals}) }) - common.ProcessDbDump(conv, internal.NewReader(bufio.NewReader(strings.NewReader(s)), nil), mysqlDbDump, &expressions_api.MockDDLVerifier{}) + common.ProcessDbDump(conv, internal.NewReader(bufio.NewReader(strings.NewReader(s)), nil), mysqlDbDump, &expressions_api.MockDDLVerifier{}, mockAccessor) return conv, rows } diff --git a/sources/mysql/report_test.go b/sources/mysql/report_test.go index ac21f31cd..60c7a8633 100644 --- a/sources/mysql/report_test.go +++ b/sources/mysql/report_test.go @@ -17,6 +17,7 @@ package mysql import ( "bufio" "bytes" + "context" "encoding/json" "io/ioutil" "path/filepath" @@ -27,9 +28,11 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/expressions_api" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal/reports" + "github.com/GoogleCloudPlatform/spanner-migration-tool/mocks" "github.com/GoogleCloudPlatform/spanner-migration-tool/proto/migration" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) func TestReport(t *testing.T) { @@ -59,7 +62,14 @@ func TestReport(t *testing.T) { c text);` conv := internal.MakeConv() conv.SetSchemaMode() - common.ProcessDbDump(conv, internal.NewReader(bufio.NewReader(strings.NewReader(s)), nil), DbDumpImpl{}, &expressions_api.MockDDLVerifier{}) + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) + common.ProcessDbDump(conv, internal.NewReader(bufio.NewReader(strings.NewReader(s)), nil), DbDumpImpl{}, &expressions_api.MockDDLVerifier{}, mockAccessor) conv.SetDataMode() badSchemaTableId, err := internal.GetTableIdFromSpName(conv.SpSchema, "bad_schema") diff --git a/sources/mysql/toddl_test.go b/sources/mysql/toddl_test.go index 939463765..713efb0f4 100644 --- a/sources/mysql/toddl_test.go +++ b/sources/mysql/toddl_test.go @@ -15,15 +15,18 @@ package mysql import ( + "context" "testing" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/expressions_api" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/mocks" "github.com/GoogleCloudPlatform/spanner-migration-tool/schema" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) func TestToSpannerTypeInternal(t *testing.T) { @@ -178,8 +181,16 @@ func TestToSpannerType(t *testing.T) { PrimaryKeys: []schema.Key{{ColId: "c14"}}, } conv.UsedNames = map[string]bool{"ref_table": true, "ref_table2": true} + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) schemaToSpanner := common.SchemaToSpannerImpl{ - DdlV: &expressions_api.MockDDLVerifier{}, + ExpressionVerificationAccessor: mockAccessor, + DdlV: &expressions_api.MockDDLVerifier{}, } assert.Nil(t, schemaToSpanner.SchemaToSpannerDDL(conv, ToDdlImpl{})) actual := conv.SpSchema[tableId] @@ -274,8 +285,16 @@ func TestToSpannerPostgreSQLDialectType(t *testing.T) { PrimaryKeys: []schema.Key{{ColId: "c14"}}, } conv.UsedNames = map[string]bool{"ref_table": true, "ref_table2": true} + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) schemaToSpanner := common.SchemaToSpannerImpl{ - DdlV: &expressions_api.MockDDLVerifier{}, + ExpressionVerificationAccessor: mockAccessor, + DdlV: &expressions_api.MockDDLVerifier{}, } assert.Nil(t, schemaToSpanner.SchemaToSpannerDDL(conv, ToDdlImpl{})) actual := conv.SpSchema[tableId] diff --git a/sources/oracle/infoschema_test.go b/sources/oracle/infoschema_test.go index ee26b6a33..6a791387b 100644 --- a/sources/oracle/infoschema_test.go +++ b/sources/oracle/infoschema_test.go @@ -15,6 +15,7 @@ package oracle import ( + "context" "database/sql" "database/sql/driver" "fmt" @@ -22,9 +23,11 @@ import ( "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/GoogleCloudPlatform/spanner-migration-tool/expressions_api" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/mocks" "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" @@ -163,8 +166,18 @@ func TestProcessSchemaOracle(t *testing.T) { db := mkMockDB(t, ms) conv := internal.MakeConv() processSchema := common.ProcessSchemaImpl{} + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) + schemaToSpanner := common.SchemaToSpannerImpl{ - DdlV: &expressions_api.MockDDLVerifier{}, + ExpressionVerificationAccessor: mockAccessor, + DdlV: &expressions_api.MockDDLVerifier{}, } err := processSchema.ProcessSchema(conv, InfoSchemaImpl{"test", db, "migration-project-id", profiles.SourceProfile{}, profiles.TargetProfile{}}, 1, internal.AdditionalSchemaAttributes{}, &schemaToSpanner, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) assert.Nil(t, err) diff --git a/sources/oracle/toddl_test.go b/sources/oracle/toddl_test.go index a7876ef92..07391a7ea 100644 --- a/sources/oracle/toddl_test.go +++ b/sources/oracle/toddl_test.go @@ -15,16 +15,19 @@ package oracle import ( + "context" "testing" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/expressions_api" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/GoogleCloudPlatform/spanner-migration-tool/mocks" "github.com/GoogleCloudPlatform/spanner-migration-tool/schema" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "go.uber.org/zap" ) @@ -168,8 +171,16 @@ func TestToSpannerType(t *testing.T) { PrimaryKeys: []schema.Key{{ColId: "c12"}}, } conv.UsedNames = map[string]bool{"ref_table": true, "ref_table2": true} + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) schemaToSpanner := common.SchemaToSpannerImpl{ - DdlV: &expressions_api.MockDDLVerifier{}, + ExpressionVerificationAccessor: mockAccessor, + DdlV: &expressions_api.MockDDLVerifier{}, } assert.Nil(t, schemaToSpanner.SchemaToSpannerDDL(conv, ToDdlImpl{})) actual := conv.SpSchema[tableId] @@ -263,8 +274,16 @@ func TestToSpannerPostgreSQLDialectType(t *testing.T) { PrimaryKeys: []schema.Key{{ColId: "c15"}}, } conv.UsedNames = map[string]bool{"ref_table": true, "ref_table2": true} + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) schemaToSpanner := common.SchemaToSpannerImpl{ - DdlV: &expressions_api.MockDDLVerifier{}, + ExpressionVerificationAccessor: mockAccessor, + DdlV: &expressions_api.MockDDLVerifier{}, } assert.Nil(t, schemaToSpanner.SchemaToSpannerDDL(conv, ToDdlImpl{})) actual := conv.SpSchema[tableId] diff --git a/sources/postgres/infoschema_test.go b/sources/postgres/infoschema_test.go index c3d1c0f0b..7a808a21c 100644 --- a/sources/postgres/infoschema_test.go +++ b/sources/postgres/infoschema_test.go @@ -15,6 +15,7 @@ package postgres import ( + "context" "database/sql" "database/sql/driver" "math/big" @@ -27,11 +28,13 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/expressions_api" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/GoogleCloudPlatform/spanner-migration-tool/mocks" "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" "github.com/GoogleCloudPlatform/spanner-migration-tool/schema" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "go.uber.org/zap" ) @@ -231,9 +234,17 @@ func TestProcessSchema(t *testing.T) { } db := mkMockDB(t, ms) conv := internal.MakeConv() + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) processSchema := common.ProcessSchemaImpl{} schemaToSpanner := common.SchemaToSpannerImpl{ - DdlV: &expressions_api.MockDDLVerifier{}, + ExpressionVerificationAccessor: mockAccessor, + DdlV: &expressions_api.MockDDLVerifier{}, } err := processSchema.ProcessSchema(conv, InfoSchemaImpl{db, "migration-project-id", profiles.SourceProfile{}, profiles.TargetProfile{}, newFalsePtr()}, 1, internal.AdditionalSchemaAttributes{}, &schemaToSpanner, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) assert.Nil(t, err) @@ -519,9 +530,17 @@ func TestConvertSqlRow_MultiCol(t *testing.T) { } db := mkMockDB(t, ms) conv := internal.MakeConv() + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) processSchema := common.ProcessSchemaImpl{} schemaToSpanner := common.SchemaToSpannerImpl{ - DdlV: &expressions_api.MockDDLVerifier{}, + ExpressionVerificationAccessor: mockAccessor, + DdlV: &expressions_api.MockDDLVerifier{}, } err := processSchema.ProcessSchema(conv, InfoSchemaImpl{db, "migration-project-id", profiles.SourceProfile{}, profiles.TargetProfile{}, newFalsePtr()}, 1, internal.AdditionalSchemaAttributes{}, &schemaToSpanner, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) assert.Nil(t, err) diff --git a/sources/postgres/pgdump.go b/sources/postgres/pgdump.go index 0cc9347ec..25eef201a 100644 --- a/sources/postgres/pgdump.go +++ b/sources/postgres/pgdump.go @@ -31,7 +31,8 @@ import ( ) // DbDumpImpl Postgres specific implementation for DdlDumpImpl. -type DbDumpImpl struct{} +type DbDumpImpl struct { +} type copyOrInsert struct { stmt stmtType diff --git a/sources/postgres/pgdump_test.go b/sources/postgres/pgdump_test.go index cf8823fd9..33df0afef 100644 --- a/sources/postgres/pgdump_test.go +++ b/sources/postgres/pgdump_test.go @@ -16,6 +16,7 @@ package postgres import ( "bufio" + "context" "fmt" "math/big" "math/bits" @@ -28,10 +29,12 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/expressions_api" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/mocks" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" pg_query "github.com/pganalyze/pg_query_go/v5" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) type spannerData struct { @@ -1512,7 +1515,14 @@ func TestProcessPgDump_WithUnparsableContent(t *testing.T) { conv := internal.MakeConv() conv.SetLocation(time.UTC) conv.SetSchemaMode() - err := common.ProcessDbDump(conv, internal.NewReader(bufio.NewReader(strings.NewReader(s)), nil), DbDumpImpl{}, &expressions_api.MockDDLVerifier{}) + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) + err := common.ProcessDbDump(conv, internal.NewReader(bufio.NewReader(strings.NewReader(s)), nil), DbDumpImpl{}, &expressions_api.MockDDLVerifier{}, mockAccessor) if err == nil { t.Fatalf("Expect an error, but got nil") } @@ -1525,15 +1535,22 @@ func runProcessPgDump(s string) (*internal.Conv, []spannerData) { conv := internal.MakeConv() conv.SetLocation(time.UTC) conv.SetSchemaMode() + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) pgDump := DbDumpImpl{} - common.ProcessDbDump(conv, internal.NewReader(bufio.NewReader(strings.NewReader(s)), nil), pgDump, &expressions_api.MockDDLVerifier{}) + common.ProcessDbDump(conv, internal.NewReader(bufio.NewReader(strings.NewReader(s)), nil), pgDump, &expressions_api.MockDDLVerifier{}, mockAccessor) conv.SetDataMode() var rows []spannerData conv.SetDataSink( func(table string, cols []string, vals []interface{}) { rows = append(rows, spannerData{table: table, cols: cols, vals: vals}) }) - common.ProcessDbDump(conv, internal.NewReader(bufio.NewReader(strings.NewReader(s)), nil), pgDump, &expressions_api.MockDDLVerifier{}) + common.ProcessDbDump(conv, internal.NewReader(bufio.NewReader(strings.NewReader(s)), nil), pgDump, &expressions_api.MockDDLVerifier{}, mockAccessor) return conv, rows } @@ -1542,15 +1559,22 @@ func runProcessPgDumpPGTarget(s string) (*internal.Conv, []spannerData) { conv.SpDialect = constants.DIALECT_POSTGRESQL conv.SetLocation(time.UTC) conv.SetSchemaMode() + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) pgDump := DbDumpImpl{} - common.ProcessDbDump(conv, internal.NewReader(bufio.NewReader(strings.NewReader(s)), nil), pgDump, &expressions_api.MockDDLVerifier{}) + common.ProcessDbDump(conv, internal.NewReader(bufio.NewReader(strings.NewReader(s)), nil), pgDump, &expressions_api.MockDDLVerifier{}, mockAccessor) conv.SetDataMode() var rows []spannerData conv.SetDataSink( func(table string, cols []string, vals []interface{}) { rows = append(rows, spannerData{table: table, cols: cols, vals: vals}) }) - common.ProcessDbDump(conv, internal.NewReader(bufio.NewReader(strings.NewReader(s)), nil), pgDump, &expressions_api.MockDDLVerifier{}) + common.ProcessDbDump(conv, internal.NewReader(bufio.NewReader(strings.NewReader(s)), nil), pgDump, &expressions_api.MockDDLVerifier{}, mockAccessor) return conv, rows } diff --git a/sources/postgres/report_test.go b/sources/postgres/report_test.go index 4892a5a51..79a0ca52d 100644 --- a/sources/postgres/report_test.go +++ b/sources/postgres/report_test.go @@ -17,6 +17,7 @@ package postgres import ( "bufio" "bytes" + "context" "encoding/json" "io/ioutil" "path/filepath" @@ -27,9 +28,11 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/expressions_api" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal/reports" + "github.com/GoogleCloudPlatform/spanner-migration-tool/mocks" "github.com/GoogleCloudPlatform/spanner-migration-tool/proto/migration" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) func TestReport(t *testing.T) { @@ -54,7 +57,14 @@ func TestReport(t *testing.T) { c text);` conv := internal.MakeConv() conv.SetSchemaMode() - common.ProcessDbDump(conv, internal.NewReader(bufio.NewReader(strings.NewReader(s)), nil), DbDumpImpl{}, &expressions_api.MockDDLVerifier{}) + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) + common.ProcessDbDump(conv, internal.NewReader(bufio.NewReader(strings.NewReader(s)), nil), DbDumpImpl{}, &expressions_api.MockDDLVerifier{}, mockAccessor) conv.SetDataMode() badSchemaTableId, err := internal.GetTableIdFromSpName(conv.SpSchema, "bad_schema") diff --git a/sources/postgres/toddl_test.go b/sources/postgres/toddl_test.go index b995dbdc2..5174c9dfe 100644 --- a/sources/postgres/toddl_test.go +++ b/sources/postgres/toddl_test.go @@ -15,16 +15,19 @@ package postgres import ( + "context" "sort" "testing" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/expressions_api" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/mocks" "github.com/GoogleCloudPlatform/spanner-migration-tool/schema" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) func TestToSpannerTypeInternal(t *testing.T) { @@ -169,8 +172,16 @@ func TestToSpannerType(t *testing.T) { PrimaryKeys: []schema.Key{{ColId: "c10"}}, } conv.UsedNames = map[string]bool{"ref_table": true, "ref_table2": true} + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) schemaToSpanner := common.SchemaToSpannerImpl{ - DdlV: &expressions_api.MockDDLVerifier{}, + DdlV: &expressions_api.MockDDLVerifier{}, + ExpressionVerificationAccessor: mockAccessor, } assert.Nil(t, schemaToSpanner.SchemaToSpannerDDL(conv, ToDdlImpl{})) actual := conv.SpSchema[tableId] @@ -260,8 +271,16 @@ func TestToExperimentalSpannerType(t *testing.T) { PrimaryKeys: []schema.Key{{ColId: "c10"}}, } conv.UsedNames = map[string]bool{"ref_table": true, "ref_table2": true} + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) schemaToSpanner := common.SchemaToSpannerImpl{ - DdlV: &expressions_api.MockDDLVerifier{}, + ExpressionVerificationAccessor: mockAccessor, + DdlV: &expressions_api.MockDDLVerifier{}, } assert.Nil(t, schemaToSpanner.SchemaToSpannerDDL(conv, ToDdlImpl{})) actual := conv.SpSchema[tableId] diff --git a/sources/sqlserver/infoschema_test.go b/sources/sqlserver/infoschema_test.go index 11fc7155d..fcf9cad98 100644 --- a/sources/sqlserver/infoschema_test.go +++ b/sources/sqlserver/infoschema_test.go @@ -15,6 +15,7 @@ package sqlserver import ( + "context" "database/sql" "database/sql/driver" "testing" @@ -23,9 +24,11 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/expressions_api" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/GoogleCloudPlatform/spanner-migration-tool/mocks" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "go.uber.org/zap" ) @@ -249,9 +252,18 @@ func TestProcessSchema(t *testing.T) { } db := mkMockDB(t, ms) conv := internal.MakeConv() + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + processSchema := common.ProcessSchemaImpl{} + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) schemaToSpanner := common.SchemaToSpannerImpl{ - DdlV: &expressions_api.MockDDLVerifier{}, + ExpressionVerificationAccessor: mockAccessor, + DdlV: &expressions_api.MockDDLVerifier{}, } err := processSchema.ProcessSchema(conv, InfoSchemaImpl{"test", db}, 1, internal.AdditionalSchemaAttributes{}, &schemaToSpanner, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) assert.Nil(t, err) diff --git a/sources/sqlserver/toddl_test.go b/sources/sqlserver/toddl_test.go index d90698e78..2dee591eb 100644 --- a/sources/sqlserver/toddl_test.go +++ b/sources/sqlserver/toddl_test.go @@ -15,15 +15,18 @@ package sqlserver import ( + "context" "testing" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/expressions_api" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/mocks" "github.com/GoogleCloudPlatform/spanner-migration-tool/schema" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) func TestToSpannerTypeInternal(t *testing.T) { @@ -174,8 +177,16 @@ func TestToSpannerType(t *testing.T) { PrimaryKeys: []schema.Key{{ColId: "c19"}}, } conv.UsedNames = map[string]bool{"ref_table": true, "ref_table2": true} + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) schemaToSpanner := common.SchemaToSpannerImpl{ - DdlV: &expressions_api.MockDDLVerifier{}, + ExpressionVerificationAccessor: mockAccessor, + DdlV: &expressions_api.MockDDLVerifier{}, } assert.Nil(t, schemaToSpanner.SchemaToSpannerDDL(conv, ToDdlImpl{})) actual := conv.SpSchema[tableId] @@ -280,8 +291,16 @@ func TestToSpannerPostgreSQLDialectType(t *testing.T) { PrimaryKeys: []schema.Key{{ColId: "c19"}}, } conv.UsedNames = map[string]bool{"ref_table": true, "ref_table2": true} + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + }, + }) schemaToSpanner := common.SchemaToSpannerImpl{ - DdlV: &expressions_api.MockDDLVerifier{}, + ExpressionVerificationAccessor: mockAccessor, + DdlV: &expressions_api.MockDDLVerifier{}, } assert.Nil(t, schemaToSpanner.SchemaToSpannerDDL(conv, ToDdlImpl{})) actual := conv.SpSchema[tableId] diff --git a/spanner/ddl/ast.go b/spanner/ddl/ast.go index 9b260371c..db2cf02cd 100644 --- a/spanner/ddl/ast.go +++ b/spanner/ddl/ast.go @@ -268,9 +268,10 @@ type IndexKey struct { } type CheckConstraint struct { - Id string - Name string - Expr string + Id string + Name string + Expr string + ExprId string } // PrintPkOrIndexKey unparses the primary or index keys. diff --git a/testing/csv/integration_test.go b/testing/csv/integration_test.go index 662837f06..372da696b 100644 --- a/testing/csv/integration_test.go +++ b/testing/csv/integration_test.go @@ -172,7 +172,7 @@ func TestIntegration_CSV_Command(t *testing.T) { defer dropDatabase(t, dbURI) createSpannerSchema(t, projectID, instanceID, dbName) - args := fmt.Sprintf("data -source=csv -source-profile=manifest=%s -target-profile='instance=%s,dbName=%s'", MANIFEST_FILE_NAME, instanceID, dbName) + args := fmt.Sprintf("data -source=csv -source-profile=manifest=%s -target-profile='instance=%s,dbName=%s,project=%s'", MANIFEST_FILE_NAME, instanceID, dbName, projectID) err := common.RunCommand(args, projectID) if err != nil { t.Fatal(err) diff --git a/testing/dynamodb/snapshot/integration_test.go b/testing/dynamodb/snapshot/integration_test.go index 936a78678..26544e964 100644 --- a/testing/dynamodb/snapshot/integration_test.go +++ b/testing/dynamodb/snapshot/integration_test.go @@ -220,7 +220,7 @@ func TestIntegration_DYNAMODB_Command(t *testing.T) { dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, dbName) filePrefix := filepath.Join(tmpdir, dbName) - args := fmt.Sprintf(`schema-and-data -source=%s -prefix=%s -target-profile="instance=%s,dbName=%s"`, constants.DYNAMODB, filePrefix, instanceID, dbName) + args := fmt.Sprintf(`schema-and-data -source=%s -prefix=%s -target-profile="instance=%s,dbName=%s,project=%s"`, constants.DYNAMODB, filePrefix, instanceID, dbName, projectID) err := common.RunCommand(args, projectID) if err != nil { t.Fatal(err) diff --git a/testing/dynamodb/streaming/integration_test.go b/testing/dynamodb/streaming/integration_test.go index 9c44ade5f..2d0519d4d 100644 --- a/testing/dynamodb/streaming/integration_test.go +++ b/testing/dynamodb/streaming/integration_test.go @@ -334,7 +334,7 @@ func TestIntegration_DYNAMODB_Streaming_Command(t *testing.T) { } defer client.Close() - args := fmt.Sprintf(`schema-and-data -source=%s -prefix=%s -source-profile="enableStreaming=true" -target-profile="instance=%s,dbName=%s"`, constants.DYNAMODB, filePrefix, instanceID, dbName) + args := fmt.Sprintf(`schema-and-data -source=%s -prefix=%s -source-profile="enableStreaming=true" -target-profile="instance=%s,dbName=%s,project=%s"`, constants.DYNAMODB, filePrefix, instanceID, dbName, projectID) err = RunStreamingMigration(t, args, projectID, client) if err != nil { t.Fatal(err) diff --git a/testing/mysql/integration_test.go b/testing/mysql/integration_test.go index c8506449f..c25680623 100644 --- a/testing/mysql/integration_test.go +++ b/testing/mysql/integration_test.go @@ -122,7 +122,7 @@ func TestIntegration_MYSQL_SchemaAndDataSubcommand(t *testing.T) { filePrefix := filepath.Join(tmpdir, dbName) host, user, srcDb, password := os.Getenv("MYSQLHOST"), os.Getenv("MYSQLUSER"), os.Getenv("MYSQLDATABASE"), os.Getenv("MYSQLPWD") - args := fmt.Sprintf("schema-and-data -source=%s -prefix=%s -source-profile='host=%s,user=%s,dbName=%s,password=%s' -target-profile='instance=%s,dbName=%s'", constants.MYSQL, filePrefix, host, user, srcDb, password, instanceID, dbName) + args := fmt.Sprintf("schema-and-data -source=%s -prefix=%s -source-profile='host=%s,user=%s,dbName=%s,password=%s' -target-profile='instance=%s,dbName=%s,project=%s'", constants.MYSQL, filePrefix, host, user, srcDb, password, instanceID, dbName, projectID) err := common.RunCommand(args, projectID) if err != nil { t.Fatal(err) @@ -134,7 +134,7 @@ func TestIntegration_MYSQL_SchemaAndDataSubcommand(t *testing.T) { } func runSchemaSubcommand(t *testing.T, dbName, filePrefix, sessionFile, dumpFilePath string) { - args := fmt.Sprintf("schema -prefix %s -source=mysql -target-profile='instance=%s,dbName=%s' < %s", filePrefix, instanceID, dbName, dumpFilePath) + args := fmt.Sprintf("schema -prefix %s -source=mysql -target-profile='instance=%s,dbName=%s,project=%s' < %s", filePrefix, instanceID, dbName, projectID, dumpFilePath) err := common.RunCommand(args, projectID) if err != nil { t.Fatal(err) @@ -142,7 +142,7 @@ func runSchemaSubcommand(t *testing.T, dbName, filePrefix, sessionFile, dumpFile } func runDataSubcommand(t *testing.T, dbName, dbURI, filePrefix, sessionFile, dumpFilePath string) { - args := fmt.Sprintf("data -source=mysql -prefix %s -session %s -target-profile='instance=%s,dbName=%s' < %s", filePrefix, sessionFile, instanceID, dbName, dumpFilePath) + args := fmt.Sprintf("data -source=mysql -prefix %s -session %s -target-profile='instance=%s,dbName=%s,project=%s' < %s", filePrefix, sessionFile, instanceID, dbName, projectID, dumpFilePath) err := common.RunCommand(args, projectID) if err != nil { t.Fatal(err) @@ -150,7 +150,7 @@ func runDataSubcommand(t *testing.T, dbName, dbURI, filePrefix, sessionFile, dum } func runSchemaAndDataSubcommand(t *testing.T, dbName, dbURI, filePrefix, dumpFilePath string) { - args := fmt.Sprintf("schema-and-data -source=mysql -prefix %s -target-profile='instance=%s,dbName=%s' < %s", filePrefix, instanceID, dbName, dumpFilePath) + args := fmt.Sprintf("schema-and-data -source=mysql -prefix %s -target-profile='instance=%s,dbName=%s,project=%s' < %s", filePrefix, instanceID, dbName, projectID, dumpFilePath) err := common.RunCommand(args, projectID) if err != nil { t.Fatal(err) @@ -228,7 +228,7 @@ func TestIntegration_MYSQL_ForeignKeyActionMigration(t *testing.T) { filePrefix := filepath.Join(tmpdir, dbName) host, user, srcDb, password := os.Getenv("MYSQLHOST"), os.Getenv("MYSQLUSER"), os.Getenv("MYSQLDB_FKACTION"), os.Getenv("MYSQLPWD") - args := fmt.Sprintf("schema-and-data -source=%s -prefix=%s -source-profile='host=%s,user=%s,dbName=%s,password=%s' -target-profile='instance=%s,dbName=%s'", constants.MYSQL, filePrefix, host, user, srcDb, password, instanceID, dbName) + args := fmt.Sprintf("schema-and-data -source=%s -prefix=%s -source-profile='host=%s,user=%s,dbName=%s,password=%s' -target-profile='instance=%s,dbName=%s,project=%s'", constants.MYSQL, filePrefix, host, user, srcDb, password, instanceID, dbName, projectID) err := common.RunCommand(args, projectID) if err != nil { t.Fatal(err) diff --git a/testing/oracle/integration_test.go b/testing/oracle/integration_test.go index 5d065a45c..030a37d52 100644 --- a/testing/oracle/integration_test.go +++ b/testing/oracle/integration_test.go @@ -121,9 +121,10 @@ func TestIntegration_ORACLE_SchemaSubcommand(t *testing.T) { t.Parallel() tmpdir := prepareIntegrationTest(t) defer os.RemoveAll(tmpdir) + filePrefix := filepath.Join(tmpdir, "Oracle_IntTest.") - args := fmt.Sprintf("schema -prefix %s -source=%s -source-profile='host=localhost,user=STI,dbName=xe,password=test1,port=1521'", filePrefix, constants.ORACLE) + args := fmt.Sprintf("schema -prefix %s -source=%s -source-profile='host=localhost,user=STI,dbName=xe,password=test1,port=1521' -target-profile='instance=%s,dbName=xe,project=%s'", filePrefix, constants.ORACLE, instanceID, projectID) err := common.RunCommand(args, projectID) if err != nil { t.Fatal(err) @@ -138,7 +139,7 @@ func TestIntegration_ORACLE_SchemaAndDataSubcommand(t *testing.T) { dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, dbName) filePrefix := filepath.Join(tmpdir, "Oracle_IntTest.") - args := fmt.Sprintf("schema-and-data -prefix %s -source=%s -source-profile='host=localhost,user=STI,dbName=xe,password=test1,port=1521' -target-profile='instance=%s,dbName=%s'", filePrefix, constants.ORACLE, instanceID, dbName) + args := fmt.Sprintf("schema-and-data -prefix %s -source=%s -source-profile='host=localhost,user=STI,dbName=xe,password=test1,port=1521' -target-profile='instance=%s,dbName=%s,project=%s'", filePrefix, constants.ORACLE, instanceID, dbName, projectID) err := common.RunCommand(args, projectID) if err != nil { t.Fatal(err) diff --git a/testing/postgres/golden_test.go b/testing/postgres/golden_test.go index c69569c12..1955f1b43 100644 --- a/testing/postgres/golden_test.go +++ b/testing/postgres/golden_test.go @@ -16,6 +16,7 @@ package postgres_test import ( "bufio" + "context" "strings" "testing" "time" @@ -24,11 +25,13 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/expressions_api" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/GoogleCloudPlatform/spanner-migration-tool/mocks" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/postgres" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" commonTesting "github.com/GoogleCloudPlatform/spanner-migration-tool/testing/common" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "go.uber.org/zap" ) @@ -43,7 +46,12 @@ func TestGoldens(t *testing.T) { testCases := commonTesting.GoldenTestCasesFrom(t, GoldenTestsDir) t.Logf("executing %d test cases from %s", len(testCases), GoldenTestsDir) - schemaToSpanner := common.SchemaToSpannerImpl{} + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + ctx := context.Background() + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{{Result: true}}, + }, nil) + schemaToSpanner := common.SchemaToSpannerImpl{ExpressionVerificationAccessor: mockAccessor} for _, tc := range testCases { t.Run(tc.Name, func(t *testing.T) { @@ -54,7 +62,7 @@ func TestGoldens(t *testing.T) { err := common.ProcessDbDump( conv, internal.NewReader(bufio.NewReader(strings.NewReader(tc.Input)), nil), - postgres.DbDumpImpl{}, &expressions_api.MockDDLVerifier{}) + postgres.DbDumpImpl{}, &expressions_api.MockDDLVerifier{}, mockAccessor) if err != nil { t.Fatalf("error when processing dump %s: %s", tc.Input, err) } @@ -72,7 +80,6 @@ func TestGoldens(t *testing.T) { config.SpDialect = constants.DIALECT_POSTGRESQL actual = ddl.GetDDL(config, conv.SpSchema, conv.SpSequences) assert.Equal(t, tc.PSQLWant, formatDdl(actual)) - }) } } diff --git a/testing/postgres/integration_test.go b/testing/postgres/integration_test.go index 3dbefcb2d..1040d3831 100644 --- a/testing/postgres/integration_test.go +++ b/testing/postgres/integration_test.go @@ -121,7 +121,7 @@ func TestIntegration_PGDUMP_SchemaAndDataSubcommand(t *testing.T) { dataFilepath := "../../test_data/pg_dump.test.out" filePrefix := filepath.Join(tmpdir, dbName) - args := fmt.Sprintf("schema-and-data -prefix %s -source=postgres -target-profile='instance=%s,dbName=%s' < %s", filePrefix, instanceID, dbName, dataFilepath) + args := fmt.Sprintf("schema-and-data -prefix %s -source=postgres -target-profile='instance=%s,dbName=%s,project=%s' < %s", filePrefix, instanceID, dbName, projectID, dataFilepath) err := common.RunCommand(args, projectID) if err != nil { t.Fatal(err) @@ -148,7 +148,7 @@ func TestIntegration_PGDUMP_SchemaSubcommand(t *testing.T) { dataFilepath := "../../test_data/pg_dump.test.out" - args := fmt.Sprintf("schema -source=pg -target-profile='instance=%s,dbName=%s,dialect=%s' < %s", instanceID, dbName, dialect, dataFilepath) + args := fmt.Sprintf("schema -source=pg -target-profile='instance=%s,dbName=%s,dialect=%s,project=%s' < %s", instanceID, dbName, dialect, projectID, dataFilepath) err := common.RunCommand(args, projectID) if err != nil { t.Fatal(err) @@ -172,7 +172,7 @@ func TestIntegration_POSTGRES_SchemaAndDataSubcommand(t *testing.T) { dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, dbName) filePrefix := filepath.Join(tmpdir, dbName) - args := fmt.Sprintf("schema-and-data -prefix %s -source=postgres -target-profile='instance=%s,dbName=%s'", filePrefix, instanceID, dbName) + args := fmt.Sprintf("schema-and-data -prefix %s -source=postgres -target-profile='instance=%s,dbName=%s,project=%s'", filePrefix, instanceID, dbName, projectID) err := common.RunCommand(args, projectID) if err != nil { t.Fatal(err) @@ -195,7 +195,7 @@ func TestIntegration_POSTGRES_SchemaSubcommand(t *testing.T) { dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, dbName) filePrefix := filepath.Join(tmpdir, dbName) - args := fmt.Sprintf("schema -prefix %s -source=postgres -target-profile='instance=%s,dbName=%s'", filePrefix, instanceID, dbName) + args := fmt.Sprintf("schema -prefix %s -source=postgres -target-profile='instance=%s,dbName=%s,project=%s'", filePrefix, instanceID, dbName, projectID) err := common.RunCommand(args, "emulator-test-project") if err != nil { t.Fatal(err) @@ -219,7 +219,7 @@ func TestIntegration_PGDUMP_ForeignKeyActionMigration(t *testing.T) { dataFilepath := "../../test_data/pg_dump.test.out" filePrefix := filepath.Join(tmpdir, dbName) - args := fmt.Sprintf("schema-and-data -prefix %s -source=postgres -target-profile='instance=test-instance,dbName=%s' < %s", filePrefix, dbName, dataFilepath) + args := fmt.Sprintf("schema-and-data -prefix %s -source=postgres -target-profile='instance=test-instance,dbName=%s,project=%s' < %s", filePrefix, dbName, projectID, dataFilepath) err := common.RunCommand(args, "emulator-test-project") if err != nil { t.Fatal(err) @@ -243,7 +243,7 @@ func TestIntegration_POSTGRES_ForeignKeyActionMigration(t *testing.T) { dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, dbName) filePrefix := filepath.Join(tmpdir, dbName) - args := fmt.Sprintf("schema-and-data -prefix %s -source=postgres -target-profile='instance=%s,dbName=%s'", filePrefix, instanceID, dbName) + args := fmt.Sprintf("schema-and-data -prefix %s -source=postgres -target-profile='instance=%s,dbName=%s,project=%s'", filePrefix, instanceID, dbName, projectID) err := common.RunCommand(args, "emulator-test-project") if err != nil { t.Fatal(err) diff --git a/testing/sqlserver/integration_test.go b/testing/sqlserver/integration_test.go index d8d56d1e0..fe38db442 100644 --- a/testing/sqlserver/integration_test.go +++ b/testing/sqlserver/integration_test.go @@ -109,9 +109,10 @@ func TestIntegration_SQLserver_SchemaSubcommand(t *testing.T) { t.Parallel() tmpdir := prepareIntegrationTest(t) defer os.RemoveAll(tmpdir) + dbName := "sqlserver-schema-and-data" filePrefix := filepath.Join(tmpdir, "SqlServer_IntTest.") - args := fmt.Sprintf("schema -prefix %s -source=sqlserver -source-profile='host=localhost,user=sa,dbName=SqlServer_IntTest'", filePrefix) + args := fmt.Sprintf("schema -prefix %s -source=sqlserver -source-profile='host=localhost,user=sa,dbName=SqlServer_IntTest' -target-profile='instance=%s,dbName=%s,project=%s'", filePrefix, instanceID, dbName, projectID) err := common.RunCommand(args, projectID) if err != nil { t.Fatal(err) @@ -126,7 +127,7 @@ func TestIntegration_SQLserver_SchemaAndDataSubcommand(t *testing.T) { dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, dbName) filePrefix := filepath.Join(tmpdir, "SqlServer_IntTest.") - args := fmt.Sprintf("schema-and-data -prefix %s -source=%s -source-profile='host=localhost,user=sa,dbName=SqlServer_IntTest' -target-profile='instance=%s,dbName=%s'", filePrefix, constants.SQLSERVER, instanceID, dbName) + args := fmt.Sprintf("schema-and-data -prefix %s -source=%s -source-profile='host=localhost,user=sa,dbName=SqlServer_IntTest' -target-profile='instance=%s,dbName=%s,project=%s'", filePrefix, constants.SQLSERVER, instanceID, dbName, projectID) err := common.RunCommand(args, projectID) if err != nil { t.Fatal(err) diff --git a/webv2/api/schema.go b/webv2/api/schema.go index 7467f8ad2..22f2529c1 100644 --- a/webv2/api/schema.go +++ b/webv2/api/schema.go @@ -52,6 +52,10 @@ var ( var autoGenMap = make(map[string][]types.AutoGen) +type ExpressionsVerificationHandler struct { + ExpressionVerificationAccessor expressions_api.ExpressionVerificationAccessor +} + func init() { sessionState := session.GetSessionState() utilities.InitObjectId() @@ -62,7 +66,7 @@ func init() { // ConvertSchemaSQL converts source database to Spanner when using // with postgres and mysql driver. -func ConvertSchemaSQL(w http.ResponseWriter, r *http.Request) { +func (expressionVerificationHandler *ExpressionsVerificationHandler) ConvertSchemaSQL(w http.ResponseWriter, r *http.Request) { sessionState := session.GetSessionState() if sessionState.SourceDB == nil || sessionState.DbName == "" || sessionState.Driver == "" { http.Error(w, fmt.Sprintf("Database is not configured or Database connection is lost. Please set configuration and connect to database."), http.StatusNotFound) @@ -75,6 +79,9 @@ func ConvertSchemaSQL(w http.ResponseWriter, r *http.Request) { conv.SpInstanceId = sessionState.SpannerInstanceID conv.Source = sessionState.Driver conv.IsSharded = sessionState.IsSharded + conv.SpProjectId = sessionState.SpannerProjectId + conv.SpInstanceId = sessionState.SpannerInstanceID + conv.Source = sessionState.Driver var err error additionalSchemaAttributes := internal.AdditionalSchemaAttributes{ IsSharded: sessionState.IsSharded, @@ -87,7 +94,8 @@ func ConvertSchemaSQL(w http.ResponseWriter, r *http.Request) { return } schemaToSpanner := common.SchemaToSpannerImpl{ - DdlV: ddlVerifier, + ExpressionVerificationAccessor: expressionVerificationHandler.ExpressionVerificationAccessor, + DdlV: ddlVerifier, } switch sessionState.Driver { case constants.MYSQL: @@ -154,7 +162,7 @@ func ConvertSchemaSQL(w http.ResponseWriter, r *http.Request) { // ConvertSchemaDump converts schema from dump file to Spanner schema for // mysqldump and pg_dump driver. -func ConvertSchemaDump(w http.ResponseWriter, r *http.Request) { +func (expressionVerificationHandler *ExpressionsVerificationHandler) ConvertSchemaDump(w http.ResponseWriter, r *http.Request) { reqBody, err := ioutil.ReadAll(r.Body) if err != nil { http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) @@ -179,7 +187,7 @@ func ConvertSchemaDump(w http.ResponseWriter, r *http.Request) { sessionState := session.GetSessionState() SpProjectId := sessionState.SpannerProjectId SpInstanceId := sessionState.SpannerInstanceID - conv, err := schemaFromSource.SchemaFromDump(SpProjectId, SpInstanceId, sourceProfile.Driver, dc.SpannerDetails.Dialect, &utils.IOStreams{In: f, Out: os.Stdout}, &conversion.ProcessDumpByDialectImpl{}) + conv, err := schemaFromSource.SchemaFromDump(SpProjectId, SpInstanceId, sourceProfile.Driver, dc.SpannerDetails.Dialect, &utils.IOStreams{In: f, Out: os.Stdout}, &conversion.ProcessDumpByDialectImpl{ExpressionVerificationAccessor: expressionVerificationHandler.ExpressionVerificationAccessor}) if err != nil { http.Error(w, fmt.Sprintf("Schema Conversion Error : %v", err), http.StatusNotFound) return @@ -587,17 +595,9 @@ func UpdateCheckConstraint(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(convm) } -func doesNameExist(spcks []ddl.CheckConstraint, targetName string) bool { - for _, spck := range spcks { - if strings.Contains(spck.Expr, targetName) { - return true - } - } - return false -} - -// ValidateCheckConstraint verifies if the type of a database column has been altered and add an error if a change is detected. -func ValidateCheckConstraint(w http.ResponseWriter, r *http.Request) { +// VerifyExpression this function will use expression_api to validate check constraint expressions and add the relevant error +// to suggestion tab and remove the check constraint which has error +func (expressionVerificationHandler *ExpressionsVerificationHandler) VerifyCheckConstraintExpression(w http.ResponseWriter, r *http.Request) { sessionState := session.GetSessionState() if sessionState.Conv == nil || sessionState.Driver == "" { http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) @@ -606,36 +606,50 @@ func ValidateCheckConstraint(w http.ResponseWriter, r *http.Request) { sessionState.Conv.ConvLock.Lock() defer sessionState.Conv.ConvLock.Unlock() - sp := sessionState.Conv.SpSchema - srcschema := sessionState.Conv.SrcSchema - flag := true - var schemaIssue []internal.SchemaIssue - - for _, src := range srcschema { - for _, col := range sp[src.Id].ColDefs { - if len(sp[src.Id].CheckConstraints) > 0 { - spType := col.T.Name - srcType := srcschema[src.Id].ColDefs[col.Id].Type - actualType := mysqlDefaultTypeMap[srcType.Name] - if actualType.Name != spType { - columnName := sp[src.Id].ColDefs[col.Id].Name - spcks := sp[src.Id].CheckConstraints - if doesNameExist(spcks, columnName) { - flag = false - schemaIssue = sessionState.Conv.SchemaIssues[src.Id].ColumnLevelIssues[col.Id] - if !utilities.IsSchemaIssuePresent(schemaIssue, internal.TypeMismatch) { - schemaIssue = append(schemaIssue, internal.TypeMismatch) - } - sessionState.Conv.SchemaIssues[src.Id].ColumnLevelIssues[col.Id] = schemaIssue - break + spschema := sessionState.Conv.SpSchema + + hasErrorOccurred := false + + ctx := context.Background() + + verifyExpressionsInput := internal.VerifyExpressionsInput{ + Conv: sessionState.Conv, + Source: "mysql", + ExpressionDetailList: common.GenerateExpressionDetailList(spschema), + } + sessionState.Conv.SchemaIssues = common.RemoveError(sessionState.Conv.SchemaIssues) + result := expressionVerificationHandler.ExpressionVerificationAccessor.VerifyExpressions(ctx, verifyExpressionsInput) + if result.ExpressionVerificationOutputList == nil { + http.Error(w, fmt.Sprintf("Unhandled error: : %s", result.Err.Error()), http.StatusBadRequest) + return + } + + issueTypes := common.GetIssue(result) + if len(issueTypes) > 0 { + hasErrorOccurred = true + for tableId, issues := range issueTypes { + for _, issue := range issues { + if _, exists := sessionState.Conv.SchemaIssues[tableId]; !exists { + sessionState.Conv.SchemaIssues[tableId] = internal.TableIssues{ + TableLevelIssues: []internal.SchemaIssue{}, } } + + tableIssue := sessionState.Conv.SchemaIssues[tableId] + + if !utilities.IsSchemaIssuePresent(tableIssue.TableLevelIssues, issue) { + tableIssue.TableLevelIssues = append(tableIssue.TableLevelIssues, issue) + } + + sessionState.Conv.SchemaIssues[tableId] = tableIssue } } } + session.UpdateSessionFile() + w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(flag) + json.NewEncoder(w).Encode(hasErrorOccurred) } // renameForeignKeys checks the new names for spanner name validity, ensures the new names are already not used by existing tables diff --git a/webv2/api/schema_test.go b/webv2/api/schema_test.go index ecc8b7664..93f5488b5 100644 --- a/webv2/api/schema_test.go +++ b/webv2/api/schema_test.go @@ -3,6 +3,7 @@ package api_test import ( "bytes" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -15,6 +16,7 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal/reports" "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/GoogleCloudPlatform/spanner-migration-tool/mocks" "github.com/GoogleCloudPlatform/spanner-migration-tool/proto/migration" "github.com/GoogleCloudPlatform/spanner-migration-tool/schema" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" @@ -22,6 +24,7 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/session" "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/types" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "go.uber.org/zap" ) @@ -829,6 +832,9 @@ func TestRenameIndexes(t *testing.T) { "t1": { Indexes: []ddl.CreateIndex{{Name: "idx_1", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false}}}, {Name: "idx_2", Id: "i2", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false}}}}, + ForeignKeys: []ddl.Foreignkey{{Name: "fk1", Id: "fkId1", ColIds: []string{"c2"}, ReferTableId: "reft1", ReferColumnIds: []string{"ref_b"}}, + {Name: "fk2", Id: "fkId2", ColIds: []string{"c3", "d"}, ReferTableId: "reft2", ReferColumnIds: []string{"ref_c", "ref_d"}}}, + ParentTable: ddl.InterleavedParent{Id: "", OnDelete: ""}, }}, }, }, @@ -2682,3 +2688,148 @@ type errReader struct{} func (errReader) Read(p []byte) (n int, err error) { return 0, fmt.Errorf("simulated read error") } + +func TestVerifyCheckConstraintExpressions(t *testing.T) { + tests := []struct { + name string + expressions []ddl.CheckConstraint + expectedResults []internal.ExpressionVerificationOutput + expectedResponse bool + }{ + { + name: "AllValidExpressions", + expressions: []ddl.CheckConstraint{ + {Expr: "(col1 > 0)", ExprId: "expr1", Name: "check1"}, + {Expr: "(((col1 > 0) and (col2 like 'A%') and (col4 between 5 and 100)) or (col5 in ('Alpha', 'Beta', 'Gamma')) )", ExprId: "expr2", Name: "complex_check"}, + {Expr: "(col1 > 10)", ExprId: "expr3", Name: "conflict_check1"}, + {Expr: "(col1 < 40)", ExprId: "expr4", Name: "conflict_check2"}, + {Expr: "(col1 > 0)", ExprId: "expr5", Name: "auto_increment_check"}, + {Expr: "(price >= 0.01 AND price <= 10000.00)", ExprId: "expr6", Name: "numeric_check"}, + {Expr: "username <> 'invalid'", ExprId: "expr7", Name: "character_check"}, + {Expr: "(status IN ('Pending', 'In Progress', 'Completed', 'Cancelled'))", ExprId: "expr8", Name: "enumerate_check"}, + {Expr: "((col2 & 8) = 0)", ExprId: "expr9", Name: "bitwise_check"}, + {Expr: "featureA IN (0, 1)", ExprId: "expr10", Name: "boolean_check"}, + }, + expectedResults: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(((col1 > 0) and (col2 like 'A%') and (col4 between 5 and 100)) or (col5 in ('Alpha', 'Beta', 'Gamma')) )", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c2", "checkConstraintName": "complex_check"}, ExpressionId: "expr2"}}, + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 10)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c3", "checkConstraintName": "conflict_check1"}, ExpressionId: "expr3"}}, + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 < 40)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c4", "checkConstraintName": "conflict_check2"}, ExpressionId: "expr4"}}, + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col3 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c5", "checkConstraintName": "auto_increment_check"}, ExpressionId: "expr5"}}, + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(price >= 0.01 AND price <= 10000.00)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c6", "checkConstraintName": "numeric_check"}, ExpressionId: "expr6"}}, + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "username <> 'invalid'", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c7", "checkConstraintName": "character_check"}, ExpressionId: "expr7"}}, + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(status IN ('Pending', 'In Progress', 'Completed', 'Cancelled'))", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c8", "checkConstraintName": "enumerate_check"}, ExpressionId: "expr8"}}, + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "((col2 & 8) = 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c9", "checkConstraintName": "bitwise_check"}, ExpressionId: "expr9"}}, + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "featureA IN (0, 1)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c10", "checkConstraintName": "boolean_check"}, ExpressionId: "expr10"}}, + }, + expectedResponse: false, + }, + { + name: "InvalidSyntaxError", + expressions: []ddl.CheckConstraint{ + {Expr: "(col1 > 0)", ExprId: "expr1", Name: "check1"}, + {Expr: "(col1 > 18", ExprId: "expr2", Name: "check2"}, + }, + expectedResults: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + {Result: false, Err: errors.New("Syntax error ..."), ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 18", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check2"}, ExpressionId: "expr2"}}, + }, + expectedResponse: true, + }, + { + name: "NameError", + expressions: []ddl.CheckConstraint{ + {Expr: "(col1 > 0)", ExprId: "expr1", Name: "check1"}, + {Expr: "(col1 > 18)", ExprId: "expr2", Name: "check2"}, + }, + expectedResults: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + {Result: false, Err: errors.New("Unrecognized name ..."), ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 18)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check2"}, ExpressionId: "expr2"}}, + }, + expectedResponse: true, + }, + { + name: "TypeError", + expressions: []ddl.CheckConstraint{ + {Expr: "(col1 > 0)", ExprId: "expr1", Name: "check1"}, + {Expr: "(col1 > 18)", ExprId: "expr2", Name: "check2"}, + }, + expectedResults: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + {Result: false, Err: errors.New("No matching signature for operator"), ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 18)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check2"}, ExpressionId: "expr2"}}, + }, + expectedResponse: true, + }, + { + name: "FunctionError", + expressions: []ddl.CheckConstraint{ + {Expr: "(col1 > 0)", ExprId: "expr1", Name: "check1"}, + {Expr: "(col1 > 18)", ExprId: "expr2", Name: "check2"}, + }, + expectedResults: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + {Result: false, Err: errors.New("Function not found"), ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 18)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check2"}, ExpressionId: "expr2"}}, + }, + expectedResponse: true, + }, + { + name: "GenericError", + expressions: []ddl.CheckConstraint{ + {Expr: "(col1 > 0)", ExprId: "expr1", Name: "check1"}, + {Expr: "(col1 > 18)", ExprId: "expr2", Name: "check2"}, + }, + expectedResults: []internal.ExpressionVerificationOutput{ + {Result: true, Err: nil, ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 0)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check1"}, ExpressionId: "expr1"}}, + {Result: false, Err: errors.New("Unhandle error"), ExpressionDetail: internal.ExpressionDetail{Expression: "(col1 > 18)", Type: "CHECK", Metadata: map[string]string{"tableId": "t1", "colId": "c1", "checkConstraintName": "check2"}, ExpressionId: "expr2"}}, + }, + expectedResponse: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockAccessor := new(mocks.MockExpressionVerificationAccessor) + handler := &api.ExpressionsVerificationHandler{ExpressionVerificationAccessor: mockAccessor} + + req, err := http.NewRequest("POST", "/verifyCheckConstraintExpression", nil) + if err != nil { + t.Fatal(err) + } + + ctx := req.Context() + sessionState := session.GetSessionState() + sessionState.Driver = constants.MYSQL + sessionState.SpannerInstanceID = "foo" + sessionState.SpannerProjectId = "daring-12" + sessionState.Conv = internal.MakeConv() + sessionState.Conv.SpSchema = map[string]ddl.CreateTable{ + "t1": { + Name: "table1", + Id: "t1", + PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, + ColIds: []string{"c1"}, + ColDefs: map[string]ddl.ColumnDef{ + "c1": {Name: "col1", Id: "c1", T: ddl.Type{Name: ddl.Int64}}, + }, + CheckConstraints: tc.expressions, + }, + } + + mockAccessor.On("VerifyExpressions", ctx, mock.Anything).Return(internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: tc.expectedResults, + }) + + rr := httptest.NewRecorder() + handler.VerifyCheckConstraintExpression(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + + var response bool + err = json.NewDecoder(rr.Body).Decode(&response) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, tc.expectedResponse, response) + }) + } +} diff --git a/webv2/routes.go b/webv2/routes.go index 777f3648c..90970c6ca 100644 --- a/webv2/routes.go +++ b/webv2/routes.go @@ -57,9 +57,15 @@ func getRoutes() *mux.Router { ValidateResources: validateResourceImpl, } + expressionVerificationAccessor, _ := expressions_api.NewExpressionVerificationAccessorImpl(ctx, session.GetSessionState().SpannerProjectId, session.GetSessionState().SpannerInstanceID) + + expressionVerificationHandler := api.ExpressionsVerificationHandler{ + ExpressionVerificationAccessor: expressionVerificationAccessor, + } + router.HandleFunc("/connect", databaseConnection).Methods("POST") - router.HandleFunc("/convert/infoschema", api.ConvertSchemaSQL).Methods("GET") - router.HandleFunc("/convert/dump", api.ConvertSchemaDump).Methods("POST") + router.HandleFunc("/convert/infoschema", expressionVerificationHandler.ConvertSchemaSQL).Methods("GET") + router.HandleFunc("/convert/dump", expressionVerificationHandler.ConvertSchemaDump).Methods("POST") router.HandleFunc("/convert/session", loadSession).Methods("POST") router.HandleFunc("/ddl", api.GetDDL).Methods("GET") router.HandleFunc("/seqDdl", api.GetSequenceDDL).Methods("GET") @@ -81,6 +87,7 @@ func getRoutes() *mux.Router { router.HandleFunc("/getSequenceKind", api.GetSequenceKind).Methods("GET") router.HandleFunc("/setparent", api.SetParentTable).Methods("GET") router.HandleFunc("/removeParent", api.RemoveParentTable).Methods("POST") + router.HandleFunc("/verifyCheckConstraintExpression", expressionVerificationHandler.VerifyCheckConstraintExpression).Methods("GET") // TODO:(searce) take constraint names themselves which are guaranteed to be unique for Spanner. router.HandleFunc("/drop/secondaryindex", api.DropSecondaryIndex).Methods("POST")