From e762fe6f5c10bcccbb188eee1455303a085eae00 Mon Sep 17 00:00:00 2001 From: Ramon de Klein Date: Thu, 26 Dec 2024 20:09:40 +0100 Subject: [PATCH] Allow custom `http.Client` in credential providers. --- pkg/credentials/assume_role.go | 10 +++++++++- pkg/credentials/iam_aws.go | 5 +++++ pkg/credentials/sts_client_grants.go | 10 +++++++++- pkg/credentials/sts_custom_identity.go | 11 ++++++++++- pkg/credentials/sts_ldap_identity.go | 11 ++++++++++- pkg/credentials/sts_tls_identity.go | 24 +++++++++++++++++++++--- pkg/credentials/sts_web_identity.go | 11 ++++++++++- 7 files changed, 74 insertions(+), 8 deletions(-) diff --git a/pkg/credentials/assume_role.go b/pkg/credentials/assume_role.go index 358a03b16..06f79dd56 100644 --- a/pkg/credentials/assume_role.go +++ b/pkg/credentials/assume_role.go @@ -76,6 +76,10 @@ type AssumeRoleResult struct { type STSAssumeRole struct { Expiry + // Optional http Client to use when connecting to MinIO STS service + // (overrides default client in CredContext) + Client *http.Client + // STS endpoint to fetch STS credentials. STSEndpoint string @@ -219,7 +223,11 @@ func getAssumeRoleCredentials(clnt *http.Client, endpoint string, opts STSAssume // Retrieve retrieves credentials from the MinIO service. // Error will be returned if the request fails. func (m *STSAssumeRole) Retrieve(cc *CredContext) (Value, error) { - a, err := getAssumeRoleCredentials(cc.Client, m.STSEndpoint, m.Options) + client := m.Client + if client == nil { + client = cc.Client + } + a, err := getAssumeRoleCredentials(client, m.STSEndpoint, m.Options) if err != nil { return Value{}, err } diff --git a/pkg/credentials/iam_aws.go b/pkg/credentials/iam_aws.go index e4f886402..b2ca9d61b 100644 --- a/pkg/credentials/iam_aws.go +++ b/pkg/credentials/iam_aws.go @@ -49,6 +49,10 @@ const DefaultExpiryWindow = -1 type IAM struct { Expiry + // Optional http Client to use when connecting to IAM metadata service + // (overrides default client in CredContext) + Client *http.Client + // Custom endpoint to fetch IAM role credentials. Endpoint string @@ -154,6 +158,7 @@ func (m *IAM) Retrieve(cc *CredContext) (Value, error) { } creds := &STSWebIdentity{ + Client: m.Client, STSEndpoint: endpoint, GetWebIDTokenExpiry: func() (*WebIdentityToken, error) { token, err := os.ReadFile(identityFile) diff --git a/pkg/credentials/sts_client_grants.go b/pkg/credentials/sts_client_grants.go index d687e38af..bbca025cb 100644 --- a/pkg/credentials/sts_client_grants.go +++ b/pkg/credentials/sts_client_grants.go @@ -72,6 +72,10 @@ type ClientGrantsToken struct { type STSClientGrants struct { Expiry + // Optional http Client to use when connecting to MinIO STS service. + // (overrides default client in CredContext) + Client *http.Client + // MinIO endpoint to fetch STS credentials. STSEndpoint string @@ -159,7 +163,11 @@ func getClientGrantsCredentials(clnt *http.Client, endpoint string, // Retrieve retrieves credentials from the MinIO service. // Error will be returned if the request fails. func (m *STSClientGrants) Retrieve(cc *CredContext) (Value, error) { - a, err := getClientGrantsCredentials(cc.Client, m.STSEndpoint, m.GetClientGrantsTokenExpiry) + client := m.Client + if client == nil { + client = cc.Client + } + a, err := getClientGrantsCredentials(client, m.STSEndpoint, m.GetClientGrantsTokenExpiry) if err != nil { return Value{}, err } diff --git a/pkg/credentials/sts_custom_identity.go b/pkg/credentials/sts_custom_identity.go index 11ea30a2d..10555a207 100644 --- a/pkg/credentials/sts_custom_identity.go +++ b/pkg/credentials/sts_custom_identity.go @@ -53,6 +53,10 @@ type AssumeRoleWithCustomTokenResponse struct { type CustomTokenIdentity struct { Expiry + // Optional http Client to use when connecting to MinIO STS service. + // (overrides default client in CredContext) + Client *http.Client + // MinIO server STS endpoint to fetch STS credentials. STSEndpoint string @@ -90,7 +94,12 @@ func (c *CustomTokenIdentity) Retrieve(cc *CredContext) (value Value, err error) return value, err } - resp, err := cc.Client.Do(req) + client := c.Client + if client == nil { + client = cc.Client + } + + resp, err := client.Do(req) if err != nil { return value, err } diff --git a/pkg/credentials/sts_ldap_identity.go b/pkg/credentials/sts_ldap_identity.go index bf04735fc..5d401ca72 100644 --- a/pkg/credentials/sts_ldap_identity.go +++ b/pkg/credentials/sts_ldap_identity.go @@ -55,6 +55,10 @@ type LDAPIdentityResult struct { type LDAPIdentity struct { Expiry + // Optional http Client to use when connecting to MinIO STS service. + // (overrides default client in CredContext) + Client *http.Client + // Exported STS endpoint to fetch STS credentials. STSEndpoint string @@ -143,7 +147,12 @@ func (k *LDAPIdentity) Retrieve(cc *CredContext) (value Value, err error) { req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - resp, err := cc.Client.Do(req) + client := k.Client + if client == nil { + client = cc.Client + } + + resp, err := client.Do(req) if err != nil { return value, err } diff --git a/pkg/credentials/sts_tls_identity.go b/pkg/credentials/sts_tls_identity.go index 75d6cfc77..f2c34ce9c 100644 --- a/pkg/credentials/sts_tls_identity.go +++ b/pkg/credentials/sts_tls_identity.go @@ -33,6 +33,12 @@ import ( // livetime. type CertificateIdentityOption func(*STSCertificateIdentity) +// CertificateIdentityWithTransport returns a CertificateIdentityOption that +// customizes the STSCertificateIdentity with the given http.RoundTripper. +func CertificateIdentityWithTransport(t http.RoundTripper) CertificateIdentityOption { + return CertificateIdentityOption(func(i *STSCertificateIdentity) { i.Client.Transport = t }) +} + // CertificateIdentityWithExpiry returns a CertificateIdentityOption that // customizes the STSCertificateIdentity with the given livetime. // @@ -47,6 +53,10 @@ func CertificateIdentityWithExpiry(livetime time.Duration) CertificateIdentityOp type STSCertificateIdentity struct { Expiry + // Optional http Client to use when connecting to MinIO STS service. + // (overrides default client in CredContext) + Client *http.Client + // STSEndpoint is the base URL endpoint of the STS API. // For example, https://minio.local:9000 STSEndpoint string @@ -115,7 +125,12 @@ func (i *STSCertificateIdentity) Retrieve(cc *CredContext) (Value, error) { } req.Form.Add("DurationSeconds", strconv.FormatUint(uint64(livetime.Seconds()), 10)) - tr, ok := cc.Client.Transport.(*http.Transport) + client := i.Client + if client == nil { + client = cc.Client + } + + tr, ok := client.Transport.(*http.Transport) if !ok { return Value{}, fmt.Errorf("CredContext should contain an http.Transport value") } @@ -123,8 +138,11 @@ func (i *STSCertificateIdentity) Retrieve(cc *CredContext) (Value, error) { trCopy := tr.Clone() trCopy.TLSClientConfig.Certificates = []tls.Certificate{i.Certificate} - client := http.Client{ - Transport: trCopy, + client = &http.Client{ + Transport: trCopy, + CheckRedirect: client.CheckRedirect, + Jar: client.Jar, + Timeout: client.Timeout, } resp, err := client.Do(req) diff --git a/pkg/credentials/sts_web_identity.go b/pkg/credentials/sts_web_identity.go index d49cf96f1..a268eb9f0 100644 --- a/pkg/credentials/sts_web_identity.go +++ b/pkg/credentials/sts_web_identity.go @@ -69,6 +69,10 @@ type WebIdentityToken struct { type STSWebIdentity struct { Expiry + // Optional http Client to use when connecting to MinIO STS service. + // (overrides default client in CredContext) + Client *http.Client + // Exported STS endpoint to fetch STS credentials. STSEndpoint string @@ -216,7 +220,12 @@ func getWebIdentityCredentials(clnt *http.Client, endpoint, roleARN, roleSession // Retrieve retrieves credentials from the MinIO service. // Error will be returned if the request fails. func (m *STSWebIdentity) Retrieve(cc *CredContext) (Value, error) { - a, err := getWebIdentityCredentials(cc.Client, m.STSEndpoint, m.RoleARN, m.roleSessionName, m.Policy, m.GetWebIDTokenExpiry) + client := m.Client + if client == nil { + client = cc.Client + } + + a, err := getWebIdentityCredentials(client, m.STSEndpoint, m.RoleARN, m.roleSessionName, m.Policy, m.GetWebIDTokenExpiry) if err != nil { return Value{}, err }