Skip to content

Commit

Permalink
Support more hash algorithms
Browse files Browse the repository at this point in the history
This adds support for sha1, md5, sha384, sha512 and sha256tree as hash
algorithm in checksum.sri.
  • Loading branch information
moroten committed Oct 10, 2024
1 parent 4d0f0e2 commit 65ba88a
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 35 deletions.
1 change: 1 addition & 0 deletions pkg/fetch/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ go_test(
"@bazel_remote_apis//build/bazel/remote/execution/v2:remote_execution_go_proto",
"@com_github_buildbarn_bb_storage//pkg/blobstore/buffer",
"@com_github_buildbarn_bb_storage//pkg/digest",
"@com_github_buildbarn_bb_storage//pkg/testutil",
"@com_github_golang_mock//gomock",
"@com_github_stretchr_testify//require",
"@org_golang_google_genproto_googleapis_rpc//status",
Expand Down
67 changes: 45 additions & 22 deletions pkg/fetch/http_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package fetch
import (
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"io"
Expand Down Expand Up @@ -50,19 +49,22 @@ func (hf *httpFetcher) FetchBlob(ctx context.Context, req *remoteasset.FetchBlob
// TODO: Address the following fields
// timeout := ptypes.Duration(req.timeout)
// oldestContentAccepted := ptypes.Timestamp(req.oldestContentAccepted)
expectedDigest, err := getChecksumSri(req.Qualifiers)
expectedDigest, digestFunctionEnum, err := getChecksumSri(req.Qualifiers)
if err != nil {
return nil, err
}
if digestFunctionEnum == remoteexecution.DigestFunction_UNKNOWN {
// Default to SHA256 if no digest is provided.
digestFunctionEnum = remoteexecution.DigestFunction_SHA256
}

auth, err := getAuthHeaders(req.Qualifiers)
if err != nil {
return nil, err
}

for _, uri := range req.Uris {

buffer, digest := hf.downloadBlob(ctx, uri, instanceName, expectedDigest, auth)
buffer, digest := hf.downloadBlob(ctx, uri, instanceName, expectedDigest, digestFunctionEnum, auth)
if _, err = buffer.GetSizeBytes(); err != nil {
log.Printf("Error downloading blob with URI %s: %v", uri, err)
continue
Expand Down Expand Up @@ -91,7 +93,7 @@ func (hf *httpFetcher) CheckQualifiers(qualifiers qualifier.Set) qualifier.Set {
return qualifier.Difference(qualifiers, qualifier.NewSet([]string{"checksum.sri", "bazel.auth_headers", "bazel.canonical_id"}))
}

func (hf *httpFetcher) downloadBlob(ctx context.Context, uri string, instanceName bb_digest.InstanceName, expectedDigest string, auth *AuthHeaders) (buffer.Buffer, bb_digest.Digest) {
func (hf *httpFetcher) downloadBlob(ctx context.Context, uri string, instanceName bb_digest.InstanceName, expectedDigest string, digestFunctionEnum remoteexecution.DigestFunction_Value, auth *AuthHeaders) (buffer.Buffer, bb_digest.Digest) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil)
if err != nil {
return buffer.NewBufferFromError(util.StatusWrapWithCode(err, codes.Internal, "Failed to create HTTP request")), bb_digest.BadDigest
Expand All @@ -111,6 +113,11 @@ func (hf *httpFetcher) downloadBlob(ctx context.Context, uri string, instanceNam
return buffer.NewBufferFromError(status.Errorf(codes.Internal, "HTTP request failed with status %#v", resp.Status)), bb_digest.BadDigest
}

digestFunction, err := instanceName.GetDigestFunction(digestFunctionEnum, len(expectedDigest))
if err != nil {
return buffer.NewBufferFromError(util.StatusWrapfWithCode(err, codes.Internal, "Failed to get digest function for instance: %v", instanceName)), bb_digest.BadDigest
}

// Work out the digest of the downloaded data
//
// If the HTTP response includes the content length (indicated by the value
Expand Down Expand Up @@ -138,18 +145,14 @@ func (hf *httpFetcher) downloadBlob(ctx context.Context, uri string, instanceNam
// If we don't know what the hash should be we will need to work out the
// actual hash of the content
if expectedDigest == "" {
hasher := sha256.New()
hasher := digestFunction.NewGenerator(length)
hasher.Write(bodyBytes)
hash := hasher.Sum(nil)
expectedDigest = hex.EncodeToString(hash)
digest := hasher.Sum()
expectedDigest = digest.GetHashString()
}

body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}
digestFunction, err := instanceName.GetDigestFunction(remoteexecution.DigestFunction_UNKNOWN, len(expectedDigest))
if err != nil {
return buffer.NewBufferFromError(util.StatusWrapfWithCode(err, codes.Internal, "Failed to get digest function for instance: %v", instanceName)), bb_digest.BadDigest
}
digest, err := digestFunction.NewDigest(expectedDigest, length)
if err != nil {
return buffer.NewBufferFromError(util.StatusWrapWithCode(err, codes.Internal, "Digest Creation failed")), bb_digest.BadDigest
Expand All @@ -160,21 +163,41 @@ func (hf *httpFetcher) downloadBlob(ctx context.Context, uri string, instanceNam
return buffer.NewCASBufferFromReader(digest, body, buffer.UserProvided), digest
}

func getChecksumSri(qualifiers []*remoteasset.Qualifier) (string, error) {
func getChecksumSri(qualifiers []*remoteasset.Qualifier) (string, remoteexecution.DigestFunction_Value, error) {
hashTypes := map[string]remoteexecution.DigestFunction_Value{
"sha256": remoteexecution.DigestFunction_SHA256,
"sha1": remoteexecution.DigestFunction_SHA1,
"md5": remoteexecution.DigestFunction_MD5,
"sha384": remoteexecution.DigestFunction_SHA384,
"sha512": remoteexecution.DigestFunction_SHA512,
"sha256tree": remoteexecution.DigestFunction_SHA256TREE,
}
expectedDigest := ""
digestFunctionEnum := remoteexecution.DigestFunction_UNKNOWN
for _, qualifier := range qualifiers {
if qualifier.Name == "checksum.sri" {
if strings.HasPrefix(qualifier.Value, "sha256-") {
b64hash := strings.TrimPrefix(qualifier.Value, "sha256-")
decoded, err := base64.StdEncoding.DecodeString(b64hash)
if err != nil {
return "", status.Errorf(codes.InvalidArgument, "Failed to decode checksum as b64 encoded sha256 sum: %s", err.Error())
}
return hex.EncodeToString(decoded), nil
if digestFunctionEnum != remoteexecution.DigestFunction_UNKNOWN {
return "", remoteexecution.DigestFunction_UNKNOWN, status.Errorf(codes.InvalidArgument, "Multiple checksum.sri provided")
}
parts := strings.SplitN(qualifier.Value, "-", 2)
if len(parts) != 2 {
return "", remoteexecution.DigestFunction_UNKNOWN, status.Errorf(codes.InvalidArgument, "Bad checksum.sri hash expression: %s", qualifier.Value)
}
hashName := parts[0]
b64hash := parts[1]
var ok bool
digestFunctionEnum, ok = hashTypes[hashName]
if !ok {
return "", remoteexecution.DigestFunction_UNKNOWN, status.Errorf(codes.InvalidArgument, "Unsupported checksum algorithm %s", hashName)
}
decoded, err := base64.StdEncoding.DecodeString(b64hash)
if err != nil {
return "", remoteexecution.DigestFunction_UNKNOWN, status.Errorf(codes.InvalidArgument, "Failed to decode checksum as base64 encoded %s sum: %s", hashName, err.Error())
}
return "", status.Errorf(codes.InvalidArgument, "Non sha256 checksums are not supported")
expectedDigest = hex.EncodeToString(decoded)
}
}
return "", nil
return expectedDigest, digestFunctionEnum, nil
}

func getAuthHeaders(qualifiers []*remoteasset.Qualifier) (*AuthHeaders, error) {
Expand Down
157 changes: 144 additions & 13 deletions pkg/fetch/http_fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/buildbarn/bb-remote-asset/internal/mock"
"github.com/buildbarn/bb-remote-asset/pkg/fetch"
bb_digest "github.com/buildbarn/bb-storage/pkg/digest"
"github.com/buildbarn/bb-storage/pkg/testutil"

remoteasset "github.com/bazelbuild/remote-apis/build/bazel/remote/asset/v1"
remoteexecution "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2"
Expand Down Expand Up @@ -47,17 +48,70 @@ func (hm *headerMatcher) Matches(x interface{}) bool {
return true
}

func TestHTTPFetcherFetchBlob(t *testing.T) {
func TestHTTPFetcherFetchBlobSuccessSHA256(t *testing.T) {
testHTTPFetcherFetchBlobSuccessWithHasher(
t,
remoteexecution.DigestFunction_SHA256,
"185f8db32271fe25f561a6fc938b2e264306ec304eda518007d1764826381969",
"sha256-GF+NsyJx/iX1Yab8k4suJkMG7DBO2lGAB9F2SCY4GWk=",
)
}

func TestHTTPFetcherFetchBlobSuccessSHA1(t *testing.T) {
testHTTPFetcherFetchBlobSuccessWithHasher(
t,
remoteexecution.DigestFunction_SHA1,
"f7ff9e8b7bb2e09b70935a5d785e0cc5d9d0abf0",
"sha1-9/+ei3uy4Jtwk1pdeF4MxdnQq/A=",
)
}

func TestHTTPFetcherFetchBlobSuccessMD5(t *testing.T) {
testHTTPFetcherFetchBlobSuccessWithHasher(
t,
remoteexecution.DigestFunction_MD5,
"8b1a9953c4611296a827abf8c47804d7",
"md5-ixqZU8RhEpaoJ6v4xHgE1w==",
)
}

func TestHTTPFetcherFetchBlobSuccessSHA384(t *testing.T) {
testHTTPFetcherFetchBlobSuccessWithHasher(
t,
remoteexecution.DigestFunction_SHA384,
"3519fe5ad2c596efe3e276a6f351b8fc0b03db861782490d45f7598ebd0ab5fd5520ed102f38c4a5ec834e98668035fc",
"sha384-NRn+WtLFlu/j4nam81G4/AsD24YXgkkNRfdZjr0Ktf1VIO0QLzjEpeyDTphmgDX8",
)
}

func TestHTTPFetcherFetchBlobSuccessSHA512(t *testing.T) {
testHTTPFetcherFetchBlobSuccessWithHasher(
t,
remoteexecution.DigestFunction_SHA512,
"3615f80c9d293ed7402687f94b22d58e529b8cc7916f8fac7fddf7fbd5af4cf777d3d795a7a00a16bf7e7f3fb9561ee9baae480da9fe7a18769e71886b03f315",
"sha512-NhX4DJ0pPtdAJof5SyLVjlKbjMeRb4+sf933+9WvTPd309eVp6AKFr9+fz+5Vh7puq5IDan+ehh2nnGIawPzFQ==",
)
}

func TestHTTPFetcherFetchBlobSuccessSha256tree(t *testing.T) {
testHTTPFetcherFetchBlobSuccessWithHasher(
t,
remoteexecution.DigestFunction_SHA256TREE,
"35b974ff55d4c41ca000ea35b974ff55d4c41ca000eacf29125544cf29125544",
"sha256tree-Nbl0/1XUxBygAOo1uXT/VdTEHKAA6s8pElVEzykSVUQ=",
)
}

func testHTTPFetcherFetchBlobSuccessWithHasher(t *testing.T, digestFunctionEnum remoteexecution.DigestFunction_Value, hexHash string, sriChecksum string) {
ctrl, ctx := gomock.WithContext(context.Background(), t)

uri := "www.example.com"
request := &remoteasset.FetchBlobRequest{
InstanceName: "",
Uris: []string{uri, "www.another.com"},
Uris: []string{"www.example.com"},
Qualifiers: []*remoteasset.Qualifier{
{
Name: "checksum.sri",
Value: "sha256-GF+NsyJx/iX1Yab8k4suJkMG7DBO2lGAB9F2SCY4GWk=",
Value: sriChecksum,
},
},
}
Expand All @@ -67,12 +121,12 @@ func TestHTTPFetcherFetchBlob(t *testing.T) {
body := mock.NewMockReadCloser(ctrl)
helloDigest := bb_digest.MustNewDigest(
"",
remoteexecution.DigestFunction_SHA256,
"185f8db32271fe25f561a6fc938b2e264306ec304eda518007d1764826381969",
digestFunctionEnum,
hexHash,
5,
)

t.Run("Success", func(t *testing.T) {
t.Run("Success"+helloDigest.GetDigestFunction().GetEnumValue().String(), func(t *testing.T) {
httpDoCall := roundTripper.EXPECT().RoundTrip(gomock.Any()).Return(&http.Response{
Status: "200 Success",
StatusCode: 200,
Expand All @@ -82,7 +136,7 @@ func TestHTTPFetcherFetchBlob(t *testing.T) {
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(httpDoCall)

response, err := HTTPFetcher.FetchBlob(ctx, request)
require.Nil(t, err)
require.NoError(t, err)
require.True(t, proto.Equal(response.BlobDigest, helloDigest.GetProto()))
require.Equal(t, response.Status.Code, int32(codes.OK))
})
Expand All @@ -102,10 +156,36 @@ func TestHTTPFetcherFetchBlob(t *testing.T) {
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(bodyCloseCall)

response, err := HTTPFetcher.FetchBlob(ctx, request)
require.Nil(t, err)
require.NoError(t, err)
require.True(t, proto.Equal(response.BlobDigest, helloDigest.GetProto()))
require.Equal(t, response.Status.Code, int32(codes.OK))
})
}

func TestHTTPFetcherFetchBlob(t *testing.T) {
ctrl, ctx := gomock.WithContext(context.Background(), t)

uri := "www.example.com"
request := &remoteasset.FetchBlobRequest{
InstanceName: "",
Uris: []string{uri, "www.another.com"},
Qualifiers: []*remoteasset.Qualifier{
{
Name: "checksum.sri",
Value: "sha256-GF+NsyJx/iX1Yab8k4suJkMG7DBO2lGAB9F2SCY4GWk=",
},
},
}
casBlobAccess := mock.NewMockBlobAccess(ctrl)
roundTripper := mock.NewMockRoundTripper(ctrl)
HTTPFetcher := fetch.NewHTTPFetcher(&http.Client{Transport: roundTripper}, casBlobAccess)
body := mock.NewMockReadCloser(ctrl)
helloDigest := bb_digest.MustNewDigest(
"",
remoteexecution.DigestFunction_SHA256,
"185f8db32271fe25f561a6fc938b2e264306ec304eda518007d1764826381969",
5,
)

t.Run("SuccessNoExpectedDigest", func(t *testing.T) {
request := &remoteasset.FetchBlobRequest{
Expand All @@ -127,7 +207,7 @@ func TestHTTPFetcherFetchBlob(t *testing.T) {
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(bodyCloseCall)

response, err := HTTPFetcher.FetchBlob(ctx, request)
require.Nil(t, err)
require.NoError(t, err)
require.True(t, proto.Equal(response.BlobDigest, helloDigest.GetProto()))
require.Equal(t, response.Status.Code, int32(codes.OK))
})
Expand All @@ -152,11 +232,62 @@ func TestHTTPFetcherFetchBlob(t *testing.T) {
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(bodyCloseCall)

response, err := HTTPFetcher.FetchBlob(ctx, request)
require.Nil(t, err)
require.NoError(t, err)
require.True(t, proto.Equal(response.BlobDigest, helloDigest.GetProto()))
require.Equal(t, response.Status.Code, int32(codes.OK))
})

t.Run("UnknownChecksumSriAlgo", func(t *testing.T) {
request := &remoteasset.FetchBlobRequest{
InstanceName: "",
Uris: []string{uri, "www.another.com"},
Qualifiers: []*remoteasset.Qualifier{
{
Name: "checksum.sri",
Value: "sha0-GF+NsyJx/iX1Yab8k4suJkMG7DBO2lGAB9F2SCY4GWk=",
},
},
}

response, err := HTTPFetcher.FetchBlob(ctx, request)
testutil.RequireEqualStatus(t, status.Error(codes.InvalidArgument, "Unsupported checksum algorithm sha0"), err)
require.Nil(t, response)
})

t.Run("BadChecksumSriAlgo", func(t *testing.T) {
request := &remoteasset.FetchBlobRequest{
InstanceName: "",
Uris: []string{uri, "www.another.com"},
Qualifiers: []*remoteasset.Qualifier{
{
Name: "checksum.sri",
Value: "no_dash",
},
},
}

response, err := HTTPFetcher.FetchBlob(ctx, request)
testutil.RequireEqualStatus(t, status.Error(codes.InvalidArgument, "Bad checksum.sri hash expression: no_dash"), err)
require.Nil(t, response)
})

t.Run("BadChecksumSriBase64Value", func(t *testing.T) {
request := &remoteasset.FetchBlobRequest{
InstanceName: "",
Uris: []string{uri, "www.another.com"},
Qualifiers: []*remoteasset.Qualifier{
{
Name: "checksum.sri",
Value: "sha256-no-base64",
},
},
}

response, err := HTTPFetcher.FetchBlob(ctx, request)
testutil.RequireEqualStatus(t, status.Error(codes.InvalidArgument, "Failed to decode checksum as base64 encoded sha256 sum: illegal base64 data at input byte 2"), err)
require.Nil(t, response)
})

t.Run("OneFailOneSuccess", func(t *testing.T) {
httpFailCall := roundTripper.EXPECT().RoundTrip(gomock.Any()).Return(&http.Response{
Status: "404 Not Found",
Expand All @@ -171,7 +302,7 @@ func TestHTTPFetcherFetchBlob(t *testing.T) {
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(httpSuccessCall)

response, err := HTTPFetcher.FetchBlob(ctx, request)
require.Nil(t, err)
require.NoError(t, err)
require.True(t, proto.Equal(response.BlobDigest, helloDigest.GetProto()))
require.Equal(t, response.Status.Code, int32(codes.OK))
})
Expand Down Expand Up @@ -216,7 +347,7 @@ func TestHTTPFetcherFetchBlob(t *testing.T) {
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(httpDoCall)

response, err := HTTPFetcher.FetchBlob(ctx, request)
require.Nil(t, err)
require.NoError(t, err)
require.True(t, proto.Equal(response.BlobDigest, helloDigest.GetProto()))
require.Equal(t, response.Status.Code, int32(codes.OK))
})
Expand Down

0 comments on commit 65ba88a

Please sign in to comment.