Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[COR-1114] Fix token validity check logic to use exp field in access… #5998

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions flyteidl/clients/go/admin/auth_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"google.golang.org/grpc/status"

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/logger"
)
Expand All @@ -23,12 +24,10 @@
// Once established, it'll invoke PerRPCCredentialsFuture.Store() on perRPCCredentials to populate it with the appropriate values.
func MaterializeCredentials(tokenSource oauth2.TokenSource, cfg *Config, authorizationMetadataKey string,
perRPCCredentials *PerRPCCredentialsFuture) error {

_, err := tokenSource.Token()
if err != nil {
return fmt.Errorf("failed to issue token. Error: %w", err)
}

wrappedTokenSource := NewCustomHeaderTokenSource(tokenSource, cfg.UseInsecureConnection, authorizationMetadataKey)
perRPCCredentials.Store(wrappedTokenSource)

Expand Down Expand Up @@ -119,11 +118,6 @@
return fmt.Errorf("failed to initialized Auth Metadata Client. Error: %w", err)
}

tokenSourceProvider, err := NewTokenSourceProvider(ctx, cfg, tokenCache, authMetadataClient)
if err != nil {
return fmt.Errorf("failed to initialize token source provider. Err: %w", err)
}

authorizationMetadataKey := cfg.AuthorizationHeader
if len(authorizationMetadataKey) == 0 {
clientMetadata, err := authMetadataClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{})
Expand All @@ -133,7 +127,7 @@
authorizationMetadataKey = clientMetadata.AuthorizationMetadataKey
}

tokenSource, err := tokenSourceProvider.GetTokenSource(ctx)
tokenSource, err := NewInMemoryTokenSourceProvider(tokenCache).GetTokenSource(ctx)
if err != nil {
return fmt.Errorf("failed to get token source. Error: %w", err)
}
Expand Down Expand Up @@ -188,10 +182,11 @@
}
authorizationMetadataKey := oauthMetadataProvider.authorizationMetadataKey
tokenSource := oauthMetadataProvider.tokenSource

err = MaterializeCredentials(tokenSource, cfg, authorizationMetadataKey, credentialsFuture)
if err != nil {
return fmt.Errorf("failed to materialize credentials. Error: %v", err)
if isValid := utils.Valid(t); isValid {
err = MaterializeCredentials(tokenSource, cfg, authorizationMetadataKey, credentialsFuture)
if err != nil {
return fmt.Errorf("failed to materialize credentials. Error: %v", err)
}

Check warning on line 189 in flyteidl/clients/go/admin/auth_interceptor.go

View check run for this annotation

Codecov / codecov/patch

flyteidl/clients/go/admin/auth_interceptor.go#L188-L189

Added lines #L188 - L189 were not covered by tests
}
}

Expand Down
12 changes: 5 additions & 7 deletions flyteidl/clients/go/admin/auth_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@ package admin

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strings"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
Expand All @@ -25,6 +24,7 @@ import (
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache/mocks"
adminMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/mocks"

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/logger"
Expand Down Expand Up @@ -137,10 +137,7 @@ func newAuthMetadataServer(t testing.TB, grpcPort int, httpPort int, impl servic
}

func Test_newAuthInterceptor(t *testing.T) {
plan, _ := os.ReadFile("tokenorchestrator/testdata/token.json")
var tokenData oauth2.Token
err := json.Unmarshal(plan, &tokenData)
assert.NoError(t, err)
tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(20*time.Minute))
t.Run("Other Error", func(t *testing.T) {
ctx := context.Background()
httpPort := rand.IntnRange(10000, 60000)
Expand All @@ -164,12 +161,13 @@ func Test_newAuthInterceptor(t *testing.T) {
f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()
mockTokenCache := &mocks.TokenCache{}
mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil)
mockTokenCache.OnGetTokenMatch().Return(tokenData, nil)
mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil)
interceptor := NewAuthInterceptor(&Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
}, mockTokenCache, f, p)

otherError := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return status.New(codes.Canceled, "").Err()
}
Expand Down
23 changes: 7 additions & 16 deletions flyteidl/clients/go/admin/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@ package admin

import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"testing"
"time"

