From ed02e8d7e2782c53cd9842121ffb9836089c2ca1 Mon Sep 17 00:00:00 2001 From: Ramon de Klein Date: Thu, 26 Dec 2024 12:43:33 +0100 Subject: [PATCH] Use proper HTTP client for fetching credentials --- api-presigned.go | 2 +- api.go | 19 +++++++-- bucket-cache.go | 2 +- bucket-cache_test.go | 2 +- pkg/credentials/assume_role.go | 10 +---- pkg/credentials/chain.go | 4 +- pkg/credentials/chain_test.go | 10 ++--- pkg/credentials/credentials.go | 34 ++++++++++++++- pkg/credentials/credentials_test.go | 6 +-- pkg/credentials/env_aws.go | 2 +- pkg/credentials/env_minio.go | 2 +- pkg/credentials/env_test.go | 6 +-- pkg/credentials/file_aws_credentials.go | 2 +- pkg/credentials/file_minio_client.go | 2 +- pkg/credentials/file_test.go | 22 +++++----- pkg/credentials/iam_aws.go | 19 +++------ pkg/credentials/iam_aws_test.go | 33 ++++++--------- pkg/credentials/static.go | 2 +- pkg/credentials/static_test.go | 4 +- pkg/credentials/sts_client_grants.go | 10 +---- pkg/credentials/sts_custom_identity.go | 7 +--- pkg/credentials/sts_ldap_identity.go | 9 +--- pkg/credentials/sts_tls_identity.go | 56 +++++++++---------------- pkg/credentials/sts_web_identity.go | 10 +---- 24 files changed, 128 insertions(+), 147 deletions(-) diff --git a/api-presigned.go b/api-presigned.go index 9e85f81816..29642200ee 100644 --- a/api-presigned.go +++ b/api-presigned.go @@ -140,7 +140,7 @@ func (c *Client) PresignedPostPolicy(ctx context.Context, p *PostPolicy) (u *url } // Get credentials from the configured credentials provider. - credValues, err := c.credsProvider.Get() + credValues, err := c.credsProvider.GetWithContext(c.CredContext()) if err != nil { return nil, nil, err } diff --git a/api.go b/api.go index 86d072e4f7..5bcd903e3e 100644 --- a/api.go +++ b/api.go @@ -600,9 +600,9 @@ func (c *Client) executeMethod(ctx context.Context, method string, metadata requ return nil, errors.New(c.endpointURL.String() + " is offline.") } - var retryable bool // Indicates if request can be retried. - var bodySeeker io.Seeker // Extracted seeker from io.Reader. - var reqRetry = c.maxRetries // Indicates how many times we can retry the request + var retryable bool // Indicates if request can be retried. + var bodySeeker io.Seeker // Extracted seeker from io.Reader. + reqRetry := c.maxRetries // Indicates how many times we can retry the request if metadata.contentBody != nil { // Check if body is seekable then it is retryable. @@ -808,7 +808,7 @@ func (c *Client) newRequest(ctx context.Context, method string, metadata request } // Get credentials from the configured credentials provider. - value, err := c.credsProvider.Get() + value, err := c.credsProvider.GetWithContext(c.CredContext()) if err != nil { return nil, err } @@ -1018,3 +1018,14 @@ func (c *Client) isVirtualHostStyleRequest(url url.URL, bucketName string) bool // path style requests return s3utils.IsVirtualHostSupported(url, bucketName) } + +// CredContext returns the context for fetching credentials +func (c *Client) CredContext() *credentials.CredContext { + httpClient := c.httpClient + if httpClient == nil { + httpClient = http.DefaultClient + } + return &credentials.CredContext{ + Client: httpClient, + } +} diff --git a/bucket-cache.go b/bucket-cache.go index b1d3b3852c..4e4305acd5 100644 --- a/bucket-cache.go +++ b/bucket-cache.go @@ -212,7 +212,7 @@ func (c *Client) getBucketLocationRequest(ctx context.Context, bucketName string c.setUserAgent(req) // Get credentials from the configured credentials provider. - value, err := c.credsProvider.Get() + value, err := c.credsProvider.GetWithContext(c.CredContext()) if err != nil { return nil, err } diff --git a/bucket-cache_test.go b/bucket-cache_test.go index 61ac9bd551..40f76ca209 100644 --- a/bucket-cache_test.go +++ b/bucket-cache_test.go @@ -97,7 +97,7 @@ func TestGetBucketLocationRequest(t *testing.T) { c.setUserAgent(req) // Get credentials from the configured credentials provider. - value, err := c.credsProvider.Get() + value, err := c.credsProvider.GetWithContext(c.CredContext()) if err != nil { return nil, err } diff --git a/pkg/credentials/assume_role.go b/pkg/credentials/assume_role.go index d245bc07a3..358a03b16d 100644 --- a/pkg/credentials/assume_role.go +++ b/pkg/credentials/assume_role.go @@ -76,9 +76,6 @@ type AssumeRoleResult struct { type STSAssumeRole struct { Expiry - // Required http Client to use when connecting to MinIO STS service. - Client *http.Client - // STS endpoint to fetch STS credentials. STSEndpoint string @@ -115,9 +112,6 @@ func NewSTSAssumeRole(stsEndpoint string, opts STSAssumeRoleOptions) (*Credentia return nil, errors.New("AssumeRole credentials access/secretkey is mandatory") } return New(&STSAssumeRole{ - Client: &http.Client{ - Transport: http.DefaultTransport, - }, STSEndpoint: stsEndpoint, Options: opts, }), nil @@ -224,8 +218,8 @@ func getAssumeRoleCredentials(clnt *http.Client, endpoint string, opts STSAssume // Retrieve retrieves credentials from the MinIO service. // Error will be returned if the request fails. -func (m *STSAssumeRole) Retrieve() (Value, error) { - a, err := getAssumeRoleCredentials(m.Client, m.STSEndpoint, m.Options) +func (m *STSAssumeRole) Retrieve(cc *CredContext) (Value, error) { + a, err := getAssumeRoleCredentials(cc.Client, m.STSEndpoint, m.Options) if err != nil { return Value{}, err } diff --git a/pkg/credentials/chain.go b/pkg/credentials/chain.go index ddccfb173f..7e963fe10c 100644 --- a/pkg/credentials/chain.go +++ b/pkg/credentials/chain.go @@ -60,9 +60,9 @@ func NewChainCredentials(providers []Provider) *Credentials { // // If a provider is found with credentials, it will be cached and any calls // to IsExpired() will return the expired state of the cached provider. -func (c *Chain) Retrieve() (Value, error) { +func (c *Chain) Retrieve(cc *CredContext) (Value, error) { for _, p := range c.Providers { - creds, _ := p.Retrieve() + creds, _ := p.Retrieve(cc) // Always prioritize non-anonymous providers, if any. if creds.AccessKeyID == "" && creds.SecretAccessKey == "" { continue diff --git a/pkg/credentials/chain_test.go b/pkg/credentials/chain_test.go index 280b37c649..cdd6ab7d81 100644 --- a/pkg/credentials/chain_test.go +++ b/pkg/credentials/chain_test.go @@ -28,7 +28,7 @@ type testCredProvider struct { err error } -func (s *testCredProvider) Retrieve() (Value, error) { +func (s *testCredProvider) Retrieve(_ *CredContext) (Value, error) { s.expired = false return s.creds, s.err } @@ -59,7 +59,7 @@ func TestChainGet(t *testing.T) { }, } - creds, err := p.Retrieve() + creds, err := p.Retrieve(DefaultCredContext) if err != nil { t.Fatal(err) } @@ -95,7 +95,7 @@ func TestChainIsExpired(t *testing.T) { t.Fatal("Expected expired to be true before any Retrieve") } - _, err := p.Retrieve() + _, err := p.Retrieve(DefaultCredContext) if err != nil { t.Fatal(err) } @@ -112,7 +112,7 @@ func TestChainWithNoProvider(t *testing.T) { if !p.IsExpired() { t.Fatal("Expected to be expired with no providers") } - _, err := p.Retrieve() + _, err := p.Retrieve(DefaultCredContext) if err != nil { if err.Error() != "No valid providers found []" { t.Error(err) @@ -136,7 +136,7 @@ func TestChainProviderWithNoValidProvider(t *testing.T) { t.Fatal("Expected to be expired with no providers") } - _, err := p.Retrieve() + _, err := p.Retrieve(DefaultCredContext) if err != nil { if err.Error() != "No valid providers found [FirstError SecondError]" { t.Error(err) diff --git a/pkg/credentials/credentials.go b/pkg/credentials/credentials.go index 68f9b38157..4564961b2f 100644 --- a/pkg/credentials/credentials.go +++ b/pkg/credentials/credentials.go @@ -18,6 +18,7 @@ package credentials import ( + "net/http" "sync" "time" ) @@ -30,6 +31,10 @@ const ( defaultExpiryWindow = 0.8 ) +var DefaultCredContext *CredContext = &CredContext{ + Client: http.DefaultClient, +} + // A Value is the S3 credentials value for individual credential fields. type Value struct { // S3 Access key ID @@ -54,13 +59,21 @@ type Value struct { type Provider interface { // Retrieve returns nil if it successfully retrieved the value. // Error is returned if the value were not obtainable, or empty. - Retrieve() (Value, error) + Retrieve(cc *CredContext) (Value, error) // IsExpired returns if the credentials are no longer valid, and need // to be retrieved. IsExpired() bool } +// CredContext is passed to the Retrieve function of a provider to provide +// some additional context to retrieve credentials. +type CredContext struct { + // Client specifies the HTTP client that should be used if an HTTP + // request is to be made to fetch the credentials. + Client *http.Client +} + // A Expiry provides shared expiration logic to be used by credentials // providers to implement expiry functionality. // @@ -146,7 +159,24 @@ func New(provider Provider) *Credentials { // // If Credentials.Expire() was called the credentials Value will be force // expired, and the next call to Get() will cause them to be refreshed. +// +// Deprecated: Get() exists for historical compatibility and should not be +// used. To get new credentials use the Credentials.GetWithContext function +// to ensure the proper context (i.e. HTTP client) will be used. func (c *Credentials) Get() (Value, error) { + return c.GetWithContext(DefaultCredContext) +} + +// GetWithContext returns the credentials value, or error if the +// credentials Value failed to be retrieved. +// +// Will return the cached credentials Value if it has not expired. If the +// credentials Value has expired the Provider's Retrieve() will be called +// to refresh the credentials. +// +// If Credentials.Expire() was called the credentials Value will be force +// expired, and the next call to Get() will cause them to be refreshed. +func (c *Credentials) GetWithContext(cc *CredContext) (Value, error) { if c == nil { return Value{}, nil } @@ -155,7 +185,7 @@ func (c *Credentials) Get() (Value, error) { defer c.Unlock() if c.isExpired() { - creds, err := c.provider.Retrieve() + creds, err := c.provider.Retrieve(cc) if err != nil { return Value{}, err } diff --git a/pkg/credentials/credentials_test.go b/pkg/credentials/credentials_test.go index 8283450275..d065fd6d12 100644 --- a/pkg/credentials/credentials_test.go +++ b/pkg/credentials/credentials_test.go @@ -28,7 +28,7 @@ type credProvider struct { err error } -func (s *credProvider) Retrieve() (Value, error) { +func (s *credProvider) Retrieve(_ *CredContext) (Value, error) { s.expired = false return s.creds, s.err } @@ -47,7 +47,7 @@ func TestCredentialsGet(t *testing.T) { expired: true, }) - creds, err := c.Get() + creds, err := c.GetWithContext(DefaultCredContext) if err != nil { t.Fatal(err) } @@ -65,7 +65,7 @@ func TestCredentialsGet(t *testing.T) { func TestCredentialsGetWithError(t *testing.T) { c := New(&credProvider{err: errors.New("Custom error")}) - _, err := c.Get() + _, err := c.GetWithContext(DefaultCredContext) if err != nil { if err.Error() != "Custom error" { t.Errorf("Expected \"Custom error\", got %s", err.Error()) diff --git a/pkg/credentials/env_aws.go b/pkg/credentials/env_aws.go index b6e60d0e16..a6fc95d03b 100644 --- a/pkg/credentials/env_aws.go +++ b/pkg/credentials/env_aws.go @@ -38,7 +38,7 @@ func NewEnvAWS() *Credentials { } // Retrieve retrieves the keys from the environment. -func (e *EnvAWS) Retrieve() (Value, error) { +func (e *EnvAWS) Retrieve(_ *CredContext) (Value, error) { e.retrieved = false id := os.Getenv("AWS_ACCESS_KEY_ID") diff --git a/pkg/credentials/env_minio.go b/pkg/credentials/env_minio.go index 5bfeab140a..ba1e949347 100644 --- a/pkg/credentials/env_minio.go +++ b/pkg/credentials/env_minio.go @@ -39,7 +39,7 @@ func NewEnvMinio() *Credentials { } // Retrieve retrieves the keys from the environment. -func (e *EnvMinio) Retrieve() (Value, error) { +func (e *EnvMinio) Retrieve(_ *CredContext) (Value, error) { e.retrieved = false id := os.Getenv("MINIO_ROOT_USER") diff --git a/pkg/credentials/env_test.go b/pkg/credentials/env_test.go index 5e9240e331..dd9f15280d 100644 --- a/pkg/credentials/env_test.go +++ b/pkg/credentials/env_test.go @@ -34,7 +34,7 @@ func TestEnvAWSRetrieve(t *testing.T) { t.Error("Expect creds to be expired before retrieve.") } - creds, err := e.Retrieve() + creds, err := e.Retrieve(DefaultCredContext) if err != nil { t.Fatal(err) } @@ -63,7 +63,7 @@ func TestEnvAWSRetrieve(t *testing.T) { SignerType: SignatureV4, } - creds, err = e.Retrieve() + creds, err = e.Retrieve(DefaultCredContext) if err != nil { t.Fatal(err) } @@ -84,7 +84,7 @@ func TestEnvMinioRetrieve(t *testing.T) { t.Error("Expect creds to be expired before retrieve.") } - creds, err := e.Retrieve() + creds, err := e.Retrieve(DefaultCredContext) if err != nil { t.Fatal(err) } diff --git a/pkg/credentials/file_aws_credentials.go b/pkg/credentials/file_aws_credentials.go index 541e1a72f0..5fd08950a4 100644 --- a/pkg/credentials/file_aws_credentials.go +++ b/pkg/credentials/file_aws_credentials.go @@ -73,7 +73,7 @@ func NewFileAWSCredentials(filename, profile string) *Credentials { // Retrieve reads and extracts the shared credentials from the current // users home directory. -func (p *FileAWSCredentials) Retrieve() (Value, error) { +func (p *FileAWSCredentials) Retrieve(_ *CredContext) (Value, error) { if p.Filename == "" { p.Filename = os.Getenv("AWS_SHARED_CREDENTIALS_FILE") if p.Filename == "" { diff --git a/pkg/credentials/file_minio_client.go b/pkg/credentials/file_minio_client.go index 750e26ffa8..0e1131fca7 100644 --- a/pkg/credentials/file_minio_client.go +++ b/pkg/credentials/file_minio_client.go @@ -58,7 +58,7 @@ func NewFileMinioClient(filename, alias string) *Credentials { // Retrieve reads and extracts the shared credentials from the current // users home directory. -func (p *FileMinioClient) Retrieve() (Value, error) { +func (p *FileMinioClient) Retrieve(_ *CredContext) (Value, error) { if p.Filename == "" { if value, ok := os.LookupEnv("MINIO_SHARED_CREDENTIALS_FILE"); ok { p.Filename = value diff --git a/pkg/credentials/file_test.go b/pkg/credentials/file_test.go index fab48dc441..3f164e79ad 100644 --- a/pkg/credentials/file_test.go +++ b/pkg/credentials/file_test.go @@ -31,7 +31,7 @@ func TestFileAWS(t *testing.T) { os.Clearenv() creds := NewFileAWSCredentials("credentials.sample", "") - credValues, err := creds.Get() + credValues, err := creds.GetWithContext(DefaultCredContext) if err != nil { t.Fatal(err) } @@ -48,7 +48,7 @@ func TestFileAWS(t *testing.T) { os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "credentials.sample") creds = NewFileAWSCredentials("", "") - credValues, err = creds.Get() + credValues, err = creds.GetWithContext(DefaultCredContext) if err != nil { t.Fatal(err) } @@ -70,7 +70,7 @@ func TestFileAWS(t *testing.T) { os.Setenv("AWS_SHARED_CREDENTIALS_FILE", filepath.Join(wd, "credentials.sample")) creds = NewFileAWSCredentials("", "") - credValues, err = creds.Get() + credValues, err = creds.GetWithContext(DefaultCredContext) if err != nil { t.Fatal(err) } @@ -89,7 +89,7 @@ func TestFileAWS(t *testing.T) { os.Setenv("AWS_PROFILE", "no_token") creds = NewFileAWSCredentials("credentials.sample", "") - credValues, err = creds.Get() + credValues, err = creds.GetWithContext(DefaultCredContext) if err != nil { t.Fatal(err) } @@ -104,7 +104,7 @@ func TestFileAWS(t *testing.T) { os.Clearenv() creds = NewFileAWSCredentials("credentials.sample", "no_token") - credValues, err = creds.Get() + credValues, err = creds.GetWithContext(DefaultCredContext) if err != nil { t.Fatal(err) } @@ -117,7 +117,7 @@ func TestFileAWS(t *testing.T) { } creds = NewFileAWSCredentials("credentials-non-existent.sample", "no_token") - _, err = creds.Get() + _, err = creds.GetWithContext(DefaultCredContext) if !os.IsNotExist(err) { t.Errorf("Expected open non-existent.json: no such file or directory, got %s", err) } @@ -128,7 +128,7 @@ func TestFileAWS(t *testing.T) { os.Clearenv() creds = NewFileAWSCredentials("credentials.sample", "with_process") - credValues, err = creds.Get() + credValues, err = creds.GetWithContext(DefaultCredContext) if err != nil { t.Fatal(err) } @@ -151,7 +151,7 @@ func TestFileMinioClient(t *testing.T) { os.Clearenv() creds := NewFileMinioClient("config.json.sample", "") - credValues, err := creds.Get() + credValues, err := creds.GetWithContext(DefaultCredContext) if err != nil { t.Fatal(err) } @@ -170,7 +170,7 @@ func TestFileMinioClient(t *testing.T) { os.Setenv("MINIO_ALIAS", "play") creds = NewFileMinioClient("config.json.sample", "") - credValues, err = creds.Get() + credValues, err = creds.GetWithContext(DefaultCredContext) if err != nil { t.Fatal(err) } @@ -188,7 +188,7 @@ func TestFileMinioClient(t *testing.T) { os.Clearenv() creds = NewFileMinioClient("config.json.sample", "play") - credValues, err = creds.Get() + credValues, err = creds.GetWithContext(DefaultCredContext) if err != nil { t.Fatal(err) } @@ -204,7 +204,7 @@ func TestFileMinioClient(t *testing.T) { } creds = NewFileMinioClient("non-existent.json", "play") - _, err = creds.Get() + _, err = creds.GetWithContext(DefaultCredContext) if !os.IsNotExist(err) { t.Errorf("Expected open non-existent.json: no such file or directory, got %s", err) } diff --git a/pkg/credentials/iam_aws.go b/pkg/credentials/iam_aws.go index ea4b3ef937..e4f8864025 100644 --- a/pkg/credentials/iam_aws.go +++ b/pkg/credentials/iam_aws.go @@ -49,9 +49,6 @@ const DefaultExpiryWindow = -1 type IAM struct { Expiry - // Required http Client to use when connecting to IAM metadata service. - Client *http.Client - // Custom endpoint to fetch IAM role credentials. Endpoint string @@ -90,9 +87,6 @@ const ( // NewIAM returns a pointer to a new Credentials object wrapping the IAM. func NewIAM(endpoint string) *Credentials { return New(&IAM{ - Client: &http.Client{ - Transport: http.DefaultTransport, - }, Endpoint: endpoint, }) } @@ -100,7 +94,7 @@ func NewIAM(endpoint string) *Credentials { // Retrieve retrieves credentials from the EC2 service. // Error will be returned if the request fails, or unable to extract // the desired -func (m *IAM) Retrieve() (Value, error) { +func (m *IAM) Retrieve(cc *CredContext) (Value, error) { token := os.Getenv("AWS_CONTAINER_AUTHORIZATION_TOKEN") if token == "" { token = m.Container.AuthorizationToken @@ -160,7 +154,6 @@ func (m *IAM) Retrieve() (Value, error) { } creds := &STSWebIdentity{ - Client: m.Client, STSEndpoint: endpoint, GetWebIDTokenExpiry: func() (*WebIdentityToken, error) { token, err := os.ReadFile(identityFile) @@ -174,7 +167,7 @@ func (m *IAM) Retrieve() (Value, error) { roleSessionName: roleSessionName, } - stsWebIdentityCreds, err := creds.Retrieve() + stsWebIdentityCreds, err := creds.Retrieve(cc) if err == nil { m.SetExpiration(creds.Expiration(), DefaultExpiryWindow) } @@ -185,11 +178,11 @@ func (m *IAM) Retrieve() (Value, error) { endpoint = fmt.Sprintf("%s%s", DefaultECSRoleEndpoint, relativeURI) } - roleCreds, err = getEcsTaskCredentials(m.Client, endpoint, token) + roleCreds, err = getEcsTaskCredentials(cc.Client, endpoint, token) case tokenFile != "" && fullURI != "": endpoint = fullURI - roleCreds, err = getEKSPodIdentityCredentials(m.Client, endpoint, tokenFile) + roleCreds, err = getEKSPodIdentityCredentials(cc.Client, endpoint, tokenFile) case fullURI != "": if len(endpoint) == 0 { @@ -203,10 +196,10 @@ func (m *IAM) Retrieve() (Value, error) { } } - roleCreds, err = getEcsTaskCredentials(m.Client, endpoint, token) + roleCreds, err = getEcsTaskCredentials(cc.Client, endpoint, token) default: - roleCreds, err = getCredentials(m.Client, endpoint) + roleCreds, err = getCredentials(cc.Client, endpoint) } if err != nil { diff --git a/pkg/credentials/iam_aws_test.go b/pkg/credentials/iam_aws_test.go index 4089c13ed3..4c0b74e5b4 100644 --- a/pkg/credentials/iam_aws_test.go +++ b/pkg/credentials/iam_aws_test.go @@ -156,7 +156,7 @@ func initStsTestServer(expireOn string) *httptest.Server { func TestIAMMalformedEndpoint(t *testing.T) { creds := NewIAM("%%%%") - _, err := creds.Get() + _, err := creds.GetWithContext(DefaultCredContext) if err == nil { t.Fatal("Unexpected should fail here") } @@ -168,7 +168,7 @@ func TestIAMFailServer(t *testing.T) { creds := NewIAM(server.URL) - _, err := creds.Get() + _, err := creds.GetWithContext(DefaultCredContext) if err == nil { t.Fatal("Unexpected should fail here") } @@ -182,7 +182,7 @@ func TestIAMNoRoles(t *testing.T) { defer server.Close() creds := NewIAM(server.URL) - _, err := creds.Get() + _, err := creds.GetWithContext(DefaultCredContext) if err == nil { t.Fatal("Unexpected should fail here") } @@ -196,11 +196,10 @@ func TestIAM(t *testing.T) { defer server.Close() p := &IAM{ - Client: http.DefaultClient, Endpoint: server.URL, } - creds, err := p.Retrieve() + creds, err := p.Retrieve(DefaultCredContext) if err != nil { t.Fatal(err) } @@ -227,11 +226,10 @@ func TestIAMFailAssume(t *testing.T) { defer server.Close() p := &IAM{ - Client: http.DefaultClient, Endpoint: server.URL, } - _, err := p.Retrieve() + _, err := p.Retrieve(DefaultCredContext) if err == nil { t.Fatal("Unexpected success, should fail") } @@ -245,7 +243,6 @@ func TestIAMIsExpired(t *testing.T) { defer server.Close() p := &IAM{ - Client: http.DefaultClient, Endpoint: server.URL, } p.CurrentTime = func() time.Time { @@ -256,7 +253,7 @@ func TestIAMIsExpired(t *testing.T) { t.Error("Expected creds to be expired before retrieve.") } - _, err := p.Retrieve() + _, err := p.Retrieve(DefaultCredContext) if err != nil { t.Fatal(err) } @@ -278,11 +275,10 @@ func TestEcsTask(t *testing.T) { server := initEcsTaskTestServer("2014-12-16T01:51:37Z") defer server.Close() p := &IAM{ - Client: http.DefaultClient, Endpoint: server.URL, } os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/v2/credentials?id=task_credential_id") - creds, err := p.Retrieve() + creds, err := p.Retrieve(DefaultCredContext) os.Unsetenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") if err != nil { t.Errorf("Unexpected failure %s", err) @@ -307,12 +303,10 @@ func TestEcsTask(t *testing.T) { func TestEcsTaskFullURI(t *testing.T) { server := initEcsTaskTestServer("2014-12-16T01:51:37Z") defer server.Close() - p := &IAM{ - Client: http.DefaultClient, - } + p := &IAM{} os.Setenv("AWS_CONTAINER_CREDENTIALS_FULL_URI", fmt.Sprintf("%s%s", server.URL, "/v2/credentials?id=task_credential_id")) - creds, err := p.Retrieve() + creds, err := p.Retrieve(DefaultCredContext) os.Unsetenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") if err != nil { t.Errorf("Unexpected failure %s", err) @@ -338,7 +332,6 @@ func TestSts(t *testing.T) { server := initStsTestServer("2014-12-16T01:51:37Z") defer server.Close() p := &IAM{ - Client: http.DefaultClient, Endpoint: server.URL, } @@ -352,7 +345,7 @@ func TestSts(t *testing.T) { os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", f.Name()) os.Setenv("AWS_ROLE_ARN", "arn:aws:sts::123456789012:assumed-role/FederatedWebIdentityRole/app1") - creds, err := p.Retrieve() + creds, err := p.Retrieve(DefaultCredContext) os.Unsetenv("AWS_WEB_IDENTITY_TOKEN_FILE") os.Unsetenv("AWS_ROLE_ARN") if err != nil { @@ -379,7 +372,6 @@ func TestStsCn(t *testing.T) { server := initStsTestServer("2014-12-16T01:51:37Z") defer server.Close() p := &IAM{ - Client: http.DefaultClient, Endpoint: server.URL, } @@ -394,7 +386,7 @@ func TestStsCn(t *testing.T) { os.Setenv("AWS_REGION", "cn-northwest-1") os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", f.Name()) os.Setenv("AWS_ROLE_ARN", "arn:aws:sts::123456789012:assumed-role/FederatedWebIdentityRole/app1") - creds, err := p.Retrieve() + creds, err := p.Retrieve(DefaultCredContext) os.Unsetenv("AWS_WEB_IDENTITY_TOKEN_FILE") os.Unsetenv("AWS_ROLE_ARN") if err != nil { @@ -420,10 +412,9 @@ func TestStsCn(t *testing.T) { func TestIMDSv1Blocked(t *testing.T) { server := initIMDSv2Server("2014-12-16T01:51:37Z", false) p := &IAM{ - Client: http.DefaultClient, Endpoint: server.URL, } - _, err := p.Retrieve() + _, err := p.Retrieve(DefaultCredContext) if err != nil { t.Errorf("Unexpected IMDSv2 failure %s", err) } diff --git a/pkg/credentials/static.go b/pkg/credentials/static.go index 7dde00b0a1..ea19080c20 100644 --- a/pkg/credentials/static.go +++ b/pkg/credentials/static.go @@ -51,7 +51,7 @@ func NewStatic(id, secret, token string, signerType SignatureType) *Credentials } // Retrieve returns the static credentials. -func (s *Static) Retrieve() (Value, error) { +func (s *Static) Retrieve(_ *CredContext) (Value, error) { if s.AccessKeyID == "" || s.SecretAccessKey == "" { // Anonymous is not an error return Value{SignerType: SignatureAnonymous}, nil diff --git a/pkg/credentials/static_test.go b/pkg/credentials/static_test.go index 65bec05654..ad913c9851 100644 --- a/pkg/credentials/static_test.go +++ b/pkg/credentials/static_test.go @@ -21,7 +21,7 @@ import "testing" func TestStaticGet(t *testing.T) { creds := NewStatic("UXHW", "SECRET", "", SignatureV4) - credValues, err := creds.Get() + credValues, err := creds.GetWithContext(DefaultCredContext) if err != nil { t.Fatal(err) } @@ -46,7 +46,7 @@ func TestStaticGet(t *testing.T) { } creds = NewStatic("", "", "", SignatureDefault) - credValues, err = creds.Get() + credValues, err = creds.GetWithContext(DefaultCredContext) if err != nil { t.Fatal(err) } diff --git a/pkg/credentials/sts_client_grants.go b/pkg/credentials/sts_client_grants.go index 62bfbb6b02..d687e38afd 100644 --- a/pkg/credentials/sts_client_grants.go +++ b/pkg/credentials/sts_client_grants.go @@ -72,9 +72,6 @@ type ClientGrantsToken struct { type STSClientGrants struct { Expiry - // Required http Client to use when connecting to MinIO STS service. - Client *http.Client - // MinIO endpoint to fetch STS credentials. STSEndpoint string @@ -97,9 +94,6 @@ func NewSTSClientGrants(stsEndpoint string, getClientGrantsTokenExpiry func() (* return nil, errors.New("Client grants access token and expiry retrieval function should be defined") } return New(&STSClientGrants{ - Client: &http.Client{ - Transport: http.DefaultTransport, - }, STSEndpoint: stsEndpoint, GetClientGrantsTokenExpiry: getClientGrantsTokenExpiry, }), nil @@ -164,8 +158,8 @@ func getClientGrantsCredentials(clnt *http.Client, endpoint string, // Retrieve retrieves credentials from the MinIO service. // Error will be returned if the request fails. -func (m *STSClientGrants) Retrieve() (Value, error) { - a, err := getClientGrantsCredentials(m.Client, m.STSEndpoint, m.GetClientGrantsTokenExpiry) +func (m *STSClientGrants) Retrieve(cc *CredContext) (Value, error) { + a, err := getClientGrantsCredentials(cc.Client, m.STSEndpoint, m.GetClientGrantsTokenExpiry) if err != nil { return Value{}, err } diff --git a/pkg/credentials/sts_custom_identity.go b/pkg/credentials/sts_custom_identity.go index 75e1a77d32..11ea30a2d5 100644 --- a/pkg/credentials/sts_custom_identity.go +++ b/pkg/credentials/sts_custom_identity.go @@ -53,8 +53,6 @@ type AssumeRoleWithCustomTokenResponse struct { type CustomTokenIdentity struct { Expiry - Client *http.Client - // MinIO server STS endpoint to fetch STS credentials. STSEndpoint string @@ -70,7 +68,7 @@ type CustomTokenIdentity struct { } // Retrieve - to satisfy Provider interface; fetches credentials from MinIO. -func (c *CustomTokenIdentity) Retrieve() (value Value, err error) { +func (c *CustomTokenIdentity) Retrieve(cc *CredContext) (value Value, err error) { u, err := url.Parse(c.STSEndpoint) if err != nil { return value, err @@ -92,7 +90,7 @@ func (c *CustomTokenIdentity) Retrieve() (value Value, err error) { return value, err } - resp, err := c.Client.Do(req) + resp, err := cc.Client.Do(req) if err != nil { return value, err } @@ -122,7 +120,6 @@ func (c *CustomTokenIdentity) Retrieve() (value Value, err error) { // AssumeRoleWithCustomToken STS API. func NewCustomTokenCredentials(stsEndpoint, token, roleArn string, optFuncs ...CustomTokenOpt) (*Credentials, error) { c := CustomTokenIdentity{ - Client: &http.Client{Transport: http.DefaultTransport}, STSEndpoint: stsEndpoint, Token: token, RoleArn: roleArn, diff --git a/pkg/credentials/sts_ldap_identity.go b/pkg/credentials/sts_ldap_identity.go index b8df289f20..bf04735fc3 100644 --- a/pkg/credentials/sts_ldap_identity.go +++ b/pkg/credentials/sts_ldap_identity.go @@ -55,9 +55,6 @@ type LDAPIdentityResult struct { type LDAPIdentity struct { Expiry - // Required http Client to use when connecting to MinIO STS service. - Client *http.Client - // Exported STS endpoint to fetch STS credentials. STSEndpoint string @@ -77,7 +74,6 @@ type LDAPIdentity struct { // Identity. func NewLDAPIdentity(stsEndpoint, ldapUsername, ldapPassword string, optFuncs ...LDAPIdentityOpt) (*Credentials, error) { l := LDAPIdentity{ - Client: &http.Client{Transport: http.DefaultTransport}, STSEndpoint: stsEndpoint, LDAPUsername: ldapUsername, LDAPPassword: ldapPassword, @@ -113,7 +109,6 @@ func LDAPIdentityExpiryOpt(d time.Duration) LDAPIdentityOpt { // Deprecated: Use the `LDAPIdentityPolicyOpt` with `NewLDAPIdentity` instead. func NewLDAPIdentityWithSessionPolicy(stsEndpoint, ldapUsername, ldapPassword, policy string) (*Credentials, error) { return New(&LDAPIdentity{ - Client: &http.Client{Transport: http.DefaultTransport}, STSEndpoint: stsEndpoint, LDAPUsername: ldapUsername, LDAPPassword: ldapPassword, @@ -123,7 +118,7 @@ func NewLDAPIdentityWithSessionPolicy(stsEndpoint, ldapUsername, ldapPassword, p // Retrieve gets the credential by calling the MinIO STS API for // LDAP on the configured stsEndpoint. -func (k *LDAPIdentity) Retrieve() (value Value, err error) { +func (k *LDAPIdentity) Retrieve(cc *CredContext) (value Value, err error) { u, err := url.Parse(k.STSEndpoint) if err != nil { return value, err @@ -148,7 +143,7 @@ func (k *LDAPIdentity) Retrieve() (value Value, err error) { req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - resp, err := k.Client.Do(req) + resp, err := cc.Client.Do(req) if err != nil { return value, err } diff --git a/pkg/credentials/sts_tls_identity.go b/pkg/credentials/sts_tls_identity.go index 10083502d1..75d6cfc779 100644 --- a/pkg/credentials/sts_tls_identity.go +++ b/pkg/credentials/sts_tls_identity.go @@ -20,8 +20,8 @@ import ( "crypto/tls" "encoding/xml" "errors" + "fmt" "io" - "net" "net/http" "net/url" "strconv" @@ -33,12 +33,6 @@ import ( // livetime. type CertificateIdentityOption func(*STSCertificateIdentity) -// CertificateIdentityWithTransport returns a CertificateIdentityOption that -// customizes the STSCertificateIdentity with the given http.RoundTripper. -func CertificateIdentityWithTransport(t http.RoundTripper) CertificateIdentityOption { - return CertificateIdentityOption(func(i *STSCertificateIdentity) { i.Client.Transport = t }) -} - // CertificateIdentityWithExpiry returns a CertificateIdentityOption that // customizes the STSCertificateIdentity with the given livetime. // @@ -68,17 +62,9 @@ type STSCertificateIdentity struct { // The default livetime is one hour. S3CredentialLivetime time.Duration - // Client is the HTTP client used to authenticate and fetch - // S3 credentials. - // - // A custom TLS client configuration can be specified by - // using a custom http.Transport: - // Client: http.Client { - // Transport: &http.Transport{ - // TLSClientConfig: &tls.Config{}, - // }, - // } - Client http.Client + // Certificate is the client certificate that is used for + // STS authentication. + Certificate tls.Certificate } var _ Provider = (*STSWebIdentity)(nil) // compiler check @@ -95,23 +81,7 @@ func NewSTSCertificateIdentity(endpoint string, certificate tls.Certificate, opt } identity := &STSCertificateIdentity{ STSEndpoint: endpoint, - Client: http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 5 * time.Second, - TLSClientConfig: &tls.Config{ - Certificates: []tls.Certificate{certificate}, - }, - }, - }, + Certificate: certificate, } for _, option := range options { option(identity) @@ -121,7 +91,7 @@ func NewSTSCertificateIdentity(endpoint string, certificate tls.Certificate, opt // Retrieve fetches a new set of S3 credentials from the configured // STS API endpoint. -func (i *STSCertificateIdentity) Retrieve() (Value, error) { +func (i *STSCertificateIdentity) Retrieve(cc *CredContext) (Value, error) { endpointURL, err := url.Parse(i.STSEndpoint) if err != nil { return Value{}, err @@ -145,7 +115,19 @@ func (i *STSCertificateIdentity) Retrieve() (Value, error) { } req.Form.Add("DurationSeconds", strconv.FormatUint(uint64(livetime.Seconds()), 10)) - resp, err := i.Client.Do(req) + tr, ok := cc.Client.Transport.(*http.Transport) + if !ok { + return Value{}, fmt.Errorf("CredContext should contain an http.Transport value") + } + + trCopy := tr.Clone() + trCopy.TLSClientConfig.Certificates = []tls.Certificate{i.Certificate} + + client := http.Client{ + Transport: trCopy, + } + + resp, err := client.Do(req) if err != nil { return Value{}, err } diff --git a/pkg/credentials/sts_web_identity.go b/pkg/credentials/sts_web_identity.go index 8c06bac60d..d49cf96f1b 100644 --- a/pkg/credentials/sts_web_identity.go +++ b/pkg/credentials/sts_web_identity.go @@ -69,9 +69,6 @@ type WebIdentityToken struct { type STSWebIdentity struct { Expiry - // Required http Client to use when connecting to MinIO STS service. - Client *http.Client - // Exported STS endpoint to fetch STS credentials. STSEndpoint string @@ -104,9 +101,6 @@ func NewSTSWebIdentity(stsEndpoint string, getWebIDTokenExpiry func() (*WebIdent return nil, errors.New("Web ID token and expiry retrieval function should be defined") } i := &STSWebIdentity{ - Client: &http.Client{ - Transport: http.DefaultTransport, - }, STSEndpoint: stsEndpoint, GetWebIDTokenExpiry: getWebIDTokenExpiry, } @@ -221,8 +215,8 @@ func getWebIdentityCredentials(clnt *http.Client, endpoint, roleARN, roleSession // Retrieve retrieves credentials from the MinIO service. // Error will be returned if the request fails. -func (m *STSWebIdentity) Retrieve() (Value, error) { - a, err := getWebIdentityCredentials(m.Client, m.STSEndpoint, m.RoleARN, m.roleSessionName, m.Policy, m.GetWebIDTokenExpiry) +func (m *STSWebIdentity) Retrieve(cc *CredContext) (Value, error) { + a, err := getWebIdentityCredentials(cc.Client, m.STSEndpoint, m.RoleARN, m.roleSessionName, m.Policy, m.GetWebIDTokenExpiry) if err != nil { return Value{}, err }