From 00c59d373f4be90ed44cb55f3c8f4585329d6862 Mon Sep 17 00:00:00 2001 From: Priyanshu Thapliyal Date: Wed, 25 Dec 2024 00:02:28 +0530 Subject: [PATCH] Add methods to directly add Reference Values and Endorsed Values Signed-off-by: Priyanshu Thapliyal --- comid/comid.go | 142 +++++++++++++++++++++++++++++++++++------- comid/comid_test.go | 30 +++++++++ comid/digests.go | 63 +++++++++---------- comid/hashalg.go | 77 +++++++++++++++++++++++ comid/hashalg_test.go | 108 ++++++++++++++++++++++++++++++++ comid/valuetriple.go | 31 +++++++++ 6 files changed, 398 insertions(+), 53 deletions(-) create mode 100644 comid/hashalg.go create mode 100644 comid/hashalg_test.go diff --git a/comid/comid.go b/comid/comid.go index 116100f4..fda09be1 100644 --- a/comid/comid.go +++ b/comid/comid.go @@ -243,27 +243,27 @@ func (o *Comid) AddDevIdentityKey(val KeyTriple) *Comid { } func (o Comid) Valid() error { - if err := o.TagIdentity.Valid(); err != nil { - return fmt.Errorf("tag-identity validation failed: %w", err) - } - - if o.Entities != nil { - if err := o.Entities.Valid(); err != nil { - return fmt.Errorf("entities validation failed: %w", err) - } - } - - if o.LinkedTags != nil { - if err := o.LinkedTags.Valid(); err != nil { - return fmt.Errorf("linked-tags validation failed: %w", err) - } - } - - if err := o.Triples.Valid(); err != nil { - return fmt.Errorf("triples validation failed: %w", err) - } - - return o.Extensions.validComid(&o) + if err := o.TagIdentity.Valid(); err != nil { + return fmt.Errorf("tag-identity validation failed: %v", err) // Changed %w to %v + } + + if o.Entities != nil { + if err := o.Entities.Valid(); err != nil { + return fmt.Errorf("entities validation failed: %v", err) // Changed %w to %v + } + } + + if o.LinkedTags != nil { + if err := o.LinkedTags.Valid(); err != nil { + return fmt.Errorf("linked-tags validation failed: %v", err) // Changed %w to %v + } + } + + if err := o.Triples.Valid(); err != nil { + return fmt.Errorf("triples validation failed: %v", err) // Changed %w to %v + } + + return o.Extensions.validComid(&o) } // ToCBOR serializes the target Comid to CBOR @@ -321,3 +321,103 @@ func (o Comid) ToJSONPretty(indent string) ([]byte, error) { return json.MarshalIndent(&o, "", indent) } + +// AddSimpleReferenceValue adds a reference value with a single measurement +func (o *Comid) AddSimpleReferenceValue(env Environment, measurement *Measurement) error { + if err := env.Valid(); err != nil { + return fmt.Errorf("invalid environment: %w", err) + } + + if measurement == nil { + return fmt.Errorf("measurement cannot be nil") + } + + if o.Triples.ReferenceValues == nil { + o.Triples.ReferenceValues = NewValueTriples() + } + + builder := NewReferenceValueBuilder(). + WithEnvironment(env). + WithMeasurement(measurement) + + triple, err := builder.Build() + if err != nil { + return fmt.Errorf("building reference value: %w", err) + } + + if res := o.AddReferenceValue(*triple); res == nil { + return fmt.Errorf("failed to add reference value") + } + + return nil +} + +func (o *Comid) AddDigestReferenceValue(env Environment, alg string, digest []byte) error { + if len(digest) == 0 { + return fmt.Errorf("digest cannot be empty") + } + hashAlg := HashAlgFromString(alg) + if !hashAlg.Valid() { + return fmt.Errorf("unrecognized algorithm %q", alg) + } + m := &Measurement{ + Val: Mval{ + Digests: NewDigests(), + }, + } + if m.Val.Digests.AddDigest(hashAlg.ToUint64(), digest) == nil { + return fmt.Errorf("failed to create hash entry") + } + return o.AddSimpleReferenceValue(env, m) +} + +// AddRawReferenceValue adds a reference value with raw measurement data +func (o *Comid) AddRawReferenceValue(env Environment, raw []byte) error { + if len(raw) == 0 { + return fmt.Errorf("raw value cannot be empty") + } + + m := &Measurement{ + Val: Mval{ + RawValue: NewRawValue().SetBytes(raw), + }, + } + + return o.AddSimpleReferenceValue(env, m) +} + +// AddReferenceValueDirect adds a reference value directly to the reference-triples list without creating instances for Measurement and ValueTriples. +func (o *Comid) AddReferenceValueDirect(environment Environment, measurements Measurements) *Comid { + if o != nil { + val := ValueTriple{ + Environment: environment, + Measurements: measurements, + } + if o.Triples.ReferenceValues == nil { + o.Triples.ReferenceValues = NewValueTriples() + } + + if o.Triples.AddReferenceValue(val) == nil { + return nil + } + } + return o +} + +// AddEndorsedValueDirect adds an endorsed value directly to the endorsed-triples list without creating instances for Measurement and ValueTriples. +func (o *Comid) AddEndorsedValueDirect(environment Environment, measurements Measurements) *Comid { + if o != nil { + val := ValueTriple{ + Environment: environment, + Measurements: measurements, + } + if o.Triples.EndorsedValues == nil { + o.Triples.EndorsedValues = NewValueTriples() + } + + if o.Triples.AddEndorsedValue(val) == nil { + return nil + } + } + return o +} \ No newline at end of file diff --git a/comid/comid_test.go b/comid/comid_test.go index 7aed60dd..add76c58 100644 --- a/comid/comid_test.go +++ b/comid/comid_test.go @@ -94,3 +94,33 @@ func Test_String2URI_nok(t *testing.T) { _, err := String2URI(&s) assert.EqualError(t, err, `expecting an absolute URI: "@@@" is not an absolute URI`) } + + +func Test_Comid_SimpleReferenceValue(t *testing.T) { + c := NewComid() + env := Environment{ + Instance: MustNewUUIDInstance(TestUUID), + } + + // Test digest reference value + err := c.AddDigestReferenceValue(env, "sha-256", []byte{ + 0x00, 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, + 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, + 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, + }) + require.NoError(t, err) + + // Verify values were added + require.NotNil(t, c.Triples.ReferenceValues) + require.Len(t, c.Triples.ReferenceValues.Values, 1) + + // Verify digest value + rv := c.Triples.ReferenceValues.Values[0] + require.NotNil(t, rv.Measurements.Values[0].Val.Digests) + require.Equal(t, HashAlgSHA256.ToUint64(), (*rv.Measurements.Values[0].Val.Digests)[0].HashAlgID) +} \ No newline at end of file diff --git a/comid/digests.go b/comid/digests.go index f6f4f8b7..3b8054bf 100644 --- a/comid/digests.go +++ b/comid/digests.go @@ -1,53 +1,52 @@ -// Copyright 2021 Contributors to the Veraison project. -// SPDX-License-Identifier: Apache-2.0 - package comid import ( - "fmt" - - "github.com/veraison/swid" + "fmt" + "github.com/veraison/swid" ) -// Digests is an alias for an array of SWID HashEntry +// Digests is an array of SWID HashEntry type Digests []swid.HashEntry // NewDigests instantiates an empty array of Digests func NewDigests() *Digests { - return new(Digests) + return new(Digests) } -// AddDigest create a new digest from the supplied arguments and appends it to -// the (already instantiated) Digests target. The method is a no-op if it is -// invoked on a nil target and will refuse to add inconsistent algo/value -// combinations. +// AddDigest create a new digest from the supplied arguments and appends it to the (already instantiated) Digests target. +// The method is a no-op if it is invoked on a nil target and will refuse to add inconsistent algo/value combinations. func (o *Digests) AddDigest(algID uint64, value []byte) *Digests { - if o != nil { - he := NewHashEntry(algID, value) - if he == nil { - return nil - } - *o = append(*o, *he) - } - return o + if o != nil { + he := NewHashEntry(algID, value) + if he == nil { + return nil + } + *o = append(*o, *he) + } + return o } func (o Digests) Valid() error { - for i, m := range o { - if err := swid.ValidHashEntry(m.HashAlgID, m.HashValue); err != nil { - return fmt.Errorf("digest at index %d: %w", i, err) - } - } - return nil + if len(o) == 0 { + return fmt.Errorf("digests must not be empty") + } + + for i, m := range o { + if err := swid.ValidHashEntry(m.HashAlgID, m.HashValue); err != nil { + return fmt.Errorf("digest at index %d: %w", i, err) + } + } + return nil } + func NewHashEntry(algID uint64, value []byte) *swid.HashEntry { - var he swid.HashEntry + var he swid.HashEntry - err := he.Set(algID, value) - if err != nil { - return nil - } + err := he.Set(algID, value) + if err != nil { + return nil + } - return &he + return &he } diff --git a/comid/hashalg.go b/comid/hashalg.go new file mode 100644 index 00000000..ffe8ab2c --- /dev/null +++ b/comid/hashalg.go @@ -0,0 +1,77 @@ +package comid + +import ( + "fmt" + "strings" + "encoding/json" +) + +type HashAlg uint64 + +const ( + HashAlgSHA256 HashAlg = 1 + HashAlgSHA384 HashAlg = 2 + HashAlgSHA512 HashAlg = 3 +) + +func (h HashAlg) Valid() bool { + return h >= HashAlgSHA256 && h <= HashAlgSHA512 +} + +func HashAlgFromString(s string) HashAlg { + switch strings.ToLower(s) { + case "sha-256": + return HashAlgSHA256 + case "sha-384": + return HashAlgSHA384 + case "sha-512": + return HashAlgSHA512 + default: + return 0 + } +} + +func (h HashAlg) String() string { + switch h { + case HashAlgSHA256: + return "sha-256" + case HashAlgSHA384: + return "sha-384" + case HashAlgSHA512: + return "sha-512" + default: + return fmt.Sprintf("unknown(%d)", h) + } +} + +func (h HashAlg) MarshalJSON() ([]byte, error) { + return json.Marshal(h.String()) +} +func (h *HashAlg) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + *h = HashAlgFromString(s) + if !h.Valid() { + return fmt.Errorf("invalid hash algorithm: %s", s) + } + return nil +} + +// ToUint64 returns 0 if invalid, otherwise the numeric value. +func (h HashAlg) ToUint64() uint64 { + if !h.Valid() { + return 0 + } + return uint64(h) +} + +// HashAlgFromUint64 returns 0 if v is invalid, otherwise the matching HashAlg. +func HashAlgFromUint64(v uint64) HashAlg { + h := HashAlg(v) + if !h.Valid() { + return 0 + } + return h +} \ No newline at end of file diff --git a/comid/hashalg_test.go b/comid/hashalg_test.go new file mode 100644 index 00000000..3b95a8c3 --- /dev/null +++ b/comid/hashalg_test.go @@ -0,0 +1,108 @@ +package comid + +import ( + "encoding/json" + "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_HashAlgFromString(t *testing.T) { + tests := []struct { + name string + input string + want HashAlg + }{ + {"sha-256", "sha-256", HashAlgSHA256}, + {"SHA-256", "SHA-256", HashAlgSHA256}, + {"sha-384", "sha-384", HashAlgSHA384}, + {"sha-512", "sha-512", HashAlgSHA512}, + {"invalid", "invalid", 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := HashAlgFromString(tt.input) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_HashAlg_String(t *testing.T) { + tests := []struct { + name string + hash HashAlg + want string + }{ + {"sha-256", HashAlgSHA256, "sha-256"}, + {"sha-384", HashAlgSHA384, "sha-384"}, + {"sha-512", HashAlgSHA512, "sha-512"}, + {"invalid", 99, "unknown(99)"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.hash.String() + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_HashAlg_JSON(t *testing.T) { + tests := []struct { + name string + hash HashAlg + want string + }{ + {"sha-256", HashAlgSHA256, `"sha-256"`}, + {"sha-384", HashAlgSHA384, `"sha-384"`}, + {"sha-512", HashAlgSHA512, `"sha-512"`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.hash) + require.NoError(t, err) + assert.Equal(t, tt.want, string(data)) + + var got HashAlg + err = json.Unmarshal(data, &got) + require.NoError(t, err) + assert.Equal(t, tt.hash, got) + }) + } +} + +// ...existing code... + +func Test_HashAlg_Uint64(t *testing.T) { + tests := []struct { + name string + hash HashAlg + want uint64 + }{ + {"sha-256", HashAlgSHA256, 1}, + {"sha-384", HashAlgSHA384, 2}, + {"sha-512", HashAlgSHA512, 3}, + {"invalid", HashAlg(99), 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 1) Check forward conversion (HashAlg → uint64). + got := tt.hash.ToUint64() + assert.Equal(t, tt.want, got) + + // 2) Check backward conversion (uint64 → HashAlg). + // For valid alg we expect round-trip equality, for invalid we do not. + recon := HashAlgFromUint64(tt.want) + if tt.name == "invalid" { + // The want is 0, so we expect recon == 0, not 99. + assert.Equal(t, HashAlg(0), recon) + } else { + // Valid case - recon should match original hash. + assert.Equal(t, tt.hash, recon) + } + }) + } +} \ No newline at end of file diff --git a/comid/valuetriple.go b/comid/valuetriple.go index a7662ea0..45891319 100644 --- a/comid/valuetriple.go +++ b/comid/valuetriple.go @@ -89,3 +89,34 @@ func (o ValueTriples) MarshalJSON() ([]byte, error) { func (o *ValueTriples) UnmarshalJSON(data []byte) error { return (*extensions.Collection[ValueTriple, *ValueTriple])(o).UnmarshalJSON(data) } + +// ReferenceValueBuilder provides a fluent interface for building reference values +type ReferenceValueBuilder struct { + triple ValueTriple +} + +// NewReferenceValueBuilder creates a new builder for reference values +func NewReferenceValueBuilder() *ReferenceValueBuilder { + return &ReferenceValueBuilder{ + triple: ValueTriple{ + Measurements: *NewMeasurements(), + }, + } +} + +func (b *ReferenceValueBuilder) WithEnvironment(env Environment) *ReferenceValueBuilder { + b.triple.Environment = env + return b +} + +func (b *ReferenceValueBuilder) WithMeasurement(m *Measurement) *ReferenceValueBuilder { + b.triple.Measurements.Add(m) + return b +} + +func (b *ReferenceValueBuilder) Build() (*ValueTriple, error) { + if err := b.triple.Valid(); err != nil { + return nil, fmt.Errorf("invalid reference value: %w", err) + } + return &b.triple, nil +}