Skip to content

Commit

Permalink
feat: Add VerifyExpressions API (#937)
Browse files Browse the repository at this point in the history
* Add Expression verification API signature

* Implement Expression Verification APIs

* Address comments from design review

* Update thread pool

* Refactor spanner accessor and add UTs

* Remove expressions API

* Add expressions API

* Add test data

* Use spannerClient.DatabaseName() inside the API

* Address comments and rebase

* Fix sequence removal implementation and UTs

* Fix db name

* Address comments
  • Loading branch information
manitgupta authored Nov 28, 2024
1 parent 7f2d623 commit 7a76a1b
Show file tree
Hide file tree
Showing 9 changed files with 849 additions and 17 deletions.
5 changes: 5 additions & 0 deletions accessors/clients/spanner/client/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

type SpannerClient interface {
Single() ReadOnlyTransaction
DatabaseName() string
}

type ReadOnlyTransaction interface {
Expand Down Expand Up @@ -39,6 +40,10 @@ func (c *SpannerClientImpl) Single() ReadOnlyTransaction {
return &ReadOnlyTransactionImpl{rotxn: rotxn}
}

func (c *SpannerClientImpl) DatabaseName() string {
return c.spannerClient.DatabaseName()
}

type ReadOnlyTransactionImpl struct {
rotxn *spanner.ReadOnlyTransaction
}
Expand Down
5 changes: 5 additions & 0 deletions accessors/clients/spanner/client/mocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

type SpannerClientMock struct {
SingleMock func() ReadOnlyTransaction
DatabaseNameMock func() string
}

type ReadOnlyTransactionMock struct {
Expand All @@ -23,6 +24,10 @@ func (scm SpannerClientMock) Single() ReadOnlyTransaction {
return scm.SingleMock()
}

func (scm SpannerClientMock) DatabaseName() string {
return scm.DatabaseNameMock()
}

func (rom ReadOnlyTransactionMock) Query(ctx context.Context, stmt spanner.Statement) RowIterator {
return rom.QueryMock(ctx, stmt)
}
Expand Down
26 changes: 13 additions & 13 deletions cmd/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (

sp "cloud.google.com/go/spanner"
database "cloud.google.com/go/spanner/admin/database/apiv1"
"github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/datastream"
datastreamclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/datastream"
storageclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/storage"
datastream_accessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/datastream"
spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner"
Expand Down Expand Up @@ -160,18 +160,18 @@ func MigrateDatabase(ctx context.Context, migrationProjectId string, targetProfi

func migrateSchema(ctx context.Context, targetProfile profiles.TargetProfile, sourceProfile profiles.SourceProfile,
ioHelper *utils.IOStreams, conv *internal.Conv, dbURI string, adminClient *database.DatabaseAdminClient) error {
spA, err := spanneraccessor.NewSpannerAccessorClientImpl(ctx)
if err != nil {
return err
}
err = spA.CreateOrUpdateDatabase(ctx, dbURI, sourceProfile.Driver, conv, sourceProfile.Config.ConfigType)
if err != nil {
err = fmt.Errorf("can't create/update database: %v", err)
return err
}
metricsPopulation(ctx, sourceProfile.Driver, conv)
conv.Audit.Progress.UpdateProgress("Schema migration complete.", completionPercentage, internal.SchemaMigrationComplete)
return nil
spA, err := spanneraccessor.NewSpannerAccessorClientImpl(ctx)
if err != nil {
return err
}
err = spA.CreateOrUpdateDatabase(ctx, dbURI, sourceProfile.Driver, conv, sourceProfile.Config.ConfigType)
if err != nil {
err = fmt.Errorf("can't create/update database: %v", err)
return err
}
metricsPopulation(ctx, sourceProfile.Driver, conv)
conv.Audit.Progress.UpdateProgress("Schema migration complete.", completionPercentage, internal.SchemaMigrationComplete)
return nil
}

func migrateData(ctx context.Context, migrationProjectId string, targetProfile profiles.TargetProfile, sourceProfile profiles.SourceProfile,
Expand Down
5 changes: 5 additions & 0 deletions common/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,9 @@ const (
// GCS PUBSUB MODES
REGULAR_GCS string = "data"
DLQ_GCS string = "dlq"

//VerifyExpresions API
CHECK_EXPRESSION = "CHECK"
DEFAUT_EXPRESSION = "DEFAULT"

)
4 changes: 2 additions & 2 deletions conversion/data_from_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (

"cloud.google.com/go/datastream/apiv1/datastreampb"
sp "cloud.google.com/go/spanner"
"github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/datastream"
datastreamclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/datastream"
storageclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/storage"
datastream_accessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/datastream"
spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner"
Expand Down Expand Up @@ -60,7 +60,7 @@ func (dd *DataFromDatabaseImpl) dataFromDatabaseForDMSMigration() (*writer.Batch
// 5. Perform streaming migration via dataflow
func (dd *DataFromDatabaseImpl) dataFromDatabaseForDataflowMigration(migrationProjectId string, targetProfile profiles.TargetProfile, ctx context.Context, sourceProfile profiles.SourceProfile, conv *internal.Conv, is common.InfoSchemaInterface) (*writer.BatchWriter, error) {
// Fetch Spanner Region
if conv.SpRegion == "" {
if conv.SpRegion == "" {
spAcc, err := spanneraccessor.NewSpannerAccessorClientImpl(ctx)
if err != nil {
return nil, fmt.Errorf("unable to fetch Spanner Region for resource creation: %v", err)
Expand Down
143 changes: 143 additions & 0 deletions expressions_api/expression_verify.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package expressions_api

import (
"context"
"encoding/json"
"fmt"
"sync"

spannerclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/client"
spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner"
"github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants"
"github.com/GoogleCloudPlatform/spanner-migration-tool/common/task"
"github.com/GoogleCloudPlatform/spanner-migration-tool/internal"
"github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl"
)

const THREAD_POOL = 500

type ExpressionVerificationAccessor interface {
//Batch API which parallelizes expression verification calls
VerifyExpressions(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput
}

type ExpressionVerificationAccessorImpl struct {
SpannerAccessor *spanneraccessor.SpannerAccessorImpl
}

func NewExpressionVerificationAccessorImpl(ctx context.Context, project string, instance string) (*ExpressionVerificationAccessorImpl, error) {
spannerAccessor, err := spanneraccessor.NewSpannerAccessorClientImplWithSpannerClient(ctx, fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, "smt-staging-db"))
if err != nil {
return nil, err
}
return &ExpressionVerificationAccessorImpl{
SpannerAccessor: spannerAccessor,
}, nil
}

// This is an internal struct to the API implementation and should not leak out of the spanneraccessor package (member fields are not exported)
type ExpressionVerificationInput struct {
spannerClient spannerclient.SpannerClient
expressionDetail internal.ExpressionDetail
}

func (ev *ExpressionVerificationAccessorImpl) VerifyExpressions(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput {
err := ev.validateRequest(verifyExpressionsInput)
if err != nil {
return internal.VerifyExpressionsOutput{Err: err}
}
dbURI := ev.SpannerAccessor.SpannerClient.DatabaseName()
dbExists, err := ev.SpannerAccessor.CheckExistingDb(ctx, dbURI)
if err != nil {
return internal.VerifyExpressionsOutput{Err: err}
}
if dbExists {
err := ev.SpannerAccessor.DropDatabase(ctx, dbURI)
if err != nil {
return internal.VerifyExpressionsOutput{Err: err}
}
}
verifyExpressionsInput.Conv, err = ev.removeExpressions(verifyExpressionsInput.Conv)
if err != nil {
return internal.VerifyExpressionsOutput{Err: err}
}
err = ev.SpannerAccessor.CreateDatabase(ctx, dbURI, verifyExpressionsInput.Conv, verifyExpressionsInput.Source, constants.DATAFLOW_MIGRATION)
if err != nil {
return internal.VerifyExpressionsOutput{Err: err}
}
//Drop the staging database after verifications are completed.
defer ev.SpannerAccessor.DropDatabase(ctx, dbURI)
verificationInputList := make([]ExpressionVerificationInput, len(verifyExpressionsInput.ExpressionDetailList))
for i, expressionDetail := range verifyExpressionsInput.ExpressionDetailList {
verificationInputList[i] = ExpressionVerificationInput{
spannerClient: ev.SpannerAccessor.SpannerClient,
expressionDetail: expressionDetail,
}
}
r := task.RunParallelTasksImpl[ExpressionVerificationInput, internal.ExpressionVerificationOutput]{}
expressionVerificationOutputList, _ := r.RunParallelTasks(verificationInputList, THREAD_POOL, ev.verifyExpressionInternal, true)
var verifyExpressionsOutput internal.VerifyExpressionsOutput
var errorCount int16 = 0
for _, expressionVerificationOutput := range expressionVerificationOutputList {
verifyExpressionsOutput.ExpressionVerificationOutputList = append(verifyExpressionsOutput.ExpressionVerificationOutputList, expressionVerificationOutput.Result)
if expressionVerificationOutput.Result.Err != nil {
errorCount++
}
}
if errorCount != 0 {
verifyExpressionsOutput.Err = fmt.Errorf("%d expressions either failed verification or did not get verified. Please look at the individual errors returned for each expression", errorCount)

}
return verifyExpressionsOutput
}

func (ev *ExpressionVerificationAccessorImpl) verifyExpressionInternal(expressionVerificationInput ExpressionVerificationInput, mutex *sync.Mutex) task.TaskResult[internal.ExpressionVerificationOutput] {
var sqlStatement string
switch expressionVerificationInput.expressionDetail.Type {
case constants.CHECK_EXPRESSION:
sqlStatement = fmt.Sprintf("SELECT 1 from %s where %s;", expressionVerificationInput.expressionDetail.ReferenceElement.Name, expressionVerificationInput.expressionDetail.Expression)
case constants.DEFAUT_EXPRESSION:
sqlStatement = fmt.Sprintf("SELECT CAST(%s as %s)", expressionVerificationInput.expressionDetail.Expression, expressionVerificationInput.expressionDetail.ReferenceElement.Name)
default:
return task.TaskResult[internal.ExpressionVerificationOutput]{Result: internal.ExpressionVerificationOutput{Result: false, Err: fmt.Errorf("invalid expression type requested")}, Err: nil}
}
result, err := ev.SpannerAccessor.ValidateDML(context.Background(), sqlStatement)
return task.TaskResult[internal.ExpressionVerificationOutput]{Result: internal.ExpressionVerificationOutput{Result: result, Err: err, ExpressionDetail: expressionVerificationInput.expressionDetail}, Err: nil}
}

func (ev *ExpressionVerificationAccessorImpl) validateRequest(verifyExpressionsInput internal.VerifyExpressionsInput) error {
if verifyExpressionsInput.Conv == nil || verifyExpressionsInput.Source == "" {
return fmt.Errorf("one of conv or source is empty. These are mandatory fields = %v", verifyExpressionsInput)
}
for _, expressionDetail := range verifyExpressionsInput.ExpressionDetailList {
if expressionDetail.ExpressionId == "" || expressionDetail.Expression == "" || expressionDetail.Type == "" || expressionDetail.ReferenceElement.Name == "" {
return fmt.Errorf("one of expressionId, expression, type or referenceElement.Name is empty. These are mandatory fields = %v", expressionDetail)
}
}
return nil
}

// We simplify conv to remove any existing expressions that are part of the SpSchema to ensure that the stagingDB creation
// does not fail due to inconsistent, user configured expressions during a schema conversion session.
// The minimal conv object needed for stagingDB is one which contains all table and column definitions only.
func (ev *ExpressionVerificationAccessorImpl) removeExpressions(inputConv *internal.Conv) (*internal.Conv, error) {
convCopy := &internal.Conv{}
convJSON, err := json.Marshal(inputConv)
if err != nil {
return nil, fmt.Errorf("error marshaling conv: %v", err)
}
err = json.Unmarshal(convJSON, convCopy)
if err != nil {
return nil, fmt.Errorf("error unmarshaling conv: %v", err)
}
//Set sequences as nil
//TODO: Implement similar checks for DEFAULT and CHECK constraints as well
convCopy.SpSequences = nil
for _, table := range convCopy.SpSchema {
for colName, colDef := range table.ColDefs {
colDef.AutoGen = ddl.AutoGenCol{}
table.ColDefs[colName] = colDef
}
}
return convCopy, nil
}
Loading

0 comments on commit 7a76a1b

Please sign in to comment.