Skip to content

Commit

Permalink
watch client connection state in test
Browse files Browse the repository at this point in the history
  • Loading branch information
GavinFrazar authored and github-actions committed Dec 3, 2024
1 parent ce348d4 commit ff2dfa0
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 19 deletions.
7 changes: 3 additions & 4 deletions lib/srv/db/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import (
"testing"
"time"

gspanner "cloud.google.com/go/spanner"
"github.com/ClickHouse/ch-go"
cqlclient "github.com/datastax/go-cassandra-native-protocol/client"
elastic "github.com/elastic/go-elasticsearch/v8"
Expand Down Expand Up @@ -2129,7 +2128,7 @@ func (c *testContext) dynamodbClient(ctx context.Context, teleportUser, dbServic
return db, proxy, nil
}

func (c *testContext) spannerClient(ctx context.Context, teleportUser, dbService, dbUser, dbName string) (*gspanner.Client, *alpnproxy.LocalProxy, error) {
func (c *testContext) spannerClient(ctx context.Context, teleportUser, dbService, dbUser, dbName string) (*spanner.SpannerTestClient, *alpnproxy.LocalProxy, error) {
route := tlsca.RouteToDatabase{
ServiceName: dbService,
Protocol: defaults.ProtocolSpanner,
Expand All @@ -2142,7 +2141,7 @@ func (c *testContext) spannerClient(ctx context.Context, teleportUser, dbService
return nil, nil, trace.Wrap(err)
}

db, err := spanner.MakeTestClient(ctx, common.TestClientConfig{
clt, err := spanner.MakeTestClient(ctx, common.TestClientConfig{
AuthClient: c.authClient,
AuthServer: c.authServer,
Address: proxy.GetAddr(),
Expand All @@ -2154,7 +2153,7 @@ func (c *testContext) spannerClient(ctx context.Context, teleportUser, dbService
return nil, nil, trace.Wrap(err)
}

return db, proxy, nil
return clt, proxy, nil
}

type roleOptFn func(types.Role)
Expand Down
46 changes: 37 additions & 9 deletions lib/srv/db/spanner/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/sirupsen/logrus"
"google.golang.org/api/option"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
Expand All @@ -47,11 +48,35 @@ import (
"github.com/gravitational/teleport/lib/tlsca"
)

func MakeTestClient(ctx context.Context, config common.TestClientConfig) (*spanner.Client, error) {
// SpannerTestClient wraps a [spanner.Client] and provides direct access to the
// underlying [grpc.ClientConn] of the client.
type SpannerTestClient struct {
ClientConn *grpc.ClientConn
*spanner.Client
}

// WaitForConnectionState waits until the spanner client's underlying gRPC
// connection transitions into the given state or the context expires.
func (c *SpannerTestClient) WaitForConnectionState(ctx context.Context, wantState connectivity.State) error {
for {
s := c.ClientConn.GetState()
if s == wantState {
return nil
}
if s == connectivity.Shutdown {
return trace.Errorf("spanner test client connection has shutdown")
}
if !c.ClientConn.WaitForStateChange(ctx, s) {
return ctx.Err()
}
}
}

func MakeTestClient(ctx context.Context, config common.TestClientConfig) (*SpannerTestClient, error) {
return makeTestClient(ctx, config, false)
}

func makeTestClient(ctx context.Context, config common.TestClientConfig, useTLS bool) (*spanner.Client, error) {
func makeTestClient(ctx context.Context, config common.TestClientConfig, useTLS bool) (*SpannerTestClient, error) {
databaseID, err := getDatabaseID(ctx, config.RouteToDatabase, config.AuthServer)
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -68,13 +93,13 @@ func makeTestClient(ctx context.Context, config common.TestClientConfig, useTLS
transportOpt = grpc.WithTransportCredentials(insecure.NewCredentials())
}

cc, err := grpc.NewClient(config.Address, transportOpt)
if err != nil {
return nil, trace.Wrap(err)
}

opts := []option.ClientOption{
// dial with custom transport security
option.WithGRPCDialOption(transportOpt),
// create 1 connection
option.WithGRPCConnectionPool(1),
// connect to the Teleport endpoint
option.WithEndpoint(config.Address),
option.WithGRPCConn(cc),
// client should not bring any GCP credentials
option.WithoutAuthentication(),
}
Expand All @@ -86,7 +111,10 @@ func makeTestClient(ctx context.Context, config common.TestClientConfig, useTLS
if err != nil {
return nil, trace.Wrap(err)
}
return clt, nil
return &SpannerTestClient{
ClientConn: cc,
Client: clt,
}, nil
}

func getDatabaseID(ctx context.Context, route tlsca.RouteToDatabase, getter services.DatabaseServersGetter) (string, error) {
Expand Down
18 changes: 12 additions & 6 deletions lib/srv/db/spanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

gspanner "cloud.google.com/go/spanner"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/connectivity"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/events"
Expand Down Expand Up @@ -234,7 +235,15 @@ func TestAuditSpanner(t *testing.T) {
_ = localProxy.Close()
})

require.NoError(t, err)
require.NoError(t, clt.WaitForConnectionState(ctx, connectivity.Ready))
reconnectingCh := make(chan bool)
go func() {
// we should observe the connection leave the "ready" state after
// it gets an access denied error.
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
reconnectingCh <- clt.ClientConn.WaitForStateChange(ctx, connectivity.Ready)
}()

row, err := pingSpanner(ctx, clt, 42)
require.Error(t, err)
Expand All @@ -246,10 +255,7 @@ func TestAuditSpanner(t *testing.T) {
require.True(t, ok)
require.Equal(t, "googlesql", dbStart1.DatabaseName)

// the connection should start to shut down, but wait some time for that
// to happen before sending further RPCs so that the client will dial
// a new connection before sending another query RPC.
time.Sleep(500 * time.Millisecond)
require.True(t, <-reconnectingCh, "timed out waiting for the spanner client to reconnect")
row, err = pingSpanner(ctx, clt, 42)
require.Error(t, err)
require.ErrorContains(t, err, "access to db denied")
Expand Down Expand Up @@ -312,7 +318,7 @@ func TestAuditSpanner(t *testing.T) {
})
}

func pingSpanner(ctx context.Context, clt *gspanner.Client, want int64) (*gspanner.Row, error) {
func pingSpanner(ctx context.Context, clt *spanner.SpannerTestClient, want int64) (*gspanner.Row, error) {
query := gspanner.NewStatement(fmt.Sprintf("SELECT %d", want))
rowIter := clt.Single().Query(ctx, query)
defer rowIter.Stop()
Expand Down

0 comments on commit ff2dfa0

Please sign in to comment.