Skip to content

Commit

Permalink
fix get aws creds from environment
Browse files Browse the repository at this point in the history
Signed-off-by: Fabian Martinez <[email protected]>
  • Loading branch information
famarting committed Nov 28, 2024
1 parent f48b412 commit c8d900a
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 19 deletions.
7 changes: 0 additions & 7 deletions common/authentication/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,6 @@ type Provider interface {
Close() error
}

func isX509Auth(m map[string]string) bool {
tp, _ := m["trustProfileArn"]
ta, _ := m["trustAnchorArn"]
ar, _ := m["assumeRoleArn"]
return tp != "" && ta != "" && ar != ""
}

func NewProvider(ctx context.Context, opts Options, cfg *aws.Config) (Provider, error) {
if isX509Auth(opts.Properties) {
return newX509(ctx, opts, cfg)
Expand Down
38 changes: 27 additions & 11 deletions common/authentication/aws/static.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,7 @@ type StaticAuth struct {

func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth, error) {
auth := &StaticAuth{
logger: opts.Logger,
region: &opts.Region,
endpoint: &opts.Endpoint,
accessKey: &opts.AccessKey,
secretKey: &opts.SecretKey,
sessionToken: &opts.SessionToken,
assumeRoleARN: &opts.AssumeRoleARN,
sessionName: &opts.SessionName,

logger: opts.Logger,
cfg: func() *aws.Config {
// if nil is passed or it's just a default cfg,
// then we use the options to build the aws cfg.
Expand All @@ -70,7 +62,29 @@ func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth
clients: newClients(),
}

initialSession, err := auth.getTokenClient()
if opts.Region != "" {
auth.region = &opts.Region
}
if opts.Endpoint != "" {
auth.endpoint = &opts.Endpoint
}
if opts.AccessKey != "" {
auth.accessKey = &opts.AccessKey
}
if opts.SecretKey != "" {
auth.secretKey = &opts.SecretKey
}
if opts.SessionToken != "" {
auth.sessionToken = &opts.SessionToken
}
if opts.AssumeRoleARN != "" {
auth.assumeRoleARN = &opts.AssumeRoleARN
}
if opts.SessionName != "" {
auth.sessionName = &opts.SessionName
}

initialSession, err := auth.createSession()
if err != nil {
return nil, fmt.Errorf("failed to get token client: %v", err)
}
Expand Down Expand Up @@ -243,7 +257,7 @@ func (a *StaticAuth) Kafka(opts KafkaOptions) (*KafkaClients, error) {
return a.clients.kafka, nil
}

func (a *StaticAuth) getTokenClient() (*session.Session, error) {
func (a *StaticAuth) createSession() (*session.Session, error) {
var awsConfig *aws.Config
if a.cfg == nil {
awsConfig = aws.NewConfig()
Expand All @@ -264,6 +278,8 @@ func (a *StaticAuth) getTokenClient() (*session.Session, error) {
awsConfig = awsConfig.WithEndpoint(*a.endpoint)
}

// TODO support assume role for all aws components

awsSession, err := session.NewSessionWithOptions(session.Options{
Config: *awsConfig,
SharedConfigState: session.SharedConfigEnable,
Expand Down
8 changes: 7 additions & 1 deletion common/authentication/aws/static_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,17 @@ func TestGetTokenClient(t *testing.T) {
endpoint: aws.String("https://test.endpoint.com"),
},
},
{
name: "creds from environment",
awsInstance: &StaticAuth{
region: aws.String("us-west-2"),
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
session, err := tt.awsInstance.getTokenClient()
session, err := tt.awsInstance.createSession()
require.NotNil(t, session)
require.NoError(t, err)
assert.Equal(t, tt.awsInstance.region, session.Config.Region)
Expand Down
7 changes: 7 additions & 0 deletions common/authentication/aws/x509.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ import (
"github.com/dapr/kit/ptr"
)

func isX509Auth(m map[string]string) bool {
tp := m["trustProfileArn"]
ta := m["trustAnchorArn"]
ar := m["assumeRoleArn"]
return tp != "" && ta != "" && ar != ""
}

type x509Options struct {
TrustProfileArn *string `json:"trustProfileArn" mapstructure:"trustProfileArn"`
TrustAnchorArn *string `json:"trustAnchorArn" mapstructure:"trustAnchorArn"`
Expand Down

0 comments on commit c8d900a

Please sign in to comment.