diff --git a/.github/workflows/ci-go-cover.yml b/.github/workflows/ci-go-cover.yml index 94b48854..9dce92ca 100644 --- a/.github/workflows/ci-go-cover.yml +++ b/.github/workflows/ci-go-cover.yml @@ -26,7 +26,7 @@ jobs: steps: - uses: actions/setup-go@v2 with: - go-version: "1.18" + go-version: "1.19" - name: Checkout code uses: actions/checkout@v2 - name: Install mockgen diff --git a/.github/workflows/linters.yml b/.github/workflows/linters.yml index 665e7048..92368c61 100644 --- a/.github/workflows/linters.yml +++ b/.github/workflows/linters.yml @@ -10,7 +10,7 @@ jobs: steps: - uses: actions/setup-go@v2 with: - go-version: "1.18" + go-version: "1.19" - name: Checkout code uses: actions/checkout@v2 - name: Install golangci-lint diff --git a/Makefile b/Makefile index de492694..f5e8bbdb 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,8 @@ GOPKG := github.com/veraison/corim/corim GOPKG += github.com/veraison/corim/comid GOPKG += github.com/veraison/corim/cots GOPKG += github.com/veraison/corim/cocli/cmd +GOPKG += github.com/veraison/corim/encoding +GOPKG += github.com/veraison/corim/extensions MOCKGEN := $(shell go env GOPATH)/bin/mockgen INTERFACES := cocli/cmd/isubmitter.go diff --git a/cocli/cmd/corimCreate_test.go b/cocli/cmd/corimCreate_test.go index 33f112b8..b2dfb770 100644 --- a/cocli/cmd/corimCreate_test.go +++ b/cocli/cmd/corimCreate_test.go @@ -117,7 +117,7 @@ func Test_CorimCreateCmd_with_a_bad_comid(t *testing.T) { cmd.SetArgs(args) err = cmd.Execute() - assert.EqualError(t, err, `error loading CoMID from bad-comid.cbor: cbor: unexpected "break" code`) + assert.EqualError(t, err, `error loading CoMID from bad-comid.cbor: expected map (CBOR Major Type 5), found Major Type 7`) } func Test_CorimCreateCmd_with_an_invalid_comid(t *testing.T) { @@ -138,7 +138,7 @@ func Test_CorimCreateCmd_with_an_invalid_comid(t *testing.T) { cmd.SetArgs(args) err = cmd.Execute() - assert.EqualError(t, err, `error adding CoMID from invalid-comid.cbor (check its validity using the "comid validate" sub-command)`) + assert.EqualError(t, err, `error loading CoMID from invalid-comid.cbor: missing mandatory field "Triples" (4)`) } func Test_CorimCreateCmd_with_a_bad_coswid(t *testing.T) { diff --git a/cocli/cmd/corimDisplay_test.go b/cocli/cmd/corimDisplay_test.go index 7f87cf25..5e1ccb47 100644 --- a/cocli/cmd/corimDisplay_test.go +++ b/cocli/cmd/corimDisplay_test.go @@ -76,7 +76,7 @@ func Test_CorimDisplayCmd_invalid_signed_corim(t *testing.T) { require.NoError(t, err) err = cmd.Execute() - assert.EqualError(t, err, "error decoding signed CoRIM from invalid.cbor: failed validation of unsigned CoRIM: empty id") + assert.EqualError(t, err, `error decoding signed CoRIM from invalid.cbor: failed CBOR decoding of unsigned CoRIM: unexpected EOF`) } func Test_CorimDisplayCmd_ok_top_level_view(t *testing.T) { diff --git a/cocli/cmd/corimExtract_test.go b/cocli/cmd/corimExtract_test.go index 2ec748d1..8b476c98 100644 --- a/cocli/cmd/corimExtract_test.go +++ b/cocli/cmd/corimExtract_test.go @@ -76,7 +76,7 @@ func Test_CorimExtractCmd_invalid_signed_corim(t *testing.T) { require.NoError(t, err) err = cmd.Execute() - assert.EqualError(t, err, "error decoding signed CoRIM from invalid.cbor: failed validation of unsigned CoRIM: empty id") + assert.EqualError(t, err, `error decoding signed CoRIM from invalid.cbor: failed CBOR decoding of unsigned CoRIM: unexpected EOF`) } func Test_CorimExtractCmd_ok_save_to_default_dir(t *testing.T) { diff --git a/cocli/cmd/corimSign_test.go b/cocli/cmd/corimSign_test.go index 27956a4f..bc460376 100644 --- a/cocli/cmd/corimSign_test.go +++ b/cocli/cmd/corimSign_test.go @@ -91,7 +91,7 @@ func Test_CorimSignCmd_bad_unsigned_corim(t *testing.T) { require.NoError(t, err) err = cmd.Execute() - assert.EqualError(t, err, "error decoding unsigned CoRIM from bad.txt: unexpected EOF") + assert.EqualError(t, err, "error decoding unsigned CoRIM from bad.txt: expected map (CBOR Major Type 5), found Major Type 3") } func Test_CorimSignCmd_invalid_unsigned_corim(t *testing.T) { @@ -109,7 +109,7 @@ func Test_CorimSignCmd_invalid_unsigned_corim(t *testing.T) { require.NoError(t, err) err = cmd.Execute() - assert.EqualError(t, err, "error validating CoRIM: tags validation failed: no tags") + assert.EqualError(t, err, `error decoding unsigned CoRIM from invalid.cbor: missing mandatory field "Tags" (1)`) } func Test_CorimSignCmd_non_existent_meta_file(t *testing.T) { diff --git a/comid/comid.go b/comid/comid.go index ea9779c7..194745d3 100644 --- a/comid/comid.go +++ b/comid/comid.go @@ -8,6 +8,8 @@ import ( "fmt" "net/url" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" "github.com/veraison/swid" ) @@ -19,6 +21,8 @@ type Comid struct { Entities *Entities `cbor:"2,keyasint,omitempty" json:"entities,omitempty"` LinkedTags *LinkedTags `cbor:"3,keyasint,omitempty" json:"linked-tags,omitempty"` Triples Triples `cbor:"4,keyasint" json:"triples"` + + Extensions } // NewComid instantiates an empty Comid @@ -26,6 +30,16 @@ func NewComid() *Comid { return &Comid{} } +// RegisterExtensions registers a struct as a collections of extensions +func (o *Comid) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) +} + +// GetExtensions returns pervisouosly registered extension +func (o *Comid) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + // SetLanguage sets the language used in the target Comid to the supplied // language tag. See also: BCP 47 and the IANA Language subtag registry. func (o *Comid) SetLanguage(language string) *Comid { @@ -229,7 +243,7 @@ func (o Comid) Valid() error { return fmt.Errorf("triples validation failed: %w", err) } - return nil + return o.Extensions.ValidComid(&o) } // ToCBOR serializes the target Comid to CBOR @@ -238,17 +252,17 @@ func (o Comid) ToCBOR() ([]byte, error) { return nil, err } - return em.Marshal(&o) + return encoding.SerializeStructToCBOR(em, &o) } // FromCBOR deserializes a CBOR-encoded CoMID into the target Comid func (o *Comid) FromCBOR(data []byte) error { - return dm.Unmarshal(data, o) + return encoding.PopulateStructFromCBOR(dm, data, o) } // FromJSON deserializes a JSON-encoded CoMID into the target Comid func (o *Comid) FromJSON(data []byte) error { - return json.Unmarshal(data, o) + return encoding.PopulateStructFromJSON(data, o) } // ToJSON serializes the target Comid to JSON @@ -257,7 +271,7 @@ func (o Comid) ToJSON() ([]byte, error) { return nil, err } - return json.Marshal(&o) + return encoding.SerializeStructToJSON(&o) } func (o Comid) ToJSONPretty(indent string) ([]byte, error) { diff --git a/comid/comid_test.go b/comid/comid_test.go new file mode 100644 index 00000000..a8cdf198 --- /dev/null +++ b/comid/comid_test.go @@ -0,0 +1,51 @@ +package comid + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/veraison/swid" +) + +func Test_Comid_Extensions(t *testing.T) { + c := NewComid() + assert.Nil(t, c.GetExtensions()) + assert.Equal(t, "", c.MustGetString("field-one")) + + err := c.Set("field-one", "foo") + assert.EqualError(t, err, "extension not found: field-one") + + type ComidExt struct { + FieldOne string `cbor:"-1,keyasint" json:"field-one"` + } + + c.RegisterExtensions(&ComidExt{}) + + err = c.Set("field-one", "foo") + assert.NoError(t, err) + assert.Equal(t, "foo", c.MustGetString("-1")) +} + +func Test_Comid_ToJSONPretty(t *testing.T) { + c := NewComid() + + _, err := c.ToJSONPretty(" ") + assert.EqualError(t, err, "tag-identity validation failed: empty tag-id") + + c.TagIdentity = TagIdentity{TagID: *swid.NewTagID("test"), TagVersion: 1} + c.Triples = Triples{ReferenceValues: &[]ReferenceValue{}} + + expected := `{ + "tag-identity": { + "id": "test", + "version": 1 + }, + "triples": { + "reference-values": [] + } +}` + v, err := c.ToJSONPretty(" ") + require.NoError(t, err) + assert.Equal(t, expected, string(v)) +} diff --git a/comid/cryptokey_test.go b/comid/cryptokey_test.go index e1a3b4f8..1c0812fc 100644 --- a/comid/cryptokey_test.go +++ b/comid/cryptokey_test.go @@ -119,7 +119,7 @@ func Test_CryptoKey_NewCOSEKey(t *testing.T) { assert.EqualError(t, err, "empty COSE_Key bytes") _, err = NewCOSEKey([]byte("DEADBEEF")) - assert.Contains(t, err.Error(), "cbor: cannot unmarshal") + assert.Contains(t, err.Error(), "cbor: 3 bytes of extraneous data starting at index 5") badKey := []byte{ // taken from go-cose unit tests 0xa2, // map(2) diff --git a/comid/entity.go b/comid/entity.go index 694fef6c..278b7e52 100644 --- a/comid/entity.go +++ b/comid/entity.go @@ -3,7 +3,12 @@ package comid -import "fmt" +import ( + "fmt" + + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" +) type TaggedURI string @@ -16,8 +21,21 @@ type Entity struct { EntityName string `cbor:"0,keyasint" json:"name"` RegID *TaggedURI `cbor:"1,keyasint,omitempty" json:"regid,omitempty"` Roles Roles `cbor:"2,keyasint" json:"roles"` + + Extensions +} + +// RegisterExtensions registers a struct as a collections of extensions +func (o *Entity) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) } +// GetExtensions returns pervisouosly registered extension +func (o *Entity) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + +// SetEntityName is used to set the EntityName field of Entity using supplied name func (o *Entity) SetEntityName(name string) *Entity { if o != nil { if name == "" { @@ -28,6 +46,7 @@ func (o *Entity) SetEntityName(name string) *Entity { return o } +// SetRegID is used to set the RegID field of Entity using supplied uri func (o *Entity) SetRegID(uri string) *Entity { if o != nil { if uri == "" { @@ -39,6 +58,7 @@ func (o *Entity) SetRegID(uri string) *Entity { return o } +// SetRoles appends the supplied roles to the target entity. func (o *Entity) SetRoles(roles ...Role) *Entity { if o != nil { o.Roles.Add(roles...) @@ -46,6 +66,7 @@ func (o *Entity) SetRoles(roles ...Role) *Entity { return o } +// Valid checks for validity of the fields within each Entity func (o Entity) Valid() error { if o.EntityName == "" { return fmt.Errorf("invalid entity: empty entity-name") @@ -59,7 +80,27 @@ func (o Entity) Valid() error { return fmt.Errorf("invalid entity: %w", err) } - return nil + return o.Extensions.ValidEntity(&o) +} + +// UnmarshalCBOR deserializes from CBOR +func (o *Entity) UnmarshalCBOR(data []byte) error { + return encoding.PopulateStructFromCBOR(dm, data, o) +} + +// MarshalCBOR serializes to CBOR +func (o *Entity) MarshalCBOR() ([]byte, error) { + return encoding.SerializeStructToCBOR(em, o) +} + +// UnmarshalJSON deserializes from JSON +func (o *Entity) UnmarshalJSON(data []byte) error { + return encoding.PopulateStructFromJSON(data, o) +} + +// MarshalJSON serializes to JSON +func (o *Entity) MarshalJSON() ([]byte, error) { + return encoding.SerializeStructToJSON(o) } // Entities is an array of entity-map's @@ -78,6 +119,7 @@ func (o *Entities) AddEntity(e Entity) *Entities { return o } +// Valid iterates over the range of individual entities to check for validity func (o Entities) Valid() error { for i, m := range o { if err := m.Valid(); err != nil { diff --git a/comid/example_test.go b/comid/example_test.go index 7f159d44..756829d3 100644 --- a/comid/example_test.go +++ b/comid/example_test.go @@ -37,7 +37,8 @@ func Example_encode() { SetSVN(2). AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}). AddDigest(swid.Sha256_32, []byte{0xff, 0xff, 0xff, 0xff}). - SetOpFlags(OpFlagNotSecure, OpFlagDebug). + SetFlagsTrue(FlagIsDebug). + SetFlagsFalse(FlagIsSecure). SetSerialNumber("C02X70VHJHD5"). SetUEID(TestUEID). SetUUID(TestUUID). @@ -65,7 +66,8 @@ func Example_encode() { SetMinSVN(2). AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}). AddDigest(swid.Sha256_32, []byte{0xff, 0xff, 0xff, 0xff}). - SetOpFlags(OpFlagNotSecure, OpFlagDebug, OpFlagNotConfigured). + SetFlagsTrue(FlagIsDebug). + SetFlagsFalse(FlagIsSecure, FlagIsConfigured). SetSerialNumber("C02X70VHJHD5"). SetUEID(TestUEID). SetUUID(TestUUID). @@ -107,8 +109,8 @@ func Example_encode() { } // Output: - //a50065656e2d474201a10078206d792d6e733a61636d652d726f616472756e6e65722d737570706c656d656e740282a3006941434d45204c74642e01d8207468747470733a2f2f61636d652e6578616d706c6502820100a20069454d4341204c74642e0281020382a200781a6d792d6e733a61636d652d726f616472756e6e65722d626173650100a20078196d792d6e733a61636d652d726f616472756e6e65722d6f6c64010104a4008182a300a500d86f445502c000016941434d45204c74642e026a526f616452756e6e65720300040101d902264702deadbeefdead02d8255031fb5abf023e4992aa4e95f9c1503bfa81a200d8255031fb5abf023e4992aa4e95f9c1503bfa01aa01d90228020282820644abcdef00820644ffffffff030a04d9023044010203040544ffffffff064802005e1000000001075020010db8000000000000000000000068086c43303258373056484a484435094702deadbeefdead0a5031fb5abf023e4992aa4e95f9c1503bfa018182a300a500d8255031fb5abf023e4992aa4e95f9c1503bfa016941434d45204c74642e026a526f616452756e6e65720300040101d902264702deadbeefdead02d8255031fb5abf023e4992aa4e95f9c1503bfa81a200d8255031fb5abf023e4992aa4e95f9c1503bfa01aa01d90229020282820644abcdef00820644ffffffff030b04d9023044010203040544ffffffff064802005e1000000001075020010db8000000000000000000000068086c43303258373056484a484435094702deadbeefdead0a5031fb5abf023e4992aa4e95f9c1503bfa028182a101d8255031fb5abf023e4992aa4e95f9c1503bfa81d9022a78b12d2d2d2d2d424547494e205055424c4943204b45592d2d2d2d2d0a4d466b77457759484b6f5a497a6a3043415159494b6f5a497a6a304441516344516741455731427671462b2f727938425761375a454d553178595948455138420a6c4c54344d46484f614f2b4943547449767245654570722f7366544150363648326843486462354845584b74524b6f6436514c634f4c504131513d3d0a2d2d2d2d2d454e44205055424c4943204b45592d2d2d2d2d038182a101d902264702deadbeefdead81d9022a78b12d2d2d2d2d424547494e205055424c4943204b45592d2d2d2d2d0a4d466b77457759484b6f5a497a6a3043415159494b6f5a497a6a304441516344516741455731427671462b2f727938425761375a454d553178595948455138420a6c4c54344d46484f614f2b4943547449767245654570722f7366544150363648326843486462354845584b74524b6f6436514c634f4c504131513d3d0a2d2d2d2d2d454e44205055424c4943204b45592d2d2d2d2d - //{"lang":"en-GB","tag-identity":{"id":"my-ns:acme-roadrunner-supplement"},"entities":[{"name":"ACME Ltd.","regid":"https://acme.example","roles":["creator","tagCreator"]},{"name":"EMCA Ltd.","roles":["maintainer"]}],"linked-tags":[{"target":"my-ns:acme-roadrunner-base","rel":"supplements"},{"target":"my-ns:acme-roadrunner-old","rel":"replaces"}],"triples":{"reference-values":[{"environment":{"class":{"id":{"type":"oid","value":"2.5.2.8192"},"vendor":"ACME Ltd.","model":"RoadRunner","layer":0,"index":1},"instance":{"type":"ueid","value":"At6tvu/erQ=="},"group":{"type":"ueid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"measurements":[{"key":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"value":{"svn":{"type":"exact-value","value":2},"digests":["sha-256-32;q83vAA==","sha-256-32;/////w=="],"op-flags":["notSecure","debug"],"raw-value":{"type":"bytes","value":"AQIDBA=="},"raw-value-mask":"/////w==","mac-addr":"02:00:5e:10:00:00:00:01","ip-addr":"2001:db8::68","serial-number":"C02X70VHJHD5","ueid":"At6tvu/erQ==","uuid":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}}]}],"endorsed-values":[{"environment":{"class":{"id":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"vendor":"ACME Ltd.","model":"RoadRunner","layer":0,"index":1},"instance":{"type":"ueid","value":"At6tvu/erQ=="},"group":{"type":"ueid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"measurements":[{"key":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"value":{"svn":{"type":"min-value","value":2},"digests":["sha-256-32;q83vAA==","sha-256-32;/////w=="],"op-flags":["notConfigured","notSecure","debug"],"raw-value":{"type":"bytes","value":"AQIDBA=="},"raw-value-mask":"/////w==","mac-addr":"02:00:5e:10:00:00:00:01","ip-addr":"2001:db8::68","serial-number":"C02X70VHJHD5","ueid":"At6tvu/erQ==","uuid":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}}]}],"attester-verification-keys":[{"environment":{"instance":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}],"dev-identity-keys":[{"environment":{"instance":{"type":"ueid","value":"At6tvu/erQ=="}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}]}} + // a50065656e2d474201a10078206d792d6e733a61636d652d726f616472756e6e65722d737570706c656d656e740282a3006941434d45204c74642e01d8207468747470733a2f2f61636d652e6578616d706c6502820100a20069454d4341204c74642e0281020382a200781a6d792d6e733a61636d652d726f616472756e6e65722d626173650100a20078196d792d6e733a61636d652d726f616472756e6e65722d6f6c64010104a4008182a300a500d86f445502c000016941434d45204c74642e026a526f616452756e6e65720300040101d902264702deadbeefdead02d8255031fb5abf023e4992aa4e95f9c1503bfa81a200d8255031fb5abf023e4992aa4e95f9c1503bfa01aa01d90228020282820644abcdef00820644ffffffff03a201f403f504d9023044010203040544ffffffff064802005e1000000001075020010db8000000000000000000000068086c43303258373056484a484435094702deadbeefdead0a5031fb5abf023e4992aa4e95f9c1503bfa018182a300a500d8255031fb5abf023e4992aa4e95f9c1503bfa016941434d45204c74642e026a526f616452756e6e65720300040101d902264702deadbeefdead02d8255031fb5abf023e4992aa4e95f9c1503bfa81a200d8255031fb5abf023e4992aa4e95f9c1503bfa01aa01d90229020282820644abcdef00820644ffffffff03a300f401f403f504d9023044010203040544ffffffff064802005e1000000001075020010db8000000000000000000000068086c43303258373056484a484435094702deadbeefdead0a5031fb5abf023e4992aa4e95f9c1503bfa028182a101d8255031fb5abf023e4992aa4e95f9c1503bfa81d9022a78b12d2d2d2d2d424547494e205055424c4943204b45592d2d2d2d2d0a4d466b77457759484b6f5a497a6a3043415159494b6f5a497a6a304441516344516741455731427671462b2f727938425761375a454d553178595948455138420a6c4c54344d46484f614f2b4943547449767245654570722f7366544150363648326843486462354845584b74524b6f6436514c634f4c504131513d3d0a2d2d2d2d2d454e44205055424c4943204b45592d2d2d2d2d038182a101d902264702deadbeefdead81d9022a78b12d2d2d2d2d424547494e205055424c4943204b45592d2d2d2d2d0a4d466b77457759484b6f5a497a6a3043415159494b6f5a497a6a304441516344516741455731427671462b2f727938425761375a454d553178595948455138420a6c4c54344d46484f614f2b4943547449767245654570722f7366544150363648326843486462354845584b74524b6f6436514c634f4c504131513d3d0a2d2d2d2d2d454e44205055424c4943204b45592d2d2d2d2d + // {"lang":"en-GB","tag-identity":{"id":"my-ns:acme-roadrunner-supplement"},"entities":[{"name":"ACME Ltd.","regid":"https://acme.example","roles":["creator","tagCreator"]},{"name":"EMCA Ltd.","roles":["maintainer"]}],"linked-tags":[{"target":"my-ns:acme-roadrunner-base","rel":"supplements"},{"target":"my-ns:acme-roadrunner-old","rel":"replaces"}],"triples":{"reference-values":[{"environment":{"class":{"id":{"type":"oid","value":"2.5.2.8192"},"vendor":"ACME Ltd.","model":"RoadRunner","layer":0,"index":1},"instance":{"type":"ueid","value":"At6tvu/erQ=="},"group":{"type":"ueid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"measurements":[{"key":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"value":{"svn":{"type":"exact-value","value":2},"digests":["sha-256-32;q83vAA==","sha-256-32;/////w=="],"flags":{"is-secure":false,"is-debug":true},"raw-value":{"type":"bytes","value":"AQIDBA=="},"raw-value-mask":"/////w==","mac-addr":"02:00:5e:10:00:00:00:01","ip-addr":"2001:db8::68","serial-number":"C02X70VHJHD5","ueid":"At6tvu/erQ==","uuid":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}}]}],"endorsed-values":[{"environment":{"class":{"id":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"vendor":"ACME Ltd.","model":"RoadRunner","layer":0,"index":1},"instance":{"type":"ueid","value":"At6tvu/erQ=="},"group":{"type":"ueid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"measurements":[{"key":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"value":{"svn":{"type":"min-value","value":2},"digests":["sha-256-32;q83vAA==","sha-256-32;/////w=="],"flags":{"is-configured":false,"is-secure":false,"is-debug":true},"raw-value":{"type":"bytes","value":"AQIDBA=="},"raw-value-mask":"/////w==","mac-addr":"02:00:5e:10:00:00:00:01","ip-addr":"2001:db8::68","serial-number":"C02X70VHJHD5","ueid":"At6tvu/erQ==","uuid":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}}]}],"attester-verification-keys":[{"environment":{"instance":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}],"dev-identity-keys":[{"environment":{"instance":{"type":"ueid","value":"At6tvu/erQ=="}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}]}} } func Example_encode_PSA() { @@ -162,8 +164,8 @@ func Example_encode_PSA() { } // Output: - //a301a10078206d792d6e733a61636d652d726f616472756e6e65722d737570706c656d656e740281a3006941434d45204c74642e01d8207468747470733a2f2f61636d652e6578616d706c65028301000204a2008182a100a300d90258582061636d652d696d706c656d656e746174696f6e2d69642d303030303030303031016941434d45204c74642e026e526f616452756e6e657220322e3082a200d90259a30162424c0465352e302e35055820acbb11c7e4da217205523ce4ce1a245ae1a239ae3c6bfd9e7871f7e5d8bae86b01a10281820644abcdef00a200d90259a3016450526f540465312e332e35055820acbb11c7e4da217205523ce4ce1a245ae1a239ae3c6bfd9e7871f7e5d8bae86b01a10281820644abcdef00028182a101d902264702deadbeefdead81d9022a78b12d2d2d2d2d424547494e205055424c4943204b45592d2d2d2d2d0a4d466b77457759484b6f5a497a6a3043415159494b6f5a497a6a304441516344516741455731427671462b2f727938425761375a454d553178595948455138420a6c4c54344d46484f614f2b4943547449767245654570722f7366544150363648326843486462354845584b74524b6f6436514c634f4c504131513d3d0a2d2d2d2d2d454e44205055424c4943204b45592d2d2d2d2d - //{"tag-identity":{"id":"my-ns:acme-roadrunner-supplement"},"entities":[{"name":"ACME Ltd.","regid":"https://acme.example","roles":["creator","tagCreator","maintainer"]}],"triples":{"reference-values":[{"environment":{"class":{"id":{"type":"psa.impl-id","value":"YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE="},"vendor":"ACME Ltd.","model":"RoadRunner 2.0"}},"measurements":[{"key":{"type":"psa.refval-id","value":{"label":"BL","version":"5.0.5","signer-id":"rLsRx+TaIXIFUjzkzhokWuGiOa48a/2eeHH35di66Gs="}},"value":{"digests":["sha-256-32;q83vAA=="]}},{"key":{"type":"psa.refval-id","value":{"label":"PRoT","version":"1.3.5","signer-id":"rLsRx+TaIXIFUjzkzhokWuGiOa48a/2eeHH35di66Gs="}},"value":{"digests":["sha-256-32;q83vAA=="]}}]}],"attester-verification-keys":[{"environment":{"instance":{"type":"ueid","value":"At6tvu/erQ=="}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}]}} + // a301a10078206d792d6e733a61636d652d726f616472756e6e65722d737570706c656d656e740281a3006941434d45204c74642e01d8207468747470733a2f2f61636d652e6578616d706c65028301000204a2008182a100a300d90258582061636d652d696d706c656d656e746174696f6e2d69642d303030303030303031016941434d45204c74642e026e526f616452756e6e657220322e3082a200d90259a30162424c0465352e302e35055820acbb11c7e4da217205523ce4ce1a245ae1a239ae3c6bfd9e7871f7e5d8bae86b01a10281820644abcdef00a200d90259a3016450526f540465312e332e35055820acbb11c7e4da217205523ce4ce1a245ae1a239ae3c6bfd9e7871f7e5d8bae86b01a10281820644abcdef00028182a101d902264702deadbeefdead81d9022a78b12d2d2d2d2d424547494e205055424c4943204b45592d2d2d2d2d0a4d466b77457759484b6f5a497a6a3043415159494b6f5a497a6a304441516344516741455731427671462b2f727938425761375a454d553178595948455138420a6c4c54344d46484f614f2b4943547449767245654570722f7366544150363648326843486462354845584b74524b6f6436514c634f4c504131513d3d0a2d2d2d2d2d454e44205055424c4943204b45592d2d2d2d2d + // {"tag-identity":{"id":"my-ns:acme-roadrunner-supplement"},"entities":[{"name":"ACME Ltd.","regid":"https://acme.example","roles":["creator","tagCreator","maintainer"]}],"triples":{"reference-values":[{"environment":{"class":{"id":{"type":"psa.impl-id","value":"YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE="},"vendor":"ACME Ltd.","model":"RoadRunner 2.0"}},"measurements":[{"key":{"type":"psa.refval-id","value":{"label":"BL","version":"5.0.5","signer-id":"rLsRx+TaIXIFUjzkzhokWuGiOa48a/2eeHH35di66Gs="}},"value":{"digests":["sha-256-32;q83vAA=="]}},{"key":{"type":"psa.refval-id","value":{"label":"PRoT","version":"1.3.5","signer-id":"rLsRx+TaIXIFUjzkzhokWuGiOa48a/2eeHH35di66Gs="}},"value":{"digests":["sha-256-32;q83vAA=="]}}]}],"attester-verification-keys":[{"environment":{"instance":{"type":"ueid","value":"At6tvu/erQ=="}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}]}} } func Example_encode_PSA_attestation_verification() { diff --git a/comid/extensions.go b/comid/extensions.go new file mode 100644 index 00000000..b62da058 --- /dev/null +++ b/comid/extensions.go @@ -0,0 +1,173 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package comid + +import ( + "github.com/veraison/corim/extensions" +) + +type IComidValidator interface { + ValidComid(*Comid) error +} + +type ITriplesValidator interface { + ValidTriples(*Triples) error +} + +type IMvalValidator interface { + ValidMval(*Mval) error +} + +type IEntityValidator interface { + ValidEntity(*Entity) error +} + +type IFlagsMapValidator interface { + ValidFlagsMap(*FlagsMap) error +} + +type IFlagSetter interface { + AnySet() bool + SetTrue(Flag) + SetFalse(Flag) + Clear(Flag) + Get(Flag) *bool +} + +type Extensions struct { + extensions.Extensions +} + +func (o *Extensions) ValidComid(comid *Comid) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(IComidValidator) + if ok { + if err := ev.ValidComid(comid); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) ValidTriples(triples *Triples) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(ITriplesValidator) + if ok { + if err := ev.ValidTriples(triples); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) ValidMval(triples *Mval) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(IMvalValidator) + if ok { + if err := ev.ValidMval(triples); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) ValidEntity(triples *Entity) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(IEntityValidator) + if ok { + if err := ev.ValidEntity(triples); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) ValidFlagsMap(triples *FlagsMap) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(IFlagsMapValidator) + if ok { + if err := ev.ValidFlagsMap(triples); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) SetTrue(flag Flag) { + if !o.HaveExtensions() { + return + } + + ev, ok := o.IExtensionsValue.(IFlagSetter) + if ok { + ev.SetTrue(flag) + } +} + +func (o *Extensions) SetFalse(flag Flag) { + if !o.HaveExtensions() { + return + } + + ev, ok := o.IExtensionsValue.(IFlagSetter) + if ok { + ev.SetFalse(flag) + } +} + +func (o *Extensions) Clear(flag Flag) { + if !o.HaveExtensions() { + return + } + + ev, ok := o.IExtensionsValue.(IFlagSetter) + if ok { + ev.Clear(flag) + } +} + +func (o *Extensions) Get(flag Flag) *bool { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(IFlagSetter) + if ok { + return ev.Get(flag) + } + + return nil +} + +func (o *Extensions) AnySet() bool { + if !o.HaveExtensions() { + return false + } + + ev, ok := o.IExtensionsValue.(IFlagSetter) + if ok { + return ev.AnySet() + } + + return false +} diff --git a/comid/extensions_test.go b/comid/extensions_test.go new file mode 100644 index 00000000..8b36c9a8 --- /dev/null +++ b/comid/extensions_test.go @@ -0,0 +1,96 @@ +package comid + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +var FlagTestFlag = Flag(-1) + +type TestExtension struct { + TestFlag *bool +} + +func (o *TestExtension) ValidComid(v *Comid) error { + return errors.New("invalid") +} + +func (o *TestExtension) ValidTriples(v *Triples) error { + return errors.New("invalid") +} + +func (o *TestExtension) ValidMval(v *Mval) error { + return errors.New("invalid") +} + +func (o *TestExtension) ValidFlagsMap(v *FlagsMap) error { + return errors.New("invalid") +} + +func (o *TestExtension) ValidEntity(v *Entity) error { + return errors.New("invalid") +} + +func (o *TestExtension) SetTrue(flag Flag) { + if flag == FlagTestFlag { + o.TestFlag = &True + } +} +func (o *TestExtension) SetFalse(flag Flag) { + if flag == FlagTestFlag { + o.TestFlag = &False + } +} + +func (o *TestExtension) Clear(flag Flag) { + if flag == FlagTestFlag { + o.TestFlag = nil + } +} + +func (o *TestExtension) Get(flag Flag) *bool { + if flag == FlagTestFlag { + return o.TestFlag + } + + return nil +} + +func (o *TestExtension) AnySet() bool { + return o.TestFlag != nil +} + +func Test_Extensions(t *testing.T) { + exts := Extensions{} + exts.Register(&TestExtension{}) + + err := exts.ValidComid(nil) + assert.EqualError(t, err, "invalid") + + err = exts.ValidTriples(nil) + assert.EqualError(t, err, "invalid") + + err = exts.ValidMval(nil) + assert.EqualError(t, err, "invalid") + + err = exts.ValidEntity(nil) + assert.EqualError(t, err, "invalid") + + err = exts.ValidFlagsMap(nil) + assert.EqualError(t, err, "invalid") + + assert.False(t, exts.AnySet()) + + exts.SetTrue(FlagTestFlag) + assert.True(t, exts.AnySet()) + assert.True(t, *exts.Get(FlagTestFlag)) + + exts.SetFalse(FlagTestFlag) + assert.False(t, *exts.Get(FlagTestFlag)) + + exts.Clear(FlagTestFlag) + assert.Nil(t, exts.Get(FlagTestFlag)) + assert.False(t, exts.AnySet()) +} diff --git a/comid/flagsmap.go b/comid/flagsmap.go index 279cf1a4..95737419 100644 --- a/comid/flagsmap.go +++ b/comid/flagsmap.go @@ -3,6 +3,11 @@ package comid +import ( + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" +) + var True = true var False = false @@ -52,6 +57,8 @@ type FlagsMap struct { // IsTcb indicates whether the measured environment is a trusted // computing base. IsTcb *bool `cbor:"8,keyasint,omitempty" json:"is-tcb,omitempty"` + + Extensions } func NewFlagsMap() *FlagsMap { @@ -65,7 +72,7 @@ func (o *FlagsMap) AnySet() bool { return true } - return false + return o.Extensions.AnySet() } func (o *FlagsMap) SetTrue(flags ...Flag) { @@ -90,6 +97,7 @@ func (o *FlagsMap) SetTrue(flags ...Flag) { case FlagIsTcb: o.IsTcb = &True default: + o.Extensions.SetTrue(flag) } } } @@ -116,6 +124,7 @@ func (o *FlagsMap) SetFalse(flags ...Flag) { case FlagIsTcb: o.IsTcb = &False default: + o.Extensions.SetFalse(flag) } } } @@ -142,6 +151,7 @@ func (o *FlagsMap) Clear(flags ...Flag) { case FlagIsTcb: o.IsTcb = nil default: + o.Extensions.Clear(flag) } } } @@ -167,11 +177,41 @@ func (o *FlagsMap) Get(flag Flag) *bool { case FlagIsTcb: return o.IsTcb default: - return nil + return o.Extensions.Get(flag) } } +// RegisterExtensions registers a struct as a collections of extensions +func (o *FlagsMap) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) +} + +// GetExtensions returns pervisouosly registered extension +func (o *FlagsMap) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + +// UnmarshalCBOR deserializes from CBOR +func (o *FlagsMap) UnmarshalCBOR(data []byte) error { + return encoding.PopulateStructFromCBOR(dm, data, o) +} + +// MarshalCBOR serializes to CBOR +func (o *FlagsMap) MarshalCBOR() ([]byte, error) { + return encoding.SerializeStructToCBOR(em, o) +} + +// UnmarshalJSON deserializes from JSON +func (o *FlagsMap) UnmarshalJSON(data []byte) error { + return encoding.PopulateStructFromJSON(data, o) +} + +// MarshalJSON serializes to JSON +func (o *FlagsMap) MarshalJSON() ([]byte, error) { + return encoding.SerializeStructToJSON(o) +} + // Valid returns an error if the FlagsMap is invalid. func (o FlagsMap) Valid() error { - return nil + return o.Extensions.ValidFlagsMap(&o) } diff --git a/comid/measurement.go b/comid/measurement.go index b1909d87..7bfd7f99 100644 --- a/comid/measurement.go +++ b/comid/measurement.go @@ -247,6 +247,16 @@ func (o *Mval) MarshalCBOR() ([]byte, error) { return encoding.SerializeStructToCBOR(em, o) } +// UnmarshalJSON deserializes from JSON +func (o *Mval) UnmarshalJSON(data []byte) error { + return encoding.PopulateStructFromJSON(data, o) +} + +// MarshalJSON serializes to JSON +func (o *Mval) MarshalJSON() ([]byte, error) { + return encoding.SerializeStructToJSON(o) +} + func (o Mval) Valid() error { if o.Ver == nil && o.SVN == nil && diff --git a/comid/measurement_test.go b/comid/measurement_test.go index c9abed17..d5d58bff 100644 --- a/comid/measurement_test.go +++ b/comid/measurement_test.go @@ -103,7 +103,7 @@ func TestMeasurement_NewUUIDMeasurement_some_value(t *testing.T) { tv := NewUUIDMeasurement(TestUUID). SetMinSVN(2). - SetOpFlags(OpFlagDebug). + SetFlagsTrue(FlagIsDebug). SetVersion("1.2.3", swid.VersionSchemeSemVer) require.NotNil(t, tv) diff --git a/comid/referencevalue_test.go b/comid/referencevalue_test.go new file mode 100644 index 00000000..00b89d18 --- /dev/null +++ b/comid/referencevalue_test.go @@ -0,0 +1,22 @@ +package comid + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ReferenceValue(t *testing.T) { + rv := ReferenceValue{} + err := rv.Valid() + assert.EqualError(t, err, "environment validation failed: environment must not be empty") + + rv.Environment.Instance = NewInstance() + id, err := uuid.NewUUID() + require.NoError(t, err) + rv.Environment.Instance.SetUUID(id) + err = rv.Valid() + assert.EqualError(t, err, "measurements validation failed: no measurement entries") +} diff --git a/comid/triples.go b/comid/triples.go index 28df6b75..8500734e 100644 --- a/comid/triples.go +++ b/comid/triples.go @@ -3,13 +3,50 @@ package comid -import "fmt" +import ( + "fmt" + + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" +) type Triples struct { ReferenceValues *[]ReferenceValue `cbor:"0,keyasint,omitempty" json:"reference-values,omitempty"` EndorsedValues *[]EndorsedValue `cbor:"1,keyasint,omitempty" json:"endorsed-values,omitempty"` AttestVerifKeys *[]AttestVerifKey `cbor:"2,keyasint,omitempty" json:"attester-verification-keys,omitempty"` DevIdentityKeys *[]DevIdentityKey `cbor:"3,keyasint,omitempty" json:"dev-identity-keys,omitempty"` + + Extensions +} + +// RegisterExtensions registers a struct as a collections of extensions +func (o *Triples) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) +} + +// GetExtensions returns pervisouosly registered extension +func (o *Triples) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + +// UnmarshalCBOR deserializes from CBOR +func (o *Triples) UnmarshalCBOR(data []byte) error { + return encoding.PopulateStructFromCBOR(dm, data, o) +} + +// MarshalCBOR serializes to CBOR +func (o *Triples) MarshalCBOR() ([]byte, error) { + return encoding.SerializeStructToCBOR(em, o) +} + +// UnmarshalJSON deserializes from JSON +func (o *Triples) UnmarshalJSON(data []byte) error { + return encoding.PopulateStructFromJSON(data, o) +} + +// MarshalJSON serializes to JSON +func (o *Triples) MarshalJSON() ([]byte, error) { + return encoding.SerializeStructToJSON(o) } // Valid checks that the Triples is valid as per the specification @@ -52,7 +89,7 @@ func (o Triples) Valid() error { } } - return nil + return o.Extensions.ValidTriples(&o) } func (o *Triples) AddReferenceValue(val ReferenceValue) *Triples { diff --git a/corim/entity.go b/corim/entity.go index 49aa1468..b9288095 100644 --- a/corim/entity.go +++ b/corim/entity.go @@ -7,6 +7,8 @@ import ( "fmt" "github.com/veraison/corim/comid" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" ) // Entity stores an entity-map capable of CBOR and JSON serializations. @@ -14,12 +16,24 @@ type Entity struct { EntityName string `cbor:"0,keyasint" json:"name"` RegID *comid.TaggedURI `cbor:"1,keyasint,omitempty" json:"regid,omitempty"` Roles Roles `cbor:"2,keyasint" json:"roles"` + + Extensions } func NewEntity() *Entity { return &Entity{} } +// RegisterExtensions registers a struct as a collections of extensions +func (o *Entity) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) +} + +// GetExtensions returns pervisouosly registered extension +func (o *Entity) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + // SetEntityName is used to set the EntityName field of Entity using supplied name func (o *Entity) SetEntityName(name string) *Entity { if o != nil { @@ -72,7 +86,27 @@ func (o Entity) Valid() error { return fmt.Errorf("invalid entity: %w", err) } - return nil + return o.Extensions.ValidEntity(&o) +} + +// UnmarshalCBOR deserializes from CBOR +func (o *Entity) UnmarshalCBOR(data []byte) error { + return encoding.PopulateStructFromCBOR(dm, data, o) +} + +// MarshalCBOR serializes to CBOR +func (o *Entity) MarshalCBOR() ([]byte, error) { + return encoding.SerializeStructToCBOR(em, o) +} + +// UnmarshalJSON deserializes from JSON +func (o *Entity) UnmarshalJSON(data []byte) error { + return encoding.PopulateStructFromJSON(data, o) +} + +// MarshalJSON serializes to JSON +func (o *Entity) MarshalJSON() ([]byte, error) { + return encoding.SerializeStructToJSON(o) } // Entities is an array of entity-map's diff --git a/corim/extensions.go b/corim/extensions.go new file mode 100644 index 00000000..55f2410a --- /dev/null +++ b/corim/extensions.go @@ -0,0 +1,68 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package corim + +import ( + "github.com/veraison/corim/extensions" +) + +type IEntityValidator interface { + ValidEntity(*Entity) error +} + +type ICorimValidator interface { + ValidCorim(*UnsignedCorim) error +} + +type ISignerValidator interface { + ValidSigner(*Signer) error +} + +type Extensions struct { + extensions.Extensions +} + +func (o *Extensions) ValidEntity(entity *Entity) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(IEntityValidator) + if ok { + if err := ev.ValidEntity(entity); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) ValidCorim(c *UnsignedCorim) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(ICorimValidator) + if ok { + if err := ev.ValidCorim(c); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) ValidSigner(signer *Signer) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(ISignerValidator) + if ok { + if err := ev.ValidSigner(signer); err != nil { + return err + } + } + + return nil +} diff --git a/corim/extensions_test.go b/corim/extensions_test.go new file mode 100644 index 00000000..653a13e1 --- /dev/null +++ b/corim/extensions_test.go @@ -0,0 +1,86 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package corim + +import ( + "errors" + "testing" + + "github.com/fxamacker/cbor/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type TestExtensions struct { + Address string `cbor:"-1,keyasint,omitempty" json:"address,omitempty"` + Size int `cbor:"-2,keyasint,omitempty" json:"size,omitempty"` +} + +func (o TestExtensions) ValidEntity(ent *Entity) error { + if ent.EntityName != "Futurama" { + return errors.New(`EntityName must be "Futurama"`) // nolint:golint + } + + return nil +} + +func (o TestExtensions) ValidCorim(c *UnsignedCorim) error { + return errors.New("invalid") +} + +func (o TestExtensions) ValidSigner(s *Signer) error { + return errors.New("invalid") +} + +func TestEntityExtensions_Valid(t *testing.T) { + ent := NewEntity() + ent.SetEntityName("The Simpsons") + ent.SetRoles(RoleManifestCreator) + + err := ent.Valid() + assert.NoError(t, err) + + ent.RegisterExtensions(&TestExtensions{}) + err = ent.Valid() + assert.EqualError(t, err, `EntityName must be "Futurama"`) + + ent.SetEntityName("Futurama") + err = ent.Valid() + assert.NoError(t, err) + + assert.EqualError(t, ent.Extensions.ValidCorim(nil), "invalid") + assert.EqualError(t, ent.Extensions.ValidSigner(nil), "invalid") +} + +func TestEntityExtensions_CBOR(t *testing.T) { + data := []byte{ + 0xa4, // map(4) + + 0x00, // key 0 + 0x64, // val tstr(4) + 0x61, 0x63, 0x6d, 0x65, // "acme" + + 0x02, // key 2 + 0x81, // array(1) + 0x01, // 1 + + 0x20, // key -1 + 0x63, // val tstr(3) + 0x66, 0x6f, 0x6f, // "foo" + + 0x21, // key -2 + 0x06, // val 6 + } + + ent := NewEntity() + ent.RegisterExtensions(&TestExtensions{}) + + err := cbor.Unmarshal(data, &ent) + assert.NoError(t, err) + + assert.Equal(t, ent.EntityName, "acme") + + address, err := ent.Get("address") + require.NoError(t, err) + assert.Equal(t, address, "foo") +} diff --git a/corim/meta.go b/corim/meta.go index 45f1ea49..4b1eeeb6 100644 --- a/corim/meta.go +++ b/corim/meta.go @@ -10,17 +10,31 @@ import ( "time" "github.com/veraison/corim/comid" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" ) type Signer struct { Name string `cbor:"0,keyasint" json:"name"` URI *comid.TaggedURI `cbor:"1,keyasint,omitempty" json:"uri,omitempty"` + + Extensions } func NewSigner() *Signer { return &Signer{} } +// RegisterExtensions registers a struct as a collections of extensions +func (o *Signer) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) +} + +// GetExtensions returns pervisouosly registered extension +func (o *Signer) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + // SetName sets the target Signer's name to the supplied value func (o *Signer) SetName(name string) *Signer { if o != nil { @@ -61,7 +75,27 @@ func (o Signer) Valid() error { } } - return nil + return o.Extensions.ValidSigner(&o) +} + +// UnmarshalCBOR deserializes from CBOR +func (o *Signer) UnmarshalCBOR(data []byte) error { + return encoding.PopulateStructFromCBOR(dm, data, o) +} + +// MarshalCBOR serializes to CBOR +func (o *Signer) MarshalCBOR() ([]byte, error) { + return encoding.SerializeStructToCBOR(em, o) +} + +// UnmarshalJSON deserializes from JSON +func (o *Signer) UnmarshalJSON(data []byte) error { + return encoding.PopulateStructFromJSON(data, o) +} + +// MarshalJSON serializes to JSON +func (o *Signer) MarshalJSON() ([]byte, error) { + return encoding.SerializeStructToJSON(o) } // Meta stores a corim-meta-map with JSON and CBOR serializations. It carries diff --git a/corim/meta_test.go b/corim/meta_test.go index a4dd34cd..811495cf 100644 --- a/corim/meta_test.go +++ b/corim/meta_test.go @@ -183,3 +183,15 @@ func TestMeta_FromCBOR_full(t *testing.T) { assert.Equal(t, notBefore.Unix(), actual.Validity.NotBefore.Unix()) assert.Equal(t, notAfter.Unix(), actual.Validity.NotAfter.Unix()) } + +func Test_Signer_Valid(t *testing.T) { + var signer Signer + + assert.EqualError(t, signer.Valid(), "empty name") + + signer.Name = "test-signer" + uri := comid.TaggedURI("@@@") + signer.URI = &uri + + assert.EqualError(t, signer.Valid(), `invalid URI: "@@@" is not an absolute URI`) +} diff --git a/corim/signedcorim_test.go b/corim/signedcorim_test.go index 9e030b3c..eae15d45 100644 --- a/corim/signedcorim_test.go +++ b/corim/signedcorim_test.go @@ -246,7 +246,7 @@ func TestSignedCorim_FromCOSE_fail_invalid_corim(t *testing.T) { var actual SignedCorim err := actual.FromCOSE(tv) - assert.EqualError(t, err, "failed validation of unsigned CoRIM: tags validation failed: no tags") + assert.EqualError(t, err, `failed CBOR decoding of unsigned CoRIM: missing mandatory field "Tags" (1)`) } func TestSignedCorim_FromCOSE_fail_no_content_type(t *testing.T) { diff --git a/corim/unsignedcorim.go b/corim/unsignedcorim.go index 68ca6786..a7488500 100644 --- a/corim/unsignedcorim.go +++ b/corim/unsignedcorim.go @@ -4,12 +4,13 @@ package corim import ( - "encoding/json" "errors" "fmt" "time" "github.com/veraison/corim/cots" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" "github.com/veraison/corim/comid" "github.com/veraison/eat" @@ -19,12 +20,22 @@ import ( // UnsignedCorim is the top-level representation of the unsigned-corim-map with // CBOR and JSON serialization. type UnsignedCorim struct { - ID swid.TagID `cbor:"0,keyasint" json:"corim-id"` - Tags []Tag `cbor:"1,keyasint" json:"tags"` + ID swid.TagID `cbor:"0,keyasint" json:"corim-id"` + // note: even though tags are mandatory for CoRIM, we allow omitting + // them in our JSON templates for cocli (the min template just has + // corim-id). Since we're never writing JSON (so far), this normally + // wouldn't matter, however the custom serialization code we use to + // handle embedded structs relies on the omitempty entry to determine + // if a filed is optional, so we use it during unmarshaling as well as + // marshaling. Hence omitempty is present for the json tag, but not + // cbor. + Tags []Tag `cbor:"1,keyasint" json:"tags,omitempty"` DependentRims *[]Locator `cbor:"2,keyasint,omitempty" json:"dependent-rims,omitempty"` Profiles *[]eat.Profile `cbor:"3,keyasint,omitempty" json:"profiles,omitempty"` RimValidity *Validity `cbor:"4,keyasint,omitempty" json:"validity,omitempty"` Entities *Entities `cbor:"5,keyasint,omitempty" json:"entities,omitempty"` + + Extensions } // NewUnsignedCorim instantiates an empty UnsignedCorim @@ -32,6 +43,16 @@ func NewUnsignedCorim() *UnsignedCorim { return &UnsignedCorim{} } +// RegisterExtensions registers a struct as a collections of extensions +func (o *UnsignedCorim) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) +} + +// GetExtensions returns pervisouosly registered extension +func (o *UnsignedCorim) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + // SetID sets the corim-id in the unsigned-corim-map to the supplied value. The // corim-id can be passed as UUID in string or binary form (i.e., byte array), // or as a (non-empty) string @@ -239,22 +260,22 @@ func (o UnsignedCorim) Valid() error { } } - return nil + return o.Extensions.ValidCorim(&o) } // ToCBOR serializes the target unsigned CoRIM to CBOR -func (o UnsignedCorim) ToCBOR() ([]byte, error) { - return em.Marshal(&o) +func (o *UnsignedCorim) ToCBOR() ([]byte, error) { + return encoding.SerializeStructToCBOR(em, o) } // FromCBOR deserializes a CBOR-encoded unsigned CoRIM into the target UnsignedCorim func (o *UnsignedCorim) FromCBOR(data []byte) error { - return dm.Unmarshal(data, o) + return encoding.PopulateStructFromCBOR(dm, data, o) } // FromJSON deserializes a JSON-encoded unsigned CoRIM into the target UnsignedCorim func (o *UnsignedCorim) FromJSON(data []byte) error { - return json.Unmarshal(data, o) + return encoding.PopulateStructFromJSON(data, o) } // Tag is either a CBOR-encoded CoMID, CoSWID or CoTS diff --git a/corim/unsignedcorim_test.go b/corim/unsignedcorim_test.go index d9b4d105..c2b7ee86 100644 --- a/corim/unsignedcorim_test.go +++ b/corim/unsignedcorim_test.go @@ -299,3 +299,11 @@ func TestUnsignedCorim_AddEntity_non_nil_empty_URI(t *testing.T) { assert.Nil(t, tv) } + +func TestUnsignedCorim_FromJSON(t *testing.T) { + data := []byte(`{"corim-id": "5c57e8f4-46cd-421b-91c9-08cf93e13cfc"}`) + + err := NewUnsignedCorim().FromJSON(data) + + assert.NoError(t, err) +} diff --git a/encoding/cbor.go b/encoding/cbor.go new file mode 100644 index 00000000..5d7a3533 --- /dev/null +++ b/encoding/cbor.go @@ -0,0 +1,409 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 + +package encoding + +import ( + "encoding/binary" + "errors" + "fmt" + "math" + "reflect" + "strconv" + "strings" + + cbor "github.com/fxamacker/cbor/v2" +) + +func SerializeStructToCBOR(em cbor.EncMode, source any) ([]byte, error) { + rawMap := newStructFieldsCBOR() + + structType := reflect.TypeOf(source) + structVal := reflect.ValueOf(source) + + if err := doSerializeStructToCBOR(em, rawMap, structType, structVal); err != nil { + return nil, err + } + + return rawMap.ToCBOR(em) +} + +func doSerializeStructToCBOR( + em cbor.EncMode, + rawMap *structFieldsCBOR, + structType reflect.Type, + structVal reflect.Value, +) error { + if structType.Kind() == reflect.Pointer { + structType = structType.Elem() + structVal = structVal.Elem() + } + + var embeds []embedded + + for i := 0; i < structVal.NumField(); i++ { + typeField := structType.Field(i) + valField := structVal.Field(i) + + if collectEmbedded(&typeField, valField, &embeds) { + continue + } + + tag, ok := typeField.Tag.Lookup("cbor") + if !ok { + continue + } + + parts := strings.Split(tag, ",") + keyString := parts[0] + + isOmitEmpty := false + if len(parts) > 1 { + for _, option := range parts[1:] { + if option == omitempty { + isOmitEmpty = true + break + } + } + } + + // do not serialize zero values if the corresponding field is + // omitempty + if isOmitEmpty && valField.IsZero() { + continue + } + + keyInt, err := strconv.Atoi(keyString) + if err != nil { + return fmt.Errorf("non-integer cbor key: %s", keyString) + } + + data, err := em.Marshal(valField.Interface()) + if err != nil { + return fmt.Errorf("error marshaling field %q: %w", + typeField.Name, + err, + ) + } + + if err := rawMap.Add(keyInt, cbor.RawMessage(data)); err != nil { + return err + } + } + + for _, emb := range embeds { + if err := doSerializeStructToCBOR(em, rawMap, emb.Type, emb.Value); err != nil { + return err + } + } + + return nil +} + +func PopulateStructFromCBOR(dm cbor.DecMode, data []byte, dest any) error { + rawMap := newStructFieldsCBOR() + + if err := rawMap.FromCBOR(dm, data); err != nil { + return err + } + + structType := reflect.TypeOf(dest) + structVal := reflect.ValueOf(dest) + + return doPopulateStructFromCBOR(dm, rawMap, structType, structVal) +} + +func doPopulateStructFromCBOR( + dm cbor.DecMode, + rawMap *structFieldsCBOR, + structType reflect.Type, + structVal reflect.Value, +) error { + if structType.Kind() == reflect.Pointer { + structType = structType.Elem() + structVal = structVal.Elem() + } + + var embeds []embedded + + for i := 0; i < structVal.NumField(); i++ { + typeField := structType.Field(i) + valField := structVal.Field(i) + + if collectEmbedded(&typeField, valField, &embeds) { + continue + } + + tag, ok := typeField.Tag.Lookup("cbor") + if !ok { + continue + } + + parts := strings.Split(tag, ",") + keyString := parts[0] + + isOmitEmpty := false + if len(parts) > 1 { + for _, option := range parts[1:] { + if option == omitempty { + isOmitEmpty = true + break + } + } + } + + keyInt, err := strconv.Atoi(keyString) + if err != nil { + return fmt.Errorf("non-integer cbor key %s", keyString) + } + + rawVal, ok := rawMap.Get(keyInt) + if !ok { + if isOmitEmpty { + continue + } + + return fmt.Errorf("missing mandatory field %q (%d)", + typeField.Name, keyInt) + } + + fieldPtr := valField.Addr().Interface() + if err := dm.Unmarshal(rawVal, fieldPtr); err != nil { + return fmt.Errorf("error unmarshalling field %q: %w", + typeField.Name, + err, + ) + } + + rawMap.Delete(keyInt) + } + + for _, emb := range embeds { + if err := doPopulateStructFromCBOR(dm, rawMap, emb.Type, emb.Value); err != nil { + return err + } + } + + return nil +} + +// structFieldsCBOR is a specialized implementation of "OrderedMap", where the +// order of the keys is kept track of, and used when serializing the map to +// CBOR. While CBOR maps do not mandate any particular ordering, and so this +// isn't strictly necessary, it is useful to have a _stable_ serialization +// order for map keys to be compatible with regular Go struct serialization +// behavior. This is also useful for tests/examples that compare encoded +// []byte's. +type structFieldsCBOR struct { + Fields map[int]cbor.RawMessage + Keys []int +} + +func newStructFieldsCBOR() *structFieldsCBOR { + return &structFieldsCBOR{ + Fields: make(map[int]cbor.RawMessage), + } +} + +func (o structFieldsCBOR) Has(key int) bool { + _, ok := o.Fields[key] + return ok +} + +func (o *structFieldsCBOR) Add(key int, val cbor.RawMessage) error { + if o.Has(key) { + return fmt.Errorf("duplicate cbor key: %d", key) + } + + o.Fields[key] = val + o.Keys = append(o.Keys, key) + + return nil +} + +func (o *structFieldsCBOR) Get(key int) (cbor.RawMessage, bool) { + val, ok := o.Fields[key] + return val, ok +} + +func (o *structFieldsCBOR) Delete(key int) { + delete(o.Fields, key) + + for i, existing := range o.Keys { + if existing == key { + o.Keys = append(o.Keys[:i], o.Keys[i+1:]...) + } + } +} + +func (o *structFieldsCBOR) ToCBOR(em cbor.EncMode) ([]byte, error) { + var out []byte + + header := byte(0xa0) // 0b101_00000 -- Major Type 5 == map + mapLen := len(o.Keys) + if mapLen == 0 { + return []byte{header}, nil + } else if mapLen < 24 { + header = header | byte(mapLen) + out = append(out, header) + } else if mapLen <= math.MaxUint8 { + header = header | byte(24) + out = append(out, header, uint8(mapLen)) + } else if mapLen <= math.MaxUint16 { + header = header | byte(25) + out = append(out, header) + out = binary.BigEndian.AppendUint16(out, uint16(mapLen)) + } else { + header = header | byte(26) + out = append(out, header) + out = binary.BigEndian.AppendUint32(out, uint32(mapLen)) + } + // Since len() returns an int, the value cannot exceed MaxUint32, so + // the 8-byte length variant cannot occur. + + for _, key := range o.Keys { + marshalledKey, err := em.Marshal(key) + if err != nil { + return nil, fmt.Errorf("problem marshaling key %d: %w", key, err) + } + + out = append(out, marshalledKey...) + out = append(out, o.Fields[key]...) + } + + return out, nil +} + +func (o *structFieldsCBOR) FromCBOR(dm cbor.DecMode, data []byte) error { + if len(data) == 0 { + return errors.New("empty input") + } + + header := data[0] + rest := data[1:] + additionalInfo := 0x1f & header + + var err error + + majorType := (0xe0 & header) >> 5 + if majorType == 6 { // tag + _, rest, err = processAdditionalInfo(additionalInfo, rest) + if err != nil { + return err + } + + header = rest[0] + rest = rest[1:] + majorType = (0xe0 & header) >> 5 + additionalInfo = 0x1f & header + } + + if majorType != 5 { + return fmt.Errorf("expected map (CBOR Major Type 5), found Major Type %d", majorType) + } + + var mapLen int + + mapLen, rest, err = processAdditionalInfo(additionalInfo, rest) + if err != nil { + return err + } + + if mapLen != 0 { + o.Fields = make(map[int]cbor.RawMessage, mapLen) + + for i := 0; i < mapLen; i++ { + rest, err = o.unmarshalKeyValue(dm, rest) + if err != nil { + return fmt.Errorf("map item %d: %w", i, err) + } + } + } else { // mapLen == 0 --> indefinite encoding + o.Fields = make(map[int]cbor.RawMessage) + + i := 0 + done := false + for len(rest) > 0 { + if rest[0] == 0xFF { + done = true + break + } + + rest, err = o.unmarshalKeyValue(dm, rest) + if err != nil { + return fmt.Errorf("map item %d: %w", i, err) + } + + i++ + } + + if !done { + return errors.New("unexpected EOF") + } + } + + return nil +} + +func (o *structFieldsCBOR) unmarshalKeyValue(dm cbor.DecMode, rest []byte) ([]byte, error) { + var key int + var val cbor.RawMessage + var err error + + rest, err = dm.UnmarshalFirst(rest, &key) + if err != nil { + return rest, fmt.Errorf("could not unmarshal key: %w", err) + } + + rest, err = dm.UnmarshalFirst(rest, &val) + if err != nil { + return rest, fmt.Errorf("could not unmarshal value: %w", err) + } + + if err = o.Add(key, val); err != nil { + return rest, err + } + + return rest, nil +} + +func processAdditionalInfo( + additionalInfo byte, + data []byte, +) (int, []byte, error) { + var val int + rest := data + + if additionalInfo < 24 { + val = int(additionalInfo) + } else if additionalInfo < 28 { + switch additionalInfo - 23 { + case 1: + if len(data) < 1 { + return 0, nil, errors.New("unexpected EOF") + } + val = int(data[0]) + rest = data[1:] + case 2: + if len(data) < 2 { + return 0, nil, errors.New("unexpected EOF") + } + val = int(binary.BigEndian.Uint16(data[:2])) + rest = data[2:] + case 3: + if len(data) < 4 { + return 0, nil, errors.New("unexpected EOF") + } + val = int(binary.BigEndian.Uint32(data[:4])) + rest = data[4:] + default: + return 0, nil, errors.New("cbor: cannot decode length value of 8 bytes") + } + } else if additionalInfo == 31 { + val = 0 // indefinite encoding + } else { + return 0, nil, fmt.Errorf("cbor: unexpected additional information value %d", additionalInfo) + } + + return val, rest, nil +} diff --git a/encoding/cbor_test.go b/encoding/cbor_test.go new file mode 100644 index 00000000..1dda9146 --- /dev/null +++ b/encoding/cbor_test.go @@ -0,0 +1,279 @@ +// Copyright 2021 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package encoding + +import ( + "testing" + + "github.com/fxamacker/cbor/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_PopulateStructFromCBOR_simple(t *testing.T) { + type SimpleStruct struct { + FieldOne string `cbor:"0,keyasint,omitempty"` + FieldTwo int `cbor:"1,keyasint"` + } + + var v SimpleStruct + + data := []byte{ + 0xa2, // map(2) + + 0x00, // key 0 + 0x64, // val tstr(4) + 0x61, 0x63, 0x6d, 0x65, // "acme" + + 0x01, // key 1 + 0x06, // val 6 + } + + dm, err := cbor.DecOptions{}.DecMode() + require.NoError(t, err) + + err = PopulateStructFromCBOR(dm, data, &v) + require.NoError(t, err) + assert.Equal(t, "acme", v.FieldOne) + assert.Equal(t, 6, v.FieldTwo) + + data = []byte{ + 0xa1, // map(1) + + 0x01, // key 1 + 0x06, // val 6 + } + v = SimpleStruct{} + + err = PopulateStructFromCBOR(dm, data, &v) + require.NoError(t, err) + assert.Equal(t, "", v.FieldOne) + assert.Equal(t, 6, v.FieldTwo) + + data = []byte{ + 0xa1, // map(1) + + 0x02, // key 2 + 0x64, // val tstr(4) + 0x61, 0x63, 0x6d, 0x65, // "acme" + } + v = SimpleStruct{} + + err = PopulateStructFromCBOR(dm, data, &v) + assert.EqualError(t, err, `missing mandatory field "FieldTwo" (1)`) + + err = PopulateStructFromCBOR(dm, []byte{0x01}, &v) + assert.EqualError(t, err, `expected map (CBOR Major Type 5), found Major Type 0`) + + type CompositeStruct struct { + FieldThree string `cbor:"2,keyasint"` + SimpleStruct + } + + var c CompositeStruct + + data = []byte{ + 0xa3, // map(3) + + 0x00, // key 0 + 0x64, // val tstr(4) + 0x61, 0x63, 0x6d, 0x65, // "acme" + + 0x01, // key 1 + 0x06, // val 6 + + 0x02, // key 2 + 0x63, // val tstr(3) + 0x66, 0x6f, 0x6f, // "foo" + } + + err = PopulateStructFromCBOR(dm, data, &c) + require.NoError(t, err) + assert.Equal(t, "acme", c.FieldOne) + assert.Equal(t, 6, c.FieldTwo) + assert.Equal(t, "foo", c.FieldThree) + + em, err := cbor.EncOptions{}.EncMode() + require.NoError(t, err) + + res, err := SerializeStructToCBOR(em, &c) + require.NoError(t, err) + + var c2 CompositeStruct + err = PopulateStructFromCBOR(dm, res, &c2) + require.NoError(t, err) + assert.EqualValues(t, c, c2) + +} + +func Test_structFieldsCBOR_CRUD(t *testing.T) { + sf := newStructFieldsCBOR() + + err := sf.Add(2, cbor.RawMessage{0x02}) + assert.NoError(t, err) + + err = sf.Add(1, cbor.RawMessage{0x01}) + assert.NoError(t, err) + + err = sf.Add(3, cbor.RawMessage{0x03}) + assert.NoError(t, err) + + assert.Equal(t, []int{2, 1, 3}, sf.Keys) + assert.True(t, sf.Has(3)) + assert.False(t, sf.Has(4)) + + val, ok := sf.Get(2) + assert.True(t, ok) + assert.Equal(t, cbor.RawMessage{0x2}, val) + + _, ok = sf.Get(4) + assert.False(t, ok) + + sf.Delete(2) + _, ok = sf.Get(2) + assert.False(t, ok) + + err = sf.Add(1, cbor.RawMessage{0x11}) + assert.EqualError(t, err, "duplicate cbor key: 1") +} + +func Test_structFieldsCBOR_CBOR_roundtrip(t *testing.T) { + em, err := cbor.EncOptions{}.EncMode() + require.NoError(t, err) + dm, err := cbor.DecOptions{}.DecMode() + require.NoError(t, err) + + sf := newStructFieldsCBOR() + + data, err := sf.ToCBOR(em) + require.NoError(t, err) + assert.Equal(t, data, []byte{0xa0}) // empty map + + for i := 0; i < 5; i++ { + err = sf.Add(i, cbor.RawMessage{0x00}) + require.NoError(t, err) + } + + data, err = sf.ToCBOR(em) + require.NoError(t, err) + assert.Equal(t, data, []byte{ + 0xa5, // map 5 + 0x00, 0x00, + 0x01, 0x00, + 0x02, 0x00, + 0x03, 0x00, + 0x04, 0x00, + }) + + sfOut := newStructFieldsCBOR() + err = sfOut.FromCBOR(dm, data) + require.NoError(t, err) + assert.Equal(t, sf, sfOut) + + for i := 5; i < 200; i++ { + err = sf.Add(i, cbor.RawMessage{0x00}) + require.NoError(t, err) + } + + data, err = sf.ToCBOR(em) + require.NoError(t, err) + assert.Equal(t, data[:2], []byte{ + 0xb8, 0xc8, // map 200 + }) + + sfOut = newStructFieldsCBOR() + err = sfOut.FromCBOR(dm, data) + require.NoError(t, err) + assert.Equal(t, sf, sfOut) + + for i := 200; i < 2048; i++ { + err = sf.Add(i, cbor.RawMessage{0x00}) + require.NoError(t, err) + } + + data, err = sf.ToCBOR(em) + require.NoError(t, err) + assert.Equal(t, data[:3], []byte{ + 0xb9, 0x08, 0x00, // map 2048 + }) + + sfOut = newStructFieldsCBOR() + err = sfOut.FromCBOR(dm, data) + require.NoError(t, err) + assert.Equal(t, sf, sfOut) +} + +func Test_structFieldsCBOR_CBOR_decode_tagged(t *testing.T) { + data := []byte{ + 0xc1, // tag 1 + 0xa5, // map 5 + 0x00, 0x00, + 0x01, 0x00, + 0x02, 0x00, + 0x03, 0x00, + 0x04, 0x00, + } + + dm, err := cbor.DecOptions{}.DecMode() + require.NoError(t, err) + + sfOut := newStructFieldsCBOR() + err = sfOut.FromCBOR(dm, data) + require.NoError(t, err) + assert.Equal(t, []int{0, 1, 2, 3, 4}, sfOut.Keys) +} + +func Test_structFieldsCBOR_CBOR_decode_indefinite(t *testing.T) { + data := []byte{ + 0xbf, // indefinite map + 0x00, 0x00, + 0x01, 0x00, + 0x02, 0x00, + 0x03, 0x00, + 0x04, 0x00, + 0xff, // break + } + + dm, err := cbor.DecOptions{}.DecMode() + require.NoError(t, err) + + sfOut := newStructFieldsCBOR() + err = sfOut.FromCBOR(dm, data) + require.NoError(t, err) + assert.Equal(t, []int{0, 1, 2, 3, 4}, sfOut.Keys) +} + +func Test_structFieldsCBOR_CBOR_decode_negative(t *testing.T) { + dm, err := cbor.DecOptions{}.DecMode() + require.NoError(t, err) + + sfOut := newStructFieldsCBOR() + err = sfOut.FromCBOR(dm, []byte{0xa1, 0xff, 0x00}) + assert.EqualError(t, err, `map item 0: could not unmarshal key: cbor: unexpected "break" code`) + err = sfOut.FromCBOR(dm, []byte{0xbf, 0x00, 0x00}) + assert.EqualError(t, err, `unexpected EOF`) + err = sfOut.FromCBOR(dm, []byte{0xa1, 0x00, 0xff}) + assert.EqualError(t, err, `map item 0: could not unmarshal value: cbor: unexpected "break" code`) + + err = sfOut.FromCBOR(dm, []byte{0x00}) + assert.EqualError(t, err, `expected map (CBOR Major Type 5), found Major Type 0`) +} + +func Test_processAdditionalInfo(t *testing.T) { + addInfo := byte(26) + data := []byte{0x00, 0x00, 0x00, 0x01} + + val, rest, err := processAdditionalInfo(addInfo, data) + require.NoError(t, err) + assert.Equal(t, 1, val) + assert.Equal(t, []byte{}, rest) + + _, _, err = processAdditionalInfo(byte(27), data) + assert.EqualError(t, err, "cbor: cannot decode length value of 8 bytes") + + _, _, err = processAdditionalInfo(byte(28), data) + assert.EqualError(t, err, "cbor: unexpected additional information value 28") + + _, _, err = processAdditionalInfo(addInfo, []byte{}) + assert.EqualError(t, err, "unexpected EOF") +} diff --git a/encoding/embedded.go b/encoding/embedded.go new file mode 100644 index 00000000..1bfeb999 --- /dev/null +++ b/encoding/embedded.go @@ -0,0 +1,49 @@ +package encoding + +import "reflect" + +const omitempty = "omitempty" + +type embedded struct { + Type reflect.Type + Value reflect.Value +} + +// collectEmbedded returns true if the Field is embedded (regardless of +// whether or not it was collected). +func collectEmbedded( + typeField *reflect.StructField, + valField reflect.Value, + embeds *[]embedded, +) bool { + // embedded fields are alway anonymous:w + if !typeField.Anonymous { + return false + } + + if typeField.Name == typeField.Type.Name() && + (typeField.Type.Kind() == reflect.Struct || + typeField.Type.Kind() == reflect.Interface) { + + var fieldType reflect.Type + var fieldValue reflect.Value + + if typeField.Type.Kind() == reflect.Interface { + fieldValue = valField.Elem() + if fieldValue.Kind() == reflect.Invalid { + // no value underlying the interface + return true + } + // use the interface's underlying value's real type + fieldType = valField.Elem().Type() + } else { + fieldType = typeField.Type + fieldValue = valField + } + + *embeds = append(*embeds, embedded{Type: fieldType, Value: fieldValue}) + return true + } + + return false +} diff --git a/encoding/json.go b/encoding/json.go new file mode 100644 index 00000000..e8193e4b --- /dev/null +++ b/encoding/json.go @@ -0,0 +1,319 @@ +package encoding + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "reflect" + "strings" +) + +func SerializeStructToJSON(source any) ([]byte, error) { + rawMap := newStructFieldsJSON() + + structType := reflect.TypeOf(source) + structVal := reflect.ValueOf(source) + + if err := doSerializeStructToJSON(rawMap, structType, structVal); err != nil { + return nil, err + } + + return rawMap.ToJSON() +} + +func doSerializeStructToJSON( + rawMap *structFieldsJSON, + structType reflect.Type, + structVal reflect.Value, +) error { + if structType.Kind() == reflect.Pointer { + structType = structType.Elem() + structVal = structVal.Elem() + } + + var embeds []embedded + + for i := 0; i < structVal.NumField(); i++ { + typeField := structType.Field(i) + valField := structVal.Field(i) + + if collectEmbedded(&typeField, valField, &embeds) { + continue + } + + tag, ok := typeField.Tag.Lookup("json") + if !ok { + continue + } + + parts := strings.Split(tag, ",") + key := parts[0] + + isOmitEmpty := false + if len(parts) > 1 { + for _, option := range parts[1:] { + if option == omitempty { + isOmitEmpty = true + break + } + } + } + + // do not serialize zero values if the corresponding field is + // omitempty + if isOmitEmpty && valField.IsZero() { + continue + } + + data, err := json.Marshal(valField.Interface()) + if err != nil { + return fmt.Errorf("error marshaling field %q: %w", + typeField.Name, + err, + ) + } + + if err := rawMap.Add(key, json.RawMessage(data)); err != nil { + return err + } + } + + for _, emb := range embeds { + if err := doSerializeStructToJSON(rawMap, emb.Type, emb.Value); err != nil { + return err + } + } + + return nil +} + +func PopulateStructFromJSON(data []byte, dest any) error { + rawMap := newStructFieldsJSON() + + if err := rawMap.FromJSON(data); err != nil { + return err + } + + structType := reflect.TypeOf(dest) + structVal := reflect.ValueOf(dest) + + return doPopulateStructFromJSON(rawMap, structType, structVal) +} + +func doPopulateStructFromJSON( + rawMap *structFieldsJSON, + structType reflect.Type, + structVal reflect.Value, +) error { + if structType.Kind() == reflect.Pointer { + structType = structType.Elem() + structVal = structVal.Elem() + } + + var embeds []embedded + + for i := 0; i < structVal.NumField(); i++ { + typeField := structType.Field(i) + valField := structVal.Field(i) + + if collectEmbedded(&typeField, valField, &embeds) { + continue + } + + tag, ok := typeField.Tag.Lookup("json") + if !ok { + continue + } + + parts := strings.Split(tag, ",") + key := parts[0] + + isOmitEmpty := false + if len(parts) > 1 { + for _, option := range parts[1:] { + if option == omitempty { + isOmitEmpty = true + break + } + } + } + + rawVal, ok := rawMap.Get(key) + if !ok { + if isOmitEmpty { + continue + } + + return fmt.Errorf("missing mandatory field %q (%q)", + typeField.Name, key) + } + + fieldPtr := valField.Addr().Interface() + if err := json.Unmarshal(rawVal, fieldPtr); err != nil { + return fmt.Errorf("error unmarshalling field %q: %w", + typeField.Name, + err, + ) + } + + rawMap.Delete(key) + } + + for _, emb := range embeds { + if err := doPopulateStructFromJSON(rawMap, emb.Type, emb.Value); err != nil { + return err + } + } + + return nil +} + +// structFieldsJSON is a specialized implementation of "OrderedMap", where the +// order of the keys is kept track of, and used when serializing the map to +// JSON. While JSON maps do not mandate any particular ordering, and so this +// isn't strictly necessary, it is useful to have a _stable_ serialization +// order for map keys to be compatible with regular Go struct serialization +// behavior. This is also useful for tests/examples that compare encoded +// []byte's. +type structFieldsJSON struct { + Fields map[string]json.RawMessage + Keys []string +} + +func newStructFieldsJSON() *structFieldsJSON { + return &structFieldsJSON{ + Fields: make(map[string]json.RawMessage), + } +} + +func (o structFieldsJSON) Has(key string) bool { + _, ok := o.Fields[key] + return ok +} + +func (o *structFieldsJSON) Add(key string, val json.RawMessage) error { + if o.Has(key) { + return fmt.Errorf("duplicate JSON key: %q", key) + } + + o.Fields[key] = val + o.Keys = append(o.Keys, key) + + return nil +} + +func (o *structFieldsJSON) Get(key string) (json.RawMessage, bool) { + val, ok := o.Fields[key] + return val, ok +} + +func (o *structFieldsJSON) Delete(key string) { + delete(o.Fields, key) + + for i, existing := range o.Keys { + if existing == key { + o.Keys = append(o.Keys[:i], o.Keys[i+1:]...) + } + } +} + +func (o *structFieldsJSON) ToJSON() ([]byte, error) { + var out bytes.Buffer + + out.Write([]byte("{")) + + first := true + for _, key := range o.Keys { + if first { + first = false + } else { + out.Write([]byte(",")) + } + marshaledKey, err := json.Marshal(key) + if err != nil { + return nil, fmt.Errorf("problem marshaling key %s: %w", key, err) + } + out.Write(marshaledKey) + out.Write([]byte(":")) + out.Write(o.Fields[key]) + } + + out.Write([]byte("}")) + + return out.Bytes(), nil +} + +func (o *structFieldsJSON) FromJSON(data []byte) error { + if err := json.Unmarshal(data, &o.Fields); err != nil { + return err + } + + return o.unmarshalKeys(data) +} + +func (o *structFieldsJSON) unmarshalKeys(data []byte) error { + + decoder := json.NewDecoder(bytes.NewReader(data)) + + token, err := decoder.Token() + if err != nil { + return err + } + + if token != json.Delim('{') { + return errors.New("expected start of object") + } + + var keys []string + + for { + token, err = decoder.Token() + if err != nil { + return err + } + + if token == json.Delim('}') { + break + } + + key, ok := token.(string) + if !ok { + return fmt.Errorf("expected string, found %T", token) + } + + keys = append(keys, key) + + if err := skipValue(decoder); err != nil { + return err + } + } + + o.Keys = keys + + return nil +} + +var errEndOfStream = errors.New("invalid end of array or object") + +func skipValue(decoder *json.Decoder) error { + + token, err := decoder.Token() + if err != nil { + return err + } + switch token { + case json.Delim('['), json.Delim('{'): + for { + if err := skipValue(decoder); err != nil { + if err == errEndOfStream { + break + } + return err + } + } + case json.Delim(']'), json.Delim('}'): + return errEndOfStream + } + return nil +} diff --git a/encoding/json_test.go b/encoding/json_test.go new file mode 100644 index 00000000..132d1587 --- /dev/null +++ b/encoding/json_test.go @@ -0,0 +1,121 @@ +package encoding + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_PopulateStructFromJSON(t *testing.T) { + type SimpleStruct struct { + FieldOne string `json:"field-one,omitempty"` + FieldTwo int `json:"field-two"` + } + + var v SimpleStruct + + data := []byte(`{"field-one": "acme", "field-two": 6}`) + + err := PopulateStructFromJSON(data, &v) + require.NoError(t, err) + assert.Equal(t, "acme", v.FieldOne) + assert.Equal(t, 6, v.FieldTwo) + + data = []byte(`{"field-two": 6}`) + v = SimpleStruct{} + + err = PopulateStructFromJSON(data, &v) + require.NoError(t, err) + assert.Equal(t, "", v.FieldOne) + assert.Equal(t, 6, v.FieldTwo) + + data = []byte(`{"field-one": "acme"}`) + v = SimpleStruct{} + + err = PopulateStructFromJSON(data, &v) + assert.EqualError(t, err, `missing mandatory field "FieldTwo" ("field-two")`) + + err = PopulateStructFromJSON([]byte("7"), &v) + assert.EqualError(t, err, `json: cannot unmarshal number into Go value of type map[string]json.RawMessage`) + + type CompositeStruct struct { + FieldThree string `json:"field-three"` + SimpleStruct + } + + var c CompositeStruct + + data = []byte(`{"field-one": "acme", "field-two": 6, "field-three": "foo"}`) + + err = PopulateStructFromJSON(data, &c) + require.NoError(t, err) + assert.Equal(t, "acme", c.FieldOne) + assert.Equal(t, 6, c.FieldTwo) + assert.Equal(t, "foo", c.FieldThree) + + res, err := SerializeStructToJSON(&c) + require.NoError(t, err) + + var c2 CompositeStruct + err = PopulateStructFromJSON(res, &c2) + require.NoError(t, err) + assert.EqualValues(t, c, c2) +} + +func Test_structFieldsJSON_CRUD(t *testing.T) { + sf := newStructFieldsJSON() + + err := sf.Add("two", json.RawMessage("2")) + assert.NoError(t, err) + + err = sf.Add("one", json.RawMessage("1")) + assert.NoError(t, err) + + err = sf.Add("three", json.RawMessage("3")) + assert.NoError(t, err) + + assert.Equal(t, []string{"two", "one", "three"}, sf.Keys) + assert.True(t, sf.Has("three")) + assert.False(t, sf.Has("four")) + + val, ok := sf.Get("two") + assert.True(t, ok) + assert.Equal(t, json.RawMessage("2"), val) + + _, ok = sf.Get("four") + assert.False(t, ok) + + sf.Delete("two") + _, ok = sf.Get("two") + assert.False(t, ok) + + err = sf.Add("one", json.RawMessage("4")) + assert.EqualError(t, err, `duplicate JSON key: "one"`) +} + +func Test_skipValue(t *testing.T) { + text := "" + decoder := json.NewDecoder(strings.NewReader(text)) + err := skipValue(decoder) + assert.EqualError(t, err, "EOF") + + text = "[]" + decoder = json.NewDecoder(strings.NewReader(text)) + _, _ = decoder.Token() // skip the '[' + err = skipValue(decoder) + assert.EqualError(t, err, "invalid end of array or object") + + text = `{"embed": {"one": 1, "two": [1,2,3]}, "other": 1}` + decoder = json.NewDecoder(strings.NewReader(text)) + _, _ = decoder.Token() // skip the '{' + _, _ = decoder.Token() // skip the '"embed"' + err = skipValue(decoder) + assert.NoError(t, err) + + token, err := decoder.Token() + assert.NoError(t, err) + assert.Equal(t, "other", token) +} diff --git a/extensions/extensions.go b/extensions/extensions.go new file mode 100644 index 00000000..4be60fd1 --- /dev/null +++ b/extensions/extensions.go @@ -0,0 +1,379 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package extensions + +import ( + "errors" + "fmt" + "reflect" + "strings" + + "github.com/spf13/cast" +) + +var ErrExtensionNotFound = errors.New("extension not found") + +type IExtensionsValue any + +type Extensions struct { + IExtensionsValue `json:"extensions,omitempty"` +} + +func (o *Extensions) Register(exts IExtensionsValue) { + if reflect.TypeOf(exts).Kind() != reflect.Pointer { + panic("attempting to register a non-pointer IExtensionsValue") + } + + o.IExtensionsValue = exts +} + +func (o *Extensions) HaveExtensions() bool { + return o.IExtensionsValue != nil +} + +func (o *Extensions) Get(name string) (any, error) { + if o.IExtensionsValue == nil { + return nil, fmt.Errorf("%w: %s", ErrExtensionNotFound, name) + } + + extType := reflect.TypeOf(o.IExtensionsValue) + extVal := reflect.ValueOf(o.IExtensionsValue) + if extType.Kind() == reflect.Pointer { + extType = extType.Elem() + extVal = extVal.Elem() + } + + var fieldName, fieldJSONTag, fieldCBORTag string + for i := 0; i < extVal.NumField(); i++ { + typeField := extType.Field(i) + fieldName = typeField.Name + + tag, ok := typeField.Tag.Lookup("json") + if ok { + fieldJSONTag = strings.Split(tag, ",")[0] + } + + tag, ok = typeField.Tag.Lookup("cbor") + if ok { + fieldCBORTag = strings.Split(tag, ",")[0] + } + + if fieldName == name || fieldJSONTag == name || fieldCBORTag == name { + return extVal.Field(i).Interface(), nil + } + } + + return nil, fmt.Errorf("%w: %s", ErrExtensionNotFound, name) +} + +func (o *Extensions) MustGetString(name string) string { + v, _ := o.GetString(name) + return v +} + +func (o *Extensions) GetString(name string) (string, error) { + v, err := o.Get(name) + if err != nil { + return "", err + } + + return cast.ToStringE(v) +} + +func (o *Extensions) MustGetInt(name string) int { + v, _ := o.GetInt(name) + return v +} + +func (o *Extensions) GetInt(name string) (int, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToIntE(v) +} + +func (o *Extensions) MustGetInt64(name string) int64 { + v, _ := o.GetInt64(name) + return v +} + +func (o *Extensions) GetInt64(name string) (int64, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToInt64E(v) +} + +func (o *Extensions) MustGetInt32(name string) int32 { + v, _ := o.GetInt32(name) + return v +} + +func (o *Extensions) GetInt32(name string) (int32, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToInt32E(v) +} + +func (o *Extensions) MustGetInt16(name string) int16 { + v, _ := o.GetInt16(name) + return v +} + +func (o *Extensions) GetInt16(name string) (int16, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToInt16E(v) +} + +func (o *Extensions) MustGetInt8(name string) int8 { + v, _ := o.GetInt8(name) + return v +} + +func (o *Extensions) GetInt8(name string) (int8, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToInt8E(v) +} + +func (o *Extensions) MustGetUint(name string) uint { + v, _ := o.GetUint(name) + return v +} + +func (o *Extensions) GetUint(name string) (uint, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToUintE(v) +} + +func (o *Extensions) MustGetUint64(name string) uint64 { + v, _ := o.GetUint64(name) + return v +} + +func (o *Extensions) GetUint64(name string) (uint64, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToUint64E(v) +} + +func (o *Extensions) MustGetUint32(name string) uint32 { + v, _ := o.GetUint32(name) + return v +} + +func (o *Extensions) GetUint32(name string) (uint32, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToUint32E(v) +} + +func (o *Extensions) MustGetUint16(name string) uint16 { + v, _ := o.GetUint16(name) + return v +} + +func (o *Extensions) GetUint16(name string) (uint16, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToUint16E(v) +} + +func (o *Extensions) MustGetUint8(name string) uint8 { + v, _ := o.GetUint8(name) + return v +} + +func (o *Extensions) GetUint8(name string) (uint8, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToUint8E(v) +} + +func (o *Extensions) MustGetFloat32(name string) float32 { + v, _ := o.GetFloat32(name) + return v +} + +func (o *Extensions) GetFloat32(name string) (float32, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToFloat32E(v) +} + +func (o *Extensions) MustGetFloat64(name string) float64 { + v, _ := o.GetFloat64(name) + return v +} + +func (o *Extensions) GetFloat64(name string) (float64, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToFloat64E(v) +} + +func (o *Extensions) MustGetBool(name string) bool { + v, _ := o.GetBool(name) + return v +} + +func (o *Extensions) GetBool(name string) (bool, error) { + v, err := o.Get(name) + if err != nil { + return false, err + } + + return cast.ToBoolE(v) +} + +func (o *Extensions) MustGetSlice(name string) []any { + v, _ := o.GetSlice(name) + return v +} + +func (o *Extensions) GetSlice(name string) ([]any, error) { + v, err := o.Get(name) + if err != nil { + return []any{}, err + } + + return cast.ToSliceE(v) +} + +func (o *Extensions) MustGetIntSlice(name string) []int { + v, _ := o.GetIntSlice(name) + return v +} + +func (o *Extensions) GetIntSlice(name string) ([]int, error) { + v, err := o.Get(name) + if err != nil { + return []int{}, err + } + + return cast.ToIntSliceE(v) +} + +func (o *Extensions) MustGetStringSlice(name string) []string { + v, _ := o.GetStringSlice(name) + return v +} + +func (o *Extensions) GetStringSlice(name string) ([]string, error) { + v, err := o.Get(name) + if err != nil { + return []string{}, err + } + + return cast.ToStringSliceE(v) +} + +func (o *Extensions) MustGetStringMap(name string) map[string]any { + v, _ := o.GetStringMap(name) + return v +} + +func (o *Extensions) GetStringMap(name string) (map[string]any, error) { + v, err := o.Get(name) + if err != nil { + return map[string]any{}, err + } + + return cast.ToStringMapE(v) +} + +func (o *Extensions) MustGetStringMapString(name string) map[string]string { + v, _ := o.GetStringMapString(name) + return v +} + +func (o *Extensions) GetStringMapString(name string) (map[string]string, error) { + v, err := o.Get(name) + if err != nil { + return map[string]string{}, err + } + + return cast.ToStringMapStringE(v) +} + +func (o *Extensions) Set(name string, value any) error { + if o.IExtensionsValue == nil { + return fmt.Errorf("%w: %s", ErrExtensionNotFound, name) + } + + extType := reflect.TypeOf(o.IExtensionsValue) + extVal := reflect.ValueOf(o.IExtensionsValue) + if extType.Kind() == reflect.Pointer { + extType = extType.Elem() + extVal = extVal.Elem() + } + + var fieldName, fieldJSONTag, fieldCBORTag string + for i := 0; i < extVal.NumField(); i++ { + typeField := extType.Field(i) + valField := extVal.Field(i) + fieldName = typeField.Name + + tag, ok := typeField.Tag.Lookup("json") + if ok { + fieldJSONTag = strings.Split(tag, ",")[0] + } + + tag, ok = typeField.Tag.Lookup("cbor") + if ok { + fieldCBORTag = strings.Split(tag, ",")[0] + } + + if fieldName == name || fieldJSONTag == name || fieldCBORTag == name { + newVal := reflect.ValueOf(value) + if newVal.CanConvert(valField.Type()) { + valField.Set(newVal.Convert(valField.Type())) + return nil + } + + return fmt.Errorf( + "cannot set field %q (of type %s) to %v (%T)", + name, typeField.Type.Name(), + value, value, + ) + } + } + + return fmt.Errorf("%w: %s", ErrExtensionNotFound, name) +} diff --git a/extensions/extensions_test.go b/extensions/extensions_test.go new file mode 100644 index 00000000..74d44f52 --- /dev/null +++ b/extensions/extensions_test.go @@ -0,0 +1,122 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package extensions + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type Entity struct { + EntityName string + Roles []int64 + + Extensions +} + +type TestExtensions struct { + Address string `cbor:"-1,keyasint,omitempty" json:"address,omitempty"` + Size int `cbor:"-2,keyasint,omitempty" json:"size,omitempty"` + YearsOnAir float32 `cbor:"-3,keyasint,omitempty" json:"years-on-air,omitempty"` + StillAiring bool `cbor:"-4,keyasint,omitempty" json:"still-airing,omitempty"` + Ages []int `cbor:"-5,keyasint,omitempty" json:"ages,omitempty"` + Jobs map[string]string `cbor:"-6,keyasint,omitempty" json:"jobs,omitempty"` +} + +func TestExtensions_Register(t *testing.T) { + exts := Extensions{} + assert.False(t, exts.HaveExtensions()) + + exts.Register(&TestExtensions{}) + assert.True(t, exts.HaveExtensions()) + + badRegister := func() { + exts.Register(TestExtensions{}) + } + + assert.Panics(t, badRegister) +} + +func TestExtensions_GetSet(t *testing.T) { + extsVal := TestExtensions{ + Address: "742 Evergreen Terrace", + Size: 6, + YearsOnAir: 33.8, + StillAiring: true, + Ages: []int{2, 7, 8, 10, 37, 38}, + Jobs: map[string]string{ + "Homer": "safety inspector", + "Marge": "housewife", + "Bart": "elementary school student", + "Lisa": "elementary school student", + }, + } + exts := Extensions{IExtensionsValue: &extsVal} + + v, err := exts.GetInt("size") + assert.NoError(t, err) + assert.Equal(t, 6, v) + + assert.Equal(t, 6, exts.MustGetInt("size")) + assert.Equal(t, int64(6), exts.MustGetInt64("size")) + assert.Equal(t, int32(6), exts.MustGetInt32("size")) + assert.Equal(t, int16(6), exts.MustGetInt16("size")) + assert.Equal(t, int8(6), exts.MustGetInt8("size")) + + assert.Equal(t, uint(6), exts.MustGetUint("size")) + assert.Equal(t, uint64(6), exts.MustGetUint64("size")) + assert.Equal(t, uint32(6), exts.MustGetUint32("size")) + assert.Equal(t, uint16(6), exts.MustGetUint16("size")) + assert.Equal(t, uint8(6), exts.MustGetUint8("size")) + + assert.InEpsilon(t, float32(33.8), exts.MustGetFloat32("years-on-air"), 0.000001) + assert.InEpsilon(t, float64(33.8), exts.MustGetFloat64("-3"), 0.000001) + + assert.Equal(t, true, exts.MustGetBool("StillAiring")) + + _, err = exts.GetSlice("ages") + assert.EqualError(t, err, + `unable to cast []int{2, 7, 8, 10, 37, 38} of type []int to []interface{}`) + assert.Nil(t, exts.MustGetSlice("ages")) + + assert.EqualValues(t, []int{2, 7, 8, 10, 37, 38}, exts.MustGetIntSlice("ages")) + assert.EqualValues(t, []string{"2", "7", "8", "10", "37", "38"}, + exts.MustGetStringSlice("ages")) + + assert.EqualValues(t, map[string]string{ + "Homer": "safety inspector", + "Marge": "housewife", + "Bart": "elementary school student", + "Lisa": "elementary school student", + }, exts.MustGetStringMapString("jobs")) + + _, err = exts.GetStringMap("jobs") + assert.EqualError(t, err, + `unable to cast map[string]string{"Bart":"elementary school student", "Homer":"safety inspector", "Lisa":"elementary school student", "Marge":"housewife"} of type map[string]string to map[string]interface{}`) + m := exts.MustGetStringMap("jobs") + assert.Equal(t, map[string]any{}, m) + + s, err := exts.GetString("address") + assert.NoError(t, err) + assert.Equal(t, "742 Evergreen Terrace", s) + + _, err = exts.GetInt("address") + assert.EqualError(t, err, `unable to cast "742 Evergreen Terrace" of type string to int`) + + _, err = exts.GetInt("foo") + assert.EqualError(t, err, "extension not found: foo") + + err = exts.Set("-1", "123 Fake Street") + assert.NoError(t, err) + + s, err = exts.GetString("address") + assert.NoError(t, err) + assert.Equal(t, "123 Fake Street", s) + + err = exts.Set("Size", "foo") + assert.EqualError(t, err, `cannot set field "Size" (of type int) to foo (string)`) + + assert.Equal(t, "", exts.MustGetString("does-not-exist")) + assert.Equal(t, 0, exts.MustGetInt("does-not-exist")) +} diff --git a/go.mod b/go.mod index 9d95f047..96a5aa49 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,12 @@ module github.com/veraison/corim go 1.18 require ( - github.com/fxamacker/cbor/v2 v2.4.0 + github.com/fxamacker/cbor/v2 v2.5.0 github.com/golang/mock v1.6.0 github.com/google/uuid v1.3.0 github.com/lestrrat-go/jwx/v2 v2.0.8 github.com/spf13/afero v1.9.2 + github.com/spf13/cast v1.4.1 github.com/spf13/cobra v1.2.1 github.com/spf13/viper v1.9.0 github.com/stretchr/testify v1.8.2 @@ -35,7 +36,6 @@ require ( github.com/moogar0880/problems v0.1.1 // indirect github.com/pelletier/go-toml v1.9.4 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/spf13/cast v1.4.1 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.2.0 // indirect diff --git a/go.sum b/go.sum index c6347523..e1ae4e55 100644 --- a/go.sum +++ b/go.sum @@ -90,8 +90,8 @@ github.com/fsnotify/fsnotify v1.5.1 h1:mZcQUHVQUQWoPXXtuf9yuEXKudkV2sx1E06UadKWp github.com/fsnotify/fsnotify v1.5.1/go.mod h1:T3375wBYaZdLLcVNkcVbzGHY7f1l/uK5T5Ai1i3InKU= github.com/fxamacker/cbor/v2 v2.2.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= github.com/fxamacker/cbor/v2 v2.3.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= -github.com/fxamacker/cbor/v2 v2.4.0 h1:ri0ArlOR+5XunOP8CRUowT0pSJOwhW098ZCUyskZD88= -github.com/fxamacker/cbor/v2 v2.4.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= +github.com/fxamacker/cbor/v2 v2.5.0 h1:oHsG0V/Q6E/wqTS2O1Cozzsy69nqCiguo5Q1a1ADivE= +github.com/fxamacker/cbor/v2 v2.5.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=