diff --git a/datacatalog/pkg/repositories/handle.go b/datacatalog/pkg/repositories/handle.go index e552fb9d01..e9bf395765 100644 --- a/datacatalog/pkg/repositories/handle.go +++ b/datacatalog/pkg/repositories/handle.go @@ -61,7 +61,7 @@ func (h *DBHandle) CreateDB(dbName string) error { result = h.db.Exec(createDBStatement) if result.Error != nil { - if !isPgErrorWithCode(result.Error, pqDbAlreadyExistsCode) { + if !database.IsPgErrorWithCode(result.Error, database.PqDbAlreadyExistsCode) { return result.Error } logger.Infof(context.TODO(), "Not creating database %s, already exists", dbName) diff --git a/datacatalog/pkg/repositories/initialize.go b/datacatalog/pkg/repositories/initialize.go index e020f5dc04..150801345f 100644 --- a/datacatalog/pkg/repositories/initialize.go +++ b/datacatalog/pkg/repositories/initialize.go @@ -2,11 +2,9 @@ package repositories import ( "context" - "errors" + "github.com/flyteorg/flyte/flytestdlib/database" "reflect" - "github.com/jackc/pgconn" - errors2 "github.com/flyteorg/flyte/datacatalog/pkg/repositories/errors" "github.com/flyteorg/flyte/datacatalog/pkg/runtime" "github.com/flyteorg/flyte/flytestdlib/logger" @@ -18,8 +16,6 @@ var migrateScope = migrationsScope.NewSubScope("migrate") // all postgres servers come by default with a db name named postgres const defaultDB = "postgres" -const pqInvalidDBCode = "3D000" -const pqDbAlreadyExistsCode = "42P04" // Migrate This command will run all the migrations for the database func Migrate(ctx context.Context) error { @@ -31,14 +27,14 @@ func Migrate(ctx context.Context) error { if err != nil { // if db does not exist, try creating it - cErr, ok := err.(errors2.ConnectError) + _, ok := err.(errors2.ConnectError) if !ok { logger.Errorf(ctx, "Failed to cast error of type: %v, err: %v", reflect.TypeOf(err), err) panic(err) } - pqError := cErr.Unwrap().(*pgconn.PgError) - if pqError.Code == pqInvalidDBCode { + + if database.IsPgErrorWithCode(err, database.PqInvalidDBCode) { logger.Warningf(ctx, "Database [%v] does not exist, trying to create it now", dbName) dbConfigValues.Postgres.DbName = defaultDB @@ -78,14 +74,3 @@ func Migrate(ctx context.Context) error { logger.Infof(ctx, "Ran DB migration successfully.") return nil } - -func isPgErrorWithCode(err error, code string) bool { - pgErr := &pgconn.PgError{} - if !errors.As(err, &pgErr) { - // err chain does not contain a pgconn.PgError - return false - } - - // pgconn.PgError found in chain and set to code specified - return pgErr.Code == code -} diff --git a/flyteadmin/pkg/repositories/errors/postgres.go b/flyteadmin/pkg/repositories/errors/postgres.go index 7be734aa13..3993e140c1 100644 --- a/flyteadmin/pkg/repositories/errors/postgres.go +++ b/flyteadmin/pkg/repositories/errors/postgres.go @@ -15,6 +15,7 @@ import ( "reflect" "github.com/jackc/pgconn" + pgxPgconn "github.com/jackc/pgx/v5/pgconn" "github.com/prometheus/client_golang/prometheus" "google.golang.org/grpc/codes" "gorm.io/gorm" @@ -68,23 +69,29 @@ func (p *postgresErrorTransformer) ToFlyteAdminError(err error) flyteAdminErrors err = unwrappedErr } - pqError, ok := err.(*pgconn.PgError) - if !ok { - logger.Debugf(context.Background(), "Unable to cast to pgconn.PgError. Error type: [%v]", - reflect.TypeOf(err)) - return p.fromGormError(err) + // Try things both ways, the two pgconns are different types. + if pqError, ok := err.(*pgconn.PgError); ok { + return p.flyteAdminErrorFromCode(pqError.Code, pqError.Message) + } else if pqError, ok := err.(*pgxPgconn.PgError); ok { + return p.flyteAdminErrorFromCode(pqError.Code, pqError.Message) } - switch pqError.Code { + logger.Debugf(context.Background(), "Unable to cast to pgconn.PgError. Error type: [%v]", + reflect.TypeOf(err)) + return p.fromGormError(err) +} + +func (p *postgresErrorTransformer) flyteAdminErrorFromCode(pqErrorCode string, message string) flyteAdminErrors.FlyteAdminError { + switch pqErrorCode { case uniqueConstraintViolationCode: p.metrics.AlreadyExistsError.Inc() - return flyteAdminErrors.NewFlyteAdminErrorf(codes.AlreadyExists, uniqueConstraintViolation, pqError.Message) + return flyteAdminErrors.NewFlyteAdminErrorf(codes.AlreadyExists, uniqueConstraintViolation, message) case undefinedTable: p.metrics.UndefinedTable.Inc() - return flyteAdminErrors.NewFlyteAdminErrorf(codes.InvalidArgument, unsupportedTableOperation, pqError.Message) + return flyteAdminErrors.NewFlyteAdminErrorf(codes.InvalidArgument, unsupportedTableOperation, message) default: p.metrics.PostgresError.Inc() - return flyteAdminErrors.NewFlyteAdminError(codes.Unknown, fmt.Sprintf(defaultPgError, pqError.Message)) + return flyteAdminErrors.NewFlyteAdminError(codes.Unknown, fmt.Sprintf(defaultPgError, message)) } } diff --git a/flyteadmin/pkg/repositories/errors/postgres_test.go b/flyteadmin/pkg/repositories/errors/postgres_test.go index dcc988e3a7..5b73e31bec 100644 --- a/flyteadmin/pkg/repositories/errors/postgres_test.go +++ b/flyteadmin/pkg/repositories/errors/postgres_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/jackc/pgconn" + pgxPgconn "github.com/jackc/pgx/v5/pgconn" "github.com/magiconair/properties/assert" "google.golang.org/grpc/codes" @@ -27,6 +28,16 @@ func TestToFlyteAdminError_UniqueConstraintViolation(t *testing.T) { assert.Equal(t, codes.AlreadyExists, transformedErr.Code()) assert.Equal(t, "value with matching already exists (message)", transformedErr.Error()) + + err2 := &pgxPgconn.PgError{ + Code: "23505", + Message: "message", + } + transformedErr = NewPostgresErrorTransformer(mockScope.NewTestScope()).ToFlyteAdminError(err2) + assert.Equal(t, codes.AlreadyExists, transformedErr.Code()) + assert.Equal(t, "value with matching already exists (message)", + transformedErr.Error()) + } func TestToFlyteAdminError_UnrecognizedPostgresError(t *testing.T) { @@ -38,4 +49,13 @@ func TestToFlyteAdminError_UnrecognizedPostgresError(t *testing.T) { assert.Equal(t, codes.Unknown, transformedErr.Code()) assert.Equal(t, "failed database operation with message", transformedErr.Error()) + + err2 := &pgxPgconn.PgError{ + Code: "foo", + Message: "message", + } + transformedErr = NewPostgresErrorTransformer(mockScope.NewTestScope()).ToFlyteAdminError(err2) + assert.Equal(t, codes.Unknown, transformedErr.Code()) + assert.Equal(t, "failed database operation with message", + transformedErr.Error()) } diff --git a/flytestdlib/database/postgres.go b/flytestdlib/database/postgres.go index 1254dd5d14..be5c118e58 100644 --- a/flytestdlib/database/postgres.go +++ b/flytestdlib/database/postgres.go @@ -16,8 +16,8 @@ import ( "github.com/flyteorg/flyte/flytestdlib/logger" ) -const pqInvalidDBCode = "3D000" -const pqDbAlreadyExistsCode = "42P04" +const PqInvalidDBCode = "3D000" +const PqDbAlreadyExistsCode = "42P04" const PgDuplicatedForeignKey = "23503" const PgDuplicatedKey = "23505" const defaultDB = "postgres" @@ -61,7 +61,7 @@ func CreatePostgresDbIfNotExists(ctx context.Context, gormConfig *gorm.Config, p if err == nil { return gormDb, nil } - if !IsPgErrorWithCode(err, pqInvalidDBCode) { + if !IsPgErrorWithCode(err, PqInvalidDBCode) { return nil, err } logger.Warningf(ctx, "Database [%v] does not exist", pgConfig.DbName) @@ -84,7 +84,7 @@ func CreatePostgresDbIfNotExists(ctx context.Context, gormConfig *gorm.Config, p result := gormDb.Exec(createDBStatement) if result.Error != nil { - if !IsPgErrorWithCode(result.Error, pqDbAlreadyExistsCode) { + if !IsPgErrorWithCode(result.Error, PqDbAlreadyExistsCode) { return nil, result.Error } logger.Warningf(ctx, "Got DB already exists error for [%s], skipping...", pgConfig.DbName) diff --git a/flytestdlib/database/postgres_test.go b/flytestdlib/database/postgres_test.go index b84698291c..311b05c351 100644 --- a/flytestdlib/database/postgres_test.go +++ b/flytestdlib/database/postgres_test.go @@ -112,7 +112,7 @@ func TestIsInvalidDBPgError(t *testing.T) { tc := tc t.Run(tc.Name, func(t *testing.T) { - assert.Equal(t, tc.ExpectedResult, IsPgErrorWithCode(tc.Err, pqInvalidDBCode)) + assert.Equal(t, tc.ExpectedResult, IsPgErrorWithCode(tc.Err, PqInvalidDBCode)) }) } } @@ -150,7 +150,7 @@ func TestIsPgDbAlreadyExistsError(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.Name, func(t *testing.T) { - assert.Equal(t, tc.ExpectedResult, IsPgErrorWithCode(tc.Err, pqDbAlreadyExistsCode)) + assert.Equal(t, tc.ExpectedResult, IsPgErrorWithCode(tc.Err, PqDbAlreadyExistsCode)) }) } }