Skip to content

Commit

Permalink
Allow custom http.Client in credential providers.
Browse files Browse the repository at this point in the history
  • Loading branch information
ramondeklein committed Dec 26, 2024
1 parent 2be0ccc commit f9f409f
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 12 deletions.
10 changes: 9 additions & 1 deletion pkg/credentials/assume_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ type AssumeRoleResult struct {
type STSAssumeRole struct {
Expiry

// Optional http Client to use when connecting to MinIO STS service
// (overrides default client in CredContext)
Client *http.Client

// STS endpoint to fetch STS credentials.
STSEndpoint string

Expand Down Expand Up @@ -219,7 +223,11 @@ func getAssumeRoleCredentials(clnt *http.Client, endpoint string, opts STSAssume
// Retrieve retrieves credentials from the MinIO service.
// Error will be returned if the request fails.
func (m *STSAssumeRole) Retrieve(cc *CredContext) (Value, error) {
a, err := getAssumeRoleCredentials(cc.Client, m.STSEndpoint, m.Options)
client := m.Client
if client == nil {
client = cc.Client
}
a, err := getAssumeRoleCredentials(client, m.STSEndpoint, m.Options)
if err != nil {
return Value{}, err
}
Expand Down
18 changes: 14 additions & 4 deletions pkg/credentials/iam_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ const DefaultExpiryWindow = -1
type IAM struct {
Expiry

// Optional http Client to use when connecting to IAM metadata service
// (overrides default client in CredContext)
Client *http.Client

// Custom endpoint to fetch IAM role credentials.
Endpoint string

Expand Down Expand Up @@ -138,6 +142,11 @@ func (m *IAM) Retrieve(cc *CredContext) (Value, error) {
var roleCreds ec2RoleCredRespBody
var err error

client := m.Client
if client == nil {
client = cc.Client
}

endpoint := m.Endpoint
switch {
case identityFile != "":
Expand All @@ -154,6 +163,7 @@ func (m *IAM) Retrieve(cc *CredContext) (Value, error) {
}

creds := &STSWebIdentity{
Client: client,
STSEndpoint: endpoint,
GetWebIDTokenExpiry: func() (*WebIdentityToken, error) {
token, err := os.ReadFile(identityFile)
Expand All @@ -178,11 +188,11 @@ func (m *IAM) Retrieve(cc *CredContext) (Value, error) {
endpoint = fmt.Sprintf("%s%s", DefaultECSRoleEndpoint, relativeURI)
}

roleCreds, err = getEcsTaskCredentials(cc.Client, endpoint, token)
roleCreds, err = getEcsTaskCredentials(client, endpoint, token)

case tokenFile != "" && fullURI != "":
endpoint = fullURI
roleCreds, err = getEKSPodIdentityCredentials(cc.Client, endpoint, tokenFile)
roleCreds, err = getEKSPodIdentityCredentials(client, endpoint, tokenFile)

case fullURI != "":
if len(endpoint) == 0 {
Expand All @@ -196,10 +206,10 @@ func (m *IAM) Retrieve(cc *CredContext) (Value, error) {
}
}

roleCreds, err = getEcsTaskCredentials(cc.Client, endpoint, token)
roleCreds, err = getEcsTaskCredentials(client, endpoint, token)

default:
roleCreds, err = getCredentials(cc.Client, endpoint)
roleCreds, err = getCredentials(client, endpoint)
}

if err != nil {
Expand Down
10 changes: 9 additions & 1 deletion pkg/credentials/sts_client_grants.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ type ClientGrantsToken struct {
type STSClientGrants struct {
Expiry

// Optional http Client to use when connecting to MinIO STS service.
// (overrides default client in CredContext)
Client *http.Client

// MinIO endpoint to fetch STS credentials.
STSEndpoint string

Expand Down Expand Up @@ -159,7 +163,11 @@ func getClientGrantsCredentials(clnt *http.Client, endpoint string,
// Retrieve retrieves credentials from the MinIO service.
// Error will be returned if the request fails.
func (m *STSClientGrants) Retrieve(cc *CredContext) (Value, error) {
a, err := getClientGrantsCredentials(cc.Client, m.STSEndpoint, m.GetClientGrantsTokenExpiry)
client := m.Client
if client == nil {
client = cc.Client
}
a, err := getClientGrantsCredentials(client, m.STSEndpoint, m.GetClientGrantsTokenExpiry)
if err != nil {
return Value{}, err
}
Expand Down
11 changes: 10 additions & 1 deletion pkg/credentials/sts_custom_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ type AssumeRoleWithCustomTokenResponse struct {
type CustomTokenIdentity struct {
Expiry

// Optional http Client to use when connecting to MinIO STS service.
// (overrides default client in CredContext)
Client *http.Client

// MinIO server STS endpoint to fetch STS credentials.
STSEndpoint string

Expand Down Expand Up @@ -90,7 +94,12 @@ func (c *CustomTokenIdentity) Retrieve(cc *CredContext) (value Value, err error)
return value, err
}

resp, err := cc.Client.Do(req)
client := c.Client
if client == nil {
client = cc.Client
}

resp, err := client.Do(req)
if err != nil {
return value, err
}
Expand Down
11 changes: 10 additions & 1 deletion pkg/credentials/sts_ldap_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ type LDAPIdentityResult struct {
type LDAPIdentity struct {
Expiry

// Optional http Client to use when connecting to MinIO STS service.
// (overrides default client in CredContext)
Client *http.Client

// Exported STS endpoint to fetch STS credentials.
STSEndpoint string

Expand Down Expand Up @@ -143,7 +147,12 @@ func (k *LDAPIdentity) Retrieve(cc *CredContext) (value Value, err error) {

req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

resp, err := cc.Client.Do(req)
client := k.Client
if client == nil {
client = cc.Client
}

resp, err := client.Do(req)
if err != nil {
return value, err
}
Expand Down
29 changes: 26 additions & 3 deletions pkg/credentials/sts_tls_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ import (
// livetime.
type CertificateIdentityOption func(*STSCertificateIdentity)

// CertificateIdentityWithTransport returns a CertificateIdentityOption that
// customizes the STSCertificateIdentity with the given http.RoundTripper.
func CertificateIdentityWithTransport(t http.RoundTripper) CertificateIdentityOption {
return CertificateIdentityOption(func(i *STSCertificateIdentity) {
if i.Client == nil {
i.Client = &http.Client{}
}
i.Client.Transport = t
})
}

// CertificateIdentityWithExpiry returns a CertificateIdentityOption that
// customizes the STSCertificateIdentity with the given livetime.
//
Expand All @@ -47,6 +58,10 @@ func CertificateIdentityWithExpiry(livetime time.Duration) CertificateIdentityOp
type STSCertificateIdentity struct {
Expiry

// Optional http Client to use when connecting to MinIO STS service.
// (overrides default client in CredContext)
Client *http.Client

// STSEndpoint is the base URL endpoint of the STS API.
// For example, https://minio.local:9000
STSEndpoint string
Expand Down Expand Up @@ -115,16 +130,24 @@ func (i *STSCertificateIdentity) Retrieve(cc *CredContext) (Value, error) {
}
req.Form.Add("DurationSeconds", strconv.FormatUint(uint64(livetime.Seconds()), 10))

tr, ok := cc.Client.Transport.(*http.Transport)
client := i.Client
if client == nil {
client = cc.Client
}

tr, ok := client.Transport.(*http.Transport)
if !ok {
return Value{}, fmt.Errorf("CredContext should contain an http.Transport value")
}

trCopy := tr.Clone()
trCopy.TLSClientConfig.Certificates = []tls.Certificate{i.Certificate}

client := http.Client{
Transport: trCopy,
client = &http.Client{
Transport: trCopy,
CheckRedirect: client.CheckRedirect,
Jar: client.Jar,
Timeout: client.Timeout,
}

resp, err := client.Do(req)
Expand Down
11 changes: 10 additions & 1 deletion pkg/credentials/sts_web_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ type WebIdentityToken struct {
type STSWebIdentity struct {
Expiry

// Optional http Client to use when connecting to MinIO STS service.
// (overrides default client in CredContext)
Client *http.Client

// Exported STS endpoint to fetch STS credentials.
STSEndpoint string

Expand Down Expand Up @@ -216,7 +220,12 @@ func getWebIdentityCredentials(clnt *http.Client, endpoint, roleARN, roleSession
// Retrieve retrieves credentials from the MinIO service.
// Error will be returned if the request fails.
func (m *STSWebIdentity) Retrieve(cc *CredContext) (Value, error) {
a, err := getWebIdentityCredentials(cc.Client, m.STSEndpoint, m.RoleARN, m.roleSessionName, m.Policy, m.GetWebIDTokenExpiry)
client := m.Client
if client == nil {
client = cc.Client
}

a, err := getWebIdentityCredentials(client, m.STSEndpoint, m.RoleARN, m.roleSessionName, m.Policy, m.GetWebIDTokenExpiry)
if err != nil {
return Value{}, err
}
Expand Down

0 comments on commit f9f409f

Please sign in to comment.