diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go index c48b4caa3f35a..14b37f32bfbca 100644 --- a/lib/srv/db/access_test.go +++ b/lib/srv/db/access_test.go @@ -32,7 +32,6 @@ import ( "testing" "time" - gspanner "cloud.google.com/go/spanner" "github.com/ClickHouse/ch-go" cqlclient "github.com/datastax/go-cassandra-native-protocol/client" elastic "github.com/elastic/go-elasticsearch/v8" @@ -2129,7 +2128,7 @@ func (c *testContext) dynamodbClient(ctx context.Context, teleportUser, dbServic return db, proxy, nil } -func (c *testContext) spannerClient(ctx context.Context, teleportUser, dbService, dbUser, dbName string) (*gspanner.Client, *alpnproxy.LocalProxy, error) { +func (c *testContext) spannerClient(ctx context.Context, teleportUser, dbService, dbUser, dbName string) (*spanner.SpannerTestClient, *alpnproxy.LocalProxy, error) { route := tlsca.RouteToDatabase{ ServiceName: dbService, Protocol: defaults.ProtocolSpanner, @@ -2142,7 +2141,7 @@ func (c *testContext) spannerClient(ctx context.Context, teleportUser, dbService return nil, nil, trace.Wrap(err) } - db, err := spanner.MakeTestClient(ctx, common.TestClientConfig{ + clt, err := spanner.MakeTestClient(ctx, common.TestClientConfig{ AuthClient: c.authClient, AuthServer: c.authServer, Address: proxy.GetAddr(), @@ -2154,7 +2153,7 @@ func (c *testContext) spannerClient(ctx context.Context, teleportUser, dbService return nil, nil, trace.Wrap(err) } - return db, proxy, nil + return clt, proxy, nil } type roleOptFn func(types.Role) diff --git a/lib/srv/db/spanner/test.go b/lib/srv/db/spanner/test.go index 4a7b5378766cd..44b74085c05bf 100644 --- a/lib/srv/db/spanner/test.go +++ b/lib/srv/db/spanner/test.go @@ -33,6 +33,7 @@ import ( "github.com/sirupsen/logrus" "google.golang.org/api/option" "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" @@ -47,11 +48,35 @@ import ( "github.com/gravitational/teleport/lib/tlsca" ) -func MakeTestClient(ctx context.Context, config common.TestClientConfig) (*spanner.Client, error) { +// SpannerTestClient wraps a [spanner.Client] and provides direct access to the +// underlying [grpc.ClientConn] of the client. +type SpannerTestClient struct { + ClientConn *grpc.ClientConn + *spanner.Client +} + +// WaitForConnectionState waits until the spanner client's underlying gRPC +// connection transitions into the given state or the context expires. +func (c *SpannerTestClient) WaitForConnectionState(ctx context.Context, wantState connectivity.State) error { + for { + s := c.ClientConn.GetState() + if s == wantState { + return nil + } + if s == connectivity.Shutdown { + return trace.Errorf("spanner test client connection has shutdown") + } + if !c.ClientConn.WaitForStateChange(ctx, s) { + return ctx.Err() + } + } +} + +func MakeTestClient(ctx context.Context, config common.TestClientConfig) (*SpannerTestClient, error) { return makeTestClient(ctx, config, false) } -func makeTestClient(ctx context.Context, config common.TestClientConfig, useTLS bool) (*spanner.Client, error) { +func makeTestClient(ctx context.Context, config common.TestClientConfig, useTLS bool) (*SpannerTestClient, error) { databaseID, err := getDatabaseID(ctx, config.RouteToDatabase, config.AuthServer) if err != nil { return nil, trace.Wrap(err) @@ -68,13 +93,13 @@ func makeTestClient(ctx context.Context, config common.TestClientConfig, useTLS transportOpt = grpc.WithTransportCredentials(insecure.NewCredentials()) } + cc, err := grpc.NewClient(config.Address, transportOpt) + if err != nil { + return nil, trace.Wrap(err) + } + opts := []option.ClientOption{ - // dial with custom transport security - option.WithGRPCDialOption(transportOpt), - // create 1 connection - option.WithGRPCConnectionPool(1), - // connect to the Teleport endpoint - option.WithEndpoint(config.Address), + option.WithGRPCConn(cc), // client should not bring any GCP credentials option.WithoutAuthentication(), } @@ -86,7 +111,10 @@ func makeTestClient(ctx context.Context, config common.TestClientConfig, useTLS if err != nil { return nil, trace.Wrap(err) } - return clt, nil + return &SpannerTestClient{ + ClientConn: cc, + Client: clt, + }, nil } func getDatabaseID(ctx context.Context, route tlsca.RouteToDatabase, getter services.DatabaseServersGetter) (string, error) { diff --git a/lib/srv/db/spanner_test.go b/lib/srv/db/spanner_test.go index 12ba6dfed2f38..4af7fbdb46cb8 100644 --- a/lib/srv/db/spanner_test.go +++ b/lib/srv/db/spanner_test.go @@ -28,6 +28,7 @@ import ( gspanner "cloud.google.com/go/spanner" "github.com/stretchr/testify/require" + "google.golang.org/grpc/connectivity" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/events" @@ -234,7 +235,15 @@ func TestAuditSpanner(t *testing.T) { _ = localProxy.Close() }) - require.NoError(t, err) + require.NoError(t, clt.WaitForConnectionState(ctx, connectivity.Ready)) + reconnectingCh := make(chan bool) + go func() { + // we should observe the connection leave the "ready" state after + // it gets an access denied error. + ctx, cancel := context.WithTimeout(ctx, time.Second*10) + defer cancel() + reconnectingCh <- clt.ClientConn.WaitForStateChange(ctx, connectivity.Ready) + }() row, err := pingSpanner(ctx, clt, 42) require.Error(t, err) @@ -246,10 +255,7 @@ func TestAuditSpanner(t *testing.T) { require.True(t, ok) require.Equal(t, "googlesql", dbStart1.DatabaseName) - // the connection should start to shut down, but wait some time for that - // to happen before sending further RPCs so that the client will dial - // a new connection before sending another query RPC. - time.Sleep(500 * time.Millisecond) + require.True(t, <-reconnectingCh, "timed out waiting for the spanner client to reconnect") row, err = pingSpanner(ctx, clt, 42) require.Error(t, err) require.ErrorContains(t, err, "access to db denied") @@ -312,7 +318,7 @@ func TestAuditSpanner(t *testing.T) { }) } -func pingSpanner(ctx context.Context, clt *gspanner.Client, want int64) (*gspanner.Row, error) { +func pingSpanner(ctx context.Context, clt *spanner.SpannerTestClient, want int64) (*gspanner.Row, error) { query := gspanner.NewStatement(fmt.Sprintf("SELECT %d", want)) rowIter := clt.Single().Query(ctx, query) defer rowIter.Stop()