Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Feat: Enable proxy-authorization in admin client #437

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 77 additions & 8 deletions clients/go/admin/auth_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package admin

import (
"context"
"errors"
"fmt"
"net/http"

Expand All @@ -16,10 +17,12 @@ import (
"google.golang.org/grpc"
)

const ProxyAuthorizationHeader = "proxy-authorization"

// MaterializeCredentials will attempt to build a TokenSource given the anonymously available information exposed by the server.
// Once established, it'll invoke PerRPCCredentialsFuture.Store() on perRPCCredentials to populate it with the appropriate values.
func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, perRPCCredentials *PerRPCCredentialsFuture) error {
authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg)
func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, perRPCCredentials *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) error {
authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg, proxyCredentialsFuture)
if err != nil {
return fmt.Errorf("failed to initialized Auth Metadata Client. Error: %w", err)
}
Expand Down Expand Up @@ -48,19 +51,70 @@ func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.T
return nil
}

func GetProxyTokenSource(ctx context.Context, cfg *Config) (oauth2.TokenSource, error) {
tokenSourceProvider, err := NewExternalTokenSourceProvider(cfg.ProxyCommand)
if err != nil {
return nil, fmt.Errorf("failed to initialized proxy authorization token source provider. Err: %w", err)
}
proxyTokenSource, err := tokenSourceProvider.GetTokenSource(ctx)
if err != nil {
return nil, err
}
return proxyTokenSource, nil
}

func MaterializeProxyAuthCredentials(ctx context.Context, cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture) error {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adapted from MaterializeAuthCredentials.

proxyTokenSource, err := GetProxyTokenSource(ctx, cfg)
if err != nil {
return err
}

wrappedTokenSource := NewCustomHeaderTokenSource(proxyTokenSource, cfg.UseInsecureConnection, ProxyAuthorizationHeader)
proxyCredentialsFuture.Store(wrappedTokenSource)

return nil
}

func shouldAttemptToAuthenticate(errorCode codes.Code) bool {
return errorCode == codes.Unauthenticated
}

type proxyAuthTransport struct {
transport http.RoundTripper
proxyCredentialsFuture *PerRPCCredentialsFuture
}

func (c *proxyAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// check if the proxy credentials future is initialized
if !c.proxyCredentialsFuture.IsInitialized() {
return nil, errors.New("proxy credentials future is not initialized")
}

metadata, err := c.proxyCredentialsFuture.GetRequestMetadata(context.Background(), "")
if err != nil {
return nil, err
}
token := metadata[ProxyAuthorizationHeader]
req.Header.Add(ProxyAuthorizationHeader, token)
return c.transport.RoundTrip(req)
}

