Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auth/prevent lookup per call #5686

Merged
merged 15 commits into from
Aug 23, 2024
10 changes: 7 additions & 3 deletions flytectl/cmd/configuration/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,13 @@ func CreateConfigCommand() *cobra.Command {
configCmd := viper.GetConfigCommand()

getResourcesFuncs := map[string]cmdcore.CommandEntry{
"init": {CmdFunc: configInitFunc, Aliases: []string{""}, ProjectDomainNotRequired: true,
Short: initCmdShort,
Long: initCmdLong, PFlagProvider: initConfig.DefaultConfig},
"init": {
CmdFunc: configInitFunc,
Aliases: []string{""},
ProjectDomainNotRequired: true,
DisableFlyteClient: true,
Short: initCmdShort,
Long: initCmdLong, PFlagProvider: initConfig.DefaultConfig},
}

configCmd.Flags().BoolVar(&initConfig.DefaultConfig.Force, "force", false, "Force to overwrite the default config file without confirmation")
Expand Down
4 changes: 2 additions & 2 deletions flytectl/cmd/core/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestGenerateCommandFunc(t *testing.T) {
adminCfg.Endpoint = config.URL{URL: url.URL{Host: "dummyHost"}}
adminCfg.AuthType = admin.AuthTypePkce
rootCmd := &cobra.Command{}
cmdEntry := CommandEntry{CmdFunc: testCommandFunc, ProjectDomainNotRequired: true}
cmdEntry := CommandEntry{CmdFunc: testCommandFunc, ProjectDomainNotRequired: true, DisableFlyteClient: true}
fn := generateCommandFunc(cmdEntry)
assert.Nil(t, fn(rootCmd, []string{}))
})
Expand All @@ -30,7 +30,7 @@ func TestGenerateCommandFunc(t *testing.T) {
adminCfg := admin.GetConfig(context.Background())
adminCfg.Endpoint = config.URL{URL: url.URL{Host: ""}}
rootCmd := &cobra.Command{}
cmdEntry := CommandEntry{CmdFunc: testCommandFunc, ProjectDomainNotRequired: true}
cmdEntry := CommandEntry{CmdFunc: testCommandFunc, ProjectDomainNotRequired: true, DisableFlyteClient: true}
fn := generateCommandFunc(cmdEntry)
assert.Nil(t, fn(rootCmd, []string{}))
})
Expand Down
108 changes: 80 additions & 28 deletions flyteidl/clients/go/admin/auth_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"errors"
"fmt"
"net/http"
"sync"

"golang.org/x/oauth2"
"google.golang.org/grpc"
Expand All @@ -20,33 +21,10 @@

// MaterializeCredentials will attempt to build a TokenSource given the anonymously available information exposed by the server.
// Once established, it'll invoke PerRPCCredentialsFuture.Store() on perRPCCredentials to populate it with the appropriate values.
func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache,
perRPCCredentials *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) error {
authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg, proxyCredentialsFuture)
if err != nil {
return fmt.Errorf("failed to initialized Auth Metadata Client. Error: %w", err)
}

tokenSourceProvider, err := NewTokenSourceProvider(ctx, cfg, tokenCache, authMetadataClient)
if err != nil {
return fmt.Errorf("failed to initialized token source provider. Err: %w", err)
}

authorizationMetadataKey := cfg.AuthorizationHeader
if len(authorizationMetadataKey) == 0 {
clientMetadata, err := authMetadataClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{})
if err != nil {
return fmt.Errorf("failed to fetch client metadata. Error: %v", err)
}
authorizationMetadataKey = clientMetadata.AuthorizationMetadataKey
}

tokenSource, err := tokenSourceProvider.GetTokenSource(ctx)
if err != nil {
return fmt.Errorf("failed to get token source. Error: %w", err)
}
func MaterializeCredentials(tokenSource oauth2.TokenSource, cfg *Config, authorizationMetadataKey string,
perRPCCredentials *PerRPCCredentialsFuture) error {

_, err = tokenSource.Token()
_, err := tokenSource.Token()
if err != nil {
return fmt.Errorf("failed to issue token. Error: %w", err)
}
Expand Down Expand Up @@ -127,6 +105,60 @@
return context.WithValue(ctx, oauth2.HTTPClient, httpClient)
}

