Skip to content

Commit

Permalink
Various AWS improvements (dapr#525)
Browse files Browse the repository at this point in the history
* Support sessiontoken

* Fixed tests and other components

* Fixed tests and things

* fmt

* Fix lint errors

* gofmt

* Fixed lint bugs

* Remove unneeded parameter

* gofmt

Co-authored-by: Yaron Schneider <[email protected]>
  • Loading branch information
trondhindenes and yaron2 authored Nov 23, 2020
1 parent b1f1ecc commit 8d69783
Show file tree
Hide file tree
Showing 15 changed files with 151 additions and 106 deletions.
5 changes: 3 additions & 2 deletions authentication/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ import (
"github.com/aws/aws-sdk-go/aws/session"
)

func GetClient(accessKey string, secretKey string, region string, endpoint string) (*session.Session, error) {
func GetClient(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (*session.Session, error) {
awsConfig := aws.NewConfig()

if region != "" {
awsConfig = awsConfig.WithRegion(region)
}

if accessKey != "" && secretKey != "" {
awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(accessKey, secretKey, ""))
awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(accessKey, secretKey, sessionToken))
}

if endpoint != "" {
Expand Down
14 changes: 7 additions & 7 deletions bindings/aws/dynamodb/dynamodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ type DynamoDB struct {
}

type dynamoDBMetadata struct {
Region string `json:"region"`
Endpoint string `json:"endpoint"`
AccessKey string `json:"accessKey"`
SecretKey string `json:"secretKey"`
Table string `json:"table"`
Region string `json:"region"`
Endpoint string `json:"endpoint"`
AccessKey string `json:"accessKey"`
SecretKey string `json:"secretKey"`
SessionToken string `json:"sessionToken"`
Table string `json:"table"`
}

// NewDynamoDB returns a new DynamoDB instance
Expand Down Expand Up @@ -99,11 +100,10 @@ func (d *DynamoDB) getDynamoDBMetadata(spec bindings.Metadata) (*dynamoDBMetadat
}

func (d *DynamoDB) getClient(metadata *dynamoDBMetadata) (*dynamodb.DynamoDB, error) {
sess, err := aws_auth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.Region, metadata.Endpoint)
sess, err := aws_auth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint)
if err != nil {
return nil, err
}

c := dynamodb.New(sess)

return c, nil
Expand Down
5 changes: 4 additions & 1 deletion bindings/aws/dynamodb/dynamodb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ import (

func TestParseMetadata(t *testing.T) {
m := bindings.Metadata{}
m.Properties = map[string]string{"AccessKey": "a", "Region": "a", "SecretKey": "a", "Table": "a", "Endpoint": "a"}
m.Properties = map[string]string{
"AccessKey": "a", "Region": "a", "SecretKey": "a", "Table": "a", "Endpoint": "a", "SessionToken": "t",
}
dy := DynamoDB{}
meta, err := dy.getDynamoDBMetadata(m)
assert.Nil(t, err)
Expand All @@ -23,4 +25,5 @@ func TestParseMetadata(t *testing.T) {
assert.Equal(t, "a", meta.SecretKey)
assert.Equal(t, "a", meta.Table)
assert.Equal(t, "a", meta.Endpoint)
assert.Equal(t, "t", meta.SessionToken)
}
4 changes: 2 additions & 2 deletions bindings/aws/kinesis/kinesis.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type kinesisMetadata struct {
Endpoint string `json:"endpoint"`
AccessKey string `json:"accessKey"`
SecretKey string `json:"secretKey"`
SessionToken string `json:"sessionToken"`
KinesisConsumerMode kinesisConsumerMode `json:"mode"`
}

Expand Down Expand Up @@ -295,11 +296,10 @@ func (a *AWSKinesis) waitUntilConsumerExists(ctx aws.Context, input *kinesis.Des
}

func (a *AWSKinesis) getClient(metadata *kinesisMetadata) (*kinesis.Kinesis, error) {
sess, err := aws_auth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.Region, metadata.Endpoint)
sess, err := aws_auth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint)
if err != nil {
return nil, err
}

k := kinesis.New(sess)

return k, nil
Expand Down
2 changes: 2 additions & 0 deletions bindings/aws/kinesis/kinesis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ func TestParseMetadata(t *testing.T) {
"StreamName": "stream",
"Mode": "extended",
"Endpoint": "endpoint",
"SessionToken": "token",
}
kinesis := AWSKinesis{}
meta, err := kinesis.parseMetadata(m)
Expand All @@ -32,5 +33,6 @@ func TestParseMetadata(t *testing.T) {
assert.Equal(t, "test", meta.ConsumerName)
assert.Equal(t, "stream", meta.StreamName)
assert.Equal(t, "endpoint", meta.Endpoint)
assert.Equal(t, "token", meta.SessionToken)
assert.Equal(t, kinesisConsumerMode("extended"), meta.KinesisConsumerMode)
}
13 changes: 7 additions & 6 deletions bindings/aws/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ type AWSS3 struct {
}

type s3Metadata struct {
Region string `json:"region"`
Endpoint string `json:"endpoint"`
AccessKey string `json:"accessKey"`
SecretKey string `json:"secretKey"`
Bucket string `json:"bucket"`
Region string `json:"region"`
Endpoint string `json:"endpoint"`
AccessKey string `json:"accessKey"`
SecretKey string `json:"secretKey"`
SessionToken string `json:"sessionToken"`
Bucket string `json:"bucket"`
}

// NewAWSS3 returns a new AWSS3 instance
Expand Down Expand Up @@ -92,7 +93,7 @@ func (s *AWSS3) parseMetadata(metadata bindings.Metadata) (*s3Metadata, error) {
}

func (s *AWSS3) getClient(metadata *s3Metadata) (*s3manager.Uploader, error) {
sess, err := aws_auth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.Region, metadata.Endpoint)
sess, err := aws_auth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint)
if err != nil {
return nil, err
}
Expand Down
5 changes: 4 additions & 1 deletion bindings/aws/s3/s3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ import (

func TestParseMetadata(t *testing.T) {
m := bindings.Metadata{}
m.Properties = map[string]string{"AccessKey": "key", "Region": "region", "SecretKey": "secret", "Bucket": "test", "Endpoint": "endpoint"}
m.Properties = map[string]string{
"AccessKey": "key", "Region": "region", "SecretKey": "secret", "Bucket": "test", "Endpoint": "endpoint", "SessionToken": "token",
}
s3 := AWSS3{}
meta, err := s3.parseMetadata(m)
assert.Nil(t, err)
Expand All @@ -23,4 +25,5 @@ func TestParseMetadata(t *testing.T) {
assert.Equal(t, "secret", meta.SecretKey)
assert.Equal(t, "test", meta.Bucket)
assert.Equal(t, "endpoint", meta.Endpoint)
assert.Equal(t, "token", meta.SessionToken)
}
13 changes: 7 additions & 6 deletions bindings/aws/sns/sns.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ type AWSSNS struct {
}

type snsMetadata struct {
TopicArn string `json:"topicArn"`
Region string `json:"region"`
Endpoint string `json:"endpoint"`
AccessKey string `json:"accessKey"`
SecretKey string `json:"secretKey"`
TopicArn string `json:"topicArn"`
Region string `json:"region"`
Endpoint string `json:"endpoint"`
AccessKey string `json:"accessKey"`
SecretKey string `json:"secretKey"`
SessionToken string `json:"sessionToken"`
}

type dataPayload struct {
Expand Down Expand Up @@ -73,7 +74,7 @@ func (a *AWSSNS) parseMetadata(metadata bindings.Metadata) (*snsMetadata, error)
}

func (a *AWSSNS) getClient(metadata *snsMetadata) (*sns.SNS, error) {
sess, err := aws_auth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.Region, metadata.Endpoint)
sess, err := aws_auth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint)
if err != nil {
return nil, err
}
Expand Down
5 changes: 4 additions & 1 deletion bindings/aws/sns/sns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ import (

func TestParseMetadata(t *testing.T) {
m := bindings.Metadata{}
m.Properties = map[string]string{"TopicArn": "a", "Region": "a", "AccessKey": "a", "SecretKey": "a", "Endpoint": "a"}
m.Properties = map[string]string{
"TopicArn": "a", "Region": "a", "AccessKey": "a", "SecretKey": "a", "Endpoint": "a", "SessionToken": "t",
}
s := AWSSNS{}
snsM, err := s.parseMetadata(m)
assert.Nil(t, err)
Expand All @@ -23,4 +25,5 @@ func TestParseMetadata(t *testing.T) {
assert.Equal(t, "a", snsM.AccessKey)
assert.Equal(t, "a", snsM.SecretKey)
assert.Equal(t, "a", snsM.Endpoint)
assert.Equal(t, "t", snsM.SessionToken)
}
13 changes: 7 additions & 6 deletions bindings/aws/sqs/sqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ type AWSSQS struct {
}

type sqsMetadata struct {
QueueName string `json:"queueName"`
Region string `json:"region"`
Endpoint string `json:"endpoint"`
AccessKey string `json:"accessKey"`
SecretKey string `json:"secretKey"`
QueueName string `json:"queueName"`
Region string `json:"region"`
Endpoint string `json:"endpoint"`
AccessKey string `json:"accessKey"`
SecretKey string `json:"secretKey"`
SessionToken string `json:"sessionToken"`
}

// NewAWSSQS returns a new AWS SQS instance
Expand Down Expand Up @@ -132,7 +133,7 @@ func (a *AWSSQS) parseSQSMetadata(metadata bindings.Metadata) (*sqsMetadata, err
}

func (a *AWSSQS) getClient(metadata *sqsMetadata) (*sqs.SQS, error) {
sess, err := aws_auth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.Region, metadata.Endpoint)
sess, err := aws_auth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint)
if err != nil {
return nil, err
}
Expand Down
5 changes: 4 additions & 1 deletion bindings/aws/sqs/sqs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ import (

func TestParseMetadata(t *testing.T) {
m := bindings.Metadata{}
m.Properties = map[string]string{"QueueName": "a", "Region": "a", "AccessKey": "a", "SecretKey": "a", "Endpoint": "a"}
m.Properties = map[string]string{
"QueueName": "a", "Region": "a", "AccessKey": "a", "SecretKey": "a", "Endpoint": "a", "SessionToken": "t",
}
s := AWSSQS{}
sqsM, err := s.parseSQSMetadata(m)
assert.Nil(t, err)
Expand All @@ -23,4 +25,5 @@ func TestParseMetadata(t *testing.T) {
assert.Equal(t, "a", sqsM.AccessKey)
assert.Equal(t, "a", sqsM.SecretKey)
assert.Equal(t, "a", sqsM.Endpoint)
assert.Equal(t, "t", sqsM.SessionToken)
}
71 changes: 35 additions & 36 deletions pubsub/aws/snssqs/snssqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ type snsSqs struct {
topicHash map[string]string
// key is the topic name, value holds the ARN of the queue and its url
queues map[string]*sqsQueueInfo
awsAcctID string
snsClient *sns.SNS
sqsClient *sqs.SQS
metadata *snsSqsMetadata
Expand All @@ -41,23 +40,23 @@ type snsSqsMetadata struct {
sqsQueueName string

// aws endpoint for the component to use.
awsEndpoint string
// aws account ID to use for SNS/SQS. Required
awsAccountID string
// aws secret corresponding to the account ID. Required
awsSecret string
// aws token to use. Required
awsToken string
// aws region in which SNS/SQS should create resources. Required
awsRegion string
Endpoint string
// access key to use for accessing sqs/sns
AccessKey string
// secret key to use for accessing sqs/sns
SecretKey string
// aws session token to use.
SessionToken string
// aws region in which SNS/SQS should create resources
Region string

// amount of time in seconds that a message is hidden from receive requests after it is sent to a subscriber. Default: 10
messageVisibilityTimeout int64
// number of times to resend a message after processing of that message fails before removing that message from the queue. Default: 10
messageRetryLimit int64
// amount of time to await receipt of a message before making another request. Default: 1
messageWaitTimeSeconds int64
// maximum number of messsages to receive from the queue at a time. Default: 10, Maximum: 10
// maximum number of messages to receive from the queue at a time. Default: 10, Maximum: 10
messageMaxNumber int64
}

Expand All @@ -70,6 +69,17 @@ func NewSnsSqs(l logger.Logger) pubsub.PubSub {
return &snsSqs{logger: l, subscriptions: []*string{}}
}

func getAliasedProperty(aliases []string, metadata pubsub.Metadata) (string, bool) {
props := metadata.Properties
for _, s := range aliases {
if val, ok := props[s]; ok {
return val, true
}
}

return "", false
}

func parseInt64(input string, propertyName string) (int64, error) {
number, err := strconv.Atoi(input)
if err != nil {
Expand All @@ -94,39 +104,29 @@ func (s *snsSqs) getSnsSqsMetatdata(metadata pubsub.Metadata) (*snsSqsMetadata,
md.sqsQueueName = metadata.Properties["consumerID"]
s.logger.Debugf("Setting queue name to %s", md.sqsQueueName)

if val, ok := props["awsEndpoint"]; ok {
md.awsEndpoint = val
if val, ok := getAliasedProperty([]string{"Endpoint", "endpoint"}, metadata); ok {
s.logger.Debugf("endpoint: %s", val)
md.Endpoint = val
}

val, ok := props["awsAccountID"]

if !ok {
return nil, errors.New("missing required property: awsAccountID")
if val, ok := getAliasedProperty([]string{"awsAccountID", "accessKey"}, metadata); ok {
s.logger.Debugf("AccessKey: %s", val)
md.AccessKey = val
}

md.awsAccountID = val

val, ok = props["awsSecret"]
if !ok {
return nil, errors.New("missing required property: awsSecret")
if val, ok := getAliasedProperty([]string{"awsSecret", "secretKey"}, metadata); ok {
s.logger.Debugf("awsToken: %s", val)
md.SecretKey = val
}

md.awsSecret = val

val, ok = props["awsToken"]
if !ok {
md.awsToken = ""
} else {
md.awsToken = val
if val, ok := getAliasedProperty([]string{"sessionToken"}, metadata); ok {
md.SessionToken = val
}

val, ok = props["awsRegion"]
if !ok {
return nil, errors.New("missing required property: awsRegion")
if val, ok := getAliasedProperty([]string{"awsRegion", "region"}, metadata); ok {
md.Region = val
}

md.awsRegion = val

if val, ok := props["messageVisibilityTimeout"]; !ok {
md.messageVisibilityTimeout = 10
} else {
Expand Down Expand Up @@ -205,8 +205,7 @@ func (s *snsSqs) Init(metadata pubsub.Metadata) error {
s.topics = make(map[string]string)
s.topicHash = make(map[string]string)
s.queues = make(map[string]*sqsQueueInfo)
s.awsAcctID = md.awsAccountID
sess, err := aws_auth.GetClient(s.awsAcctID, md.awsSecret, md.awsRegion, md.awsEndpoint)
sess, err := aws_auth.GetClient(md.AccessKey, md.SecretKey, md.SessionToken, md.Region, md.Endpoint)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit 8d69783

Please sign in to comment.