Skip to content

Commit

Permalink
Merge pull request #153 from mdreem/allowRoleAssumption
Browse files Browse the repository at this point in the history
Allow role assumption
  • Loading branch information
clarafu authored Feb 1, 2022
2 parents 5298d59 + 746c018 commit 80c54bd
Show file tree
Hide file tree
Showing 13 changed files with 72 additions and 7 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ version numbers.
* `session_token`: *Optional.* The AWS STS session token to use when
accessing the bucket.

* `aws_role_arn`: *Optional.* The AWS role ARN to be assumed by the user
identified by `access_key_id` and `secret_access_key`.

* `region_name`: *Optional.* The region the bucket is in. Defaults to
`us-east-1`.

Expand Down Expand Up @@ -252,6 +255,12 @@ docker build . -t s3-resource --target tests -f dockerfiles/ubuntu/Dockerfile \
--build-arg S3_ENDPOINT="https://s3.amazonaws.com"
```

##### Integration tests using role assumption

If `S3_TESTING_AWS_ROLE_ARN` is set to a role ARN, this role will be assumed for accessing
the S3 bucket during integration tests. The whole integration test suite runs either
completely using role assumption or completely by direct access via the credentials.

##### Required IAM permissions

In addition to the required permissions above, the `s3:PutObjectTagging` permission is required to run integration tests.
Expand Down
1 change: 1 addition & 0 deletions cmd/check/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ func main() {
os.Stderr,
awsConfig,
request.Source.UseV2Signing,
request.Source.AwsRoleARN,
)

command := check.NewCommand(client)
Expand Down
1 change: 1 addition & 0 deletions cmd/in/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ func main() {
os.Stderr,
awsConfig,
request.Source.UseV2Signing,
request.Source.AwsRoleARN,
)

command := in.NewCommand(client)
Expand Down
1 change: 1 addition & 0 deletions cmd/out/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func main() {
os.Stderr,
awsConfig,
request.Source.UseV2Signing,
request.Source.AwsRoleARN,
)

command := out.NewCommand(os.Stderr, client)
Expand Down
1 change: 1 addition & 0 deletions dockerfiles/alpine/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ FROM resource AS tests
ARG S3_TESTING_ACCESS_KEY_ID
ARG S3_TESTING_SECRET_ACCESS_KEY
ARG S3_TESTING_SESSION_TOKEN
ARG S3_TESTING_AWS_ROLE_ARN
ARG S3_VERSIONED_TESTING_BUCKET
ARG S3_TESTING_BUCKET
ARG S3_TESTING_REGION
Expand Down
1 change: 1 addition & 0 deletions dockerfiles/ubuntu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ FROM resource AS tests
ARG S3_TESTING_ACCESS_KEY_ID
ARG S3_TESTING_SECRET_ACCESS_KEY
ARG S3_TESTING_SESSION_TOKEN
ARG S3_TESTING_AWS_ROLE_ARN
ARG S3_VERSIONED_TESTING_BUCKET
ARG S3_TESTING_BUCKET
ARG S3_TESTING_REGION
Expand Down
6 changes: 3 additions & 3 deletions in/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ var _ = Describe("In Command", func() {
})
})

Context("when configured globaly to skip download", func() {
Context("when configured globally to skip download", func() {
BeforeEach(func() {
request.Source.SkipDownload = true
})
Expand All @@ -91,7 +91,7 @@ var _ = Describe("In Command", func() {
})
})

Context("when configured localy to skip download", func() {
Context("when configured locally to skip download", func() {
BeforeEach(func() {
request.Params.SkipDownload = "true"
})
Expand All @@ -103,7 +103,7 @@ var _ = Describe("In Command", func() {
})
})

Context("when override localy to not skip download", func() {
Context("when override locally to not skip download", func() {
BeforeEach(func() {
request.Source.SkipDownload = true
request.Params.SkipDownload = "false"
Expand Down
5 changes: 5 additions & 0 deletions integration/check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ var _ = Describe("check", func() {
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
AwsRoleARN: awsRoleARN,
Bucket: versionedBucketName,
RegionName: regionName,
Regexp: "some-regex",
Expand Down Expand Up @@ -83,6 +84,7 @@ var _ = Describe("check", func() {
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
AwsRoleARN: awsRoleARN,
Bucket: bucketName,
RegionName: regionName,
Endpoint: endpoint,
Expand Down Expand Up @@ -178,6 +180,7 @@ var _ = Describe("check", func() {
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
AwsRoleARN: awsRoleARN,
Bucket: versionedBucketName,
RegionName: regionName,
Endpoint: endpoint,
Expand Down Expand Up @@ -321,6 +324,7 @@ var _ = Describe("check", func() {
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
AwsRoleARN: awsRoleARN,
Bucket: bucketName,
RegionName: regionName,
Endpoint: endpoint,
Expand Down Expand Up @@ -484,6 +488,7 @@ var _ = Describe("check", func() {
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
AwsRoleARN: awsRoleARN,
Bucket: versionedBucketName,
RegionName: regionName,
Endpoint: endpoint,
Expand Down
6 changes: 6 additions & 0 deletions integration/in_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ var _ = Describe("in", func() {
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
AwsRoleARN: awsRoleARN,
Bucket: versionedBucketName,
RegionName: regionName,
Endpoint: endpoint,
Expand Down Expand Up @@ -92,6 +93,7 @@ var _ = Describe("in", func() {
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
AwsRoleARN: awsRoleARN,
Bucket: bucketName,
RegionName: regionName,
Endpoint: endpoint,
Expand Down Expand Up @@ -215,6 +217,7 @@ var _ = Describe("in", func() {
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
AwsRoleARN: awsRoleARN,
Bucket: versionedBucketName,
RegionName: regionName,
Endpoint: endpoint,
Expand Down Expand Up @@ -346,6 +349,7 @@ var _ = Describe("in", func() {
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
AwsRoleARN: awsRoleARN,
CloudfrontURL: os.Getenv("S3_TESTING_CLOUDFRONT_URL"),
RegionName: regionName,
Endpoint: endpoint,
Expand Down Expand Up @@ -421,6 +425,7 @@ var _ = Describe("in", func() {
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
AwsRoleARN: awsRoleARN,
CloudfrontURL: "https://no-dots-here",
RegionName: regionName,
Endpoint: endpoint,
Expand Down Expand Up @@ -452,6 +457,7 @@ var _ = Describe("in", func() {
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
AwsRoleARN: awsRoleARN,
Bucket: bucketName,
RegionName: regionName,
Endpoint: endpoint,
Expand Down
17 changes: 14 additions & 3 deletions integration/integration_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package integration_test

import (
"encoding/json"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"io/ioutil"
"os"

Expand All @@ -24,6 +25,7 @@ func TestIntegration(t *testing.T) {
var accessKeyID = os.Getenv("S3_TESTING_ACCESS_KEY_ID")
var secretAccessKey = os.Getenv("S3_TESTING_SECRET_ACCESS_KEY")
var sessionToken = os.Getenv("S3_TESTING_SESSION_TOKEN")
var awsRoleARN = os.Getenv("S3_TESTING_AWS_ROLE_ARN")
var versionedBucketName = os.Getenv("S3_VERSIONED_TESTING_BUCKET")
var bucketName = os.Getenv("S3_TESTING_BUCKET")
var regionName = os.Getenv("S3_TESTING_REGION")
Expand Down Expand Up @@ -82,7 +84,7 @@ func getSessionTokenS3Client(awsConfig *aws.Config) (*s3.S3, s3resource.S3Client
false,
)
s3Service := s3.New(session.New(newAwsConfig), newAwsConfig)
s3client := s3resource.NewS3Client(ioutil.Discard, newAwsConfig, v2signing == "true")
s3client := s3resource.NewS3Client(ioutil.Discard, newAwsConfig, v2signing == "true", awsRoleARN)

return s3Service, s3client
}
Expand Down Expand Up @@ -128,9 +130,18 @@ var _ = SynchronizedBeforeSuite(func() []byte {
false,
)

s3Service = s3.New(session.New(awsConfig), awsConfig)
additionalAwsConfig := aws.Config{}
if len(awsRoleARN) != 0 {
stsConfig := awsConfig.Copy()
stsConfig.Endpoint = nil
stsSession := session.Must(session.NewSession(stsConfig))
roleCredentials := stscreds.NewCredentials(stsSession, awsRoleARN)

s3client = s3resource.NewS3Client(ioutil.Discard, awsConfig, v2signing == "true")
additionalAwsConfig.Credentials = roleCredentials
}

s3Service = s3.New(session.New(awsConfig), awsConfig, &additionalAwsConfig)
s3client = s3resource.NewS3Client(ioutil.Discard, awsConfig, v2signing == "true", awsRoleARN)
}
})

Expand Down
10 changes: 10 additions & 0 deletions integration/out_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ var _ = Describe("out", func() {
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
AwsRoleARN: awsRoleARN,
Bucket: versionedBucketName,
RegionName: regionName,
Endpoint: endpoint,
Expand Down Expand Up @@ -94,6 +95,7 @@ var _ = Describe("out", func() {
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
AwsRoleARN: awsRoleARN,
Bucket: bucketName,
RegionName: regionName,
Endpoint: endpoint,
Expand Down Expand Up @@ -137,6 +139,7 @@ var _ = Describe("out", func() {
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
AwsRoleARN: awsRoleARN,
Bucket: bucketName,
RegionName: regionName,
Endpoint: endpoint,
Expand Down Expand Up @@ -178,6 +181,7 @@ var _ = Describe("out", func() {
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
AwsRoleARN: awsRoleARN,
Bucket: bucketName,
RegionName: regionName,
Endpoint: endpoint,
Expand Down Expand Up @@ -224,6 +228,7 @@ var _ = Describe("out", func() {
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
AwsRoleARN: awsRoleARN,
Bucket: bucketName,
RegionName: regionName,
Endpoint: endpoint,
Expand Down Expand Up @@ -309,6 +314,7 @@ var _ = Describe("out", func() {
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
AwsRoleARN: awsRoleARN,
Bucket: bucketName,
RegionName: regionName,
Endpoint: endpoint,
Expand Down Expand Up @@ -346,6 +352,7 @@ var _ = Describe("out", func() {
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
AwsRoleARN: awsRoleARN,
Bucket: bucketName,
RegionName: regionName,
Endpoint: endpoint,
Expand Down Expand Up @@ -400,6 +407,7 @@ var _ = Describe("out", func() {
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
AwsRoleARN: awsRoleARN,
Bucket: bucketName,
RegionName: regionName,
VersionedFile: filepath.Join(directoryPrefix, "file-to-upload"),
Expand Down Expand Up @@ -450,6 +458,7 @@ var _ = Describe("out", func() {
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
AwsRoleARN: awsRoleARN,
Bucket: versionedBucketName,
RegionName: regionName,
VersionedFile: filepath.Join(directoryPrefix, "file-to-upload"),
Expand Down Expand Up @@ -508,6 +517,7 @@ var _ = Describe("out", func() {
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
AwsRoleARN: awsRoleARN,
Bucket: versionedBucketName,
RegionName: regionName,
Endpoint: endpoint,
Expand Down
1 change: 1 addition & 0 deletions models.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ type Source struct {
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
SessionToken string `json:"session_token"`
AwsRoleARN string `json:"aws_role_arn"`
Bucket string `json:"bucket"`
Regexp string `json:"regexp"`
VersionedFile string `json:"versioned_file"`
Expand Down
20 changes: 19 additions & 1 deletion s3client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"io"
"io/ioutil"
"os"
Expand Down Expand Up @@ -68,9 +69,13 @@ func NewS3Client(
progressOutput io.Writer,
awsConfig *aws.Config,
useV2Signing bool,
roleToAssume string,
) S3Client {
sess := session.New(awsConfig)
client := s3.New(sess, awsConfig)

assumedRoleAwsConfig := fetchCredentialsForRoleIfDefined(roleToAssume, awsConfig)

client := s3.New(sess, awsConfig, &assumedRoleAwsConfig)

if useV2Signing {
setv2Handlers(client)
Expand All @@ -84,6 +89,19 @@ func NewS3Client(
}
}

func fetchCredentialsForRoleIfDefined(roleToAssume string, awsConfig *aws.Config) aws.Config {
assumedRoleAwsConfig := aws.Config{}
if len(roleToAssume) != 0 {
stsConfig := awsConfig.Copy()
stsConfig.Endpoint = nil
stsSession := session.Must(session.NewSession(stsConfig))
roleCredentials := stscreds.NewCredentials(stsSession, roleToAssume)

assumedRoleAwsConfig.Credentials = roleCredentials
}
return assumedRoleAwsConfig
}

func NewAwsConfig(
accessKey string,
secretKey string,
Expand Down

0 comments on commit 80c54bd

Please sign in to comment.