Skip to content

Commit

Permalink
[COR-1114] Fix token validity check logic to use exp field in access …
Browse files Browse the repository at this point in the history
…token (#330)

* Add logs for token

* add logs

* Fixing the validity check logic for token

* nit

* nit

* Adding in memory token source provider

* nit

* changed Valid method to log and ignore parseDateClaim error

* nit

* Fix unit tests

* lint

* fix unit tests

Signed-off-by: pmahindrakar-oss <[email protected]>
  • Loading branch information
pmahindrakar-oss committed Nov 12, 2024
1 parent 3c3ae05 commit adc3ce2
Show file tree
Hide file tree
Showing 10 changed files with 161 additions and 90 deletions.
17 changes: 7 additions & 10 deletions flyteidl/clients/go/admin/auth_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"google.golang.org/grpc/status"

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/logger"
)
Expand All @@ -23,12 +24,10 @@ const ProxyAuthorizationHeader = "proxy-authorization"
// Once established, it'll invoke PerRPCCredentialsFuture.Store() on perRPCCredentials to populate it with the appropriate values.
func MaterializeCredentials(tokenSource oauth2.TokenSource, cfg *Config, authorizationMetadataKey string,
perRPCCredentials *PerRPCCredentialsFuture) error {

_, err := tokenSource.Token()
if err != nil {
return fmt.Errorf("failed to issue token. Error: %w", err)
}

wrappedTokenSource := NewCustomHeaderTokenSource(tokenSource, cfg.UseInsecureConnection, authorizationMetadataKey)
perRPCCredentials.Store(wrappedTokenSource)

Expand Down Expand Up @@ -119,11 +118,6 @@ func (o *OauthMetadataProvider) getTokenSourceAndMetadata(cfg *Config, tokenCach
return fmt.Errorf("failed to initialized Auth Metadata Client. Error: %w", err)
}

tokenSourceProvider, err := NewTokenSourceProvider(ctx, cfg, tokenCache, authMetadataClient)
if err != nil {
return fmt.Errorf("failed to initialize token source provider. Err: %w", err)
}

authorizationMetadataKey := cfg.AuthorizationHeader
if len(authorizationMetadataKey) == 0 {
clientMetadata, err := authMetadataClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{})
Expand All @@ -133,7 +127,7 @@ func (o *OauthMetadataProvider) getTokenSourceAndMetadata(cfg *Config, tokenCach
authorizationMetadataKey = clientMetadata.AuthorizationMetadataKey
}

tokenSource, err := tokenSourceProvider.GetTokenSource(ctx)
tokenSource, err := NewInMemoryTokenSourceProvider(tokenCache).GetTokenSource(ctx)
if err != nil {
return fmt.Errorf("failed to get token source. Error: %w", err)
}
Expand Down Expand Up @@ -190,8 +184,11 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut
tokenSource := oauthMetadataProvider.tokenSource

err = MaterializeCredentials(tokenSource, cfg, authorizationMetadataKey, credentialsFuture)
if err != nil {
return fmt.Errorf("failed to materialize credentials. Error: %v", err)
if isValid := utils.Valid(t); isValid {
err = MaterializeCredentials(tokenSource, cfg, authorizationMetadataKey, credentialsFuture)
if err != nil {
return fmt.Errorf("failed to materialize credentials. Error: %v", err)
}

Check warning on line 191 in flyteidl/clients/go/admin/auth_interceptor.go

View check run for this annotation

Codecov / codecov/patch

flyteidl/clients/go/admin/auth_interceptor.go#L190-L191

Added lines #L190 - L191 were not covered by tests
}
}

Expand Down
12 changes: 5 additions & 7 deletions flyteidl/clients/go/admin/auth_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@ package admin

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strings"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
Expand All @@ -25,6 +24,7 @@ import (
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache/mocks"
adminMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/mocks"

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/logger"
Expand Down Expand Up @@ -137,10 +137,7 @@ func newAuthMetadataServer(t testing.TB, grpcPort int, httpPort int, impl servic
}

func Test_newAuthInterceptor(t *testing.T) {
plan, _ := os.ReadFile("tokenorchestrator/testdata/token.json")
var tokenData oauth2.Token
err := json.Unmarshal(plan, &tokenData)
assert.NoError(t, err)
tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(20*time.Minute))
t.Run("Other Error", func(t *testing.T) {
ctx := context.Background()
httpPort := rand.IntnRange(10000, 60000)
Expand All @@ -164,12 +161,13 @@ func Test_newAuthInterceptor(t *testing.T) {
f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()
mockTokenCache := &mocks.TokenCache{}
mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil)
mockTokenCache.OnGetTokenMatch().Return(tokenData, nil)
mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil)
interceptor := NewAuthInterceptor(&Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
}, mockTokenCache, f, p)

otherError := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return status.New(codes.Canceled, "").Err()
}
Expand Down
23 changes: 7 additions & 16 deletions flyteidl/clients/go/admin/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@ package admin

import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"testing"
"time"

Expand All @@ -24,6 +21,7 @@ import (
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/oauth"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/pkce"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/tokenorchestrator"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/logger"
Expand Down Expand Up @@ -231,15 +229,11 @@ func TestGetAuthenticationDialOptionPkce(t *testing.T) {
RedirectUri: "http://localhost:54545/callback",
}
http.DefaultServeMux = http.NewServeMux()
plan, _ := os.ReadFile("tokenorchestrator/testdata/token.json")
var tokenData oauth2.Token
err := json.Unmarshal(plan, &tokenData)
assert.NoError(t, err)
tokenData.Expiry = time.Now().Add(time.Minute)
tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(time.Minute))
t.Run("cache hit", func(t *testing.T) {
mockTokenCache := new(cachemocks.TokenCache)
mockAuthClient := new(mocks.AuthMetadataServiceClient)
mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil)
mockTokenCache.OnGetTokenMatch().Return(tokenData, nil)
mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil)
mockAuthClient.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(metadata, nil)
mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil)
Expand All @@ -249,11 +243,11 @@ func TestGetAuthenticationDialOptionPkce(t *testing.T) {
assert.NotNil(t, dialOption)
assert.Nil(t, err)
})
tokenData.Expiry = time.Now().Add(-time.Minute)
t.Run("cache miss auth failure", func(t *testing.T) {
tokenData = utils.GenTokenWithCustomExpiry(t, time.Now().Add(-time.Minute))
mockTokenCache := new(cachemocks.TokenCache)
mockAuthClient := new(mocks.AuthMetadataServiceClient)
mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil)
mockTokenCache.OnGetTokenMatch().Return(tokenData, nil)
mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil)
mockTokenCache.On("Lock").Return()
mockTokenCache.On("Unlock").Return()
Expand Down Expand Up @@ -284,14 +278,11 @@ func Test_getPkceAuthTokenSource(t *testing.T) {
mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil)

