Skip to content

Commit

Permalink
Set content length when doing multipart upload (#4)
Browse files Browse the repository at this point in the history
* Set content length when doing multipart upload

* Prevent infinite loop by separating the io.ReadFull()'s error
  • Loading branch information
edigaryev authored Jul 1, 2024
1 parent 775d4d8 commit 86453d5
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 17 deletions.
2 changes: 1 addition & 1 deletion internal/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type RemoteCache interface {
}

type MultipartUpload interface {
UploadPart(ctx context.Context, number int32, r io.Reader) error
UploadPart(ctx context.Context, number int32, r io.Reader, length int64) error
Size(ctx context.Context) (int64, error)
Commit(ctx context.Context) error
Rollback(ctx context.Context) error
Expand Down
13 changes: 7 additions & 6 deletions internal/cache/s3/multipartupload.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@ type MultipartUpload struct {
mtx sync.Mutex
}

func (mu *MultipartUpload) UploadPart(ctx context.Context, number int32, r io.Reader) error {
func (mu *MultipartUpload) UploadPart(ctx context.Context, number int32, r io.Reader, length int64) error {
// Work around https://github.com/aws/aws-sdk-go-v2/issues/2038
opt := s3pkg.WithAPIOptions(v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)

result, err := mu.client.UploadPart(ctx, &s3pkg.UploadPartInput{
Bucket: aws.String(mu.bucket),
Key: aws.String(mu.key),
UploadId: aws.String(mu.uploadID),
PartNumber: aws.Int32(number),
Body: r,
Bucket: aws.String(mu.bucket),
Key: aws.String(mu.key),
UploadId: aws.String(mu.uploadID),
PartNumber: aws.Int32(number),
Body: r,
ContentLength: aws.Int64(length),
}, opt)
if err != nil {
return err
Expand Down
6 changes: 4 additions & 2 deletions internal/cache/s3/s3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ func TestSimple(t *testing.T) {
multipartUpload, err := cache.Put(ctx, "test")
require.NoError(t, err)

require.NoError(t, multipartUpload.UploadPart(ctx, 1, bytes.NewReader(contentBytes)))
require.NoError(t, multipartUpload.UploadPart(ctx, 1, bytes.NewReader(contentBytes),
int64(len(contentBytes))))

size, err := multipartUpload.Size(ctx)
require.NoError(t, err)
Expand All @@ -48,7 +49,8 @@ func TestSimple(t *testing.T) {
newContentsBytes := []byte("Bye bye!")
multipartUpload, err = cache.Put(ctx, "test")
require.NoError(t, err)
require.NoError(t, multipartUpload.UploadPart(ctx, 1, bytes.NewReader(newContentsBytes)))
require.NoError(t, multipartUpload.UploadPart(ctx, 1, bytes.NewReader(newContentsBytes),
int64(len(newContentsBytes))))
require.NoError(t, multipartUpload.Commit(ctx))

// Retrieval of a re-inserted key should yield modified contents
Expand Down
8 changes: 5 additions & 3 deletions internal/server/protocol/ghacache/ghacache.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,16 +142,18 @@ func (cache *GHACache) updateUploadable(c echo.Context) error {
return fail.Fail(c, http.StatusBadRequest, "failed to parse Content-Range header: %v", err)
}

if len(httpRanges) == 0 {
return fail.Fail(c, http.StatusBadRequest, "expected at least one Content-Range value")
if len(httpRanges) != 1 {
return fail.Fail(c, http.StatusBadRequest, "expected exactly one Content-Range value, got %d",
len(httpRanges))
}

partNumber, err := uploadable.RangeToPart.Tell(c.Request().Context(), httpRanges[0].Start, httpRanges[0].Length)
if err != nil {
return fail.Fail(c, http.StatusBadRequest, "%v", err)
}

return uploadable.MultipartUpload.UploadPart(c.Request().Context(), partNumber, c.Request().Body)
return uploadable.MultipartUpload.UploadPart(c.Request().Context(), partNumber, c.Request().Body,
httpRanges[0].Length)
}

func (cache *GHACache) commitUploadable(c echo.Context) error {
Expand Down
11 changes: 6 additions & 5 deletions internal/server/protocol/httpcache/httpcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,18 @@ func (cache *HTTPCache) put(c echo.Context) error {
partNumber := int32(1)

for {
n, err := io.ReadFull(c.Request().Body, buf)
if err != nil && !(errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF)) {
n, readFullErr := io.ReadFull(c.Request().Body, buf)
if readFullErr != nil && !(errors.Is(readFullErr, io.ErrUnexpectedEOF) || errors.Is(readFullErr, io.EOF)) {
return fail.Fail(c, http.StatusInternalServerError, "failed to read data to be uploaded "+
"for cache key %q: %v", key, err)
"for cache key %q: %v", key, readFullErr)
}

if err := multipartUpload.UploadPart(c.Request().Context(), partNumber, bytes.NewReader(buf[:n])); err != nil {
err = multipartUpload.UploadPart(c.Request().Context(), partNumber, bytes.NewReader(buf[:n]), int64(n))
if err != nil {
return err
}

if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
if errors.Is(readFullErr, io.EOF) || errors.Is(readFullErr, io.ErrUnexpectedEOF) {
break
}

Expand Down

0 comments on commit 86453d5

Please sign in to comment.