diff --git a/api-presigned.go b/api-presigned.go index 9e85f8181..29642200e 100644 --- a/api-presigned.go +++ b/api-presigned.go @@ -140,7 +140,7 @@ func (c *Client) PresignedPostPolicy(ctx context.Context, p *PostPolicy) (u *url } // Get credentials from the configured credentials provider. - credValues, err := c.credsProvider.Get() + credValues, err := c.credsProvider.GetWithContext(c.CredContext()) if err != nil { return nil, nil, err } diff --git a/api.go b/api.go index 86d072e4f..5bcd903e3 100644 --- a/api.go +++ b/api.go @@ -600,9 +600,9 @@ func (c *Client) executeMethod(ctx context.Context, method string, metadata requ return nil, errors.New(c.endpointURL.String() + " is offline.") } - var retryable bool // Indicates if request can be retried. - var bodySeeker io.Seeker // Extracted seeker from io.Reader. - var reqRetry = c.maxRetries // Indicates how many times we can retry the request + var retryable bool // Indicates if request can be retried. + var bodySeeker io.Seeker // Extracted seeker from io.Reader. + reqRetry := c.maxRetries // Indicates how many times we can retry the request if metadata.contentBody != nil { // Check if body is seekable then it is retryable. @@ -808,7 +808,7 @@ func (c *Client) newRequest(ctx context.Context, method string, metadata request } // Get credentials from the configured credentials provider. - value, err := c.credsProvider.Get() + value, err := c.credsProvider.GetWithContext(c.CredContext()) if err != nil { return nil, err } @@ -1018,3 +1018,14 @@ func (c *Client) isVirtualHostStyleRequest(url url.URL, bucketName string) bool // path style requests return s3utils.IsVirtualHostSupported(url, bucketName) } + +// CredContext returns the context for fetching credentials +func (c *Client) CredContext() *credentials.CredContext { + httpClient := c.httpClient + if httpClient == nil { + httpClient = http.DefaultClient + } + return &credentials.CredContext{ + Client: httpClient, + } +} diff --git a/bucket-cache.go b/bucket-cache.go index b1d3b3852..4e4305acd 100644 --- a/bucket-cache.go +++ b/bucket-cache.go @@ -212,7 +212,7 @@ func (c *Client) getBucketLocationRequest(ctx context.Context, bucketName string c.setUserAgent(req) // Get credentials from the configured credentials provider. - value, err := c.credsProvider.Get() + value, err := c.credsProvider.GetWithContext(c.CredContext()) if err != nil { return nil, err } diff --git a/bucket-cache_test.go b/bucket-cache_test.go index 61ac9bd55..40f76ca20 100644 --- a/bucket-cache_test.go +++ b/bucket-cache_test.go @@ -97,7 +97,7 @@ func TestGetBucketLocationRequest(t *testing.T) { c.setUserAgent(req) // Get credentials from the configured credentials provider. - value, err := c.credsProvider.Get() + value, err := c.credsProvider.GetWithContext(c.CredContext()) if err != nil { return nil, err } diff --git a/pkg/credentials/assume_role.go b/pkg/credentials/assume_role.go index d245bc07a..06f79dd56 100644 --- a/pkg/credentials/assume_role.go +++ b/pkg/credentials/assume_role.go @@ -76,7 +76,8 @@ type AssumeRoleResult struct { type STSAssumeRole struct { Expiry - // Required http Client to use when connecting to MinIO STS service. + // Optional http Client to use when connecting to MinIO STS service + // (overrides default client in CredContext) Client *http.Client // STS endpoint to fetch STS credentials. @@ -115,9 +116,6 @@ func NewSTSAssumeRole(stsEndpoint string, opts STSAssumeRoleOptions) (*Credentia return nil, errors.New("AssumeRole credentials access/secretkey is mandatory") } return New(&STSAssumeRole{ - Client: &http.Client{ - Transport: http.DefaultTransport, - }, STSEndpoint: stsEndpoint, Options: opts, }), nil @@ -224,8 +222,12 @@ func getAssumeRoleCredentials(clnt *http.Client, endpoint string, opts STSAssume // Retrieve retrieves credentials from the MinIO service. // Error will be returned if the request fails. -func (m *STSAssumeRole) Retrieve() (Value, error) { - a, err := getAssumeRoleCredentials(m.Client, m.STSEndpoint, m.Options) +func (m *STSAssumeRole) Retrieve(cc *CredContext) (Value, error) { + client := m.Client + if client == nil { + client = cc.Client + } + a, err := getAssumeRoleCredentials(client, m.STSEndpoint, m.Options) if err != nil { return Value{}, err } diff --git a/pkg/credentials/chain.go b/pkg/credentials/chain.go index ddccfb173..7e963fe10 100644 --- a/pkg/credentials/chain.go +++ b/pkg/credentials/chain.go @@ -60,9 +60,9 @@ func NewChainCredentials(providers []Provider) *Credentials { // // If a provider is found with credentials, it will be cached and any calls // to IsExpired() will return the expired state of the cached provider. -func (c *Chain) Retrieve() (Value, error) { +func (c *Chain) Retrieve(cc *CredContext) (Value, error) { for _, p := range c.Providers { - creds, _ := p.Retrieve() + creds, _ := p.Retrieve(cc) // Always prioritize non-anonymous providers, if any. if creds.AccessKeyID == "" && creds.SecretAccessKey == "" { continue diff --git a/pkg/credentials/chain_test.go b/pkg/credentials/chain_test.go index 280b37c64..80177148d 100644 --- a/pkg/credentials/chain_test.go +++ b/pkg/credentials/chain_test.go @@ -28,7 +28,7 @@ type testCredProvider struct { err error } -func (s *testCredProvider) Retrieve() (Value, error) { +func (s *testCredProvider) Retrieve(_ *CredContext) (Value, error) { s.expired = false return s.creds, s.err } @@ -59,7 +59,7 @@ func TestChainGet(t *testing.T) { }, } - creds, err := p.Retrieve() + creds, err := p.Retrieve(defaultCredContext) if err != nil { t.Fatal(err) } @@ -95,7 +95,7 @@ func TestChainIsExpired(t *testing.T) { t.Fatal("Expected expired to be true before any Retrieve") } - _, err := p.Retrieve() + _, err := p.Retrieve(defaultCredContext) if err != nil { t.Fatal(err) } @@ -112,7 +112,7 @@ func TestChainWithNoProvider(t *testing.T) { if !p.IsExpired() { t.Fatal("Expected to be expired with no providers") } - _, err := p.Retrieve() + _, err := p.Retrieve(defaultCredContext) if err != nil { if err.Error() != "No valid providers found []" { t.Error(err) @@ -136,7 +136,7 @@ func TestChainProviderWithNoValidProvider(t *testing.T) { t.Fatal("Expected to be expired with no providers") } - _, err := p.Retrieve() + _, err := p.Retrieve(defaultCredContext) if err != nil { if err.Error() != "No valid providers found [FirstError SecondError]" { t.Error(err) diff --git a/pkg/credentials/credentials.go b/pkg/credentials/credentials.go index 68f9b3815..b9f846cf0 100644 --- a/pkg/credentials/credentials.go +++ b/pkg/credentials/credentials.go @@ -18,6 +18,7 @@ package credentials import ( + "net/http" "sync" "time" ) @@ -30,6 +31,10 @@ const ( defaultExpiryWindow = 0.8 ) +// defaultCredContext is used when the credential context doesn't +// actually matter or the default context is suitable. +var defaultCredContext = &CredContext{Client: http.DefaultClient} + // A Value is the S3 credentials value for individual credential fields. type Value struct { // S3 Access key ID @@ -54,13 +59,21 @@ type Value struct { type Provider interface { // Retrieve returns nil if it successfully retrieved the value. // Error is returned if the value were not obtainable, or empty. - Retrieve() (Value, error) + Retrieve(cc *CredContext) (Value, error) // IsExpired returns if the credentials are no longer valid, and need // to be retrieved. IsExpired() bool } +// CredContext is passed to the Retrieve function of a provider to provide +// some additional context to retrieve credentials. +type CredContext struct { + // Client specifies the HTTP client that should be used if an HTTP + // request is to be made to fetch the credentials. + Client *http.Client +} + // A Expiry provides shared expiration logic to be used by credentials // providers to implement expiry functionality. // @@ -146,7 +159,24 @@ func New(provider Provider) *Credentials { // // If Credentials.Expire() was called the credentials Value will be force // expired, and the next call to Get() will cause them to be refreshed. +// +// Deprecated: Get() exists for historical compatibility and should not be +// used. To get new credentials use the Credentials.GetWithContext function +// to ensure the proper context (i.e. HTTP client) will be used. func (c *Credentials) Get() (Value, error) { + return c.GetWithContext(defaultCredContext) +} + +// GetWithContext returns the credentials value, or error if the +// credentials Value failed to be retrieved. +// +// Will return the cached credentials Value if it has not expired. If the +// credentials Value has expired the Provider's Retrieve() will be called +// to refresh the credentials. +// +// If Credentials.Expire() was called the credentials Value will be force +// expired, and the next call to Get() will cause them to be refreshed. +func (c *Credentials) GetWithContext(cc *CredContext) (Value, error) { if c == nil { return Value{}, nil } @@ -155,7 +185,7 @@ func (c *Credentials) Get() (Value, error) { defer c.Unlock() if c.isExpired() { - creds, err := c.provider.Retrieve() + creds, err := c.provider.Retrieve(cc) if err != nil { return Value{}, err } diff --git a/pkg/credentials/credentials_test.go b/pkg/credentials/credentials_test.go index 828345027..ac77a3a5e 100644 --- a/pkg/credentials/credentials_test.go +++ b/pkg/credentials/credentials_test.go @@ -28,7 +28,7 @@ type credProvider struct { err error } -func (s *credProvider) Retrieve() (Value, error) { +func (s *credProvider) Retrieve(_ *CredContext) (Value, error) { s.expired = false return s.creds, s.err } @@ -47,7 +47,7 @@ func TestCredentialsGet(t *testing.T) { expired: true, }) - creds, err := c.Get() + creds, err := c.GetWithContext(defaultCredContext) if err != nil { t.Fatal(err) } @@ -65,7 +65,7 @@ func TestCredentialsGet(t *testing.T) { func TestCredentialsGetWithError(t *testing.T) { c := New(&credProvider{err: errors.New("Custom error")}) - _, err := c.Get() + _, err := c.GetWithContext(defaultCredContext) if err != nil { if err.Error() != "Custom error" { t.Errorf("Expected \"Custom error\", got %s", err.Error()) diff --git a/pkg/credentials/env_aws.go b/pkg/credentials/env_aws.go index b6e60d0e1..a6fc95d03 100644 --- a/pkg/credentials/env_aws.go +++ b/pkg/credentials/env_aws.go @@ -38,7 +38,7 @@ func NewEnvAWS() *Credentials { } // Retrieve retrieves the keys from the environment. -func (e *EnvAWS) Retrieve() (Value, error) { +func (e *EnvAWS) Retrieve(_ *CredContext) (Value, error) { e.retrieved = false id := os.Getenv("AWS_ACCESS_KEY_ID") diff --git a/pkg/credentials/env_minio.go b/pkg/credentials/env_minio.go index 5bfeab140..ba1e94934 100644 --- a/pkg/credentials/env_minio.go +++ b/pkg/credentials/env_minio.go @@ -39,7 +39,7 @@ func NewEnvMinio() *Credentials { } // Retrieve retrieves the keys from the environment. -func (e *EnvMinio) Retrieve() (Value, error) { +func (e *EnvMinio) Retrieve(_ *CredContext) (Value, error) { e.retrieved = false id := os.Getenv("MINIO_ROOT_USER") diff --git a/pkg/credentials/env_test.go b/pkg/credentials/env_test.go index 5e9240e33..8e374d2d0 100644 --- a/pkg/credentials/env_test.go +++ b/pkg/credentials/env_test.go @@ -34,7 +34,7 @@ func TestEnvAWSRetrieve(t *testing.T) { t.Error("Expect creds to be expired before retrieve.") } - creds, err := e.Retrieve() + creds, err := e.Retrieve(defaultCredContext) if err != nil { t.Fatal(err) } @@ -63,7 +63,7 @@ func TestEnvAWSRetrieve(t *testing.T) { SignerType: SignatureV4, } - creds, err = e.Retrieve() + creds, err = e.Retrieve(defaultCredContext) if err != nil { t.Fatal(err) } @@ -84,7 +84,7 @@ func TestEnvMinioRetrieve(t *testing.T) { t.Error("Expect creds to be expired before retrieve.") } - creds, err := e.Retrieve() + creds, err := e.Retrieve(defaultCredContext) if err != nil { t.Fatal(err) } diff --git a/pkg/credentials/file_aws_credentials.go b/pkg/credentials/file_aws_credentials.go index 541e1a72f..5fd08950a 100644 --- a/pkg/credentials/file_aws_credentials.go +++ b/pkg/credentials/file_aws_credentials.go @@ -73,7 +73,7 @@ func NewFileAWSCredentials(filename, profile string) *Credentials { // Retrieve reads and extracts the shared credentials from the current // users home directory. -func (p *FileAWSCredentials) Retrieve() (Value, error) { +func (p *FileAWSCredentials) Retrieve(_ *CredContext) (Value, error) { if p.Filename == "" { p.Filename = os.Getenv("AWS_SHARED_CREDENTIALS_FILE") if p.Filename == "" { diff --git a/pkg/credentials/file_minio_client.go b/pkg/credentials/file_minio_client.go index 750e26ffa..0e1131fca 100644 --- a/pkg/credentials/file_minio_client.go +++ b/pkg/credentials/file_minio_client.go @@ -58,7 +58,7 @@ func NewFileMinioClient(filename, alias string) *Credentials { // Retrieve reads and extracts the shared credentials from the current // users home directory. -func (p *FileMinioClient) Retrieve() (Value, error) { +func (p *FileMinioClient) Retrieve(_ *CredContext) (Value, error) { if p.Filename == "" { if value, ok := os.LookupEnv("MINIO_SHARED_CREDENTIALS_FILE"); ok { p.Filename = value diff --git a/pkg/credentials/file_test.go b/pkg/credentials/file_test.go index fab48dc44..3136a45c3 100644 --- a/pkg/credentials/file_test.go +++ b/pkg/credentials/file_test.go @@ -31,7 +31,7 @@ func TestFileAWS(t *testing.T) { os.Clearenv() creds := NewFileAWSCredentials("credentials.sample", "") - credValues, err := creds.Get() + credValues, err := creds.GetWithContext(defaultCredContext) if err != nil { t.Fatal(err) } @@ -48,7 +48,7 @@ func TestFileAWS(t *testing.T) { os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "credentials.sample") creds = NewFileAWSCredentials("", "") - credValues, err = creds.Get() + credValues, err = creds.GetWithContext(defaultCredContext) if err != nil { t.Fatal(err) } @@ -70,7 +70,7 @@ func TestFileAWS(t *testing.T) { os.Setenv("AWS_SHARED_CREDENTIALS_FILE", filepath.Join(wd, "credentials.sample")) creds = NewFileAWSCredentials("", "") - credValues, err = creds.Get() + credValues, err = creds.GetWithContext(defaultCredContext) if err != nil { t.Fatal(err) } @@ -89,7 +89,7 @@ func TestFileAWS(t *testing.T) { os.Setenv("AWS_PROFILE", "no_token") creds = NewFileAWSCredentials("credentials.sample", "") - credValues, err = creds.Get() + credValues, err = creds.GetWithContext(defaultCredContext) if err != nil { t.Fatal(err) } @@ -104,7 +104,7 @@ func TestFileAWS(t *testing.T) { os.Clearenv() creds = NewFileAWSCredentials("credentials.sample", "no_token") - credValues, err = creds.Get() + credValues, err = creds.GetWithContext(defaultCredContext) if err != nil { t.Fatal(err) } @@ -117,7 +117,7 @@ func TestFileAWS(t *testing.T) { } creds = NewFileAWSCredentials("credentials-non-existent.sample", "no_token") - _, err = creds.Get() + _, err = creds.GetWithContext(defaultCredContext) if !os.IsNotExist(err) { t.Errorf("Expected open non-existent.json: no such file or directory, got %s", err) } @@ -128,7 +128,7 @@ func TestFileAWS(t *testing.T) { os.Clearenv() creds = NewFileAWSCredentials("credentials.sample", "with_process") - credValues, err = creds.Get() + credValues, err = creds.GetWithContext(defaultCredContext) if err != nil { t.Fatal(err) } @@ -151,7 +151,7 @@ func TestFileMinioClient(t *testing.T) { os.Clearenv() creds := NewFileMinioClient("config.json.sample", "") - credValues, err := creds.Get() + credValues, err := creds.GetWithContext(defaultCredContext) if err != nil { t.Fatal(err) } @@ -170,7 +170,7 @@ func TestFileMinioClient(t *testing.T) { os.Setenv("MINIO_ALIAS", "play") creds = NewFileMinioClient("config.json.sample", "") - credValues, err = creds.Get() + credValues, err = creds.GetWithContext(defaultCredContext) if err != nil { t.Fatal(err) } @@ -188,7 +188,7 @@ func TestFileMinioClient(t *testing.T) { os.Clearenv() creds = NewFileMinioClient("config.json.sample", "play") - credValues, err = creds.Get() + credValues, err = creds.GetWithContext(defaultCredContext) if err != nil { t.Fatal(err) } @@ -204,7 +204,7 @@ func TestFileMinioClient(t *testing.T) { } creds = NewFileMinioClient("non-existent.json", "play") - _, err = creds.Get() + _, err = creds.GetWithContext(defaultCredContext) if !os.IsNotExist(err) { t.Errorf("Expected open non-existent.json: no such file or directory, got %s", err) } diff --git a/pkg/credentials/iam_aws.go b/pkg/credentials/iam_aws.go index ea4b3ef93..083bf01bd 100644 --- a/pkg/credentials/iam_aws.go +++ b/pkg/credentials/iam_aws.go @@ -49,7 +49,8 @@ const DefaultExpiryWindow = -1 type IAM struct { Expiry - // Required http Client to use when connecting to IAM metadata service. + // Optional http Client to use when connecting to IAM metadata service + // (overrides default client in CredContext) Client *http.Client // Custom endpoint to fetch IAM role credentials. @@ -90,9 +91,6 @@ const ( // NewIAM returns a pointer to a new Credentials object wrapping the IAM. func NewIAM(endpoint string) *Credentials { return New(&IAM{ - Client: &http.Client{ - Transport: http.DefaultTransport, - }, Endpoint: endpoint, }) } @@ -100,7 +98,7 @@ func NewIAM(endpoint string) *Credentials { // Retrieve retrieves credentials from the EC2 service. // Error will be returned if the request fails, or unable to extract // the desired -func (m *IAM) Retrieve() (Value, error) { +func (m *IAM) Retrieve(cc *CredContext) (Value, error) { token := os.Getenv("AWS_CONTAINER_AUTHORIZATION_TOKEN") if token == "" { token = m.Container.AuthorizationToken @@ -144,6 +142,11 @@ func (m *IAM) Retrieve() (Value, error) { var roleCreds ec2RoleCredRespBody var err error + client := m.Client + if client == nil { + client = cc.Client + } + endpoint := m.Endpoint switch { case identityFile != "": @@ -160,7 +163,7 @@ func (m *IAM) Retrieve() (Value, error) { } creds := &STSWebIdentity{ - Client: m.Client, + Client: client, STSEndpoint: endpoint, GetWebIDTokenExpiry: func() (*WebIdentityToken, error) { token, err := os.ReadFile(identityFile) @@ -174,7 +177,7 @@ func (m *IAM) Retrieve() (Value, error) { roleSessionName: roleSessionName, } - stsWebIdentityCreds, err := creds.Retrieve() + stsWebIdentityCreds, err := creds.Retrieve(cc) if err == nil { m.SetExpiration(creds.Expiration(), DefaultExpiryWindow) } @@ -185,11 +188,11 @@ func (m *IAM) Retrieve() (Value, error) { endpoint = fmt.Sprintf("%s%s", DefaultECSRoleEndpoint, relativeURI) } - roleCreds, err = getEcsTaskCredentials(m.Client, endpoint, token) + roleCreds, err = getEcsTaskCredentials(client, endpoint, token) case tokenFile != "" && fullURI != "": endpoint = fullURI - roleCreds, err = getEKSPodIdentityCredentials(m.Client, endpoint, tokenFile) + roleCreds, err = getEKSPodIdentityCredentials(client, endpoint, tokenFile) case fullURI != "": if len(endpoint) == 0 { @@ -203,10 +206,10 @@ func (m *IAM) Retrieve() (Value, error) { } } - roleCreds, err = getEcsTaskCredentials(m.Client, endpoint, token) + roleCreds, err = getEcsTaskCredentials(client, endpoint, token) default: - roleCreds, err = getCredentials(m.Client, endpoint) + roleCreds, err = getCredentials(client, endpoint) } if err != nil { diff --git a/pkg/credentials/iam_aws_test.go b/pkg/credentials/iam_aws_test.go index 4089c13ed..b05c34357 100644 --- a/pkg/credentials/iam_aws_test.go +++ b/pkg/credentials/iam_aws_test.go @@ -156,7 +156,7 @@ func initStsTestServer(expireOn string) *httptest.Server { func TestIAMMalformedEndpoint(t *testing.T) { creds := NewIAM("%%%%") - _, err := creds.Get() + _, err := creds.GetWithContext(defaultCredContext) if err == nil { t.Fatal("Unexpected should fail here") } @@ -168,7 +168,7 @@ func TestIAMFailServer(t *testing.T) { creds := NewIAM(server.URL) - _, err := creds.Get() + _, err := creds.GetWithContext(defaultCredContext) if err == nil { t.Fatal("Unexpected should fail here") } @@ -182,7 +182,7 @@ func TestIAMNoRoles(t *testing.T) { defer server.Close() creds := NewIAM(server.URL) - _, err := creds.Get() + _, err := creds.GetWithContext(defaultCredContext) if err == nil { t.Fatal("Unexpected should fail here") } @@ -196,11 +196,10 @@ func TestIAM(t *testing.T) { defer server.Close() p := &IAM{ - Client: http.DefaultClient, Endpoint: server.URL, } - creds, err := p.Retrieve() + creds, err := p.Retrieve(defaultCredContext) if err != nil { t.Fatal(err) } @@ -227,11 +226,10 @@ func TestIAMFailAssume(t *testing.T) { defer server.Close() p := &IAM{ - Client: http.DefaultClient, Endpoint: server.URL, } - _, err := p.Retrieve() + _, err := p.Retrieve(defaultCredContext) if err == nil { t.Fatal("Unexpected success, should fail") } @@ -245,7 +243,6 @@ func TestIAMIsExpired(t *testing.T) { defer server.Close() p := &IAM{ - Client: http.DefaultClient, Endpoint: server.URL, } p.CurrentTime = func() time.Time { @@ -256,7 +253,7 @@ func TestIAMIsExpired(t *testing.T) { t.Error("Expected creds to be expired before retrieve.") } - _, err := p.Retrieve() + _, err := p.Retrieve(defaultCredContext) if err != nil { t.Fatal(err) } @@ -278,11 +275,10 @@ func TestEcsTask(t *testing.T) { server := initEcsTaskTestServer("2014-12-16T01:51:37Z") defer server.Close() p := &IAM{ - Client: http.DefaultClient, Endpoint: server.URL, } os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/v2/credentials?id=task_credential_id") - creds, err := p.Retrieve() + creds, err := p.Retrieve(defaultCredContext) os.Unsetenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") if err != nil { t.Errorf("Unexpected failure %s", err) @@ -307,12 +303,10 @@ func TestEcsTask(t *testing.T) { func TestEcsTaskFullURI(t *testing.T) { server := initEcsTaskTestServer("2014-12-16T01:51:37Z") defer server.Close() - p := &IAM{ - Client: http.DefaultClient, - } + p := &IAM{} os.Setenv("AWS_CONTAINER_CREDENTIALS_FULL_URI", fmt.Sprintf("%s%s", server.URL, "/v2/credentials?id=task_credential_id")) - creds, err := p.Retrieve() + creds, err := p.Retrieve(defaultCredContext) os.Unsetenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") if err != nil { t.Errorf("Unexpected failure %s", err) @@ -338,7 +332,6 @@ func TestSts(t *testing.T) { server := initStsTestServer("2014-12-16T01:51:37Z") defer server.Close() p := &IAM{ - Client: http.DefaultClient, Endpoint: server.URL, } @@ -352,7 +345,7 @@ func TestSts(t *testing.T) { os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", f.Name()) os.Setenv("AWS_ROLE_ARN", "arn:aws:sts::123456789012:assumed-role/FederatedWebIdentityRole/app1") - creds, err := p.Retrieve() + creds, err := p.Retrieve(defaultCredContext) os.Unsetenv("AWS_WEB_IDENTITY_TOKEN_FILE") os.Unsetenv("AWS_ROLE_ARN") if err != nil { @@ -379,7 +372,6 @@ func TestStsCn(t *testing.T) { server := initStsTestServer("2014-12-16T01:51:37Z") defer server.Close() p := &IAM{ - Client: http.DefaultClient, Endpoint: server.URL, } @@ -394,7 +386,7 @@ func TestStsCn(t *testing.T) { os.Setenv("AWS_REGION", "cn-northwest-1") os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", f.Name()) os.Setenv("AWS_ROLE_ARN", "arn:aws:sts::123456789012:assumed-role/FederatedWebIdentityRole/app1") - creds, err := p.Retrieve() + creds, err := p.Retrieve(defaultCredContext) os.Unsetenv("AWS_WEB_IDENTITY_TOKEN_FILE") os.Unsetenv("AWS_ROLE_ARN") if err != nil { @@ -420,10 +412,9 @@ func TestStsCn(t *testing.T) { func TestIMDSv1Blocked(t *testing.T) { server := initIMDSv2Server("2014-12-16T01:51:37Z", false) p := &IAM{ - Client: http.DefaultClient, Endpoint: server.URL, } - _, err := p.Retrieve() + _, err := p.Retrieve(defaultCredContext) if err != nil { t.Errorf("Unexpected IMDSv2 failure %s", err) } diff --git a/pkg/credentials/static.go b/pkg/credentials/static.go index 7dde00b0a..ea19080c2 100644 --- a/pkg/credentials/static.go +++ b/pkg/credentials/static.go @@ -51,7 +51,7 @@ func NewStatic(id, secret, token string, signerType SignatureType) *Credentials } // Retrieve returns the static credentials. -func (s *Static) Retrieve() (Value, error) { +func (s *Static) Retrieve(_ *CredContext) (Value, error) { if s.AccessKeyID == "" || s.SecretAccessKey == "" { // Anonymous is not an error return Value{SignerType: SignatureAnonymous}, nil diff --git a/pkg/credentials/static_test.go b/pkg/credentials/static_test.go index 65bec0565..a70cf368b 100644 --- a/pkg/credentials/static_test.go +++ b/pkg/credentials/static_test.go @@ -21,7 +21,7 @@ import "testing" func TestStaticGet(t *testing.T) { creds := NewStatic("UXHW", "SECRET", "", SignatureV4) - credValues, err := creds.Get() + credValues, err := creds.GetWithContext(defaultCredContext) if err != nil { t.Fatal(err) } @@ -46,7 +46,7 @@ func TestStaticGet(t *testing.T) { } creds = NewStatic("", "", "", SignatureDefault) - credValues, err = creds.Get() + credValues, err = creds.GetWithContext(defaultCredContext) if err != nil { t.Fatal(err) } diff --git a/pkg/credentials/sts_client_grants.go b/pkg/credentials/sts_client_grants.go index 62bfbb6b0..bbca025cb 100644 --- a/pkg/credentials/sts_client_grants.go +++ b/pkg/credentials/sts_client_grants.go @@ -72,7 +72,8 @@ type ClientGrantsToken struct { type STSClientGrants struct { Expiry - // Required http Client to use when connecting to MinIO STS service. + // Optional http Client to use when connecting to MinIO STS service. + // (overrides default client in CredContext) Client *http.Client // MinIO endpoint to fetch STS credentials. @@ -97,9 +98,6 @@ func NewSTSClientGrants(stsEndpoint string, getClientGrantsTokenExpiry func() (* return nil, errors.New("Client grants access token and expiry retrieval function should be defined") } return New(&STSClientGrants{ - Client: &http.Client{ - Transport: http.DefaultTransport, - }, STSEndpoint: stsEndpoint, GetClientGrantsTokenExpiry: getClientGrantsTokenExpiry, }), nil @@ -164,8 +162,12 @@ func getClientGrantsCredentials(clnt *http.Client, endpoint string, // Retrieve retrieves credentials from the MinIO service. // Error will be returned if the request fails. -func (m *STSClientGrants) Retrieve() (Value, error) { - a, err := getClientGrantsCredentials(m.Client, m.STSEndpoint, m.GetClientGrantsTokenExpiry) +func (m *STSClientGrants) Retrieve(cc *CredContext) (Value, error) { + client := m.Client + if client == nil { + client = cc.Client + } + a, err := getClientGrantsCredentials(client, m.STSEndpoint, m.GetClientGrantsTokenExpiry) if err != nil { return Value{}, err } diff --git a/pkg/credentials/sts_custom_identity.go b/pkg/credentials/sts_custom_identity.go index 75e1a77d3..10555a207 100644 --- a/pkg/credentials/sts_custom_identity.go +++ b/pkg/credentials/sts_custom_identity.go @@ -53,6 +53,8 @@ type AssumeRoleWithCustomTokenResponse struct { type CustomTokenIdentity struct { Expiry + // Optional http Client to use when connecting to MinIO STS service. + // (overrides default client in CredContext) Client *http.Client // MinIO server STS endpoint to fetch STS credentials. @@ -70,7 +72,7 @@ type CustomTokenIdentity struct { } // Retrieve - to satisfy Provider interface; fetches credentials from MinIO. -func (c *CustomTokenIdentity) Retrieve() (value Value, err error) { +func (c *CustomTokenIdentity) Retrieve(cc *CredContext) (value Value, err error) { u, err := url.Parse(c.STSEndpoint) if err != nil { return value, err @@ -92,7 +94,12 @@ func (c *CustomTokenIdentity) Retrieve() (value Value, err error) { return value, err } - resp, err := c.Client.Do(req) + client := c.Client + if client == nil { + client = cc.Client + } + + resp, err := client.Do(req) if err != nil { return value, err } @@ -122,7 +129,6 @@ func (c *CustomTokenIdentity) Retrieve() (value Value, err error) { // AssumeRoleWithCustomToken STS API. func NewCustomTokenCredentials(stsEndpoint, token, roleArn string, optFuncs ...CustomTokenOpt) (*Credentials, error) { c := CustomTokenIdentity{ - Client: &http.Client{Transport: http.DefaultTransport}, STSEndpoint: stsEndpoint, Token: token, RoleArn: roleArn, diff --git a/pkg/credentials/sts_ldap_identity.go b/pkg/credentials/sts_ldap_identity.go index b8df289f2..5d401ca72 100644 --- a/pkg/credentials/sts_ldap_identity.go +++ b/pkg/credentials/sts_ldap_identity.go @@ -55,7 +55,8 @@ type LDAPIdentityResult struct { type LDAPIdentity struct { Expiry - // Required http Client to use when connecting to MinIO STS service. + // Optional http Client to use when connecting to MinIO STS service. + // (overrides default client in CredContext) Client *http.Client // Exported STS endpoint to fetch STS credentials. @@ -77,7 +78,6 @@ type LDAPIdentity struct { // Identity. func NewLDAPIdentity(stsEndpoint, ldapUsername, ldapPassword string, optFuncs ...LDAPIdentityOpt) (*Credentials, error) { l := LDAPIdentity{ - Client: &http.Client{Transport: http.DefaultTransport}, STSEndpoint: stsEndpoint, LDAPUsername: ldapUsername, LDAPPassword: ldapPassword, @@ -113,7 +113,6 @@ func LDAPIdentityExpiryOpt(d time.Duration) LDAPIdentityOpt { // Deprecated: Use the `LDAPIdentityPolicyOpt` with `NewLDAPIdentity` instead. func NewLDAPIdentityWithSessionPolicy(stsEndpoint, ldapUsername, ldapPassword, policy string) (*Credentials, error) { return New(&LDAPIdentity{ - Client: &http.Client{Transport: http.DefaultTransport}, STSEndpoint: stsEndpoint, LDAPUsername: ldapUsername, LDAPPassword: ldapPassword, @@ -123,7 +122,7 @@ func NewLDAPIdentityWithSessionPolicy(stsEndpoint, ldapUsername, ldapPassword, p // Retrieve gets the credential by calling the MinIO STS API for // LDAP on the configured stsEndpoint. -func (k *LDAPIdentity) Retrieve() (value Value, err error) { +func (k *LDAPIdentity) Retrieve(cc *CredContext) (value Value, err error) { u, err := url.Parse(k.STSEndpoint) if err != nil { return value, err @@ -148,7 +147,12 @@ func (k *LDAPIdentity) Retrieve() (value Value, err error) { req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - resp, err := k.Client.Do(req) + client := k.Client + if client == nil { + client = cc.Client + } + + resp, err := client.Do(req) if err != nil { return value, err } diff --git a/pkg/credentials/sts_tls_identity.go b/pkg/credentials/sts_tls_identity.go index 10083502d..16e30d5e6 100644 --- a/pkg/credentials/sts_tls_identity.go +++ b/pkg/credentials/sts_tls_identity.go @@ -20,8 +20,8 @@ import ( "crypto/tls" "encoding/xml" "errors" + "fmt" "io" - "net" "net/http" "net/url" "strconv" @@ -36,7 +36,12 @@ type CertificateIdentityOption func(*STSCertificateIdentity) // CertificateIdentityWithTransport returns a CertificateIdentityOption that // customizes the STSCertificateIdentity with the given http.RoundTripper. func CertificateIdentityWithTransport(t http.RoundTripper) CertificateIdentityOption { - return CertificateIdentityOption(func(i *STSCertificateIdentity) { i.Client.Transport = t }) + return CertificateIdentityOption(func(i *STSCertificateIdentity) { + if i.Client == nil { + i.Client = &http.Client{} + } + i.Client.Transport = t + }) } // CertificateIdentityWithExpiry returns a CertificateIdentityOption that @@ -53,6 +58,10 @@ func CertificateIdentityWithExpiry(livetime time.Duration) CertificateIdentityOp type STSCertificateIdentity struct { Expiry + // Optional http Client to use when connecting to MinIO STS service. + // (overrides default client in CredContext) + Client *http.Client + // STSEndpoint is the base URL endpoint of the STS API. // For example, https://minio.local:9000 STSEndpoint string @@ -68,17 +77,9 @@ type STSCertificateIdentity struct { // The default livetime is one hour. S3CredentialLivetime time.Duration - // Client is the HTTP client used to authenticate and fetch - // S3 credentials. - // - // A custom TLS client configuration can be specified by - // using a custom http.Transport: - // Client: http.Client { - // Transport: &http.Transport{ - // TLSClientConfig: &tls.Config{}, - // }, - // } - Client http.Client + // Certificate is the client certificate that is used for + // STS authentication. + Certificate tls.Certificate } var _ Provider = (*STSWebIdentity)(nil) // compiler check @@ -95,23 +96,7 @@ func NewSTSCertificateIdentity(endpoint string, certificate tls.Certificate, opt } identity := &STSCertificateIdentity{ STSEndpoint: endpoint, - Client: http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 5 * time.Second, - TLSClientConfig: &tls.Config{ - Certificates: []tls.Certificate{certificate}, - }, - }, - }, + Certificate: certificate, } for _, option := range options { option(identity) @@ -121,7 +106,7 @@ func NewSTSCertificateIdentity(endpoint string, certificate tls.Certificate, opt // Retrieve fetches a new set of S3 credentials from the configured // STS API endpoint. -func (i *STSCertificateIdentity) Retrieve() (Value, error) { +func (i *STSCertificateIdentity) Retrieve(cc *CredContext) (Value, error) { endpointURL, err := url.Parse(i.STSEndpoint) if err != nil { return Value{}, err @@ -145,7 +130,25 @@ func (i *STSCertificateIdentity) Retrieve() (Value, error) { } req.Form.Add("DurationSeconds", strconv.FormatUint(uint64(livetime.Seconds()), 10)) - resp, err := i.Client.Do(req) + client := i.Client + if client == nil { + client = cc.Client + } + + tr, ok := client.Transport.(*http.Transport) + if !ok { + return Value{}, fmt.Errorf("CredContext should contain an http.Transport value") + } + + // Clone the HTTP transport (patch the TLS client certificate) + trCopy := tr.Clone() + trCopy.TLSClientConfig.Certificates = []tls.Certificate{i.Certificate} + + // Clone the HTTP client (patch the HTTP transport) + clientCopy := *client + clientCopy.Transport = trCopy + + resp, err := clientCopy.Do(req) if err != nil { return Value{}, err } diff --git a/pkg/credentials/sts_web_identity.go b/pkg/credentials/sts_web_identity.go index 8c06bac60..a268eb9f0 100644 --- a/pkg/credentials/sts_web_identity.go +++ b/pkg/credentials/sts_web_identity.go @@ -69,7 +69,8 @@ type WebIdentityToken struct { type STSWebIdentity struct { Expiry - // Required http Client to use when connecting to MinIO STS service. + // Optional http Client to use when connecting to MinIO STS service. + // (overrides default client in CredContext) Client *http.Client // Exported STS endpoint to fetch STS credentials. @@ -104,9 +105,6 @@ func NewSTSWebIdentity(stsEndpoint string, getWebIDTokenExpiry func() (*WebIdent return nil, errors.New("Web ID token and expiry retrieval function should be defined") } i := &STSWebIdentity{ - Client: &http.Client{ - Transport: http.DefaultTransport, - }, STSEndpoint: stsEndpoint, GetWebIDTokenExpiry: getWebIDTokenExpiry, } @@ -221,8 +219,13 @@ func getWebIdentityCredentials(clnt *http.Client, endpoint, roleARN, roleSession // Retrieve retrieves credentials from the MinIO service. // Error will be returned if the request fails. -func (m *STSWebIdentity) Retrieve() (Value, error) { - a, err := getWebIdentityCredentials(m.Client, m.STSEndpoint, m.RoleARN, m.roleSessionName, m.Policy, m.GetWebIDTokenExpiry) +func (m *STSWebIdentity) Retrieve(cc *CredContext) (Value, error) { + client := m.Client + if client == nil { + client = cc.Client + } + + a, err := getWebIdentityCredentials(client, m.STSEndpoint, m.RoleARN, m.roleSessionName, m.Policy, m.GetWebIDTokenExpiry) if err != nil { return Value{}, err }