Expand All @@ -24,6 +21,7 @@ import (
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/oauth"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/pkce"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/tokenorchestrator"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/logger"
Expand Down Expand Up @@ -231,15 +229,11 @@ func TestGetAuthenticationDialOptionPkce(t *testing.T) {
RedirectUri: "http://localhost:54545/callback",
}
http.DefaultServeMux = http.NewServeMux()
plan, _ := os.ReadFile("tokenorchestrator/testdata/token.json")
var tokenData oauth2.Token
err := json.Unmarshal(plan, &tokenData)
assert.NoError(t, err)
tokenData.Expiry = time.Now().Add(time.Minute)
tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(time.Minute))
t.Run("cache hit", func(t *testing.T) {
mockTokenCache := new(cachemocks.TokenCache)
mockAuthClient := new(mocks.AuthMetadataServiceClient)
mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil)
mockTokenCache.OnGetTokenMatch().Return(tokenData, nil)
mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil)
mockAuthClient.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(metadata, nil)
mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil)
Expand All @@ -249,11 +243,11 @@ func TestGetAuthenticationDialOptionPkce(t *testing.T) {
assert.NotNil(t, dialOption)
assert.Nil(t, err)
})
tokenData.Expiry = time.Now().Add(-time.Minute)
t.Run("cache miss auth failure", func(t *testing.T) {
tokenData = utils.GenTokenWithCustomExpiry(t, time.Now().Add(-time.Minute))
mockTokenCache := new(cachemocks.TokenCache)
mockAuthClient := new(mocks.AuthMetadataServiceClient)
mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil)
mockTokenCache.OnGetTokenMatch().Return(tokenData, nil)
mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil)
mockTokenCache.On("Lock").Return()
mockTokenCache.On("Unlock").Return()
Expand Down Expand Up @@ -284,14 +278,11 @@ func Test_getPkceAuthTokenSource(t *testing.T) {
mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil)

