Skip to content

Commit

Permalink
Support more hash algorithms (#47)
Browse files Browse the repository at this point in the history
* Support more hash algorithms

This adds support for sha1, md5, sha384, sha512 and sha256tree as hash
algorithm in checksum.sri.

* Fix style
  • Loading branch information
moroten authored Oct 14, 2024
1 parent 4d0f0e2 commit 5a41232
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 35 deletions.
1 change: 1 addition & 0 deletions pkg/fetch/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ go_test(
"@bazel_remote_apis//build/bazel/remote/execution/v2:remote_execution_go_proto",
"@com_github_buildbarn_bb_storage//pkg/blobstore/buffer",
"@com_github_buildbarn_bb_storage//pkg/digest",
"@com_github_buildbarn_bb_storage//pkg/testutil",
"@com_github_golang_mock//gomock",
"@com_github_stretchr_testify//require",
"@org_golang_google_genproto_googleapis_rpc//status",
Expand Down
67 changes: 45 additions & 22 deletions pkg/fetch/http_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package fetch
import (
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"io"
Expand Down Expand Up @@ -50,19 +49,22 @@ func (hf *httpFetcher) FetchBlob(ctx context.Context, req *remoteasset.FetchBlob
// TODO: Address the following fields
// timeout := ptypes.Duration(req.timeout)
// oldestContentAccepted := ptypes.Timestamp(req.oldestContentAccepted)
expectedDigest, err := getChecksumSri(req.Qualifiers)
expectedDigest, digestFunctionEnum, err := getChecksumSri(req.Qualifiers)
if err != nil {
return nil, err
}
if digestFunctionEnum == remoteexecution.DigestFunction_UNKNOWN {
// Default to SHA256 if no digest is provided.
digestFunctionEnum = remoteexecution.DigestFunction_SHA256
}

auth, err := getAuthHeaders(req.Qualifiers)
if err != nil {
return nil, err
}

for _, uri := range req.Uris {

buffer, digest := hf.downloadBlob(ctx, uri, instanceName, expectedDigest, auth)
buffer, digest := hf.downloadBlob(ctx, uri, instanceName, expectedDigest, digestFunctionEnum, auth)
if _, err = buffer.GetSizeBytes(); err != nil {
log.Printf("Error downloading blob with URI %s: %v", uri, err)
continue
Expand Down Expand Up @@ -91,7 +93,7 @@ func (hf *httpFetcher) CheckQualifiers(qualifiers qualifier.Set) qualifier.Set {
return qualifier.Difference(qualifiers, qualifier.NewSet([]string{"checksum.sri", "bazel.auth_headers", "bazel.canonical_id"}))
}

func (hf *httpFetcher) downloadBlob(ctx context.Context, uri string, instanceName bb_digest.InstanceName, expectedDigest string, auth *AuthHeaders) (buffer.Buffer, bb_digest.Digest) {
func (hf *httpFetcher) downloadBlob(ctx context.Context, uri string, instanceName bb_digest.InstanceName, expectedDigest string, digestFunctionEnum remoteexecution.DigestFunction_Value, auth *AuthHeaders) (buffer.Buffer, bb_digest.Digest) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil)
if err != nil {
return buffer.NewBufferFromError(util.StatusWrapWithCode(err, codes.Internal, "Failed to create HTTP request")), bb_digest.BadDigest
Expand All @@ -111,6 +113,11 @@ func (hf *httpFetcher) downloadBlob(ctx context.Context, uri string, instanceNam
return buffer.NewBufferFromError(status.Errorf(codes.Internal, "HTTP request failed with status %#v", resp.Status)), bb_digest.BadDigest
}

digestFunction, err := instanceName.GetDigestFunction(digestFunctionEnum, len(expectedDigest))
if err != nil {
return buffer.NewBufferFromError(util.StatusWrapfWithCode(err, codes.Internal, "Failed to get digest function for instance: %v", instanceName)), bb_digest.BadDigest
}

// Work out the digest of the downloaded data
//
// If the HTTP response includes the content length (indicated by the value
Expand Down Expand Up @@ -138,18 +145,14 @@ func (hf *httpFetcher) downloadBlob(ctx context.Context, uri string, instanceNam
// If we don't know what the hash should be we will need to work out the
// actual hash of the content
if expectedDigest == "" {
hasher := sha256.New()
hasher := digestFunction.NewGenerator(length)
hasher.Write(bodyBytes)
hash := hasher.Sum(nil)
expectedDigest = hex.EncodeToString(hash)
digest := hasher.Sum()
expectedDigest = digest.GetHashString()
}

body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}
digestFunction, err := instanceName.GetDigestFunction(remoteexecution.DigestFunction_UNKNOWN, len(expectedDigest))
if err != nil {
return buffer.NewBufferFromError(util.StatusWrapfWithCode(err, codes.Internal, "Failed to get digest function for instance: %v", instanceName)), bb_digest.BadDigest
}
digest, err := digestFunction.NewDigest(expectedDigest, length)
if err != nil {
return buffer.NewBufferFromError(util.StatusWrapWithCode(err, codes.Internal, "Digest Creation failed")), bb_digest.BadDigest
Expand All @@ -160,21 +163,41 @@ func (hf *httpFetcher) downloadBlob(ctx context.Context, uri string, instanceNam
return buffer.NewCASBufferFromReader(digest, body, buffer.UserProvided), digest
}

func getChecksumSri(qualifiers []*remoteasset.Qualifier) (string, error) {
func getChecksumSri(qualifiers []*remoteasset.Qualifier) (string, remoteexecution.DigestFunction_Value, error) {
hashTypes := map[string]remoteexecution.DigestFunction_Value{
"sha256": remoteexecution.DigestFunction_SHA256,
"sha1": remoteexecution.DigestFunction_SHA1,
"md5": remoteexecution.DigestFunction_MD5,
"sha384": remoteexecution.DigestFunction_SHA384,
"sha512": remoteexecution.DigestFunction_SHA512,
"sha256tree": remoteexecution.DigestFunction_SHA256TREE,
}
expectedDigest := ""
digestFunctionEnum := remoteexecution.DigestFunction_UNKNOWN
for _, qualifier := range qualifiers {
if qualifier.Name == "checksum.sri" {
if strings.HasPrefix(qualifier.Value, "sha256-") {
b64hash := strings.TrimPrefix(qualifier.Value, "sha256-")
decoded, err := base64.StdEncoding.DecodeString(b64hash)
if err != nil {
return "", status.Errorf(codes.InvalidArgument, "Failed to decode checksum as b64 encoded sha256 sum: %s", err.Error())
}
return hex.EncodeToString(decoded), nil
if digestFunctionEnum != remoteexecution.DigestFunction_UNKNOWN {
return "", remoteexecution.DigestFunction_UNKNOWN, status.Errorf(codes.InvalidArgument, "Multiple checksum.sri provided")
}
parts := strings.SplitN(qualifier.Value, "-", 2)
if len(parts) != 2 {
return "", remoteexecution.DigestFunction_UNKNOWN, status.Errorf(codes.InvalidArgument, "Bad checksum.sri hash expression: %s", qualifier.Value)
}
hashName := parts[0]
b64hash := parts[1]
var ok bool
digestFunctionEnum, ok = hashTypes[hashName]
if !ok {
return "", remoteexecution.DigestFunction_UNKNOWN, status.Errorf(codes.InvalidArgument, "Unsupported checksum algorithm %s", hashName)
}
decoded, err := base64.StdEncoding.DecodeString(b64hash)
if err != nil {
return "", remoteexecution.DigestFunction_UNKNOWN, status.Errorf(codes.InvalidArgument, "Failed to decode checksum as base64 encoded %s sum: %s", hashName, err.Error())
}
return "", status.Errorf(codes.InvalidArgument, "Non sha256 checksums are not supported")
expectedDigest = hex.EncodeToString(decoded)
}
}
return "", nil
return expectedDigest, digestFunctionEnum, nil
}

func getAuthHeaders(qualifiers []*remoteasset.Qualifier) (*AuthHeaders, error) {
Expand Down
157 changes: 144 additions & 13 deletions pkg/fetch/http_fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/buildbarn/bb-remote-asset/internal/mock"
"github.com/buildbarn/bb-remote-asset/pkg/fetch"
bb_digest "github.com/buildbarn/bb-storage/pkg/digest"
"github.com/buildbarn/bb-storage/pkg/testutil"

remoteasset "github.com/bazelbuild/remote-apis/build/bazel/remote/asset/v1"
remoteexecution "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2"
Expand Down Expand Up @@ -47,17 +48,70 @@ func (hm *headerMatcher) Matches(x interface{}) bool {
return true
}

func TestHTTPFetcherFetchBlob(t *testing.T) {
func TestHTTPFetcherFetchBlobSuccessSHA256(t *testing.T) {
testHTTPFetcherFetchBlobSuccessWithHasher(
t,
remoteexecution.DigestFunction_SHA256,
"185f8db32271fe25f561a6fc938b2e264306ec304eda518007d1764826381969",
"sha256-GF+NsyJx/iX1Yab8k4suJkMG7DBO2lGAB9F2SCY4GWk=",
)
}

func TestHTTPFetcherFetchBlobSuccessSHA1(t *testing.T) {
testHTTPFetcherFetchBlobSuccessWithHasher(
t,
remoteexecution.DigestFunction_SHA1,
"f7ff9e8b7bb2e09b70935a5d785e0cc5d9d0abf0",
"sha1-9/+ei3uy4Jtwk1pdeF4MxdnQq/A=",
)
}

func TestHTTPFetcherFetchBlobSuccessMD5(t *testing.T) {
testHTTPFetcherFetchBlobSuccessWithHasher(
t,
remoteexecution.DigestFunction_MD5,
"8b1a9953c4611296a827abf8c47804d7",
"md5-ixqZU8RhEpaoJ6v4xHgE1w==",
)
}

func TestHTTPFetcherFetchBlobSuccessSHA384(t *testing.T) {
testHTTPFetcherFetchBlobSuccessWithHasher(
t,
remoteexecution.DigestFunction_SHA384,
"3519fe5ad2c596efe3e276a6f351b8fc0b03db861782490d45f7598ebd0ab5fd5520ed102f38c4a5ec834e98668035fc",
"sha384-NRn+WtLFlu/j4nam81G4/AsD24YXgkkNRfdZjr0Ktf1VIO0QLzjEpeyDTphmgDX8",
)
}

func TestHTTPFetcherFetchBlobSuccessSHA512(t *testing.T) {
testHTTPFetcherFetchBlobSuccessWithHasher(
t,
remoteexecution.DigestFunction_SHA512,
"3615f80c9d293ed7402687f94b22d58e529b8cc7916f8fac7fddf7fbd5af4cf777d3d795a7a00a16bf7e7f3fb9561ee9baae480da9fe7a18769e71886b03f315",
"sha512-NhX4DJ0pPtdAJof5SyLVjlKbjMeRb4+sf933+9WvTPd309eVp6AKFr9+fz+5Vh7puq5IDan+ehh2nnGIawPzFQ==",
)
}

func TestHTTPFetcherFetchBlobSuccessSha256tree(t *testing.T) {
testHTTPFetcherFetchBlobSuccessWithHasher(
t,
remoteexecution.DigestFunction_SHA256TREE,
"35b974ff55d4c41ca000ea35b974ff55d4c41ca000eacf29125544cf29125544",
"sha256tree-Nbl0/1XUxBygAOo1uXT/VdTEHKAA6s8pElVEzykSVUQ=",
)
}

func testHTTPFetcherFetchBlobSuccessWithHasher(t *testing.T, digestFunctionEnum remoteexecution.DigestFunction_Value, hexHash, sriChecksum string) {
ctrl, ctx := gomock.WithContext(context.Background(), t)

uri := "www.example.com"
request := &remoteasset.FetchBlobRequest{
InstanceName: "",
Uris: []string{uri, "www.another.com"},
Uris: []string{"www.example.com"},
Qualifiers: []*remoteasset.Qualifier{
{
Name: "checksum.sri",
Value: "sha256-GF+NsyJx/iX1Yab8k4suJkMG7DBO2lGAB9F2SCY4GWk=",
Value: sriChecksum,
},
},
}
Expand All @@ -67,12 +121,12 @@ func TestHTTPFetcherFetchBlob(t *testing.T) {
body := mock.NewMockReadCloser(ctrl)
helloDigest := bb_digest.MustNewDigest(
"",
remoteexecution.DigestFunction_SHA256,
"185f8db32271fe25f561a6fc938b2e264306ec304eda518007d1764826381969",
digestFunctionEnum,
hexHash,
5,
)

t.Run("Success", func(t *testing.T) {
t.Run("Success"+helloDigest.GetDigestFunction().GetEnumValue().String(), func(t *testing.T) {
httpDoCall := roundTripper.EXPECT().RoundTrip(gomock.Any()).Return(&http.Response{
Status: "200 Success",
StatusCode: 200,
Expand All @@ -82,7 +136,7 @@ func TestHTTPFetcherFetchBlob(t *testing.T) {
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(httpDoCall)

response, err := HTTPFetcher.FetchBlob(ctx, request)
require.Nil(t, err)
require.NoError(t, err)
require.True(t, proto.Equal(response.BlobDigest, helloDigest.GetProto()))
require.Equal(t, response.Status.Code, int32(codes.OK))
})
Expand All @@ -102,10 +156,36 @@ func TestHTTPFetcherFetchBlob(t *testing.T) {
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(bodyCloseCall)

response, err := HTTPFetcher.FetchBlob(ctx, request)
require.Nil(t, err)
require.NoError(t, err)
require.True(t, proto.Equal(response.BlobDigest, helloDigest.GetProto()))
require.Equal(t, response.Status.Code, int32(codes.OK))
})
}

func TestHTTPFetcherFetchBlob(t *testing.T) {
ctrl, ctx := gomock.WithContext(context.Background(), t)

uri := "www.example.com"
request := &remoteasset.FetchBlobRequest{
InstanceName: "",
Uris: []string{uri, "www.another.com"},
Qualifiers: []*remoteasset.Qualifier{
{
Name: "checksum.sri",
Value: "sha256-GF+NsyJx/iX1Yab8k4suJkMG7DBO2lGAB9F2SCY4GWk=",
},
},
}
casBlobAccess := mock.NewMockBlobAccess(ctrl)
roundTripper := mock.NewMockRoundTripper(ctrl)
HTTPFetcher := fetch.NewHTTPFetcher(&http.Client{Transport: roundTripper}, casBlobAccess)
body := mock.NewMockReadCloser(ctrl)
helloDigest := bb_digest.MustNewDigest(
"",
remoteexecution.DigestFunction_SHA256,
"185f8db32271fe25f561a6fc938b2e264306ec304eda518007d1764826381969",
5,
)

t.Run("SuccessNoExpectedDigest", func(t *testing.T) {
request := &remoteasset.FetchBlobRequest{
Expand All @@ -127,7 +207,7 @@ func TestHTTPFetcherFetchBlob(t *testing.T) {
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(bodyCloseCall)

response, err := HTTPFetcher.FetchBlob(ctx, request)
require.Nil(t, err)
require.NoError(t, err)
require.True(t, proto.Equal(response.BlobDigest, helloDigest.GetProto()))
require.Equal(t, response.Status.Code, int32(codes.OK))
})
Expand All @@ -152,11 +232,62 @@ func TestHTTPFetcherFetchBlob(t *testing.T) {
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(bodyCloseCall)

response, err := HTTPFetcher.FetchBlob(ctx, request)
require.Nil(t, err)
require.NoError(t, err)
require.True(t, proto.Equal(response.BlobDigest, helloDigest.GetProto()))
require.Equal(t, response.Status.Code, int32(codes.OK))
})

t.Run("UnknownChecksumSriAlgo", func(t *testing.T) {
request := &remoteasset.FetchBlobRequest{
InstanceName: "",
Uris: []string{uri, "www.another.com"},
Qualifiers: []*remoteasset.Qualifier{
{
Name: "checksum.sri",
Value: "sha0-GF+NsyJx/iX1Yab8k4suJkMG7DBO2lGAB9F2SCY4GWk=",
},
},
}

response, err := HTTPFetcher.FetchBlob(ctx, request)
testutil.RequireEqualStatus(t, status.Error(codes.InvalidArgument, "Unsupported checksum algorithm sha0"), err)
require.Nil(t, response)
})

t.Run("BadChecksumSriAlgo", func(t *testing.T) {
request := &remoteasset.FetchBlobRequest{
InstanceName: "",
Uris: []string{uri, "www.another.com"},
Qualifiers: []*remoteasset.Qualifier{
{
Name: "checksum.sri",
Value: "no_dash",
},
},
}

response, err := HTTPFetcher.FetchBlob(ctx, request)
testutil.RequireEqualStatus(t, status.Error(codes.InvalidArgument, "Bad checksum.sri hash expression: no_dash"), err)
require.Nil(t, response)
})

t.Run("BadChecksumSriBase64Value", func(t *testing.T) {
request := &remoteasset.FetchBlobRequest{
InstanceName: "",
Uris: []string{uri, "www.another.com"},
Qualifiers: []*remoteasset.Qualifier{
{
Name: "checksum.sri",
Value: "sha256-no-base64",
},
},
}

response, err := HTTPFetcher.FetchBlob(ctx, request)
testutil.RequireEqualStatus(t, status.Error(codes.InvalidArgument, "Failed to decode checksum as base64 encoded sha256 sum: illegal base64 data at input byte 2"), err)
require.Nil(t, response)
})

t.Run("OneFailOneSuccess", func(t *testing.T) {
httpFailCall := roundTripper.EXPECT().RoundTrip(gomock.Any()).Return(&http.Response{
Status: "404 Not Found",
Expand All @@ -171,7 +302,7 @@ func TestHTTPFetcherFetchBlob(t *testing.T) {
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(httpSuccessCall)

response, err := HTTPFetcher.FetchBlob(ctx, request)
require.Nil(t, err)
require.NoError(t, err)
require.True(t, proto.Equal(response.BlobDigest, helloDigest.GetProto()))
require.Equal(t, response.Status.Code, int32(codes.OK))
})
Expand Down Expand Up @@ -216,7 +347,7 @@ func TestHTTPFetcherFetchBlob(t *testing.T) {
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(httpDoCall)

response, err := HTTPFetcher.FetchBlob(ctx, request)
require.Nil(t, err)
require.NoError(t, err)
require.True(t, proto.Equal(response.BlobDigest, helloDigest.GetProto()))
require.Equal(t, response.Status.Code, int32(codes.OK))
})
Expand Down

0 comments on commit 5a41232

Please sign in to comment.