Skip to content

Commit

Permalink
Endpoint validation has been added (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
bma13 authored Aug 21, 2024
1 parent 592a797 commit 8e169b1
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 23 deletions.
4 changes: 4 additions & 0 deletions cmd/ydbcp/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ db_connection:
client_connection:
insecure: true
discovery: false
allowed_endpoint_domains:
- .allowed-domain.com
- allowed-hostname.domain.com
allow_insecure_endpoint: false

s3:
endpoint: s3.endpoint.com
Expand Down
9 changes: 8 additions & 1 deletion cmd/ydbcp/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,14 @@ func main() {
}
}()

backup.NewBackupService(dbConnector, clientConnector, configInstance.S3, authProvider).Register(server)
backup.NewBackupService(
dbConnector,
clientConnector,
configInstance.S3,
authProvider,
configInstance.ClientConnection.AllowedEndpointDomains,
configInstance.ClientConnection.AllowInsecureEndpoint,
).Register(server)
operation.NewOperationService(dbConnector, authProvider).Register(server)

if err := server.Start(ctx, &wg); err != nil {
Expand Down
7 changes: 2 additions & 5 deletions internal/auth/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,10 @@ func (p *MockAuthProvider) Authorize(
return results, subject, nil
}

func NewMockAuthProvider(ctx context.Context, options ...Option) (auth.AuthProvider, error) {
func NewMockAuthProvider(options ...Option) *MockAuthProvider {
p := &MockAuthProvider{}
for _, opt := range options {
opt(p)
}
if err := p.Init(ctx, ""); err != nil {
return nil, err
}
return p, nil
return p
}
26 changes: 21 additions & 5 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"os"
"regexp"
"strings"
"ydbcp/internal/util/xlog"

Expand All @@ -30,10 +31,12 @@ type YDBConnectionConfig struct {
}

type ClientConnectionConfig struct {
Insecure bool `yaml:"insecure"`
Discovery bool `yaml:"discovery" default:"true"`
DialTimeoutSeconds uint32 `yaml:"dial_timeout_seconds" default:"5"`
OAuth2KeyFile string `yaml:"oauth2_key_file"`
Insecure bool `yaml:"insecure"`
Discovery bool `yaml:"discovery" default:"true"`
DialTimeoutSeconds uint32 `yaml:"dial_timeout_seconds" default:"5"`
OAuth2KeyFile string `yaml:"oauth2_key_file"`
AllowedEndpointDomains []string `yaml:"allowed_endpoint_domains"`
AllowInsecureEndpoint bool `yaml:"allow_insecure_endpoint"`
}

type AuthConfig struct {
Expand All @@ -57,6 +60,10 @@ type Config struct {
GRPCServer GRPCServerConfig `yaml:"grpc_server"`
}

var (
validDomainFilter = regexp.MustCompile(`^[A-Za-z\.][A-Za-z0-9\-\.]+[A-Za-z]$`)
)

func (config Config) ToString() (string, error) {
data, err := yaml.Marshal(&config)
if err != nil {
Expand All @@ -82,11 +89,20 @@ func InitConfig(ctx context.Context, confPath string) (Config, error) {
zap.Error(err))
return Config{}, err
}
return config, nil
return config, config.Validate()
}
return Config{}, errors.New("configuration file path is empty")
}

func (c *Config) Validate() error {
for _, domain := range c.ClientConnection.AllowedEndpointDomains {
if !validDomainFilter.MatchString(domain) {
return fmt.Errorf("incorrect domain filter in allowed_endpoint_domains: %s", domain)
}
}
return nil
}

func readSecret(filename string) (string, error) {
rawSecret, err := os.ReadFile(filename)
if err != nil {
Expand Down
66 changes: 54 additions & 12 deletions internal/server/services/backup/backupservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package backup
import (
"context"
"path"
"regexp"
"strings"

table_types "github.com/ydb-platform/ydb-go-sdk/v3/table/types"
Expand All @@ -26,10 +27,39 @@ import (

type BackupService struct {
pb.UnimplementedBackupServiceServer
driver db.DBConnector
clientConn client.ClientConnector
s3 config.S3Config
auth ap.AuthProvider
driver db.DBConnector
clientConn client.ClientConnector
s3 config.S3Config
auth ap.AuthProvider
allowedEndpointDomains []string
allowInsecureEndpoint bool
}

var (
validEndpoint = regexp.MustCompile(`^(grpcs://|grpc://)?([A-Za-z0-9\-\.]+)(:[0-9]+)?$`)
)

func (s *BackupService) isAllowedEndpoint(e string) bool {
groups := validEndpoint.FindStringSubmatch(e)
if len(groups) < 3 {
return false
}
tls := groups[1] == "grpcs://"
if !tls && !s.allowInsecureEndpoint {
return false
}
fqdn := groups[2]

for _, domain := range s.allowedEndpointDomains {
if strings.HasPrefix(domain, ".") {
if strings.HasSuffix(fqdn, domain) {
return true
}
} else if fqdn == domain {
return true
}
}
return false
}

func (s *BackupService) GetBackup(ctx context.Context, request *pb.GetBackupRequest) (*pb.Backup, error) {
Expand Down Expand Up @@ -75,9 +105,13 @@ func (s *BackupService) MakeBackup(ctx context.Context, req *pb.MakeBackupReques
}
xlog.Debug(ctx, "MakeBackup", zap.String("subject", subject))

if !s.isAllowedEndpoint(req.DatabaseEndpoint) {
return nil, status.Errorf(codes.InvalidArgument, "endpoint of database is invalid or not allowed, endpoint %s", req.DatabaseEndpoint)
}

clientConnectionParams := types.YdbConnectionParams{
Endpoint: req.GetDatabaseEndpoint(),
DatabaseName: req.GetDatabaseName(),
Endpoint: req.DatabaseEndpoint,
DatabaseName: req.DatabaseName,
}
dsn := types.MakeYdbConnectionString(clientConnectionParams)
client, err := s.clientConn.Open(ctx, dsn)
Expand Down Expand Up @@ -186,9 +220,13 @@ func (s *BackupService) MakeRestore(ctx context.Context, req *pb.MakeRestoreRequ
}
xlog.Debug(ctx, "MakeRestore", zap.String("subject", subject))

if !s.isAllowedEndpoint(req.DatabaseEndpoint) {
return nil, status.Errorf(codes.InvalidArgument, "endpoint of database is invalid or not allowed, endpoint %s", req.DatabaseEndpoint)
}

clientConnectionParams := types.YdbConnectionParams{
Endpoint: req.GetDatabaseEndpoint(),
DatabaseName: req.GetDatabaseName(),
Endpoint: req.DatabaseEndpoint,
DatabaseName: req.DatabaseName,
}
dsn := types.MakeYdbConnectionString(clientConnectionParams)
client, err := s.clientConn.Open(ctx, dsn)
Expand Down Expand Up @@ -315,11 +353,15 @@ func NewBackupService(
clientConn client.ClientConnector,
s3 config.S3Config,
auth ap.AuthProvider,
allowedEndpointDomains []string,
allowInsecureEndpoint bool,
) *BackupService {
return &BackupService{
driver: driver,
clientConn: clientConn,
s3: s3,
auth: auth,
driver: driver,
clientConn: clientConn,
s3: s3,
auth: auth,
allowedEndpointDomains: allowedEndpointDomains,
allowInsecureEndpoint: allowInsecureEndpoint,
}
}
60 changes: 60 additions & 0 deletions internal/server/services/backup/backupservice_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package backup

import (
"testing"
"ydbcp/internal/auth"
"ydbcp/internal/config"
"ydbcp/internal/connectors/client"
"ydbcp/internal/connectors/db"

"github.com/stretchr/testify/assert"
)

func TestEndpointValidation(t *testing.T) {
dbConnector := db.NewMockDBConnector()
clientConnector := client.NewMockClientConnector()
auth := auth.NewMockAuthProvider()

s := NewBackupService(
dbConnector,
clientConnector,
config.S3Config{},
auth,
[]string{".valid.com", "hostname.good.com"},
true,
)

assert.True(t, s.isAllowedEndpoint("grpc://some-host.zone.valid.com"))
assert.False(t, s.isAllowedEndpoint("grpcs://host.zone.invalid.com"))
assert.True(t, s.isAllowedEndpoint("grpcs://hostname.good.com:1234"))
assert.True(t, s.isAllowedEndpoint("example.valid.com:1234"))
assert.False(t, s.isAllowedEndpoint("grpcs://something.hostname.good.com:1234"))
assert.False(t, s.isAllowedEndpoint(""))
assert.False(t, s.isAllowedEndpoint("grpcs://evilvalid.com:1234"))
assert.False(t, s.isAllowedEndpoint("badhostname.good.com"))
assert.False(t, s.isAllowedEndpoint("some^bad$symbols.valid.com"))
}

func TestEndpointSecureValidation(t *testing.T) {
dbConnector := db.NewMockDBConnector()
clientConnector := client.NewMockClientConnector()
auth := auth.NewMockAuthProvider()

s := NewBackupService(
dbConnector,
clientConnector,
config.S3Config{},
auth,
[]string{".valid.com", "hostname.good.com"},
false,
)

assert.False(t, s.isAllowedEndpoint("grpc://some-host.zone.valid.com"))
assert.False(t, s.isAllowedEndpoint("grpcs://host.zone.invalid.com"))
assert.False(t, s.isAllowedEndpoint("host.zone.valid.com"))
assert.True(t, s.isAllowedEndpoint("grpcs://hostname.good.com:1234"))
assert.False(t, s.isAllowedEndpoint("grpcs://something.hostname.good.com:1234"))
assert.False(t, s.isAllowedEndpoint(""))
assert.False(t, s.isAllowedEndpoint("grpcs://evilvalid.com:1234"))
assert.False(t, s.isAllowedEndpoint("badhostname.good.com"))
}

0 comments on commit 8e169b1

Please sign in to comment.