diff --git a/pkg/verify/signature.go b/pkg/verify/signature.go index 1d96c12..8d5569e 100644 --- a/pkg/verify/signature.go +++ b/pkg/verify/signature.go @@ -23,6 +23,7 @@ import ( "fmt" "hash" "io" + "slices" in_toto "github.com/in-toto/attestation/go/v1" "github.com/secure-systems-lab/go-securesystemslib/dsse" @@ -146,56 +147,40 @@ func verifyEnvelopeWithArtifact(verifier signature.Verifier, envelope EnvelopeCo if err = limitSubjects(statement); err != nil { return err } - - var artifactDigestAlgorithm string - var artifactDigest []byte - - // Determine artifact digest algorithm by looking at the first subject's - // digests. This assumes that if a statement contains multiple subjects, - // they all use the same digest algorithm(s). + // Sanity check (no subjects) if len(statement.Subject) == 0 { return errors.New("no subjects found in statement") } - if len(statement.Subject[0].Digest) == 0 { - return errors.New("no digests found in statement") - } - // Select the strongest digest algorithm available. - for _, alg := range []string{"sha512", "sha384", "sha256"} { - if _, ok := statement.Subject[0].Digest[alg]; ok { - artifactDigestAlgorithm = alg - continue - } - } - if artifactDigestAlgorithm == "" { - return errors.New("could not verify artifact: unsupported digest algorithm") + // determine which hash functions to use + hashFuncs, err := getHashFunctions(statement) + if err != nil { + return fmt.Errorf("could not verify artifact: unable to determine hash functions: %w", err) } // Compute digest of the artifact. - var hasher hash.Hash - switch artifactDigestAlgorithm { - case "sha512": - hasher = crypto.SHA512.New() - case "sha384": - hasher = crypto.SHA384.New() - case "sha256": - hasher = crypto.SHA256.New() - } + hasher := newMultihasher(hashFuncs) _, err = io.Copy(hasher, artifact) if err != nil { return fmt.Errorf("could not verify artifact: unable to calculate digest: %w", err) } - artifactDigest = hasher.Sum(nil) + artifactDigests := hasher.Sum(nil) // Look for artifact digest in statement for _, subject := range statement.Subject { - for alg, digest := range subject.Digest { - hexdigest, err := hex.DecodeString(digest) + for alg, hexdigest := range subject.Digest { + hf, err := algStringToHashFunc(alg) if err != nil { - return fmt.Errorf("could not verify artifact: unable to decode subject digest: %w", err) + continue } - if alg == artifactDigestAlgorithm && bytes.Equal(artifactDigest, hexdigest) { - return nil + if artifactDigest, ok := artifactDigests[hf]; ok { + digest, err := hex.DecodeString(hexdigest) + if err != nil { + continue + } + if bytes.Equal(artifactDigest, digest) { + return nil + } } } } @@ -269,3 +254,110 @@ func limitSubjects(statement *in_toto.Statement) error { } return nil } + +type multihasher struct { + hashfuncs []crypto.Hash + hashes []hash.Hash +} + +func newMultihasher(hashfuncs []crypto.Hash) *multihasher { + hashes := make([]hash.Hash, len(hashfuncs)) + for i := range hashfuncs { + hashes[i] = hashfuncs[i].New() + } + return &multihasher{ + hashfuncs: hashfuncs, + hashes: hashes, + } +} + +func (m *multihasher) Write(p []byte) (n int, err error) { + for i := range m.hashes { + n, err = m.hashes[i].Write(p) + if err != nil { + return + } + } + return +} + +func (m *multihasher) Sum(b []byte) map[crypto.Hash][]byte { + sums := make(map[crypto.Hash][]byte, len(m.hashes)) + for i := range m.hashes { + sums[m.hashfuncs[i]] = m.hashes[i].Sum(b) + } + return sums +} + +func algStringToHashFunc(alg string) (crypto.Hash, error) { + switch alg { + case "sha256": + return crypto.SHA256, nil + case "sha384": + return crypto.SHA384, nil + case "sha512": + return crypto.SHA512, nil + default: + return 0, errors.New("unsupported digest algorithm") + } +} + +// getHashFunctions returns the smallest subset of supported hash functions +// that are needed to verify all subjects in a statement. +func getHashFunctions(statement *in_toto.Statement) ([]crypto.Hash, error) { + if len(statement.Subject) == 0 { + return nil, errors.New("no subjects found in statement") + } + + supportedHashFuncs := []crypto.Hash{crypto.SHA512, crypto.SHA384, crypto.SHA256} + chosenHashFuncs := make([]crypto.Hash, 0, len(supportedHashFuncs)) + subjectHashFuncs := make([][]crypto.Hash, len(statement.Subject)) + + // go through the statement and make a simple data structure to hold the + // list of hash funcs for each subject (subjectHashFuncs) + for i, subject := range statement.Subject { + for alg := range subject.Digest { + hf, err := algStringToHashFunc(alg) + if err != nil { + continue + } + subjectHashFuncs[i] = append(subjectHashFuncs[i], hf) + } + } + + // for each subject, see if we have chosen a compatible hash func, and if + // not, add the first one that is supported + for _, hfs := range subjectHashFuncs { + // if any of the hash funcs are already in chosenHashFuncs, skip + if len(intersection(hfs, chosenHashFuncs)) > 0 { + continue + } + + // check each supported hash func and add it if the subject + // has a digest for it + for _, hf := range supportedHashFuncs { + if slices.Contains(hfs, hf) { + chosenHashFuncs = append(chosenHashFuncs, hf) + break + } + } + } + + if len(chosenHashFuncs) == 0 { + return nil, errors.New("no supported digest algorithms found") + } + + return chosenHashFuncs, nil +} + +func intersection(a, b []crypto.Hash) []crypto.Hash { + var result []crypto.Hash + for _, x := range a { + for _, y := range b { + if x == y { + result = append(result, x) + } + } + } + return result +} diff --git a/pkg/verify/signature_internal_test.go b/pkg/verify/signature_internal_test.go new file mode 100644 index 0000000..68ca7fe --- /dev/null +++ b/pkg/verify/signature_internal_test.go @@ -0,0 +1,108 @@ +// Copyright 2024 The Sigstore Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package verify + +import ( + "crypto" + "crypto/sha256" + "crypto/sha512" + "testing" + + in_toto "github.com/in-toto/attestation/go/v1" + "github.com/stretchr/testify/assert" +) + +func TestMultiHasher(t *testing.T) { + testBytes := []byte("Hello, world!") + hash256 := sha256.Sum256(testBytes) + hash512 := sha512.Sum512(testBytes) + + hasher := newMultihasher([]crypto.Hash{crypto.SHA256, crypto.SHA512}) + _, err := hasher.Write(testBytes) + assert.NoError(t, err) + + hashes := hasher.Sum(nil) + + assert.Equal(t, 2, len(hashes)) + assert.EqualValues(t, hash256[:], hashes[crypto.SHA256]) + assert.EqualValues(t, hash512[:], hashes[crypto.SHA512]) +} + +func makeStatement(subjectalgs [][]string) *in_toto.Statement { + statement := &in_toto.Statement{ + Subject: make([]*in_toto.ResourceDescriptor, len(subjectalgs)), + } + for i, subjectAlg := range subjectalgs { + statement.Subject[i] = &in_toto.ResourceDescriptor{ + Digest: make(map[string]string), + } + for _, digest := range subjectAlg { + // content of digest doesn't matter for this test + statement.Subject[i].Digest[digest] = "foobar" + } + } + return statement +} + +func TestGetHashFunctions(t *testing.T) { + for _, test := range []struct { + name string + algs [][]string + expectOutput []crypto.Hash + expectError bool + }{ + { + name: "choose strongest algorithm", + algs: [][]string{{"sha256", "sha512"}}, + expectOutput: []crypto.Hash{crypto.SHA512}, + }, + { + name: "choose both algorithms", + algs: [][]string{{"sha256"}, {"sha512"}}, + expectOutput: []crypto.Hash{crypto.SHA256, crypto.SHA512}, + }, + { + name: "choose one algorithm", + algs: [][]string{{"sha512"}, {"sha256", "sha512"}}, + expectOutput: []crypto.Hash{crypto.SHA512}, + }, + { + name: "choose two algorithms", + algs: [][]string{{"sha256", "sha512"}, {"sha384", "sha512"}, {"sha256", "sha384"}}, + expectOutput: []crypto.Hash{crypto.SHA512, crypto.SHA384}, + }, + { + name: "ignore unknown algorithm", + algs: [][]string{{"md5", "sha512"}, {"sha256", "sha512"}}, + expectOutput: []crypto.Hash{crypto.SHA512}, + }, + { + name: "no recognized algorithms", + algs: [][]string{{"md5"}, {"sha1"}}, + expectError: true, + }, + } { + t.Run(test.name, func(t *testing.T) { + statement := makeStatement(test.algs) + hfs, err := getHashFunctions(statement) + if test.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectOutput, hfs) + }) + } +}