type OauthMetadataProvider struct {
authorizationMetadataKey string
tokenSource oauth2.TokenSource
once sync.Once
}

func (o *OauthMetadataProvider) getTokenSourceAndMetadata(cfg *Config, tokenCache cache.TokenCache, proxyCredentialsFuture *PerRPCCredentialsFuture) error {
ctx := context.Background()

authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg, proxyCredentialsFuture)
if err != nil {
return fmt.Errorf("failed to initialized Auth Metadata Client. Error: %w", err)

Check warning on line 119 in flyteidl/clients/go/admin/auth_interceptor.go

View check run for this annotation

Codecov / codecov/patch

flyteidl/clients/go/admin/auth_interceptor.go#L119

Added line #L119 was not covered by tests
}

tokenSourceProvider, err := NewTokenSourceProvider(ctx, cfg, tokenCache, authMetadataClient)
if err != nil {
return fmt.Errorf("failed to initialize token source provider. Err: %w", err)

Check warning on line 124 in flyteidl/clients/go/admin/auth_interceptor.go

View check run for this annotation

Codecov / codecov/patch

flyteidl/clients/go/admin/auth_interceptor.go#L124

Added line #L124 was not covered by tests
}

authorizationMetadataKey := cfg.AuthorizationHeader
if len(authorizationMetadataKey) == 0 {
clientMetadata, err := authMetadataClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{})
if err != nil {
return fmt.Errorf("failed to fetch client metadata. Error: %v", err)
}
authorizationMetadataKey = clientMetadata.AuthorizationMetadataKey
}

tokenSource, err := tokenSourceProvider.GetTokenSource(ctx)
if err != nil {
return fmt.Errorf("failed to get token source. Error: %w", err)

Check warning on line 138 in flyteidl/clients/go/admin/auth_interceptor.go

View check run for this annotation

Codecov / codecov/patch

flyteidl/clients/go/admin/auth_interceptor.go#L138

Added line #L138 was not covered by tests
}

o.authorizationMetadataKey = authorizationMetadataKey
o.tokenSource = tokenSource

return nil
}

func (o *OauthMetadataProvider) GetOauthMetadata(cfg *Config, tokenCache cache.TokenCache, proxyCredentialsFuture *PerRPCCredentialsFuture) error {
// Ensure loadTokenRelated() is only executed once
var err error
o.once.Do(func() {
err = o.getTokenSourceAndMetadata(cfg, tokenCache, proxyCredentialsFuture)
if err != nil {
logger.Errorf(context.Background(), "Failed to load token related config. Error: %v", err)
}
})
if err != nil {
return err
}
return nil
}

