Skip to content

Commit

Permalink
Merge pull request #48 from cunnie/PR-force-path-style
Browse files Browse the repository at this point in the history
Add `virtual_hosted_style` flag to allow interactions with CloudFront
  • Loading branch information
vito authored Oct 9, 2016
2 parents 38b5bec + bee278e commit 0f38106
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 79 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@ version numbers.
URLs provided are signed.

* `cloudfront_url`: *Optional.* The URL (scheme and domain) of your CloudFront
distribution that is fronting this bucket. This will be used in the `url`
file that is given to the following task.
distribution that is fronting this bucket (e.g
`https://d5yxxxxx.cloudfront.net`). This will affect `in` but not `check`
and `put`. `in` will ignore the `bucket` name setting, exclusively using the
`cloudfront_url`. When configuring CloudFront with versioned buckets, set
`Query String Forwarding and Caching` to `Forward all, cache based on all` to
ensure S3 calls succeed.

* `endpoint`: *Optional.* Custom endpoint for using S3 compatible provider.

Expand All @@ -34,6 +38,7 @@ version numbers.
* `sse_kms_key_id`: *Optional.* The ID of the AWS KMS master encryption key
used for the object.


### File Names

One of the following two options must be specified:
Expand Down
20 changes: 20 additions & 0 deletions cmd/in/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@ import (
"encoding/json"
"os"

"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/concourse/s3-resource"
"github.com/concourse/s3-resource/in"
"net/url"
"strings"
)