// Set up http client used in oauth2
func setHTTPClientContext(ctx context.Context, cfg *Config) context.Context {
func setHTTPClientContext(ctx context.Context, cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture) context.Context {
httpClient := &http.Client{}
transport := &http.Transport{}

if len(cfg.HTTPProxyURL.String()) > 0 {
// create a transport that uses the proxy
transport := &http.Transport{
Proxy: http.ProxyURL(&cfg.HTTPProxyURL.URL),
transport.Proxy = http.ProxyURL(&cfg.HTTPProxyURL.URL)
}

if cfg.ProxyCommand != nil {
httpClient.Transport = &proxyAuthTransport{
transport: transport,
proxyCredentialsFuture: proxyCredentialsFuture,
}
} else {
httpClient.Transport = transport
}

Expand All @@ -77,9 +131,9 @@ func setHTTPClientContext(ctx context.Context, cfg *Config) context.Context {
// more. It'll fail hard if it couldn't do so (i.e. it will no longer attempt to send an unauthenticated request). Once
// a token source has been created, it'll invoke the grpc pipeline again, this time the grpc.PerRPCCredentials should
// be able to find and acquire a valid AccessToken to annotate the request with.
func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor {
func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFuture *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
ctx = setHTTPClientContext(ctx, cfg)
ctx = setHTTPClientContext(ctx, cfg, proxyCredentialsFuture)

err := invoker(ctx, method, req, reply, cc, opts...)
if err != nil {
Expand All @@ -89,7 +143,7 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut
// If the error we receive from executing the request expects
if shouldAttemptToAuthenticate(st.Code()) {
logger.Debugf(ctx, "Request failed due to [%v]. Attempting to establish an authenticated connection and trying again.", st.Code())
newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture)
newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture)
if newErr != nil {
return fmt.Errorf("authentication error! Original Error: %v, Auth Error: %w", err, newErr)
}
Expand All @@ -102,3 +156,18 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut
return err
}
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adapted from NewAuthInterceptor.

func NewProxyAuthInterceptor(cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {

err := invoker(ctx, method, req, reply, cc, opts...)
if err != nil {
newErr := MaterializeProxyAuthCredentials(ctx, cfg, proxyCredentialsFuture)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The external command is called a single time per call of flytectl. After the call of the external command, the credentials future is initialized and can be reused for further grpc calls.

if newErr != nil {
return fmt.Errorf("proxy authorization error! Original Error: %v, Proxy Auth Error: %w", err, newErr)
}
return invoker(ctx, method, req, reply, cc, opts...)
}
return err
}
}
129 changes: 123 additions & 6 deletions clients/go/admin/auth_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/oauth2"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -114,7 +115,8 @@ func newAuthMetadataServer(t testing.TB, port int, impl service.AuthMetadataServ
func Test_newAuthInterceptor(t *testing.T) {
t.Run("Other Error", func(t *testing.T) {
f := NewPerRPCCredentialsFuture()
interceptor := NewAuthInterceptor(&Config{}, &mocks.TokenCache{}, f)
p := NewPerRPCCredentialsFuture()
interceptor := NewAuthInterceptor(&Config{}, &mocks.TokenCache{}, f, p)
otherError := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return status.New(codes.Canceled, "").Err()
}
Expand Down Expand Up @@ -146,11 +148,12 @@ func Test_newAuthInterceptor(t *testing.T) {
assert.NoError(t, err)

f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()
interceptor := NewAuthInterceptor(&Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: AuthTypeClientSecret,
}, &mocks.TokenCache{}, f)
}, &mocks.TokenCache{}, f, p)
unauthenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return status.New(codes.Unauthenticated, "").Err()
}
Expand All @@ -177,11 +180,13 @@ func Test_newAuthInterceptor(t *testing.T) {
assert.NoError(t, err)

f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()

interceptor := NewAuthInterceptor(&Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: AuthTypeClientSecret,
}, &mocks.TokenCache{}, f)
}, &mocks.TokenCache{}, f, p)
authenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return nil
}
Expand Down Expand Up @@ -216,11 +221,13 @@ func Test_newAuthInterceptor(t *testing.T) {
assert.NoError(t, err)

f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()

interceptor := NewAuthInterceptor(&Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: AuthTypeClientSecret,
}, &mocks.TokenCache{}, f)
}, &mocks.TokenCache{}, f, p)
unauthenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return status.New(codes.Aborted, "").Err()
}
Expand All @@ -246,6 +253,8 @@ func TestMaterializeCredentials(t *testing.T) {
assert.NoError(t, err)

f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()

err = MaterializeCredentials(ctx, &Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
Expand All @@ -254,7 +263,7 @@ func TestMaterializeCredentials(t *testing.T) {
Scopes: []string{"all"},
Audience: "http://localhost:30081",
AuthorizationHeader: "authorization",
}, &mocks.TokenCache{}, f)
}, &mocks.TokenCache{}, f, p)
assert.NoError(t, err)
})
t.Run("Failed to fetch client metadata", func(t *testing.T) {
Expand All @@ -271,13 +280,121 @@ func TestMaterializeCredentials(t *testing.T) {
assert.NoError(t, err)

f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()

err = MaterializeCredentials(ctx, &Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: AuthTypeClientSecret,
TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", port),
Scopes: []string{"all"},
}, &mocks.TokenCache{}, f)
}, &mocks.TokenCache{}, f, p)
assert.EqualError(t, err, "failed to fetch client metadata. Error: rpc error: code = Unknown desc = expected err")
})
}

func TestNewProxyAuthInterceptor(t *testing.T) {
cfg := &Config{
ProxyCommand: []string{"echo", "test-token"},
}

p := NewPerRPCCredentialsFuture()

interceptor := NewProxyAuthInterceptor(cfg, p)

ctx := context.Background()
method := "/test.method"
req := "request"
reply := "reply"
cc := new(grpc.ClientConn)

errorInvoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return errors.New("test error")
}

// Call should return an error and trigger the interceptor to materialize proxy auth credentials
err := interceptor(ctx, method, req, reply, cc, errorInvoker)
assert.Error(t, err)

// Check if proxyCredentialsFuture contains a proxy auth header token
creds, err := p.Get().GetRequestMetadata(ctx, "")
assert.True(t, p.IsInitialized())
assert.NoError(t, err)
assert.Equal(t, "Bearer test-token", creds[ProxyAuthorizationHeader])
}

type testRoundTripper struct {
RoundTripFunc func(req *http.Request) (*http.Response, error)
}

func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return t.RoundTripFunc(req)
}

