diff --git a/internal/connectors/db/connector.go b/internal/connectors/db/connector.go index 5b67cf9c..f0b1bec7 100644 --- a/internal/connectors/db/connector.go +++ b/internal/connectors/db/connector.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "time" "ydbcp/internal/config" "ydbcp/internal/connectors/db/yql/queries" @@ -13,6 +14,7 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3" "github.com/ydb-platform/ydb-go-sdk/v3/balancers" + "github.com/ydb-platform/ydb-go-sdk/v3/query" "github.com/ydb-platform/ydb-go-sdk/v3/table" "github.com/ydb-platform/ydb-go-sdk/v3/table/result" table_types "github.com/ydb-platform/ydb-go-sdk/v3/table/types" @@ -21,22 +23,22 @@ import ( ) var ( - readTx = table.TxControl( - table.BeginTx( - table.WithOnlineReadOnly(), + readTx = query.TxControl( + query.BeginTx( + query.WithOnlineReadOnly(), ), - table.CommitTx(), + query.CommitTx(), ) - writeTx = table.TxControl( - table.BeginTx( - table.WithSerializableReadWrite(), + writeTx = query.TxControl( + query.BeginTx( + query.WithSerializableReadWrite(), ), - table.CommitTx(), + query.CommitTx(), ) ) type DBConnector interface { - GetTableClient() table.Client + GetQueryClient() query.Client SelectBackups(ctx context.Context, queryBuilder queries.ReadTableQuery) ( []*types.Backup, error, ) @@ -127,8 +129,8 @@ func NewYdbConnector(ctx context.Context, config config.YDBConnectionConfig) (*Y return &YdbConnector{driver: driver}, nil } -func (d *YdbConnector) GetTableClient() table.Client { - return d.driver.Table() +func (d *YdbConnector) GetQueryClient() query.Client { + return d.driver.Query() } func (d *YdbConnector) Close(ctx context.Context) { @@ -151,31 +153,54 @@ func DoStructSelect[T any]( if err != nil { return nil, err } - err = d.GetTableClient().Do( - ctx, func(ctx context.Context, s table.Session) error { + err = d.GetQueryClient().Do( + ctx, func(ctx context.Context, s query.Session) error { var ( - res result.Result + res query.Result ) _, res, err = s.Execute( ctx, - readTx, - queryFormat.QueryText, queryFormat.QueryParams, + queryFormat.QueryText, + query.WithParameters(queryFormat.QueryParams), + query.WithTxControl(readTx), ) if err != nil { return err } - defer func(res result.Result) { - err = res.Close() + defer func(res query.Result) { + err = res.Close(ctx) if err != nil { xlog.Error(ctx, "Error closing transaction result") } }(res) // result must be closed - if res.ResultSetCount() != 1 { - return errors.New("expected 1 result set") - } - for res.NextResultSet(ctx) { - for res.NextRow() { - entity, readErr := readLambda(res) + + resultSetCount := 0 + for { + resultSet, err := res.NextResultSet(ctx) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + + return err + } + + resultSetCount++ + if resultSetCount > 1 { + return errors.New("expected 1 result set") + } + + for { + row, err := resultSet.NextRow(ctx) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + + return err + } + + entity, readErr := readLambda(row) if readErr != nil { return readErr } @@ -205,31 +230,54 @@ func DoInterfaceSelect[T any]( if err != nil { return nil, err } - err = d.GetTableClient().Do( - ctx, func(ctx context.Context, s table.Session) error { + err = d.GetQueryClient().Do( + ctx, func(ctx context.Context, s query.Session) error { var ( - res result.Result + res query.Result ) _, res, err = s.Execute( ctx, - readTx, - queryFormat.QueryText, queryFormat.QueryParams, + queryFormat.QueryText, + query.WithParameters(queryFormat.QueryParams), + query.WithTxControl(readTx), ) if err != nil { return err } - defer func(res result.Result) { - err = res.Close() + defer func(res query.Result) { + err = res.Close(ctx) if err != nil { xlog.Error(ctx, "Error closing transaction result") } }(res) // result must be closed - if res.ResultSetCount() != 1 { - return errors.New("expected 1 result set") - } - for res.NextResultSet(ctx) { - for res.NextRow() { - entity, readErr := readLambda(res) + + resultSetCount := 0 + for { + resultSet, err := res.NextResultSet(ctx) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + + return err + } + + resultSetCount++ + if resultSetCount > 1 { + return errors.New("expected 1 result set") + } + + for { + row, err := resultSet.NextRow(ctx) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + + return err + } + + entity, readErr := readLambda(row) if readErr != nil { return readErr } @@ -251,13 +299,13 @@ func (d *YdbConnector) ExecuteUpsert(ctx context.Context, queryBuilder queries.W if err != nil { return err } - err = d.GetTableClient().Do( - ctx, func(ctx context.Context, s table.Session) (err error) { + err = d.GetQueryClient().Do( + ctx, func(ctx context.Context, s query.Session) (err error) { _, _, err = s.Execute( ctx, - writeTx, queryFormat.QueryText, - queryFormat.QueryParams, + query.WithParameters(queryFormat.QueryParams), + query.WithTxControl(writeTx), ) if err != nil { return err @@ -306,7 +354,7 @@ func (d *YdbConnector) SelectBackupSchedules( ctx, d, queryBuilder, - func(res result.Result) (*types.BackupSchedule, error) { + func(res query.Row) (*types.BackupSchedule, error) { return ReadBackupScheduleFromResultSet(res, false) }, ) @@ -319,7 +367,7 @@ func (d *YdbConnector) SelectBackupSchedulesWithRPOInfo( ctx, d, queryBuilder, - func(res result.Result) (*types.BackupSchedule, error) { + func(res query.Row) (*types.BackupSchedule, error) { return ReadBackupScheduleFromResultSet(res, true) }, ) diff --git a/internal/connectors/db/mock.go b/internal/connectors/db/mock.go index 43ebc895..1100ac0e 100644 --- a/internal/connectors/db/mock.go +++ b/internal/connectors/db/mock.go @@ -9,7 +9,7 @@ import ( "ydbcp/internal/connectors/db/yql/queries" "ydbcp/internal/types" - "github.com/ydb-platform/ydb-go-sdk/v3/table" + "github.com/ydb-platform/ydb-go-sdk/v3/query" ) type MockDBConnector struct { @@ -113,7 +113,7 @@ func (c *MockDBConnector) UpdateBackup( } func (c *MockDBConnector) Close(_ context.Context) {} -func (c *MockDBConnector) GetTableClient() table.Client { +func (c *MockDBConnector) GetQueryClient() query.Client { return nil } diff --git a/internal/connectors/db/process_result_set.go b/internal/connectors/db/process_result_set.go index 88f02bd2..07f3a514 100644 --- a/internal/connectors/db/process_result_set.go +++ b/internal/connectors/db/process_result_set.go @@ -7,15 +7,14 @@ import ( "ydbcp/internal/types" pb "ydbcp/pkg/proto/ydbcp/v1alpha1" - "github.com/ydb-platform/ydb-go-sdk/v3/table/result" - "github.com/ydb-platform/ydb-go-sdk/v3/table/result/named" + "github.com/ydb-platform/ydb-go-sdk/v3/query" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" ) -type StructFromResultSet[T any] func(result result.Result) (*T, error) +type StructFromResultSet[T any] func(result query.Row) (*T, error) -type InterfaceFromResultSet[T any] func(result result.Result) (T, error) +type InterfaceFromResultSet[T any] func(result query.Row) (T, error) func StringOrDefault(str *string, def string) string { if str == nil { @@ -57,23 +56,35 @@ func auditFromDb(initiated *string, createdAt *time.Time, completedAt *time.Time } } +func checkRequiredFields(fields map[string]*string) error { + for fieldName, fieldValue := range fields { + if fieldValue == nil { + return fmt.Errorf("failed to read required field %s", fieldName) + } + } + return nil +} + //TODO: unit test this -func ReadBackupFromResultSet(res result.Result) (*types.Backup, error) { +func ReadBackupFromResultSet(res query.Row) (*types.Backup, error) { var ( - backupId string - containerId string - databaseName string - databaseEndpoint string - s3endpoint *string - s3region *string - s3bucket *string - s3pathprefix *string - status *string - message *string - size *int64 - scheduleId *string - sourcePaths *string + // required fields + backupId *string + containerId *string + databaseName *string + databaseEndpoint *string + + // optional fields + s3endpoint *string + s3region *string + s3bucket *string + s3pathPrefix *string + status *string + message *string + size *int64 + scheduleId *string + sourcePaths *string creator *string completedAt *time.Time @@ -81,30 +92,41 @@ func ReadBackupFromResultSet(res result.Result) (*types.Backup, error) { expireAt *time.Time ) + requiredFields := map[string]*string{ + "backup_id": backupId, + "container_id": containerId, + "database_name": databaseName, + "database_endpoint": databaseEndpoint, + } + err := res.ScanNamed( - named.Required("id", &backupId), - named.Required("container_id", &containerId), - named.Required("database", &databaseName), - named.Required("endpoint", &databaseEndpoint), - named.Optional("s3_endpoint", &s3endpoint), - named.Optional("s3_region", &s3region), - named.Optional("s3_bucket", &s3bucket), - named.Optional("s3_path_prefix", &s3pathprefix), - named.Optional("status", &status), - named.Optional("message", &message), - named.Optional("size", &size), - named.Optional("schedule_id", &scheduleId), - named.Optional("expire_at", &expireAt), - named.Optional("paths", &sourcePaths), - - named.Optional("created_at", &createdAt), - named.Optional("completed_at", &completedAt), - named.Optional("initiated", &creator), + query.Named("id", &backupId), + query.Named("container_id", &containerId), + query.Named("database", &databaseName), + query.Named("endpoint", &databaseEndpoint), + query.Named("s3_endpoint", &s3endpoint), + query.Named("s3_region", &s3region), + query.Named("s3_bucket", &s3bucket), + query.Named("s3_path_prefix", &s3pathPrefix), + query.Named("status", &status), + query.Named("message", &message), + query.Named("size", &size), + query.Named("schedule_id", &scheduleId), + query.Named("expire_at", &expireAt), + query.Named("paths", &sourcePaths), + + query.Named("created_at", &createdAt), + query.Named("completed_at", &completedAt), + query.Named("initiated", &creator), ) if err != nil { return nil, err } + if err = checkRequiredFields(requiredFields); err != nil { + return nil, err + } + sourcePathsSlice := make([]string, 0) if sourcePaths != nil { sourcePathsSlice, err = types.ParseSourcePaths(*sourcePaths) @@ -114,14 +136,14 @@ func ReadBackupFromResultSet(res result.Result) (*types.Backup, error) { } return &types.Backup{ - ID: backupId, - ContainerID: containerId, - DatabaseName: databaseName, - DatabaseEndpoint: databaseEndpoint, + ID: *backupId, + ContainerID: *containerId, + DatabaseName: *databaseName, + DatabaseEndpoint: *databaseEndpoint, S3Endpoint: StringOrEmpty(s3endpoint), S3Region: StringOrEmpty(s3region), S3Bucket: StringOrEmpty(s3bucket), - S3PathPrefix: StringOrEmpty(s3pathprefix), + S3PathPrefix: StringOrEmpty(s3pathPrefix), Status: StringOrDefault(status, types.BackupStateUnknown), Message: StringOrEmpty(message), AuditInfo: auditFromDb(creator, createdAt, completedAt), @@ -132,14 +154,16 @@ func ReadBackupFromResultSet(res result.Result) (*types.Backup, error) { }, nil } -func ReadOperationFromResultSet(res result.Result) (types.Operation, error) { +func ReadOperationFromResultSet(res query.Row) (types.Operation, error) { var ( - operationId string - containerId string - operationType string - databaseName string - databaseEndpoint string - + // required fields + operationId *string + containerId *string + operationType *string + databaseName *string + databaseEndpoint *string + + // optional fields backupId *string ydbOperationId *string operationStateBuf *string @@ -158,34 +182,48 @@ func ReadOperationFromResultSet(res result.Result) (types.Operation, error) { retriesCount *uint32 maxBackoff *time.Duration ) + + requiredFields := map[string]*string{ + "operation_id": operationId, + "container_id": containerId, + "operation_type": operationType, + "database_name": databaseName, + "database_endpoint": databaseEndpoint, + } + err := res.ScanNamed( - named.Required("id", &operationId), - named.Required("container_id", &containerId), - named.Required("type", &operationType), - named.Required("database", &databaseName), - named.Required("endpoint", &databaseEndpoint), - - named.Optional("backup_id", &backupId), - named.Optional("operation_id", &ydbOperationId), - named.Optional("status", &operationStateBuf), - named.Optional("message", &message), - named.Optional("paths", &sourcePaths), - named.Optional("paths_to_exclude", &sourcePathsToExclude), - - named.Optional("created_at", &createdAt), - named.Optional("completed_at", &completedAt), - named.Optional("initiated", &creator), - named.Optional("updated_at", &updatedAt), - named.Optional("parent_operation_id", &parentOperationID), - named.Optional("schedule_id", &scheduleID), - named.Optional("ttl", &ttl), - named.Optional("retries", &retries), - named.Optional("retries_count", &retriesCount), - named.Optional("retries_max_backoff", &maxBackoff), + query.Named("id", &operationId), + query.Named("container_id", &containerId), + query.Named("type", &operationType), + query.Named("database", &databaseName), + query.Named("endpoint", &databaseEndpoint), + + query.Named("backup_id", &backupId), + query.Named("operation_id", &ydbOperationId), + query.Named("status", &operationStateBuf), + query.Named("message", &message), + query.Named("paths", &sourcePaths), + query.Named("paths_to_exclude", &sourcePathsToExclude), + + query.Named("created_at", &createdAt), + query.Named("completed_at", &completedAt), + query.Named("initiated", &creator), + query.Named("updated_at", &updatedAt), + query.Named("parent_operation_id", &parentOperationID), + query.Named("schedule_id", &scheduleID), + query.Named("ttl", &ttl), + query.Named("retries", &retries), + query.Named("retries_count", &retriesCount), + query.Named("retries_max_backoff", &maxBackoff), ) if err != nil { return nil, err } + + if err = checkRequiredFields(requiredFields); err != nil { + return nil, err + } + operationState := types.OperationStateUnknown if operationStateBuf != nil { operationState = types.OperationState(*operationStateBuf) @@ -210,19 +248,19 @@ func ReadOperationFromResultSet(res result.Result) (types.Operation, error) { updatedTs = timestamppb.New(*updatedAt) } - if operationType == string(types.OperationTypeTB) { + if *operationType == string(types.OperationTypeTB) { if backupId == nil { return nil, fmt.Errorf("failed to read backup_id for TB operation: %s", operationId) } return &types.TakeBackupOperation{ - ID: operationId, + ID: *operationId, BackupID: *backupId, - ContainerID: containerId, + ContainerID: *containerId, State: operationState, Message: StringOrEmpty(message), YdbConnectionParams: types.YdbConnectionParams{ - Endpoint: databaseEndpoint, - DatabaseName: databaseName, + Endpoint: *databaseEndpoint, + DatabaseName: *databaseName, }, SourcePaths: sourcePathsSlice, SourcePathsToExclude: sourcePathsToExcludeSlice, @@ -231,26 +269,26 @@ func ReadOperationFromResultSet(res result.Result) (types.Operation, error) { UpdatedAt: updatedTs, ParentOperationID: parentOperationID, }, nil - } else if operationType == string(types.OperationTypeRB) { + } else if *operationType == string(types.OperationTypeRB) { if backupId == nil { return nil, fmt.Errorf("failed to read backup_id for RB operation: %s", operationId) } return &types.RestoreBackupOperation{ - ID: operationId, + ID: *operationId, BackupId: *backupId, - ContainerID: containerId, + ContainerID: *containerId, State: operationState, Message: StringOrEmpty(message), YdbConnectionParams: types.YdbConnectionParams{ - Endpoint: databaseEndpoint, - DatabaseName: databaseName, + Endpoint: *databaseEndpoint, + DatabaseName: *databaseName, }, YdbOperationId: StringOrEmpty(ydbOperationId), SourcePaths: sourcePathsSlice, Audit: auditFromDb(creator, createdAt, completedAt), UpdatedAt: updatedTs, }, nil - } else if operationType == string(types.OperationTypeDB) { + } else if *operationType == string(types.OperationTypeDB) { if backupId == nil { return nil, fmt.Errorf("failed to read backup_id for DB operation: %s", operationId) } @@ -261,12 +299,12 @@ func ReadOperationFromResultSet(res result.Result) (types.Operation, error) { } return &types.DeleteBackupOperation{ - ID: operationId, + ID: *operationId, BackupID: *backupId, - ContainerID: containerId, + ContainerID: *containerId, YdbConnectionParams: types.YdbConnectionParams{ - Endpoint: databaseEndpoint, - DatabaseName: databaseName, + Endpoint: *databaseEndpoint, + DatabaseName: *databaseName, }, State: operationState, Message: StringOrEmpty(message), @@ -274,7 +312,7 @@ func ReadOperationFromResultSet(res result.Result) (types.Operation, error) { PathPrefix: pathPrefix, UpdatedAt: updatedTs, }, nil - } else if operationType == string(types.OperationTypeTBWR) { + } else if *operationType == string(types.OperationTypeTBWR) { var retryConfig *pb.RetryConfig = nil if maxBackoff != nil { retryConfig = &pb.RetryConfig{ @@ -294,13 +332,13 @@ func ReadOperationFromResultSet(res result.Result) (types.Operation, error) { } return &types.TakeBackupWithRetryOperation{ TakeBackupOperation: types.TakeBackupOperation{ - ID: operationId, - ContainerID: containerId, + ID: *operationId, + ContainerID: *containerId, State: operationState, Message: StringOrEmpty(message), YdbConnectionParams: types.YdbConnectionParams{ - Endpoint: databaseEndpoint, - DatabaseName: databaseName, + Endpoint: *databaseEndpoint, + DatabaseName: *databaseName, }, SourcePaths: sourcePathsSlice, SourcePathsToExclude: sourcePathsToExcludeSlice, @@ -314,18 +352,20 @@ func ReadOperationFromResultSet(res result.Result) (types.Operation, error) { }, nil } - return &types.GenericOperation{ID: operationId}, nil + return &types.GenericOperation{ID: *operationId}, nil } -func ReadBackupScheduleFromResultSet(res result.Result, withRPOInfo bool) (*types.BackupSchedule, error) { +func ReadBackupScheduleFromResultSet(res query.Row, withRPOInfo bool) (*types.BackupSchedule, error) { var ( - ID string - containerID string - databaseName string - databaseEndpoint string + // required fields + ID *string + containerID *string + databaseName *string + databaseEndpoint *string - crontab string + crontab *string + // optional fields status *string initiated *string createdAt *time.Time @@ -340,27 +380,35 @@ func ReadBackupScheduleFromResultSet(res result.Result, withRPOInfo bool) (*type nextLaunch *time.Time ) - namedValues := []named.Value{ - named.Required("id", &ID), - named.Required("container_id", &containerID), - named.Required("database", &databaseName), - named.Required("endpoint", &databaseEndpoint), - named.Required("crontab", &crontab), - - named.Optional("status", &status), - named.Optional("initiated", &initiated), - named.Optional("created_at", &createdAt), - named.Optional("name", &name), - named.Optional("ttl", &ttl), - named.Optional("paths", &sourcePaths), - named.Optional("paths_to_exclude", &sourcePathsToExclude), - named.Optional("recovery_point_objective", &recoveryPointObjective), - named.Optional("next_launch", &nextLaunch), + requiredFields := map[string]*string{ + "id": ID, + "container_id": containerID, + "database_name": databaseName, + "database_endpoint": databaseEndpoint, + "crontab": crontab, + } + + namedValues := []query.NamedDestination{ + query.Named("id", &ID), + query.Named("container_id", &containerID), + query.Named("database", &databaseName), + query.Named("endpoint", &databaseEndpoint), + query.Named("crontab", &crontab), + + query.Named("status", &status), + query.Named("initiated", &initiated), + query.Named("created_at", &createdAt), + query.Named("name", &name), + query.Named("ttl", &ttl), + query.Named("paths", &sourcePaths), + query.Named("paths_to_exclude", &sourcePathsToExclude), + query.Named("recovery_point_objective", &recoveryPointObjective), + query.Named("next_launch", &nextLaunch), } if withRPOInfo { - namedValues = append(namedValues, named.Optional("last_backup_id", &lastBackupID)) - namedValues = append(namedValues, named.Optional("last_successful_backup_id", &lastSuccessfulBackupID)) - namedValues = append(namedValues, named.Optional("recovery_point", &recoveryPoint)) + namedValues = append(namedValues, query.Named("last_backup_id", &lastBackupID)) + namedValues = append(namedValues, query.Named("last_successful_backup_id", &lastSuccessfulBackupID)) + namedValues = append(namedValues, query.Named("recovery_point", &recoveryPoint)) } err := res.ScanNamed(namedValues...) @@ -369,6 +417,10 @@ func ReadBackupScheduleFromResultSet(res result.Result, withRPOInfo bool) (*type return nil, err } + if err = checkRequiredFields(requiredFields); err != nil { + return nil, err + } + var sourcePathsSlice []string var sourcePathsToExcludeSlice []string if sourcePaths != nil { @@ -397,17 +449,17 @@ func ReadBackupScheduleFromResultSet(res result.Result, withRPOInfo bool) (*type } return &types.BackupSchedule{ - ID: ID, - ContainerID: containerID, - DatabaseName: databaseName, - DatabaseEndpoint: databaseEndpoint, + ID: *ID, + ContainerID: *containerID, + DatabaseName: *databaseName, + DatabaseEndpoint: *databaseEndpoint, SourcePaths: sourcePathsSlice, SourcePathsToExclude: sourcePathsToExcludeSlice, Audit: auditFromDb(initiated, createdAt, nil), Name: name, Status: StringOrDefault(status, types.BackupScheduleStateUnknown), ScheduleSettings: &pb.BackupScheduleSettings{ - SchedulePattern: &pb.BackupSchedulePattern{Crontab: crontab}, + SchedulePattern: &pb.BackupSchedulePattern{Crontab: *crontab}, Ttl: ttlDuration, RecoveryPointObjective: rpoDuration, },