func main() {
Expand All @@ -27,6 +31,22 @@ func main() {
request.Source.DisableSSL,
)

if len(request.Source.CloudfrontURL) != 0 {
cloudfrontUrl, err := url.ParseRequestURI(request.Source.CloudfrontURL)
if err != nil {
s3resource.Fatal("parsing 'cloudfront_url'", err)
}
awsConfig.S3ForcePathStyle = aws.Bool(false)

splitResult := strings.Split(cloudfrontUrl.Host, ".")
if len(splitResult) < 2 {
s3resource.Fatal("verifying 'cloudfront_url'", fmt.Errorf("'%s' doesn't have enough dots ('.'), a typical format is 'https://d111111abcdef8.cloudfront.net'", request.Source.CloudfrontURL))
}
request.Source.Bucket = strings.Split(cloudfrontUrl.Host, ".")[0]
fqdn := strings.SplitAfterN(cloudfrontUrl.Host, ".", 2)[1]
awsConfig.Endpoint = aws.String(fmt.Sprintf("%s://%s", cloudfrontUrl.Scheme, fqdn))
}

client := s3resource.NewS3Client(
os.Stderr,
awsConfig,
Expand Down
15 changes: 0 additions & 15 deletions in/in_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"path"
"path/filepath"

"github.com/cloudfoundry/gunk/urljoiner"
"github.com/concourse/s3-resource"
"github.com/concourse/s3-resource/versions"
)
Expand All @@ -20,27 +19,13 @@ type RequestURLProvider struct {
}

func (up *RequestURLProvider) GetURL(request InRequest, remotePath string) string {
if request.Source.CloudfrontURL != "" {
return up.cloudfrontURL(request, remotePath)
}

return up.s3URL(request, remotePath)
}

func (up *RequestURLProvider) s3URL(request InRequest, remotePath string) string {
return up.s3Client.URL(request.Source.Bucket, remotePath, request.Source.Private, request.Version.VersionID)
}

func (up *RequestURLProvider) cloudfrontURL(request InRequest, remotePath string) string {
url := urljoiner.Join(request.Source.CloudfrontURL, remotePath)

if request.Version.VersionID != "" {
url = url + "?versionId=" + request.Version.VersionID
}

return url
}

type InCommand struct {
s3client s3resource.S3Client
urlProvider RequestURLProvider
Expand Down
19 changes: 0 additions & 19 deletions in/in_command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,25 +109,6 @@ var _ = Describe("In Command", func() {
Ω(versionID).Should(BeEmpty())
})

Context("when using a CloudFront domain", func() {
BeforeEach(func() {
request.Source.CloudfrontURL = "https://1234567890.cloudfront.net"
})

It("creates a 'url' file that contains the URL including the CloudFront domain", func() {
urlPath := filepath.Join(destDir, "url")
Ω(urlPath).ShouldNot(ExistOnFilesystem())

_, err := command.Run(destDir, request)
Ω(err).ShouldNot(HaveOccurred())

Ω(urlPath).Should(ExistOnFilesystem())
contents, err := ioutil.ReadFile(urlPath)
Ω(err).ShouldNot(HaveOccurred())
Ω(string(contents)).Should(Equal("https://1234567890.cloudfront.net/files/a-file-1.3.tgz"))
})
})

Context("when configured with private URLs", func() {
BeforeEach(func() {
request.Source.Private = true
Expand Down
171 changes: 129 additions & 42 deletions integration/in_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package integration_test
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"os"
"os/exec"
Expand Down Expand Up @@ -106,37 +107,23 @@ var _ = Describe("in", func() {
Ω(err).ShouldNot(HaveOccurred())
tempFile.Close()

err = ioutil.WriteFile(tempFile.Name(), []byte("some-file-1"), 0755)
Ω(err).ShouldNot(HaveOccurred())

_, err = s3client.UploadFile(bucketName, filepath.Join(directoryPrefix, "some-file-1"), tempFile.Name(), "private", "", "")
Ω(err).ShouldNot(HaveOccurred())

err = ioutil.WriteFile(tempFile.Name(), []byte("some-file-2"), 0755)
Ω(err).ShouldNot(HaveOccurred())

_, err = s3client.UploadFile(bucketName, filepath.Join(directoryPrefix, "some-file-2"), tempFile.Name(), "private", "", "")
Ω(err).ShouldNot(HaveOccurred())

err = ioutil.WriteFile(tempFile.Name(), []byte("some-file-3"), 0755)
Ω(err).ShouldNot(HaveOccurred())
for i := 1; i <= 3; i++ {
err = ioutil.WriteFile(tempFile.Name(), []byte(fmt.Sprintf("some-file-%d", i)), 0755)
Ω(err).ShouldNot(HaveOccurred())

_, err = s3client.UploadFile(bucketName, filepath.Join(directoryPrefix, "some-file-3"), tempFile.Name(), "private", "", "")
Ω(err).ShouldNot(HaveOccurred())
_, err = s3client.UploadFile(bucketName, filepath.Join(directoryPrefix, fmt.Sprintf("some-file-%d", i)), tempFile.Name(), "private", "", "")
Ω(err).ShouldNot(HaveOccurred())
}

err = os.Remove(tempFile.Name())
Ω(err).ShouldNot(HaveOccurred())
})

AfterEach(func() {
err := s3client.DeleteFile(bucketName, filepath.Join(directoryPrefix, "some-file-1"))
Ω(err).ShouldNot(HaveOccurred())

err = s3client.DeleteFile(bucketName, filepath.Join(directoryPrefix, "some-file-2"))
Ω(err).ShouldNot(HaveOccurred())

err = s3client.DeleteFile(bucketName, filepath.Join(directoryPrefix, "some-file-3"))
Ω(err).ShouldNot(HaveOccurred())
for i := 1; i <= 3; i++ {
err := s3client.DeleteFile(bucketName, filepath.Join(directoryPrefix, fmt.Sprintf("some-file-%d", i)))
Ω(err).ShouldNot(HaveOccurred())
}
})

It("downloads the file", func() {
Expand Down Expand Up @@ -201,22 +188,14 @@ var _ = Describe("in", func() {
Ω(err).ShouldNot(HaveOccurred())
tempFile.Close()

err = ioutil.WriteFile(tempFile.Name(), []byte("some-file-1"), 0755)
Ω(err).ShouldNot(HaveOccurred())

_, err = s3client.UploadFile(versionedBucketName, filepath.Join(directoryPrefix, "some-file"), tempFile.Name(), "private", "", "")
Ω(err).ShouldNot(HaveOccurred())

err = ioutil.WriteFile(tempFile.Name(), []byte("some-file-2"), 0755)
Ω(err).ShouldNot(HaveOccurred())

_, err = s3client.UploadFile(versionedBucketName, filepath.Join(directoryPrefix, "some-file"), tempFile.Name(), "private", "", "")
Ω(err).ShouldNot(HaveOccurred())

err = ioutil.WriteFile(tempFile.Name(), []byte("some-file-3"), 0755)
Ω(err).ShouldNot(HaveOccurred())
for i := 1; i <= 3; i++ {
err = ioutil.WriteFile(tempFile.Name(), []byte(fmt.Sprintf("some-file-%d", i)), 0755)
Ω(err).ShouldNot(HaveOccurred())

_, err = s3client.UploadFile(versionedBucketName, filepath.Join(directoryPrefix, "some-file"), tempFile.Name(), "private", "", "")
_, err = s3client.UploadFile(versionedBucketName, filepath.Join(directoryPrefix, "some-file"), tempFile.Name(), "private", "", "")
Ω(err).ShouldNot(HaveOccurred())
}
err = os.Remove(tempFile.Name())
Ω(err).ShouldNot(HaveOccurred())

versions, err := s3client.BucketFileVersions(versionedBucketName, filepath.Join(directoryPrefix, "some-file"))
Expand All @@ -226,9 +205,6 @@ var _ = Describe("in", func() {

err = json.NewEncoder(stdin).Encode(inRequest)
Ω(err).ShouldNot(HaveOccurred())

err = os.Remove(tempFile.Name())
Ω(err).ShouldNot(HaveOccurred())
})

AfterEach(func() {
Expand Down Expand Up @@ -280,4 +256,115 @@ var _ = Describe("in", func() {
})
})

Context("when cloudfront_url is set", func() {
var inRequest in.InRequest
var directoryPrefix string

BeforeEach(func() {
if len(os.Getenv("S3_TESTING_CLOUDFRONT_URL")) == 0 {
Skip("'S3_TESTING_CLOUDFRONT_URL' is not set, skipping.")
}

directoryPrefix = "in-request-cloudfront-files"
inRequest = in.InRequest{
Source: s3resource.Source{
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
CloudfrontURL: os.Getenv("S3_TESTING_CLOUDFRONT_URL"),
RegionName: regionName,
Regexp: filepath.Join(directoryPrefix, "some-file-(.*)"),
},
Version: s3resource.Version{
Path: filepath.Join(directoryPrefix, "some-file-2"),
},
}

err := json.NewEncoder(stdin).Encode(inRequest)
Ω(err).ShouldNot(HaveOccurred())

tempFile, err := ioutil.TempFile("", "file-to-upload")
Ω(err).ShouldNot(HaveOccurred())
tempFile.Close()

for i := 1; i <= 3; i++ {
err = ioutil.WriteFile(tempFile.Name(), []byte(fmt.Sprintf("some-file-%d", i)), 0755)
Ω(err).ShouldNot(HaveOccurred())

_, err = s3client.UploadFile(bucketName, filepath.Join(directoryPrefix, fmt.Sprintf("some-file-%d", i)), tempFile.Name(), "private", "", "")
Ω(err).ShouldNot(HaveOccurred())
}

err = os.Remove(tempFile.Name())
Ω(err).ShouldNot(HaveOccurred())
})

AfterEach(func() {
for i := 1; i <= 3; i++ {
err := s3client.DeleteFile(bucketName, filepath.Join(directoryPrefix, fmt.Sprintf("some-file-%d", i)))
Ω(err).ShouldNot(HaveOccurred())
}
})

It("downloads the file from CloudFront", func() {
reader := bytes.NewBuffer(session.Out.Contents())

var response in.InResponse
err := json.NewDecoder(reader).Decode(&response)
Ω(err).ShouldNot(HaveOccurred())

Ω(response).Should(Equal(in.InResponse{
Version: s3resource.Version{
Path: "in-request-cloudfront-files/some-file-2",
},
Metadata: []s3resource.MetadataPair{
{
Name: "filename",
Value: "some-file-2",
},
{
Name: "url",
Value: inRequest.Source.CloudfrontURL + "/in-request-cloudfront-files/some-file-2",
},
},
}))

Ω(filepath.Join(destDir, "some-file-2")).Should(BeARegularFile())
contents, err := ioutil.ReadFile(filepath.Join(destDir, "some-file-2"))
Ω(err).ShouldNot(HaveOccurred())
Ω(contents).Should(Equal([]byte("some-file-2")))

Ω(filepath.Join(destDir, "url")).Should(BeARegularFile())
urlContents, err := ioutil.ReadFile(filepath.Join(destDir, "url"))
Ω(err).ShouldNot(HaveOccurred())
Ω(urlContents).Should(Equal([]byte(inRequest.Source.CloudfrontURL + "/in-request-cloudfront-files/some-file-2")))
})
})

Context("when cloudfront_url is set but has too few dots", func() {
var inRequest in.InRequest

BeforeEach(func() {
inRequest = in.InRequest{
Source: s3resource.Source{
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
CloudfrontURL: "https://no-dots-here",
RegionName: regionName,
Regexp: "unused",
},
Version: s3resource.Version{
Path: "unused",
},
}

expectedExitStatus = 1

err := json.NewEncoder(stdin).Encode(inRequest)
Ω(err).ShouldNot(HaveOccurred())
})

It("returns an error", func() {
Ω(session.Err).Should(gbytes.Say(`'https://no-dots-here' doesn't have enough dots \('.'\), a typical format is 'https://d111111abcdef8.cloudfront.net'`))
})
})
})
1 change: 0 additions & 1 deletion integration/out_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ var _ = Describe("out", func() {
Ω(session.Err).Should(gbytes.Say("object versioning not enabled"))
})
})

})

Context("with a versioned bucket", func() {
Expand Down
4 changes: 4 additions & 0 deletions s3client.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ func (client *s3client) getBucketContents(bucketName string, prefix string) (map
}

if *listObjectsResponse.IsTruncated {
prevMarker := marker
if listObjectsResponse.NextMarker == nil {
// From the s3 docs: If response does not include the
// NextMarker and it is truncated, you can use the value of the
Expand All @@ -297,6 +298,9 @@ func (client *s3client) getBucketContents(bucketName string, prefix string) (map
} else {
marker = *listObjectsResponse.NextMarker
}
if marker == prevMarker {
return nil, errors.New("Unable to list all bucket objects; perhaps this is a CloudFront S3 bucket that needs its `Query String Forwarding and Caching` set to `Forward all, cache based on all`?")
}
} else {
break
}
Expand Down

0 comments on commit 0f38106

Please sign in to comment.