t.Run("cached token expired", func(t *testing.T) {
plan, _ := ioutil.ReadFile("tokenorchestrator/testdata/token.json")
var tokenData oauth2.Token
err := json.Unmarshal(plan, &tokenData)
assert.NoError(t, err)
tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(-time.Minute))

// populate the cache
tokenCache := cache.NewTokenCacheInMemoryProvider()
assert.NoError(t, tokenCache.SaveToken(&tokenData))
assert.NoError(t, tokenCache.SaveToken(tokenData))

baseOrchestrator := tokenorchestrator.BaseTokenOrchestrator{
ClientConfig: &oauth.Config{
Expand Down
46 changes: 39 additions & 7 deletions flyteidl/clients/go/admin/token_source_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/externalprocess"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/pkce"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/tokenorchestrator"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/logger"
)
Expand Down Expand Up @@ -229,28 +230,36 @@
s.mu.Lock()
defer s.mu.Unlock()

if token, err := s.tokenCache.GetToken(); err == nil && token.Valid() {
return token, nil
token, err := s.tokenCache.GetToken()
if err != nil {
logger.Warnf(s.ctx, "failed to get token from cache: %v", err)

Check warning on line 235 in flyteidl/clients/go/admin/token_source_provider.go

View check run for this annotation

Codecov / codecov/patch

flyteidl/clients/go/admin/token_source_provider.go#L235

Added line #L235 was not covered by tests
} else {
if isValid := utils.Valid(token); isValid {
logger.Infof(context.Background(), "retrieved token from cache with expiry %v", token.Expiry)
return token, nil
}
}

totalAttempts := s.cfg.MaxRetries + 1 // Add one for initial request attempt
backoff := wait.Backoff{
Duration: s.cfg.PerRetryTimeout.Duration,
Steps: totalAttempts,
}
var token *oauth2.Token
err := retry.OnError(backoff, func(err error) bool {

err = retry.OnError(backoff, func(err error) bool {
return err != nil
}, func() (err error) {
token, err = s.new.Token()
if err != nil {
logger.Infof(s.ctx, "failed to get token: %w", err)
return fmt.Errorf("failed to get token: %w", err)
logger.Infof(s.ctx, "failed to get new token: %w", err)
return fmt.Errorf("failed to get new token: %w", err)
}
logger.Infof(context.Background(), "Fetched new token with expiry %v", token.Expiry)
return nil
})
if err != nil {
return nil, err
logger.Warnf(s.ctx, "failed to get new token: %v", err)
return nil, fmt.Errorf("failed to get new token: %w", err)
}
logger.Infof(s.ctx, "retrieved token with expiry %v", token.Expiry)

Expand All @@ -262,6 +271,29 @@
return token, nil
}

type InMemoryTokenSourceProvider struct {
tokenCache cache.TokenCache
}

func NewInMemoryTokenSourceProvider(tokenCache cache.TokenCache) TokenSourceProvider {
return InMemoryTokenSourceProvider{tokenCache: tokenCache}
}

func (i InMemoryTokenSourceProvider) GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) {
return GetInMemoryAuthTokenSource(ctx, i.tokenCache)
}

// GetInMemoryAuthTokenSource Returns the token source with cached token
func GetInMemoryAuthTokenSource(ctx context.Context, tokenCache cache.TokenCache) (oauth2.TokenSource, error) {
authToken, err := tokenCache.GetToken()
if err != nil {
return nil, err
}

Check warning on line 291 in flyteidl/clients/go/admin/token_source_provider.go

View check run for this annotation

Codecov / codecov/patch

flyteidl/clients/go/admin/token_source_provider.go#L290-L291

Added lines #L290 - L291 were not covered by tests
return &pkce.SimpleTokenSource{
CachedToken: authToken,
}, nil
}

type DeviceFlowTokenSourceProvider struct {
tokenOrchestrator deviceflow.TokenOrchestrator
}
Expand Down
25 changes: 13 additions & 12 deletions flyteidl/clients/go/admin/token_source_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

tokenCacheMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache/mocks"
adminMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/mocks"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
)

Expand Down Expand Up @@ -88,9 +89,9 @@ func TestCustomTokenSource_Token(t *testing.T) {
minuteAgo := time.Now().Add(-time.Minute)
hourAhead := time.Now().Add(time.Hour)
twoHourAhead := time.Now().Add(2 * time.Hour)
invalidToken := oauth2.Token{AccessToken: "foo", Expiry: minuteAgo}
validToken := oauth2.Token{AccessToken: "foo", Expiry: hourAhead}
newToken := oauth2.Token{AccessToken: "foo", Expiry: twoHourAhead}
invalidToken := utils.GenTokenWithCustomExpiry(t, minuteAgo)
validToken := utils.GenTokenWithCustomExpiry(t, hourAhead)
newToken := utils.GenTokenWithCustomExpiry(t, twoHourAhead)

tests := []struct {
name string
Expand All @@ -101,24 +102,24 @@ func TestCustomTokenSource_Token(t *testing.T) {
{
name: "no cached token",
token: nil,
newToken: &newToken,
expectedToken: &newToken,
newToken: newToken,
expectedToken: newToken,
},
{
name: "cached token valid",
token: &validToken,
token: validToken,
newToken: nil,
expectedToken: &validToken,
expectedToken: validToken,
},
{
name: "cached token expired",
token: &invalidToken,
newToken: &newToken,
expectedToken: &newToken,
token: invalidToken,
newToken: newToken,
expectedToken: newToken,
},
{
name: "failed new token",
token: &invalidToken,
token: invalidToken,
newToken: nil,
expectedToken: nil,
},
Expand All @@ -138,7 +139,7 @@ func TestCustomTokenSource_Token(t *testing.T) {
assert.True(t, ok)

mockSource := &adminMocks.TokenSource{}
if test.token != &validToken {
if test.token != validToken {
if test.newToken != nil {
mockSource.OnToken().Return(test.newToken, nil)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/oauth"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/logger"
Expand Down Expand Up @@ -52,7 +53,8 @@ func (t BaseTokenOrchestrator) FetchTokenFromCacheOrRefreshIt(ctx context.Contex
return nil, err
}

if token.Valid() {
if isValid := utils.Valid(token); isValid {
logger.Infof(context.Background(), "retrieved token from cache with expiry %v", token.Expiry)
return token, nil
}

Expand Down
Loading
Loading