From 663624980e6215261e3584bf6ca29a59ae8ab782 Mon Sep 17 00:00:00 2001 From: Sergei Trofimov Date: Fri, 6 Oct 2023 10:58:36 +0100 Subject: [PATCH] Implement struct extensions Implement struct extension mechanism and add extension points in places specified by https://www.ietf.org/archive/id/draft-ietf-rats-corim-02.html Extensions are implemented via Extensions object that allows registering user-defined structs are extensions. Fields from these user-defined structs are merged into the parent struct that embeds the Extensions object via some custom serialization code (this required updating the cbor dependency to v2.5.0, and go version to 1.19). Extensions also implements Viper*-like accessor methods for convenient access to the extension fields from the parent struct. (* see github.com/spf13/viper) Signed-off-by: Sergei Trofimov --- .github/workflows/ci-go-cover.yml | 2 +- .github/workflows/linters.yml | 2 +- Makefile | 2 + cocli/cmd/corimCreate_test.go | 4 +- cocli/cmd/corimDisplay_test.go | 2 +- cocli/cmd/corimExtract_test.go | 2 +- cocli/cmd/corimSign_test.go | 4 +- comid/comid.go | 24 +- comid/comid_test.go | 51 ++++ comid/cryptokey_test.go | 2 +- comid/entity.go | 46 +++- comid/example_test.go | 14 +- comid/extensions.go | 173 +++++++++++++ comid/extensions_test.go | 96 +++++++ comid/flagsmap.go | 46 +++- comid/measurement.go | 10 + comid/measurement_test.go | 2 +- comid/referencevalue_test.go | 22 ++ comid/triples.go | 41 ++- corim/entity.go | 36 ++- corim/extensions.go | 68 +++++ corim/extensions_test.go | 86 +++++++ corim/meta.go | 36 ++- corim/meta_test.go | 12 + corim/signedcorim_test.go | 2 +- corim/unsignedcorim.go | 37 ++- corim/unsignedcorim_test.go | 8 + encoding/cbor.go | 409 ++++++++++++++++++++++++++++++ encoding/cbor_test.go | 279 ++++++++++++++++++++ encoding/embedded.go | 49 ++++ encoding/json.go | 319 +++++++++++++++++++++++ encoding/json_test.go | 121 +++++++++ extensions/extensions.go | 379 +++++++++++++++++++++++++++ extensions/extensions_test.go | 122 +++++++++ go.mod | 4 +- go.sum | 4 +- 36 files changed, 2473 insertions(+), 43 deletions(-) create mode 100644 comid/comid_test.go create mode 100644 comid/extensions.go create mode 100644 comid/extensions_test.go create mode 100644 comid/referencevalue_test.go create mode 100644 corim/extensions.go create mode 100644 corim/extensions_test.go create mode 100644 encoding/cbor.go create mode 100644 encoding/cbor_test.go create mode 100644 encoding/embedded.go create mode 100644 encoding/json.go create mode 100644 encoding/json_test.go create mode 100644 extensions/extensions.go create mode 100644 extensions/extensions_test.go diff --git a/.github/workflows/ci-go-cover.yml b/.github/workflows/ci-go-cover.yml index 94b48854..9dce92ca 100644 --- a/.github/workflows/ci-go-cover.yml +++ b/.github/workflows/ci-go-cover.yml @@ -26,7 +26,7 @@ jobs: steps: - uses: actions/setup-go@v2 with: - go-version: "1.18" + go-version: "1.19" - name: Checkout code uses: actions/checkout@v2 - name: Install mockgen diff --git a/.github/workflows/linters.yml b/.github/workflows/linters.yml index 665e7048..92368c61 100644 --- a/.github/workflows/linters.yml +++ b/.github/workflows/linters.yml @@ -10,7 +10,7 @@ jobs: steps: - uses: actions/setup-go@v2 with: - go-version: "1.18" + go-version: "1.19" - name: Checkout code uses: actions/checkout@v2 - name: Install golangci-lint diff --git a/Makefile b/Makefile index de492694..f5e8bbdb 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,8 @@ GOPKG := github.com/veraison/corim/corim GOPKG += github.com/veraison/corim/comid GOPKG += github.com/veraison/corim/cots GOPKG += github.com/veraison/corim/cocli/cmd +GOPKG += github.com/veraison/corim/encoding +GOPKG += github.com/veraison/corim/extensions MOCKGEN := $(shell go env GOPATH)/bin/mockgen INTERFACES := cocli/cmd/isubmitter.go diff --git a/cocli/cmd/corimCreate_test.go b/cocli/cmd/corimCreate_test.go index 33f112b8..b2dfb770 100644 --- a/cocli/cmd/corimCreate_test.go +++ b/cocli/cmd/corimCreate_test.go @@ -117,7 +117,7 @@ func Test_CorimCreateCmd_with_a_bad_comid(t *testing.T) { cmd.SetArgs(args) err = cmd.Execute() - assert.EqualError(t, err, `error loading CoMID from bad-comid.cbor: cbor: unexpected "break" code`) + assert.EqualError(t, err, `error loading CoMID from bad-comid.cbor: expected map (CBOR Major Type 5), found Major Type 7`) } func Test_CorimCreateCmd_with_an_invalid_comid(t *testing.T) { @@ -138,7 +138,7 @@ func Test_CorimCreateCmd_with_an_invalid_comid(t *testing.T) { cmd.SetArgs(args) err = cmd.Execute() - assert.EqualError(t, err, `error adding CoMID from invalid-comid.cbor (check its validity using the "comid validate" sub-command)`) + assert.EqualError(t, err, `error loading CoMID from invalid-comid.cbor: missing mandatory field "Triples" (4)`) } func Test_CorimCreateCmd_with_a_bad_coswid(t *testing.T) { diff --git a/cocli/cmd/corimDisplay_test.go b/cocli/cmd/corimDisplay_test.go index 7f87cf25..5e1ccb47 100644 --- a/cocli/cmd/corimDisplay_test.go +++ b/cocli/cmd/corimDisplay_test.go @@ -76,7 +76,7 @@ func Test_CorimDisplayCmd_invalid_signed_corim(t *testing.T) { require.NoError(t, err) err = cmd.Execute() - assert.EqualError(t, err, "error decoding signed CoRIM from invalid.cbor: failed validation of unsigned CoRIM: empty id") + assert.EqualError(t, err, `error decoding signed CoRIM from invalid.cbor: failed CBOR decoding of unsigned CoRIM: unexpected EOF`) } func Test_CorimDisplayCmd_ok_top_level_view(t *testing.T) { diff --git a/cocli/cmd/corimExtract_test.go b/cocli/cmd/corimExtract_test.go index 2ec748d1..8b476c98 100644 --- a/cocli/cmd/corimExtract_test.go +++ b/cocli/cmd/corimExtract_test.go @@ -76,7 +76,7 @@ func Test_CorimExtractCmd_invalid_signed_corim(t *testing.T) { require.NoError(t, err) err = cmd.Execute() - assert.EqualError(t, err, "error decoding signed CoRIM from invalid.cbor: failed validation of unsigned CoRIM: empty id") + assert.EqualError(t, err, `error decoding signed CoRIM from invalid.cbor: failed CBOR decoding of unsigned CoRIM: unexpected EOF`) } func Test_CorimExtractCmd_ok_save_to_default_dir(t *testing.T) { diff --git a/cocli/cmd/corimSign_test.go b/cocli/cmd/corimSign_test.go index 27956a4f..bc460376 100644 --- a/cocli/cmd/corimSign_test.go +++ b/cocli/cmd/corimSign_test.go @@ -91,7 +91,7 @@ func Test_CorimSignCmd_bad_unsigned_corim(t *testing.T) { require.NoError(t, err) err = cmd.Execute() - assert.EqualError(t, err, "error decoding unsigned CoRIM from bad.txt: unexpected EOF") + assert.EqualError(t, err, "error decoding unsigned CoRIM from bad.txt: expected map (CBOR Major Type 5), found Major Type 3") } func Test_CorimSignCmd_invalid_unsigned_corim(t *testing.T) { @@ -109,7 +109,7 @@ func Test_CorimSignCmd_invalid_unsigned_corim(t *testing.T) { require.NoError(t, err) err = cmd.Execute() - assert.EqualError(t, err, "error validating CoRIM: tags validation failed: no tags") + assert.EqualError(t, err, `error decoding unsigned CoRIM from invalid.cbor: missing mandatory field "Tags" (1)`) } func Test_CorimSignCmd_non_existent_meta_file(t *testing.T) { diff --git a/comid/comid.go b/comid/comid.go index ea9779c7..194745d3 100644 --- a/comid/comid.go +++ b/comid/comid.go @@ -8,6 +8,8 @@ import ( "fmt" "net/url" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" "github.com/veraison/swid" ) @@ -19,6 +21,8 @@ type Comid struct { Entities *Entities `cbor:"2,keyasint,omitempty" json:"entities,omitempty"` LinkedTags *LinkedTags `cbor:"3,keyasint,omitempty" json:"linked-tags,omitempty"` Triples Triples `cbor:"4,keyasint" json:"triples"` + + Extensions } // NewComid instantiates an empty Comid @@ -26,6 +30,16 @@ func NewComid() *Comid { return &Comid{} } +// RegisterExtensions registers a struct as a collections of extensions +func (o *Comid) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) +} + +// GetExtensions returns pervisouosly registered extension +func (o *Comid) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + // SetLanguage sets the language used in the target Comid to the supplied // language tag. See also: BCP 47 and the IANA Language subtag registry. func (o *Comid) SetLanguage(language string) *Comid { @@ -229,7 +243,7 @@ func (o Comid) Valid() error { return fmt.Errorf("triples validation failed: %w", err) } - return nil + return o.Extensions.ValidComid(&o) } // ToCBOR serializes the target Comid to CBOR @@ -238,17 +252,17 @@ func (o Comid) ToCBOR() ([]byte, error) { return nil, err } - return em.Marshal(&o) + return encoding.SerializeStructToCBOR(em, &o) } // FromCBOR deserializes a CBOR-encoded CoMID into the target Comid func (o *Comid) FromCBOR(data []byte) error { - return dm.Unmarshal(data, o) + return encoding.PopulateStructFromCBOR(dm, data, o) } // FromJSON deserializes a JSON-encoded CoMID into the target Comid func (o *Comid) FromJSON(data []byte) error { - return json.Unmarshal(data, o) + return encoding.PopulateStructFromJSON(data, o) } // ToJSON serializes the target Comid to JSON @@ -257,7 +271,7 @@ func (o Comid) ToJSON() ([]byte, error) { return nil, err } - return json.Marshal(&o) + return encoding.SerializeStructToJSON(&o) } func (o Comid) ToJSONPretty(indent string) ([]byte, error) { diff --git a/comid/comid_test.go b/comid/comid_test.go new file mode 100644 index 00000000..a8cdf198 --- /dev/null +++ b/comid/comid_test.go @@ -0,0 +1,51 @@ +package comid + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/veraison/swid" +) + +func Test_Comid_Extensions(t *testing.T) { + c := NewComid() + assert.Nil(t, c.GetExtensions()) + assert.Equal(t, "", c.MustGetString("field-one")) + + err := c.Set("field-one", "foo") + assert.EqualError(t, err, "extension not found: field-one") + + type ComidExt struct { + FieldOne string `cbor:"-1,keyasint" json:"field-one"` + } + + c.RegisterExtensions(&ComidExt{}) + + err = c.Set("field-one", "foo") + assert.NoError(t, err) + assert.Equal(t, "foo", c.MustGetString("-1")) +} + +func Test_Comid_ToJSONPretty(t *testing.T) { + c := NewComid() + + _, err := c.ToJSONPretty(" ") + assert.EqualError(t, err, "tag-identity validation failed: empty tag-id") + + c.TagIdentity = TagIdentity{TagID: *swid.NewTagID("test"), TagVersion: 1} + c.Triples = Triples{ReferenceValues: &[]ReferenceValue{}} + + expected := `{ + "tag-identity": { + "id": "test", + "version": 1 + }, + "triples": { + "reference-values": [] + } +}` + v, err := c.ToJSONPretty(" ") + require.NoError(t, err) + assert.Equal(t, expected, string(v)) +} diff --git a/comid/cryptokey_test.go b/comid/cryptokey_test.go index e1a3b4f8..1c0812fc 100644 --- a/comid/cryptokey_test.go +++ b/comid/cryptokey_test.go @@ -119,7 +119,7 @@ func Test_CryptoKey_NewCOSEKey(t *testing.T) { assert.EqualError(t, err, "empty COSE_Key bytes") _, err = NewCOSEKey([]byte("DEADBEEF")) - assert.Contains(t, err.Error(), "cbor: cannot unmarshal") + assert.Contains(t, err.Error(), "cbor: 3 bytes of extraneous data starting at index 5") badKey := []byte{ // taken from go-cose unit tests 0xa2, // map(2) diff --git a/comid/entity.go b/comid/entity.go index 694fef6c..278b7e52 100644 --- a/comid/entity.go +++ b/comid/entity.go @@ -3,7 +3,12 @@ package comid -import "fmt" +import ( + "fmt" + + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" +) type TaggedURI string @@ -16,8 +21,21 @@ type Entity struct { EntityName string `cbor:"0,keyasint" json:"name"` RegID *TaggedURI `cbor:"1,keyasint,omitempty" json:"regid,omitempty"` Roles Roles `cbor:"2,keyasint" json:"roles"` + + Extensions +} + +// RegisterExtensions registers a struct as a collections of extensions +func (o *Entity) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) } +// GetExtensions returns pervisouosly registered extension +func (o *Entity) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + +// SetEntityName is used to set the EntityName field of Entity using supplied name func (o *Entity) SetEntityName(name string) *Entity { if o != nil { if name == "" { @@ -28,6 +46,7 @@ func (o *Entity) SetEntityName(name string) *Entity { return o } +// SetRegID is used to set the RegID field of Entity using supplied uri func (o *Entity) SetRegID(uri string) *Entity { if o != nil { if uri == "" { @@ -39,6 +58,7 @@ func (o *Entity) SetRegID(uri string) *Entity { return o } +// SetRoles appends the supplied roles to the target entity. func (o *Entity) SetRoles(roles ...Role) *Entity { if o != nil { o.Roles.Add(roles...) @@ -46,6 +66,7 @@ func (o *Entity) SetRoles(roles ...Role) *Entity { return o } +// Valid checks for validity of the fields within each Entity func (o Entity) Valid() error { if o.EntityName == "" { return fmt.Errorf("invalid entity: empty entity-name") @@ -59,7 +80,27 @@ func (o Entity) Valid() error { return fmt.Errorf("invalid entity: %w", err) } - return nil + return o.Extensions.ValidEntity(&o) +} + +// UnmarshalCBOR deserializes from CBOR +func (o *Entity) UnmarshalCBOR(data []byte) error { + return encoding.PopulateStructFromCBOR(dm, data, o) +} + +// MarshalCBOR serializes to CBOR +func (o *Entity) MarshalCBOR() ([]byte, error) { + return encoding.SerializeStructToCBOR(em, o) +} + +// UnmarshalJSON deserializes from JSON +func (o *Entity) UnmarshalJSON(data []byte) error { + return encoding.PopulateStructFromJSON(data, o) +} + +// MarshalJSON serializes to JSON +func (o *Entity) MarshalJSON() ([]byte, error) { + return encoding.SerializeStructToJSON(o) } // Entities is an array of entity-map's @@ -78,6 +119,7 @@ func (o *Entities) AddEntity(e Entity) *Entities { return o } +// Valid iterates over the range of individual entities to check for validity func (o Entities) Valid() error { for i, m := range o { if err := m.Valid(); err != nil { diff --git a/comid/example_test.go b/comid/example_test.go index 7f159d44..756829d3 100644 --- a/comid/example_test.go +++ b/comid/example_test.go @@ -37,7 +37,8 @@ func Example_encode() { SetSVN(2). AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}). AddDigest(swid.Sha256_32, []byte{0xff, 0xff, 0xff, 0xff}). - SetOpFlags(OpFlagNotSecure, OpFlagDebug). + SetFlagsTrue(FlagIsDebug). + SetFlagsFalse(FlagIsSecure). SetSerialNumber("C02X70VHJHD5"). SetUEID(TestUEID). SetUUID(TestUUID). @@ -65,7 +66,8 @@ func Example_encode() { SetMinSVN(2). AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}). AddDigest(swid.Sha256_32, []byte{0xff, 0xff, 0xff, 0xff}). - SetOpFlags(OpFlagNotSecure, OpFlagDebug, OpFlagNotConfigured). + SetFlagsTrue(FlagIsDebug). + SetFlagsFalse(FlagIsSecure, FlagIsConfigured). SetSerialNumber("C02X70VHJHD5"). SetUEID(TestUEID). SetUUID(TestUUID). @@ -107,8 +109,8 @@ func Example_encode() { } // Output: - //a50065656e2d474201a10078206d792d6e733a61636d652d726f616472756e6e65722d737570706c656d656e740282a3006941434d45204c74642e01d8207468747470733a2f2f61636d652e6578616d706c6502820100a20069454d4341204c74642e0281020382a200781a6d792d6e733a61636d652d726f616472756e6e65722d626173650100a20078196d792d6e733a61636d652d726f616472756e6e65722d6f6c64010104a4008182a300a500d86f445502c000016941434d45204c74642e026a526f616452756e6e65720300040101d902264702deadbeefdead02d8255031fb5abf023e4992aa4e95f9c1503bfa81a200d8255031fb5abf023e4992aa4e95f9c1503bfa01aa01d90228020282820644abcdef00820644ffffffff030a04d9023044010203040544ffffffff064802005e1000000001075020010db8000000000000000000000068086c43303258373056484a484435094702deadbeefdead0a5031fb5abf023e4992aa4e95f9c1503bfa018182a300a500d8255031fb5abf023e4992aa4e95f9c1503bfa016941434d45204c74642e026a526f616452756e6e65720300040101d902264702deadbeefdead02d8255031fb5abf023e4992aa4e95f9c1503bfa81a200d8255031fb5abf023e4992aa4e95f9c1503bfa01aa01d90229020282820644abcdef00820644ffffffff030b04d9023044010203040544ffffffff064802005e1000000001075020010db8000000000000000000000068086c43303258373056484a484435094702deadbeefdead0a5031fb5abf023e4992aa4e95f9c1503bfa028182a101d8255031fb5abf023e4992aa4e95f9c1503bfa81d9022a78b12d2d2d2d2d424547494e205055424c4943204b45592d2d2d2d2d0a4d466b77457759484b6f5a497a6a3043415159494b6f5a497a6a304441516344516741455731427671462b2f727938425761375a454d553178595948455138420a6c4c54344d46484f614f2b4943547449767245654570722f7366544150363648326843486462354845584b74524b6f6436514c634f4c504131513d3d0a2d2d2d2d2d454e44205055424c4943204b45592d2d2d2d2d038182a101d902264702deadbeefdead81d9022a78b12d2d2d2d2d424547494e205055424c4943204b45592d2d2d2d2d0a4d466b77457759484b6f5a497a6a3043415159494b6f5a497a6a304441516344516741455731427671462b2f727938425761375a454d553178595948455138420a6c4c54344d46484f614f2b4943547449767245654570722f7366544150363648326843486462354845584b74524b6f6436514c634f4c504131513d3d0a2d2d2d2d2d454e44205055424c4943204b45592d2d2d2d2d - //{"lang":"en-GB","tag-identity":{"id":"my-ns:acme-roadrunner-supplement"},"entities":[{"name":"ACME Ltd.","regid":"https://acme.example","roles":["creator","tagCreator"]},{"name":"EMCA Ltd.","roles":["maintainer"]}],"linked-tags":[{"target":"my-ns:acme-roadrunner-base","rel":"supplements"},{"target":"my-ns:acme-roadrunner-old","rel":"replaces"}],"triples":{"reference-values":[{"environment":{"class":{"id":{"type":"oid","value":"2.5.2.8192"},"vendor":"ACME Ltd.","model":"RoadRunner","layer":0,"index":1},"instance":{"type":"ueid","value":"At6tvu/erQ=="},"group":{"type":"ueid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"measurements":[{"key":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"value":{"svn":{"type":"exact-value","value":2},"digests":["sha-256-32;q83vAA==","sha-256-32;/////w=="],"op-flags":["notSecure","debug"],"raw-value":{"type":"bytes","value":"AQIDBA=="},"raw-value-mask":"/////w==","mac-addr":"02:00:5e:10:00:00:00:01","ip-addr":"2001:db8::68","serial-number":"C02X70VHJHD5","ueid":"At6tvu/erQ==","uuid":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}}]}],"endorsed-values":[{"environment":{"class":{"id":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"vendor":"ACME Ltd.","model":"RoadRunner","layer":0,"index":1},"instance":{"type":"ueid","value":"At6tvu/erQ=="},"group":{"type":"ueid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"measurements":[{"key":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"value":{"svn":{"type":"min-value","value":2},"digests":["sha-256-32;q83vAA==","sha-256-32;/////w=="],"op-flags":["notConfigured","notSecure","debug"],"raw-value":{"type":"bytes","value":"AQIDBA=="},"raw-value-mask":"/////w==","mac-addr":"02:00:5e:10:00:00:00:01","ip-addr":"2001:db8::68","serial-number":"C02X70VHJHD5","ueid":"At6tvu/erQ==","uuid":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}}]}],"attester-verification-keys":[{"environment":{"instance":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}],"dev-identity-keys":[{"environment":{"instance":{"type":"ueid","value":"At6tvu/erQ=="}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}]}} + // a50065656e2d474201a10078206d792d6e733a61636d652d726f616472756e6e65722d737570706c656d656e740282a3006941434d45204c74642e01d8207468747470733a2f2f61636d652e6578616d706c6502820100a20069454d4341204c74642e0281020382a200781a6d792d6e733a61636d652d726f616472756e6e65722d626173650100a20078196d792d6e733a61636d652d726f616472756e6e65722d6f6c64010104a4008182a300a500d86f445502c000016941434d45204c74642e026a526f616452756e6e65720300040101d902264702deadbeefdead02d8255031fb5abf023e4992aa4e95f9c1503bfa81a200d8255031fb5abf023e4992aa4e95f9c1503bfa01aa01d90228020282820644abcdef00820644ffffffff03a201f403f504d9023044010203040544ffffffff064802005e1000000001075020010db8000000000000000000000068086c43303258373056484a484435094702deadbeefdead0a5031fb5abf023e4992aa4e95f9c1503bfa018182a300a500d8255031fb5abf023e4992aa4e95f9c1503bfa016941434d45204c74642e026a526f616452756e6e65720300040101d902264702deadbeefdead02d8255031fb5abf023e4992aa4e95f9c1503bfa81a200d8255031fb5abf023e4992aa4e95f9c1503bfa01aa01d90229020282820644abcdef00820644ffffffff03a300f401f403f504d9023044010203040544ffffffff064802005e1000000001075020010db8000000000000000000000068086c43303258373056484a484435094702deadbeefdead0a5031fb5abf023e4992aa4e95f9c1503bfa028182a101d8255031fb5abf023e4992aa4e95f9c1503bfa81d9022a78b12d2d2d2d2d424547494e205055424c4943204b45592d2d2d2d2d0a4d466b77457759484b6f5a497a6a3043415159494b6f5a497a6a304441516344516741455731427671462b2f727938425761375a454d553178595948455138420a6c4c54344d46484f614f2b4943547449767245654570722f7366544150363648326843486462354845584b74524b6f6436514c634f4c504131513d3d0a2d2d2d2d2d454e44205055424c4943204b45592d2d2d2d2d038182a101d902264702deadbeefdead81d9022a78b12d2d2d2d2d424547494e205055424c4943204b45592d2d2d2d2d0a4d466b77457759484b6f5a497a6a3043415159494b6f5a497a6a304441516344516741455731427671462b2f727938425761375a454d553178595948455138420a6c4c54344d46484f614f2b4943547449767245654570722f7366544150363648326843486462354845584b74524b6f6436514c634f4c504131513d3d0a2d2d2d2d2d454e44205055424c4943204b45592d2d2d2d2d + // {"lang":"en-GB","tag-identity":{"id":"my-ns:acme-roadrunner-supplement"},"entities":[{"name":"ACME Ltd.","regid":"https://acme.example","roles":["creator","tagCreator"]},{"name":"EMCA Ltd.","roles":["maintainer"]}],"linked-tags":[{"target":"my-ns:acme-roadrunner-base","rel":"supplements"},{"target":"my-ns:acme-roadrunner-old","rel":"replaces"}],"triples":{"reference-values":[{"environment":{"class":{"id":{"type":"oid","value":"2.5.2.8192"},"vendor":"ACME Ltd.","model":"RoadRunner","layer":0,"index":1},"instance":{"type":"ueid","value":"At6tvu/erQ=="},"group":{"type":"ueid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"measurements":[{"key":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"value":{"svn":{"type":"exact-value","value":2},"digests":["sha-256-32;q83vAA==","sha-256-32;/////w=="],"flags":{"is-secure":false,"is-debug":true},"raw-value":{"type":"bytes","value":"AQIDBA=="},"raw-value-mask":"/////w==","mac-addr":"02:00:5e:10:00:00:00:01","ip-addr":"2001:db8::68","serial-number":"C02X70VHJHD5","ueid":"At6tvu/erQ==","uuid":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}}]}],"endorsed-values":[{"environment":{"class":{"id":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"vendor":"ACME Ltd.","model":"RoadRunner","layer":0,"index":1},"instance":{"type":"ueid","value":"At6tvu/erQ=="},"group":{"type":"ueid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"measurements":[{"key":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"value":{"svn":{"type":"min-value","value":2},"digests":["sha-256-32;q83vAA==","sha-256-32;/////w=="],"flags":{"is-configured":false,"is-secure":false,"is-debug":true},"raw-value":{"type":"bytes","value":"AQIDBA=="},"raw-value-mask":"/////w==","mac-addr":"02:00:5e:10:00:00:00:01","ip-addr":"2001:db8::68","serial-number":"C02X70VHJHD5","ueid":"At6tvu/erQ==","uuid":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}}]}],"attester-verification-keys":[{"environment":{"instance":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}],"dev-identity-keys":[{"environment":{"instance":{"type":"ueid","value":"At6tvu/erQ=="}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}]}} } func Example_encode_PSA() { @@ -162,8 +164,8 @@ func Example_encode_PSA() { } // Output: - //a301a10078206d792d6e733a61636d652d726f616472756e6e65722d737570706c656d656e740281a3006941434d45204c74642e01d8207468747470733a2f2f61636d652e6578616d706c65028301000204a2008182a100a300d90258582061636d652d696d706c656d656e746174696f6e2d69642d303030303030303031016941434d45204c74642e026e526f616452756e6e657220322e3082a200d90259a30162424c0465352e302e35055820acbb11c7e4da217205523ce4ce1a245ae1a239ae3c6bfd9e7871f7e5d8bae86b01a10281820644abcdef00a200d90259a3016450526f540465312e332e35055820acbb11c7e4da217205523ce4ce1a245ae1a239ae3c6bfd9e7871f7e5d8bae86b01a10281820644abcdef00028182a101d902264702deadbeefdead81d9022a78b12d2d2d2d2d424547494e205055424c4943204b45592d2d2d2d2d0a4d466b77457759484b6f5a497a6a3043415159494b6f5a497a6a304441516344516741455731427671462b2f727938425761375a454d553178595948455138420a6c4c54344d46484f614f2b4943547449767245654570722f7366544150363648326843486462354845584b74524b6f6436514c634f4c504131513d3d0a2d2d2d2d2d454e44205055424c4943204b45592d2d2d2d2d - //{"tag-identity":{"id":"my-ns:acme-roadrunner-supplement"},"entities":[{"name":"ACME Ltd.","regid":"https://acme.example","roles":["creator","tagCreator","maintainer"]}],"triples":{"reference-values":[{"environment":{"class":{"id":{"type":"psa.impl-id","value":"YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE="},"vendor":"ACME Ltd.","model":"RoadRunner 2.0"}},"measurements":[{"key":{"type":"psa.refval-id","value":{"label":"BL","version":"5.0.5","signer-id":"rLsRx+TaIXIFUjzkzhokWuGiOa48a/2eeHH35di66Gs="}},"value":{"digests":["sha-256-32;q83vAA=="]}},{"key":{"type":"psa.refval-id","value":{"label":"PRoT","version":"1.3.5","signer-id":"rLsRx+TaIXIFUjzkzhokWuGiOa48a/2eeHH35di66Gs="}},"value":{"digests":["sha-256-32;q83vAA=="]}}]}],"attester-verification-keys":[{"environment":{"instance":{"type":"ueid","value":"At6tvu/erQ=="}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}]}} + // a301a10078206d792d6e733a61636d652d726f616472756e6e65722d737570706c656d656e740281a3006941434d45204c74642e01d8207468747470733a2f2f61636d652e6578616d706c65028301000204a2008182a100a300d90258582061636d652d696d706c656d656e746174696f6e2d69642d303030303030303031016941434d45204c74642e026e526f616452756e6e657220322e3082a200d90259a30162424c0465352e302e35055820acbb11c7e4da217205523ce4ce1a245ae1a239ae3c6bfd9e7871f7e5d8bae86b01a10281820644abcdef00a200d90259a3016450526f540465312e332e35055820acbb11c7e4da217205523ce4ce1a245ae1a239ae3c6bfd9e7871f7e5d8bae86b01a10281820644abcdef00028182a101d902264702deadbeefdead81d9022a78b12d2d2d2d2d424547494e205055424c4943204b45592d2d2d2d2d0a4d466b77457759484b6f5a497a6a3043415159494b6f5a497a6a304441516344516741455731427671462b2f727938425761375a454d553178595948455138420a6c4c54344d46484f614f2b4943547449767245654570722f7366544150363648326843486462354845584b74524b6f6436514c634f4c504131513d3d0a2d2d2d2d2d454e44205055424c4943204b45592d2d2d2d2d + // {"tag-identity":{"id":"my-ns:acme-roadrunner-supplement"},"entities":[{"name":"ACME Ltd.","regid":"https://acme.example","roles":["creator","tagCreator","maintainer"]}],"triples":{"reference-values":[{"environment":{"class":{"id":{"type":"psa.impl-id","value":"YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE="},"vendor":"ACME Ltd.","model":"RoadRunner 2.0"}},"measurements":[{"key":{"type":"psa.refval-id","value":{"label":"BL","version":"5.0.5","signer-id":"rLsRx+TaIXIFUjzkzhokWuGiOa48a/2eeHH35di66Gs="}},"value":{"digests":["sha-256-32;q83vAA=="]}},{"key":{"type":"psa.refval-id","value":{"label":"PRoT","version":"1.3.5","signer-id":"rLsRx+TaIXIFUjzkzhokWuGiOa48a/2eeHH35di66Gs="}},"value":{"digests":["sha-256-32;q83vAA=="]}}]}],"attester-verification-keys":[{"environment":{"instance":{"type":"ueid","value":"At6tvu/erQ=="}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}]}} } func Example_encode_PSA_attestation_verification() { diff --git a/comid/extensions.go b/comid/extensions.go new file mode 100644 index 00000000..b62da058 --- /dev/null +++ b/comid/extensions.go @@ -0,0 +1,173 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package comid + +import ( + "github.com/veraison/corim/extensions" +) + +type IComidValidator interface { + ValidComid(*Comid) error +} + +type ITriplesValidator interface { + ValidTriples(*Triples) error +} + +type IMvalValidator interface { + ValidMval(*Mval) error +} + +type IEntityValidator interface { + ValidEntity(*Entity) error +} + +type IFlagsMapValidator interface { + ValidFlagsMap(*FlagsMap) error +} + +type IFlagSetter interface { + AnySet() bool + SetTrue(Flag) + SetFalse(Flag) + Clear(Flag) + Get(Flag) *bool +} + +type Extensions struct { + extensions.Extensions +} + +func (o *Extensions) ValidComid(comid *Comid) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(IComidValidator) + if ok { + if err := ev.ValidComid(comid); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) ValidTriples(triples *Triples) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(ITriplesValidator) + if ok { + if err := ev.ValidTriples(triples); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) ValidMval(triples *Mval) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(IMvalValidator) + if ok { + if err := ev.ValidMval(triples); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) ValidEntity(triples *Entity) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(IEntityValidator) + if ok { + if err := ev.ValidEntity(triples); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) ValidFlagsMap(triples *FlagsMap) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(IFlagsMapValidator) + if ok { + if err := ev.ValidFlagsMap(triples); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) SetTrue(flag Flag) { + if !o.HaveExtensions() { + return + } + + ev, ok := o.IExtensionsValue.(IFlagSetter) + if ok { + ev.SetTrue(flag) + } +} + +func (o *Extensions) SetFalse(flag Flag) { + if !o.HaveExtensions() { + return + } + + ev, ok := o.IExtensionsValue.(IFlagSetter) + if ok { + ev.SetFalse(flag) + } +} + +func (o *Extensions) Clear(flag Flag) { + if !o.HaveExtensions() { + return + } + + ev, ok := o.IExtensionsValue.(IFlagSetter) + if ok { + ev.Clear(flag) + } +} + +func (o *Extensions) Get(flag Flag) *bool { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(IFlagSetter) + if ok { + return ev.Get(flag) + } + + return nil +} + +func (o *Extensions) AnySet() bool { + if !o.HaveExtensions() { + return false + } + + ev, ok := o.IExtensionsValue.(IFlagSetter) + if ok { + return ev.AnySet() + } + + return false +} diff --git a/comid/extensions_test.go b/comid/extensions_test.go new file mode 100644 index 00000000..8b36c9a8 --- /dev/null +++ b/comid/extensions_test.go @@ -0,0 +1,96 @@ +package comid + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +var FlagTestFlag = Flag(-1) + +type TestExtension struct { + TestFlag *bool +} + +func (o *TestExtension) ValidComid(v *Comid) error { + return errors.New("invalid") +} + +func (o *TestExtension) ValidTriples(v *Triples) error { + return errors.New("invalid") +} + +func (o *TestExtension) ValidMval(v *Mval) error { + return errors.New("invalid") +} + +func (o *TestExtension) ValidFlagsMap(v *FlagsMap) error { + return errors.New("invalid") +} + +func (o *TestExtension) ValidEntity(v *Entity) error { + return errors.New("invalid") +} + +func (o *TestExtension) SetTrue(flag Flag) { + if flag == FlagTestFlag { + o.TestFlag = &True + } +} +func (o *TestExtension) SetFalse(flag Flag) { + if flag == FlagTestFlag { + o.TestFlag = &False + } +} + +func (o *TestExtension) Clear(flag Flag) { + if flag == FlagTestFlag { + o.TestFlag = nil + } +} + +func (o *TestExtension) Get(flag Flag) *bool { + if flag == FlagTestFlag { + return o.TestFlag + } + + return nil +} + +func (o *TestExtension) AnySet() bool { + return o.TestFlag != nil +} + +func Test_Extensions(t *testing.T) { + exts := Extensions{} + exts.Register(&TestExtension{}) + + err := exts.ValidComid(nil) + assert.EqualError(t, err, "invalid") + + err = exts.ValidTriples(nil) + assert.EqualError(t, err, "invalid") + + err = exts.ValidMval(nil) + assert.EqualError(t, err, "invalid") + + err = exts.ValidEntity(nil) + assert.EqualError(t, err, "invalid") + + err = exts.ValidFlagsMap(nil) + assert.EqualError(t, err, "invalid") + + assert.False(t, exts.AnySet()) + + exts.SetTrue(FlagTestFlag) + assert.True(t, exts.AnySet()) + assert.True(t, *exts.Get(FlagTestFlag)) + + exts.SetFalse(FlagTestFlag) + assert.False(t, *exts.Get(FlagTestFlag)) + + exts.Clear(FlagTestFlag) + assert.Nil(t, exts.Get(FlagTestFlag)) + assert.False(t, exts.AnySet()) +} diff --git a/comid/flagsmap.go b/comid/flagsmap.go index 279cf1a4..95737419 100644 --- a/comid/flagsmap.go +++ b/comid/flagsmap.go @@ -3,6 +3,11 @@ package comid +import ( + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" +) + var True = true var False = false @@ -52,6 +57,8 @@ type FlagsMap struct { // IsTcb indicates whether the measured environment is a trusted // computing base. IsTcb *bool `cbor:"8,keyasint,omitempty" json:"is-tcb,omitempty"` + + Extensions } func NewFlagsMap() *FlagsMap { @@ -65,7 +72,7 @@ func (o *FlagsMap) AnySet() bool { return true } - return false + return o.Extensions.AnySet() } func (o *FlagsMap) SetTrue(flags ...Flag) { @@ -90,6 +97,7 @@ func (o *FlagsMap) SetTrue(flags ...Flag) { case FlagIsTcb: o.IsTcb = &True default: + o.Extensions.SetTrue(flag) } } } @@ -116,6 +124,7 @@ func (o *FlagsMap) SetFalse(flags ...Flag) { case FlagIsTcb: o.IsTcb = &False default: + o.Extensions.SetFalse(flag) } } } @@ -142,6 +151,7 @@ func (o *FlagsMap) Clear(flags ...Flag) { case FlagIsTcb: o.IsTcb = nil default: + o.Extensions.Clear(flag) } } } @@ -167,11 +177,41 @@ func (o *FlagsMap) Get(flag Flag) *bool { case FlagIsTcb: return o.IsTcb default: - return nil + return o.Extensions.Get(flag) } } +// RegisterExtensions registers a struct as a collections of extensions +func (o *FlagsMap) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) +} + +// GetExtensions returns pervisouosly registered extension +func (o *FlagsMap) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + +// UnmarshalCBOR deserializes from CBOR +func (o *FlagsMap) UnmarshalCBOR(data []byte) error { + return encoding.PopulateStructFromCBOR(dm, data, o) +} + +// MarshalCBOR serializes to CBOR +func (o *FlagsMap) MarshalCBOR() ([]byte, error) { + return encoding.SerializeStructToCBOR(em, o) +} + +// UnmarshalJSON deserializes from JSON +func (o *FlagsMap) UnmarshalJSON(data []byte) error { + return encoding.PopulateStructFromJSON(data, o) +} + +// MarshalJSON serializes to JSON +func (o *FlagsMap) MarshalJSON() ([]byte, error) { + return encoding.SerializeStructToJSON(o) +} + // Valid returns an error if the FlagsMap is invalid. func (o FlagsMap) Valid() error { - return nil + return o.Extensions.ValidFlagsMap(&o) } diff --git a/comid/measurement.go b/comid/measurement.go index b1909d87..7bfd7f99 100644 --- a/comid/measurement.go +++ b/comid/measurement.go @@ -247,6 +247,16 @@ func (o *Mval) MarshalCBOR() ([]byte, error) { return encoding.SerializeStructToCBOR(em, o) } +// UnmarshalJSON deserializes from JSON +func (o *Mval) UnmarshalJSON(data []byte) error { + return encoding.PopulateStructFromJSON(data, o) +} + +// MarshalJSON serializes to JSON +func (o *Mval) MarshalJSON() ([]byte, error) { + return encoding.SerializeStructToJSON(o) +} + func (o Mval) Valid() error { if o.Ver == nil && o.SVN == nil && diff --git a/comid/measurement_test.go b/comid/measurement_test.go index c9abed17..d5d58bff 100644 --- a/comid/measurement_test.go +++ b/comid/measurement_test.go @@ -103,7 +103,7 @@ func TestMeasurement_NewUUIDMeasurement_some_value(t *testing.T) { tv := NewUUIDMeasurement(TestUUID). SetMinSVN(2). - SetOpFlags(OpFlagDebug). + SetFlagsTrue(FlagIsDebug). SetVersion("1.2.3", swid.VersionSchemeSemVer) require.NotNil(t, tv) diff --git a/comid/referencevalue_test.go b/comid/referencevalue_test.go new file mode 100644 index 00000000..00b89d18 --- /dev/null +++ b/comid/referencevalue_test.go @@ -0,0 +1,22 @@ +package comid + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ReferenceValue(t *testing.T) { + rv := ReferenceValue{} + err := rv.Valid() + assert.EqualError(t, err, "environment validation failed: environment must not be empty") + + rv.Environment.Instance = NewInstance() + id, err := uuid.NewUUID() + require.NoError(t, err) + rv.Environment.Instance.SetUUID(id) + err = rv.Valid() + assert.EqualError(t, err, "measurements validation failed: no measurement entries") +} diff --git a/comid/triples.go b/comid/triples.go index 28df6b75..8500734e 100644 --- a/comid/triples.go +++ b/comid/triples.go @@ -3,13 +3,50 @@ package comid -import "fmt" +import ( + "fmt" + + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" +) type Triples struct { ReferenceValues *[]ReferenceValue `cbor:"0,keyasint,omitempty" json:"reference-values,omitempty"` EndorsedValues *[]EndorsedValue `cbor:"1,keyasint,omitempty" json:"endorsed-values,omitempty"` AttestVerifKeys *[]AttestVerifKey `cbor:"2,keyasint,omitempty" json:"attester-verification-keys,omitempty"` DevIdentityKeys *[]DevIdentityKey `cbor:"3,keyasint,omitempty" json:"dev-identity-keys,omitempty"` + + Extensions +} + +// RegisterExtensions registers a struct as a collections of extensions +func (o *Triples) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) +} + +// GetExtensions returns pervisouosly registered extension +func (o *Triples) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + +// UnmarshalCBOR deserializes from CBOR +func (o *Triples) UnmarshalCBOR(data []byte) error { + return encoding.PopulateStructFromCBOR(dm, data, o) +} + +// MarshalCBOR serializes to CBOR +func (o *Triples) MarshalCBOR() ([]byte, error) { + return encoding.SerializeStructToCBOR(em, o) +} + +// UnmarshalJSON deserializes from JSON +func (o *Triples) UnmarshalJSON(data []byte) error { + return encoding.PopulateStructFromJSON(data, o) +} + +// MarshalJSON serializes to JSON +func (o *Triples) MarshalJSON() ([]byte, error) { + return encoding.SerializeStructToJSON(o) } // Valid checks that the Triples is valid as per the specification @@ -52,7 +89,7 @@ func (o Triples) Valid() error { } } - return nil + return o.Extensions.ValidTriples(&o) } func (o *Triples) AddReferenceValue(val ReferenceValue) *Triples { diff --git a/corim/entity.go b/corim/entity.go index 49aa1468..b9288095 100644 --- a/corim/entity.go +++ b/corim/entity.go @@ -7,6 +7,8 @@ import ( "fmt" "github.com/veraison/corim/comid" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" ) // Entity stores an entity-map capable of CBOR and JSON serializations. @@ -14,12 +16,24 @@ type Entity struct { EntityName string `cbor:"0,keyasint" json:"name"` RegID *comid.TaggedURI `cbor:"1,keyasint,omitempty" json:"regid,omitempty"` Roles Roles `cbor:"2,keyasint" json:"roles"` + + Extensions } func NewEntity() *Entity { return &Entity{} } +// RegisterExtensions registers a struct as a collections of extensions +func (o *Entity) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) +} + +// GetExtensions returns pervisouosly registered extension +func (o *Entity) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + // SetEntityName is used to set the EntityName field of Entity using supplied name func (o *Entity) SetEntityName(name string) *Entity { if o != nil { @@ -72,7 +86,27 @@ func (o Entity) Valid() error { return fmt.Errorf("invalid entity: %w", err) } - return nil + return o.Extensions.ValidEntity(&o) +} + +// UnmarshalCBOR deserializes from CBOR +func (o *Entity) UnmarshalCBOR(data []byte) error { + return encoding.PopulateStructFromCBOR(dm, data, o) +} + +// MarshalCBOR serializes to CBOR +func (o *Entity) MarshalCBOR() ([]byte, error) { + return encoding.SerializeStructToCBOR(em, o) +} + +// UnmarshalJSON deserializes from JSON +func (o *Entity) UnmarshalJSON(data []byte) error { + return encoding.PopulateStructFromJSON(data, o) +} + +// MarshalJSON serializes to JSON +func (o *Entity) MarshalJSON() ([]byte, error) { + return encoding.SerializeStructToJSON(o) } // Entities is an array of entity-map's diff --git a/corim/extensions.go b/corim/extensions.go new file mode 100644 index 00000000..55f2410a --- /dev/null +++ b/corim/extensions.go @@ -0,0 +1,68 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package corim + +import ( + "github.com/veraison/corim/extensions" +) + +type IEntityValidator interface { + ValidEntity(*Entity) error +} + +type ICorimValidator interface { + ValidCorim(*UnsignedCorim) error +} + +type ISignerValidator interface { + ValidSigner(*Signer) error +} + +type Extensions struct { + extensions.Extensions +} + +func (o *Extensions) ValidEntity(entity *Entity) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(IEntityValidator) + if ok { + if err := ev.ValidEntity(entity); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) ValidCorim(c *UnsignedCorim) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(ICorimValidator) + if ok { + if err := ev.ValidCorim(c); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) ValidSigner(signer *Signer) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(ISignerValidator) + if ok { + if err := ev.ValidSigner(signer); err != nil { + return err + } + } + + return nil +} diff --git a/corim/extensions_test.go b/corim/extensions_test.go new file mode 100644 index 00000000..653a13e1 --- /dev/null +++ b/corim/extensions_test.go @@ -0,0 +1,86 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package corim + +import ( + "errors" + "testing" + + "github.com/fxamacker/cbor/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type TestExtensions struct { + Address string `cbor:"-1,keyasint,omitempty" json:"address,omitempty"` + Size int `cbor:"-2,keyasint,omitempty" json:"size,omitempty"` +} + +func (o TestExtensions) ValidEntity(ent *Entity) error { + if ent.EntityName != "Futurama" { + return errors.New(`EntityName must be "Futurama"`) // nolint:golint + } + + return nil +} + +func (o TestExtensions) ValidCorim(c *UnsignedCorim) error { + return errors.New("invalid") +} + +func (o TestExtensions) ValidSigner(s *Signer) error { + return errors.New("invalid") +} + +func TestEntityExtensions_Valid(t *testing.T) { + ent := NewEntity() + ent.SetEntityName("The Simpsons") + ent.SetRoles(RoleManifestCreator) + + err := ent.Valid() + assert.NoError(t, err) + + ent.RegisterExtensions(&TestExtensions{}) + err = ent.Valid() + assert.EqualError(t, err, `EntityName must be "Futurama"`) + + ent.SetEntityName("Futurama") + err = ent.Valid() + assert.NoError(t, err) + + assert.EqualError(t, ent.Extensions.ValidCorim(nil), "invalid") + assert.EqualError(t, ent.Extensions.ValidSigner(nil), "invalid") +} + +func TestEntityExtensions_CBOR(t *testing.T) { + data := []byte{ + 0xa4, // map(4) + + 0x00, // key 0 + 0x64, // val tstr(4) + 0x61, 0x63, 0x6d, 0x65, // "acme" + + 0x02, // key 2 + 0x81, // array(1) + 0x01, // 1 + + 0x20, // key -1 + 0x63, // val tstr(3) + 0x66, 0x6f, 0x6f, // "foo" + + 0x21, // key -2 + 0x06, // val 6 + } + + ent := NewEntity() + ent.RegisterExtensions(&TestExtensions{}) + + err := cbor.Unmarshal(data, &ent) + assert.NoError(t, err) + + assert.Equal(t, ent.EntityName, "acme") + + address, err := ent.Get("address") + require.NoError(t, err) + assert.Equal(t, address, "foo") +} diff --git a/corim/meta.go b/corim/meta.go index 45f1ea49..4b1eeeb6 100644 --- a/corim/meta.go +++ b/corim/meta.go @@ -10,17 +10,31 @@ import ( "time" "github.com/veraison/corim/comid" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" ) type Signer struct { Name string `cbor:"0,keyasint" json:"name"` URI *comid.TaggedURI `cbor:"1,keyasint,omitempty" json:"uri,omitempty"` + + Extensions } func NewSigner() *Signer { return &Signer{} } +// RegisterExtensions registers a struct as a collections of extensions +func (o *Signer) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) +} + +// GetExtensions returns pervisouosly registered extension +func (o *Signer) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + // SetName sets the target Signer's name to the supplied value func (o *Signer) SetName(name string) *Signer { if o != nil { @@ -61,7 +75,27 @@ func (o Signer) Valid() error { } } - return nil + return o.Extensions.ValidSigner(&o) +} + +// UnmarshalCBOR deserializes from CBOR +func (o *Signer) UnmarshalCBOR(data []byte) error { + return encoding.PopulateStructFromCBOR(dm, data, o) +} + +// MarshalCBOR serializes to CBOR +func (o *Signer) MarshalCBOR() ([]byte, error) { + return encoding.SerializeStructToCBOR(em, o) +} + +// UnmarshalJSON deserializes from JSON +func (o *Signer) UnmarshalJSON(data []byte) error { + return encoding.PopulateStructFromJSON(data, o) +} + +// MarshalJSON serializes to JSON +func (o *Signer) MarshalJSON() ([]byte, error) { + return encoding.SerializeStructToJSON(o) } // Meta stores a corim-meta-map with JSON and CBOR serializations. It carries diff --git a/corim/meta_test.go b/corim/meta_test.go index a4dd34cd..811495cf 100644 --- a/corim/meta_test.go +++ b/corim/meta_test.go @@ -183,3 +183,15 @@ func TestMeta_FromCBOR_full(t *testing.T) { assert.Equal(t, notBefore.Unix(), actual.Validity.NotBefore.Unix()) assert.Equal(t, notAfter.Unix(), actual.Validity.NotAfter.Unix()) } + +func Test_Signer_Valid(t *testing.T) { + var signer Signer + + assert.EqualError(t, signer.Valid(), "empty name") + + signer.Name = "test-signer" + uri := comid.TaggedURI("@@@") + signer.URI = &uri + + assert.EqualError(t, signer.Valid(), `invalid URI: "@@@" is not an absolute URI`) +} diff --git a/corim/signedcorim_test.go b/corim/signedcorim_test.go index 9e030b3c..eae15d45 100644 --- a/corim/signedcorim_test.go +++ b/corim/signedcorim_test.go @@ -246,7 +246,7 @@ func TestSignedCorim_FromCOSE_fail_invalid_corim(t *testing.T) { var actual SignedCorim err := actual.FromCOSE(tv) - assert.EqualError(t, err, "failed validation of unsigned CoRIM: tags validation failed: no tags") + assert.EqualError(t, err, `failed CBOR decoding of unsigned CoRIM: missing mandatory field "Tags" (1)`) } func TestSignedCorim_FromCOSE_fail_no_content_type(t *testing.T) { diff --git a/corim/unsignedcorim.go b/corim/unsignedcorim.go index 68ca6786..a7488500 100644 --- a/corim/unsignedcorim.go +++ b/corim/unsignedcorim.go @@ -4,12 +4,13 @@ package corim import ( - "encoding/json" "errors" "fmt" "time" "github.com/veraison/corim/cots" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" "github.com/veraison/corim/comid" "github.com/veraison/eat" @@ -19,12 +20,22 @@ import ( // UnsignedCorim is the top-level representation of the unsigned-corim-map with // CBOR and JSON serialization. type UnsignedCorim struct { - ID swid.TagID `cbor:"0,keyasint" json:"corim-id"` - Tags []Tag `cbor:"1,keyasint" json:"tags"` + ID swid.TagID `cbor:"0,keyasint" json:"corim-id"` + // note: even though tags are mandatory for CoRIM, we allow omitting + // them in our JSON templates for cocli (the min template just has + // corim-id). Since we're never writing JSON (so far), this normally + // wouldn't matter, however the custom serialization code we use to + // handle embedded structs relies on the omitempty entry to determine + // if a filed is optional, so we use it during unmarshaling as well as + // marshaling. Hence omitempty is present for the json tag, but not + // cbor. + Tags []Tag `cbor:"1,keyasint" json:"tags,omitempty"` DependentRims *[]Locator `cbor:"2,keyasint,omitempty" json:"dependent-rims,omitempty"` Profiles *[]eat.Profile `cbor:"3,keyasint,omitempty" json:"profiles,omitempty"` RimValidity *Validity `cbor:"4,keyasint,omitempty" json:"validity,omitempty"` Entities *Entities `cbor:"5,keyasint,omitempty" json:"entities,omitempty"` + + Extensions } // NewUnsignedCorim instantiates an empty UnsignedCorim @@ -32,6 +43,16 @@ func NewUnsignedCorim() *UnsignedCorim { return &UnsignedCorim{} } +// RegisterExtensions registers a struct as a collections of extensions +func (o *UnsignedCorim) RegisterExtensions(exts extensions.IExtensionsValue) { + o.Extensions.Register(exts) +} + +// GetExtensions returns pervisouosly registered extension +func (o *UnsignedCorim) GetExtensions() extensions.IExtensionsValue { + return o.Extensions.IExtensionsValue +} + // SetID sets the corim-id in the unsigned-corim-map to the supplied value. The // corim-id can be passed as UUID in string or binary form (i.e., byte array), // or as a (non-empty) string @@ -239,22 +260,22 @@ func (o UnsignedCorim) Valid() error { } } - return nil + return o.Extensions.ValidCorim(&o) } // ToCBOR serializes the target unsigned CoRIM to CBOR -func (o UnsignedCorim) ToCBOR() ([]byte, error) { - return em.Marshal(&o) +func (o *UnsignedCorim) ToCBOR() ([]byte, error) { + return encoding.SerializeStructToCBOR(em, o) } // FromCBOR deserializes a CBOR-encoded unsigned CoRIM into the target UnsignedCorim func (o *UnsignedCorim) FromCBOR(data []byte) error { - return dm.Unmarshal(data, o) + return encoding.PopulateStructFromCBOR(dm, data, o) } // FromJSON deserializes a JSON-encoded unsigned CoRIM into the target UnsignedCorim func (o *UnsignedCorim) FromJSON(data []byte) error { - return json.Unmarshal(data, o) + return encoding.PopulateStructFromJSON(data, o) } // Tag is either a CBOR-encoded CoMID, CoSWID or CoTS diff --git a/corim/unsignedcorim_test.go b/corim/unsignedcorim_test.go index d9b4d105..c2b7ee86 100644 --- a/corim/unsignedcorim_test.go +++ b/corim/unsignedcorim_test.go @@ -299,3 +299,11 @@ func TestUnsignedCorim_AddEntity_non_nil_empty_URI(t *testing.T) { assert.Nil(t, tv) } + +func TestUnsignedCorim_FromJSON(t *testing.T) { + data := []byte(`{"corim-id": "5c57e8f4-46cd-421b-91c9-08cf93e13cfc"}`) + + err := NewUnsignedCorim().FromJSON(data) + + assert.NoError(t, err) +} diff --git a/encoding/cbor.go b/encoding/cbor.go new file mode 100644 index 00000000..5d7a3533 --- /dev/null +++ b/encoding/cbor.go @@ -0,0 +1,409 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 + +package encoding + +import ( + "encoding/binary" + "errors" + "fmt" + "math" + "reflect" + "strconv" + "strings" + + cbor "github.com/fxamacker/cbor/v2" +) + +func SerializeStructToCBOR(em cbor.EncMode, source any) ([]byte, error) { + rawMap := newStructFieldsCBOR() + + structType := reflect.TypeOf(source) + structVal := reflect.ValueOf(source) + + if err := doSerializeStructToCBOR(em, rawMap, structType, structVal); err != nil { + return nil, err + } + + return rawMap.ToCBOR(em) +} + +func doSerializeStructToCBOR( + em cbor.EncMode, + rawMap *structFieldsCBOR, + structType reflect.Type, + structVal reflect.Value, +) error { + if structType.Kind() == reflect.Pointer { + structType = structType.Elem() + structVal = structVal.Elem() + } + + var embeds []embedded + + for i := 0; i < structVal.NumField(); i++ { + typeField := structType.Field(i) + valField := structVal.Field(i) + + if collectEmbedded(&typeField, valField, &embeds) { + continue + } + + tag, ok := typeField.Tag.Lookup("cbor") + if !ok { + continue + } + + parts := strings.Split(tag, ",") + keyString := parts[0] + + isOmitEmpty := false + if len(parts) > 1 { + for _, option := range parts[1:] { + if option == omitempty { + isOmitEmpty = true + break + } + } + } + + // do not serialize zero values if the corresponding field is + // omitempty + if isOmitEmpty && valField.IsZero() { + continue + } + + keyInt, err := strconv.Atoi(keyString) + if err != nil { + return fmt.Errorf("non-integer cbor key: %s", keyString) + } + + data, err := em.Marshal(valField.Interface()) + if err != nil { + return fmt.Errorf("error marshaling field %q: %w", + typeField.Name, + err, + ) + } + + if err := rawMap.Add(keyInt, cbor.RawMessage(data)); err != nil { + return err + } + } + + for _, emb := range embeds { + if err := doSerializeStructToCBOR(em, rawMap, emb.Type, emb.Value); err != nil { + return err + } + } + + return nil +} + +func PopulateStructFromCBOR(dm cbor.DecMode, data []byte, dest any) error { + rawMap := newStructFieldsCBOR() + + if err := rawMap.FromCBOR(dm, data); err != nil { + return err + } + + structType := reflect.TypeOf(dest) + structVal := reflect.ValueOf(dest) + + return doPopulateStructFromCBOR(dm, rawMap, structType, structVal) +} + +func doPopulateStructFromCBOR( + dm cbor.DecMode, + rawMap *structFieldsCBOR, + structType reflect.Type, + structVal reflect.Value, +) error { + if structType.Kind() == reflect.Pointer { + structType = structType.Elem() + structVal = structVal.Elem() + } + + var embeds []embedded + + for i := 0; i < structVal.NumField(); i++ { + typeField := structType.Field(i) + valField := structVal.Field(i) + + if collectEmbedded(&typeField, valField, &embeds) { + continue + } + + tag, ok := typeField.Tag.Lookup("cbor") + if !ok { + continue + } + + parts := strings.Split(tag, ",") + keyString := parts[0] + + isOmitEmpty := false + if len(parts) > 1 { + for _, option := range parts[1:] { + if option == omitempty { + isOmitEmpty = true + break + } + } + } + + keyInt, err := strconv.Atoi(keyString) + if err != nil { + return fmt.Errorf("non-integer cbor key %s", keyString) + } + + rawVal, ok := rawMap.Get(keyInt) + if !ok { + if isOmitEmpty { + continue + } + + return fmt.Errorf("missing mandatory field %q (%d)", + typeField.Name, keyInt) + } + + fieldPtr := valField.Addr().Interface() + if err := dm.Unmarshal(rawVal, fieldPtr); err != nil { + return fmt.Errorf("error unmarshalling field %q: %w", + typeField.Name, + err, + ) + } + + rawMap.Delete(keyInt) + } + + for _, emb := range embeds { + if err := doPopulateStructFromCBOR(dm, rawMap, emb.Type, emb.Value); err != nil { + return err + } + } + + return nil +} + +// structFieldsCBOR is a specialized implementation of "OrderedMap", where the +// order of the keys is kept track of, and used when serializing the map to +// CBOR. While CBOR maps do not mandate any particular ordering, and so this +// isn't strictly necessary, it is useful to have a _stable_ serialization +// order for map keys to be compatible with regular Go struct serialization +// behavior. This is also useful for tests/examples that compare encoded +// []byte's. +type structFieldsCBOR struct { + Fields map[int]cbor.RawMessage + Keys []int +} + +func newStructFieldsCBOR() *structFieldsCBOR { + return &structFieldsCBOR{ + Fields: make(map[int]cbor.RawMessage), + } +} + +func (o structFieldsCBOR) Has(key int) bool { + _, ok := o.Fields[key] + return ok +} + +func (o *structFieldsCBOR) Add(key int, val cbor.RawMessage) error { + if o.Has(key) { + return fmt.Errorf("duplicate cbor key: %d", key) + } + + o.Fields[key] = val + o.Keys = append(o.Keys, key) + + return nil +} + +func (o *structFieldsCBOR) Get(key int) (cbor.RawMessage, bool) { + val, ok := o.Fields[key] + return val, ok +} + +func (o *structFieldsCBOR) Delete(key int) { + delete(o.Fields, key) + + for i, existing := range o.Keys { + if existing == key { + o.Keys = append(o.Keys[:i], o.Keys[i+1:]...) + } + } +} + +func (o *structFieldsCBOR) ToCBOR(em cbor.EncMode) ([]byte, error) { + var out []byte + + header := byte(0xa0) // 0b101_00000 -- Major Type 5 == map + mapLen := len(o.Keys) + if mapLen == 0 { + return []byte{header}, nil + } else if mapLen < 24 { + header = header | byte(mapLen) + out = append(out, header) + } else if mapLen <= math.MaxUint8 { + header = header | byte(24) + out = append(out, header, uint8(mapLen)) + } else if mapLen <= math.MaxUint16 { + header = header | byte(25) + out = append(out, header) + out = binary.BigEndian.AppendUint16(out, uint16(mapLen)) + } else { + header = header | byte(26) + out = append(out, header) + out = binary.BigEndian.AppendUint32(out, uint32(mapLen)) + } + // Since len() returns an int, the value cannot exceed MaxUint32, so + // the 8-byte length variant cannot occur. + + for _, key := range o.Keys { + marshalledKey, err := em.Marshal(key) + if err != nil { + return nil, fmt.Errorf("problem marshaling key %d: %w", key, err) + } + + out = append(out, marshalledKey...) + out = append(out, o.Fields[key]...) + } + + return out, nil +} + +func (o *structFieldsCBOR) FromCBOR(dm cbor.DecMode, data []byte) error { + if len(data) == 0 { + return errors.New("empty input") + } + + header := data[0] + rest := data[1:] + additionalInfo := 0x1f & header + + var err error + + majorType := (0xe0 & header) >> 5 + if majorType == 6 { // tag + _, rest, err = processAdditionalInfo(additionalInfo, rest) + if err != nil { + return err + } + + header = rest[0] + rest = rest[1:] + majorType = (0xe0 & header) >> 5 + additionalInfo = 0x1f & header + } + + if majorType != 5 { + return fmt.Errorf("expected map (CBOR Major Type 5), found Major Type %d", majorType) + } + + var mapLen int + + mapLen, rest, err = processAdditionalInfo(additionalInfo, rest) + if err != nil { + return err + } + + if mapLen != 0 { + o.Fields = make(map[int]cbor.RawMessage, mapLen) + + for i := 0; i < mapLen; i++ { + rest, err = o.unmarshalKeyValue(dm, rest) + if err != nil { + return fmt.Errorf("map item %d: %w", i, err) + } + } + } else { // mapLen == 0 --> indefinite encoding + o.Fields = make(map[int]cbor.RawMessage) + + i := 0 + done := false + for len(rest) > 0 { + if rest[0] == 0xFF { + done = true + break + } + + rest, err = o.unmarshalKeyValue(dm, rest) + if err != nil { + return fmt.Errorf("map item %d: %w", i, err) + } + + i++ + } + + if !done { + return errors.New("unexpected EOF") + } + } + + return nil +} + +func (o *structFieldsCBOR) unmarshalKeyValue(dm cbor.DecMode, rest []byte) ([]byte, error) { + var key int + var val cbor.RawMessage + var err error + + rest, err = dm.UnmarshalFirst(rest, &key) + if err != nil { + return rest, fmt.Errorf("could not unmarshal key: %w", err) + } + + rest, err = dm.UnmarshalFirst(rest, &val) + if err != nil { + return rest, fmt.Errorf("could not unmarshal value: %w", err) + } + + if err = o.Add(key, val); err != nil { + return rest, err + } + + return rest, nil +} + +func processAdditionalInfo( + additionalInfo byte, + data []byte, +) (int, []byte, error) { + var val int + rest := data + + if additionalInfo < 24 { + val = int(additionalInfo) + } else if additionalInfo < 28 { + switch additionalInfo - 23 { + case 1: + if len(data) < 1 { + return 0, nil, errors.New("unexpected EOF") + } + val = int(data[0]) + rest = data[1:] + case 2: + if len(data) < 2 { + return 0, nil, errors.New("unexpected EOF") + } + val = int(binary.BigEndian.Uint16(data[:2])) + rest = data[2:] + case 3: + if len(data) < 4 { + return 0, nil, errors.New("unexpected EOF") + } + val = int(binary.BigEndian.Uint32(data[:4])) + rest = data[4:] + default: + return 0, nil, errors.New("cbor: cannot decode length value of 8 bytes") + } + } else if additionalInfo == 31 { + val = 0 // indefinite encoding + } else { + return 0, nil, fmt.Errorf("cbor: unexpected additional information value %d", additionalInfo) + } + + return val, rest, nil +} diff --git a/encoding/cbor_test.go b/encoding/cbor_test.go new file mode 100644 index 00000000..1dda9146 --- /dev/null +++ b/encoding/cbor_test.go @@ -0,0 +1,279 @@ +// Copyright 2021 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package encoding + +import ( + "testing" + + "github.com/fxamacker/cbor/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_PopulateStructFromCBOR_simple(t *testing.T) { + type SimpleStruct struct { + FieldOne string `cbor:"0,keyasint,omitempty"` + FieldTwo int `cbor:"1,keyasint"` + } + + var v SimpleStruct + + data := []byte{ + 0xa2, // map(2) + + 0x00, // key 0 + 0x64, // val tstr(4) + 0x61, 0x63, 0x6d, 0x65, // "acme" + + 0x01, // key 1 + 0x06, // val 6 + } + + dm, err := cbor.DecOptions{}.DecMode() + require.NoError(t, err) + + err = PopulateStructFromCBOR(dm, data, &v) + require.NoError(t, err) + assert.Equal(t, "acme", v.FieldOne) + assert.Equal(t, 6, v.FieldTwo) + + data = []byte{ + 0xa1, // map(1) + + 0x01, // key 1 + 0x06, // val 6 + } + v = SimpleStruct{} + + err = PopulateStructFromCBOR(dm, data, &v) + require.NoError(t, err) + assert.Equal(t, "", v.FieldOne) + assert.Equal(t, 6, v.FieldTwo) + + data = []byte{ + 0xa1, // map(1) + + 0x02, // key 2 + 0x64, // val tstr(4) + 0x61, 0x63, 0x6d, 0x65, // "acme" + } + v = SimpleStruct{} + + err = PopulateStructFromCBOR(dm, data, &v) + assert.EqualError(t, err, `missing mandatory field "FieldTwo" (1)`) + + err = PopulateStructFromCBOR(dm, []byte{0x01}, &v) + assert.EqualError(t, err, `expected map (CBOR Major Type 5), found Major Type 0`) + + type CompositeStruct struct { + FieldThree string `cbor:"2,keyasint"` + SimpleStruct + } + + var c CompositeStruct + + data = []byte{ + 0xa3, // map(3) + + 0x00, // key 0 + 0x64, // val tstr(4) + 0x61, 0x63, 0x6d, 0x65, // "acme" + + 0x01, // key 1 + 0x06, // val 6 + + 0x02, // key 2 + 0x63, // val tstr(3) + 0x66, 0x6f, 0x6f, // "foo" + } + + err = PopulateStructFromCBOR(dm, data, &c) + require.NoError(t, err) + assert.Equal(t, "acme", c.FieldOne) + assert.Equal(t, 6, c.FieldTwo) + assert.Equal(t, "foo", c.FieldThree) + + em, err := cbor.EncOptions{}.EncMode() + require.NoError(t, err) + + res, err := SerializeStructToCBOR(em, &c) + require.NoError(t, err) + + var c2 CompositeStruct + err = PopulateStructFromCBOR(dm, res, &c2) + require.NoError(t, err) + assert.EqualValues(t, c, c2) + +} + +func Test_structFieldsCBOR_CRUD(t *testing.T) { + sf := newStructFieldsCBOR() + + err := sf.Add(2, cbor.RawMessage{0x02}) + assert.NoError(t, err) + + err = sf.Add(1, cbor.RawMessage{0x01}) + assert.NoError(t, err) + + err = sf.Add(3, cbor.RawMessage{0x03}) + assert.NoError(t, err) + + assert.Equal(t, []int{2, 1, 3}, sf.Keys) + assert.True(t, sf.Has(3)) + assert.False(t, sf.Has(4)) + + val, ok := sf.Get(2) + assert.True(t, ok) + assert.Equal(t, cbor.RawMessage{0x2}, val) + + _, ok = sf.Get(4) + assert.False(t, ok) + + sf.Delete(2) + _, ok = sf.Get(2) + assert.False(t, ok) + + err = sf.Add(1, cbor.RawMessage{0x11}) + assert.EqualError(t, err, "duplicate cbor key: 1") +} + +func Test_structFieldsCBOR_CBOR_roundtrip(t *testing.T) { + em, err := cbor.EncOptions{}.EncMode() + require.NoError(t, err) + dm, err := cbor.DecOptions{}.DecMode() + require.NoError(t, err) + + sf := newStructFieldsCBOR() + + data, err := sf.ToCBOR(em) + require.NoError(t, err) + assert.Equal(t, data, []byte{0xa0}) // empty map + + for i := 0; i < 5; i++ { + err = sf.Add(i, cbor.RawMessage{0x00}) + require.NoError(t, err) + } + + data, err = sf.ToCBOR(em) + require.NoError(t, err) + assert.Equal(t, data, []byte{ + 0xa5, // map 5 + 0x00, 0x00, + 0x01, 0x00, + 0x02, 0x00, + 0x03, 0x00, + 0x04, 0x00, + }) + + sfOut := newStructFieldsCBOR() + err = sfOut.FromCBOR(dm, data) + require.NoError(t, err) + assert.Equal(t, sf, sfOut) + + for i := 5; i < 200; i++ { + err = sf.Add(i, cbor.RawMessage{0x00}) + require.NoError(t, err) + } + + data, err = sf.ToCBOR(em) + require.NoError(t, err) + assert.Equal(t, data[:2], []byte{ + 0xb8, 0xc8, // map 200 + }) + + sfOut = newStructFieldsCBOR() + err = sfOut.FromCBOR(dm, data) + require.NoError(t, err) + assert.Equal(t, sf, sfOut) + + for i := 200; i < 2048; i++ { + err = sf.Add(i, cbor.RawMessage{0x00}) + require.NoError(t, err) + } + + data, err = sf.ToCBOR(em) + require.NoError(t, err) + assert.Equal(t, data[:3], []byte{ + 0xb9, 0x08, 0x00, // map 2048 + }) + + sfOut = newStructFieldsCBOR() + err = sfOut.FromCBOR(dm, data) + require.NoError(t, err) + assert.Equal(t, sf, sfOut) +} + +func Test_structFieldsCBOR_CBOR_decode_tagged(t *testing.T) { + data := []byte{ + 0xc1, // tag 1 + 0xa5, // map 5 + 0x00, 0x00, + 0x01, 0x00, + 0x02, 0x00, + 0x03, 0x00, + 0x04, 0x00, + } + + dm, err := cbor.DecOptions{}.DecMode() + require.NoError(t, err) + + sfOut := newStructFieldsCBOR() + err = sfOut.FromCBOR(dm, data) + require.NoError(t, err) + assert.Equal(t, []int{0, 1, 2, 3, 4}, sfOut.Keys) +} + +func Test_structFieldsCBOR_CBOR_decode_indefinite(t *testing.T) { + data := []byte{ + 0xbf, // indefinite map + 0x00, 0x00, + 0x01, 0x00, + 0x02, 0x00, + 0x03, 0x00, + 0x04, 0x00, + 0xff, // break + } + + dm, err := cbor.DecOptions{}.DecMode() + require.NoError(t, err) + + sfOut := newStructFieldsCBOR() + err = sfOut.FromCBOR(dm, data) + require.NoError(t, err) + assert.Equal(t, []int{0, 1, 2, 3, 4}, sfOut.Keys) +} + +func Test_structFieldsCBOR_CBOR_decode_negative(t *testing.T) { + dm, err := cbor.DecOptions{}.DecMode() + require.NoError(t, err) + + sfOut := newStructFieldsCBOR() + err = sfOut.FromCBOR(dm, []byte{0xa1, 0xff, 0x00}) + assert.EqualError(t, err, `map item 0: could not unmarshal key: cbor: unexpected "break" code`) + err = sfOut.FromCBOR(dm, []byte{0xbf, 0x00, 0x00}) + assert.EqualError(t, err, `unexpected EOF`) + err = sfOut.FromCBOR(dm, []byte{0xa1, 0x00, 0xff}) + assert.EqualError(t, err, `map item 0: could not unmarshal value: cbor: unexpected "break" code`) + + err = sfOut.FromCBOR(dm, []byte{0x00}) + assert.EqualError(t, err, `expected map (CBOR Major Type 5), found Major Type 0`) +} + +func Test_processAdditionalInfo(t *testing.T) { + addInfo := byte(26) + data := []byte{0x00, 0x00, 0x00, 0x01} + + val, rest, err := processAdditionalInfo(addInfo, data) + require.NoError(t, err) + assert.Equal(t, 1, val) + assert.Equal(t, []byte{}, rest) + + _, _, err = processAdditionalInfo(byte(27), data) + assert.EqualError(t, err, "cbor: cannot decode length value of 8 bytes") + + _, _, err = processAdditionalInfo(byte(28), data) + assert.EqualError(t, err, "cbor: unexpected additional information value 28") + + _, _, err = processAdditionalInfo(addInfo, []byte{}) + assert.EqualError(t, err, "unexpected EOF") +} diff --git a/encoding/embedded.go b/encoding/embedded.go new file mode 100644 index 00000000..1bfeb999 --- /dev/null +++ b/encoding/embedded.go @@ -0,0 +1,49 @@ +package encoding + +import "reflect" + +const omitempty = "omitempty" + +type embedded struct { + Type reflect.Type + Value reflect.Value +} + +// collectEmbedded returns true if the Field is embedded (regardless of +// whether or not it was collected). +func collectEmbedded( + typeField *reflect.StructField, + valField reflect.Value, + embeds *[]embedded, +) bool { + // embedded fields are alway anonymous:w + if !typeField.Anonymous { + return false + } + + if typeField.Name == typeField.Type.Name() && + (typeField.Type.Kind() == reflect.Struct || + typeField.Type.Kind() == reflect.Interface) { + + var fieldType reflect.Type + var fieldValue reflect.Value + + if typeField.Type.Kind() == reflect.Interface { + fieldValue = valField.Elem() + if fieldValue.Kind() == reflect.Invalid { + // no value underlying the interface + return true + } + // use the interface's underlying value's real type + fieldType = valField.Elem().Type() + } else { + fieldType = typeField.Type + fieldValue = valField + } + + *embeds = append(*embeds, embedded{Type: fieldType, Value: fieldValue}) + return true + } + + return false +} diff --git a/encoding/json.go b/encoding/json.go new file mode 100644 index 00000000..e8193e4b --- /dev/null +++ b/encoding/json.go @@ -0,0 +1,319 @@ +package encoding + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "reflect" + "strings" +) + +func SerializeStructToJSON(source any) ([]byte, error) { + rawMap := newStructFieldsJSON() + + structType := reflect.TypeOf(source) + structVal := reflect.ValueOf(source) + + if err := doSerializeStructToJSON(rawMap, structType, structVal); err != nil { + return nil, err + } + + return rawMap.ToJSON() +} + +func doSerializeStructToJSON( + rawMap *structFieldsJSON, + structType reflect.Type, + structVal reflect.Value, +) error { + if structType.Kind() == reflect.Pointer { + structType = structType.Elem() + structVal = structVal.Elem() + } + + var embeds []embedded + + for i := 0; i < structVal.NumField(); i++ { + typeField := structType.Field(i) + valField := structVal.Field(i) + + if collectEmbedded(&typeField, valField, &embeds) { + continue + } + + tag, ok := typeField.Tag.Lookup("json") + if !ok { + continue + } + + parts := strings.Split(tag, ",") + key := parts[0] + + isOmitEmpty := false + if len(parts) > 1 { + for _, option := range parts[1:] { + if option == omitempty { + isOmitEmpty = true + break + } + } + } + + // do not serialize zero values if the corresponding field is + // omitempty + if isOmitEmpty && valField.IsZero() { + continue + } + + data, err := json.Marshal(valField.Interface()) + if err != nil { + return fmt.Errorf("error marshaling field %q: %w", + typeField.Name, + err, + ) + } + + if err := rawMap.Add(key, json.RawMessage(data)); err != nil { + return err + } + } + + for _, emb := range embeds { + if err := doSerializeStructToJSON(rawMap, emb.Type, emb.Value); err != nil { + return err + } + } + + return nil +} + +func PopulateStructFromJSON(data []byte, dest any) error { + rawMap := newStructFieldsJSON() + + if err := rawMap.FromJSON(data); err != nil { + return err + } + + structType := reflect.TypeOf(dest) + structVal := reflect.ValueOf(dest) + + return doPopulateStructFromJSON(rawMap, structType, structVal) +} + +func doPopulateStructFromJSON( + rawMap *structFieldsJSON, + structType reflect.Type, + structVal reflect.Value, +) error { + if structType.Kind() == reflect.Pointer { + structType = structType.Elem() + structVal = structVal.Elem() + } + + var embeds []embedded + + for i := 0; i < structVal.NumField(); i++ { + typeField := structType.Field(i) + valField := structVal.Field(i) + + if collectEmbedded(&typeField, valField, &embeds) { + continue + } + + tag, ok := typeField.Tag.Lookup("json") + if !ok { + continue + } + + parts := strings.Split(tag, ",") + key := parts[0] + + isOmitEmpty := false + if len(parts) > 1 { + for _, option := range parts[1:] { + if option == omitempty { + isOmitEmpty = true + break + } + } + } + + rawVal, ok := rawMap.Get(key) + if !ok { + if isOmitEmpty { + continue + } + + return fmt.Errorf("missing mandatory field %q (%q)", + typeField.Name, key) + } + + fieldPtr := valField.Addr().Interface() + if err := json.Unmarshal(rawVal, fieldPtr); err != nil { + return fmt.Errorf("error unmarshalling field %q: %w", + typeField.Name, + err, + ) + } + + rawMap.Delete(key) + } + + for _, emb := range embeds { + if err := doPopulateStructFromJSON(rawMap, emb.Type, emb.Value); err != nil { + return err + } + } + + return nil +} + +// structFieldsJSON is a specialized implementation of "OrderedMap", where the +// order of the keys is kept track of, and used when serializing the map to +// JSON. While JSON maps do not mandate any particular ordering, and so this +// isn't strictly necessary, it is useful to have a _stable_ serialization +// order for map keys to be compatible with regular Go struct serialization +// behavior. This is also useful for tests/examples that compare encoded +// []byte's. +type structFieldsJSON struct { + Fields map[string]json.RawMessage + Keys []string +} + +func newStructFieldsJSON() *structFieldsJSON { + return &structFieldsJSON{ + Fields: make(map[string]json.RawMessage), + } +} + +func (o structFieldsJSON) Has(key string) bool { + _, ok := o.Fields[key] + return ok +} + +func (o *structFieldsJSON) Add(key string, val json.RawMessage) error { + if o.Has(key) { + return fmt.Errorf("duplicate JSON key: %q", key) + } + + o.Fields[key] = val + o.Keys = append(o.Keys, key) + + return nil +} + +func (o *structFieldsJSON) Get(key string) (json.RawMessage, bool) { + val, ok := o.Fields[key] + return val, ok +} + +func (o *structFieldsJSON) Delete(key string) { + delete(o.Fields, key) + + for i, existing := range o.Keys { + if existing == key { + o.Keys = append(o.Keys[:i], o.Keys[i+1:]...) + } + } +} + +func (o *structFieldsJSON) ToJSON() ([]byte, error) { + var out bytes.Buffer + + out.Write([]byte("{")) + + first := true + for _, key := range o.Keys { + if first { + first = false + } else { + out.Write([]byte(",")) + } + marshaledKey, err := json.Marshal(key) + if err != nil { + return nil, fmt.Errorf("problem marshaling key %s: %w", key, err) + } + out.Write(marshaledKey) + out.Write([]byte(":")) + out.Write(o.Fields[key]) + } + + out.Write([]byte("}")) + + return out.Bytes(), nil +} + +func (o *structFieldsJSON) FromJSON(data []byte) error { + if err := json.Unmarshal(data, &o.Fields); err != nil { + return err + } + + return o.unmarshalKeys(data) +} + +func (o *structFieldsJSON) unmarshalKeys(data []byte) error { + + decoder := json.NewDecoder(bytes.NewReader(data)) + + token, err := decoder.Token() + if err != nil { + return err + } + + if token != json.Delim('{') { + return errors.New("expected start of object") + } + + var keys []string + + for { + token, err = decoder.Token() + if err != nil { + return err + } + + if token == json.Delim('}') { + break + } + + key, ok := token.(string) + if !ok { + return fmt.Errorf("expected string, found %T", token) + } + + keys = append(keys, key) + + if err := skipValue(decoder); err != nil { + return err + } + } + + o.Keys = keys + + return nil +} + +var errEndOfStream = errors.New("invalid end of array or object") + +func skipValue(decoder *json.Decoder) error { + + token, err := decoder.Token() + if err != nil { + return err + } + switch token { + case json.Delim('['), json.Delim('{'): + for { + if err := skipValue(decoder); err != nil { + if err == errEndOfStream { + break + } + return err + } + } + case json.Delim(']'), json.Delim('}'): + return errEndOfStream + } + return nil +} diff --git a/encoding/json_test.go b/encoding/json_test.go new file mode 100644 index 00000000..132d1587 --- /dev/null +++ b/encoding/json_test.go @@ -0,0 +1,121 @@ +package encoding + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_PopulateStructFromJSON(t *testing.T) { + type SimpleStruct struct { + FieldOne string `json:"field-one,omitempty"` + FieldTwo int `json:"field-two"` + } + + var v SimpleStruct + + data := []byte(`{"field-one": "acme", "field-two": 6}`) + + err := PopulateStructFromJSON(data, &v) + require.NoError(t, err) + assert.Equal(t, "acme", v.FieldOne) + assert.Equal(t, 6, v.FieldTwo) + + data = []byte(`{"field-two": 6}`) + v = SimpleStruct{} + + err = PopulateStructFromJSON(data, &v) + require.NoError(t, err) + assert.Equal(t, "", v.FieldOne) + assert.Equal(t, 6, v.FieldTwo) + + data = []byte(`{"field-one": "acme"}`) + v = SimpleStruct{} + + err = PopulateStructFromJSON(data, &v) + assert.EqualError(t, err, `missing mandatory field "FieldTwo" ("field-two")`) + + err = PopulateStructFromJSON([]byte("7"), &v) + assert.EqualError(t, err, `json: cannot unmarshal number into Go value of type map[string]json.RawMessage`) + + type CompositeStruct struct { + FieldThree string `json:"field-three"` + SimpleStruct + } + + var c CompositeStruct + + data = []byte(`{"field-one": "acme", "field-two": 6, "field-three": "foo"}`) + + err = PopulateStructFromJSON(data, &c) + require.NoError(t, err) + assert.Equal(t, "acme", c.FieldOne) + assert.Equal(t, 6, c.FieldTwo) + assert.Equal(t, "foo", c.FieldThree) + + res, err := SerializeStructToJSON(&c) + require.NoError(t, err) + + var c2 CompositeStruct + err = PopulateStructFromJSON(res, &c2) + require.NoError(t, err) + assert.EqualValues(t, c, c2) +} + +func Test_structFieldsJSON_CRUD(t *testing.T) { + sf := newStructFieldsJSON() + + err := sf.Add("two", json.RawMessage("2")) + assert.NoError(t, err) + + err = sf.Add("one", json.RawMessage("1")) + assert.NoError(t, err) + + err = sf.Add("three", json.RawMessage("3")) + assert.NoError(t, err) + + assert.Equal(t, []string{"two", "one", "three"}, sf.Keys) + assert.True(t, sf.Has("three")) + assert.False(t, sf.Has("four")) + + val, ok := sf.Get("two") + assert.True(t, ok) + assert.Equal(t, json.RawMessage("2"), val) + + _, ok = sf.Get("four") + assert.False(t, ok) + + sf.Delete("two") + _, ok = sf.Get("two") + assert.False(t, ok) + + err = sf.Add("one", json.RawMessage("4")) + assert.EqualError(t, err, `duplicate JSON key: "one"`) +} + +func Test_skipValue(t *testing.T) { + text := "" + decoder := json.NewDecoder(strings.NewReader(text)) + err := skipValue(decoder) + assert.EqualError(t, err, "EOF") + + text = "[]" + decoder = json.NewDecoder(strings.NewReader(text)) + _, _ = decoder.Token() // skip the '[' + err = skipValue(decoder) + assert.EqualError(t, err, "invalid end of array or object") + + text = `{"embed": {"one": 1, "two": [1,2,3]}, "other": 1}` + decoder = json.NewDecoder(strings.NewReader(text)) + _, _ = decoder.Token() // skip the '{' + _, _ = decoder.Token() // skip the '"embed"' + err = skipValue(decoder) + assert.NoError(t, err) + + token, err := decoder.Token() + assert.NoError(t, err) + assert.Equal(t, "other", token) +} diff --git a/extensions/extensions.go b/extensions/extensions.go new file mode 100644 index 00000000..4be60fd1 --- /dev/null +++ b/extensions/extensions.go @@ -0,0 +1,379 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package extensions + +import ( + "errors" + "fmt" + "reflect" + "strings" + + "github.com/spf13/cast" +) + +var ErrExtensionNotFound = errors.New("extension not found") + +type IExtensionsValue any + +type Extensions struct { + IExtensionsValue `json:"extensions,omitempty"` +} + +func (o *Extensions) Register(exts IExtensionsValue) { + if reflect.TypeOf(exts).Kind() != reflect.Pointer { + panic("attempting to register a non-pointer IExtensionsValue") + } + + o.IExtensionsValue = exts +} + +func (o *Extensions) HaveExtensions() bool { + return o.IExtensionsValue != nil +} + +func (o *Extensions) Get(name string) (any, error) { + if o.IExtensionsValue == nil { + return nil, fmt.Errorf("%w: %s", ErrExtensionNotFound, name) + } + + extType := reflect.TypeOf(o.IExtensionsValue) + extVal := reflect.ValueOf(o.IExtensionsValue) + if extType.Kind() == reflect.Pointer { + extType = extType.Elem() + extVal = extVal.Elem() + } + + var fieldName, fieldJSONTag, fieldCBORTag string + for i := 0; i < extVal.NumField(); i++ { + typeField := extType.Field(i) + fieldName = typeField.Name + + tag, ok := typeField.Tag.Lookup("json") + if ok { + fieldJSONTag = strings.Split(tag, ",")[0] + } + + tag, ok = typeField.Tag.Lookup("cbor") + if ok { + fieldCBORTag = strings.Split(tag, ",")[0] + } + + if fieldName == name || fieldJSONTag == name || fieldCBORTag == name { + return extVal.Field(i).Interface(), nil + } + } + + return nil, fmt.Errorf("%w: %s", ErrExtensionNotFound, name) +} + +func (o *Extensions) MustGetString(name string) string { + v, _ := o.GetString(name) + return v +} + +func (o *Extensions) GetString(name string) (string, error) { + v, err := o.Get(name) + if err != nil { + return "", err + } + + return cast.ToStringE(v) +} + +func (o *Extensions) MustGetInt(name string) int { + v, _ := o.GetInt(name) + return v +} + +func (o *Extensions) GetInt(name string) (int, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToIntE(v) +} + +func (o *Extensions) MustGetInt64(name string) int64 { + v, _ := o.GetInt64(name) + return v +} + +func (o *Extensions) GetInt64(name string) (int64, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToInt64E(v) +} + +func (o *Extensions) MustGetInt32(name string) int32 { + v, _ := o.GetInt32(name) + return v +} + +func (o *Extensions) GetInt32(name string) (int32, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToInt32E(v) +} + +func (o *Extensions) MustGetInt16(name string) int16 { + v, _ := o.GetInt16(name) + return v +} + +func (o *Extensions) GetInt16(name string) (int16, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToInt16E(v) +} + +func (o *Extensions) MustGetInt8(name string) int8 { + v, _ := o.GetInt8(name) + return v +} + +func (o *Extensions) GetInt8(name string) (int8, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToInt8E(v) +} + +func (o *Extensions) MustGetUint(name string) uint { + v, _ := o.GetUint(name) + return v +} + +func (o *Extensions) GetUint(name string) (uint, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToUintE(v) +} + +func (o *Extensions) MustGetUint64(name string) uint64 { + v, _ := o.GetUint64(name) + return v +} + +func (o *Extensions) GetUint64(name string) (uint64, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToUint64E(v) +} + +func (o *Extensions) MustGetUint32(name string) uint32 { + v, _ := o.GetUint32(name) + return v +} + +func (o *Extensions) GetUint32(name string) (uint32, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToUint32E(v) +} + +func (o *Extensions) MustGetUint16(name string) uint16 { + v, _ := o.GetUint16(name) + return v +} + +func (o *Extensions) GetUint16(name string) (uint16, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToUint16E(v) +} + +func (o *Extensions) MustGetUint8(name string) uint8 { + v, _ := o.GetUint8(name) + return v +} + +func (o *Extensions) GetUint8(name string) (uint8, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToUint8E(v) +} + +func (o *Extensions) MustGetFloat32(name string) float32 { + v, _ := o.GetFloat32(name) + return v +} + +func (o *Extensions) GetFloat32(name string) (float32, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToFloat32E(v) +} + +func (o *Extensions) MustGetFloat64(name string) float64 { + v, _ := o.GetFloat64(name) + return v +} + +func (o *Extensions) GetFloat64(name string) (float64, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + return cast.ToFloat64E(v) +} + +func (o *Extensions) MustGetBool(name string) bool { + v, _ := o.GetBool(name) + return v +} + +func (o *Extensions) GetBool(name string) (bool, error) { + v, err := o.Get(name) + if err != nil { + return false, err + } + + return cast.ToBoolE(v) +} + +func (o *Extensions) MustGetSlice(name string) []any { + v, _ := o.GetSlice(name) + return v +} + +func (o *Extensions) GetSlice(name string) ([]any, error) { + v, err := o.Get(name) + if err != nil { + return []any{}, err + } + + return cast.ToSliceE(v) +} + +func (o *Extensions) MustGetIntSlice(name string) []int { + v, _ := o.GetIntSlice(name) + return v +} + +func (o *Extensions) GetIntSlice(name string) ([]int, error) { + v, err := o.Get(name) + if err != nil { + return []int{}, err + } + + return cast.ToIntSliceE(v) +} + +func (o *Extensions) MustGetStringSlice(name string) []string { + v, _ := o.GetStringSlice(name) + return v +} + +func (o *Extensions) GetStringSlice(name string) ([]string, error) { + v, err := o.Get(name) + if err != nil { + return []string{}, err + } + + return cast.ToStringSliceE(v) +} + +func (o *Extensions) MustGetStringMap(name string) map[string]any { + v, _ := o.GetStringMap(name) + return v +} + +func (o *Extensions) GetStringMap(name string) (map[string]any, error) { + v, err := o.Get(name) + if err != nil { + return map[string]any{}, err + } + + return cast.ToStringMapE(v) +} + +func (o *Extensions) MustGetStringMapString(name string) map[string]string { + v, _ := o.GetStringMapString(name) + return v +} + +func (o *Extensions) GetStringMapString(name string) (map[string]string, error) { + v, err := o.Get(name) + if err != nil { + return map[string]string{}, err + } + + return cast.ToStringMapStringE(v) +} + +func (o *Extensions) Set(name string, value any) error { + if o.IExtensionsValue == nil { + return fmt.Errorf("%w: %s", ErrExtensionNotFound, name) + } + + extType := reflect.TypeOf(o.IExtensionsValue) + extVal := reflect.ValueOf(o.IExtensionsValue) + if extType.Kind() == reflect.Pointer { + extType = extType.Elem() + extVal = extVal.Elem() + } + + var fieldName, fieldJSONTag, fieldCBORTag string + for i := 0; i < extVal.NumField(); i++ { + typeField := extType.Field(i) + valField := extVal.Field(i) + fieldName = typeField.Name + + tag, ok := typeField.Tag.Lookup("json") + if ok { + fieldJSONTag = strings.Split(tag, ",")[0] + } + + tag, ok = typeField.Tag.Lookup("cbor") + if ok { + fieldCBORTag = strings.Split(tag, ",")[0] + } + + if fieldName == name || fieldJSONTag == name || fieldCBORTag == name { + newVal := reflect.ValueOf(value) + if newVal.CanConvert(valField.Type()) { + valField.Set(newVal.Convert(valField.Type())) + return nil + } + + return fmt.Errorf( + "cannot set field %q (of type %s) to %v (%T)", + name, typeField.Type.Name(), + value, value, + ) + } + } + + return fmt.Errorf("%w: %s", ErrExtensionNotFound, name) +} diff --git a/extensions/extensions_test.go b/extensions/extensions_test.go new file mode 100644 index 00000000..74d44f52 --- /dev/null +++ b/extensions/extensions_test.go @@ -0,0 +1,122 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package extensions + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type Entity struct { + EntityName string + Roles []int64 + + Extensions +} + +type TestExtensions struct { + Address string `cbor:"-1,keyasint,omitempty" json:"address,omitempty"` + Size int `cbor:"-2,keyasint,omitempty" json:"size,omitempty"` + YearsOnAir float32 `cbor:"-3,keyasint,omitempty" json:"years-on-air,omitempty"` + StillAiring bool `cbor:"-4,keyasint,omitempty" json:"still-airing,omitempty"` + Ages []int `cbor:"-5,keyasint,omitempty" json:"ages,omitempty"` + Jobs map[string]string `cbor:"-6,keyasint,omitempty" json:"jobs,omitempty"` +} + +func TestExtensions_Register(t *testing.T) { + exts := Extensions{} + assert.False(t, exts.HaveExtensions()) + + exts.Register(&TestExtensions{}) + assert.True(t, exts.HaveExtensions()) + + badRegister := func() { + exts.Register(TestExtensions{}) + } + + assert.Panics(t, badRegister) +} + +func TestExtensions_GetSet(t *testing.T) { + extsVal := TestExtensions{ + Address: "742 Evergreen Terrace", + Size: 6, + YearsOnAir: 33.8, + StillAiring: true, + Ages: []int{2, 7, 8, 10, 37, 38}, + Jobs: map[string]string{ + "Homer": "safety inspector", + "Marge": "housewife", + "Bart": "elementary school student", + "Lisa": "elementary school student", + }, + } + exts := Extensions{IExtensionsValue: &extsVal} + + v, err := exts.GetInt("size") + assert.NoError(t, err) + assert.Equal(t, 6, v) + + assert.Equal(t, 6, exts.MustGetInt("size")) + assert.Equal(t, int64(6), exts.MustGetInt64("size")) + assert.Equal(t, int32(6), exts.MustGetInt32("size")) + assert.Equal(t, int16(6), exts.MustGetInt16("size")) + assert.Equal(t, int8(6), exts.MustGetInt8("size")) + + assert.Equal(t, uint(6), exts.MustGetUint("size")) + assert.Equal(t, uint64(6), exts.MustGetUint64("size")) + assert.Equal(t, uint32(6), exts.MustGetUint32("size")) + assert.Equal(t, uint16(6), exts.MustGetUint16("size")) + assert.Equal(t, uint8(6), exts.MustGetUint8("size")) + + assert.InEpsilon(t, float32(33.8), exts.MustGetFloat32("years-on-air"), 0.000001) + assert.InEpsilon(t, float64(33.8), exts.MustGetFloat64("-3"), 0.000001) + + assert.Equal(t, true, exts.MustGetBool("StillAiring")) + + _, err = exts.GetSlice("ages") + assert.EqualError(t, err, + `unable to cast []int{2, 7, 8, 10, 37, 38} of type []int to []interface{}`) + assert.Nil(t, exts.MustGetSlice("ages")) + + assert.EqualValues(t, []int{2, 7, 8, 10, 37, 38}, exts.MustGetIntSlice("ages")) + assert.EqualValues(t, []string{"2", "7", "8", "10", "37", "38"}, + exts.MustGetStringSlice("ages")) + + assert.EqualValues(t, map[string]string{ + "Homer": "safety inspector", + "Marge": "housewife", + "Bart": "elementary school student", + "Lisa": "elementary school student", + }, exts.MustGetStringMapString("jobs")) + + _, err = exts.GetStringMap("jobs") + assert.EqualError(t, err, + `unable to cast map[string]string{"Bart":"elementary school student", "Homer":"safety inspector", "Lisa":"elementary school student", "Marge":"housewife"} of type map[string]string to map[string]interface{}`) + m := exts.MustGetStringMap("jobs") + assert.Equal(t, map[string]any{}, m) + + s, err := exts.GetString("address") + assert.NoError(t, err) + assert.Equal(t, "742 Evergreen Terrace", s) + + _, err = exts.GetInt("address") + assert.EqualError(t, err, `unable to cast "742 Evergreen Terrace" of type string to int`) + + _, err = exts.GetInt("foo") + assert.EqualError(t, err, "extension not found: foo") + + err = exts.Set("-1", "123 Fake Street") + assert.NoError(t, err) + + s, err = exts.GetString("address") + assert.NoError(t, err) + assert.Equal(t, "123 Fake Street", s) + + err = exts.Set("Size", "foo") + assert.EqualError(t, err, `cannot set field "Size" (of type int) to foo (string)`) + + assert.Equal(t, "", exts.MustGetString("does-not-exist")) + assert.Equal(t, 0, exts.MustGetInt("does-not-exist")) +} diff --git a/go.mod b/go.mod index 9d95f047..96a5aa49 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,12 @@ module github.com/veraison/corim go 1.18 require ( - github.com/fxamacker/cbor/v2 v2.4.0 + github.com/fxamacker/cbor/v2 v2.5.0 github.com/golang/mock v1.6.0 github.com/google/uuid v1.3.0 github.com/lestrrat-go/jwx/v2 v2.0.8 github.com/spf13/afero v1.9.2 + github.com/spf13/cast v1.4.1 github.com/spf13/cobra v1.2.1 github.com/spf13/viper v1.9.0 github.com/stretchr/testify v1.8.2 @@ -35,7 +36,6 @@ require ( github.com/moogar0880/problems v0.1.1 // indirect github.com/pelletier/go-toml v1.9.4 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/spf13/cast v1.4.1 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.2.0 // indirect diff --git a/go.sum b/go.sum index c6347523..e1ae4e55 100644 --- a/go.sum +++ b/go.sum @@ -90,8 +90,8 @@ github.com/fsnotify/fsnotify v1.5.1 h1:mZcQUHVQUQWoPXXtuf9yuEXKudkV2sx1E06UadKWp github.com/fsnotify/fsnotify v1.5.1/go.mod h1:T3375wBYaZdLLcVNkcVbzGHY7f1l/uK5T5Ai1i3InKU= github.com/fxamacker/cbor/v2 v2.2.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= github.com/fxamacker/cbor/v2 v2.3.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= -github.com/fxamacker/cbor/v2 v2.4.0 h1:ri0ArlOR+5XunOP8CRUowT0pSJOwhW098ZCUyskZD88= -github.com/fxamacker/cbor/v2 v2.4.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= +github.com/fxamacker/cbor/v2 v2.5.0 h1:oHsG0V/Q6E/wqTS2O1Cozzsy69nqCiguo5Q1a1ADivE= +github.com/fxamacker/cbor/v2 v2.5.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=