func TestSetHTTPClientContext(t *testing.T) {
ctx := context.Background()

t.Run("no proxy command and no proxy url", func(t *testing.T) {
cfg := &Config{}

newCtx, err := setHTTPClientContext(ctx, cfg, nil)
assert.NoError(t, err)

httpClient, ok := newCtx.Value(oauth2.HTTPClient).(*http.Client)
assert.True(t, ok)

transport, ok := httpClient.Transport.(*http.Transport)
assert.True(t, ok)
assert.Nil(t, transport.Proxy)
})

t.Run("proxy url", func(t *testing.T) {
cfg := &Config{
HTTPProxyURL: config.
URL{URL: url.URL{
Scheme: "http",
Host: "localhost:8080",
}},
}
newCtx, err := setHTTPClientContext(ctx, cfg, nil)
assert.NoError(t, err)

httpClient, ok := newCtx.Value(oauth2.HTTPClient).(*http.Client)
assert.True(t, ok)

transport, ok := httpClient.Transport.(*http.Transport)
assert.True(t, ok)
assert.NotNil(t, transport.Proxy)
})

t.Run("proxy command adds proxy-authorization header", func(t *testing.T) {
cfg := &Config{
ProxyCommand: []string{"echo", "test-token-http-client"},
}

p := NewPerRPCCredentialsFuture()
MaterializeProxyAuthCredentials(ctx, cfg, p)

newCtx, err := setHTTPClientContext(ctx, cfg, p)
assert.NoError(t, err)

httpClient, ok := newCtx.Value(oauth2.HTTPClient).(*http.Client)
assert.True(t, ok)

pat, ok := httpClient.Transport.(*proxyAuthTransport)
assert.True(t, ok)

testRoundTripper := &testRoundTripper{
RoundTripFunc: func(req *http.Request) (*http.Response, error) {
// Check if the ProxyAuthorizationHeader is correctly set
assert.Equal(t, "Bearer test-token-http-client", req.Header.Get(ProxyAuthorizationHeader))
return &http.Response{StatusCode: http.StatusOK}, nil
},
}
pat.transport = testRoundTripper

req, _ := http.NewRequest("GET", "http://example.com", nil)
_, err = httpClient.Do(req)
assert.NoError(t, err)
})
}
19 changes: 13 additions & 6 deletions clients/go/admin/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,21 +110,21 @@ func getAuthenticationDialOption(ctx context.Context, cfg *Config, tokenSourcePr
}

// InitializeAuthMetadataClient creates a new anonymously Auth Metadata Service client.
func InitializeAuthMetadataClient(ctx context.Context, cfg *Config) (client service.AuthMetadataServiceClient, err error) {
func InitializeAuthMetadataClient(ctx context.Context, cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture) (client service.AuthMetadataServiceClient, err error) {
// Create an unauthenticated connection to fetch AuthMetadata
authMetadataConnection, err := NewAdminConnection(ctx, cfg)
authMetadataConnection, err := NewAdminConnection(ctx, cfg, proxyCredentialsFuture)
if err != nil {
return nil, fmt.Errorf("failed to initialized admin connection. Error: %w", err)
}

return service.NewAuthMetadataServiceClient(authMetadataConnection), nil
}

func NewAdminConnection(ctx context.Context, cfg *Config, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
func NewAdminConnection(ctx context.Context, cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
if opts == nil {
// Initialize opts list to the potential number of options we will add. Initialization optimizes memory
// allocation.
opts = make([]grpc.DialOption, 0, 5)
opts = make([]grpc.DialOption, 0, 7)
}

if cfg.UseInsecureConnection {
Expand Down Expand Up @@ -153,6 +153,11 @@ func NewAdminConnection(ctx context.Context, cfg *Config, opts ...grpc.DialOptio

opts = append(opts, GetAdditionalAdminClientConfigOptions(cfg)...)

if cfg.ProxyCommand != nil {
opts = append(opts, grpc.WithChainUnaryInterceptor(NewProxyAuthInterceptor(cfg, proxyCredentialsFuture)))
opts = append(opts, grpc.WithPerRPCCredentials(proxyCredentialsFuture))
}

return grpc.Dial(cfg.Endpoint.String(), opts...)
}

Expand All @@ -172,15 +177,17 @@ func InitializeAdminClient(ctx context.Context, cfg *Config, opts ...grpc.DialOp
// for the process. Note that if called with different cfg/dialoptions, it will not refresh the connection.
func initializeClients(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, opts ...grpc.DialOption) (*Clientset, error) {
credentialsFuture := NewPerRPCCredentialsFuture()
proxyCredentialsFuture := NewPerRPCCredentialsFuture()

opts = append(opts,
grpc.WithChainUnaryInterceptor(NewAuthInterceptor(cfg, tokenCache, credentialsFuture)),
grpc.WithChainUnaryInterceptor(NewAuthInterceptor(cfg, tokenCache, credentialsFuture, proxyCredentialsFuture)),
grpc.WithPerRPCCredentials(credentialsFuture))

if cfg.DefaultServiceConfig != "" {
opts = append(opts, grpc.WithDefaultServiceConfig(cfg.DefaultServiceConfig))
}

adminConnection, err := NewAdminConnection(ctx, cfg, opts...)
adminConnection, err := NewAdminConnection(ctx, cfg, proxyCredentialsFuture, opts...)
if err != nil {
logger.Panicf(ctx, "failed to initialized Admin connection. Err: %s", err.Error())
}
Expand Down
Loading