// NewAuthInterceptor creates a new grpc.UnaryClientInterceptor that forwards the grpc call and inspects the error.
// It will first invoke the grpc pipeline (to proceed with the request) with no modifications. It's expected for the grpc
// pipeline to already have a grpc.WithPerRPCCredentials() DialOption. If the perRPCCredentials has already been initialized,
Expand All @@ -138,13 +170,26 @@
// a token source has been created, it'll invoke the grpc pipeline again, this time the grpc.PerRPCCredentials should
// be able to find and acquire a valid AccessToken to annotate the request with.
func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFuture *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor {

oauthMetadataProvider := OauthMetadataProvider{
once: sync.Once{},
}

return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {

ctx = setHTTPClientContext(ctx, cfg, proxyCredentialsFuture)

// If there is already a token in the cache (e.g. key-ring), we should use it immediately...
t, _ := tokenCache.GetToken()
if t != nil {
err := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture)
err := oauthMetadataProvider.GetOauthMetadata(cfg, tokenCache, proxyCredentialsFuture)
if err != nil {
return err
}
authorizationMetadataKey := oauthMetadataProvider.authorizationMetadataKey
tokenSource := oauthMetadataProvider.tokenSource

err = MaterializeCredentials(tokenSource, cfg, authorizationMetadataKey, credentialsFuture)
if err != nil {
return fmt.Errorf("failed to materialize credentials. Error: %v", err)
}
Expand All @@ -157,6 +202,13 @@
if st, ok := status.FromError(err); ok {
// If the error we receive from executing the request expects
if shouldAttemptToAuthenticate(st.Code()) {
err := oauthMetadataProvider.GetOauthMetadata(cfg, tokenCache, proxyCredentialsFuture)
if err != nil {
return err

Check warning on line 207 in flyteidl/clients/go/admin/auth_interceptor.go

View check run for this annotation

Codecov / codecov/patch

flyteidl/clients/go/admin/auth_interceptor.go#L207

Added line #L207 was not covered by tests
}
authorizationMetadataKey := oauthMetadataProvider.authorizationMetadataKey
tokenSource := oauthMetadataProvider.tokenSource

err = func() error {
if !tokenCache.TryLock() {
tokenCache.CondWait()
Expand All @@ -171,7 +223,7 @@
}

logger.Debugf(ctx, "Request failed due to [%v]. Attempting to establish an authenticated connection and trying again.", st.Code())
newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture)
newErr := MaterializeCredentials(tokenSource, cfg, authorizationMetadataKey, credentialsFuture)
if newErr != nil {
errString := fmt.Sprintf("authentication error! Original Error: %v, Auth Error: %v", err, newErr)
logger.Errorf(ctx, errString)
Expand Down
88 changes: 78 additions & 10 deletions flyteidl/clients/go/admin/auth_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache/mocks"
adminMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/mocks"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/logger"
Expand Down Expand Up @@ -141,11 +142,34 @@ func Test_newAuthInterceptor(t *testing.T) {
err := json.Unmarshal(plan, &tokenData)
assert.NoError(t, err)
t.Run("Other Error", func(t *testing.T) {
ctx := context.Background()
httpPort := rand.IntnRange(10000, 60000)
grpcPort := rand.IntnRange(10000, 60000)
m := &adminMocks.AuthMetadataServiceServer{}
m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service.OAuth2MetadataResponse{
AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", httpPort),
TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", httpPort),
JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", httpPort),
}, nil)

m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&service.PublicClientAuthConfigResponse{
Scopes: []string{"all"},
}, nil)

s := newAuthMetadataServer(t, grpcPort, httpPort, m)
assert.NoError(t, s.Start(ctx))
defer s.Close()
u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", grpcPort))
assert.NoError(t, err)
f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()
mockTokenCache := &mocks.TokenCache{}
mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil)
interceptor := NewAuthInterceptor(&Config{}, mockTokenCache, f, p)
mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil)
interceptor := NewAuthInterceptor(&Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
}, mockTokenCache, f, p)
otherError := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return status.New(codes.Canceled, "").Err()
}
Expand Down Expand Up @@ -209,6 +233,14 @@ func Test_newAuthInterceptor(t *testing.T) {
httpPort := rand.IntnRange(10000, 60000)
grpcPort := rand.IntnRange(10000, 60000)
m := &adminMocks.AuthMetadataServiceServer{}
m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service.OAuth2MetadataResponse{
AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", httpPort),
TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", httpPort),
JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", httpPort),
}, nil)
m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&service.PublicClientAuthConfigResponse{
Scopes: []string{"all"},
}, nil)
s := newAuthMetadataServer(t, grpcPort, httpPort, m)
ctx := context.Background()
assert.NoError(t, s.Start(ctx))
Expand Down Expand Up @@ -283,12 +315,13 @@ func Test_newAuthInterceptor(t *testing.T) {
})
}