t.Run("cached token expired", func(t *testing.T) {
plan, _ := ioutil.ReadFile("tokenorchestrator/testdata/token.json")
var tokenData oauth2.Token
err := json.Unmarshal(plan, &tokenData)
assert.NoError(t, err)
tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(-time.Minute))

// populate the cache
tokenCache := cache.NewTokenCacheInMemoryProvider()
assert.NoError(t, tokenCache.SaveToken(&tokenData))
assert.NoError(t, tokenCache.SaveToken(tokenData))

baseOrchestrator := tokenorchestrator.BaseTokenOrchestrator{
ClientConfig: &oauth.Config{
Expand Down
46 changes: 39 additions & 7 deletions flyteidl/clients/go/admin/token_source_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/externalprocess"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/pkce"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/tokenorchestrator"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/logger"
)
Expand Down Expand Up @@ -229,28 +230,36 @@ func (s *customTokenSource) Token() (*oauth2.Token, error) {
s.mu.Lock()
defer s.mu.Unlock()

if token, err := s.tokenCache.GetToken(); err == nil && token.Valid() {
return token, nil
token, err := s.tokenCache.GetToken()
if err != nil {
logger.Warnf(s.ctx, "failed to get token from cache: %v", err)

Check warning on line 235 in flyteidl/clients/go/admin/token_source_provider.go

View check run for this annotation

Codecov / codecov/patch

flyteidl/clients/go/admin/token_source_provider.go#L235

Added line #L235 was not covered by tests
} else {
if isValid := utils.Valid(token); isValid {
logger.Infof(context.Background(), "retrieved token from cache with expiry %v", token.Expiry)
return token, nil
}
}

totalAttempts := s.cfg.MaxRetries + 1 // Add one for initial request attempt
backoff := wait.Backoff{
Duration: s.cfg.PerRetryTimeout.Duration,
Steps: totalAttempts,
}
var token *oauth2.Token
err := retry.OnError(backoff, func(err error) bool {

err = retry.OnError(backoff, func(err error) bool {
return err != nil
}, func() (err error) {
token, err = s.new.Token()
if err != nil {
logger.Infof(s.ctx, "failed to get token: %w", err)
return fmt.Errorf("failed to get token: %w", err)
logger.Infof(s.ctx, "failed to get new token: %w", err)
return fmt.Errorf("failed to get new token: %w", err)
}
logger.Infof(context.Background(), "Fetched new token with expiry %v", token.Expiry)
return nil
})
if err != nil {
return nil, err
logger.Warnf(s.ctx, "failed to get new token: %v", err)
return nil, fmt.Errorf("failed to get new token: %w", err)
}
logger.Infof(s.ctx, "retrieved token with expiry %v", token.Expiry)

Expand All @@ -262,6 +271,29 @@ func (s *customTokenSource) Token() (*oauth2.Token, error) {
return token, nil
}

type InMemoryTokenSourceProvider struct {
tokenCache cache.TokenCache
}

func NewInMemoryTokenSourceProvider(tokenCache cache.TokenCache) TokenSourceProvider {
return InMemoryTokenSourceProvider{tokenCache: tokenCache}
}

func (i InMemoryTokenSourceProvider) GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) {
return GetInMemoryAuthTokenSource(ctx, i.tokenCache)
}

// GetInMemoryAuthTokenSource Returns the token source with cached token
func GetInMemoryAuthTokenSource(ctx context.Context, tokenCache cache.TokenCache) (oauth2.TokenSource, error) {
authToken, err := tokenCache.GetToken()
if err != nil {
return nil, err
}

Check warning on line 291 in flyteidl/clients/go/admin/token_source_provider.go

View check run for this annotation

Codecov / codecov/patch

flyteidl/clients/go/admin/token_source_provider.go#L290-L291

Added lines #L290 - L291 were not covered by tests
return &pkce.SimpleTokenSource{
CachedToken: authToken,
}, nil
}

type DeviceFlowTokenSourceProvider struct {
tokenOrchestrator deviceflow.TokenOrchestrator
}
Expand Down
25 changes: 13 additions & 12 deletions flyteidl/clients/go/admin/token_source_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

tokenCacheMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache/mocks"
adminMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/mocks"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
)

Expand Down Expand Up @@ -88,9 +89,9 @@ func TestCustomTokenSource_Token(t *testing.T) {
minuteAgo := time.Now().Add(-time.Minute)
hourAhead := time.Now().Add(time.Hour)
twoHourAhead := time.Now().Add(2 * time.Hour)
invalidToken := oauth2.Token{AccessToken: "foo", Expiry: minuteAgo}
validToken := oauth2.Token{AccessToken: "foo", Expiry: hourAhead}
newToken := oauth2.Token{AccessToken: "foo", Expiry: twoHourAhead}
invalidToken := utils.GenTokenWithCustomExpiry(t, minuteAgo)
validToken := utils.GenTokenWithCustomExpiry(t, hourAhead)
newToken := utils.GenTokenWithCustomExpiry(t, twoHourAhead)

tests := []struct {
name string
Expand All @@ -101,24 +102,24 @@ func TestCustomTokenSource_Token(t *testing.T) {
{
name: "no cached token",
token: nil,
newToken: &newToken,
expectedToken: &newToken,
newToken: newToken,
expectedToken: newToken,
},
{
name: "cached token valid",
token: &validToken,
token: validToken,
newToken: nil,
expectedToken: &validToken,
expectedToken: validToken,
},
{
name: "cached token expired",
token: &invalidToken,
newToken: &newToken,
expectedToken: &newToken,
token: invalidToken,
newToken: newToken,
expectedToken: newToken,
},
{
name: "failed new token",
token: &invalidToken,
token: invalidToken,
newToken: nil,
expectedToken: nil,
},
Expand All @@ -138,7 +139,7 @@ func TestCustomTokenSource_Token(t *testing.T) {
assert.True(t, ok)

mockSource := &adminMocks.TokenSource{}
if test.token != &validToken {
if test.token != validToken {
if test.newToken != nil {
mockSource.OnToken().Return(test.newToken, nil)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/oauth"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/logger"
Expand Down Expand Up @@ -52,7 +53,8 @@ func (t BaseTokenOrchestrator) FetchTokenFromCacheOrRefreshIt(ctx context.Contex
return nil, err
}

if token.Valid() {
if isValid := utils.Valid(token); isValid {
logger.Infof(context.Background(), "retrieved token from cache with expiry %v", token.Expiry)
return token, nil
}

Expand Down
Loading

0 comments on commit adc3ce2

Please sign in to comment.