From bdcca6b4a88110569bbfe1fa23ce9014154fe61d Mon Sep 17 00:00:00 2001 From: Sergei Trofimov Date: Wed, 13 Sep 2023 17:48:30 +0100 Subject: [PATCH] Implement type choice extensions Implement an extensions mechanism for entries identified as "type-choice" in the [CoRIM] spec draft 2. This done by defining a common registration pattern for ITypeChoiceValue interface implementations (the interface is often extended for specific extension points). A Register function is used to associate a factory that produces a type choice instance populated with an appropriate value with a CBOR tag. The factory function must be able to handle nil values, resulting in the zero-value for the associated type; aside from the, the range of valid values the factory can handle is type-specific. A couple of the entries indicates as type choices in the spec (rel and role) are in fact extensible enums, rather than types, so they're implemented somewhat differently. The registration function for them just takes the new value and associated name; there is no need for a factory function. Additionally: - add tagged-int-type implementation for class-id - add EntityName type to represent entity-name type choice (entity-name was previously implemented as a string). [CoRIM]: https://www.ietf.org/archive/id/draft-ietf-rats-corim-02.html Signed-off-by: Sergei Trofimov --- comid/attestverifkey_test.go | 5 +- comid/cbor.go | 31 +- comid/ccaplatformconfigid.go | 75 ++++- comid/ccaplatformconfigid_test.go | 64 ++++ comid/class.go | 32 +- comid/classid.go | 474 ++++++++++++++++---------- comid/classid_test.go | 272 +++++++++++++-- comid/comid.go | 2 +- comid/cryptokey.go | 322 +++++++++++------- comid/cryptokey_test.go | 67 +++- comid/devidentitykey_test.go | 4 +- comid/entity.go | 231 ++++++++++++- comid/entity_test.go | 184 +++++++++++ comid/environment_test.go | 32 +- comid/example_cca_refval_test.go | 18 +- comid/example_psa_keys_test.go | 7 +- comid/example_psa_refval_test.go | 14 +- comid/example_test.go | 42 ++- comid/group.go | 149 +++++---- comid/group_test.go | 62 ++++ comid/instance.go | 278 +++++++++------- comid/instance_test.go | 67 +++- comid/measurement.go | 532 +++++++++++++++++------------- comid/measurement_test.go | 321 +++++++++--------- comid/oid.go | 77 +++++ comid/psareferencevalue.go | 100 +++++- comid/psareferencevalue_test.go | 18 +- comid/referencevalue_test.go | 3 +- comid/rel.go | 41 ++- comid/rel_test.go | 17 + comid/role.go | 17 + comid/role_test.go | 17 + comid/svn.go | 268 +++++++++++---- comid/svn_test.go | 182 ++++++++++ comid/ueid.go | 71 +++- comid/ueid_test.go | 30 ++ comid/uuid.go | 91 ++++- comid/uuid_test.go | 24 ++ corim/cbor.go | 29 +- corim/entity.go | 238 ++++++++++++- corim/entity_test.go | 192 ++++++++++- corim/extensions_test.go | 4 +- corim/role.go | 20 +- corim/role_test.go | 18 + corim/unsignedcorim_test.go | 5 +- encoding/json.go | 32 ++ encoding/json_test.go | 31 ++ extensions/typechoice.go | 33 ++ 48 files changed, 3732 insertions(+), 1111 deletions(-) create mode 100644 comid/group_test.go create mode 100644 comid/svn_test.go create mode 100644 comid/ueid_test.go create mode 100644 comid/uuid_test.go create mode 100644 extensions/typechoice.go diff --git a/comid/attestverifkey_test.go b/comid/attestverifkey_test.go index 659a0abc..20c791d9 100644 --- a/comid/attestverifkey_test.go +++ b/comid/attestverifkey_test.go @@ -23,16 +23,17 @@ func TestAttestVerifKey_Valid_empty(t *testing.T) { testerr: "environment validation failed: environment must not be empty", }, { - env: Environment{Instance: NewInstanceUEID(TestUEID)}, + env: Environment{Instance: MustNewUEIDInstance(TestUEID)}, verifkey: CryptoKeys{}, testerr: "verification keys validation failed: no keys to validate", }, { - env: Environment{Instance: NewInstanceUEID(TestUEID)}, + env: Environment{Instance: MustNewUEIDInstance(TestUEID)}, verifkey: CryptoKeys{&invalidKey}, testerr: "verification keys validation failed: invalid key at index 0: key value not set", }, } + for _, tv := range tvs { av := AttestVerifKey{Environment: tv.env, VerifKeys: tv.verifkey} err := av.Valid() diff --git a/comid/cbor.go b/comid/cbor.go index c44d518d..ab8a9233 100644 --- a/comid/cbor.go +++ b/comid/cbor.go @@ -4,6 +4,7 @@ package comid import ( + "fmt" "reflect" cbor "github.com/fxamacker/cbor/v2" @@ -12,16 +13,14 @@ import ( var ( em, emError = initCBOREncMode() dm, dmError = initCBORDecMode() -) -func comidTags() cbor.TagSet { - comidTagsMap := map[uint64]interface{}{ + comidTagsMap = map[uint64]interface{}{ 32: TaggedURI(""), 37: TaggedUUID{}, 111: TaggedOID{}, // CoMID tags 550: TaggedUEID{}, - //551: To Do see: https://github.com/veraison/corim/issues/32 + 551: TaggedInt(0), 552: TaggedSVN(0), 553: TaggedMinSVN(0), 554: TaggedPKIXBase64Key(""), @@ -37,7 +36,9 @@ func comidTags() cbor.TagSet { 601: TaggedPSARefValID{}, 602: TaggedCCAPlatformConfigID(""), } +) +func comidTags() cbor.TagSet { opts := cbor.TagOptions{ EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired, @@ -69,6 +70,28 @@ func initCBORDecMode() (dm cbor.DecMode, err error) { return decOpt.DecModeWithTags(comidTags()) } +func registerCOMIDTag(tag uint64, t interface{}) error { + if _, exists := comidTagsMap[tag]; exists { + return fmt.Errorf("tag %d is already registered", tag) + } + + comidTagsMap[tag] = t + + var err error + + em, err = initCBOREncMode() + if err != nil { + return err + } + + dm, err = initCBORDecMode() + if err != nil { + return err + } + + return nil +} + func init() { if emError != nil { panic(emError) diff --git a/comid/ccaplatformconfigid.go b/comid/ccaplatformconfigid.go index e485083f..a55271ff 100644 --- a/comid/ccaplatformconfigid.go +++ b/comid/ccaplatformconfigid.go @@ -3,11 +3,16 @@ package comid -import "fmt" +import ( + "encoding/json" + "errors" + "fmt" + "unicode/utf8" +) -type CCAPlatformConfigID string +var CCAPlatformConfigIDType = "cca.platform-config-id" -type TaggedCCAPlatformConfigID CCAPlatformConfigID +type CCAPlatformConfigID string func (o CCAPlatformConfigID) Empty() bool { return o == "" @@ -27,3 +32,67 @@ func (o CCAPlatformConfigID) Get() (CCAPlatformConfigID, error) { } return o, nil } + +type TaggedCCAPlatformConfigID CCAPlatformConfigID + +func NewTaggedCCAPlatformConfigID(val any) (*TaggedCCAPlatformConfigID, error) { + var ret TaggedCCAPlatformConfigID + + if val == nil { + return &ret, nil + } + + switch t := val.(type) { + case TaggedCCAPlatformConfigID: + ret = t + case *TaggedCCAPlatformConfigID: + ret = *t + case CCAPlatformConfigID: + ret = TaggedCCAPlatformConfigID(t) + case *CCAPlatformConfigID: + ret = TaggedCCAPlatformConfigID(*t) + case string: + ret = TaggedCCAPlatformConfigID(t) + case []byte: + if !utf8.Valid(t) { + return nil, errors.New("bytes do not form a valid UTF-8 string") + } + ret = TaggedCCAPlatformConfigID(t) + default: + return nil, fmt.Errorf("unexpected type for CCA platform-config-id: %T", t) + } + + return &ret, nil +} + +func (o TaggedCCAPlatformConfigID) Valid() error { + if o == "" { + return errors.New("empty value") + } + + return nil +} + +func (o TaggedCCAPlatformConfigID) String() string { + return string(o) +} + +func (o TaggedCCAPlatformConfigID) Type() string { + return CCAPlatformConfigIDType +} + +func (o TaggedCCAPlatformConfigID) IsZero() bool { + return len(o) == 0 +} + +func (o *TaggedCCAPlatformConfigID) UnmarshalJSON(data []byte) error { + var temp string + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + *o = TaggedCCAPlatformConfigID(temp) + + return nil +} diff --git a/comid/ccaplatformconfigid_test.go b/comid/ccaplatformconfigid_test.go index b0a0a37c..543390fd 100644 --- a/comid/ccaplatformconfigid_test.go +++ b/comid/ccaplatformconfigid_test.go @@ -29,3 +29,67 @@ func TestCCAPlatformConfigID_Get_nok(t *testing.T) { _, err := cca.Get() assert.EqualError(t, err, expectedErr) } + +func TestNewTaggedCCAPlatformConfigID(t *testing.T) { + testID := TaggedCCAPlatformConfigID("test") + untagged := CCAPlatformConfigID("test") + + for _, tv := range []struct { + Name string + Input any + Err string + Expected TaggedCCAPlatformConfigID + }{ + { + Name: "TaggedCCAPlatformConfigID ok", + Input: testID, + Expected: testID, + }, + { + Name: "*TaggedCCAPlatformConfigID ok", + Input: &testID, + Expected: testID, + }, + { + Name: "CCAPlatformConfigID ok", + Input: untagged, + Expected: testID, + }, + { + Name: "*CCAPlatformConfigID ok", + Input: &untagged, + Expected: testID, + }, + { + Name: "string ok", + Input: "test", + Expected: testID, + }, + { + Name: "[]byte ok", + Input: []byte{0x74, 0x65, 0x73, 0x74}, + Expected: testID, + }, + { + Name: "[]byte not ok", + Input: []byte{0x80, 0x65, 0x73, 0x74}, + Err: "bytes do not form a valid UTF-8 string", + }, + { + Name: "bad type", + Input: 7, + Err: "unexpected type for CCA platform-config-id: int", + }, + } { + t.Run(tv.Name, func(t *testing.T) { + out, err := NewTaggedCCAPlatformConfigID(tv.Input) + + if tv.Err != "" { + assert.Nil(t, out) + assert.EqualError(t, err, tv.Err) + } else { + assert.Equal(t, tv.Expected, *out) + } + }) + } +} diff --git a/comid/class.go b/comid/class.go index 085e703b..ab414740 100644 --- a/comid/class.go +++ b/comid/class.go @@ -23,39 +23,33 @@ type Class struct { // NewClassUUID instantiates a new Class object with the specified UUID as // identifier func NewClassUUID(uuid UUID) *Class { - c := Class{ - ClassID: &ClassID{}, - } - - if c.ClassID.SetUUID(uuid) == nil { + classID, err := NewUUIDClassID(uuid) + if err != nil { return nil } - return &c + + return &Class{ClassID: classID} } // NewClassImplID instantiates a new Class object that identifies the specified PSA // Implementation ID func NewClassImplID(implID ImplID) *Class { - c := Class{ - ClassID: &ClassID{}, - } - - if c.ClassID.SetImplID(implID) == nil { + classID, err := NewImplIDClassID(implID) + if err != nil { return nil } - return &c + + return &Class{ClassID: classID} } // NewClassOID instantiates a new Class object that identifies the OID func NewClassOID(oid string) *Class { - c := Class{ - ClassID: &ClassID{}, - } - - if c.ClassID.SetOID(oid) == nil { + classID, err := NewOIDClassID(oid) + if err != nil { return nil } - return &c + + return &Class{ClassID: classID} } // SetVendor sets the vendor metadata to the supplied string @@ -131,7 +125,7 @@ func (o *Class) SetIndex(index uint64) *Class { // Valid checks the non-empty<> constraint on the map func (o Class) Valid() error { // check non-empty<{ ... }> - if (o.ClassID == nil || o.ClassID.Unset()) && + if (o.ClassID == nil || !o.ClassID.IsSet()) && o.Vendor == nil && o.Model == nil && o.Layer == nil && o.Index == nil { return fmt.Errorf("class must not be empty") } diff --git a/comid/classid.go b/comid/classid.go index 760f8e5d..33021599 100644 --- a/comid/classid.go +++ b/comid/classid.go @@ -5,245 +5,375 @@ package comid import ( "encoding/base64" + "encoding/binary" "encoding/json" + "errors" "fmt" + "strconv" + + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" ) -// ClassID represents a $class-id-type-choice, which can be one of TaggedUUID, -// TaggedOID, or TaggedImplID (PSA-specific extension) +// ClassID identifies the environment via a well-known identifier. This can be +// an OID, a UUID, or a profile-defined extension type. type ClassID struct { - val interface{} + Value IClassIDValue } -type ClassIDType uint16 +// NewClassID creates a new ClassID of the specified type using the specified value. +func NewClassID(val any, typ string) (*ClassID, error) { + factory, ok := classIDValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unknown class id type: %s", typ) + } -const ( - ClassIDTypeUUID = ClassIDType(iota) - ClassIDTypeImplID - ClassIDTypeOID + return factory(val) +} - ClassIDTypeUnknown = ^ClassIDType(0) -) +// Valid returns nil if the ClassID is valid, or an error describing the +// problem, if it is not. +func (o ClassID) Valid() error { + if o.Value == nil { + return errors.New("nil value") + } -// SetUUID sets the value of the targed ClassID to the supplied UUID -func (o *ClassID) SetUUID(uuid UUID) *ClassID { - if o != nil { - o.val = TaggedUUID(uuid) + return o.Value.Valid() +} + +// Type returns the type of the ClassID +func (o ClassID) Type() string { + if o.Value == nil { + return "" } - return o + + return o.Value.Type() } -type ImplID [32]byte -type TaggedImplID ImplID +// Bytes returns a []byte containing the raw bytes of the class id value +func (o ClassID) Bytes() []byte { + if o.Value == nil { + return []byte{} + } + return o.Value.Bytes() +} -func (o ImplID) MarshalJSON() ([]byte, error) { - return json.Marshal(o[:]) +// IsSet returns true iff the underlying class id value has been set (is not nil) +func (o ClassID) IsSet() bool { + return o.Value != nil } -func (o *ImplID) UnmarshalJSON(data []byte) error { - var b []byte +// MarshalCBOR serializes the target ClassID to CBOR +func (o ClassID) MarshalCBOR() ([]byte, error) { + return em.Marshal(o.Value) +} - if err := json.Unmarshal(data, &b); err != nil { - return fmt.Errorf("bad ImplID: %w", err) +// UnmarshalCBOR deserializes the supplied CBOR buffer into the target ClassID. +// It is undefined behavior to try and inspect the target ClassID in case this +// method returns an error. +func (o *ClassID) UnmarshalCBOR(data []byte) error { + return dm.Unmarshal(data, &o.Value) +} + +// UnmarshalJSON deserializes the supplied JSON object into the target ClassID +// The class id object must have the following shape: +// +// { +// "type": "", +// "value": +// } +// +// where must be one of the known IClassIDValue implementation +// type names (available in this implementation: "uuid", "oid", +// "psa.impl-id", "int"), and is the JSON encoding of the underlying +// class id value. The exact encoding is dependent. For the base +// implementation types it is +// +// oid: dot-seprated integers, e.g. "1.2.3.4" +// psa.impl-id: base64-encoded bytes, e.g. "YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE=" +// uuid: standard UUID string representation, e.g. "550e8400-e29b-41d4-a716-446655440000" +// int: an integer value, e.g. 7 +func (o *ClassID) UnmarshalJSON(data []byte) error { + var tnv encoding.TypeAndValue + + if err := json.Unmarshal(data, &tnv); err != nil { + return fmt.Errorf("class id decoding failure: %w", err) } - if nb := len(b); nb != 32 { - return fmt.Errorf("bad ImplID format: got %d bytes, want 32", nb) + decoded, err := NewClassID(nil, tnv.Type) + if err != nil { + return err + } + + if err := json.Unmarshal(tnv.Value, &decoded.Value); err != nil { + return fmt.Errorf( + "cannot unmarshal class id: %w", + err, + ) + } + + if err := decoded.Value.Valid(); err != nil { + return fmt.Errorf("invalid %s: %w", tnv.Type, err) } - copy(o[:], b) + o.Value = decoded.Value return nil } -type TaggedOID OID +// MarshalJSON serializes the target ClassID to JSON +func (o ClassID) MarshalJSON() ([]byte, error) { + return extensions.TypeChoiceValueMarshalJSON(o.Value) +} -// SetImplID sets the value of the targed ClassID to the supplied PSA -// Implementation ID (see Section 3.2.2 of draft-tschofenig-rats-psa-token) -func (o *ClassID) SetImplID(implID ImplID) *ClassID { - if o != nil { - o.val = TaggedImplID(implID) +// String returns a printable string of the ClassID value. UUIDs use the +// canonical 8-4-4-4-12 format, PSA Implementation IDs are base64 encoded. +// OIDs are output in dotted-decimal notation. +func (o ClassID) String() string { + if o.Value == nil { + return "" } - return o + + return o.Value.String() +} + +type IClassIDValue interface { + extensions.ITypeChoiceValue + + Bytes() []byte +} + +const ImplIDType = "psa.impl-id" + +type ImplID [32]byte + +func (o ImplID) String() string { + return base64.StdEncoding.EncodeToString(o[:]) +} + +func (o ImplID) Valid() error { + return nil } -func (o ClassID) GetImplID() (ImplID, error) { - switch t := o.val.(type) { +type TaggedImplID ImplID + +func NewImplIDClassID(val any) (*ClassID, error) { + var ret TaggedImplID + + if val == nil { + return &ClassID{&TaggedImplID{}}, nil + } + + switch t := val.(type) { + case []byte: + if nb := len(t); nb != 32 { + return nil, fmt.Errorf("bad psa.impl-id: got %d bytes, want 32", nb) + } + + copy(ret[:], t) + case string: + v, err := base64.StdEncoding.DecodeString(t) + if err != nil { + return nil, fmt.Errorf("bad psa.impl-id: %w", err) + } + + if nb := len(v); nb != 32 { + return nil, fmt.Errorf("bad psa.impl-id: decoded %d bytes, want 32", nb) + } + + copy(ret[:], v) case TaggedImplID: - return ImplID(t), nil + copy(ret[:], t[:]) + case *TaggedImplID: + copy(ret[:], (*t)[:]) + case ImplID: + copy(ret[:], t[:]) + case *ImplID: + copy(ret[:], (*t)[:]) default: - return ImplID{}, fmt.Errorf("class-id type is: %T", t) + return nil, fmt.Errorf("unexpected type for psa.impl-id: %T", t) } + + return &ClassID{&ret}, nil } -// SetOID sets the value of the targed ClassID to the supplied OID. -// The OID is a string in dotted-decimal notation -func (o *ClassID) SetOID(s string) *ClassID { - if o != nil { - var berOID OID - if berOID.FromString(s) != nil { - return nil - } - o.val = TaggedOID(berOID) +func MustNewImplIDClassID(val any) *ClassID { + ret, err := NewImplIDClassID(val) + if err != nil { + panic(err) } - return o + + return ret } -// MarshalCBOR serializes the target ClassID to CBOR -func (o ClassID) MarshalCBOR() ([]byte, error) { - return em.Marshal(o.val) +func (o TaggedImplID) Valid() error { + return ImplID(o).Valid() } -// UnmarshalCBOR deserializes the supplied CBOR buffer into the target ClassID. -// It is undefined behavior to try and inspect the target ClassID in case this -// method returns an error. -func (o *ClassID) UnmarshalCBOR(data []byte) error { - var implID TaggedImplID +func (o TaggedImplID) String() string { + return ImplID(o).String() +} - if dm.Unmarshal(data, &implID) == nil { - o.val = implID - return nil +func (o TaggedImplID) Type() string { + return ImplIDType +} + +func (o TaggedImplID) Bytes() []byte { + return o[:] +} + +func (o *TaggedImplID) MarshalJSON() ([]byte, error) { + return json.Marshal((*o)[:]) +} + +func (o *TaggedImplID) UnmarshalJSON(data []byte) error { + var out []byte + if err := json.Unmarshal(data, &out); err != nil { + return err + } + + if len(out) != 32 { + return fmt.Errorf("bad psa.impl-id: decoded %d bytes, want 32", len(out)) } - var uuid TaggedUUID + copy((*o)[:], out) + + return nil +} - if dm.Unmarshal(data, &uuid) == nil { - o.val = uuid - return nil +func NewOIDClassID(val any) (*ClassID, error) { + ret, err := NewTaggedOID(val) + if err != nil { + return nil, err } - var oid TaggedOID + return &ClassID{ret}, nil +} - if dm.Unmarshal(data, &oid) == nil { - o.val = oid - return nil +func MustNewOIDClassID(val any) *ClassID { + ret, err := NewOIDClassID(val) + if err != nil { + panic(err) } - return fmt.Errorf("unknown class id (CBOR: %x)", data) + return ret } -// UnmarshalJSON deserializes the supplied JSON object into the target ClassID -// The class id object must have one of the following shapes: -// -// UUID: -// -// { -// "type": "uuid", -// "value": "69E027B2-7157-4758-BCB4-D9F167FE49EA" -// } -// -// OID: -// -// { -// "type": "oid", -// "value": "2.16.840.1.113741.1.15.4.2" -// } -// -// PSA Implementation ID: -// -// { -// "type": "psa.impl-id", -// "value": "YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE=" -// } -func (o *ClassID) UnmarshalJSON(data []byte) error { - var v tnv +func NewUUIDClassID(val any) (*ClassID, error) { + if val == nil { + return &ClassID{&TaggedUUID{}}, nil + } - if err := json.Unmarshal(data, &v); err != nil { - return err + ret, err := NewTaggedUUID(val) + if err != nil { + return nil, err } - switch v.Type { - case "uuid": // nolint: goconst - var x UUID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedUUID(x) - case "oid": - var x OID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedOID(x) - case "psa.impl-id": - var x ImplID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedImplID(x) - default: - return fmt.Errorf("unknown type '%s' for class id", v.Type) + return &ClassID{ret}, nil +} + +func MustNewUUIDClassID(val any) *ClassID { + ret, err := NewUUIDClassID(val) + if err != nil { + panic(err) } - return nil + return ret } -// MarshalJSON serializes the target ClassID to JSON -func (o ClassID) MarshalJSON() ([]byte, error) { - var ( - v tnv - b []byte - err error - ) - - switch t := o.val.(type) { - case TaggedUUID: - b, err = UUID(t).MarshalJSON() - if err != nil { - return nil, err - } - v = tnv{Type: "uuid", Value: b} - case TaggedOID: - b, err = OID(t).MarshalJSON() +const IntType = "int" + +type TaggedInt int + +func NewIntClassID(val any) (*ClassID, error) { + if val == nil { + zeroVal := TaggedInt(0) + return &ClassID{&zeroVal}, nil + } + + var ret TaggedInt + + switch t := val.(type) { + case string: + i, err := strconv.Atoi(t) if err != nil { - return nil, err + return nil, fmt.Errorf("bad int: %w", err) } - v = tnv{Type: "oid", Value: b} - case TaggedImplID: - b, err = ImplID(t).MarshalJSON() - if err != nil { - return nil, err + ret = TaggedInt(i) + case []byte: + if len(t) != 8 { + return nil, fmt.Errorf("bad int: want 8 bytes, got %d bytes", len(t)) } - v = tnv{Type: "psa.impl-id", Value: b} + ret = TaggedInt(binary.BigEndian.Uint64(t)) + case int: + ret = TaggedInt(t) + case *int: + ret = TaggedInt(*t) + case int64: + ret = TaggedInt(t) + case *int64: + ret = TaggedInt(*t) + case uint64: + ret = TaggedInt(t) + case *uint64: + ret = TaggedInt(*t) default: - return nil, fmt.Errorf("unknown type %T for class-id", t) + return nil, fmt.Errorf("unexpected type for int: %T", t) } - return json.Marshal(v) + if err := ret.Valid(); err != nil { + return nil, err + } + + return &ClassID{&ret}, nil } -// Type returns the type of the target ClassID, i.e., one of UUID, OID or PSA -// Implementation ID -func (o ClassID) Type() ClassIDType { - switch o.val.(type) { - case TaggedUUID: - return ClassIDTypeUUID - case TaggedImplID: - return ClassIDTypeImplID - case TaggedOID: - return ClassIDTypeOID - } - return ClassIDTypeUnknown +func (o TaggedInt) String() string { + return fmt.Sprint(int(o)) } -// String returns a printable string of the ClassID value. UUIDs use the -// canonical 8-4-4-4-12 format, PSA Implementation IDs are base64 encoded. -// OIDs are output in dotted-decimal notation. -func (o ClassID) String() string { - switch t := o.val.(type) { - case TaggedUUID: - return UUID(t).String() - case TaggedImplID: - b := [32]byte(t) - return base64.StdEncoding.EncodeToString(b[:]) - case TaggedOID: - return OID(t).String() - default: - return "" - } +func (o TaggedInt) Valid() error { + return nil +} + +func (o TaggedInt) Type() string { + return "int" +} + +func (o TaggedInt) Bytes() []byte { + var ret [8]byte + binary.BigEndian.PutUint64(ret[:], uint64(o)) + return ret[:] } -// Unset tests whether the target ClassID has been initialized -func (o ClassID) Unset() bool { - return o.val == nil || o.Type() == ClassIDTypeUnknown +type IClassIDFactory func(any) (*ClassID, error) + +var classIDValueRegister = map[string]IClassIDFactory{ + OIDType: NewOIDClassID, + UUIDType: NewUUIDClassID, + IntType: NewIntClassID, + + ImplIDType: NewImplIDClassID, +} + +// RegisterClassIDType registers a new IClassIDValue implementation (created +// by the provided IClassIDFactory) under the specified CBOR tag. +func RegisterClassIDType(tag uint64, factory IClassIDFactory) error { + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Type() + if _, exists := classIDValueRegister[typ]; exists { + return fmt.Errorf("class ID type with name %q already exists", typ) + } + + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + classIDValueRegister[typ] = factory + + return nil } diff --git a/comid/classid_test.go b/comid/classid_test.go index 51b07391..aaeeb393 100644 --- a/comid/classid_test.go +++ b/comid/classid_test.go @@ -4,6 +4,8 @@ package comid import ( + "encoding/binary" + "encoding/json" "fmt" "testing" @@ -12,9 +14,7 @@ import ( ) func TestClassID_MarshalCBOR_UUID(t *testing.T) { - var tv ClassID - - require.NotNil(t, tv.SetUUID(TestUUID)) + tv := MustNewUUIDClassID(TestUUID) // 37(h'31FB5ABF023E4992AA4E95F9C1503BFA') // tag(37): d8 25 @@ -29,9 +29,7 @@ func TestClassID_MarshalCBOR_UUID(t *testing.T) { } func TestClassID_MarshalCBOR_ImplID(t *testing.T) { - var tv ClassID - - require.NotNil(t, tv.SetImplID(TestImplID)) + tv := MustNewImplIDClassID(TestImplID) // 600 (h'61636D652D696D706C656D656E746174696F6E2D69642D303030303030303031') // tag(600): d9 0258 @@ -66,7 +64,7 @@ func TestClassID_UnmarshalCBOR_UUID_OK(t *testing.T) { err := actual.UnmarshalCBOR(tv) assert.Nil(t, err) - assert.Equal(t, ClassIDTypeUUID, actual.Type()) + assert.Equal(t, "uuid", actual.Type()) assert.Equal(t, TestUUIDString, actual.String()) } @@ -79,7 +77,7 @@ func TestClassID_UnmarshalCBOR_ImplID_OK(t *testing.T) { err := actual.UnmarshalCBOR(tv) assert.Nil(t, err) - assert.Equal(t, ClassIDTypeImplID, actual.Type()) + assert.Equal(t, "psa.impl-id", actual.Type()) assert.Equal(t, expected, actual.String()) } @@ -88,12 +86,10 @@ func TestClassID_UnmarshalCBOR_badInput(t *testing.T) { hex := "582061636d652d696d706c656d656e746174696f6e2d69642d303030303030303031" tv := MustHexDecode(t, hex) - expectedError := fmt.Sprintf("unknown class id (CBOR: %s)", hex) - var actual ClassID err := actual.UnmarshalCBOR(tv) - assert.EqualError(t, err, expectedError) + assert.EqualError(t, err, "cbor: cannot unmarshal byte string into Go value of type comid.IClassIDValue") } func TestClassID_UnmarshalJSON_UUID(t *testing.T) { @@ -105,7 +101,7 @@ func TestClassID_UnmarshalJSON_UUID(t *testing.T) { err := actual.UnmarshalJSON([]byte(tv)) assert.Nil(t, err) - assert.Equal(t, ClassIDTypeUUID, actual.Type()) + assert.Equal(t, "uuid", actual.Type()) assert.Equal(t, TestUUIDString, actual.String()) } @@ -122,7 +118,7 @@ func TestClassID_UnmarshalJSON_ImplID(t *testing.T) { err := actual.UnmarshalJSON([]byte(tv)) assert.Nil(t, err) - assert.Equal(t, ClassIDTypeImplID, actual.Type()) + assert.Equal(t, "psa.impl-id", actual.Type()) // the returned string is the base64 encoding of the stored binary assert.Equal(t, expected, actual.String()) } @@ -133,8 +129,8 @@ func TestClassID_UnmarshalJSON_badInput_unknown_type(t *testing.T) { var actual ClassID err := actual.UnmarshalJSON([]byte(tv)) - assert.EqualError(t, err, "unknown type 'FOOBAR' for class id") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.EqualError(t, err, "unknown class id type: FOOBAR") + assert.Equal(t, "", actual.Type()) } func TestClassID_UnmarshalJSON_badInput_missing_value(t *testing.T) { @@ -143,8 +139,8 @@ func TestClassID_UnmarshalJSON_badInput_missing_value(t *testing.T) { var actual ClassID err := actual.UnmarshalJSON([]byte(tv)) - assert.EqualError(t, err, "bad ImplID: unexpected end of JSON input") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.EqualError(t, err, "class id decoding failure: no value provided for psa.impl-id") + assert.Equal(t, "", actual.Type()) } func TestClassID_UnmarshalJSON_badInput_empty_value(t *testing.T) { @@ -153,8 +149,8 @@ func TestClassID_UnmarshalJSON_badInput_empty_value(t *testing.T) { var actual ClassID err := actual.UnmarshalJSON([]byte(tv)) - assert.EqualError(t, err, "bad ImplID format: got 0 bytes, want 32") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.EqualError(t, err, "cannot unmarshal class id: bad psa.impl-id: decoded 0 bytes, want 32") + assert.Equal(t, "", actual.Type()) } func TestClassID_UnmarshalJSON_badInput_badly_encoded_ImplID_value(t *testing.T) { @@ -163,8 +159,8 @@ func TestClassID_UnmarshalJSON_badInput_badly_encoded_ImplID_value(t *testing.T) var actual ClassID err := actual.UnmarshalJSON([]byte(tv)) - assert.EqualError(t, err, "bad ImplID: illegal base64 data at input byte 0") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.EqualError(t, err, "cannot unmarshal class id: illegal base64 data at input byte 0") + assert.Equal(t, "", actual.Type()) } func TestClassID_UnmarshalJSON_badInput_badly_encoded_UUID_value(t *testing.T) { @@ -173,8 +169,8 @@ func TestClassID_UnmarshalJSON_badInput_badly_encoded_UUID_value(t *testing.T) { var actual ClassID err := actual.UnmarshalJSON([]byte(tv)) - assert.EqualError(t, err, "bad UUID: invalid UUID length: 9") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.EqualError(t, err, "cannot unmarshal class id: bad UUID: invalid UUID length: 9") + assert.Equal(t, "", actual.Type()) } func TestClassID_SetOID_ok(t *testing.T) { @@ -200,8 +196,7 @@ func TestClassID_SetOID_ok(t *testing.T) { } for _, tv := range tvs { - c := ClassID{} - assert.NotNil(t, c.SetOID(tv)) + c := MustNewOIDClassID(tv) assert.Equal(t, tv, c.String()) } } @@ -219,7 +214,230 @@ func TestClassID_SetOID_bad(t *testing.T) { } for _, tv := range tvs { - c := ClassID{} - assert.Nil(t, c.SetOID(tv)) + c, err := NewOIDClassID(tv) + assert.NotNil(t, err) + assert.Nil(t, c) + } +} + +func Test_NewImplIDClassID(t *testing.T) { + classID, err := NewImplIDClassID(nil) + expected := [32]byte{} + require.NoError(t, err) + assert.Equal(t, expected[:], classID.Bytes()) + + taggedImplID := TaggedImplID(TestImplID) + + for _, v := range []any{ + TestImplID, + &TestImplID, + taggedImplID, + &taggedImplID, + taggedImplID.Bytes(), + } { + classID, err = NewImplIDClassID(v) + require.NoError(t, err) + assert.Equal(t, taggedImplID.Bytes(), classID.Bytes()) + } + + expected = [32]byte{ + 0x61, 0x63, 0x6d, 0x65, 0x2d, 0x69, 0x6d, 0x70, + 0x6c, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x2d, 0x69, 0x64, 0x2d, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x31, + } + classID, err = NewImplIDClassID("YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE=") + require.NoError(t, err) + assert.Equal(t, expected[:], classID.Bytes()) + + _, err = NewImplIDClassID(7) + assert.EqualError(t, err, "unexpected type for psa.impl-id: int") +} + +func Test_NewUUIDClassID(t *testing.T) { + classID, err := NewUUIDClassID(nil) + + expected := [16]byte{} + require.NoError(t, err) + assert.Equal(t, expected[:], classID.Bytes()) + + taggedUUID := TaggedUUID(TestUUID) + + for _, v := range []any{ + TestUUID, + &TestUUID, + taggedUUID, + &taggedUUID, + taggedUUID.Bytes(), + } { + classID, err = NewUUIDClassID(v) + require.NoError(t, err) + assert.Equal(t, taggedUUID.Bytes(), classID.Bytes()) + } + + classID, err = NewUUIDClassID(taggedUUID.String()) + require.NoError(t, err) + assert.Equal(t, taggedUUID.Bytes(), classID.Bytes()) +} + +func Test_NewOIDClassID(t *testing.T) { + classID, err := NewOIDClassID(nil) + + expected := []byte{} + require.NoError(t, err) + assert.Equal(t, expected, classID.Bytes()) + + var oid OID + require.NoError(t, oid.FromString(TestOID)) + taggedOID := TaggedOID(oid) + + for _, v := range []any{ + TestOID, + oid, + &oid, + taggedOID, + &taggedOID, + taggedOID.Bytes(), + } { + classID, err = NewOIDClassID(v) + require.NoError(t, err) + expected := taggedOID.Bytes() + got := classID.Bytes() + assert.Equal(t, expected, got) + } + + classID, err = NewOIDClassID(taggedOID.String()) + require.NoError(t, err) + assert.Equal(t, taggedOID.Bytes(), classID.Bytes()) +} + +func Test_NewIntClassID(t *testing.T) { + classID, err := NewIntClassID(nil) + require.NoError(t, err) + assert.Equal(t, []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, classID.Bytes()) + + testInt := 7 + testInt64 := int64(7) + testUint64 := uint64(7) + + var testBytes [8]byte + binary.BigEndian.PutUint64(testBytes[:], testUint64) + + for _, v := range []any{ + testInt, + &testInt, + testInt64, + &testInt64, + testUint64, + &testUint64, + "7", + testBytes[:], + } { + classID, err = NewIntClassID(v) + require.NoError(t, err) + got := classID.Bytes() + assert.Equal(t, testBytes[:], got) + } +} + +func Test_TaggedInt(t *testing.T) { + val := TaggedInt(7) + assert.Equal(t, "7", val.String()) + assert.Equal(t, []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07}, val.Bytes()) + assert.Equal(t, "int", val.Type()) + assert.NoError(t, val.Valid()) + + classID := ClassID{&val} + + bytes, err := em.Marshal(classID) + require.NoError(t, err) + assert.Equal(t, []byte{ + 0xd9, 0x02, 0x27, // tag 551 + 0x07, // int 7 + }, bytes) + + var out ClassID + err = dm.Unmarshal(bytes, &out) + require.NoError(t, err) + assert.Equal(t, classID, out) + + jsonBytes, err := json.Marshal(classID) + require.NoError(t, err) + assert.Equal(t, `{"type":"int","value":7}`, string(jsonBytes)) + + out = ClassID{} + err = json.Unmarshal(jsonBytes, &out) + require.NoError(t, err) + assert.Equal(t, classID, out) +} + +type testClassID [4]byte + +func newTestClassID(val any) (*ClassID, error) { + return &ClassID{&testClassID{0x74, 0x65, 0x73, 0x74}}, nil +} + +func (o testClassID) Bytes() []byte { + return o[:] +} + +func (o testClassID) Type() string { + return "test-class-id" +} + +func (o testClassID) String() string { + return "test" +} + +func (o testClassID) Valid() error { + return nil +} + +func (o testClassID) MarshalJSON() ([]byte, error) { + return json.Marshal(o.String()) +} + +func (o *testClassID) UnmarshalJSON(data []byte) error { + var out string + if err := json.Unmarshal(data, &out); err != nil { + return err } + + if len(out) != 4 { + return fmt.Errorf("bad testClassID: decoded %d bytes, want 4", len(out)) + } + + copy((*o)[:], []byte(out)) + + return nil +} + +func Test_RegisterClassIDType(t *testing.T) { + err := RegisterClassIDType(99999, newTestClassID) + require.NoError(t, err) + + classID, err := newTestClassID(nil) + require.NoError(t, err) + + data, err := json.Marshal(classID) + require.NoError(t, err) + assert.Equal(t, string(data), `{"type":"test-class-id","value":"test"}`) + + var out ClassID + err = json.Unmarshal(data, &out) + require.NoError(t, err) + assert.Equal(t, classID.Bytes(), out.Bytes()) + + data, err = em.Marshal(classID) + require.NoError(t, err) + assert.Equal(t, data, []byte{ + 0xda, 0x0, 0x1, 0x86, 0x9f, // tag 99999 + 0x44, // bstr(4) + 0x74, 0x65, 0x73, 0x74, // "test" + }) + + var out2 ClassID + err = dm.Unmarshal(data, &out2) + require.NoError(t, err) + assert.Equal(t, classID.Bytes(), out2.Bytes()) } diff --git a/comid/comid.go b/comid/comid.go index 194745d3..8652d747 100644 --- a/comid/comid.go +++ b/comid/comid.go @@ -120,7 +120,7 @@ func (o *Comid) AddEntity(name string, regID *string, roles ...Role) *Comid { } e := Entity{ - EntityName: name, + EntityName: MustNewStringEntityName(name), RegID: uri, Roles: rs, } diff --git a/comid/cryptokey.go b/comid/cryptokey.go index 53d99df8..c2d1879c 100644 --- a/comid/cryptokey.go +++ b/comid/cryptokey.go @@ -14,6 +14,8 @@ import ( "fmt" "github.com/fxamacker/cbor/v2" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" "github.com/veraison/go-cose" "github.com/veraison/swid" ) @@ -58,50 +60,12 @@ type CryptoKey struct { // specified crypto key type. For PKIX types, k must be a string. For COSE_Key, // k must be a []byte. For thumbprint types, k must be a swid.HashEntry. func NewCryptoKey(k any, typ string) (*CryptoKey, error) { - switch typ { - case PKIXBase64KeyType: - v, ok := k.(string) - if !ok { - return nil, fmt.Errorf("value must be a string; found %T", k) - } - return NewPKIXBase64Key(v) - case PKIXBase64CertType: - v, ok := k.(string) - if !ok { - return nil, fmt.Errorf("value must be a string; found %T", k) - } - return NewPKIXBase64Cert(v) - case PKIXBase64CertPathType: - v, ok := k.(string) - if !ok { - return nil, fmt.Errorf("value must be a string; found %T", k) - } - return NewPKIXBase64CertPath(v) - case COSEKeyType: - v, ok := k.([]byte) - if !ok { - return nil, fmt.Errorf("value must be a []byte; found %T", k) - } - return NewCOSEKey(v) - case ThumbprintType, CertThumbprintType, CertPathThumbprintType: - v, ok := k.(swid.HashEntry) - if !ok { - return nil, fmt.Errorf("value must be a swid.HashEntry; found %T", k) - } - switch typ { - case ThumbprintType: - return NewThumbprint(v) - case CertThumbprintType: - return NewCertThumbprint(v) - case CertPathThumbprintType: - return NewCertPathThumbprint(v) - default: - // Should never here because of the the outer case clause - panic(fmt.Sprintf("unexpected thumbprint type: %s", typ)) - } - default: + factory, ok := cryptoKeyValueRegister[typ] + if !ok { return nil, fmt.Errorf("unexpected CryptoKey type: %s", typ) } + + return factory(k) } // MustNewCryptoKey is the same as NewCryptoKey, but does not return an error, @@ -126,6 +90,11 @@ func (o CryptoKey) Valid() error { return o.Value.Valid() } +// Type returns the type of the CryptoKey value +func (o CryptoKey) Type() string { + return o.Value.Type() +} + // PublicKey returns a crypto.PublicKey constructed from the CryptoKey's // underlying value. This returns an error if the CryptoKey is one of the // thumbprint types. @@ -136,30 +105,14 @@ func (o CryptoKey) PublicKey() (crypto.PublicKey, error) { // MarshalJSON returns a []byte containing the JSON representation of the // CryptoKey. func (o CryptoKey) MarshalJSON() ([]byte, error) { - value := struct { - Type string `json:"type"` - Value string `json:"value"` - }{ - Value: o.Value.String(), - } - - switch o.Value.(type) { - case TaggedPKIXBase64Key: - value.Type = PKIXBase64KeyType - case TaggedPKIXBase64Cert: - value.Type = PKIXBase64CertType - case TaggedPKIXBase64CertPath: - value.Type = PKIXBase64CertPathType - case TaggedCOSEKey: - value.Type = COSEKeyType - case TaggedThumbprint: - value.Type = ThumbprintType - case TaggedCertThumbprint: - value.Type = CertThumbprintType - case TaggedCertPathThumbprint: - value.Type = CertPathThumbprintType - default: - return nil, fmt.Errorf("unexpected ICryptoKeyValue type: %T", o.Value) + valueBytes, err := json.Marshal(o.Value.String()) + if err != nil { + return nil, err + } + + value := encoding.TypeAndValue{ + Type: o.Value.Type(), + Value: valueBytes, } return json.Marshal(value) @@ -168,10 +121,7 @@ func (o CryptoKey) MarshalJSON() ([]byte, error) { // UnmarshalJSON populates the CryptoKey from the JSON representation inside // the provided []byte. func (o *CryptoKey) UnmarshalJSON(b []byte) error { - var value struct { - Type string `json:"type"` - Value string `json:"value"` - } + var value encoding.TypeAndValue if err := json.Unmarshal(b, &value); err != nil { return err @@ -181,36 +131,23 @@ func (o *CryptoKey) UnmarshalJSON(b []byte) error { return errors.New("key type not set") } - switch value.Type { - case PKIXBase64KeyType: - o.Value = TaggedPKIXBase64Key(value.Value) - case PKIXBase64CertType: - o.Value = TaggedPKIXBase64Cert(value.Value) - case PKIXBase64CertPathType: - o.Value = TaggedPKIXBase64CertPath(value.Value) - case COSEKeyType: - data, err := base64.StdEncoding.DecodeString(value.Value) - if err != nil { - return fmt.Errorf("base64 decode error: %w", err) - } - o.Value = TaggedCOSEKey(data) - case ThumbprintType, CertThumbprintType, CertPathThumbprintType: - he, err := swid.ParseHashEntry(value.Value) - if err != nil { - return fmt.Errorf("swid.HashEntry decode error: %w", err) - } - switch value.Type { - case ThumbprintType: - o.Value = TaggedThumbprint{digest{he}} - case CertThumbprintType: - o.Value = TaggedCertThumbprint{digest{he}} - case CertPathThumbprintType: - o.Value = TaggedCertPathThumbprint{digest{he}} - } - default: + factory, ok := cryptoKeyValueRegister[value.Type] + if !ok { return fmt.Errorf("unexpected ICryptoKeyValue type: %q", value.Type) } + var valueString string + if err := json.Unmarshal(value.Value, &valueString); err != nil { + return err + } + + k, err := factory(valueString) + if err != nil { + return err + } + + o.Value = k.Value + return o.Valid() } @@ -229,11 +166,8 @@ func (o *CryptoKey) UnmarshalCBOR(b []byte) error { // ICryptoKeyValue is the interface implemented by the concrete CryptoKey value // types. type ICryptoKeyValue interface { - // String returns the string representation of the ICryptoKeyValue. - String() string - // Valid returns an error if validation of the ICryptoKeyValue fails, - // or nil if it succeeds. - Valid() error + extensions.ITypeChoiceValue + // PublicKey returns a crypto.PublicKey constructed from the // ICryptoKeyValue's underlying value. This returns an error if the // ICryptoKeyValue is one of the thumbprint types. @@ -244,7 +178,12 @@ type ICryptoKeyValue interface { // https://www.rfc-editor.org/rfc/rfc7468#section-13 type TaggedPKIXBase64Key string -func NewPKIXBase64Key(s string) (*CryptoKey, error) { +func NewPKIXBase64Key(k any) (*CryptoKey, error) { + s, ok := k.(string) + if !ok { + return nil, fmt.Errorf("value must be a string; found %T", k) + } + key := TaggedPKIXBase64Key(s) if err := key.Valid(); err != nil { return nil, err @@ -252,8 +191,8 @@ func NewPKIXBase64Key(s string) (*CryptoKey, error) { return &CryptoKey{key}, nil } -func MustNewPKIXBase64Key(s string) *CryptoKey { - key, err := NewPKIXBase64Key(s) +func MustNewPKIXBase64Key(k any) *CryptoKey { + key, err := NewPKIXBase64Key(k) if err != nil { panic(err) } @@ -269,6 +208,10 @@ func (o TaggedPKIXBase64Key) Valid() error { return err } +func (o TaggedPKIXBase64Key) Type() string { + return PKIXBase64KeyType +} + func (o TaggedPKIXBase64Key) PublicKey() (crypto.PublicKey, error) { if string(o) == "" { return nil, errors.New("key value not set") @@ -302,7 +245,12 @@ func (o TaggedPKIXBase64Key) PublicKey() (crypto.PublicKey, error) { // certificate. See https://www.rfc-editor.org/rfc/rfc7468#section-5 type TaggedPKIXBase64Cert string -func NewPKIXBase64Cert(s string) (*CryptoKey, error) { +func NewPKIXBase64Cert(k any) (*CryptoKey, error) { + s, ok := k.(string) + if !ok { + return nil, fmt.Errorf("value must be a string; found %T", k) + } + cert := TaggedPKIXBase64Cert(s) if err := cert.Valid(); err != nil { return nil, err @@ -310,8 +258,8 @@ func NewPKIXBase64Cert(s string) (*CryptoKey, error) { return &CryptoKey{cert}, nil } -func MustNewPKIXBase64Cert(s string) *CryptoKey { - cert, err := NewPKIXBase64Cert(s) +func MustNewPKIXBase64Cert(k any) *CryptoKey { + cert, err := NewPKIXBase64Cert(k) if err != nil { panic(err) } @@ -327,6 +275,10 @@ func (o TaggedPKIXBase64Cert) Valid() error { return err } +func (o TaggedPKIXBase64Cert) Type() string { + return PKIXBase64CertType +} + func (o TaggedPKIXBase64Cert) PublicKey() (crypto.PublicKey, error) { cert, err := o.cert() if err != nil { @@ -375,7 +327,11 @@ func (o TaggedPKIXBase64Cert) cert() (*x509.Certificate, error) { // directly certifies the one preceding. type TaggedPKIXBase64CertPath string -func NewPKIXBase64CertPath(s string) (*CryptoKey, error) { +func NewPKIXBase64CertPath(k any) (*CryptoKey, error) { + s, ok := k.(string) + if !ok { + return nil, fmt.Errorf("value must be a string; found %T", k) + } cert := TaggedPKIXBase64CertPath(s) if err := cert.Valid(); err != nil { @@ -385,8 +341,8 @@ func NewPKIXBase64CertPath(s string) (*CryptoKey, error) { return &CryptoKey{cert}, nil } -func MustNewPKIXBase64CertPath(s string) *CryptoKey { - cert, err := NewPKIXBase64CertPath(s) +func MustNewPKIXBase64CertPath(k any) *CryptoKey { + cert, err := NewPKIXBase64CertPath(k) if err != nil { panic(err) @@ -404,6 +360,10 @@ func (o TaggedPKIXBase64CertPath) Valid() error { return err } +func (o TaggedPKIXBase64CertPath) Type() string { + return PKIXBase64CertPathType +} + func (o TaggedPKIXBase64CertPath) PublicKey() (crypto.PublicKey, error) { certs, err := o.certPath() if err != nil { @@ -468,7 +428,22 @@ func (o TaggedPKIXBase64CertPath) certPath() ([]*x509.Certificate, error) { // https://www.rfc-editor.org/rfc/rfc9052#section-7 type TaggedCOSEKey []byte -func NewCOSEKey(b []byte) (*CryptoKey, error) { +func NewCOSEKey(k any) (*CryptoKey, error) { + var b []byte + var err error + + switch t := k.(type) { + case []byte: + b = t + case string: + b, err = base64.StdEncoding.DecodeString(t) + if err != nil { + return nil, fmt.Errorf("base64 decode error: %w", err) + } + default: + return nil, fmt.Errorf("value must be a []byte or a string; found %T", k) + } + key := TaggedCOSEKey(b) if err := key.Valid(); err != nil { @@ -478,8 +453,8 @@ func NewCOSEKey(b []byte) (*CryptoKey, error) { return &CryptoKey{key}, nil } -func MustNewCOSEKey(b []byte) *CryptoKey { - key, err := NewCOSEKey(b) +func MustNewCOSEKey(k any) *CryptoKey { + key, err := NewCOSEKey(k) if err != nil { panic(err) @@ -509,6 +484,10 @@ func (o TaggedCOSEKey) Valid() error { return err } +func (o TaggedCOSEKey) Type() string { + return COSEKeyType +} + func (o TaggedCOSEKey) PublicKey() (crypto.PublicKey, error) { if len(o) == 0 { return nil, errors.New("empty COSE_Key value") @@ -608,7 +587,22 @@ type TaggedThumbprint struct { digest } -func NewThumbprint(he swid.HashEntry) (*CryptoKey, error) { +func NewThumbprint(k any) (*CryptoKey, error) { + var he swid.HashEntry + var err error + + switch t := k.(type) { + case string: + he, err = swid.ParseHashEntry(t) + if err != nil { + return nil, fmt.Errorf("swid.HashEntry decode error: %w", err) + } + case swid.HashEntry: + he = t + default: + return nil, fmt.Errorf("value must be a swid.HashEntry or a string; found %T", k) + } + key := &CryptoKey{TaggedThumbprint{digest{he}}} if err := key.Valid(); err != nil { @@ -618,8 +612,8 @@ func NewThumbprint(he swid.HashEntry) (*CryptoKey, error) { return key, nil } -func MustNewThumbprint(he swid.HashEntry) *CryptoKey { - key, err := NewThumbprint(he) +func MustNewThumbprint(k any) *CryptoKey { + key, err := NewThumbprint(k) if err != nil { panic(err) @@ -628,13 +622,32 @@ func MustNewThumbprint(he swid.HashEntry) *CryptoKey { return key } +func (o TaggedThumbprint) Type() string { + return ThumbprintType +} + // TaggedCertThumbprint represents a digest of a certificate. The digest value // may be used to find the certificate if contained in a lookup table. type TaggedCertThumbprint struct { digest } -func NewCertThumbprint(he swid.HashEntry) (*CryptoKey, error) { +func NewCertThumbprint(k any) (*CryptoKey, error) { + var he swid.HashEntry + var err error + + switch t := k.(type) { + case string: + he, err = swid.ParseHashEntry(t) + if err != nil { + return nil, fmt.Errorf("swid.HashEntry decode error: %w", err) + } + case swid.HashEntry: + he = t + default: + return nil, fmt.Errorf("value must be a swid.HashEntry or a string; found %T", k) + } + key := &CryptoKey{TaggedCertThumbprint{digest{he}}} if err := key.Valid(); err != nil { @@ -644,8 +657,8 @@ func NewCertThumbprint(he swid.HashEntry) (*CryptoKey, error) { return key, nil } -func MustNewCertThumbprint(he swid.HashEntry) *CryptoKey { - key, err := NewCertThumbprint(he) +func MustNewCertThumbprint(k any) *CryptoKey { + key, err := NewCertThumbprint(k) if err != nil { panic(err) @@ -654,6 +667,10 @@ func MustNewCertThumbprint(he swid.HashEntry) *CryptoKey { return key } +func (o TaggedCertThumbprint) Type() string { + return CertThumbprintType +} + // TaggedCertPathThumbprint represents a digest of a certification path. The // digest value may be used to find the certificate path if contained in a // lookup table. @@ -661,7 +678,22 @@ type TaggedCertPathThumbprint struct { digest } -func NewCertPathThumbprint(he swid.HashEntry) (*CryptoKey, error) { +func NewCertPathThumbprint(k any) (*CryptoKey, error) { + var he swid.HashEntry + var err error + + switch t := k.(type) { + case string: + he, err = swid.ParseHashEntry(t) + if err != nil { + return nil, fmt.Errorf("swid.HashEntry decode error: %w", err) + } + case swid.HashEntry: + he = t + default: + return nil, fmt.Errorf("value must be a swid.HashEntry or a string; found %T", k) + } + key := &CryptoKey{TaggedCertPathThumbprint{digest{he}}} if err := key.Valid(); err != nil { @@ -671,8 +703,8 @@ func NewCertPathThumbprint(he swid.HashEntry) (*CryptoKey, error) { return key, nil } -func MustNewCertPathThumbprint(he swid.HashEntry) *CryptoKey { - key, err := NewCertPathThumbprint(he) +func MustNewCertPathThumbprint(k any) *CryptoKey { + key, err := NewCertPathThumbprint(k) if err != nil { panic(err) @@ -680,3 +712,47 @@ func MustNewCertPathThumbprint(he swid.HashEntry) *CryptoKey { return key } + +func (o TaggedCertPathThumbprint) Type() string { + return CertPathThumbprintType +} + +// ICryptoKeyFactory creates a *CryptoKey from the specified any value. When +// passed nil as input, this must return the equivaluent nil-value for the +// associated ICryptoKeyValue implementation. +type ICryptoKeyFactory func(any) (*CryptoKey, error) + +var cryptoKeyValueRegister = map[string]ICryptoKeyFactory{ + // types defined by the core spec + PKIXBase64KeyType: NewPKIXBase64Key, + PKIXBase64CertType: NewPKIXBase64Cert, + PKIXBase64CertPathType: NewPKIXBase64CertPath, + COSEKeyType: NewCOSEKey, + ThumbprintType: NewThumbprint, + CertThumbprintType: NewCertThumbprint, + CertPathThumbprintType: NewCertPathThumbprint, +} + +// RegisterCryptoKeyType registers a new ICryptoKeyValue implementation +// (created by the provided ICryptoKeyFactory) under the specified type name +// and CBOR tag. +func RegisterCryptoKeyType(tag uint64, factory ICryptoKeyFactory) error { + + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Type() + if _, exists := cryptoKeyValueRegister[typ]; exists { + return fmt.Errorf("crypto key type with name %q already exists", typ) + } + + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + cryptoKeyValueRegister[typ] = factory + + return nil +} diff --git a/comid/cryptokey_test.go b/comid/cryptokey_test.go index 1c0812fc..40c5af9f 100644 --- a/comid/cryptokey_test.go +++ b/comid/cryptokey_test.go @@ -4,6 +4,7 @@ package comid import ( + "crypto" "encoding/base64" "encoding/json" "fmt" @@ -149,7 +150,7 @@ func Test_CryptoKey_NewCOSEKey(t *testing.T) { } func Test_CryptoKey_NewThumbprint(t *testing.T) { - type newKeyFunc func(swid.HashEntry) (*CryptoKey, error) + type newKeyFunc func(any) (*CryptoKey, error) for _, newFunc := range []newKeyFunc{ NewThumbprint, @@ -177,7 +178,7 @@ func Test_CryptoKey_NewThumbprint(t *testing.T) { assert.Contains(t, err.Error(), "length mismatch for hash algorithm") } - type mustNewKeyFunc func(swid.HashEntry) *CryptoKey + type mustNewKeyFunc func(any) *CryptoKey for _, mustNewFunc := range []mustNewKeyFunc{ MustNewThumbprint, @@ -271,7 +272,7 @@ func Test_CryptoKey_UnmarshalJSON_negative(t *testing.T) { }, { Val: `{"value":"deadbeef"}`, - ErrMsg: "key type not set", + ErrMsg: "type not set", }, { Val: `{"type": "cose-key", "value":";;;"}`, @@ -371,22 +372,22 @@ func Test_NewCryptoKey_negative(t *testing.T) { { Type: COSEKeyType, In: 7, - ErrMsg: "value must be a []byte; found int", + ErrMsg: "value must be a []byte or a string; found int", }, { Type: ThumbprintType, In: 7, - ErrMsg: "value must be a swid.HashEntry; found int", + ErrMsg: "value must be a swid.HashEntry or a string; found int", }, { Type: CertThumbprintType, In: 7, - ErrMsg: "value must be a swid.HashEntry; found int", + ErrMsg: "value must be a swid.HashEntry or a string; found int", }, { Type: CertPathThumbprintType, In: 7, - ErrMsg: "value must be a swid.HashEntry; found int", + ErrMsg: "value must be a swid.HashEntry or a string; found int", }, { Type: "random-key", @@ -399,3 +400,55 @@ func Test_NewCryptoKey_negative(t *testing.T) { assert.ErrorContains(t, err, tv.ErrMsg) } } + +type testCryptoKey [4]byte + +func newTestCryptoKey(val any) (*CryptoKey, error) { + return &CryptoKey{&testCryptoKey{0x74, 0x64, 0x73, 0x74}}, nil +} + +func (o testCryptoKey) PublicKey() (crypto.PublicKey, error) { + return crypto.PublicKey(o[:]), nil +} + +func (o testCryptoKey) Type() string { + return "test-crypto-key" +} + +func (o testCryptoKey) String() string { + return "test" +} + +func (o testCryptoKey) Valid() error { + return nil +} + +func Test_RegisterCryptoKeyTypeType(t *testing.T) { + err := RegisterCryptoKeyType(99998, newTestCryptoKey) + require.NoError(t, err) + + key, err := newTestCryptoKey(nil) + require.NoError(t, err) + + data, err := json.Marshal(key) + require.NoError(t, err) + assert.Equal(t, string(data), `{"type":"test-crypto-key","value":"test"}`) + + var out CryptoKey + err = json.Unmarshal(data, &out) + require.NoError(t, err) + assert.EqualValues(t, key, &out) + + data, err = em.Marshal(key) + require.NoError(t, err) + assert.Equal(t, data, []byte{ + 0xda, 0x0, 0x1, 0x86, 0x9e, // tag 99998 + 0x44, // bstr(4) + 0x74, 0x64, 0x73, 0x74, // "test" + }) + + var out2 CryptoKey + err = dm.Unmarshal(data, &out2) + require.NoError(t, err) + assert.Equal(t, key, &out2) +} diff --git a/comid/devidentitykey_test.go b/comid/devidentitykey_test.go index 14f2f6e1..b12f6a41 100644 --- a/comid/devidentitykey_test.go +++ b/comid/devidentitykey_test.go @@ -23,12 +23,12 @@ func TestDevIdentityKey_Valid_empty(t *testing.T) { testerr: "environment validation failed: environment must not be empty", }, { - env: Environment{Instance: NewInstanceUEID(TestUEID)}, + env: Environment{Instance: MustNewUEIDInstance(TestUEID)}, verifkey: CryptoKeys{}, testerr: "verification keys validation failed: no keys to validate", }, { - env: Environment{Instance: NewInstanceUEID(TestUEID)}, + env: Environment{Instance: MustNewUEIDInstance(TestUEID)}, verifkey: CryptoKeys{&invalidKey}, testerr: "verification keys validation failed: invalid key at index 0: key value not set", }, diff --git a/comid/entity.go b/comid/entity.go index b207c7d8..0ec06bfe 100644 --- a/comid/entity.go +++ b/comid/entity.go @@ -4,23 +4,20 @@ package comid import ( + "encoding/json" + "errors" "fmt" + "unicode/utf8" "github.com/veraison/corim/encoding" "github.com/veraison/corim/extensions" ) -type TaggedURI string - -func (o TaggedURI) Empty() bool { - return o == "" -} - // Entity stores an entity-map capable of CBOR and JSON serializations. type Entity struct { - EntityName string `cbor:"0,keyasint" json:"name"` - RegID *TaggedURI `cbor:"1,keyasint,omitempty" json:"regid,omitempty"` - Roles Roles `cbor:"2,keyasint" json:"roles"` + EntityName *EntityName `cbor:"0,keyasint" json:"name"` + RegID *TaggedURI `cbor:"1,keyasint,omitempty" json:"regid,omitempty"` + Roles Roles `cbor:"2,keyasint" json:"roles"` Extensions } @@ -41,7 +38,7 @@ func (o *Entity) SetEntityName(name string) *Entity { if name == "" { return nil } - o.EntityName = name + o.EntityName = MustNewStringEntityName(name) } return o } @@ -68,10 +65,14 @@ func (o *Entity) SetRoles(roles ...Role) *Entity { // Valid checks for validity of the fields within each Entity func (o Entity) Valid() error { - if o.EntityName == "" { + if o.EntityName == nil { return fmt.Errorf("invalid entity: empty entity-name") } + if err := o.EntityName.Valid(); err != nil { + return fmt.Errorf("invalid entity: %w", err) + } + if o.RegID != nil && o.RegID.Empty() { return fmt.Errorf("invalid entity: empty reg-id") } @@ -128,3 +129,211 @@ func (o Entities) Valid() error { } return nil } + +// EntityName encapsulates the name of the associated Entity. The CoRIM +// specification only allows for text (string) name, but this may be extended +// by other specifications. +type EntityName struct { + Value IEntityName +} + +// NewEntityName creates a new EntityName of the specified type using the +// provided value. +func NewEntityName(val any, typ string) (*EntityName, error) { + factory, ok := entityNameValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unexpected entity name type: %s", typ) + } + + return factory(val) +} + +// MustNewEntityName is like NewEntityName, except it doesn't return an error, +// assuming that the provided value is valid. It panics if that isn't the case. +func MustNewEntityName(val any, typ string) *EntityName { + ret, err := NewEntityName(val, typ) + if err != nil { + panic(err) + } + + return ret +} + +func (o EntityName) String() string { + return o.Value.String() +} + +func (o EntityName) Valid() error { + if o.Value == nil { + return errors.New("empty entity name") + } + + return o.Value.Valid() +} + +func (o EntityName) MarshalCBOR() ([]byte, error) { + if err := o.Valid(); err != nil { + return nil, err + } + + return em.Marshal(o.Value) +} + +func (o *EntityName) UnmarshalCBOR(data []byte) error { + if len(data) == 0 { + return errors.New("empty") + } + + majorType := (data[0] & 0xe0) >> 5 + if majorType == 3 { // text string + var text string + + if err := dm.Unmarshal(data, &text); err != nil { + return err + } + + name := StringEntityName(text) + o.Value = &name + + return nil + } + + return dm.Unmarshal(data, &o.Value) +} + +func (o EntityName) MarshalJSON() ([]byte, error) { + if err := o.Valid(); err != nil { + return nil, err + } + + if o.Value.Type() == extensions.StringType { + return json.Marshal(o.Value.String()) + } + + return extensions.TypeChoiceValueMarshalJSON(o.Value) +} + +func (o *EntityName) UnmarshalJSON(data []byte) error { + var text string + if err := json.Unmarshal(data, &text); err == nil { + *o = *MustNewStringEntityName(text) + return nil + } + + var tnv encoding.TypeAndValue + + if err := json.Unmarshal(data, &tnv); err != nil { + return fmt.Errorf("entity name decoding failure: %w", err) + } + + decoded, err := NewEntityName(nil, tnv.Type) + if err != nil { + return err + } + + if err := json.Unmarshal(tnv.Value, &decoded.Value); err != nil { + return fmt.Errorf( + "cannot unmarshal entity name: %w", + err, + ) + } + + if err := decoded.Value.Valid(); err != nil { + return fmt.Errorf("invalid %s: %w", tnv.Type, err) + } + + o.Value = decoded.Value + + return nil +} + +type IEntityName interface { + extensions.ITypeChoiceValue +} + +type StringEntityName string + +func NewStringEntityName(val any) (*EntityName, error) { + var ret StringEntityName + + if val == nil { + ret = StringEntityName("") + return &EntityName{&ret}, nil + } + + switch t := val.(type) { + case string: + ret = StringEntityName(t) + case []byte: + if !utf8.Valid(t) { + return nil, errors.New("bytes do not form a valid UTF-8 string") + } + + ret = StringEntityName(t) + default: + return nil, fmt.Errorf("unexpected type for string entity name: %T", t) + } + + return &EntityName{&ret}, nil +} + +func MustNewStringEntityName(val any) *EntityName { + ret, err := NewStringEntityName(val) + if err != nil { + panic(err) + } + + return ret +} + +func (o StringEntityName) String() string { + return string(o) +} + +func (o StringEntityName) Type() string { + return extensions.StringType +} + +func (o StringEntityName) Valid() error { + if o == "" { + return errors.New("empty entity-name") + } + + return nil +} + +type IEntityNameFactory func(any) (*EntityName, error) + +var entityNameValueRegister = map[string]IEntityNameFactory{ + extensions.StringType: NewStringEntityName, +} + +// RegisterEntityNameType registers a new IEntityNameValue implementation +// (created by the provided IEntityNameFactory) under the specified type name +// and CBOR tag. +func RegisterEntityNameType(tag uint64, factory IEntityNameFactory) error { + + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Value.Type() + if _, exists := entityNameValueRegister[typ]; exists { + return fmt.Errorf("entity name type with name %q already exists", typ) + } + + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + entityNameValueRegister[typ] = factory + + return nil +} + +type TaggedURI string + +func (o TaggedURI) Empty() bool { + return o == "" +} diff --git a/comid/entity_test.go b/comid/entity_test.go index 61c0d1b3..d634a2f4 100644 --- a/comid/entity_test.go +++ b/comid/entity_test.go @@ -4,6 +4,8 @@ package comid import ( + "errors" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -83,3 +85,185 @@ func TestEntity_SetRegID_empty(t *testing.T) { assert.Nil(t, e.SetRegID("")) } + +type testEntityName uint64 + +func newTestEntityName(val any) (*EntityName, error) { + if val == nil { + v := testEntityName(0) + return &EntityName{&v}, nil + } + + u, ok := val.(uint64) + if !ok { + return nil, errors.New("must be uint64") + } + + v := testEntityName(u) + return &EntityName{&v}, nil +} + +func (o testEntityName) Type() string { + return "test" +} + +func (o testEntityName) String() string { + return fmt.Sprint(uint64(o)) +} + +func (o testEntityName) Valid() error { + return nil +} + +type testEntityNameBadType struct { + testEntityName +} + +func newTestEntityNameBadType(val any) (*EntityName, error) { + v := testEntityNameBadType{testEntityName(7)} + return &EntityName{&v}, nil +} + +func (o testEntityNameBadType) Type() string { + return "string" +} + +func Test_RegisterEntityNameType(t *testing.T) { + err := RegisterEntityNameType(32, newTestEntityName) + assert.EqualError(t, err, "tag 32 is already registered") + + err = RegisterEntityNameType(99994, newTestEntityNameBadType) + assert.EqualError(t, err, `entity name type with name "string" already exists`) + + registerTestEntityNameType(t) +} + +// Since there only one, untagged, entity name type in the core spec, we use +// the test type define above in order to test the marshaling code works +// properly. Since global environment is not reset when running multiple tests, +// we cannot simply call RegisterEntityNameType() inside each test that relies +// on the test type, as that will cause the "tag already registered" error. On +// the other hand, we do not want to create inter-test dependencies by relying +// that the test registering the type is run before the others that rely on it. +// To get around this, use this global flag to only register the test type if a +// previous test hasn't already done so. +var testEntityNameTypeRegistered = false + +func registerTestEntityNameType(t *testing.T) { + if !testEntityNameTypeRegistered { + err := RegisterEntityNameType(99994, newTestEntityName) + require.NoError(t, err) + + testEntityNameTypeRegistered = true + } +} + +func TestEntityName_CBOR(t *testing.T) { + registerTestEntityNameType(t) + + for _, tv := range []struct { + Value any + Type string + ExpectedBytes []byte + ExpectedString string + }{ + { + Value: "test", + Type: "string", + ExpectedBytes: []byte{ + 0x64, // tstr(4) + 0x74, 0x65, 0x73, 0x74, // "test" + }, + ExpectedString: "test", + }, + { + Value: uint64(7), + Type: "test", + ExpectedBytes: []byte{ + 0xda, 0x0, 0x1, 0x86, 0x9a, // tag 99994 + 0x07, // unsigned int(7) + }, + ExpectedString: "7", + }, + } { + t.Run(tv.Type, func(t *testing.T) { + en, err := NewEntityName(tv.Value, tv.Type) + require.NoError(t, err) + + data, err := en.MarshalCBOR() + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedBytes, data) + + var out EntityName + + err = out.UnmarshalCBOR(data) + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedString, out.String()) + }) + } +} + +func TestEntityName_JSON(t *testing.T) { + registerTestEntityNameType(t) + + for _, tv := range []struct { + Value any + Type string + ExpectedBytes []byte + ExpectedString string + }{ + { + Value: "test", + Type: "string", + ExpectedBytes: []byte(`"test"`), + ExpectedString: "test", + }, + { + Value: uint64(7), + Type: "test", + ExpectedBytes: []byte(`{"type":"test","value":7}`), + ExpectedString: "7", + }, + } { + t.Run(tv.Type, func(t *testing.T) { + en, err := NewEntityName(tv.Value, tv.Type) + require.NoError(t, err) + + data, err := en.MarshalJSON() + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedBytes, data) + + var out EntityName + + err = out.UnmarshalJSON(data) + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedString, out.String()) + }) + } +} + +func Test_NewStringEntityName(t *testing.T) { + out, err := NewStringEntityName(nil) + require.NoError(t, err) + assert.EqualError(t, out.Valid(), "empty entity-name") + + out, err = NewStringEntityName([]byte("test")) + require.NoError(t, err) + assert.Equal(t, "test", out.String()) + + _, err = NewStringEntityName(7) + assert.EqualError(t, err, "unexpected type for string entity name: int") +} + +func Test_MustNewEntityName(t *testing.T) { + out := MustNewEntityName("test", "string") + assert.Equal(t, "test", out.String()) + + assert.Panics(t, func() { + MustNewEntityName(7, "int") + }) +} diff --git a/comid/environment_test.go b/comid/environment_test.go index c078ce90..f33d1ebf 100644 --- a/comid/environment_test.go +++ b/comid/environment_test.go @@ -46,7 +46,7 @@ func TestEnvironment_Valid_empty_group(t *testing.T) { err := tv.Valid() - assert.EqualError(t, err, "group validation failed: invalid group id") + assert.EqualError(t, err, "group validation failed: no value set") } func TestEnvironment_Valid_ok_with_class(t *testing.T) { tv := Environment{ @@ -78,7 +78,7 @@ func TestEnvironment_ToCBOR_class_only(t *testing.T) { func TestEnvironment_ToCBOR_class_and_instance(t *testing.T) { tv := Environment{ Class: NewClassUUID(TestUUID), - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewUEIDInstance(TestUEID), } require.NotNil(t, tv.Class) require.NotNil(t, tv.Instance) @@ -96,7 +96,7 @@ func TestEnvironment_ToCBOR_class_and_instance(t *testing.T) { func TestEnvironment_ToCBOR_instance_only(t *testing.T) { tv := Environment{ - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewUEIDInstance(TestUEID), } require.NotNil(t, tv.Instance) @@ -113,7 +113,7 @@ func TestEnvironment_ToCBOR_instance_only(t *testing.T) { func TestEnvironment_ToCBOR_group_only(t *testing.T) { tv := Environment{ - Group: NewGroupUUID(TestUUID), + Group: MustNewUUIDGroup(TestUUID), } require.NotNil(t, tv.Group) @@ -180,7 +180,7 @@ func TestEnvironment_FromCBOR_class_and_instance(t *testing.T) { assert.NotNil(t, actual.Class) assert.Equal(t, TestUUIDString, actual.Class.ClassID.String()) assert.NotNil(t, actual.Instance) - assert.Equal(t, TestUEIDString, actual.Instance.String()) + assert.Equal(t, []byte(TestUEID), actual.Instance.Bytes()) assert.Nil(t, actual.Group) } @@ -197,3 +197,25 @@ func TestEnvironment_FromCBOR_group_only(t *testing.T) { assert.NotNil(t, actual.Group) assert.Equal(t, TestUUIDString, actual.Group.String()) } + +func TestEnviroment_JSON(t *testing.T) { + testEnv := Environment{ + Class: NewClassUUID(TestUUID), + } + + out, err := testEnv.ToJSON() + require.NoError(t, err) + assert.Equal(t, `{"class":{"id":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}}}`, string(out)) + + var outEnv Environment + + err = outEnv.FromJSON(out) + require.NoError(t, err) + assert.Equal(t, testEnv, outEnv) + + _, err = Environment{}.ToJSON() + assert.EqualError(t, err, "environment must not be empty") + + err = outEnv.FromJSON([]byte(`{"class": 7}`)) + assert.EqualError(t, err, "json: cannot unmarshal number into Go struct field Environment.class of type comid.Class") +} diff --git a/comid/example_cca_refval_test.go b/comid/example_cca_refval_test.go index a9d1102b..9b52304c 100644 --- a/comid/example_cca_refval_test.go +++ b/comid/example_cca_refval_test.go @@ -66,21 +66,23 @@ func extractCCARefVal(rv ReferenceValue) error { if !m.Key.IsSet() { return fmt.Errorf("mKey not set at index %d", i) } - if m.Key.IsPSARefValID() { + + switch t := m.Key.Value.(type) { + case *TaggedPSARefValID: if err := extractSwMeasurement(m); err != nil { return fmt.Errorf("extracting measurement at index %d: %w", i, err) } - } - if m.Key.IsCCAPlatformConfigID() { + case *TaggedCCAPlatformConfigID: if err := extractCCARefValID(m.Key); err != nil { return fmt.Errorf("extracting cca-refval-id: %w", err) } if err := extractRawValue(m.Val.RawValue); err != nil { return fmt.Errorf("extracting raw vlue: %w", err) } - - return nil + default: + return fmt.Errorf("unexpected Mkey type: %T", t) } + } return nil @@ -105,9 +107,9 @@ func extractCCARefValID(k *Mkey) error { return fmt.Errorf("no measurement key") } - id, err := k.GetCCAPlatformConfigID() - if err != nil { - return fmt.Errorf("getting CCA platform config id: %w", err) + id, ok := k.Value.(*TaggedCCAPlatformConfigID) + if !ok { + return fmt.Errorf("expected CCA platform config id, found: %T", k.Value) } fmt.Printf("Label: %s\n", id) return nil diff --git a/comid/example_psa_keys_test.go b/comid/example_psa_keys_test.go index edf95989..3780cbbd 100644 --- a/comid/example_psa_keys_test.go +++ b/comid/example_psa_keys_test.go @@ -70,12 +70,7 @@ func extractInstanceID(i *Instance) error { return fmt.Errorf("no instance") } - instID, err := i.GetUEID() - if err != nil { - return fmt.Errorf("extracting implemenetation-id: %w", err) - } - - fmt.Printf("InstanceID: %x\n", instID) + fmt.Printf("InstanceID: %x\n", i.Bytes()) return nil } diff --git a/comid/example_psa_refval_test.go b/comid/example_psa_refval_test.go index 230209d0..18817d91 100644 --- a/comid/example_psa_refval_test.go +++ b/comid/example_psa_refval_test.go @@ -111,9 +111,10 @@ func extractPSARefValID(k *Mkey) error { return fmt.Errorf("no measurement key") } - id, err := k.GetPSARefValID() - if err != nil { - return fmt.Errorf("getting PSA refval id: %w", err) + id, ok := k.Value.(*TaggedPSARefValID) + + if !ok { + return fmt.Errorf("expected PSA refval id, found: %T", k.Value) } fmt.Printf("SignerID: %x\n", id.SignerID) @@ -142,12 +143,11 @@ func extractImplementationID(c *Class) error { return fmt.Errorf("no class-id") } - implID, err := classID.GetImplID() - if err != nil { - return fmt.Errorf("extracting implemenetation-id: %w", err) + if classID.Type() != ImplIDType { + return fmt.Errorf("class id is not a psa.impl-id") } - fmt.Printf("ImplementationID: %x\n", implID) + fmt.Printf("ImplementationID: %x\n", classID.Bytes()) return nil } diff --git a/comid/example_test.go b/comid/example_test.go index 756829d3..9d3985c6 100644 --- a/comid/example_test.go +++ b/comid/example_test.go @@ -26,13 +26,12 @@ func Example_encode() { SetModel("RoadRunner"). SetLayer(0). SetIndex(1), - Instance: NewInstanceUEID(TestUEID), - Group: NewGroupUUID(TestUUID), + Instance: MustNewUEIDInstance(TestUEID), + Group: MustNewUUIDGroup(TestUUID), }, Measurements: *NewMeasurements(). AddMeasurement( - NewMeasurement(). - SetKeyUUID(TestUUID). + MustNewUUIDMeasurement(TestUUID). SetRawValueBytes([]byte{0x01, 0x02, 0x03, 0x04}, []byte{0xff, 0xff, 0xff, 0xff}). SetSVN(2). AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}). @@ -55,13 +54,12 @@ func Example_encode() { SetModel("RoadRunner"). SetLayer(0). SetIndex(1), - Instance: NewInstanceUEID(TestUEID), - Group: NewGroupUUID(TestUUID), + Instance: MustNewUEIDInstance(TestUEID), + Group: MustNewUUIDGroup(TestUUID), }, Measurements: *NewMeasurements(). AddMeasurement( - NewMeasurement(). - SetKeyUUID(TestUUID). + MustNewUUIDMeasurement(TestUUID). SetRawValueBytes([]byte{0x01, 0x02, 0x03, 0x04}, []byte{0xff, 0xff, 0xff, 0xff}). SetMinSVN(2). AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}). @@ -79,7 +77,7 @@ func Example_encode() { AddAttestVerifKey( AttestVerifKey{ Environment: Environment{ - Instance: NewInstanceUUID(uuid.UUID(TestUUID)), + Instance: MustNewUUIDInstance(uuid.UUID(TestUUID)), }, VerifKeys: *NewCryptoKeys(). Add( @@ -89,7 +87,7 @@ func Example_encode() { ).AddDevIdentityKey( DevIdentityKey{ Environment: Environment{ - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewUEIDInstance(TestUEID), }, VerifKeys: *NewCryptoKeys(). Add( @@ -110,7 +108,7 @@ func Example_encode() { // Output: // a50065656e2d474201a10078206d792d6e733a61636d652d726f616472756e6e65722d737570706c656d656e740282a3006941434d45204c74642e01d8207468747470733a2f2f61636d652e6578616d706c6502820100a20069454d4341204c74642e0281020382a200781a6d792d6e733a61636d652d726f616472756e6e65722d626173650100a20078196d792d6e733a61636d652d726f616472756e6e65722d6f6c64010104a4008182a300a500d86f445502c000016941434d45204c74642e026a526f616452756e6e65720300040101d902264702deadbeefdead02d8255031fb5abf023e4992aa4e95f9c1503bfa81a200d8255031fb5abf023e4992aa4e95f9c1503bfa01aa01d90228020282820644abcdef00820644ffffffff03a201f403f504d9023044010203040544ffffffff064802005e1000000001075020010db8000000000000000000000068086c43303258373056484a484435094702deadbeefdead0a5031fb5abf023e4992aa4e95f9c1503bfa018182a300a500d8255031fb5abf023e4992aa4e95f9c1503bfa016941434d45204c74642e026a526f616452756e6e65720300040101d902264702deadbeefdead02d8255031fb5abf023e4992aa4e95f9c1503bfa81a200d8255031fb5abf023e4992aa4e95f9c1503bfa01aa01d90229020282820644abcdef00820644ffffffff03a300f401f403f504d9023044010203040544ffffffff064802005e1000000001075020010db8000000000000000000000068086c43303258373056484a484435094702deadbeefdead0a5031fb5abf023e4992aa4e95f9c1503bfa028182a101d8255031fb5abf023e4992aa4e95f9c1503bfa81d9022a78b12d2d2d2d2d424547494e205055424c4943204b45592d2d2d2d2d0a4d466b77457759484b6f5a497a6a3043415159494b6f5a497a6a304441516344516741455731427671462b2f727938425761375a454d553178595948455138420a6c4c54344d46484f614f2b4943547449767245654570722f7366544150363648326843486462354845584b74524b6f6436514c634f4c504131513d3d0a2d2d2d2d2d454e44205055424c4943204b45592d2d2d2d2d038182a101d902264702deadbeefdead81d9022a78b12d2d2d2d2d424547494e205055424c4943204b45592d2d2d2d2d0a4d466b77457759484b6f5a497a6a3043415159494b6f5a497a6a304441516344516741455731427671462b2f727938425761375a454d553178595948455138420a6c4c54344d46484f614f2b4943547449767245654570722f7366544150363648326843486462354845584b74524b6f6436514c634f4c504131513d3d0a2d2d2d2d2d454e44205055424c4943204b45592d2d2d2d2d - // {"lang":"en-GB","tag-identity":{"id":"my-ns:acme-roadrunner-supplement"},"entities":[{"name":"ACME Ltd.","regid":"https://acme.example","roles":["creator","tagCreator"]},{"name":"EMCA Ltd.","roles":["maintainer"]}],"linked-tags":[{"target":"my-ns:acme-roadrunner-base","rel":"supplements"},{"target":"my-ns:acme-roadrunner-old","rel":"replaces"}],"triples":{"reference-values":[{"environment":{"class":{"id":{"type":"oid","value":"2.5.2.8192"},"vendor":"ACME Ltd.","model":"RoadRunner","layer":0,"index":1},"instance":{"type":"ueid","value":"At6tvu/erQ=="},"group":{"type":"ueid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"measurements":[{"key":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"value":{"svn":{"type":"exact-value","value":2},"digests":["sha-256-32;q83vAA==","sha-256-32;/////w=="],"flags":{"is-secure":false,"is-debug":true},"raw-value":{"type":"bytes","value":"AQIDBA=="},"raw-value-mask":"/////w==","mac-addr":"02:00:5e:10:00:00:00:01","ip-addr":"2001:db8::68","serial-number":"C02X70VHJHD5","ueid":"At6tvu/erQ==","uuid":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}}]}],"endorsed-values":[{"environment":{"class":{"id":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"vendor":"ACME Ltd.","model":"RoadRunner","layer":0,"index":1},"instance":{"type":"ueid","value":"At6tvu/erQ=="},"group":{"type":"ueid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"measurements":[{"key":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"value":{"svn":{"type":"min-value","value":2},"digests":["sha-256-32;q83vAA==","sha-256-32;/////w=="],"flags":{"is-configured":false,"is-secure":false,"is-debug":true},"raw-value":{"type":"bytes","value":"AQIDBA=="},"raw-value-mask":"/////w==","mac-addr":"02:00:5e:10:00:00:00:01","ip-addr":"2001:db8::68","serial-number":"C02X70VHJHD5","ueid":"At6tvu/erQ==","uuid":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}}]}],"attester-verification-keys":[{"environment":{"instance":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}],"dev-identity-keys":[{"environment":{"instance":{"type":"ueid","value":"At6tvu/erQ=="}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}]}} + // {"lang":"en-GB","tag-identity":{"id":"my-ns:acme-roadrunner-supplement"},"entities":[{"name":"ACME Ltd.","regid":"https://acme.example","roles":["creator","tagCreator"]},{"name":"EMCA Ltd.","roles":["maintainer"]}],"linked-tags":[{"target":"my-ns:acme-roadrunner-base","rel":"supplements"},{"target":"my-ns:acme-roadrunner-old","rel":"replaces"}],"triples":{"reference-values":[{"environment":{"class":{"id":{"type":"oid","value":"2.5.2.8192"},"vendor":"ACME Ltd.","model":"RoadRunner","layer":0,"index":1},"instance":{"type":"ueid","value":"At6tvu/erQ=="},"group":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"measurements":[{"key":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"value":{"svn":{"type":"exact-value","value":2},"digests":["sha-256-32;q83vAA==","sha-256-32;/////w=="],"flags":{"is-secure":false,"is-debug":true},"raw-value":{"type":"bytes","value":"AQIDBA=="},"raw-value-mask":"/////w==","mac-addr":"02:00:5e:10:00:00:00:01","ip-addr":"2001:db8::68","serial-number":"C02X70VHJHD5","ueid":"At6tvu/erQ==","uuid":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}}]}],"endorsed-values":[{"environment":{"class":{"id":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"vendor":"ACME Ltd.","model":"RoadRunner","layer":0,"index":1},"instance":{"type":"ueid","value":"At6tvu/erQ=="},"group":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"measurements":[{"key":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"},"value":{"svn":{"type":"min-value","value":2},"digests":["sha-256-32;q83vAA==","sha-256-32;/////w=="],"flags":{"is-configured":false,"is-secure":false,"is-debug":true},"raw-value":{"type":"bytes","value":"AQIDBA=="},"raw-value-mask":"/////w==","mac-addr":"02:00:5e:10:00:00:00:01","ip-addr":"2001:db8::68","serial-number":"C02X70VHJHD5","ueid":"At6tvu/erQ==","uuid":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}}]}],"attester-verification-keys":[{"environment":{"instance":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}],"dev-identity-keys":[{"environment":{"instance":{"type":"ueid","value":"At6tvu/erQ=="}},"verification-keys":[{"type":"pkix-base64-key","value":"-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEW1BvqF+/ry8BWa7ZEMU1xYYHEQ8B\nlLT4MFHOaO+ICTtIvrEeEpr/sfTAP66H2hCHdb5HEXKtRKod6QLcOLPA1Q==\n-----END PUBLIC KEY-----"}]}]}} } func Example_encode_PSA() { @@ -126,25 +124,23 @@ func Example_encode_PSA() { }, Measurements: *NewMeasurements(). AddMeasurement( - NewPSAMeasurement( - *NewPSARefValID(TestSignerID). - SetLabel("BL"). - SetVersion("5.0.5"), - ).AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}), + MustNewPSAMeasurement( + MustCreatePSARefValID( + TestSignerID, "BL", "5.0.5", + )).AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}), ). AddMeasurement( - NewPSAMeasurement( - *NewPSARefValID(TestSignerID). - SetLabel("PRoT"). - SetVersion("1.3.5"), - ).AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}), + MustNewPSAMeasurement( + MustCreatePSARefValID( + TestSignerID, "PRoT", "1.3.5", + )).AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}), ), }, ). AddAttestVerifKey( AttestVerifKey{ Environment: Environment{ - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewUEIDInstance(TestUEID), }, VerifKeys: *NewCryptoKeys(). Add( @@ -175,7 +171,7 @@ func Example_encode_PSA_attestation_verification() { AddAttestVerifKey( AttestVerifKey{ Environment: Environment{ - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewUEIDInstance(TestUEID), }, VerifKeys: *NewCryptoKeys(). Add( diff --git a/comid/group.go b/comid/group.go index adb9cdb4..06fd6709 100644 --- a/comid/group.go +++ b/comid/group.go @@ -5,66 +5,51 @@ package comid import ( "encoding/json" + "errors" "fmt" + + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" ) // Group stores a group identity. The supported format is UUID. type Group struct { - val interface{} + Value IGroupValue } // NewGroup instantiates an empty group -func NewGroup() *Group { - return &Group{} -} - -// SetUUID sets the identity of the target group to the supplied UUID -func (o *Group) SetUUID(val UUID) *Group { - if o != nil { - o.val = TaggedUUID(val) +func NewGroup(val any, typ string) (*Group, error) { + factory, ok := groupValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unknown group type: %s", typ) } - return o -} -// NewGroupUUID instantiates a new group with the supplied UUID identity -func NewGroupUUID(val UUID) *Group { - return NewGroup().SetUUID(val) + return factory(val) } // Valid checks for the validity of given group func (o Group) Valid() error { - if o.String() == "" { - return fmt.Errorf("invalid group id") + if o.Value == nil { + return errors.New("no value set") } - return nil + + return o.Value.Valid() } // String returns a printable string of the Group value. UUIDs use the // canonical 8-4-4-4-12 format, UEIDs are hex encoded. func (o Group) String() string { - switch t := o.val.(type) { - case TaggedUUID: - return UUID(t).String() - default: - return "" - } + return o.Value.String() } // MarshalCBOR serializes the target group to CBOR func (o Group) MarshalCBOR() ([]byte, error) { - return em.Marshal(o.val) + return em.Marshal(o.Value) } // UnmarshalCBOR deserializes the supplied CBOR into the target group func (o *Group) UnmarshalCBOR(data []byte) error { - var uuid TaggedUUID - - if dm.Unmarshal(data, &uuid) == nil { - o.val = uuid - return nil - } - - return fmt.Errorf("unknown group type (CBOR: %x)", data) + return dm.Unmarshal(data, &o.Value) } // UnmarshalJSON deserializes the supplied JSON type/value object into the Group @@ -75,43 +60,89 @@ func (o *Group) UnmarshalCBOR(data []byte) error { // "value": "69E027B2-7157-4758-BCB4-D9F167FE49EA" // } func (o *Group) UnmarshalJSON(data []byte) error { - var v tnv + var tnv encoding.TypeAndValue - if err := json.Unmarshal(data, &v); err != nil { + if err := json.Unmarshal(data, &tnv); err != nil { + return fmt.Errorf("group decoding failure: %w", err) + } + + decoded, err := NewGroup(nil, tnv.Type) + if err != nil { return err } - switch v.Type { - case "uuid": - var x UUID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedUUID(x) - default: - return fmt.Errorf("unknown type %s for group", v.Type) + if err := json.Unmarshal(tnv.Value, &decoded.Value); err != nil { + return fmt.Errorf( + "cannot unmarshal group: %w", + err, + ) + } + + if err := decoded.Value.Valid(); err != nil { + return fmt.Errorf("invalid %s: %w", tnv.Type, err) } + o.Value = decoded.Value + return nil } func (o Group) MarshalJSON() ([]byte, error) { - var ( - v tnv - b []byte - err error - ) - - switch t := o.val.(type) { - case TaggedUUID: - b, err = UUID(t).MarshalJSON() - if err != nil { - return nil, err - } - v = tnv{Type: "ueid", Value: b} - default: - return nil, fmt.Errorf("unknown type %T for group", t) + return extensions.TypeChoiceValueMarshalJSON(o.Value) +} + +type IGroupValue interface { + extensions.ITypeChoiceValue +} + +func NewUUIDGroup(val any) (*Group, error) { + if val == nil { + return &Group{&TaggedUUID{}}, nil + } + + u, err := NewTaggedUUID(val) + if err != nil { + return nil, err } - return json.Marshal(v) + return &Group{u}, nil +} + +func MustNewUUIDGroup(val any) *Group { + ret, err := NewUUIDGroup(val) + if err != nil { + panic(err) + } + + return ret +} + +type IGroupFactory func(any) (*Group, error) + +var groupValueRegister = map[string]IGroupFactory{ + UUIDType: NewUUIDGroup, +} + +// RegisterGroupType registers a new IGroupValue implementation +// (created by the provided IGroupFactory) under the specified type name +// and CBOR tag. +func RegisterGroupType(tag uint64, factory IGroupFactory) error { + + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Value.Type() + if _, exists := groupValueRegister[typ]; exists { + return fmt.Errorf("Group type with name %q already exists", typ) + } + + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + groupValueRegister[typ] = factory + + return nil } diff --git a/comid/group_test.go b/comid/group_test.go new file mode 100644 index 00000000..4363f3b2 --- /dev/null +++ b/comid/group_test.go @@ -0,0 +1,62 @@ +package comid + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testGroup uint64 + +func newTestGroup(val any) (*Group, error) { + v := testGroup(7) + return &Group{&v}, nil +} + +func (o testGroup) Type() string { + return "test-value" +} + +func (o testGroup) String() string { + return "test" +} + +func (o testGroup) Valid() error { + return nil +} + +type testGroupBadType struct { + testGroup +} + +func newTestGroupBadType(val any) (*Group, error) { + v := testGroupBadType{testGroup(7)} + return &Group{&v}, nil +} + +func (o testGroupBadType) Type() string { + return "uuid" +} + +func Test_RegisterGroupType(t *testing.T) { + err := RegisterGroupType(32, newTestGroup) + assert.EqualError(t, err, "tag 32 is already registered") + + err = RegisterGroupType(99993, newTestGroupBadType) + assert.EqualError(t, err, `Group type with name "uuid" already exists`) + + err = RegisterGroupType(99993, newTestGroup) + require.NoError(t, err) + +} + +func TestGroup_UmarshalJSON(t *testing.T) { + var group Group + + err := group.UnmarshalJSON([]byte(`{`)) + assert.EqualError(t, err, "group decoding failure: unexpected end of JSON input") + + err = group.UnmarshalJSON([]byte(`{"type":"uuid","value":"aaaa"}`)) + assert.EqualError(t, err, "cannot unmarshal group: bad UUID: invalid UUID length: 4") +} diff --git a/comid/instance.go b/comid/instance.go index 8201145d..69f529e1 100644 --- a/comid/instance.go +++ b/comid/instance.go @@ -1,51 +1,27 @@ package comid import ( - "encoding/hex" "encoding/json" "fmt" - "github.com/google/uuid" - "github.com/veraison/eat" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" ) // Instance stores an instance identity. The supported formats are UUID and UEID. type Instance struct { - val interface{} + Value IInstanceValue } -// NewInstance instantiates an empty instance -func NewInstance() *Instance { - return &Instance{} -} - -// SetUEID sets the identity of the target instance to the supplied UEID -func (o *Instance) SetUEID(val eat.UEID) *Instance { - if o != nil { - if val.Validate() != nil { - return nil - } - o.val = TaggedUEID(val) - } - return o -} - -// SetUUID sets the identity of the target instance to the supplied UUID -func (o *Instance) SetUUID(val uuid.UUID) *Instance { - if o != nil { - o.val = TaggedUUID(val) +// NewInstance creates a new instance with the value of the specified type +// populated using the provided value. +func NewInstance(val any, typ string) (*Instance, error) { + factory, ok := instanceValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unknown instance type: %s", typ) } - return o -} - -// NewInstanceUEID instantiates a new instance with the supplied UEID identity -func NewInstanceUEID(val eat.UEID) *Instance { - return NewInstance().SetUEID(val) -} -// NewInstanceUUID instantiates a new instance with the supplied UUID identity -func NewInstanceUUID(val uuid.UUID) *Instance { - return NewInstance().SetUUID(val) + return factory(val) } // Valid checks for the validity of given instance @@ -56,124 +32,178 @@ func (o Instance) Valid() error { return nil } -func (o Instance) GetUEID() (eat.UEID, error) { - switch t := o.val.(type) { - case TaggedUEID: - return eat.UEID(t), nil - default: - return eat.UEID{}, fmt.Errorf("instance-id type is: %T", t) - } -} - -func (o Instance) GetUUID() (UUID, error) { - switch t := o.val.(type) { - case TaggedUUID: - return UUID(t), nil - default: - return UUID{}, fmt.Errorf("instance-id type is: %T", t) - } -} - // String returns a printable string of the Instance value. UUIDs use the // canonical 8-4-4-4-12 format, UEIDs are hex encoded. func (o Instance) String() string { - switch t := o.val.(type) { - case TaggedUUID: - return UUID(t).String() - case TaggedUEID: - return hex.EncodeToString(t) - default: + if o.Value == nil { return "" } + + return o.Value.String() +} + +// Type returns a string naming the type of the underlying Instance value. +func (o Instance) Type() string { + return o.Value.Type() +} + +// Bytes returns a []byte containing the bytes of the underlying Instance +// value. +func (o Instance) Bytes() []byte { + return o.Value.Bytes() } // MarshalCBOR serializes the target instance to CBOR func (o Instance) MarshalCBOR() ([]byte, error) { - return em.Marshal(o.val) + return em.Marshal(o.Value) } func (o *Instance) UnmarshalCBOR(data []byte) error { - var ueid TaggedUEID - - if dm.Unmarshal(data, &ueid) == nil { - o.val = ueid - return nil - } - - var u TaggedUUID - - if dm.Unmarshal(data, &u) == nil { - o.val = u - return nil - } - - return fmt.Errorf("unknown instance type (CBOR: %x)", data) + return dm.Unmarshal(data, &o.Value) } -// UnmarshalJSON deserializes the supplied JSON type/value object into the Group -// target. The supported formats are UUID, e.g.: +// UnmarshalJSON deserializes the supplied JSON object into the target Instance +// The instance object must have the following shape: // // { -// "type": "uuid", -// "value": "69E027B2-7157-4758-BCB4-D9F167FE49EA" +// "type": "", +// "value": // } // -// and UEID: +// where must be one of the known IInstanceValue implementation +// type names (available in the base implementation: "ueid" and "uuid"), and +// is the JSON encoding of the instance value. The exact +// encoding is dependent. For the base implmentation types it is // -// { -// "type": "ueid", -// "value": "Ad6tvu/erb7v3q2+796tvu8=" -// } +// ueid: base64-encoded bytes, e.g. "YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE=" +// uuid: standard UUID string representation, e.g. "550e8400-e29b-41d4-a716-446655440000" func (o *Instance) UnmarshalJSON(data []byte) error { - var v tnv + var tnv encoding.TypeAndValue + + if err := json.Unmarshal(data, &tnv); err != nil { + return fmt.Errorf("instance decoding failure: %w", err) + } - if err := json.Unmarshal(data, &v); err != nil { + decoded, err := NewInstance(nil, tnv.Type) + if err != nil { return err } - switch v.Type { - case "uuid": - var x UUID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedUUID(x) - case "ueid": - var x UEID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedUEID(x) - default: - return fmt.Errorf("unknown type %s for instance", v.Type) + if err := json.Unmarshal(tnv.Value, &decoded.Value); err != nil { + return fmt.Errorf( + "cannot unmarshal instance: %w", + err, + ) + } + + if err := decoded.Value.Valid(); err != nil { + return fmt.Errorf("invalid %s: %w", tnv.Type, err) } + o.Value = decoded.Value + return nil } +// MarshalJSON serializes the Instance into a JSON object. func (o Instance) MarshalJSON() ([]byte, error) { - var ( - v tnv - b []byte - err error - ) - - switch t := o.val.(type) { - case TaggedUUID: - b, err = UUID(t).MarshalJSON() - if err != nil { - return nil, err - } - v = tnv{Type: "uuid", Value: b} - case TaggedUEID: - b, err = UEID(t).MarshalJSON() - if err != nil { - return nil, err - } - v = tnv{Type: "ueid", Value: b} - default: - return nil, fmt.Errorf("unknown type %T for instance", t) - } - - return json.Marshal(v) + valueBytes, err := json.Marshal(o.Value) + if err != nil { + return nil, err + } + + value := encoding.TypeAndValue{ + Type: o.Value.Type(), + Value: valueBytes, + } + + return json.Marshal(value) +} + +// IInstanceValue is the interface implemented by all Instance value +// implementations. +type IInstanceValue interface { + extensions.ITypeChoiceValue + + Bytes() []byte +} + +// NewUEIDInstance instantiates a new instance with the supplied UEID identity. +func NewUEIDInstance(val any) (*Instance, error) { + if val == nil { + return &Instance{&TaggedUEID{}}, nil + } + + ret, err := NewTaggedUEID(val) + if err != nil { + return nil, err + } + return &Instance{ret}, nil +} + +// MustNewUEIDInstance is like NewUEIDInstance execept it does not return an +// error, assuming that the provided value is valid. It panics if that isn't +// the case. +func MustNewUEIDInstance(val any) *Instance { + ret, err := NewUEIDInstance(val) + if err != nil { + panic(err) + } + + return ret +} + +// NewUUIDInstance instantiates a new instance with the supplied UUID identity +func NewUUIDInstance(val any) (*Instance, error) { + if val == nil { + return &Instance{&TaggedUUID{}}, nil + } + + ret, err := NewTaggedUUID(val) + if err != nil { + return nil, err + } + + return &Instance{ret}, nil +} + +// MustNewUUIDInstance is like NewUUIDInstance execept it does not return an +// error, assuming that the provided value is valid. It panics if that isn't +// the case. +func MustNewUUIDInstance(val any) *Instance { + ret, err := NewUUIDInstance(val) + if err != nil { + panic(err) + } + + return ret +} + +type IInstanceFactory func(any) (*Instance, error) + +var instanceValueRegister = map[string]IInstanceFactory{ + UEIDType: NewUEIDInstance, + UUIDType: NewUUIDInstance, +} + +// RegisterInstanceType registers a new IInstanceValue implementation (created +// by the provided IInstanceFactory) under the specified CBOR tag. +func RegisterInstanceType(tag uint64, factory IInstanceFactory) error { + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Type() + if _, exists := instanceValueRegister[typ]; exists { + return fmt.Errorf("class ID type with name %q already exists", typ) + } + + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + instanceValueRegister[typ] = factory + + return nil } diff --git a/comid/instance_test.go b/comid/instance_test.go index 34e671e7..6ea22fd4 100644 --- a/comid/instance_test.go +++ b/comid/instance_test.go @@ -1,24 +1,69 @@ package comid import ( + "encoding/json" "testing" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestInstance_GetUUID_OK(t *testing.T) { - inst := NewInstanceUUID(uuid.UUID(TestUUID)) - require.NotNil(t, inst) - u, err := inst.GetUUID() - assert.Nil(t, err) - assert.Equal(t, u, TestUUID) + inst := MustNewUUIDInstance(TestUUID) + u, ok := inst.Value.(*TaggedUUID) + assert.True(t, ok) + assert.EqualValues(t, TestUUID, *u) } -func TestInstance_GetUUID_NOK(t *testing.T) { - inst := &Instance{} - expectedErr := "instance-id type is: " - _, err := inst.GetUUID() - assert.EqualError(t, err, expectedErr) +type testInstance string + +func newTestInstance(val any) (*Instance, error) { + ret := testInstance("test") + return &Instance{&ret}, nil +} + +func (o testInstance) Bytes() []byte { + return []byte(o) +} + +func (o testInstance) Type() string { + return "test-instance" +} + +func (o testInstance) String() string { + return string(o) +} + +func (o testInstance) Valid() error { + return nil +} + +func Test_RegisterInstanceType(t *testing.T) { + err := RegisterInstanceType(99997, newTestInstance) + require.NoError(t, err) + + instance, err := newTestInstance(nil) + require.NoError(t, err) + + data, err := json.Marshal(instance) + require.NoError(t, err) + assert.Equal(t, string(data), `{"type":"test-instance","value":"test"}`) + + var out Instance + err = json.Unmarshal(data, &out) + require.NoError(t, err) + assert.Equal(t, instance.Bytes(), out.Bytes()) + + data, err = em.Marshal(instance) + require.NoError(t, err) + assert.Equal(t, data, []byte{ + 0xda, 0x0, 0x1, 0x86, 0x9d, // tag 99997 + 0x64, // tstr(4) + 0x74, 0x65, 0x73, 0x74, // "test" + }) + + var out2 Instance + err = dm.Unmarshal(data, &out2) + require.NoError(t, err) + assert.Equal(t, instance.Bytes(), out2.Bytes()) } diff --git a/comid/measurement.go b/comid/measurement.go index 7bfd7f99..00ef949a 100644 --- a/comid/measurement.go +++ b/comid/measurement.go @@ -5,8 +5,10 @@ package comid import ( "encoding/json" + "errors" "fmt" "net" + "strconv" "github.com/veraison/corim/encoding" "github.com/veraison/corim/extensions" @@ -16,198 +18,275 @@ import ( const MaxUint64 = ^uint64(0) -// Measurement stores a measurement-map with CBOR and JSON serializations. -type Measurement struct { - Key *Mkey `cbor:"0,keyasint,omitempty" json:"key,omitempty"` - Val Mval `cbor:"1,keyasint" json:"value"` - AuthorizedBy *CryptoKey `cbor:"2,keyasint,omitempty" json:"authorized-by,omitempty"` -} - // Mkey stores a $measured-element-type-choice. // The supported types are UUID, PSA refval-id, CCA platform-config-id and unsigned integer // TO DO Add tagged OID: see https://github.com/veraison/corim/issues/35 type Mkey struct { - val interface{} + Value IMKeyValue +} + +// NewMkey creates a new Mkey of the specfied type using the provided value. +func NewMkey(val any, typ string) (*Mkey, error) { + factory, ok := mkeyValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unexpected measurement key type: %q", typ) + } + + return factory(nil) +} + +// MustNewMkey is like NewMkey, execept it does not return an error, assuming +// that the provided value is valid. It panics if that is not the case. +func MustNewMkey(val any, typ string) *Mkey { + ret, err := NewMkey(val, typ) + if err != nil { + panic(err) + } + + return ret } +// IsSet returns true if the value of the Mkey is set. func (o Mkey) IsSet() bool { - return o.val != nil + return o.Value != nil } +// Valid returns nil if the Mkey is valid or an error describing the problem, +// if it is not. func (o Mkey) Valid() error { - switch t := o.val.(type) { - case TaggedUUID: - if UUID(t).Empty() { - return fmt.Errorf("empty UUID") - } - return nil - case TaggedPSARefValID: - return PSARefValID(t).Valid() - case TaggedCCAPlatformConfigID: - if CCAPlatformConfigID(t).Empty() { - return fmt.Errorf("empty CCAPlatformConfigID") - } - case uint64: - if o.val == nil { - return fmt.Errorf("empty uint Mkey") - } - return nil - default: - return fmt.Errorf("unknown measurement key type: %T", t) + if o.Value == nil { + return errors.New("Mkey value not set") + } + + if err := o.Value.Valid(); err != nil { + return fmt.Errorf("invalid %s: %w", o.Value.Type(), err) } + return nil } -func (o Mkey) IsPSARefValID() bool { - _, ok := o.val.(TaggedPSARefValID) - return ok -} +// UnmarshalJSON deserializes the supplied JSON object into the target MKey +// The key object must have the following shape: +// +// { +// "type": "", +// "value": +// } +// +// where must be one of the known IMKeyValue implementation +// type names (available in the base implementation: "uuid", "oid", +// "psa.impl-id"), and is the class id value serialized to +// JSON. The exact serialization is depenent. For the base +// implementation types it is +// +// oid: dot-seprated integers, e.g. "1.2.3.4" +// uuid: standard UUID string representation, e.g. "550e8400-e29b-41d4-a716-446655440000" +// psa.refval-id: JSON representation of the PSA refval-id +func (o *Mkey) UnmarshalJSON(data []byte) error { + var tnv encoding.TypeAndValue + + if err := json.Unmarshal(data, &tnv); err != nil { + return err + } + + decoded, err := NewMkey(nil, tnv.Type) + if err != nil { + return err + } + + if err := json.Unmarshal(tnv.Value, decoded.Value); err != nil { + return fmt.Errorf("invalid %s: %w", tnv.Type, err) + } + + if err := decoded.Value.Valid(); err != nil { + return fmt.Errorf("invalid %s: %w", tnv.Type, err) + } + + o.Value = decoded.Value -func (o Mkey) IsCCAPlatformConfigID() bool { - _, ok := o.val.(TaggedCCAPlatformConfigID) - return ok + return nil } -func (o Mkey) GetPSARefValID() (PSARefValID, error) { - switch t := o.val.(type) { - case TaggedPSARefValID: - return PSARefValID(t), nil - default: - return PSARefValID{}, fmt.Errorf("measurement-key type is: %T", t) +// MarshalJSON serializes the target Mkey into the type'n'value JSON object +func (o Mkey) MarshalJSON() ([]byte, error) { + valueBytes, err := json.Marshal(o.Value) + if err != nil { + return nil, err + } + + value := encoding.TypeAndValue{ + Type: o.Value.Type(), + Value: valueBytes, } + + return json.Marshal(value) } -func (o Mkey) GetCCAPlatformConfigID() (CCAPlatformConfigID, error) { - switch t := o.val.(type) { - case TaggedCCAPlatformConfigID: - return CCAPlatformConfigID(t), nil - default: - return CCAPlatformConfigID(""), fmt.Errorf("measurement-key type is: %T", t) +// MarshalCBOR serializes the taret mkey into CBOR-encoded bytes. +func (o Mkey) MarshalCBOR() ([]byte, error) { + return em.Marshal(o.Value) +} + +// UnmarshalCBOR deserializes the Mkey from the provided CBOR bytes. +func (o *Mkey) UnmarshalCBOR(data []byte) error { + if len(data) == 0 { + return errors.New("empty input") } + + majorType := (data[0] & 0xe0) >> 5 + if majorType == 6 { // tag + return dm.Unmarshal(data, &o.Value) + } + + // untagged value must be a uint + + var val UintMkey + if err := dm.Unmarshal(data, &val); err != nil { + return err + } + + o.Value = &val + return nil +} + +// IMKeyValue is the interface implemented by all Mkey value implementations. +type IMKeyValue interface { + extensions.ITypeChoiceValue } -func (o Mkey) GetKeyUint() (uint64, error) { - switch t := o.val.(type) { +const UintType = "uint" + +type UintMkey uint64 + +func NewUintMkey(val any) (*UintMkey, error) { + var ret UintMkey + + if val == nil { + return &ret, nil + } + + switch t := val.(type) { + case UintMkey: + ret = t + case *UintMkey: + ret = *t + case string: + u, err := strconv.ParseUint(t, 10, 64) + if err != nil { + return nil, err + } + ret = UintMkey(u) case uint64: - return t, nil + ret = UintMkey(t) + case uint: + ret = UintMkey(t) default: - return MaxUint64, fmt.Errorf("measurement-key type is: %T", t) + return nil, fmt.Errorf("unexpected type for UintMkey: %T", t) } + + return &ret, nil } -// UnmarshalJSON deserializes the type'n'value JSON object into the target Mkey -func (o *Mkey) UnmarshalJSON(data []byte) error { - var v tnv +func (o UintMkey) Valid() error { + return nil +} + +func (o UintMkey) String() string { + return fmt.Sprint(uint64(o)) +} - if err := json.Unmarshal(data, &v); err != nil { +func (o UintMkey) Type() string { + return UintType +} + +func (o *UintMkey) UnmarshalJSON(data []byte) error { + var tmp uint64 + + if err := json.Unmarshal(data, &tmp); err != nil { return err } - switch v.Type { - case "uuid": - var x UUID - if err := x.UnmarshalJSON(v.Value); err != nil { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type UUID: %w", - err, - ) - } - o.val = TaggedUUID(x) - case "psa.refval-id": - var x PSARefValID - if err := json.Unmarshal(v.Value, &x); err != nil { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type PSARefValID: %w", - err, - ) - } - if err := x.Valid(); err != nil { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type PSARefValID: %w", - err, - ) - } - o.val = TaggedPSARefValID(x) - case "cca.platform-config-id": - var x CCAPlatformConfigID - if err := json.Unmarshal(v.Value, &x); err != nil { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type CCAPlatformConfigID: %w", - err, - ) - } - if x.Empty() { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type CCAPlatformConfigID: empty label", - ) - } - o.val = TaggedCCAPlatformConfigID(x) - case "uint": - var x uint64 - if err := json.Unmarshal(v.Value, &x); err != nil { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type uint: %w", - err, - ) - } - o.val = x - default: - return fmt.Errorf("unknown type %s for $measured-element-type-choice", v.Type) - } + *o = UintMkey(tmp) return nil } -// MarshalJSON serializes the target Mkey into the type'n'value JSON object -// Supported types are: uuid, psa.refval-id and unsigned integer -func (o Mkey) MarshalJSON() ([]byte, error) { - var ( - v tnv - b []byte - err error - ) - - switch t := o.val.(type) { - case TaggedUUID: - uuidString := UUID(t).String() - b, err = json.Marshal(uuidString) - if err != nil { - return nil, err - } - v = tnv{Type: "uuid", Value: b} - case TaggedPSARefValID: - b, err = json.Marshal(t) - if err != nil { - return nil, err - } - v = tnv{Type: "psa.refval-id", Value: b} - case TaggedCCAPlatformConfigID: - b, err = json.Marshal(t) - if err != nil { - return nil, err - } - v = tnv{Type: "cca.platform-config-id", Value: b} +func NewMkeyOID(val any) (*Mkey, error) { + ret, err := NewTaggedOID(val) + if err != nil { + return nil, err + } - case uint64: - b, err = json.Marshal(t) - if err != nil { - return nil, err - } - v = tnv{Type: "uint", Value: b} + return &Mkey{ret}, nil +} - default: - return nil, fmt.Errorf("unknown type %T for mkey", t) +func NewMkeyUUID(val any) (*Mkey, error) { + ret, err := NewTaggedUUID(val) + if err != nil { + return nil, err } - return json.Marshal(v) + return &Mkey{ret}, nil } -func (o Mkey) MarshalCBOR() ([]byte, error) { - return em.Marshal(o.val) +func NewMkeyUint(val any) (*Mkey, error) { + ret, err := NewUintMkey(val) + if err != nil { + return nil, err + } + + return &Mkey{ret}, nil } -func (o *Mkey) UnmarshalCBOR(data []byte) error { - return dm.Unmarshal(data, &o.val) +func NewMkeyPSARefvalID(val any) (*Mkey, error) { + ret, err := NewTaggedPSARefValID(val) + if err != nil { + return nil, err + } + + return &Mkey{ret}, nil +} + +func NewMkeyCCAPlatformConfigID(val any) (*Mkey, error) { + ret, err := NewTaggedCCAPlatformConfigID(val) + if err != nil { + return nil, err + } + + return &Mkey{ret}, nil +} + +type IMkeyFactory = func(val any) (*Mkey, error) + +var mkeyValueRegister = map[string]IMkeyFactory{ + OIDType: NewMkeyOID, + UUIDType: NewMkeyUUID, + UintType: NewMkeyUint, + PSARefValIDType: NewMkeyPSARefvalID, + CCAPlatformConfigIDType: NewMkeyCCAPlatformConfigID, +} + +// RegisterMkeyType registers a new IMKeyValue implementation +// (created by the provided IMKeyFactory) under the specified CBOR tag. +func RegisterMkeyType(tag uint64, factory IMkeyFactory) error { + + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Value.Type() + if _, exists := mkeyValueRegister[typ]; exists { + return fmt.Errorf("mesurement key type with name %q already exists", typ) + } + + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + mkeyValueRegister[typ] = factory + + return nil } // Mval stores a measurement-values-map with JSON and CBOR serializations. @@ -330,95 +409,112 @@ func (o Version) Valid() error { return nil } -// NewMeasurement instantiates an empty measurement -func NewMeasurement() *Measurement { - return &Measurement{} +// Measurement stores a measurement-map with CBOR and JSON serializations. +type Measurement struct { + Key *Mkey `cbor:"0,keyasint,omitempty" json:"key,omitempty"` + Val Mval `cbor:"1,keyasint" json:"value"` + AuthorizedBy *CryptoKey `cbor:"2,keyasint,omitempty" json:"authorized-by,omitempty"` } -// SetKeyPSARefValID sets the key of the target measurement-map to the supplied -// PSA refval-id -func (o *Measurement) SetKeyPSARefValID(psaRefValID PSARefValID) *Measurement { - if o != nil { - if psaRefValID.Valid() != nil { - return nil - } - o.Key = &Mkey{ - val: TaggedPSARefValID(psaRefValID), - } +func NewMeasurement(val any, typ string) (*Measurement, error) { + keyFactory, ok := mkeyValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unknown Mkey type: %s", typ) } - return o -} -// SetKeyCCAPlatformConfigID sets the key of the target measurement-map to the supplied -// CCA platform-config-id -func (o *Measurement) SetKeyCCAPlatformConfigID(ccaPlatformConfigID CCAPlatformConfigID) *Measurement { - if o != nil { - if ccaPlatformConfigID.Empty() { - return nil - } - o.Key = &Mkey{ - val: TaggedCCAPlatformConfigID(ccaPlatformConfigID), - } + key, err := keyFactory(val) + if err != nil { + return nil, fmt.Errorf("invalid key: %w", err) } - return o -} -// SetKeyKeyUUID sets the key of the target measurement-map to the supplied -// UUID -func (o *Measurement) SetKeyUUID(u UUID) *Measurement { - if o != nil { - if u.Empty() { - return nil - } + if err = key.Valid(); err != nil { + return nil, fmt.Errorf("invalid key: %w", err) + } - if u.Valid() != nil { - return nil - } + var ret Measurement + ret.Key = key - o.Key = &Mkey{ - val: TaggedUUID(u), - } - } - return o + return &ret, nil } -// SetKeyUint sets the key of the target measurement-map to the supplied -// unsigned integer -func (o *Measurement) SetKeyUint(u uint64) *Measurement { - if o != nil { - o.Key = &Mkey{ - val: u, - } +func MustNewMeasurement(val any, typ string) *Measurement { + ret, err := NewMeasurement(val, typ) + + if err != nil { + panic(err) } - return o + + return ret } // NewPSAMeasurement instantiates a new measurement-map with the key set to the // supplied PSA refval-id -func NewPSAMeasurement(psaRefValID PSARefValID) *Measurement { - m := &Measurement{} - return m.SetKeyPSARefValID(psaRefValID) +func NewPSAMeasurement(key any) (*Measurement, error) { + return NewMeasurement(key, PSARefValIDType) +} + +func MustNewPSAMeasurement(key any) *Measurement { + ret, err := NewPSAMeasurement(key) + + if err != nil { + panic(err) + } + + return ret } // NewCCAPlatCfgMeasurement instantiates a new measurement-map with the key set to the // supplied CCA platform-config-id -func NewCCAPlatCfgMeasurement(ccaPlatformConfigID CCAPlatformConfigID) *Measurement { - m := &Measurement{} - return m.SetKeyCCAPlatformConfigID(ccaPlatformConfigID) +func NewCCAPlatCfgMeasurement(key any) (*Measurement, error) { + return NewMeasurement(key, CCAPlatformConfigIDType) +} + +func MustNewCCAPlatCfgMeasurement(key any) *Measurement { + ret, err := NewCCAPlatCfgMeasurement(key) + + if err != nil { + panic(err) + } + + return ret } // NewUUIDMeasurement instantiates a new measurement-map with the key set to the // supplied UUID -func NewUUIDMeasurement(uuid UUID) *Measurement { - m := &Measurement{} - return m.SetKeyUUID(uuid) +func NewUUIDMeasurement(key any) (*Measurement, error) { + return NewMeasurement(key, UUIDType) +} + +func MustNewUUIDMeasurement(key any) *Measurement { + ret, err := NewUUIDMeasurement(key) + + if err != nil { + panic(err) + } + + return ret } // NewUintMeasurement instantiates a new measurement-map with the key set to the // supplied Uint -func NewUintMeasurement(mkey uint64) *Measurement { - m := &Measurement{} - return m.SetKeyUint(mkey) +func NewUintMeasurement(key any) (*Measurement, error) { + return NewMeasurement(key, UintType) +} + +func MustNewUintMeasurement(key any) *Measurement { + ret, err := NewUintMeasurement(key) + + if err != nil { + panic(err) + } + + return ret +} + +// NewOIDMeasurement instantiates a new measurement-map with the key set to the +// supplied OID +func NewOIDMeasurement(key any) (*Measurement, error) { + return NewMeasurement(key, OIDType) } func (o *Measurement) SetVersion(ver string, scheme int64) *Measurement { @@ -448,26 +544,14 @@ func (o *Measurement) SetRawValueBytes(rawValue, rawValueMask []byte) *Measureme // SetSVN sets the supplied svn in the measurement-values-map of the target // measurement func (o *Measurement) SetSVN(svn uint64) *Measurement { - if o != nil { - s := SVN{} - if s.SetSVN(svn) == nil { - return nil - } - o.Val.SVN = &s - } + o.Val.SVN = MustNewTaggedSVN(svn) return o } // SetMinSVN sets the supplied min-svn in the measurement-values-map of the // target measurement func (o *Measurement) SetMinSVN(svn uint64) *Measurement { - if o != nil { - s := SVN{} - if s.SetMinSVN(svn) == nil { - return nil - } - o.Val.SVN = &s - } + o.Val.SVN = MustNewTaggedMinSVN(svn) return o } diff --git a/comid/measurement_test.go b/comid/measurement_test.go index d5d58bff..c4f95e11 100644 --- a/comid/measurement_test.go +++ b/comid/measurement_test.go @@ -4,6 +4,7 @@ package comid import ( + "crypto" "fmt" "testing" @@ -14,86 +15,86 @@ import ( ) func TestMeasurement_NewUUIDMeasurement_good_uuid(t *testing.T) { - tv := NewUUIDMeasurement(TestUUID) - assert.NotNil(t, tv) + _, err := NewUUIDMeasurement(TestUUID) + assert.NoError(t, err) } func TestMeasurement_NewUUIDMeasurement_empty_uuid(t *testing.T) { emptyUUID := UUID{} - tv := NewUUIDMeasurement(emptyUUID) + _, err := NewUUIDMeasurement(emptyUUID) - assert.Nil(t, tv) + assert.EqualError(t, err, + "invalid key: expecting RFC4122 UUID, got Reserved instead") } func TestMeasurement_NewUIntMeasurement(t *testing.T) { var TestUint uint64 = 35 - tv := NewUintMeasurement(TestUint) + _, err := NewUintMeasurement(TestUint) - assert.NotNil(t, tv) + assert.NoError(t, err) } func TestMeasurement_NewPSAMeasurement_empty(t *testing.T) { emptyPSARefValID := PSARefValID{} - tv := NewPSAMeasurement(emptyPSARefValID) - assert.Nil(t, tv) + _, err := NewPSAMeasurement(emptyPSARefValID) + + assert.EqualError(t, err, "invalid key: invalid psa.refval-id: missing mandatory signer ID") } func TestMeasurement_NewPSAMeasurement_no_values(t *testing.T) { - psaRefValID := - NewPSARefValID(TestSignerID). - SetLabel("PRoT"). - SetVersion("1.2.3") + psaRefValID, err := NewPSARefValID(TestSignerID) + require.NoError(t, err) + psaRefValID.SetLabel("PRoT") + psaRefValID.SetVersion("1.2.3") require.NotNil(t, psaRefValID) - tv := NewPSAMeasurement(*psaRefValID) - assert.NotNil(t, tv) + tv, err := NewPSAMeasurement(*psaRefValID) + assert.NoError(t, err) - err := tv.Valid() + err = tv.Valid() assert.EqualError(t, err, "no measurement value set") } func TestMeasurement_NewCCAPlatCfgMeasurement_no_values(t *testing.T) { ccaplatID := CCAPlatformConfigID(TestCCALabel) - tv := NewCCAPlatCfgMeasurement(ccaplatID) - assert.NotNil(t, tv) + tv, err := NewCCAPlatCfgMeasurement(ccaplatID) + assert.NoError(t, err) - err := tv.Valid() + err = tv.Valid() assert.EqualError(t, err, "no measurement value set") } func TestMeasurement_NewCCAPlatCfgMeasurement_valid_meas(t *testing.T) { ccaplatID := CCAPlatformConfigID(TestCCALabel) - tv := NewCCAPlatCfgMeasurement(ccaplatID).SetRawValueBytes([]byte{0x01, 0x02, 0x03, 0x04}, []byte{}) - assert.NotNil(t, tv) + tv, err := NewCCAPlatCfgMeasurement(ccaplatID) + assert.NoError(t, err) - err := tv.Valid() - assert.Nil(t, err) + tv.SetRawValueBytes([]byte{0x01, 0x02, 0x03, 0x04}, []byte{}) + + err = tv.Valid() + assert.NoError(t, err) } func TestMeasurement_NewPSAMeasurement_one_value(t *testing.T) { - psaRefValID := - NewPSARefValID(TestSignerID). - SetLabel("PRoT"). - SetVersion("1.2.3") - require.NotNil(t, psaRefValID) + tv, err := NewPSAMeasurement(MustCreatePSARefValID(TestSignerID, "PRoT", "1.2.3")) + require.NoError(t, err) - tv := NewPSAMeasurement(*psaRefValID).SetIPaddr(TestIPaddr) - assert.NotNil(t, tv) + tv.SetIPaddr(TestIPaddr) - err := tv.Valid() + err = tv.Valid() assert.Nil(t, err) } func TestMeasurement_NewUUIDMeasurement_no_values(t *testing.T) { - tv := NewUUIDMeasurement(TestUUID) - require.NotNil(t, tv) + tv, err := NewUUIDMeasurement(TestUUID) + require.NoError(t, err) - err := tv.Valid() + err = tv.Valid() assert.EqualError(t, err, "no measurement value set") } @@ -101,26 +102,27 @@ func TestMeasurement_NewUUIDMeasurement_some_value(t *testing.T) { var vs swid.VersionScheme require.NoError(t, vs.SetCode(swid.VersionSchemeSemVer)) - tv := NewUUIDMeasurement(TestUUID). - SetMinSVN(2). + tv, err := NewUUIDMeasurement(TestUUID) + require.NoError(t, err) + + tv.SetMinSVN(2). SetFlagsTrue(FlagIsDebug). SetVersion("1.2.3", swid.VersionSchemeSemVer) - require.NotNil(t, tv) - err := tv.Valid() + err = tv.Valid() assert.Nil(t, err) } func TestMeasurement_NewUUIDMeasurement_bad_digest(t *testing.T) { - tv := NewUUIDMeasurement(TestUUID) - require.NotNil(t, tv) + tv, err := NewUUIDMeasurement(TestUUID) + require.NoError(t, err) assert.Nil(t, tv.AddDigest(swid.Sha256, []byte{0xff})) } func TestMeasurement_NewUUIDMeasurement_bad_ueid(t *testing.T) { - tv := NewUUIDMeasurement(TestUUID) - require.NotNil(t, tv) + tv, err := NewUUIDMeasurement(TestUUID) + require.NoError(t, err) badUEID := eat.UEID{ 0xFF, // Invalid @@ -131,8 +133,8 @@ func TestMeasurement_NewUUIDMeasurement_bad_ueid(t *testing.T) { } func TestMeasurement_NewUUIDMeasurement_bad_uuid(t *testing.T) { - tv := NewUUIDMeasurement(TestUUID) - require.NotNil(t, tv) + tv, err := NewUUIDMeasurement(TestUUID) + require.NoError(t, err) nonRFC4122UUID, err := ParseUUID("f47ac10b-58cc-4372-c567-0e02b2c3d479") require.Nil(t, err) @@ -147,7 +149,7 @@ var ( func TestMkey_Valid_no_value(t *testing.T) { mkey := &Mkey{} - expectedErr := "unknown measurement key type: " + expectedErr := "Mkey value not set" err := mkey.Valid() assert.EqualError(t, err, expectedErr) } @@ -183,19 +185,19 @@ func TestMKey_MarshalCBOR_CCAPlatformConfigID_ok(t *testing.T) { func TestMKey_UnmarshalCBOR_CCAPlatformConfigID_ok(t *testing.T) { tvs := []struct { input []byte - expected CCAPlatformConfigID + expected TaggedCCAPlatformConfigID }{ { input: MustHexDecode(t, "d9025a736363612d706c6174666f726d2d636f6e666967"), - expected: CCAPlatformConfigID(TestCCALabel), + expected: TaggedCCAPlatformConfigID(TestCCALabel), }, { input: MustHexDecode(t, "d9025a716d7974657374706c6174666f726d666967"), - expected: CCAPlatformConfigID("mytestplatformfig"), + expected: TaggedCCAPlatformConfigID("mytestplatformfig"), }, { input: MustHexDecode(t, "d9025a6c6d79746573746c6162656c32"), - expected: CCAPlatformConfigID("mytestlabel2"), + expected: TaggedCCAPlatformConfigID("mytestlabel2"), }, } @@ -203,9 +205,9 @@ func TestMKey_UnmarshalCBOR_CCAPlatformConfigID_ok(t *testing.T) { mkey := &Mkey{} err := mkey.UnmarshalCBOR(tv.input) assert.Nil(t, err) - actual, err := mkey.GetCCAPlatformConfigID() - assert.Nil(t, err) - assert.Equal(t, tv.expected, actual) + actual, ok := mkey.Value.(*TaggedCCAPlatformConfigID) + assert.True(t, ok) + assert.Equal(t, tv.expected, *actual) fmt.Printf("CBOR: %x\n", actual) } } @@ -230,36 +232,13 @@ func TestMKey_MarshalCBOR_uint_ok(t *testing.T) { } for _, tv := range tvs { - mkey := &Mkey{tv.mkey} + mkey := &Mkey{UintMkey(tv.mkey)} actual, err := mkey.MarshalCBOR() assert.Nil(t, err) assert.Equal(t, tv.expected, actual) fmt.Printf("CBOR: %x\n", actual) } } -func TestMkey_MarshalCBOR_uint_not_ok(t *testing.T) { - tvs := []struct { - input interface{} - expected string - }{ - { - input: 123.456, - expected: "unknown measurement key type: float64", - }, - { - input: "sample", - expected: "unknown measurement key type: string", - }, - } - - for _, tv := range tvs { - mkey := &Mkey{tv.input} - _, err := mkey.MarshalCBOR() - assert.Nil(t, err) - err = mkey.Valid() - assert.EqualError(t, err, tv.expected) - } -} func TestMkey_UnmarshalCBOR_uint_ok(t *testing.T) { tvs := []struct { @@ -284,10 +263,10 @@ func TestMkey_UnmarshalCBOR_uint_ok(t *testing.T) { mKey := &Mkey{} err := mKey.UnmarshalCBOR(tv.mkey) - assert.Nil(t, err) - actual, err := mKey.GetKeyUint() - assert.Nil(t, err) - assert.Equal(t, tv.expected, actual) + require.NoError(t, err) + actual, ok := mKey.Value.(*UintMkey) + require.True(t, ok) + assert.Equal(t, tv.expected, uint64(*actual)) } } @@ -315,38 +294,9 @@ func TestMkey_UnmarshalCBOR_not_ok(t *testing.T) { } } -func TestMkey_UnmarshalCBOR_uint_not_ok(t *testing.T) { - tvs := []struct { - input []byte - expected string - }{ - { - input: []byte{0xd8, 0x25, 0x50, 0x31, 0xfb, 0x5a, 0xbf, 0x02, - 0x3e, 0x49, 0x92, 0xaa, 0x4e, 0x95, 0xf9, 0xc1, - 0x50, 0x3b, 0xfa}, - expected: "measurement-key type is: comid.TaggedUUID", - }, - { - input: []byte{0xd8, 0x21, 0x50, 0x31, 0xfb, 0x5a, 0xff, 0x12, - 0xFF, 0xFF, 0x92, 0xaa, 0x4e, 0x95, 0xf9, 0xc1, - 0x50, 0x3b, 0xfa}, - expected: "measurement-key type is: cbor.Tag", - }, - } - - for _, tv := range tvs { - mKey := &Mkey{} - - err := mKey.UnmarshalCBOR(tv.input) - assert.Nil(t, err) - _, err = mKey.GetKeyUint() - assert.EqualError(t, err, tv.expected) - } -} - func TestMKey_MarshalJSON_CCAPlatformConfigID_ok(t *testing.T) { refval := TestCCALabel - mkey := &Mkey{val: TaggedCCAPlatformConfigID(refval)} + mkey := &Mkey{Value: TaggedCCAPlatformConfigID(refval)} expected := `{"type":"cca.platform-config-id","value":"cca-platform-config"}` @@ -359,30 +309,29 @@ func TestMKey_MarshalJSON_CCAPlatformConfigID_ok(t *testing.T) { func TestMKey_UnMarshalJSON_CCAPlatformConfigID_ok(t *testing.T) { input := []byte(`{"type":"cca.platform-config-id","value":"cca-platform-config"}`) - expected := CCAPlatformConfigID(TestCCALabel) + expected := TaggedCCAPlatformConfigID(TestCCALabel) mKey := &Mkey{} err := mKey.UnmarshalJSON(input) assert.Nil(t, err) - actual, err := mKey.GetCCAPlatformConfigID() - assert.Nil(t, err) - assert.Equal(t, expected, actual) + actual, ok := mKey.Value.(*TaggedCCAPlatformConfigID) + assert.True(t, ok) + assert.Equal(t, expected, *actual) } func TestMKey_UnMarshalJSON_CCAPlatformConfigID_not_ok(t *testing.T) { input := []byte(`{"type":"cca.platform-config-id","value":""}`) - expected := "cannot unmarshal $measured-element-type-choice of type CCAPlatformConfigID: empty label" + expected := "invalid cca.platform-config-id: empty value" mKey := &Mkey{} err := mKey.UnmarshalJSON(input) - assert.NotNil(t, err) - assert.Equal(t, expected, err.Error()) - + assert.EqualError(t, err, expected) } + func TestMkey_MarshalJSON_uint_ok(t *testing.T) { tvs := []struct { mkey uint64 @@ -404,7 +353,7 @@ func TestMkey_MarshalJSON_uint_ok(t *testing.T) { for _, tv := range tvs { - mkey := &Mkey{tv.mkey} + mkey := &Mkey{UintMkey(tv.mkey)} actual, err := mkey.MarshalJSON() assert.Nil(t, err) @@ -414,31 +363,6 @@ func TestMkey_MarshalJSON_uint_ok(t *testing.T) { } } -func TestMkey_MarshalJSON_uint_not_ok(t *testing.T) { - tvs := []struct { - input interface{} - expected string - }{ - { - input: 123.456, - expected: "unknown type float64 for mkey", - }, - { - input: "sample", - expected: "unknown type string for mkey", - }, - } - - for _, tv := range tvs { - - mkey := &Mkey{tv.input} - - _, err := mkey.MarshalJSON() - - assert.EqualError(t, err, tv.expected) - } -} - func TestMkey_UnmarshalJSON_uint_ok(t *testing.T) { tvs := []struct { input []byte @@ -463,9 +387,9 @@ func TestMkey_UnmarshalJSON_uint_ok(t *testing.T) { err := mKey.UnmarshalJSON(tv.input) assert.Nil(t, err) - actual, err := mKey.GetKeyUint() - assert.Nil(t, err) - assert.Equal(t, tv.expected, actual) + actual, ok := mKey.Value.(*UintMkey) + assert.True(t, ok) + assert.Equal(t, tv.expected, uint64(*actual)) } } @@ -476,11 +400,11 @@ func TestMkey_UnmarshalJSON_notok(t *testing.T) { }{ { input: []byte(`{"type":"uint","value":"abcdefg"}`), - expected: "cannot unmarshal $measured-element-type-choice of type uint: json: cannot unmarshal string into Go value of type uint64", + expected: `invalid uint: json: cannot unmarshal string into Go value of type uint64`, }, { input: []byte(`{"type":"uint","value":123.456}`), - expected: "cannot unmarshal $measured-element-type-choice of type uint: json: cannot unmarshal number 123.456 into Go value of type uint64", + expected: "invalid uint: json: cannot unmarshal number 123.456 into Go value of type uint64", }, } @@ -493,27 +417,102 @@ func TestMkey_UnmarshalJSON_notok(t *testing.T) { } } -func TestMkey_UnmarshalJSON_uint_notok(t *testing.T) { +func TestNewUintMkey(t *testing.T) { + testVal := UintMkey(7) + tvs := []struct { - input []byte - expected string + input any + expected UintMkey + err string }{ { - input: []byte(`{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}`), - expected: "measurement-key type is: comid.TaggedUUID", + input: testVal, + expected: testVal, + }, + { + input: &testVal, + expected: testVal, }, { - input: []byte(`{"type":"psa.refval-id","value":{"label": "BL","version": "2.1.0","signer-id": "rLsRx+TaIXIFUjzkzhokWuGiOa48a/2eeHH35di66Gs="}}`), - expected: "measurement-key type is: comid.TaggedPSARefValID", + input: uint(7), + expected: testVal, + }, + { + input: uint64(7), + expected: testVal, + }, + { + input: "7", + expected: testVal, + }, + { + input: true, + err: "unexpected type for UintMkey: bool", }, } for _, tv := range tvs { - mKey := &Mkey{} - - err := mKey.UnmarshalJSON(tv.input) - assert.Nil(t, err) - _, err = mKey.GetKeyUint() - assert.EqualError(t, err, tv.expected) + out, err := NewUintMkey(tv.input) + if tv.err != "" { + assert.Nil(t, out) + assert.EqualError(t, err, tv.err) + } else { + assert.Equal(t, tv.expected, *out) + } } } + +func TestNewMkeyOID(t *testing.T) { + var expectedOID OID + require.NoError(t, expectedOID.FromString(TestOID)) + expected := TaggedOID(expectedOID) + + out, err := NewMkeyOID(TestOID) + require.NoError(t, err) + assert.Equal(t, &expected, out.Value) +} + +type testMkey [4]byte + +func newTestMkey(val any) (*Mkey, error) { + return &Mkey{&testMkey{0x74, 0x64, 0x73, 0x74}}, nil +} + +func (o testMkey) PublicKey() (crypto.PublicKey, error) { + return crypto.PublicKey(o[:]), nil +} + +func (o testMkey) Type() string { + return "test-mkey" +} + +func (o testMkey) String() string { + return "test" +} + +func (o testMkey) Valid() error { + return nil +} + +type badMkey struct { + testMkey +} + +func (o badMkey) Type() string { + return "uuid" +} + +func newBadMkey(val any) (*Mkey, error) { + return &Mkey{&badMkey{testMkey{0x74, 0x64, 0x73, 0x74}}}, nil +} + +func TestRegisterMkeyType(t *testing.T) { + err := RegisterMkeyType(32, newTestMkey) + assert.EqualError(t, err, "tag 32 is already registered") + + err = RegisterMkeyType(99996, newBadMkey) + assert.EqualError(t, err, `mesurement key type with name "uuid" already exists`) + + err = RegisterMkeyType(99996, newTestMkey) + assert.NoError(t, err) +} diff --git a/comid/oid.go b/comid/oid.go index 822d7deb..24d366e6 100644 --- a/comid/oid.go +++ b/comid/oid.go @@ -11,6 +11,8 @@ import ( "strings" ) +const OIDType = "oid" + // BER-encoded absolute OID type OID []byte @@ -152,3 +154,78 @@ func (o *OID) UnmarshalJSON(data []byte) error { func (o OID) MarshalJSON() ([]byte, error) { return json.Marshal(o.String()) } + +type TaggedOID OID + +func NewTaggedOID(val any) (*TaggedOID, error) { + ret := TaggedOID{} + + if val == nil { + return &ret, nil + } + + switch t := val.(type) { + case string: + var berOID OID + if err := berOID.FromString(t); err != nil { + return nil, err + } + + ret = TaggedOID(berOID) + case []byte: + ret = make([]byte, len(t)) + copy(ret, t) + case TaggedOID: + ret = make([]byte, len(t)) + copy(ret, t) + case OID: + ret = make([]byte, len(t)) + copy(ret, t) + case *TaggedOID: + ret = make([]byte, len(*t)) + copy(ret, (*t)) + case *OID: + ret = make([]byte, len(*t)) + copy(ret, (*t)) + } + + return &ret, nil +} + +func (o TaggedOID) Type() string { + return OIDType +} + +func (o TaggedOID) String() string { + return OID(o).String() +} + +func (o TaggedOID) Valid() error { + return nil +} + +func (o TaggedOID) Bytes() []byte { + return o +} + +func (o *TaggedOID) FromString(s string) error { + return (*OID)(o).FromString(s) +} + +func (o *TaggedOID) UnmarshalJSON(data []byte) error { + var s string + + if err := json.Unmarshal(data, &s); err != nil { + return err + } + + if err := o.FromString(s); err != nil { + return err + } + + return nil +} + +func (o TaggedOID) MarshalJSON() ([]byte, error) { + return json.Marshal(o.String()) +} diff --git a/comid/psareferencevalue.go b/comid/psareferencevalue.go index cde3304b..0ffacf97 100644 --- a/comid/psareferencevalue.go +++ b/comid/psareferencevalue.go @@ -4,9 +4,12 @@ package comid import ( + "encoding/json" "fmt" ) +var PSARefValIDType = "psa.refval-id" + // PSARefValID stores a PSA refval-id with CBOR and JSON serializations // (See https://datatracker.ietf.org/doc/html/draft-xyz-rats-psa-endorsements) type PSARefValID struct { @@ -30,18 +33,56 @@ func (o PSARefValID) Valid() error { return nil } -type TaggedPSARefValID PSARefValID +func CreatePSARefValID(signerID []byte, label, version string) (*PSARefValID, error) { + ret, err := NewPSARefValID(signerID) + if err != nil { + return nil, err + } -func NewPSARefValID(signerID []byte) *PSARefValID { - switch len(signerID) { - case 32, 48, 64: - default: - return nil + ret.SetLabel(label) + ret.SetVersion(version) + + return ret, nil +} + +func MustCreatePSARefValID(signerID []byte, label, version string) *PSARefValID { + ret, err := CreatePSARefValID(signerID, label, version) + + if err != nil { + panic(err) } - return &PSARefValID{ - SignerID: signerID, + return ret +} + +func NewPSARefValID(val any) (*PSARefValID, error) { + var ret PSARefValID + + if val == nil { + return &ret, nil } + + switch t := val.(type) { + case PSARefValID: + ret = t + case *PSARefValID: + ret = *t + case string: + if err := json.Unmarshal([]byte(t), &ret); err != nil { + return nil, err + } + case []byte: + switch len(t) { + case 32, 48, 64: + ret.SignerID = t + default: + return nil, fmt.Errorf("invalid PSA RefVal ID length: %d", len(t)) + } + default: + return nil, fmt.Errorf("unexpected type for PSA RefVal ID: %T", t) + } + + return &ret, nil } func (o *PSARefValID) SetLabel(label string) *PSARefValID { @@ -57,3 +98,46 @@ func (o *PSARefValID) SetVersion(version string) *PSARefValID { } return o } + +type TaggedPSARefValID PSARefValID + +func NewTaggedPSARefValID(val any) (*TaggedPSARefValID, error) { + var ret TaggedPSARefValID + + switch t := val.(type) { + case TaggedPSARefValID: + ret = t + case *TaggedPSARefValID: + ret = *t + default: + refvalID, err := NewPSARefValID(val) + if err != nil { + return nil, err + } + ret = TaggedPSARefValID(*refvalID) + + } + + return &ret, nil +} + +func (o TaggedPSARefValID) Valid() error { + return PSARefValID(o).Valid() +} + +func (o TaggedPSARefValID) String() string { + ret, err := json.Marshal(o) + if err != nil { + return "" + } + + return string(ret) +} + +func (o TaggedPSARefValID) Type() string { + return PSARefValIDType +} + +func (o TaggedPSARefValID) IsZero() bool { + return len(o.SignerID) == 0 +} diff --git a/comid/psareferencevalue_test.go b/comid/psareferencevalue_test.go index 069c432e..72116cc2 100644 --- a/comid/psareferencevalue_test.go +++ b/comid/psareferencevalue_test.go @@ -4,9 +4,11 @@ package comid import ( + "fmt" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestPSARefValID_Valid_SignerID_range(t *testing.T) { @@ -15,13 +17,27 @@ func TestPSARefValID_Valid_SignerID_range(t *testing.T) { for i := 1; i <= 100; i++ { signerID = append(signerID, byte(0xff)) - tv := NewPSARefValID(signerID) + tv, err := NewPSARefValID(signerID) + switch i { case 32, 48, 64: assert.NotNil(t, tv) assert.Nil(t, tv.Valid()) default: assert.Nil(t, tv) + assert.EqualError( + t, + err, + fmt.Sprintf("invalid PSA RefVal ID length: %d", i), + ) } } } + +func TestPSARefValID_Streing(t *testing.T) { + signerID := MustHexDecode(t, "deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef") + refvalID, err := NewTaggedPSARefValID(signerID) + require.NoError(t, err) + + assert.Equal(t, `{"signer-id":"3q2+796tvu/erb7v3q2+796tvu/erb7v3q2+796tvu8="}`, refvalID.String()) +} diff --git a/comid/referencevalue_test.go b/comid/referencevalue_test.go index 00b89d18..f999c876 100644 --- a/comid/referencevalue_test.go +++ b/comid/referencevalue_test.go @@ -13,10 +13,9 @@ func Test_ReferenceValue(t *testing.T) { err := rv.Valid() assert.EqualError(t, err, "environment validation failed: environment must not be empty") - rv.Environment.Instance = NewInstance() id, err := uuid.NewUUID() require.NoError(t, err) - rv.Environment.Instance.SetUUID(id) + rv.Environment.Instance = MustNewUUIDInstance(id) err = rv.Valid() assert.EqualError(t, err, "measurements validation failed: no measurement entries") } diff --git a/comid/rel.go b/comid/rel.go index ad35ed36..41188408 100644 --- a/comid/rel.go +++ b/comid/rel.go @@ -18,6 +18,35 @@ const ( RelUnset = ^Rel(0) ) +var ( + relToString = map[Rel]string{ + RelReplaces: "replaces", + RelSupplements: "supplements", + } + + stringToRel = map[string]Rel{ + "replaces": RelReplaces, + "supplements": RelSupplements, + } +) + +func RegisterRel(val int64, name string) error { + rel := Rel(val) + + if _, ok := relToString[rel]; ok { + return fmt.Errorf("rel with value %d already exists", val) + } + + if _, ok := stringToRel[name]; ok { + return fmt.Errorf("rel with name %q already exists", name) + } + + relToString[rel] = name + stringToRel[name] = rel + + return nil +} + func NewRel() *Rel { r := RelUnset return &r @@ -43,14 +72,12 @@ func (o Rel) Valid() error { } func (o Rel) String() string { - switch o { - case RelReplaces: - return "replaces" - case RelSupplements: - return "supplements" - default: - return fmt.Sprintf("rel(%d)", o) + ret, ok := relToString[o] + if ok { + return ret } + + return fmt.Sprintf("rel(%d)", o) } func (o Rel) ToCBOR() ([]byte, error) { diff --git a/comid/rel_test.go b/comid/rel_test.go index 29de1629..86f6d748 100644 --- a/comid/rel_test.go +++ b/comid/rel_test.go @@ -163,3 +163,20 @@ func TestRel_ToCBOR_fail_unset(t *testing.T) { assert.EqualError(t, err, "rel is unset") } + +func Test_RegisterRel(t *testing.T) { + err := RegisterRel(1, "augments") + assert.EqualError(t, err, "rel with value 1 already exists") + + err = RegisterRel(3, "replaces") + assert.EqualError(t, err, `rel with name "replaces" already exists`) + + err = RegisterRel(3, "augments") + assert.NoError(t, err) + + rel := Rel(3) + + out, err := rel.MarshalJSON() + require.NoError(t, err) + assert.Equal(t, `"augments"`, string(out)) +} diff --git a/comid/role.go b/comid/role.go index 1b08322b..14fc6e3b 100644 --- a/comid/role.go +++ b/comid/role.go @@ -40,6 +40,23 @@ var ( } ) +func RegisterRole(val int64, name string) error { + role := Role(val) + + if _, ok := roleToString[role]; ok { + return fmt.Errorf("role with value %d already exists", val) + } + + if _, ok := stringToRole[name]; ok { + return fmt.Errorf("role with name %q already exists", name) + } + + roleToString[role] = name + stringToRole[name] = role + + return nil +} + type Roles []Role func NewRoles() *Roles { diff --git a/comid/role_test.go b/comid/role_test.go index 4479960c..0b9d585c 100644 --- a/comid/role_test.go +++ b/comid/role_test.go @@ -210,3 +210,20 @@ func TestRoles_UnmarshalJSON_fail(t *testing.T) { assert.EqualError(t, err, tv.expectedErr) } } + +func Test_RegisterRole(t *testing.T) { + err := RegisterRole(1, "owner") + assert.EqualError(t, err, "role with value 1 already exists") + + err = RegisterRole(3, "maintainer") + assert.EqualError(t, err, `role with name "maintainer" already exists`) + + err = RegisterRole(3, "owner") + assert.NoError(t, err) + + roles := NewRoles().Add(Role(3)) + + out, err := roles.MarshalJSON() + require.NoError(t, err) + assert.Equal(t, `["owner"]`, string(out)) +} diff --git a/comid/svn.go b/comid/svn.go index ed7a26bb..fdf84d94 100644 --- a/comid/svn.go +++ b/comid/svn.go @@ -6,102 +6,254 @@ package comid import ( "encoding/json" "fmt" -) + "strconv" -type TaggedSVN uint64 -type TaggedMinSVN uint64 + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" +) +// SVN is the Security Version Number. This typically changes only when a +// security relevant change is maded to the measured environment. type SVN struct { - val interface{} + Value ISVNValue } -func (o *SVN) SetSVN(val uint64) *SVN { - if o != nil { - o.val = TaggedSVN(val) +// NewSVN creates a new SVN of the specified type with the specified value. The +// type must be one of the ones defined by the spec ("exact-value", +// "min-value"), or has been registered with RegisterSVNType(). +func NewSVN(val any, typ string) (*SVN, error) { + factory, ok := svnValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unknown SVN type: %s", typ) } - return o + + return factory(val) } -func (o *SVN) SetMinSVN(val uint64) *SVN { - if o != nil { - o.val = TaggedMinSVN(val) +// MustNewSVN is like NewSVN but does not return an error, assuming that the +// provided value is valid. It panics if this is not the case. +func MustNewSVN(val any, typ string) *SVN { + ret, err := NewSVN(val, typ) + if err != nil { + panic(err) } - return o + + return ret } +// MarshalCBOR returns the CBOR encoding of the SVN. func (o SVN) MarshalCBOR() ([]byte, error) { - return em.Marshal(o.val) + return em.Marshal(o.Value) } +// UnmarshalCBOR populates the SVN form the provided CBOR bytes. func (o *SVN) UnmarshalCBOR(data []byte) error { - var svn TaggedSVN + return dm.Unmarshal(data, &o.Value) +} + +// UnmarshalJSON deserializes the supplied JSON object into the target SVN +// The SVN object must have the following shape: +// +// { +// "type": "", +// "value": +// } +// +// where must be one of the known ISVNValue implementation +// type names (available in the base implementation: "exact-value", +// "min-value"), and is the JSON encoding of the underlying +// class id value. The exact encoding is dependent. For both base +// types, it is an integer (JSON number). +func (o *SVN) UnmarshalJSON(data []byte) error { + var tnv encoding.TypeAndValue - if dm.Unmarshal(data, &svn) == nil { - o.val = svn - return nil + if err := json.Unmarshal(data, &tnv); err != nil { + return fmt.Errorf("SVN decoding failure: %w", err) } - var minsvn TaggedMinSVN + decoded, err := NewSVN(nil, tnv.Type) + if err != nil { + return err + } - if dm.Unmarshal(data, &minsvn) == nil { - o.val = svn - return nil + if err := json.Unmarshal(tnv.Value, &decoded.Value); err != nil { + return fmt.Errorf("invalid SVN %s: %w", tnv.Type, err) } - return fmt.Errorf("unknown SVN (CBOR: %x)", data) + if err := decoded.Value.Valid(); err != nil { + return fmt.Errorf("invalid SVN %s: %w", tnv.Type, err) + } + + o.Value = decoded.Value + + return nil +} + +// MarshalJSON serializes the SVN int a JSON object +func (o SVN) MarshalJSON() ([]byte, error) { + return extensions.TypeChoiceValueMarshalJSON(o.Value) } -type svnJSONRepr tnv +// ISVNValue is the interface that must be implemented by all SVN values. +type ISVNValue interface { + extensions.ITypeChoiceValue +} -// Supported formats: -// { "type": "exact-value", "value": 123 } -> SVN -// { "type": "min-value", "value": 123 } -> MinSVN -func (o *SVN) UnmarshalJSON(data []byte) error { - var s svnJSONRepr +const ( + ExactValueType = "exact-value" + MinValueType = "min-value" +) - if err := json.Unmarshal(data, &s); err != nil { - return fmt.Errorf("SVN decoding failure: %w", err) - } +type TaggedSVN uint64 + +func NewTaggedSVN(val any) (*SVN, error) { + var ret TaggedSVN - var x uint64 - if err := json.Unmarshal(s.Value, &x); err != nil { - return fmt.Errorf( - "cannot unmarshal svn or min-svn: %w", - err, - ) + if val == nil { + return &SVN{&ret}, nil } - switch s.Type { - case "exact-value": - o.val = TaggedSVN(x) - case "min-value": - o.val = TaggedMinSVN(x) + switch t := val.(type) { + case string: + u, err := strconv.ParseUint(t, 10, 64) + if err != nil { + return nil, err + } + ret = TaggedSVN(u) + case TaggedSVN: + ret = t + case *TaggedSVN: + ret = *t + case uint64: + ret = TaggedSVN(t) + case uint: + ret = TaggedSVN(t) + case int: + if t < 0 { + return nil, fmt.Errorf("SVN cannot be negative: %d", t) + } + ret = TaggedSVN(t) + case int64: + if t < 0 { + return nil, fmt.Errorf("SVN cannot be negative: %d", t) + } + ret = TaggedSVN(t) default: - return fmt.Errorf("unknown comparison operator %s", s.Type) + return nil, fmt.Errorf("unexpected type for SVN exact-value: %T", t) } + return &SVN{&ret}, nil +} + +func MustNewTaggedSVN(val any) *SVN { + ret, err := NewTaggedSVN(val) + if err != nil { + panic(err) + } + + return ret +} + +func (o TaggedSVN) String() string { + return fmt.Sprint(uint64(o)) +} + +func (o TaggedSVN) Type() string { + return ExactValueType +} + +func (o TaggedSVN) Valid() error { return nil } -func (o SVN) MarshalJSON() ([]byte, error) { - var ( - v svnJSONRepr - b []byte - err error - ) +type TaggedMinSVN uint64 - b, err = json.Marshal(o.val) - if err != nil { - return nil, err +func NewTaggedMinSVN(val any) (*SVN, error) { + var ret TaggedMinSVN + + if val == nil { + return &SVN{&ret}, nil } - switch t := o.val.(type) { - case TaggedSVN: - v = svnJSONRepr{Type: "exact-value", Value: b} + + switch t := val.(type) { + case string: + u, err := strconv.ParseUint(t, 10, 64) + if err != nil { + return nil, err + } + ret = TaggedMinSVN(u) case TaggedMinSVN: - v = svnJSONRepr{Type: "min-value", Value: b} + ret = t + case *TaggedMinSVN: + ret = *t + case uint64: + ret = TaggedMinSVN(t) + case uint: + ret = TaggedMinSVN(t) + case int: + if t < 0 { + return nil, fmt.Errorf("SVN cannot be negative: %d", t) + } + ret = TaggedMinSVN(t) + case int64: + if t < 0 { + return nil, fmt.Errorf("SVN cannot be negative: %d", t) + } + ret = TaggedMinSVN(t) default: - return nil, fmt.Errorf("unknown SVN type: %T", t) + return nil, fmt.Errorf("unexpected type for SVN min-value: %T", t) + } + + return &SVN{&ret}, nil +} + +func MustNewTaggedMinSVN(val any) *SVN { + ret, err := NewTaggedMinSVN(val) + if err != nil { + panic(err) + } + + return ret +} + +func (o TaggedMinSVN) String() string { + return fmt.Sprint(uint64(o)) +} + +func (o TaggedMinSVN) Type() string { + return MinValueType +} + +func (o TaggedMinSVN) Valid() error { + return nil +} + +type ISVNFactory func(any) (*SVN, error) + +var svnValueRegister = map[string]ISVNFactory{ + ExactValueType: NewTaggedSVN, + MinValueType: NewTaggedMinSVN, +} + +// RegisterSVNType registers a new ISVNValue implementation +// (created by the provided ISVNFactory) under the specified CBOR tag. +func RegisterSVNType(tag uint64, factory ISVNFactory) error { + + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Value.Type() + if _, exists := svnValueRegister[typ]; exists { + return fmt.Errorf("SVN type with name %q already exists", typ) } - return json.Marshal(v) + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + svnValueRegister[typ] = factory + + return nil } diff --git a/comid/svn_test.go b/comid/svn_test.go new file mode 100644 index 00000000..0ffb8f28 --- /dev/null +++ b/comid/svn_test.go @@ -0,0 +1,182 @@ +package comid + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_NewSVN(t *testing.T) { + for _, tv := range []struct { + Name string + Input any + Expected uint64 + Err string + }{ + { + Name: "string ok", + Input: "7", + Expected: 7, + Err: "", + }, + { + Name: "string err", + Input: "test", + Expected: 0, + Err: `strconv.ParseUint: parsing "test": invalid syntax`, + }, + { + Name: "uint", + Input: uint(7), + Expected: 7, + Err: "", + }, + { + Name: "uint64", + Input: uint64(7), + Expected: 7, + Err: "", + }, + { + Name: "int ok", + Input: 7, + Expected: 7, + Err: "", + }, + { + Name: "int not ok", + Input: -7, + Expected: 0, + Err: "SVN cannot be negative: -7", + }, + { + Name: "int64 ok", + Input: int64(7), + Expected: 7, + Err: "", + }, + { + Name: "int64 not ok", + Input: int64(-7), + Expected: 0, + Err: "SVN cannot be negative: -7", + }, + { + Name: "nil", + Input: nil, + Expected: 0, + Err: "", + }, + } { + t.Run(tv.Name, func(t *testing.T) { + ret, err := NewSVN(tv.Input, "exact-value") + exact := TaggedSVN(tv.Expected) + expected := SVN{&exact} + + if tv.Err != "" { + assert.EqualError(t, err, tv.Err) + } else { + assert.Equal(t, &expected, ret) + } + + retMin, err := NewSVN(tv.Input, "min-value") + min := TaggedMinSVN(tv.Expected) + expected = SVN{&min} + + if tv.Err != "" { + assert.EqualError(t, err, tv.Err) + } else { + assert.Equal(t, &expected, retMin) + } + }) + } + + in := TaggedSVN(7) + + _, err := NewSVN(in, "exact-value") + assert.NoError(t, err) + + _, err = NewSVN(&in, "exact-value") + assert.NoError(t, err) + + _, err = NewSVN(true, "exact-value") + assert.EqualError(t, err, "unexpected type for SVN exact-value: bool") + + inMin := TaggedMinSVN(7) + + _, err = NewSVN(inMin, "min-value") + assert.NoError(t, err) + + _, err = NewSVN(&inMin, "min-value") + assert.NoError(t, err) + + _, err = NewSVN(true, "min-value") + assert.EqualError(t, err, "unexpected type for SVN min-value: bool") + + _, err = NewSVN(true, "test") + assert.EqualError(t, err, "unknown SVN type: test") + + ret := MustNewSVN(7, "exact-value") + assert.NotNil(t, ret) + + assert.Panics(t, func() { MustNewSVN(true, "exact-value") }) +} + +func TestSVN_JSON(t *testing.T) { + var v SVN + + err := v.UnmarshalJSON([]byte(`{"type":"exact-value","value":2.3}`)) + assert.EqualError(t, err, "invalid SVN exact-value: json: cannot unmarshal number 2.3 into Go value of type comid.TaggedSVN") + + err = v.UnmarshalJSON([]byte(`{"type":"test","value":7}`)) + assert.EqualError(t, err, "unknown SVN type: test") + + err = v.UnmarshalJSON([]byte(`@@@`)) + assert.EqualError(t, err, "SVN decoding failure: invalid character '@' looking for beginning of value") + +} + +type testSVN uint64 + +func newTestSVN(val any) (*SVN, error) { + v := testSVN(7) + return &SVN{&v}, nil +} + +func (o testSVN) Type() string { + return "test-value" +} + +func (o testSVN) String() string { + return "test" +} + +func (o testSVN) Valid() error { + return nil +} + +type testSVNBadType struct { + testSVN +} + +func newTestSVNBadType(val any) (*SVN, error) { + v := testSVNBadType{testSVN(7)} + return &SVN{&v}, nil +} + +func (o testSVNBadType) Type() string { + return "min-value" +} + +func Test_RegisterSVNType(t *testing.T) { + err := RegisterSVNType(32, newTestSVN) + assert.EqualError(t, err, "tag 32 is already registered") + + err = RegisterSVNType(99995, newTestSVNBadType) + assert.EqualError(t, err, `SVN type with name "min-value" already exists`) + + err = RegisterSVNType(99995, newTestSVN) + require.NoError(t, err) + +} diff --git a/comid/ueid.go b/comid/ueid.go index 765d1b86..cf64ffa2 100644 --- a/comid/ueid.go +++ b/comid/ueid.go @@ -4,18 +4,17 @@ package comid import ( - "encoding/json" + "encoding/base64" "fmt" "github.com/veraison/eat" ) +const UEIDType = "ueid" + // UEID is an Unique Entity Identifier type UEID eat.UEID -// TaggedUEID is an alias to allow automatic tagging of an UEID type -type TaggedUEID UEID - func (o UEID) Empty() bool { return len(o) == 0 } @@ -28,25 +27,65 @@ func (o UEID) Valid() error { return nil } -// UnmarshalJSON deserializes the supplied string into the UEID target -func (o *UEID) UnmarshalJSON(data []byte) error { - var b []byte +func (o UEID) String() string { + return base64.StdEncoding.EncodeToString(o) +} + +// TaggedUEID is an alias to allow automatic tagging of an UEID type +type TaggedUEID UEID + +func NewTaggedUEID(val any) (*TaggedUEID, error) { + var ret TaggedUEID - if err := json.Unmarshal(data, &b); err != nil { - return err + if val == nil { + return &ret, nil } - u := UEID(b) + switch t := val.(type) { + case string: + b, err := base64.StdEncoding.DecodeString(t) + if err != nil { + return nil, fmt.Errorf("bad UEID: %w", err) + } - if err := u.Valid(); err != nil { - return err + ret = TaggedUEID(b) + case []byte: + ret = TaggedUEID(t) + case TaggedUEID: + ret = append(ret, t...) + case *TaggedUEID: + ret = append(ret, *t...) + case UEID: + ret = append(ret, t...) + case *UEID: + ret = append(ret, *t...) + case eat.UEID: + ret = append(ret, t...) + case *eat.UEID: + ret = append(ret, *t...) + default: + return nil, fmt.Errorf("unexpeted type for UEID: %T", t) } - *o = u + if err := ret.Valid(); err != nil { + return nil, err + } - return nil + return &ret, nil +} + +func (o TaggedUEID) Valid() error { + return UEID(o).Valid() +} + +func (o TaggedUEID) String() string { + return UEID(o).String() +} + +func (o TaggedUEID) Type() string { + return "ueid" } -func (o UEID) MarshalJSON() ([]byte, error) { - return json.Marshal([]byte(o)) +func (o TaggedUEID) Bytes() []byte { + return []byte(o) } diff --git a/comid/ueid_test.go b/comid/ueid_test.go new file mode 100644 index 00000000..a119ad44 --- /dev/null +++ b/comid/ueid_test.go @@ -0,0 +1,30 @@ +package comid + +import ( + "encoding/base64" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_NewTaggedUEID(t *testing.T) { + ueid := UEID(TestUEID) + tagged := TaggedUEID(TestUEID) + bytes := MustHexDecode(t, TestUEIDString) + + for _, v := range []any{ + TestUEID, + &TestUEID, + ueid, + &ueid, + tagged, + &tagged, + bytes, + base64.StdEncoding.EncodeToString(bytes), + } { + ret, err := NewTaggedUEID(v) + require.NoError(t, err) + assert.Equal(t, []byte(TestUEID), ret.Bytes()) + } +} diff --git a/comid/uuid.go b/comid/uuid.go index c5c78d61..68fc630b 100644 --- a/comid/uuid.go +++ b/comid/uuid.go @@ -10,12 +10,11 @@ import ( "github.com/google/uuid" ) +const UUIDType = "uuid" + // UUID represents an Universally Unique Identifier (UUID, see RFC4122) type UUID uuid.UUID -// TaggedUUID is an alias to allow automatic tagging of a UUID type -type TaggedUUID UUID - // ParseUUID parses the supplied string into a UUID func ParseUUID(s string) (UUID, error) { v, err := uuid.Parse(s) @@ -64,3 +63,89 @@ func (o *UUID) UnmarshalJSON(data []byte) error { func (o UUID) MarshalJSON() ([]byte, error) { return json.Marshal(o.String()) } + +// TaggedUUID is an alias to allow automatic tagging of a UUID type +type TaggedUUID UUID + +func NewTaggedUUID(val any) (*TaggedUUID, error) { + var ret TaggedUUID + + switch t := val.(type) { + case string: + u, err := ParseUUID(t) + if err != nil { + return nil, fmt.Errorf("bad UUID: %w", err) + } + ret = TaggedUUID(u) + case []byte: + if len(t) != 16 { + return nil, fmt.Errorf( + "unexpected size for UUID: expected 16 bytes, found %d", + len(t), + ) + } + + copy(ret[:], t) + case TaggedUUID: + copy(ret[:], t[:]) + case *TaggedUUID: + copy(ret[:], (*t)[:]) + case UUID: + copy(ret[:], t[:]) + case *UUID: + copy(ret[:], (*t)[:]) + case uuid.UUID: + copy(ret[:], t[:]) + case *uuid.UUID: + copy(ret[:], (*t)[:]) + default: + return nil, fmt.Errorf("unexpected type for UUID: %T", t) + } + + if err := ret.Valid(); err != nil { + return nil, err + } + + return &ret, nil +} + +// String returns a string representation of the binary UUID +func (o TaggedUUID) String() string { + return UUID(o).String() +} + +func (o TaggedUUID) Valid() error { + return UUID(o).Valid() +} + +// Type returns a string containing type name. This is part of the +// ITypeChoiceValue implementation. +func (o TaggedUUID) Type() string { + return UUIDType +} + +// Bytes returns a []byte containing the raw UUID bytes +func (o TaggedUUID) Bytes() []byte { + return o[:] +} + +func (o TaggedUUID) MarshalJSON() ([]byte, error) { + temp := o.String() + return json.Marshal(temp) +} + +func (o *TaggedUUID) UnmarshalJSON(data []byte) error { + var temp string + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + u, err := ParseUUID(temp) + if err != nil { + return fmt.Errorf("bad UUID: %w", err) + } + + *o = TaggedUUID(u) + + return nil +} diff --git a/comid/uuid_test.go b/comid/uuid_test.go new file mode 100644 index 00000000..5308854a --- /dev/null +++ b/comid/uuid_test.go @@ -0,0 +1,24 @@ +package comid + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUUID_JSON(t *testing.T) { + val := TaggedUUID(TestUUID) + expected := fmt.Sprintf(`"%s"`, val.String()) + + out, err := val.MarshalJSON() + require.NoError(t, err) + assert.Equal(t, expected, string(out)) + + var outUUID TaggedUUID + + err = outUUID.UnmarshalJSON(out) + require.NoError(t, err) + assert.Equal(t, val, outUUID) +} diff --git a/corim/cbor.go b/corim/cbor.go index ec15e95f..17464d1d 100644 --- a/corim/cbor.go +++ b/corim/cbor.go @@ -4,6 +4,7 @@ package corim import ( + "fmt" "reflect" cbor "github.com/fxamacker/cbor/v2" @@ -18,13 +19,13 @@ var ( var ( CoswidTag = []byte{0xd9, 0x01, 0xf9} // 505() ComidTag = []byte{0xd9, 0x01, 0xfa} // 506() -) -func corimTags() cbor.TagSet { - corimTagsMap := map[uint64]interface{}{ + corimTagsMap = map[uint64]interface{}{ 32: comid.TaggedURI(""), } +) +func corimTags() cbor.TagSet { opts := cbor.TagOptions{ EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired, @@ -57,6 +58,28 @@ func initCBORDecMode() (dm cbor.DecMode, err error) { return decOpt.DecModeWithTags(corimTags()) } +func registerCORIMTag(tag uint64, t interface{}) error { + if _, exists := corimTagsMap[tag]; exists { + return fmt.Errorf("tag %d is already registered", tag) + } + + corimTagsMap[tag] = t + + var err error + + em, err = initCBOREncMode() + if err != nil { + return err + } + + dm, err = initCBORDecMode() + if err != nil { + return err + } + + return nil +} + func init() { if emError != nil { panic(emError) diff --git a/corim/entity.go b/corim/entity.go index b9288095..ca44201e 100644 --- a/corim/entity.go +++ b/corim/entity.go @@ -4,7 +4,10 @@ package corim import ( + "encoding/json" + "errors" "fmt" + "unicode/utf8" "github.com/veraison/corim/comid" "github.com/veraison/corim/encoding" @@ -13,7 +16,7 @@ import ( // Entity stores an entity-map capable of CBOR and JSON serializations. type Entity struct { - EntityName string `cbor:"0,keyasint" json:"name"` + EntityName *EntityName `cbor:"0,keyasint" json:"name"` RegID *comid.TaggedURI `cbor:"1,keyasint,omitempty" json:"regid,omitempty"` Roles Roles `cbor:"2,keyasint" json:"roles"` @@ -35,12 +38,13 @@ func (o *Entity) GetExtensions() extensions.IExtensionsValue { } // SetEntityName is used to set the EntityName field of Entity using supplied name -func (o *Entity) SetEntityName(name string) *Entity { +func (o *Entity) SetEntityName(name any) *Entity { + if o != nil { if name == "" { return nil } - o.EntityName = name + o.EntityName = MustNewStringEntityName(name) } return o } @@ -74,10 +78,14 @@ func (o *Entity) SetRoles(roles ...Role) *Entity { // Valid checks for validity of the fields within each Entity func (o Entity) Valid() error { - if o.EntityName == "" { + if o.EntityName == nil { return fmt.Errorf("invalid entity: empty entity-name") } + if err := o.EntityName.Valid(); err != nil { + return fmt.Errorf("invalid entity: %w", err) + } + if o.RegID != nil && o.RegID.Empty() { return fmt.Errorf("invalid entity: empty reg-id") } @@ -134,3 +142,225 @@ func (o Entities) Valid() error { } return nil } + +// EntityName encapsulates the name of the associated Entity. The CoRIM +// specification only allows for text (string) name, but this may be extended +// by other specifications. +type EntityName struct { + Value IEntityNameValue +} + +// NewEntityName creates a new EntityName of the specified type using the +// provided value. +func NewEntityName(val any, typ string) (*EntityName, error) { + factory, ok := entityNameValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unexpected entity name type: %s", typ) + } + + return factory(val) +} + +// MustNewEntityName is like NewEntityName, except it doesn't return an error, +// assuming that the provided value is valid. It panics if that isn't the case. +func MustNewEntityName(val any, typ string) *EntityName { + ret, err := NewEntityName(val, typ) + if err != nil { + panic(err) + } + + return ret +} + +// String returns the string representation of the EntityName +func (o EntityName) String() string { + return o.Value.String() +} + +// Valid returns nil if the underlying EntityName value is valid, or an error +// describing the problem otherwise. +func (o EntityName) Valid() error { + if o.Value == nil { + return errors.New("empty entity name") + } + + return o.Value.Valid() +} + +// MarshalCBOR serializes the EntityName into CBOR-encoded bytes. +func (o EntityName) MarshalCBOR() ([]byte, error) { + if err := o.Valid(); err != nil { + return nil, err + } + + return em.Marshal(o.Value) +} + +// UnmarshalCBOR deserializes the EntityName from CBOR-encoded bytes. +func (o *EntityName) UnmarshalCBOR(data []byte) error { + if len(data) == 0 { + return errors.New("empty") + } + + majorType := (data[0] & 0xe0) >> 5 + if majorType == 3 { // text string + var text string + + if err := dm.Unmarshal(data, &text); err != nil { + return err + } + + name := StringEntityName(text) + o.Value = &name + + return nil + } + + return dm.Unmarshal(data, &o.Value) +} + +// MarshalJSON serializes the EntityName into a JSON object. +func (o EntityName) MarshalJSON() ([]byte, error) { + if err := o.Valid(); err != nil { + return nil, err + } + + if o.Value.Type() == extensions.StringType { + return json.Marshal(o.Value.String()) + } + + valueBytes, err := json.Marshal(o.Value) + if err != nil { + return nil, err + } + + value := encoding.TypeAndValue{ + Type: o.Value.Type(), + Value: valueBytes, + } + + return json.Marshal(value) +} + +// UnmarshalJSON deserializes EntityName from the provided JSON object. +func (o *EntityName) UnmarshalJSON(data []byte) error { + var text string + if err := json.Unmarshal(data, &text); err == nil { + *o = *MustNewStringEntityName(text) + return nil + } + + var tnv encoding.TypeAndValue + + if err := json.Unmarshal(data, &tnv); err != nil { + return fmt.Errorf("entity name decoding failure: %w", err) + } + + decoded, err := NewEntityName(nil, tnv.Type) + if err != nil { + return err + } + + if err := json.Unmarshal(tnv.Value, &decoded.Value); err != nil { + return fmt.Errorf( + "cannot unmarshal entity name: %w", + err, + ) + } + + if err := decoded.Value.Valid(); err != nil { + return fmt.Errorf("invalid %s: %w", tnv.Type, err) + } + + o.Value = decoded.Value + + return nil +} + +// IEntityNameValue is the interface implemented by all EntityName value types. +type IEntityNameValue interface { + extensions.ITypeChoiceValue +} + +// StringEntityName is a text string EntityName with no other contraints. This +// is the only EntityName value type defined by the CoRIM specification itself. +type StringEntityName string + +func NewStringEntityName(val any) (*EntityName, error) { + var ret StringEntityName + + if val == nil { + ret = StringEntityName("") + return &EntityName{&ret}, nil + } + + switch t := val.(type) { + case string: + ret = StringEntityName(t) + case []byte: + if !utf8.Valid(t) { + return nil, errors.New("bytes do not form a valid UTF-8 string") + } + + ret = StringEntityName(t) + default: + return nil, fmt.Errorf("unexpected type for string entity name: %T", t) + } + + return &EntityName{&ret}, nil +} + +func MustNewStringEntityName(val any) *EntityName { + ret, err := NewStringEntityName(val) + if err != nil { + panic(err) + } + + return ret +} + +func (o StringEntityName) String() string { + return string(o) +} + +func (o StringEntityName) Type() string { + return extensions.StringType +} + +func (o StringEntityName) Valid() error { + if o == "" { + return errors.New("empty entity-name") + } + + return nil +} + +type IEntityNameFactory func(any) (*EntityName, error) + +var entityNameValueRegister = map[string]IEntityNameFactory{ + extensions.StringType: NewStringEntityName, +} + +// RegisterEntityNameType registers a new IEntityNameValue implementation +// (created by the provided IEntityNameFactory) under the specified type name +// and CBOR tag. +func RegisterEntityNameType(tag uint64, factory IEntityNameFactory) error { + + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Value.Type() + if _, exists := entityNameValueRegister[typ]; exists { + return fmt.Errorf("entity name type with name %q already exists", typ) + } + + if err := registerCORIMTag(tag, nilVal.Value); err != nil { + return err + } + + entityNameValueRegister[typ] = factory + + return nil +} diff --git a/corim/entity_test.go b/corim/entity_test.go index 457b3770..aec1be92 100644 --- a/corim/entity_test.go +++ b/corim/entity_test.go @@ -4,6 +4,8 @@ package corim import ( + "errors" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -21,7 +23,7 @@ func TestEntity_Valid_uninitialized(t *testing.T) { func TestEntity_Valid_empty_name(t *testing.T) { tv := Entity{ - EntityName: "", + EntityName: MustNewStringEntityName(""), } err := tv.Valid() @@ -33,7 +35,7 @@ func TestEntity_Valid_non_nil_empty_URI(t *testing.T) { emptyRegID := comid.TaggedURI("") tv := Entity{ - EntityName: "ACME Ltd.", + EntityName: MustNewStringEntityName("ACME Ltd."), RegID: &emptyRegID, } @@ -46,7 +48,7 @@ func TestEntity_Valid_missing_roles(t *testing.T) { regID := comid.TaggedURI("http://acme.example") tv := Entity{ - EntityName: "ACME Ltd.", + EntityName: MustNewStringEntityName("ACME Ltd."), RegID: ®ID, } @@ -59,7 +61,7 @@ func TestEntity_Valid_unknown_role(t *testing.T) { regID := comid.TaggedURI("http://acme.example") tv := Entity{ - EntityName: "ACME Ltd.", + EntityName: MustNewStringEntityName("ACME Ltd."), RegID: ®ID, Roles: Roles{Role(666)}, } @@ -92,3 +94,185 @@ func TestEntities_Valid_empty(t *testing.T) { err := es.Valid() assert.EqualError(t, err, "entity at index 0: invalid entity: empty entity-name") } + +type testEntityName uint64 + +func newTestEntityName(val any) (*EntityName, error) { + if val == nil { + v := testEntityName(0) + return &EntityName{&v}, nil + } + + u, ok := val.(uint64) + if !ok { + return nil, errors.New("must be uint64") + } + + v := testEntityName(u) + return &EntityName{&v}, nil +} + +func (o testEntityName) Type() string { + return "test" +} + +func (o testEntityName) String() string { + return fmt.Sprint(uint64(o)) +} + +func (o testEntityName) Valid() error { + return nil +} + +type testEntityNameBadType struct { + testEntityName +} + +func newTestEntityNameBadType(val any) (*EntityName, error) { + v := testEntityNameBadType{testEntityName(7)} + return &EntityName{&v}, nil +} + +func (o testEntityNameBadType) Type() string { + return "string" +} + +func Test_RegisterEntityNameType(t *testing.T) { + err := RegisterEntityNameType(32, newTestEntityName) + assert.EqualError(t, err, "tag 32 is already registered") + + err = RegisterEntityNameType(99994, newTestEntityNameBadType) + assert.EqualError(t, err, `entity name type with name "string" already exists`) + + registerTestEntityNameType(t) +} + +// Since there only one, untagged, entity name type in the core spec, we use +// the test type define above in order to test the marshaling code works +// properly. Since global environment is not reset when running multiple tests, +// we cannot simply call RegisterEntityNameType() inside each test that relies +// on the test type, as that will cause the "tag already registered" error. On +// the other hand, we do not want to create inter-test dependencies by relying +// that the test registering the type is run before the others that rely on it. +// To get around this, use this global flag to only register the test type if a +// previous test hasn't already done so. +var testEntityNameTypeRegistered = false + +func registerTestEntityNameType(t *testing.T) { + if !testEntityNameTypeRegistered { + err := RegisterEntityNameType(99994, newTestEntityName) + require.NoError(t, err) + + testEntityNameTypeRegistered = true + } +} + +func TestEntityName_CBOR(t *testing.T) { + registerTestEntityNameType(t) + + for _, tv := range []struct { + Value any + Type string + ExpectedBytes []byte + ExpectedString string + }{ + { + Value: "test", + Type: "string", + ExpectedBytes: []byte{ + 0x64, // tstr(4) + 0x74, 0x65, 0x73, 0x74, // "test" + }, + ExpectedString: "test", + }, + { + Value: uint64(7), + Type: "test", + ExpectedBytes: []byte{ + 0xda, 0x0, 0x1, 0x86, 0x9a, // tag 99994 + 0x07, // unsigned int(7) + }, + ExpectedString: "7", + }, + } { + t.Run(tv.Type, func(t *testing.T) { + en, err := NewEntityName(tv.Value, tv.Type) + require.NoError(t, err) + + data, err := en.MarshalCBOR() + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedBytes, data) + + var out EntityName + + err = out.UnmarshalCBOR(data) + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedString, out.String()) + }) + } +} + +func TestEntityName_JSON(t *testing.T) { + registerTestEntityNameType(t) + + for _, tv := range []struct { + Value any + Type string + ExpectedBytes []byte + ExpectedString string + }{ + { + Value: "test", + Type: "string", + ExpectedBytes: []byte(`"test"`), + ExpectedString: "test", + }, + { + Value: uint64(7), + Type: "test", + ExpectedBytes: []byte(`{"type":"test","value":7}`), + ExpectedString: "7", + }, + } { + t.Run(tv.Type, func(t *testing.T) { + en, err := NewEntityName(tv.Value, tv.Type) + require.NoError(t, err) + + data, err := en.MarshalJSON() + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedBytes, data) + + var out EntityName + + err = out.UnmarshalJSON(data) + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedString, out.String()) + }) + } +} + +func Test_NewStringEntityName(t *testing.T) { + out, err := NewStringEntityName(nil) + require.NoError(t, err) + assert.EqualError(t, out.Valid(), "empty entity-name") + + out, err = NewStringEntityName([]byte("test")) + require.NoError(t, err) + assert.Equal(t, "test", out.String()) + + _, err = NewStringEntityName(7) + assert.EqualError(t, err, "unexpected type for string entity name: int") +} + +func Test_MustNewEntityName(t *testing.T) { + out := MustNewEntityName("test", "string") + assert.Equal(t, "test", out.String()) + + assert.Panics(t, func() { + MustNewEntityName(7, "int") + }) +} diff --git a/corim/extensions_test.go b/corim/extensions_test.go index 653a13e1..811d8e4d 100644 --- a/corim/extensions_test.go +++ b/corim/extensions_test.go @@ -17,7 +17,7 @@ type TestExtensions struct { } func (o TestExtensions) ValidEntity(ent *Entity) error { - if ent.EntityName != "Futurama" { + if ent.EntityName.String() != "Futurama" { return errors.New(`EntityName must be "Futurama"`) // nolint:golint } @@ -78,7 +78,7 @@ func TestEntityExtensions_CBOR(t *testing.T) { err := cbor.Unmarshal(data, &ent) assert.NoError(t, err) - assert.Equal(t, ent.EntityName, "acme") + assert.Equal(t, ent.EntityName.String(), "acme") address, err := ent.Get("address") require.NoError(t, err) diff --git a/corim/role.go b/corim/role.go index 9bfc1276..d81d3cf7 100644 --- a/corim/role.go +++ b/corim/role.go @@ -24,6 +24,23 @@ var ( } ) +func RegisterRole(val int64, name string) error { + role := Role(val) + + if _, ok := roleToString[role]; ok { + return fmt.Errorf("role with value %d already exists", val) + } + + if _, ok := stringToRole[name]; ok { + return fmt.Errorf("role with name %q already exists", name) + } + + roleToString[role] = name + stringToRole[name] = role + + return nil +} + type Roles []Role func NewRoles() *Roles { @@ -44,7 +61,8 @@ func (o *Roles) Add(roles ...Role) *Roles { } func isRole(r Role) bool { - return r == RoleManifestCreator + _, ok := roleToString[r] + return ok } // Valid iterates over the range of individual roles to check for validity diff --git a/corim/role_test.go b/corim/role_test.go index 66df92b0..d2d13cd3 100644 --- a/corim/role_test.go +++ b/corim/role_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestRoles_ToJSON_ok(t *testing.T) { @@ -77,3 +78,20 @@ func TestRoles_FromJSON_fail(t *testing.T) { assert.EqualError(t, err, tv.expectedErr) } } + +func Test_RegisterRole(t *testing.T) { + err := RegisterRole(1, "owner") + assert.EqualError(t, err, "role with value 1 already exists") + + err = RegisterRole(3, "manifestCreator") + assert.EqualError(t, err, `role with name "manifestCreator" already exists`) + + err = RegisterRole(3, "owner") + assert.NoError(t, err) + + roles := NewRoles().Add(Role(3)) + + out, err := roles.MarshalJSON() + require.NoError(t, err) + assert.Equal(t, `["owner"]`, string(out)) +} diff --git a/corim/unsignedcorim_test.go b/corim/unsignedcorim_test.go index d9b4d105..1fa581b0 100644 --- a/corim/unsignedcorim_test.go +++ b/corim/unsignedcorim_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/veraison/corim/comid" @@ -182,7 +181,7 @@ func TestUnsignedCorim_Valid_ok(t *testing.T) { AddAttestVerifKey( comid.AttestVerifKey{ Environment: comid.Environment{ - Instance: comid.NewInstanceUUID(uuid.UUID(comid.TestUUID)), + Instance: comid.MustNewUUIDInstance(comid.TestUUID), }, VerifKeys: *comid.NewCryptoKeys(). Add( @@ -264,7 +263,7 @@ func TestUnsignedCorim_AddEntity_full(t *testing.T) { expected := UnsignedCorim{ Entities: &Entities{ Entity{ - EntityName: name, + EntityName: MustNewStringEntityName(name), Roles: Roles{role}, RegID: &taggedRegID, }, diff --git a/encoding/json.go b/encoding/json.go index e8193e4b..f9b24f89 100644 --- a/encoding/json.go +++ b/encoding/json.go @@ -317,3 +317,35 @@ func skipValue(decoder *json.Decoder) error { } return nil } + +// TypeAndValue stores a JSON object with two attributes: a string "type" +// and a generic "value" (string) defined by type. This type is used in +// a few places to implement the choice types that CBOR handles using tags. +type TypeAndValue struct { + Type string `json:"type"` + Value json.RawMessage `json:"value"` +} + +func (o *TypeAndValue) UnmarshalJSON(data []byte) error { + var temp struct { + Type string `json:"type"` + Value json.RawMessage `json:"value"` + } + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + if temp.Type == "" { + return errors.New("type not set") + } + + if len(temp.Value) == 0 { + return fmt.Errorf("no value provided for %s", temp.Type) + } + + o.Type = temp.Type + o.Value = temp.Value + + return nil +} diff --git a/encoding/json_test.go b/encoding/json_test.go index 132d1587..85eca9ff 100644 --- a/encoding/json_test.go +++ b/encoding/json_test.go @@ -119,3 +119,34 @@ func Test_skipValue(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "other", token) } + +func TestTypeAndValue_UnmarshalJSON(t *testing.T) { + for _, tv := range []struct { + Input string + Expected TypeAndValue + Err string + }{ + { + Input: `{"type": "test", "value": "test"}`, + Expected: TypeAndValue{Type: "test", Value: []byte(`"test"`)}, + }, + { + Input: `{"type": "test"}`, + Err: "no value provided for test", + }, + { + Input: `{"value": "test"}`, + Err: "type not set", + }, + } { + var out TypeAndValue + err := out.UnmarshalJSON([]byte(tv.Input)) + + if tv.Err != "" { + assert.EqualError(t, err, tv.Err) + } else { + assert.NoError(t, err) + assert.Equal(t, tv.Expected, out) + } + } +} diff --git a/extensions/typechoice.go b/extensions/typechoice.go new file mode 100644 index 00000000..0a06ecb2 --- /dev/null +++ b/extensions/typechoice.go @@ -0,0 +1,33 @@ +package extensions + +import ( + "encoding/json" + + "github.com/veraison/corim/encoding" +) + +var StringType = "string" + +type ITypeChoiceValue interface { + // String returns the string representation of the ITypeChoiceValue. + String() string + // Valid returns an error if validation of the ITypeChoiceValue fails, + // or nil if it succeeds. + Valid() error + // Type returns the type name of this ITypeChoiceValue implementation. + Type() string +} + +func TypeChoiceValueMarshalJSON(v ITypeChoiceValue) ([]byte, error) { + valueBytes, err := json.Marshal(v) + if err != nil { + return nil, err + } + + value := encoding.TypeAndValue{ + Type: v.Type(), + Value: valueBytes, + } + + return json.Marshal(value) +}