Skip to content

Commit

Permalink
Allow custom http.Client in credential providers.
Browse files Browse the repository at this point in the history
  • Loading branch information
ramondeklein committed Dec 26, 2024
1 parent 2be0ccc commit e762fe6
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 8 deletions.
10 changes: 9 additions & 1 deletion pkg/credentials/assume_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ type AssumeRoleResult struct {
type STSAssumeRole struct {
Expiry

// Optional http Client to use when connecting to MinIO STS service
// (overrides default client in CredContext)
Client *http.Client

// STS endpoint to fetch STS credentials.
STSEndpoint string

Expand Down Expand Up @@ -219,7 +223,11 @@ func getAssumeRoleCredentials(clnt *http.Client, endpoint string, opts STSAssume
// Retrieve retrieves credentials from the MinIO service.
// Error will be returned if the request fails.
func (m *STSAssumeRole) Retrieve(cc *CredContext) (Value, error) {
a, err := getAssumeRoleCredentials(cc.Client, m.STSEndpoint, m.Options)
client := m.Client
if client == nil {
client = cc.Client
}
a, err := getAssumeRoleCredentials(client, m.STSEndpoint, m.Options)
if err != nil {
return Value{}, err
}
Expand Down
5 changes: 5 additions & 0 deletions pkg/credentials/iam_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ const DefaultExpiryWindow = -1
type IAM struct {
Expiry

// Optional http Client to use when connecting to IAM metadata service
// (overrides default client in CredContext)
Client *http.Client

// Custom endpoint to fetch IAM role credentials.
Endpoint string

Expand Down Expand Up @@ -154,6 +158,7 @@ func (m *IAM) Retrieve(cc *CredContext) (Value, error) {
}

creds := &STSWebIdentity{
Client: m.Client,
STSEndpoint: endpoint,
GetWebIDTokenExpiry: func() (*WebIdentityToken, error) {
token, err := os.ReadFile(identityFile)
Expand Down
10 changes: 9 additions & 1 deletion pkg/credentials/sts_client_grants.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ type ClientGrantsToken struct {
type STSClientGrants struct {
Expiry

// Optional http Client to use when connecting to MinIO STS service.
// (overrides default client in CredContext)
Client *http.Client

// MinIO endpoint to fetch STS credentials.
STSEndpoint string

Expand Down Expand Up @@ -159,7 +163,11 @@ func getClientGrantsCredentials(clnt *http.Client, endpoint string,
// Retrieve retrieves credentials from the MinIO service.
// Error will be returned if the request fails.
func (m *STSClientGrants) Retrieve(cc *CredContext) (Value, error) {
a, err := getClientGrantsCredentials(cc.Client, m.STSEndpoint, m.GetClientGrantsTokenExpiry)
client := m.Client
if client == nil {
client = cc.Client
}
a, err := getClientGrantsCredentials(client, m.STSEndpoint, m.GetClientGrantsTokenExpiry)
if err != nil {
return Value{}, err
}
Expand Down
11 changes: 10 additions & 1 deletion pkg/credentials/sts_custom_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ type AssumeRoleWithCustomTokenResponse struct {
type CustomTokenIdentity struct {
Expiry

// Optional http Client to use when connecting to MinIO STS service.
// (overrides default client in CredContext)
Client *http.Client

// MinIO server STS endpoint to fetch STS credentials.
STSEndpoint string

Expand Down Expand Up @@ -90,7 +94,12 @@ func (c *CustomTokenIdentity) Retrieve(cc *CredContext) (value Value, err error)
return value, err
}

resp, err := cc.Client.Do(req)
client := c.Client
if client == nil {
client = cc.Client
}

resp, err := client.Do(req)
if err != nil {
return value, err
}
Expand Down
11 changes: 10 additions & 1 deletion pkg/credentials/sts_ldap_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ type LDAPIdentityResult struct {
type LDAPIdentity struct {
Expiry

// Optional http Client to use when connecting to MinIO STS service.
// (overrides default client in CredContext)
Client *http.Client

// Exported STS endpoint to fetch STS credentials.
STSEndpoint string

Expand Down Expand Up @@ -143,7 +147,12 @@ func (k *LDAPIdentity) Retrieve(cc *CredContext) (value Value, err error) {

req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

resp, err := cc.Client.Do(req)
client := k.Client
if client == nil {
client = cc.Client
}

resp, err := client.Do(req)
if err != nil {
return value, err
}
Expand Down
24 changes: 21 additions & 3 deletions pkg/credentials/sts_tls_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ import (
// livetime.
type CertificateIdentityOption func(*STSCertificateIdentity)

// CertificateIdentityWithTransport returns a CertificateIdentityOption that
// customizes the STSCertificateIdentity with the given http.RoundTripper.
func CertificateIdentityWithTransport(t http.RoundTripper) CertificateIdentityOption {
return CertificateIdentityOption(func(i *STSCertificateIdentity) { i.Client.Transport = t })
}

// CertificateIdentityWithExpiry returns a CertificateIdentityOption that
// customizes the STSCertificateIdentity with the given livetime.
//
Expand All @@ -47,6 +53,10 @@ func CertificateIdentityWithExpiry(livetime time.Duration) CertificateIdentityOp
type STSCertificateIdentity struct {
Expiry

// Optional http Client to use when connecting to MinIO STS service.
// (overrides default client in CredContext)
Client *http.Client

// STSEndpoint is the base URL endpoint of the STS API.
// For example, https://minio.local:9000
STSEndpoint string
Expand Down Expand Up @@ -115,16 +125,24 @@ func (i *STSCertificateIdentity) Retrieve(cc *CredContext) (Value, error) {
}
req.Form.Add("DurationSeconds", strconv.FormatUint(uint64(livetime.Seconds()), 10))

tr, ok := cc.Client.Transport.(*http.Transport)
client := i.Client
if client == nil {
client = cc.Client
}

tr, ok := client.Transport.(*http.Transport)
if !ok {
return Value{}, fmt.Errorf("CredContext should contain an http.Transport value")
}

trCopy := tr.Clone()
trCopy.TLSClientConfig.Certificates = []tls.Certificate{i.Certificate}

client := http.Client{
Transport: trCopy,
client = &http.Client{
Transport: trCopy,
CheckRedirect: client.CheckRedirect,
Jar: client.Jar,
Timeout: client.Timeout,
}

resp, err := client.Do(req)
Expand Down
11 changes: 10 additions & 1 deletion pkg/credentials/sts_web_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ type WebIdentityToken struct {
type STSWebIdentity struct {
Expiry

// Optional http Client to use when connecting to MinIO STS service.
// (overrides default client in CredContext)
Client *http.Client

// Exported STS endpoint to fetch STS credentials.
STSEndpoint string

Expand Down Expand Up @@ -216,7 +220,12 @@ func getWebIdentityCredentials(clnt *http.Client, endpoint, roleARN, roleSession
// Retrieve retrieves credentials from the MinIO service.
// Error will be returned if the request fails.
func (m *STSWebIdentity) Retrieve(cc *CredContext) (Value, error) {
a, err := getWebIdentityCredentials(cc.Client, m.STSEndpoint, m.RoleARN, m.roleSessionName, m.Policy, m.GetWebIDTokenExpiry)
client := m.Client
if client == nil {
client = cc.Client
}

a, err := getWebIdentityCredentials(client, m.STSEndpoint, m.RoleARN, m.roleSessionName, m.Policy, m.GetWebIDTokenExpiry)
if err != nil {
return Value{}, err
}
Expand Down

0 comments on commit e762fe6

Please sign in to comment.