From bee278ede5cc4d938efe5f6292a577fde06a6159 Mon Sep 17 00:00:00 2001 From: Lyle Franklin Date: Fri, 2 Sep 2016 09:54:04 -0700 Subject: [PATCH] `cloudfront_url` is honored for GET - not for CHECK or PUT - if `cloudfront_url` is specified, we GET from the cloudfront URL. - particularly effective when downloading assets from an unstable network (i.e. AWS China) ``` source: cloudfront_url: http://d111111abcdef8.cloudfront.net ``` Signed-off-by: Brian Cunnie [#129581055](https://www.pivotaltracker.com/story/show/129581055) --- README.md | 9 ++- cmd/in/main.go | 20 +++++ in/in_command.go | 15 ---- in/in_command_test.go | 19 ----- integration/in_test.go | 171 ++++++++++++++++++++++++++++++---------- integration/out_test.go | 1 - s3client.go | 4 + 7 files changed, 160 insertions(+), 79 deletions(-) diff --git a/README.md b/README.md index 47909175..ee18f49f 100644 --- a/README.md +++ b/README.md @@ -20,8 +20,12 @@ version numbers. URLs provided are signed. * `cloudfront_url`: *Optional.* The URL (scheme and domain) of your CloudFront - distribution that is fronting this bucket. This will be used in the `url` - file that is given to the following task. + distribution that is fronting this bucket (e.g + `https://d5yxxxxx.cloudfront.net`). This will affect `in` but not `check` + and `put`. `in` will ignore the `bucket` name setting, exclusively using the + `cloudfront_url`. When configuring CloudFront with versioned buckets, set + `Query String Forwarding and Caching` to `Forward all, cache based on all` to + ensure S3 calls succeed. * `endpoint`: *Optional.* Custom endpoint for using S3 compatible provider. @@ -34,6 +38,7 @@ version numbers. * `sse_kms_key_id`: *Optional.* The ID of the AWS KMS master encryption key used for the object. + ### File Names One of the following two options must be specified: diff --git a/cmd/in/main.go b/cmd/in/main.go index 5567ea91..505253be 100644 --- a/cmd/in/main.go +++ b/cmd/in/main.go @@ -4,8 +4,12 @@ import ( "encoding/json" "os" + "fmt" + "github.com/aws/aws-sdk-go/aws" "github.com/concourse/s3-resource" "github.com/concourse/s3-resource/in" + "net/url" + "strings" ) func main() { @@ -27,6 +31,22 @@ func main() { request.Source.DisableSSL, ) + if len(request.Source.CloudfrontURL) != 0 { + cloudfrontUrl, err := url.ParseRequestURI(request.Source.CloudfrontURL) + if err != nil { + s3resource.Fatal("parsing 'cloudfront_url'", err) + } + awsConfig.S3ForcePathStyle = aws.Bool(false) + + splitResult := strings.Split(cloudfrontUrl.Host, ".") + if len(splitResult) < 2 { + s3resource.Fatal("verifying 'cloudfront_url'", fmt.Errorf("'%s' doesn't have enough dots ('.'), a typical format is 'https://d111111abcdef8.cloudfront.net'", request.Source.CloudfrontURL)) + } + request.Source.Bucket = strings.Split(cloudfrontUrl.Host, ".")[0] + fqdn := strings.SplitAfterN(cloudfrontUrl.Host, ".", 2)[1] + awsConfig.Endpoint = aws.String(fmt.Sprintf("%s://%s", cloudfrontUrl.Scheme, fqdn)) + } + client := s3resource.NewS3Client( os.Stderr, awsConfig, diff --git a/in/in_command.go b/in/in_command.go index 5da54ee1..df8ec203 100644 --- a/in/in_command.go +++ b/in/in_command.go @@ -8,7 +8,6 @@ import ( "path" "path/filepath" - "github.com/cloudfoundry/gunk/urljoiner" "github.com/concourse/s3-resource" "github.com/concourse/s3-resource/versions" ) @@ -20,10 +19,6 @@ type RequestURLProvider struct { } func (up *RequestURLProvider) GetURL(request InRequest, remotePath string) string { - if request.Source.CloudfrontURL != "" { - return up.cloudfrontURL(request, remotePath) - } - return up.s3URL(request, remotePath) } @@ -31,16 +26,6 @@ func (up *RequestURLProvider) s3URL(request InRequest, remotePath string) string return up.s3Client.URL(request.Source.Bucket, remotePath, request.Source.Private, request.Version.VersionID) } -func (up *RequestURLProvider) cloudfrontURL(request InRequest, remotePath string) string { - url := urljoiner.Join(request.Source.CloudfrontURL, remotePath) - - if request.Version.VersionID != "" { - url = url + "?versionId=" + request.Version.VersionID - } - - return url -} - type InCommand struct { s3client s3resource.S3Client urlProvider RequestURLProvider diff --git a/in/in_command_test.go b/in/in_command_test.go index fd358404..ac7db481 100644 --- a/in/in_command_test.go +++ b/in/in_command_test.go @@ -109,25 +109,6 @@ var _ = Describe("In Command", func() { Ω(versionID).Should(BeEmpty()) }) - Context("when using a CloudFront domain", func() { - BeforeEach(func() { - request.Source.CloudfrontURL = "https://1234567890.cloudfront.net" - }) - - It("creates a 'url' file that contains the URL including the CloudFront domain", func() { - urlPath := filepath.Join(destDir, "url") - Ω(urlPath).ShouldNot(ExistOnFilesystem()) - - _, err := command.Run(destDir, request) - Ω(err).ShouldNot(HaveOccurred()) - - Ω(urlPath).Should(ExistOnFilesystem()) - contents, err := ioutil.ReadFile(urlPath) - Ω(err).ShouldNot(HaveOccurred()) - Ω(string(contents)).Should(Equal("https://1234567890.cloudfront.net/files/a-file-1.3.tgz")) - }) - }) - Context("when configured with private URLs", func() { BeforeEach(func() { request.Source.Private = true diff --git a/integration/in_test.go b/integration/in_test.go index 7c20101e..c6bf5054 100644 --- a/integration/in_test.go +++ b/integration/in_test.go @@ -3,6 +3,7 @@ package integration_test import ( "bytes" "encoding/json" + "fmt" "io/ioutil" "os" "os/exec" @@ -106,37 +107,23 @@ var _ = Describe("in", func() { Ω(err).ShouldNot(HaveOccurred()) tempFile.Close() - err = ioutil.WriteFile(tempFile.Name(), []byte("some-file-1"), 0755) - Ω(err).ShouldNot(HaveOccurred()) - - _, err = s3client.UploadFile(bucketName, filepath.Join(directoryPrefix, "some-file-1"), tempFile.Name(), "private", "", "") - Ω(err).ShouldNot(HaveOccurred()) - - err = ioutil.WriteFile(tempFile.Name(), []byte("some-file-2"), 0755) - Ω(err).ShouldNot(HaveOccurred()) - - _, err = s3client.UploadFile(bucketName, filepath.Join(directoryPrefix, "some-file-2"), tempFile.Name(), "private", "", "") - Ω(err).ShouldNot(HaveOccurred()) - - err = ioutil.WriteFile(tempFile.Name(), []byte("some-file-3"), 0755) - Ω(err).ShouldNot(HaveOccurred()) + for i := 1; i <= 3; i++ { + err = ioutil.WriteFile(tempFile.Name(), []byte(fmt.Sprintf("some-file-%d", i)), 0755) + Ω(err).ShouldNot(HaveOccurred()) - _, err = s3client.UploadFile(bucketName, filepath.Join(directoryPrefix, "some-file-3"), tempFile.Name(), "private", "", "") - Ω(err).ShouldNot(HaveOccurred()) + _, err = s3client.UploadFile(bucketName, filepath.Join(directoryPrefix, fmt.Sprintf("some-file-%d", i)), tempFile.Name(), "private", "", "") + Ω(err).ShouldNot(HaveOccurred()) + } err = os.Remove(tempFile.Name()) Ω(err).ShouldNot(HaveOccurred()) }) AfterEach(func() { - err := s3client.DeleteFile(bucketName, filepath.Join(directoryPrefix, "some-file-1")) - Ω(err).ShouldNot(HaveOccurred()) - - err = s3client.DeleteFile(bucketName, filepath.Join(directoryPrefix, "some-file-2")) - Ω(err).ShouldNot(HaveOccurred()) - - err = s3client.DeleteFile(bucketName, filepath.Join(directoryPrefix, "some-file-3")) - Ω(err).ShouldNot(HaveOccurred()) + for i := 1; i <= 3; i++ { + err := s3client.DeleteFile(bucketName, filepath.Join(directoryPrefix, fmt.Sprintf("some-file-%d", i))) + Ω(err).ShouldNot(HaveOccurred()) + } }) It("downloads the file", func() { @@ -201,22 +188,14 @@ var _ = Describe("in", func() { Ω(err).ShouldNot(HaveOccurred()) tempFile.Close() - err = ioutil.WriteFile(tempFile.Name(), []byte("some-file-1"), 0755) - Ω(err).ShouldNot(HaveOccurred()) - - _, err = s3client.UploadFile(versionedBucketName, filepath.Join(directoryPrefix, "some-file"), tempFile.Name(), "private", "", "") - Ω(err).ShouldNot(HaveOccurred()) - - err = ioutil.WriteFile(tempFile.Name(), []byte("some-file-2"), 0755) - Ω(err).ShouldNot(HaveOccurred()) - - _, err = s3client.UploadFile(versionedBucketName, filepath.Join(directoryPrefix, "some-file"), tempFile.Name(), "private", "", "") - Ω(err).ShouldNot(HaveOccurred()) - - err = ioutil.WriteFile(tempFile.Name(), []byte("some-file-3"), 0755) - Ω(err).ShouldNot(HaveOccurred()) + for i := 1; i <= 3; i++ { + err = ioutil.WriteFile(tempFile.Name(), []byte(fmt.Sprintf("some-file-%d", i)), 0755) + Ω(err).ShouldNot(HaveOccurred()) - _, err = s3client.UploadFile(versionedBucketName, filepath.Join(directoryPrefix, "some-file"), tempFile.Name(), "private", "", "") + _, err = s3client.UploadFile(versionedBucketName, filepath.Join(directoryPrefix, "some-file"), tempFile.Name(), "private", "", "") + Ω(err).ShouldNot(HaveOccurred()) + } + err = os.Remove(tempFile.Name()) Ω(err).ShouldNot(HaveOccurred()) versions, err := s3client.BucketFileVersions(versionedBucketName, filepath.Join(directoryPrefix, "some-file")) @@ -226,9 +205,6 @@ var _ = Describe("in", func() { err = json.NewEncoder(stdin).Encode(inRequest) Ω(err).ShouldNot(HaveOccurred()) - - err = os.Remove(tempFile.Name()) - Ω(err).ShouldNot(HaveOccurred()) }) AfterEach(func() { @@ -280,4 +256,115 @@ var _ = Describe("in", func() { }) }) + Context("when cloudfront_url is set", func() { + var inRequest in.InRequest + var directoryPrefix string + + BeforeEach(func() { + if len(os.Getenv("S3_TESTING_CLOUDFRONT_URL")) == 0 { + Skip("'S3_TESTING_CLOUDFRONT_URL' is not set, skipping.") + } + + directoryPrefix = "in-request-cloudfront-files" + inRequest = in.InRequest{ + Source: s3resource.Source{ + AccessKeyID: accessKeyID, + SecretAccessKey: secretAccessKey, + CloudfrontURL: os.Getenv("S3_TESTING_CLOUDFRONT_URL"), + RegionName: regionName, + Regexp: filepath.Join(directoryPrefix, "some-file-(.*)"), + }, + Version: s3resource.Version{ + Path: filepath.Join(directoryPrefix, "some-file-2"), + }, + } + + err := json.NewEncoder(stdin).Encode(inRequest) + Ω(err).ShouldNot(HaveOccurred()) + + tempFile, err := ioutil.TempFile("", "file-to-upload") + Ω(err).ShouldNot(HaveOccurred()) + tempFile.Close() + + for i := 1; i <= 3; i++ { + err = ioutil.WriteFile(tempFile.Name(), []byte(fmt.Sprintf("some-file-%d", i)), 0755) + Ω(err).ShouldNot(HaveOccurred()) + + _, err = s3client.UploadFile(bucketName, filepath.Join(directoryPrefix, fmt.Sprintf("some-file-%d", i)), tempFile.Name(), "private", "", "") + Ω(err).ShouldNot(HaveOccurred()) + } + + err = os.Remove(tempFile.Name()) + Ω(err).ShouldNot(HaveOccurred()) + }) + + AfterEach(func() { + for i := 1; i <= 3; i++ { + err := s3client.DeleteFile(bucketName, filepath.Join(directoryPrefix, fmt.Sprintf("some-file-%d", i))) + Ω(err).ShouldNot(HaveOccurred()) + } + }) + + It("downloads the file from CloudFront", func() { + reader := bytes.NewBuffer(session.Out.Contents()) + + var response in.InResponse + err := json.NewDecoder(reader).Decode(&response) + Ω(err).ShouldNot(HaveOccurred()) + + Ω(response).Should(Equal(in.InResponse{ + Version: s3resource.Version{ + Path: "in-request-cloudfront-files/some-file-2", + }, + Metadata: []s3resource.MetadataPair{ + { + Name: "filename", + Value: "some-file-2", + }, + { + Name: "url", + Value: inRequest.Source.CloudfrontURL + "/in-request-cloudfront-files/some-file-2", + }, + }, + })) + + Ω(filepath.Join(destDir, "some-file-2")).Should(BeARegularFile()) + contents, err := ioutil.ReadFile(filepath.Join(destDir, "some-file-2")) + Ω(err).ShouldNot(HaveOccurred()) + Ω(contents).Should(Equal([]byte("some-file-2"))) + + Ω(filepath.Join(destDir, "url")).Should(BeARegularFile()) + urlContents, err := ioutil.ReadFile(filepath.Join(destDir, "url")) + Ω(err).ShouldNot(HaveOccurred()) + Ω(urlContents).Should(Equal([]byte(inRequest.Source.CloudfrontURL + "/in-request-cloudfront-files/some-file-2"))) + }) + }) + + Context("when cloudfront_url is set but has too few dots", func() { + var inRequest in.InRequest + + BeforeEach(func() { + inRequest = in.InRequest{ + Source: s3resource.Source{ + AccessKeyID: accessKeyID, + SecretAccessKey: secretAccessKey, + CloudfrontURL: "https://no-dots-here", + RegionName: regionName, + Regexp: "unused", + }, + Version: s3resource.Version{ + Path: "unused", + }, + } + + expectedExitStatus = 1 + + err := json.NewEncoder(stdin).Encode(inRequest) + Ω(err).ShouldNot(HaveOccurred()) + }) + + It("returns an error", func() { + Ω(session.Err).Should(gbytes.Say(`'https://no-dots-here' doesn't have enough dots \('.'\), a typical format is 'https://d111111abcdef8.cloudfront.net'`)) + }) + }) }) diff --git a/integration/out_test.go b/integration/out_test.go index 07187b32..83b8586d 100644 --- a/integration/out_test.go +++ b/integration/out_test.go @@ -286,7 +286,6 @@ var _ = Describe("out", func() { Ω(session.Err).Should(gbytes.Say("object versioning not enabled")) }) }) - }) Context("with a versioned bucket", func() { diff --git a/s3client.go b/s3client.go index 8fbaac29..f7a711d1 100644 --- a/s3client.go +++ b/s3client.go @@ -288,6 +288,7 @@ func (client *s3client) getBucketContents(bucketName string, prefix string) (map } if *listObjectsResponse.IsTruncated { + prevMarker := marker if listObjectsResponse.NextMarker == nil { // From the s3 docs: If response does not include the // NextMarker and it is truncated, you can use the value of the @@ -297,6 +298,9 @@ func (client *s3client) getBucketContents(bucketName string, prefix string) (map } else { marker = *listObjectsResponse.NextMarker } + if marker == prevMarker { + return nil, errors.New("Unable to list all bucket objects; perhaps this is a CloudFront S3 bucket that needs its `Query String Forwarding and Caching` set to `Forward all, cache based on all`?") + } } else { break }