func TestMaterializeCredentials(t *testing.T) {
func TestNewAuthInterceptorAndMaterialize(t *testing.T) {
t.Run("No oauth2 metadata endpoint or Public client config lookup", func(t *testing.T) {
httpPort := rand.IntnRange(10000, 60000)
grpcPort := rand.IntnRange(10000, 60000)
fakeToken := &oauth2.Token{}
c := &mocks.TokenCache{}
c.OnGetTokenMatch().Return(nil, nil)
c.OnGetTokenMatch().Return(fakeToken, nil)
c.OnSaveTokenMatch(mock.Anything).Return(nil)
m := &adminMocks.AuthMetadataServiceServer{}
m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get oauth2 metadata"))
Expand All @@ -304,22 +337,30 @@ func TestMaterializeCredentials(t *testing.T) {
f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()

err = MaterializeCredentials(ctx, &Config{
cfg := &Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: AuthTypeClientSecret,
TokenURL: fmt.Sprintf("http://localhost:%d/oauth2/token", httpPort),
Scopes: []string{"all"},
Audience: "http://localhost:30081",
AuthorizationHeader: "authorization",
}, c, f, p)
}

intercept := NewAuthInterceptor(cfg, c, f, p)
// Invoke Materialize inside the intercept
err = intercept(ctx, "GET", nil, nil, nil, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return nil
})
assert.NoError(t, err)
})

t.Run("Failed to fetch client metadata", func(t *testing.T) {
httpPort := rand.IntnRange(10000, 60000)
grpcPort := rand.IntnRange(10000, 60000)
c := &mocks.TokenCache{}
c.OnGetTokenMatch().Return(nil, nil)
fakeToken := &oauth2.Token{}
c.OnGetTokenMatch().Return(fakeToken, nil)
c.OnSaveTokenMatch(mock.Anything).Return(nil)
m := &adminMocks.AuthMetadataServiceServer{}
m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get oauth2 metadata"))
Expand All @@ -333,17 +374,44 @@ func TestMaterializeCredentials(t *testing.T) {
u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", grpcPort))
assert.NoError(t, err)

cfg := &Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: AuthTypeClientSecret,
TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", httpPort),
Scopes: []string{"all"},
}
f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()
intercept := NewAuthInterceptor(cfg, c, f, p)
err = intercept(ctx, "GET", nil, nil, nil, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return nil
})
assert.EqualError(t, err, "failed to fetch client metadata. Error: rpc error: code = Unknown desc = expected err")
})
}

func TestSimpleMaterializeCredentials(t *testing.T) {
t.Run("simple materialize", func(t *testing.T) {
httpPort := rand.IntnRange(10000, 60000)
grpcPort := rand.IntnRange(10000, 60000)
u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", grpcPort))
assert.NoError(t, err)

f := NewPerRPCCredentialsFuture()

err = MaterializeCredentials(ctx, &Config{
dummySource := DummyTestTokenSource{}

err = MaterializeCredentials(dummySource, &Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: AuthTypeClientSecret,
TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", httpPort),
TokenURL: fmt.Sprintf("http://localhost:%d/oauth2/token", httpPort),
Scopes: []string{"all"},
}, c, f, p)
assert.EqualError(t, err, "failed to fetch client metadata. Error: rpc error: code = Unknown desc = expected err")
Audience: "http://localhost:30081",
AuthorizationHeader: "authorization",
}, "authorization", f)
assert.NoError(t, err)
})
}

Expand Down
3 changes: 2 additions & 1 deletion flyteidl/clients/go/admin/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,9 @@ func initializeClients(ctx context.Context, cfg *Config, tokenCache cache.TokenC
credentialsFuture := NewPerRPCCredentialsFuture()
proxyCredentialsFuture := NewPerRPCCredentialsFuture()

authInterceptor := NewAuthInterceptor(cfg, tokenCache, credentialsFuture, proxyCredentialsFuture)
opts = append(opts,
grpc.WithChainUnaryInterceptor(NewAuthInterceptor(cfg, tokenCache, credentialsFuture, proxyCredentialsFuture)),
grpc.WithChainUnaryInterceptor(authInterceptor),
grpc.WithPerRPCCredentials(credentialsFuture))

if cfg.DefaultServiceConfig != "" {
Expand Down
Loading
Loading