Skip to content

Commit

Permalink
* Added ydb.WithNodeAddressMutator experimental option for mutate n…
Browse files Browse the repository at this point in the history
…ode addresses from `discovery.ListEndpoints` response
  • Loading branch information
asmyasnikov committed May 15, 2024
1 parent 9d49511 commit 2dbf057
Show file tree
Hide file tree
Showing 7 changed files with 347 additions and 33 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Added `ydb.WithNodeAddressMutator` experimental option for mutate node addresses from `discovery.ListEndpoints` response
* Added type assertion checks to enhance type safety and prevent unexpected panics in critical sections of the codebase

## v3.66.3
Expand Down
36 changes: 32 additions & 4 deletions internal/discovery/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package config
import (
"time"

"github.com/jonboulle/clockwork"

"github.com/ydb-platform/ydb-go-sdk/v3/internal/config"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/meta"
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
Expand All @@ -15,10 +17,12 @@ const (
type Config struct {
config.Common

endpoint string
database string
secure bool
meta *meta.Meta
endpoint string
database string
secure bool
meta *meta.Meta
addressMutator func(address string) string
clock clockwork.Clock

interval time.Duration
trace *trace.Discovery
Expand All @@ -28,6 +32,10 @@ func New(opts ...Option) *Config {
c := &Config{
interval: DefaultInterval,
trace: &trace.Discovery{},
addressMutator: func(address string) string {
return address
},
clock: clockwork.NewRealClock(),
}
for _, opt := range opts {
if opt != nil {
Expand All @@ -38,10 +46,18 @@ func New(opts ...Option) *Config {
return c
}

func (c *Config) MutateAddress(fqdn string) string {
return c.addressMutator(fqdn)
}

func (c *Config) Meta() *meta.Meta {
return c.meta
}

func (c *Config) Clock() clockwork.Clock {
return c.clock
}

func (c *Config) Interval() time.Duration {
return c.interval
}
Expand Down Expand Up @@ -85,6 +101,18 @@ func WithDatabase(database string) Option {
}
}

func WithClock(clock clockwork.Clock) Option {
return func(c *Config) {
c.clock = clock
}
}

func WithAddressMutator(addressMutator func(address string) string) Option {
return func(c *Config) {
c.addressMutator = addressMutator
}
}

// WithSecure set flag for secure connection
func WithSecure(ssl bool) Option {
return func(c *Config) {
Expand Down
76 changes: 49 additions & 27 deletions internal/discovery/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
)

//go:generate mockgen -destination grpc_client_mock_test.go -package discovery -write_package_comment=false github.com/ydb-platform/ydb-go-genproto/Ydb_Discovery_V1 DiscoveryServiceClient

func New(ctx context.Context, cc grpc.ClientConnInterface, config *config.Config) *Client {
return &Client{
config: config,
Expand All @@ -35,65 +37,85 @@ type Client struct {
client Ydb_Discovery_V1.DiscoveryServiceClient
}

// Discover cluster endpoints
func (c *Client) Discover(ctx context.Context) (endpoints []endpoint.Endpoint, err error) {
func discover(
ctx context.Context,
client Ydb_Discovery_V1.DiscoveryServiceClient,
config *config.Config,
) (endpoints []endpoint.Endpoint, location string, err error) {
var (
onDone = trace.DiscoveryOnDiscover(
c.config.Trace(), &ctx,
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/discovery.(*Client).Discover"),
c.config.Endpoint(), c.config.Database(),
)
request = Ydb_Discovery.ListEndpointsRequest{
Database: c.config.Database(),
Database: config.Database(),
}
response *Ydb_Discovery.ListEndpointsResponse
result Ydb_Discovery.ListEndpointsResult
location string
)
defer func() {
nodes := make([]trace.EndpointInfo, 0, len(endpoints))
for _, e := range endpoints {
nodes = append(nodes, e.Copy())
}
onDone(location, nodes, err)
}()

ctx, err = c.config.Meta().Context(ctx)
if err != nil {
return nil, xerrors.WithStackTrace(err)
}

response, err = c.client.ListEndpoints(ctx, &request)
response, err = client.ListEndpoints(ctx, &request)
if err != nil {
return nil, xerrors.WithStackTrace(err)
return nil, location, xerrors.WithStackTrace(err)
}

if response.GetOperation().GetStatus() != Ydb.StatusIds_SUCCESS {
return nil, xerrors.WithStackTrace(
return nil, location, xerrors.WithStackTrace(
xerrors.FromOperation(response.GetOperation()),
)
}

err = response.GetOperation().GetResult().UnmarshalTo(&result)
if err != nil {
return nil, xerrors.WithStackTrace(err)
return nil, location, xerrors.WithStackTrace(err)
}

location = result.GetSelfLocation()
endpoints = make([]endpoint.Endpoint, 0, len(result.GetEndpoints()))
for _, e := range result.GetEndpoints() {
if e.GetSsl() == c.config.Secure() {
if e.GetSsl() == config.Secure() {
endpoints = append(endpoints, endpoint.New(
net.JoinHostPort(e.GetAddress(), strconv.Itoa(int(e.GetPort()))),
net.JoinHostPort(
config.MutateAddress(e.GetAddress()),
strconv.Itoa(int(e.GetPort())),
),
endpoint.WithLocation(e.GetLocation()),
endpoint.WithID(e.GetNodeId()),
endpoint.WithLoadFactor(e.GetLoadFactor()),
endpoint.WithLocalDC(e.GetLocation() == location),
endpoint.WithServices(e.GetService()),
endpoint.WithLastUpdated(config.Clock().Now()),
))
}
}

return endpoints, result.GetSelfLocation(), nil
}

// Discover cluster endpoints
func (c *Client) Discover(ctx context.Context) (endpoints []endpoint.Endpoint, finalErr error) {
var (
onDone = trace.DiscoveryOnDiscover(
c.config.Trace(), &ctx,
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/discovery.(*Client).Discover"),
c.config.Endpoint(), c.config.Database(),
)
location string
)
defer func() {
nodes := make([]trace.EndpointInfo, 0, len(endpoints))
for _, e := range endpoints {
nodes = append(nodes, e.Copy())
}
onDone(location, nodes, finalErr)
}()

ctx, err := c.config.Meta().Context(ctx)
if err != nil {
return nil, xerrors.WithStackTrace(err)
}

endpoints, location, err = discover(ctx, c.client, c.config)
if err != nil {
return nil, xerrors.WithStackTrace(err)
}

return endpoints, nil
}

Expand Down
173 changes: 173 additions & 0 deletions internal/discovery/discovery_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
package discovery

import (
"testing"

"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Discovery"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Operations"
"go.uber.org/mock/gomock"
grpcCodes "google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/anypb"

"github.com/ydb-platform/ydb-go-sdk/v3/internal/discovery/config"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest"
)

func must[T any](t T, err error) T {
if err != nil {
panic(err)
}

return t
}

func TestDiscover(t *testing.T) {
t.Run("HappyWay", func(t *testing.T) {
ctx := xtest.Context(t)
ctrl := gomock.NewController(t)
clock := clockwork.NewFakeClock()
client := NewMockDiscoveryServiceClient(ctrl)
client.EXPECT().ListEndpoints(gomock.Any(), &Ydb_Discovery.ListEndpointsRequest{
Database: "test",
}).Return(&Ydb_Discovery.ListEndpointsResponse{
Operation: &Ydb_Operations.Operation{
Ready: true,
Status: Ydb.StatusIds_SUCCESS,
Result: must(anypb.New(&Ydb_Discovery.ListEndpointsResult{
Endpoints: []*Ydb_Discovery.EndpointInfo{
{
Address: "node1",
Port: 1,
Ssl: true,
},
{
Address: "node2",
Port: 2,
Location: "AZ0",
Ssl: true,
},
{
Address: "node3",
Port: 3,
Ssl: false,
},
{
Address: "node4",
Port: 4,
Location: "AZ0",
Ssl: false,
},
},
SelfLocation: "AZ0",
})),
},
}, nil)
endpoints, location, err := discover(ctx, client, config.New(
config.WithDatabase("test"),
config.WithSecure(false),
config.WithClock(clock),
))
require.NoError(t, err)
require.EqualValues(t, "AZ0", location)
require.EqualValues(t, []endpoint.Endpoint{
endpoint.New("node3:3",
endpoint.WithLocalDC(false),
endpoint.WithLastUpdated(clock.Now()),
),
endpoint.New("node4:4",
endpoint.WithLocalDC(true),
endpoint.WithLocation("AZ0"),
endpoint.WithLastUpdated(clock.Now()),
),
}, endpoints)
})
t.Run("TransportError", func(t *testing.T) {
ctx := xtest.Context(t)
ctrl := gomock.NewController(t)
client := NewMockDiscoveryServiceClient(ctrl)
client.EXPECT().ListEndpoints(gomock.Any(), &Ydb_Discovery.ListEndpointsRequest{
Database: "test",
}).Return(nil, xerrors.Transport(status.Error(grpcCodes.Unavailable, "")))
endpoints, location, err := discover(ctx, client, config.New(
config.WithDatabase("test"),
))
require.Error(t, err)
require.Empty(t, endpoints)
require.Equal(t, "", location)
require.True(t, xerrors.IsTransportError(err, grpcCodes.Unavailable))
})
t.Run("OperationError", func(t *testing.T) {
ctx := xtest.Context(t)
ctrl := gomock.NewController(t)
client := NewMockDiscoveryServiceClient(ctrl)
client.EXPECT().ListEndpoints(gomock.Any(), &Ydb_Discovery.ListEndpointsRequest{
Database: "test",
}).Return(&Ydb_Discovery.ListEndpointsResponse{
Operation: &Ydb_Operations.Operation{
Ready: true,
Status: Ydb.StatusIds_UNAVAILABLE,
},
}, nil)
endpoints, location, err := discover(ctx, client, config.New(
config.WithDatabase("test"),
))
require.Error(t, err)
require.Empty(t, endpoints)
require.Equal(t, "", location)
require.True(t, xerrors.IsOperationError(err, Ydb.StatusIds_UNAVAILABLE))
})
t.Run("WithAddressMutator", func(t *testing.T) {
ctx := xtest.Context(t)
ctrl := gomock.NewController(t)
clock := clockwork.NewFakeClock()
client := NewMockDiscoveryServiceClient(ctrl)
client.EXPECT().ListEndpoints(gomock.Any(), &Ydb_Discovery.ListEndpointsRequest{
Database: "test",
}).Return(&Ydb_Discovery.ListEndpointsResponse{
Operation: &Ydb_Operations.Operation{
Ready: true,
Status: Ydb.StatusIds_SUCCESS,
Result: must(anypb.New(&Ydb_Discovery.ListEndpointsResult{
Endpoints: []*Ydb_Discovery.EndpointInfo{
{
Address: "node1",
Port: 1,
},
{
Address: "node2",
Port: 2,
Location: "AZ0",
},
},
SelfLocation: "AZ0",
})),
},
}, nil)
endpoints, location, err := discover(ctx, client, config.New(
config.WithDatabase("test"),
config.WithAddressMutator(func(address string) string {
return "u-" + address
}),
config.WithClock(clock),
))
require.NoError(t, err)
require.EqualValues(t, "AZ0", location)
require.EqualValues(t, []endpoint.Endpoint{
endpoint.New("u-node1:1",
endpoint.WithLocalDC(false),
endpoint.WithLastUpdated(clock.Now()),
),
endpoint.New("u-node2:2",
endpoint.WithLocalDC(true),
endpoint.WithLocation("AZ0"),
endpoint.WithLastUpdated(clock.Now()),
),
}, endpoints)
})
}
Loading

0 comments on commit 2dbf057

Please sign in to comment.