From ba4cf94061d2bdaacae859300f2b1ff88b6f5efc Mon Sep 17 00:00:00 2001 From: Sergei Trofimov Date: Wed, 13 Sep 2023 17:48:30 +0100 Subject: [PATCH] WIP: typechoice (also adds tagged-int-type implementation for class-id) Signed-off-by: Sergei Trofimov --- comid/attestverifkey_test.go | 5 +- comid/cbor.go | 31 +- comid/ccaplatformconfigid.go | 75 ++++- comid/ccaplatformconfigid_test.go | 64 ++++ comid/class.go | 32 +- comid/classid.go | 450 ++++++++++++++++---------- comid/classid_test.go | 249 ++++++++++++-- comid/comid.go | 2 +- comid/cryptokey.go | 322 +++++++++++------- comid/cryptokey_test.go | 67 +++- comid/devidentitykey_test.go | 4 +- comid/entity.go | 213 +++++++++++- comid/entity_test.go | 175 ++++++++++ comid/environment_test.go | 28 +- comid/example_cca_refval_test.go | 18 +- comid/example_psa_keys_test.go | 7 +- comid/example_psa_refval_test.go | 14 +- comid/example_test.go | 36 +-- comid/group.go | 2 +- comid/instance.go | 245 +++++++------- comid/instance_test.go | 67 +++- comid/measurement.go | 520 +++++++++++++++++------------- comid/measurement_test.go | 321 +++++++++--------- comid/oid.go | 55 ++++ comid/psareferencevalue.go | 100 +++++- comid/psareferencevalue_test.go | 18 +- comid/referencevalue_test.go | 3 +- comid/role.go | 17 + comid/role_test.go | 17 + comid/svn.go | 251 ++++++++++---- comid/svn_test.go | 182 +++++++++++ comid/ueid.go | 71 +++- comid/ueid_test.go | 30 ++ comid/uuid.go | 91 +++++- comid/uuid_test.go | 24 ++ corim/cbor.go | 29 +- corim/entity.go | 212 +++++++++++- corim/entity_test.go | 183 ++++++++++- corim/extensions_test.go | 4 +- corim/role.go | 20 +- corim/role_test.go | 18 ++ corim/unsignedcorim_test.go | 5 +- encoding/json.go | 39 +++ encoding/json_test.go | 38 +++ extensions/typechoice.go | 13 + 45 files changed, 3346 insertions(+), 1021 deletions(-) create mode 100644 comid/svn_test.go create mode 100644 comid/ueid_test.go create mode 100644 comid/uuid_test.go create mode 100644 encoding/json.go create mode 100644 encoding/json_test.go create mode 100644 extensions/typechoice.go diff --git a/comid/attestverifkey_test.go b/comid/attestverifkey_test.go index 659a0abc..dd321e2d 100644 --- a/comid/attestverifkey_test.go +++ b/comid/attestverifkey_test.go @@ -23,16 +23,17 @@ func TestAttestVerifKey_Valid_empty(t *testing.T) { testerr: "environment validation failed: environment must not be empty", }, { - env: Environment{Instance: NewInstanceUEID(TestUEID)}, + env: Environment{Instance: MustNewInstanceUEID(TestUEID)}, verifkey: CryptoKeys{}, testerr: "verification keys validation failed: no keys to validate", }, { - env: Environment{Instance: NewInstanceUEID(TestUEID)}, + env: Environment{Instance: MustNewInstanceUEID(TestUEID)}, verifkey: CryptoKeys{&invalidKey}, testerr: "verification keys validation failed: invalid key at index 0: key value not set", }, } + for _, tv := range tvs { av := AttestVerifKey{Environment: tv.env, VerifKeys: tv.verifkey} err := av.Valid() diff --git a/comid/cbor.go b/comid/cbor.go index c44d518d..ab8a9233 100644 --- a/comid/cbor.go +++ b/comid/cbor.go @@ -4,6 +4,7 @@ package comid import ( + "fmt" "reflect" cbor "github.com/fxamacker/cbor/v2" @@ -12,16 +13,14 @@ import ( var ( em, emError = initCBOREncMode() dm, dmError = initCBORDecMode() -) -func comidTags() cbor.TagSet { - comidTagsMap := map[uint64]interface{}{ + comidTagsMap = map[uint64]interface{}{ 32: TaggedURI(""), 37: TaggedUUID{}, 111: TaggedOID{}, // CoMID tags 550: TaggedUEID{}, - //551: To Do see: https://github.com/veraison/corim/issues/32 + 551: TaggedInt(0), 552: TaggedSVN(0), 553: TaggedMinSVN(0), 554: TaggedPKIXBase64Key(""), @@ -37,7 +36,9 @@ func comidTags() cbor.TagSet { 601: TaggedPSARefValID{}, 602: TaggedCCAPlatformConfigID(""), } +) +func comidTags() cbor.TagSet { opts := cbor.TagOptions{ EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired, @@ -69,6 +70,28 @@ func initCBORDecMode() (dm cbor.DecMode, err error) { return decOpt.DecModeWithTags(comidTags()) } +func registerCOMIDTag(tag uint64, t interface{}) error { + if _, exists := comidTagsMap[tag]; exists { + return fmt.Errorf("tag %d is already registered", tag) + } + + comidTagsMap[tag] = t + + var err error + + em, err = initCBOREncMode() + if err != nil { + return err + } + + dm, err = initCBORDecMode() + if err != nil { + return err + } + + return nil +} + func init() { if emError != nil { panic(emError) diff --git a/comid/ccaplatformconfigid.go b/comid/ccaplatformconfigid.go index e485083f..a55271ff 100644 --- a/comid/ccaplatformconfigid.go +++ b/comid/ccaplatformconfigid.go @@ -3,11 +3,16 @@ package comid -import "fmt" +import ( + "encoding/json" + "errors" + "fmt" + "unicode/utf8" +) -type CCAPlatformConfigID string +var CCAPlatformConfigIDType = "cca.platform-config-id" -type TaggedCCAPlatformConfigID CCAPlatformConfigID +type CCAPlatformConfigID string func (o CCAPlatformConfigID) Empty() bool { return o == "" @@ -27,3 +32,67 @@ func (o CCAPlatformConfigID) Get() (CCAPlatformConfigID, error) { } return o, nil } + +type TaggedCCAPlatformConfigID CCAPlatformConfigID + +func NewTaggedCCAPlatformConfigID(val any) (*TaggedCCAPlatformConfigID, error) { + var ret TaggedCCAPlatformConfigID + + if val == nil { + return &ret, nil + } + + switch t := val.(type) { + case TaggedCCAPlatformConfigID: + ret = t + case *TaggedCCAPlatformConfigID: + ret = *t + case CCAPlatformConfigID: + ret = TaggedCCAPlatformConfigID(t) + case *CCAPlatformConfigID: + ret = TaggedCCAPlatformConfigID(*t) + case string: + ret = TaggedCCAPlatformConfigID(t) + case []byte: + if !utf8.Valid(t) { + return nil, errors.New("bytes do not form a valid UTF-8 string") + } + ret = TaggedCCAPlatformConfigID(t) + default: + return nil, fmt.Errorf("unexpected type for CCA platform-config-id: %T", t) + } + + return &ret, nil +} + +func (o TaggedCCAPlatformConfigID) Valid() error { + if o == "" { + return errors.New("empty value") + } + + return nil +} + +func (o TaggedCCAPlatformConfigID) String() string { + return string(o) +} + +func (o TaggedCCAPlatformConfigID) Type() string { + return CCAPlatformConfigIDType +} + +func (o TaggedCCAPlatformConfigID) IsZero() bool { + return len(o) == 0 +} + +func (o *TaggedCCAPlatformConfigID) UnmarshalJSON(data []byte) error { + var temp string + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + *o = TaggedCCAPlatformConfigID(temp) + + return nil +} diff --git a/comid/ccaplatformconfigid_test.go b/comid/ccaplatformconfigid_test.go index b0a0a37c..543390fd 100644 --- a/comid/ccaplatformconfigid_test.go +++ b/comid/ccaplatformconfigid_test.go @@ -29,3 +29,67 @@ func TestCCAPlatformConfigID_Get_nok(t *testing.T) { _, err := cca.Get() assert.EqualError(t, err, expectedErr) } + +func TestNewTaggedCCAPlatformConfigID(t *testing.T) { + testID := TaggedCCAPlatformConfigID("test") + untagged := CCAPlatformConfigID("test") + + for _, tv := range []struct { + Name string + Input any + Err string + Expected TaggedCCAPlatformConfigID + }{ + { + Name: "TaggedCCAPlatformConfigID ok", + Input: testID, + Expected: testID, + }, + { + Name: "*TaggedCCAPlatformConfigID ok", + Input: &testID, + Expected: testID, + }, + { + Name: "CCAPlatformConfigID ok", + Input: untagged, + Expected: testID, + }, + { + Name: "*CCAPlatformConfigID ok", + Input: &untagged, + Expected: testID, + }, + { + Name: "string ok", + Input: "test", + Expected: testID, + }, + { + Name: "[]byte ok", + Input: []byte{0x74, 0x65, 0x73, 0x74}, + Expected: testID, + }, + { + Name: "[]byte not ok", + Input: []byte{0x80, 0x65, 0x73, 0x74}, + Err: "bytes do not form a valid UTF-8 string", + }, + { + Name: "bad type", + Input: 7, + Err: "unexpected type for CCA platform-config-id: int", + }, + } { + t.Run(tv.Name, func(t *testing.T) { + out, err := NewTaggedCCAPlatformConfigID(tv.Input) + + if tv.Err != "" { + assert.Nil(t, out) + assert.EqualError(t, err, tv.Err) + } else { + assert.Equal(t, tv.Expected, *out) + } + }) + } +} diff --git a/comid/class.go b/comid/class.go index 085e703b..ab414740 100644 --- a/comid/class.go +++ b/comid/class.go @@ -23,39 +23,33 @@ type Class struct { // NewClassUUID instantiates a new Class object with the specified UUID as // identifier func NewClassUUID(uuid UUID) *Class { - c := Class{ - ClassID: &ClassID{}, - } - - if c.ClassID.SetUUID(uuid) == nil { + classID, err := NewUUIDClassID(uuid) + if err != nil { return nil } - return &c + + return &Class{ClassID: classID} } // NewClassImplID instantiates a new Class object that identifies the specified PSA // Implementation ID func NewClassImplID(implID ImplID) *Class { - c := Class{ - ClassID: &ClassID{}, - } - - if c.ClassID.SetImplID(implID) == nil { + classID, err := NewImplIDClassID(implID) + if err != nil { return nil } - return &c + + return &Class{ClassID: classID} } // NewClassOID instantiates a new Class object that identifies the OID func NewClassOID(oid string) *Class { - c := Class{ - ClassID: &ClassID{}, - } - - if c.ClassID.SetOID(oid) == nil { + classID, err := NewOIDClassID(oid) + if err != nil { return nil } - return &c + + return &Class{ClassID: classID} } // SetVendor sets the vendor metadata to the supplied string @@ -131,7 +125,7 @@ func (o *Class) SetIndex(index uint64) *Class { // Valid checks the non-empty<> constraint on the map func (o Class) Valid() error { // check non-empty<{ ... }> - if (o.ClassID == nil || o.ClassID.Unset()) && + if (o.ClassID == nil || !o.ClassID.IsSet()) && o.Vendor == nil && o.Model == nil && o.Layer == nil && o.Index == nil { return fmt.Errorf("class must not be empty") } diff --git a/comid/classid.go b/comid/classid.go index 760f8e5d..0b9c6ccf 100644 --- a/comid/classid.go +++ b/comid/classid.go @@ -5,245 +5,351 @@ package comid import ( "encoding/base64" + "encoding/binary" "encoding/json" + "errors" "fmt" + "strconv" + + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" ) +const ( + IntType = "int" + ImplIDType = "psa.impl-id" +) + +type IClassIDValue interface { + extensions.ITypeChoiceValue + + Bytes() []byte +} + // ClassID represents a $class-id-type-choice, which can be one of TaggedUUID, // TaggedOID, or TaggedImplID (PSA-specific extension) type ClassID struct { - val interface{} + Value IClassIDValue } -type ClassIDType uint16 +func (o ClassID) Valid() error { + if o.Value == nil { + return errors.New("nil value") + } -const ( - ClassIDTypeUUID = ClassIDType(iota) - ClassIDTypeImplID - ClassIDTypeOID + return o.Value.Valid() +} - ClassIDTypeUnknown = ^ClassIDType(0) -) +// Type returns the type of the ClassID +func (o ClassID) Type() string { + if o.Value == nil { + return "" + } + + return o.Value.Type() +} -// SetUUID sets the value of the targed ClassID to the supplied UUID -func (o *ClassID) SetUUID(uuid UUID) *ClassID { - if o != nil { - o.val = TaggedUUID(uuid) +// Bytes returns a []byte containing the raw bytes of the class id value +func (o ClassID) Bytes() []byte { + if o.Value == nil { + return []byte{} } - return o + return o.Value.Bytes() } -type ImplID [32]byte -type TaggedImplID ImplID +// IsSet returns true iff the underlying class id value has been set (is not nil) +func (o ClassID) IsSet() bool { + return o.Value != nil +} + +// MarshalCBOR serializes the target ClassID to CBOR +func (o ClassID) MarshalCBOR() ([]byte, error) { + return em.Marshal(o.Value) +} -func (o ImplID) MarshalJSON() ([]byte, error) { - return json.Marshal(o[:]) +// UnmarshalCBOR deserializes the supplied CBOR buffer into the target ClassID. +// It is undefined behavior to try and inspect the target ClassID in case this +// method returns an error. +func (o *ClassID) UnmarshalCBOR(data []byte) error { + return dm.Unmarshal(data, &o.Value) } -func (o *ImplID) UnmarshalJSON(data []byte) error { - var b []byte +// UnmarshalJSON deserializes the supplied JSON object into the target ClassID +// The class id object must have the following shape: +// +// { +// "type": "", +// "value": "" +// } +// +// where must be one of the known IClassIDValue implementation +// type names (available in the base implementation: "uuid", "oid", +// "psa.impl-id"), and is the class id value encoded as +// a string. The exact encoding is depenent. For the base +// implmentation types it is +// +// oid: dot-seprated integers, e.g. "1.2.3.4" +// psa.impl-id: base64-encoded bytes, e.g. "YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE=" +// uuid: standard UUID string representation, e.g. "550e8400-e29b-41d4-a716-446655440000" +func (o *ClassID) UnmarshalJSON(data []byte) error { + var value encoding.TypeAndValue - if err := json.Unmarshal(data, &b); err != nil { - return fmt.Errorf("bad ImplID: %w", err) + if err := json.Unmarshal(data, &value); err != nil { + return err } - if nb := len(b); nb != 32 { - return fmt.Errorf("bad ImplID format: got %d bytes, want 32", nb) + if value.Type == "" { + return errors.New("class id type not set") } - copy(o[:], b) + factory, ok := classIDValueRegister[value.Type] + if !ok { + return fmt.Errorf("unknown class id type: %q", value.Type) + } - return nil + var valueString string + if err := json.Unmarshal(value.Value, &valueString); err != nil { + return err + } + + v, err := factory(valueString) + if err != nil { + return err + } + + o.Value = v.Value + + return o.Valid() } -type TaggedOID OID +// MarshalJSON serializes the target ClassID to JSON +func (o ClassID) MarshalJSON() ([]byte, error) { + valueBytes, err := json.Marshal(o.Value.String()) + if err != nil { + return nil, err + } -// SetImplID sets the value of the targed ClassID to the supplied PSA -// Implementation ID (see Section 3.2.2 of draft-tschofenig-rats-psa-token) -func (o *ClassID) SetImplID(implID ImplID) *ClassID { - if o != nil { - o.val = TaggedImplID(implID) + value := encoding.TypeAndValue{ + Type: o.Value.Type(), + Value: valueBytes, } - return o + + return json.Marshal(value) +} + +// String returns a printable string of the ClassID value. UUIDs use the +// canonical 8-4-4-4-12 format, PSA Implementation IDs are base64 encoded. +// OIDs are output in dotted-decimal notation. +func (o ClassID) String() string { + return o.Value.String() +} + +type ImplID [32]byte + +func (o ImplID) String() string { + return base64.StdEncoding.EncodeToString(o[:]) } -func (o ClassID) GetImplID() (ImplID, error) { - switch t := o.val.(type) { +func (o ImplID) Valid() error { + return nil +} + +type TaggedImplID ImplID + +func NewImplIDClassID(val any) (*ClassID, error) { + var ret TaggedImplID + + if val == nil { + return &ClassID{&TaggedImplID{}}, nil + } + + switch t := val.(type) { + case []byte: + if nb := len(t); nb != 32 { + return nil, fmt.Errorf("bad ImplID: got %d bytes, want 32", nb) + } + + copy(ret[:], t) + case string: + v, err := base64.StdEncoding.DecodeString(t) + if err != nil { + return nil, fmt.Errorf("bad ImplID: %w", err) + } + + if nb := len(v); nb != 32 { + return nil, fmt.Errorf("bad ImplID: decoded %d bytes, want 32", nb) + } + + copy(ret[:], v) case TaggedImplID: - return ImplID(t), nil + copy(ret[:], t[:]) + case *TaggedImplID: + copy(ret[:], (*t)[:]) + case ImplID: + copy(ret[:], t[:]) + case *ImplID: + copy(ret[:], (*t)[:]) default: - return ImplID{}, fmt.Errorf("class-id type is: %T", t) + return nil, fmt.Errorf("unexpected type for ImplID: %T", t) } + + return &ClassID{&ret}, nil } -// SetOID sets the value of the targed ClassID to the supplied OID. -// The OID is a string in dotted-decimal notation -func (o *ClassID) SetOID(s string) *ClassID { - if o != nil { - var berOID OID - if berOID.FromString(s) != nil { - return nil - } - o.val = TaggedOID(berOID) +func MustNewImplIDClassID(val any) *ClassID { + ret, err := NewImplIDClassID(val) + if err != nil { + panic(err) } - return o + + return ret } -// MarshalCBOR serializes the target ClassID to CBOR -func (o ClassID) MarshalCBOR() ([]byte, error) { - return em.Marshal(o.val) +func (o TaggedImplID) Valid() error { + return ImplID(o).Valid() } -// UnmarshalCBOR deserializes the supplied CBOR buffer into the target ClassID. -// It is undefined behavior to try and inspect the target ClassID in case this -// method returns an error. -func (o *ClassID) UnmarshalCBOR(data []byte) error { - var implID TaggedImplID +func (o TaggedImplID) String() string { + return ImplID(o).String() +} - if dm.Unmarshal(data, &implID) == nil { - o.val = implID - return nil - } +func (o TaggedImplID) Type() string { + return ImplIDType +} - var uuid TaggedUUID +func (o TaggedImplID) Bytes() []byte { + return o[:] +} - if dm.Unmarshal(data, &uuid) == nil { - o.val = uuid - return nil +func NewOIDClassID(val any) (*ClassID, error) { + ret, err := NewTaggedOID(val) + if err != nil { + return nil, err } - var oid TaggedOID + return &ClassID{ret}, nil +} - if dm.Unmarshal(data, &oid) == nil { - o.val = oid - return nil +func MustNewOIDClassID(val any) *ClassID { + ret, err := NewOIDClassID(val) + if err != nil { + panic(err) } - return fmt.Errorf("unknown class id (CBOR: %x)", data) + return ret } -// UnmarshalJSON deserializes the supplied JSON object into the target ClassID -// The class id object must have one of the following shapes: -// -// UUID: -// -// { -// "type": "uuid", -// "value": "69E027B2-7157-4758-BCB4-D9F167FE49EA" -// } -// -// OID: -// -// { -// "type": "oid", -// "value": "2.16.840.1.113741.1.15.4.2" -// } -// -// PSA Implementation ID: -// -// { -// "type": "psa.impl-id", -// "value": "YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE=" -// } -func (o *ClassID) UnmarshalJSON(data []byte) error { - var v tnv +func NewUUIDClassID(val any) (*ClassID, error) { + if val == nil { + return &ClassID{&TaggedUUID{}}, nil + } - if err := json.Unmarshal(data, &v); err != nil { - return err + ret, err := NewTaggedUUID(val) + if err != nil { + return nil, err } - switch v.Type { - case "uuid": // nolint: goconst - var x UUID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedUUID(x) - case "oid": - var x OID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedOID(x) - case "psa.impl-id": - var x ImplID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedImplID(x) - default: - return fmt.Errorf("unknown type '%s' for class id", v.Type) + return &ClassID{ret}, nil +} + +func MustNewUUIDClassID(val any) *ClassID { + ret, err := NewUUIDClassID(val) + if err != nil { + panic(err) } - return nil + return ret } -// MarshalJSON serializes the target ClassID to JSON -func (o ClassID) MarshalJSON() ([]byte, error) { - var ( - v tnv - b []byte - err error - ) - - switch t := o.val.(type) { - case TaggedUUID: - b, err = UUID(t).MarshalJSON() - if err != nil { - return nil, err - } - v = tnv{Type: "uuid", Value: b} - case TaggedOID: - b, err = OID(t).MarshalJSON() +type TaggedInt int + +func NewIntClassID(val any) (*ClassID, error) { + if val == nil { + zeroVal := TaggedInt(0) + return &ClassID{&zeroVal}, nil + } + + var ret TaggedInt + + switch t := val.(type) { + case string: + i, err := strconv.Atoi(t) if err != nil { - return nil, err + return nil, fmt.Errorf("bad int: %w", err) } - v = tnv{Type: "oid", Value: b} - case TaggedImplID: - b, err = ImplID(t).MarshalJSON() - if err != nil { - return nil, err + ret = TaggedInt(i) + case []byte: + if len(t) != 8 { + return nil, fmt.Errorf("bad int: want 8 bytes, got %d bytes", len(t)) } - v = tnv{Type: "psa.impl-id", Value: b} + ret = TaggedInt(binary.BigEndian.Uint64(t)) + case int: + ret = TaggedInt(t) + case *int: + ret = TaggedInt(*t) + case int64: + ret = TaggedInt(t) + case *int64: + ret = TaggedInt(*t) + case uint64: + ret = TaggedInt(t) + case *uint64: + ret = TaggedInt(*t) default: - return nil, fmt.Errorf("unknown type %T for class-id", t) + return nil, fmt.Errorf("unexpected type for int: %T", t) + } + + if err := ret.Valid(); err != nil { + return nil, err } - return json.Marshal(v) + return &ClassID{&ret}, nil } -// Type returns the type of the target ClassID, i.e., one of UUID, OID or PSA -// Implementation ID -func (o ClassID) Type() ClassIDType { - switch o.val.(type) { - case TaggedUUID: - return ClassIDTypeUUID - case TaggedImplID: - return ClassIDTypeImplID - case TaggedOID: - return ClassIDTypeOID - } - return ClassIDTypeUnknown +func (o TaggedInt) String() string { + return fmt.Sprint(int(o)) } -// String returns a printable string of the ClassID value. UUIDs use the -// canonical 8-4-4-4-12 format, PSA Implementation IDs are base64 encoded. -// OIDs are output in dotted-decimal notation. -func (o ClassID) String() string { - switch t := o.val.(type) { - case TaggedUUID: - return UUID(t).String() - case TaggedImplID: - b := [32]byte(t) - return base64.StdEncoding.EncodeToString(b[:]) - case TaggedOID: - return OID(t).String() - default: - return "" - } +func (o TaggedInt) Valid() error { + return nil } -// Unset tests whether the target ClassID has been initialized -func (o ClassID) Unset() bool { - return o.val == nil || o.Type() == ClassIDTypeUnknown +func (o TaggedInt) Type() string { + return "int" +} + +func (o TaggedInt) Bytes() []byte { + var ret [8]byte + binary.BigEndian.PutUint64(ret[:], uint64(o)) + return ret[:] +} + +type IClassIDFactory func(any) (*ClassID, error) + +var classIDValueRegister = map[string]IClassIDFactory{ + OIDType: NewOIDClassID, + UUIDType: NewUUIDClassID, + IntType: NewIntClassID, + + ImplIDType: NewImplIDClassID, +} + +func RegisterClassIDType(tag uint64, factory IClassIDFactory) error { + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Type() + if _, exists := classIDValueRegister[typ]; exists { + return fmt.Errorf("class ID type with name %q already exists", typ) + } + + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + classIDValueRegister[typ] = factory + + return nil } diff --git a/comid/classid_test.go b/comid/classid_test.go index 51b07391..9b0a7bf8 100644 --- a/comid/classid_test.go +++ b/comid/classid_test.go @@ -4,6 +4,8 @@ package comid import ( + "encoding/binary" + "encoding/json" "fmt" "testing" @@ -12,9 +14,7 @@ import ( ) func TestClassID_MarshalCBOR_UUID(t *testing.T) { - var tv ClassID - - require.NotNil(t, tv.SetUUID(TestUUID)) + tv := MustNewUUIDClassID(TestUUID) // 37(h'31FB5ABF023E4992AA4E95F9C1503BFA') // tag(37): d8 25 @@ -29,9 +29,7 @@ func TestClassID_MarshalCBOR_UUID(t *testing.T) { } func TestClassID_MarshalCBOR_ImplID(t *testing.T) { - var tv ClassID - - require.NotNil(t, tv.SetImplID(TestImplID)) + tv := MustNewImplIDClassID(TestImplID) // 600 (h'61636D652D696D706C656D656E746174696F6E2D69642D303030303030303031') // tag(600): d9 0258 @@ -66,7 +64,7 @@ func TestClassID_UnmarshalCBOR_UUID_OK(t *testing.T) { err := actual.UnmarshalCBOR(tv) assert.Nil(t, err) - assert.Equal(t, ClassIDTypeUUID, actual.Type()) + assert.Equal(t, "uuid", actual.Type()) assert.Equal(t, TestUUIDString, actual.String()) } @@ -79,7 +77,7 @@ func TestClassID_UnmarshalCBOR_ImplID_OK(t *testing.T) { err := actual.UnmarshalCBOR(tv) assert.Nil(t, err) - assert.Equal(t, ClassIDTypeImplID, actual.Type()) + assert.Equal(t, "psa.impl-id", actual.Type()) assert.Equal(t, expected, actual.String()) } @@ -88,12 +86,10 @@ func TestClassID_UnmarshalCBOR_badInput(t *testing.T) { hex := "582061636d652d696d706c656d656e746174696f6e2d69642d303030303030303031" tv := MustHexDecode(t, hex) - expectedError := fmt.Sprintf("unknown class id (CBOR: %s)", hex) - var actual ClassID err := actual.UnmarshalCBOR(tv) - assert.EqualError(t, err, expectedError) + assert.EqualError(t, err, "cbor: cannot unmarshal byte string into Go value of type comid.IClassIDValue") } func TestClassID_UnmarshalJSON_UUID(t *testing.T) { @@ -105,7 +101,7 @@ func TestClassID_UnmarshalJSON_UUID(t *testing.T) { err := actual.UnmarshalJSON([]byte(tv)) assert.Nil(t, err) - assert.Equal(t, ClassIDTypeUUID, actual.Type()) + assert.Equal(t, "uuid", actual.Type()) assert.Equal(t, TestUUIDString, actual.String()) } @@ -122,7 +118,7 @@ func TestClassID_UnmarshalJSON_ImplID(t *testing.T) { err := actual.UnmarshalJSON([]byte(tv)) assert.Nil(t, err) - assert.Equal(t, ClassIDTypeImplID, actual.Type()) + assert.Equal(t, "psa.impl-id", actual.Type()) // the returned string is the base64 encoding of the stored binary assert.Equal(t, expected, actual.String()) } @@ -133,8 +129,8 @@ func TestClassID_UnmarshalJSON_badInput_unknown_type(t *testing.T) { var actual ClassID err := actual.UnmarshalJSON([]byte(tv)) - assert.EqualError(t, err, "unknown type 'FOOBAR' for class id") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.EqualError(t, err, `unknown class id type: "FOOBAR"`) + assert.Equal(t, "", actual.Type()) } func TestClassID_UnmarshalJSON_badInput_missing_value(t *testing.T) { @@ -143,8 +139,8 @@ func TestClassID_UnmarshalJSON_badInput_missing_value(t *testing.T) { var actual ClassID err := actual.UnmarshalJSON([]byte(tv)) - assert.EqualError(t, err, "bad ImplID: unexpected end of JSON input") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.EqualError(t, err, "no value provided for psa.impl-id") + assert.Equal(t, "", actual.Type()) } func TestClassID_UnmarshalJSON_badInput_empty_value(t *testing.T) { @@ -153,8 +149,8 @@ func TestClassID_UnmarshalJSON_badInput_empty_value(t *testing.T) { var actual ClassID err := actual.UnmarshalJSON([]byte(tv)) - assert.EqualError(t, err, "bad ImplID format: got 0 bytes, want 32") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.EqualError(t, err, "bad ImplID: decoded 0 bytes, want 32") + assert.Equal(t, "", actual.Type()) } func TestClassID_UnmarshalJSON_badInput_badly_encoded_ImplID_value(t *testing.T) { @@ -164,7 +160,7 @@ func TestClassID_UnmarshalJSON_badInput_badly_encoded_ImplID_value(t *testing.T) err := actual.UnmarshalJSON([]byte(tv)) assert.EqualError(t, err, "bad ImplID: illegal base64 data at input byte 0") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.Equal(t, "", actual.Type()) } func TestClassID_UnmarshalJSON_badInput_badly_encoded_UUID_value(t *testing.T) { @@ -174,7 +170,7 @@ func TestClassID_UnmarshalJSON_badInput_badly_encoded_UUID_value(t *testing.T) { err := actual.UnmarshalJSON([]byte(tv)) assert.EqualError(t, err, "bad UUID: invalid UUID length: 9") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.Equal(t, "", actual.Type()) } func TestClassID_SetOID_ok(t *testing.T) { @@ -200,8 +196,7 @@ func TestClassID_SetOID_ok(t *testing.T) { } for _, tv := range tvs { - c := ClassID{} - assert.NotNil(t, c.SetOID(tv)) + c := MustNewOIDClassID(tv) assert.Equal(t, tv, c.String()) } } @@ -219,7 +214,211 @@ func TestClassID_SetOID_bad(t *testing.T) { } for _, tv := range tvs { - c := ClassID{} - assert.Nil(t, c.SetOID(tv)) + c, err := NewOIDClassID(tv) + assert.NotNil(t, err) + assert.Nil(t, c) + } +} + +func Test_NewImplIDClassID(t *testing.T) { + classID, err := NewImplIDClassID(nil) + expected := [32]byte{} + require.NoError(t, err) + assert.Equal(t, expected[:], classID.Bytes()) + + taggedImplID := TaggedImplID(TestImplID) + + for _, v := range []any{ + TestImplID, + &TestImplID, + taggedImplID, + &taggedImplID, + taggedImplID.Bytes(), + } { + classID, err = NewImplIDClassID(v) + require.NoError(t, err) + assert.Equal(t, taggedImplID.Bytes(), classID.Bytes()) + } + + expected = [32]byte{ + 0x61, 0x63, 0x6d, 0x65, 0x2d, 0x69, 0x6d, 0x70, + 0x6c, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x2d, 0x69, 0x64, 0x2d, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x31, } + classID, err = NewImplIDClassID("YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE=") + require.NoError(t, err) + assert.Equal(t, expected[:], classID.Bytes()) + + _, err = NewImplIDClassID(7) + assert.EqualError(t, err, "unexpected type for ImplID: int") +} + +func Test_NewUUIDClassID(t *testing.T) { + classID, err := NewUUIDClassID(nil) + + expected := [16]byte{} + require.NoError(t, err) + assert.Equal(t, expected[:], classID.Bytes()) + + taggedUUID := TaggedUUID(TestUUID) + + for _, v := range []any{ + TestUUID, + &TestUUID, + taggedUUID, + &taggedUUID, + taggedUUID.Bytes(), + } { + classID, err = NewUUIDClassID(v) + require.NoError(t, err) + assert.Equal(t, taggedUUID.Bytes(), classID.Bytes()) + } + + classID, err = NewUUIDClassID(taggedUUID.String()) + require.NoError(t, err) + assert.Equal(t, taggedUUID.Bytes(), classID.Bytes()) +} + +func Test_NewOIDClassID(t *testing.T) { + classID, err := NewOIDClassID(nil) + + expected := []byte{} + require.NoError(t, err) + assert.Equal(t, expected, classID.Bytes()) + + var oid OID + require.NoError(t, oid.FromString(TestOID)) + taggedOID := TaggedOID(oid) + + for _, v := range []any{ + TestOID, + oid, + &oid, + taggedOID, + &taggedOID, + taggedOID.Bytes(), + } { + classID, err = NewOIDClassID(v) + require.NoError(t, err) + expected := taggedOID.Bytes() + got := classID.Bytes() + assert.Equal(t, expected, got) + } + + classID, err = NewOIDClassID(taggedOID.String()) + require.NoError(t, err) + assert.Equal(t, taggedOID.Bytes(), classID.Bytes()) +} + +func Test_NewIntClassID(t *testing.T) { + classID, err := NewIntClassID(nil) + require.NoError(t, err) + assert.Equal(t, []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, classID.Bytes()) + + testInt := 7 + testInt64 := int64(7) + testUint64 := uint64(7) + + var testBytes [8]byte + binary.BigEndian.PutUint64(testBytes[:], testUint64) + + for _, v := range []any{ + testInt, + &testInt, + testInt64, + &testInt64, + testUint64, + &testUint64, + "7", + testBytes[:], + } { + classID, err = NewIntClassID(v) + require.NoError(t, err) + got := classID.Bytes() + assert.Equal(t, testBytes[:], got) + } +} + +func Test_TaggedInt(t *testing.T) { + val := TaggedInt(7) + assert.Equal(t, "7", val.String()) + assert.Equal(t, []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07}, val.Bytes()) + assert.Equal(t, "int", val.Type()) + assert.NoError(t, val.Valid()) + + classID := ClassID{&val} + + bytes, err := em.Marshal(classID) + require.NoError(t, err) + assert.Equal(t, []byte{ + 0xd9, 0x02, 0x27, // tag 551 + 0x07, // int 7 + }, bytes) + + var out ClassID + err = dm.Unmarshal(bytes, &out) + require.NoError(t, err) + assert.Equal(t, classID, out) + + jsonBytes, err := json.Marshal(classID) + require.NoError(t, err) + assert.Equal(t, `{"type":"int","value":"7"}`, string(jsonBytes)) + + out = ClassID{} + err = json.Unmarshal(jsonBytes, &out) + require.NoError(t, err) + assert.Equal(t, classID, out) +} + +type testClassID [4]byte + +func newTestClassID(val any) (*ClassID, error) { + return &ClassID{&testClassID{0x74, 0x65, 0x73, 0x74}}, nil +} + +func (o testClassID) Bytes() []byte { + return o[:] +} + +func (o testClassID) Type() string { + return "test-class-id" +} + +func (o testClassID) String() string { + return "test" +} + +func (o testClassID) Valid() error { + return nil +} + +func Test_RegisterClassIDType(t *testing.T) { + err := RegisterClassIDType(99999, newTestClassID) + require.NoError(t, err) + + classID, err := newTestClassID(nil) + require.NoError(t, err) + + data, err := json.Marshal(classID) + require.NoError(t, err) + assert.Equal(t, string(data), `{"type":"test-class-id","value":"test"}`) + + var out ClassID + err = json.Unmarshal(data, &out) + require.NoError(t, err) + assert.Equal(t, classID.Bytes(), out.Bytes()) + + data, err = em.Marshal(classID) + require.NoError(t, err) + assert.Equal(t, data, []byte{ + 0xda, 0x0, 0x1, 0x86, 0x9f, // tag 99999 + 0x44, // bstr(4) + 0x74, 0x65, 0x73, 0x74, // "test" + }) + + var out2 ClassID + err = dm.Unmarshal(data, &out2) + require.NoError(t, err) + assert.Equal(t, classID.Bytes(), out2.Bytes()) } diff --git a/comid/comid.go b/comid/comid.go index 63ae6969..50781985 100644 --- a/comid/comid.go +++ b/comid/comid.go @@ -120,7 +120,7 @@ func (o *Comid) AddEntity(name string, regID *string, roles ...Role) *Comid { } e := Entity{ - EntityName: name, + EntityName: MustNewStringEntityName(name), RegID: uri, Roles: rs, } diff --git a/comid/cryptokey.go b/comid/cryptokey.go index 53d99df8..c2d1879c 100644 --- a/comid/cryptokey.go +++ b/comid/cryptokey.go @@ -14,6 +14,8 @@ import ( "fmt" "github.com/fxamacker/cbor/v2" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" "github.com/veraison/go-cose" "github.com/veraison/swid" ) @@ -58,50 +60,12 @@ type CryptoKey struct { // specified crypto key type. For PKIX types, k must be a string. For COSE_Key, // k must be a []byte. For thumbprint types, k must be a swid.HashEntry. func NewCryptoKey(k any, typ string) (*CryptoKey, error) { - switch typ { - case PKIXBase64KeyType: - v, ok := k.(string) - if !ok { - return nil, fmt.Errorf("value must be a string; found %T", k) - } - return NewPKIXBase64Key(v) - case PKIXBase64CertType: - v, ok := k.(string) - if !ok { - return nil, fmt.Errorf("value must be a string; found %T", k) - } - return NewPKIXBase64Cert(v) - case PKIXBase64CertPathType: - v, ok := k.(string) - if !ok { - return nil, fmt.Errorf("value must be a string; found %T", k) - } - return NewPKIXBase64CertPath(v) - case COSEKeyType: - v, ok := k.([]byte) - if !ok { - return nil, fmt.Errorf("value must be a []byte; found %T", k) - } - return NewCOSEKey(v) - case ThumbprintType, CertThumbprintType, CertPathThumbprintType: - v, ok := k.(swid.HashEntry) - if !ok { - return nil, fmt.Errorf("value must be a swid.HashEntry; found %T", k) - } - switch typ { - case ThumbprintType: - return NewThumbprint(v) - case CertThumbprintType: - return NewCertThumbprint(v) - case CertPathThumbprintType: - return NewCertPathThumbprint(v) - default: - // Should never here because of the the outer case clause - panic(fmt.Sprintf("unexpected thumbprint type: %s", typ)) - } - default: + factory, ok := cryptoKeyValueRegister[typ] + if !ok { return nil, fmt.Errorf("unexpected CryptoKey type: %s", typ) } + + return factory(k) } // MustNewCryptoKey is the same as NewCryptoKey, but does not return an error, @@ -126,6 +90,11 @@ func (o CryptoKey) Valid() error { return o.Value.Valid() } +// Type returns the type of the CryptoKey value +func (o CryptoKey) Type() string { + return o.Value.Type() +} + // PublicKey returns a crypto.PublicKey constructed from the CryptoKey's // underlying value. This returns an error if the CryptoKey is one of the // thumbprint types. @@ -136,30 +105,14 @@ func (o CryptoKey) PublicKey() (crypto.PublicKey, error) { // MarshalJSON returns a []byte containing the JSON representation of the // CryptoKey. func (o CryptoKey) MarshalJSON() ([]byte, error) { - value := struct { - Type string `json:"type"` - Value string `json:"value"` - }{ - Value: o.Value.String(), - } - - switch o.Value.(type) { - case TaggedPKIXBase64Key: - value.Type = PKIXBase64KeyType - case TaggedPKIXBase64Cert: - value.Type = PKIXBase64CertType - case TaggedPKIXBase64CertPath: - value.Type = PKIXBase64CertPathType - case TaggedCOSEKey: - value.Type = COSEKeyType - case TaggedThumbprint: - value.Type = ThumbprintType - case TaggedCertThumbprint: - value.Type = CertThumbprintType - case TaggedCertPathThumbprint: - value.Type = CertPathThumbprintType - default: - return nil, fmt.Errorf("unexpected ICryptoKeyValue type: %T", o.Value) + valueBytes, err := json.Marshal(o.Value.String()) + if err != nil { + return nil, err + } + + value := encoding.TypeAndValue{ + Type: o.Value.Type(), + Value: valueBytes, } return json.Marshal(value) @@ -168,10 +121,7 @@ func (o CryptoKey) MarshalJSON() ([]byte, error) { // UnmarshalJSON populates the CryptoKey from the JSON representation inside // the provided []byte. func (o *CryptoKey) UnmarshalJSON(b []byte) error { - var value struct { - Type string `json:"type"` - Value string `json:"value"` - } + var value encoding.TypeAndValue if err := json.Unmarshal(b, &value); err != nil { return err @@ -181,36 +131,23 @@ func (o *CryptoKey) UnmarshalJSON(b []byte) error { return errors.New("key type not set") } - switch value.Type { - case PKIXBase64KeyType: - o.Value = TaggedPKIXBase64Key(value.Value) - case PKIXBase64CertType: - o.Value = TaggedPKIXBase64Cert(value.Value) - case PKIXBase64CertPathType: - o.Value = TaggedPKIXBase64CertPath(value.Value) - case COSEKeyType: - data, err := base64.StdEncoding.DecodeString(value.Value) - if err != nil { - return fmt.Errorf("base64 decode error: %w", err) - } - o.Value = TaggedCOSEKey(data) - case ThumbprintType, CertThumbprintType, CertPathThumbprintType: - he, err := swid.ParseHashEntry(value.Value) - if err != nil { - return fmt.Errorf("swid.HashEntry decode error: %w", err) - } - switch value.Type { - case ThumbprintType: - o.Value = TaggedThumbprint{digest{he}} - case CertThumbprintType: - o.Value = TaggedCertThumbprint{digest{he}} - case CertPathThumbprintType: - o.Value = TaggedCertPathThumbprint{digest{he}} - } - default: + factory, ok := cryptoKeyValueRegister[value.Type] + if !ok { return fmt.Errorf("unexpected ICryptoKeyValue type: %q", value.Type) } + var valueString string + if err := json.Unmarshal(value.Value, &valueString); err != nil { + return err + } + + k, err := factory(valueString) + if err != nil { + return err + } + + o.Value = k.Value + return o.Valid() } @@ -229,11 +166,8 @@ func (o *CryptoKey) UnmarshalCBOR(b []byte) error { // ICryptoKeyValue is the interface implemented by the concrete CryptoKey value // types. type ICryptoKeyValue interface { - // String returns the string representation of the ICryptoKeyValue. - String() string - // Valid returns an error if validation of the ICryptoKeyValue fails, - // or nil if it succeeds. - Valid() error + extensions.ITypeChoiceValue + // PublicKey returns a crypto.PublicKey constructed from the // ICryptoKeyValue's underlying value. This returns an error if the // ICryptoKeyValue is one of the thumbprint types. @@ -244,7 +178,12 @@ type ICryptoKeyValue interface { // https://www.rfc-editor.org/rfc/rfc7468#section-13 type TaggedPKIXBase64Key string -func NewPKIXBase64Key(s string) (*CryptoKey, error) { +func NewPKIXBase64Key(k any) (*CryptoKey, error) { + s, ok := k.(string) + if !ok { + return nil, fmt.Errorf("value must be a string; found %T", k) + } + key := TaggedPKIXBase64Key(s) if err := key.Valid(); err != nil { return nil, err @@ -252,8 +191,8 @@ func NewPKIXBase64Key(s string) (*CryptoKey, error) { return &CryptoKey{key}, nil } -func MustNewPKIXBase64Key(s string) *CryptoKey { - key, err := NewPKIXBase64Key(s) +func MustNewPKIXBase64Key(k any) *CryptoKey { + key, err := NewPKIXBase64Key(k) if err != nil { panic(err) } @@ -269,6 +208,10 @@ func (o TaggedPKIXBase64Key) Valid() error { return err } +func (o TaggedPKIXBase64Key) Type() string { + return PKIXBase64KeyType +} + func (o TaggedPKIXBase64Key) PublicKey() (crypto.PublicKey, error) { if string(o) == "" { return nil, errors.New("key value not set") @@ -302,7 +245,12 @@ func (o TaggedPKIXBase64Key) PublicKey() (crypto.PublicKey, error) { // certificate. See https://www.rfc-editor.org/rfc/rfc7468#section-5 type TaggedPKIXBase64Cert string -func NewPKIXBase64Cert(s string) (*CryptoKey, error) { +func NewPKIXBase64Cert(k any) (*CryptoKey, error) { + s, ok := k.(string) + if !ok { + return nil, fmt.Errorf("value must be a string; found %T", k) + } + cert := TaggedPKIXBase64Cert(s) if err := cert.Valid(); err != nil { return nil, err @@ -310,8 +258,8 @@ func NewPKIXBase64Cert(s string) (*CryptoKey, error) { return &CryptoKey{cert}, nil } -func MustNewPKIXBase64Cert(s string) *CryptoKey { - cert, err := NewPKIXBase64Cert(s) +func MustNewPKIXBase64Cert(k any) *CryptoKey { + cert, err := NewPKIXBase64Cert(k) if err != nil { panic(err) } @@ -327,6 +275,10 @@ func (o TaggedPKIXBase64Cert) Valid() error { return err } +func (o TaggedPKIXBase64Cert) Type() string { + return PKIXBase64CertType +} + func (o TaggedPKIXBase64Cert) PublicKey() (crypto.PublicKey, error) { cert, err := o.cert() if err != nil { @@ -375,7 +327,11 @@ func (o TaggedPKIXBase64Cert) cert() (*x509.Certificate, error) { // directly certifies the one preceding. type TaggedPKIXBase64CertPath string -func NewPKIXBase64CertPath(s string) (*CryptoKey, error) { +func NewPKIXBase64CertPath(k any) (*CryptoKey, error) { + s, ok := k.(string) + if !ok { + return nil, fmt.Errorf("value must be a string; found %T", k) + } cert := TaggedPKIXBase64CertPath(s) if err := cert.Valid(); err != nil { @@ -385,8 +341,8 @@ func NewPKIXBase64CertPath(s string) (*CryptoKey, error) { return &CryptoKey{cert}, nil } -func MustNewPKIXBase64CertPath(s string) *CryptoKey { - cert, err := NewPKIXBase64CertPath(s) +func MustNewPKIXBase64CertPath(k any) *CryptoKey { + cert, err := NewPKIXBase64CertPath(k) if err != nil { panic(err) @@ -404,6 +360,10 @@ func (o TaggedPKIXBase64CertPath) Valid() error { return err } +func (o TaggedPKIXBase64CertPath) Type() string { + return PKIXBase64CertPathType +} + func (o TaggedPKIXBase64CertPath) PublicKey() (crypto.PublicKey, error) { certs, err := o.certPath() if err != nil { @@ -468,7 +428,22 @@ func (o TaggedPKIXBase64CertPath) certPath() ([]*x509.Certificate, error) { // https://www.rfc-editor.org/rfc/rfc9052#section-7 type TaggedCOSEKey []byte -func NewCOSEKey(b []byte) (*CryptoKey, error) { +func NewCOSEKey(k any) (*CryptoKey, error) { + var b []byte + var err error + + switch t := k.(type) { + case []byte: + b = t + case string: + b, err = base64.StdEncoding.DecodeString(t) + if err != nil { + return nil, fmt.Errorf("base64 decode error: %w", err) + } + default: + return nil, fmt.Errorf("value must be a []byte or a string; found %T", k) + } + key := TaggedCOSEKey(b) if err := key.Valid(); err != nil { @@ -478,8 +453,8 @@ func NewCOSEKey(b []byte) (*CryptoKey, error) { return &CryptoKey{key}, nil } -func MustNewCOSEKey(b []byte) *CryptoKey { - key, err := NewCOSEKey(b) +func MustNewCOSEKey(k any) *CryptoKey { + key, err := NewCOSEKey(k) if err != nil { panic(err) @@ -509,6 +484,10 @@ func (o TaggedCOSEKey) Valid() error { return err } +func (o TaggedCOSEKey) Type() string { + return COSEKeyType +} + func (o TaggedCOSEKey) PublicKey() (crypto.PublicKey, error) { if len(o) == 0 { return nil, errors.New("empty COSE_Key value") @@ -608,7 +587,22 @@ type TaggedThumbprint struct { digest } -func NewThumbprint(he swid.HashEntry) (*CryptoKey, error) { +func NewThumbprint(k any) (*CryptoKey, error) { + var he swid.HashEntry + var err error + + switch t := k.(type) { + case string: + he, err = swid.ParseHashEntry(t) + if err != nil { + return nil, fmt.Errorf("swid.HashEntry decode error: %w", err) + } + case swid.HashEntry: + he = t + default: + return nil, fmt.Errorf("value must be a swid.HashEntry or a string; found %T", k) + } + key := &CryptoKey{TaggedThumbprint{digest{he}}} if err := key.Valid(); err != nil { @@ -618,8 +612,8 @@ func NewThumbprint(he swid.HashEntry) (*CryptoKey, error) { return key, nil } -func MustNewThumbprint(he swid.HashEntry) *CryptoKey { - key, err := NewThumbprint(he) +func MustNewThumbprint(k any) *CryptoKey { + key, err := NewThumbprint(k) if err != nil { panic(err) @@ -628,13 +622,32 @@ func MustNewThumbprint(he swid.HashEntry) *CryptoKey { return key } +func (o TaggedThumbprint) Type() string { + return ThumbprintType +} + // TaggedCertThumbprint represents a digest of a certificate. The digest value // may be used to find the certificate if contained in a lookup table. type TaggedCertThumbprint struct { digest } -func NewCertThumbprint(he swid.HashEntry) (*CryptoKey, error) { +func NewCertThumbprint(k any) (*CryptoKey, error) { + var he swid.HashEntry + var err error + + switch t := k.(type) { + case string: + he, err = swid.ParseHashEntry(t) + if err != nil { + return nil, fmt.Errorf("swid.HashEntry decode error: %w", err) + } + case swid.HashEntry: + he = t + default: + return nil, fmt.Errorf("value must be a swid.HashEntry or a string; found %T", k) + } + key := &CryptoKey{TaggedCertThumbprint{digest{he}}} if err := key.Valid(); err != nil { @@ -644,8 +657,8 @@ func NewCertThumbprint(he swid.HashEntry) (*CryptoKey, error) { return key, nil } -func MustNewCertThumbprint(he swid.HashEntry) *CryptoKey { - key, err := NewCertThumbprint(he) +func MustNewCertThumbprint(k any) *CryptoKey { + key, err := NewCertThumbprint(k) if err != nil { panic(err) @@ -654,6 +667,10 @@ func MustNewCertThumbprint(he swid.HashEntry) *CryptoKey { return key } +func (o TaggedCertThumbprint) Type() string { + return CertThumbprintType +} + // TaggedCertPathThumbprint represents a digest of a certification path. The // digest value may be used to find the certificate path if contained in a // lookup table. @@ -661,7 +678,22 @@ type TaggedCertPathThumbprint struct { digest } -func NewCertPathThumbprint(he swid.HashEntry) (*CryptoKey, error) { +func NewCertPathThumbprint(k any) (*CryptoKey, error) { + var he swid.HashEntry + var err error + + switch t := k.(type) { + case string: + he, err = swid.ParseHashEntry(t) + if err != nil { + return nil, fmt.Errorf("swid.HashEntry decode error: %w", err) + } + case swid.HashEntry: + he = t + default: + return nil, fmt.Errorf("value must be a swid.HashEntry or a string; found %T", k) + } + key := &CryptoKey{TaggedCertPathThumbprint{digest{he}}} if err := key.Valid(); err != nil { @@ -671,8 +703,8 @@ func NewCertPathThumbprint(he swid.HashEntry) (*CryptoKey, error) { return key, nil } -func MustNewCertPathThumbprint(he swid.HashEntry) *CryptoKey { - key, err := NewCertPathThumbprint(he) +func MustNewCertPathThumbprint(k any) *CryptoKey { + key, err := NewCertPathThumbprint(k) if err != nil { panic(err) @@ -680,3 +712,47 @@ func MustNewCertPathThumbprint(he swid.HashEntry) *CryptoKey { return key } + +func (o TaggedCertPathThumbprint) Type() string { + return CertPathThumbprintType +} + +// ICryptoKeyFactory creates a *CryptoKey from the specified any value. When +// passed nil as input, this must return the equivaluent nil-value for the +// associated ICryptoKeyValue implementation. +type ICryptoKeyFactory func(any) (*CryptoKey, error) + +var cryptoKeyValueRegister = map[string]ICryptoKeyFactory{ + // types defined by the core spec + PKIXBase64KeyType: NewPKIXBase64Key, + PKIXBase64CertType: NewPKIXBase64Cert, + PKIXBase64CertPathType: NewPKIXBase64CertPath, + COSEKeyType: NewCOSEKey, + ThumbprintType: NewThumbprint, + CertThumbprintType: NewCertThumbprint, + CertPathThumbprintType: NewCertPathThumbprint, +} + +// RegisterCryptoKeyType registers a new ICryptoKeyValue implementation +// (created by the provided ICryptoKeyFactory) under the specified type name +// and CBOR tag. +func RegisterCryptoKeyType(tag uint64, factory ICryptoKeyFactory) error { + + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Type() + if _, exists := cryptoKeyValueRegister[typ]; exists { + return fmt.Errorf("crypto key type with name %q already exists", typ) + } + + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + cryptoKeyValueRegister[typ] = factory + + return nil +} diff --git a/comid/cryptokey_test.go b/comid/cryptokey_test.go index 1c0812fc..40c5af9f 100644 --- a/comid/cryptokey_test.go +++ b/comid/cryptokey_test.go @@ -4,6 +4,7 @@ package comid import ( + "crypto" "encoding/base64" "encoding/json" "fmt" @@ -149,7 +150,7 @@ func Test_CryptoKey_NewCOSEKey(t *testing.T) { } func Test_CryptoKey_NewThumbprint(t *testing.T) { - type newKeyFunc func(swid.HashEntry) (*CryptoKey, error) + type newKeyFunc func(any) (*CryptoKey, error) for _, newFunc := range []newKeyFunc{ NewThumbprint, @@ -177,7 +178,7 @@ func Test_CryptoKey_NewThumbprint(t *testing.T) { assert.Contains(t, err.Error(), "length mismatch for hash algorithm") } - type mustNewKeyFunc func(swid.HashEntry) *CryptoKey + type mustNewKeyFunc func(any) *CryptoKey for _, mustNewFunc := range []mustNewKeyFunc{ MustNewThumbprint, @@ -271,7 +272,7 @@ func Test_CryptoKey_UnmarshalJSON_negative(t *testing.T) { }, { Val: `{"value":"deadbeef"}`, - ErrMsg: "key type not set", + ErrMsg: "type not set", }, { Val: `{"type": "cose-key", "value":";;;"}`, @@ -371,22 +372,22 @@ func Test_NewCryptoKey_negative(t *testing.T) { { Type: COSEKeyType, In: 7, - ErrMsg: "value must be a []byte; found int", + ErrMsg: "value must be a []byte or a string; found int", }, { Type: ThumbprintType, In: 7, - ErrMsg: "value must be a swid.HashEntry; found int", + ErrMsg: "value must be a swid.HashEntry or a string; found int", }, { Type: CertThumbprintType, In: 7, - ErrMsg: "value must be a swid.HashEntry; found int", + ErrMsg: "value must be a swid.HashEntry or a string; found int", }, { Type: CertPathThumbprintType, In: 7, - ErrMsg: "value must be a swid.HashEntry; found int", + ErrMsg: "value must be a swid.HashEntry or a string; found int", }, { Type: "random-key", @@ -399,3 +400,55 @@ func Test_NewCryptoKey_negative(t *testing.T) { assert.ErrorContains(t, err, tv.ErrMsg) } } + +type testCryptoKey [4]byte + +func newTestCryptoKey(val any) (*CryptoKey, error) { + return &CryptoKey{&testCryptoKey{0x74, 0x64, 0x73, 0x74}}, nil +} + +func (o testCryptoKey) PublicKey() (crypto.PublicKey, error) { + return crypto.PublicKey(o[:]), nil +} + +func (o testCryptoKey) Type() string { + return "test-crypto-key" +} + +func (o testCryptoKey) String() string { + return "test" +} + +func (o testCryptoKey) Valid() error { + return nil +} + +func Test_RegisterCryptoKeyTypeType(t *testing.T) { + err := RegisterCryptoKeyType(99998, newTestCryptoKey) + require.NoError(t, err) + + key, err := newTestCryptoKey(nil) + require.NoError(t, err) + + data, err := json.Marshal(key) + require.NoError(t, err) + assert.Equal(t, string(data), `{"type":"test-crypto-key","value":"test"}`) + + var out CryptoKey + err = json.Unmarshal(data, &out) + require.NoError(t, err) + assert.EqualValues(t, key, &out) + + data, err = em.Marshal(key) + require.NoError(t, err) + assert.Equal(t, data, []byte{ + 0xda, 0x0, 0x1, 0x86, 0x9e, // tag 99998 + 0x44, // bstr(4) + 0x74, 0x64, 0x73, 0x74, // "test" + }) + + var out2 CryptoKey + err = dm.Unmarshal(data, &out2) + require.NoError(t, err) + assert.Equal(t, key, &out2) +} diff --git a/comid/devidentitykey_test.go b/comid/devidentitykey_test.go index 14f2f6e1..be618c8c 100644 --- a/comid/devidentitykey_test.go +++ b/comid/devidentitykey_test.go @@ -23,12 +23,12 @@ func TestDevIdentityKey_Valid_empty(t *testing.T) { testerr: "environment validation failed: environment must not be empty", }, { - env: Environment{Instance: NewInstanceUEID(TestUEID)}, + env: Environment{Instance: MustNewInstanceUEID(TestUEID)}, verifkey: CryptoKeys{}, testerr: "verification keys validation failed: no keys to validate", }, { - env: Environment{Instance: NewInstanceUEID(TestUEID)}, + env: Environment{Instance: MustNewInstanceUEID(TestUEID)}, verifkey: CryptoKeys{&invalidKey}, testerr: "verification keys validation failed: invalid key at index 0: key value not set", }, diff --git a/comid/entity.go b/comid/entity.go index 8c475cee..b52b04ca 100644 --- a/comid/entity.go +++ b/comid/entity.go @@ -4,12 +4,211 @@ package comid import ( + "encoding/json" + "errors" "fmt" + "unicode/utf8" "github.com/veraison/corim/encoding" "github.com/veraison/corim/extensions" ) +type EntityName struct { + Value IEntityName +} + +func NewEntityName(val any, typ string) (*EntityName, error) { + factory, ok := entityNameValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unexpected entity name type: %s", typ) + } + + return factory(val) +} + +func (o EntityName) String() string { + return o.Value.String() +} + +func (o EntityName) Valid() error { + if o.Value == nil { + return errors.New("empty entity name") + } + + return o.Value.Valid() +} + +func (o EntityName) MarshalCBOR() ([]byte, error) { + if err := o.Valid(); err != nil { + return nil, err + } + + return em.Marshal(o.Value) +} + +func (o *EntityName) UnmarshalCBOR(data []byte) error { + if len(data) == 0 { + return errors.New("empty") + } + + majorType := (data[0] & 0xe0) >> 5 + if majorType == 3 { // text string + var text string + + if err := dm.Unmarshal(data, &text); err != nil { + return err + } + + name := StringEntityName(text) + o.Value = &name + + return nil + } + + return dm.Unmarshal(data, &o.Value) +} + +func (o EntityName) MarshalJSON() ([]byte, error) { + if err := o.Valid(); err != nil { + return nil, err + } + + if o.Value.Type() == extensions.StringType { + return json.Marshal(o.Value.String()) + } + + valueBytes, err := json.Marshal(o.Value) + if err != nil { + return nil, err + } + + value := encoding.TypeAndValue{ + Type: o.Value.Type(), + Value: valueBytes, + } + + return json.Marshal(value) +} + +func (o *EntityName) UnmarshalJSON(data []byte) error { + var text string + if err := json.Unmarshal(data, &text); err == nil { + *o = *MustNewStringEntityName(text) + return nil + } + + var tnv encoding.TypeAndValue + + if err := json.Unmarshal(data, &tnv); err != nil { + return fmt.Errorf("entity name decoding failure: %w", err) + } + + decoded, err := NewEntityName(nil, tnv.Type) + if err != nil { + return err + } + + if err := json.Unmarshal(tnv.Value, &decoded.Value); err != nil { + return fmt.Errorf( + "cannot unmarshal entity name: %w", + err, + ) + } + + if err := decoded.Value.Valid(); err != nil { + return fmt.Errorf("invalid %s: %w", tnv.Type, err) + } + + o.Value = decoded.Value + + return nil +} + +type IEntityName interface { + extensions.ITypeChoiceValue +} + +type IEntityNameFactory func(any) (*EntityName, error) + +var entityNameValueRegister = map[string]IEntityNameFactory{ + extensions.StringType: NewStringEntityName, +} + +// RegisterEntityNameType registers a new IEntityNameValue implementation +// (created by the provided IEntityNameFactory) under the specified type name +// and CBOR tag. +func RegisterEntityNameType(tag uint64, factory IEntityNameFactory) error { + + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Value.Type() + if _, exists := entityNameValueRegister[typ]; exists { + return fmt.Errorf("entity name type with name %q already exists", typ) + } + + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + entityNameValueRegister[typ] = factory + + return nil +} + +type StringEntityName string + +func NewStringEntityName(val any) (*EntityName, error) { + var ret StringEntityName + + if val == nil { + ret = StringEntityName("") + return &EntityName{&ret}, nil + } + + switch t := val.(type) { + case string: + ret = StringEntityName(t) + case []byte: + if !utf8.Valid(t) { + return nil, errors.New("bytes do not form a valid UTF-8 string") + } + + ret = StringEntityName(t) + default: + return nil, fmt.Errorf("unexpected type for string entity name: %T", t) + } + + return &EntityName{&ret}, nil +} + +func MustNewStringEntityName(val any) *EntityName { + ret, err := NewStringEntityName(val) + if err != nil { + panic(err) + } + + return ret +} + +func (o StringEntityName) String() string { + return string(o) +} + +func (o StringEntityName) Type() string { + return extensions.StringType +} + +func (o StringEntityName) Valid() error { + if o == "" { + return errors.New("empty entity-name") + } + + return nil +} + type TaggedURI string func (o TaggedURI) Empty() bool { @@ -18,9 +217,9 @@ func (o TaggedURI) Empty() bool { // Entity stores an entity-map capable of CBOR and JSON serializations. type Entity struct { - EntityName string `cbor:"0,keyasint" json:"name"` - RegID *TaggedURI `cbor:"1,keyasint,omitempty" json:"regid,omitempty"` - Roles Roles `cbor:"2,keyasint" json:"roles"` + EntityName *EntityName `cbor:"0,keyasint" json:"name"` + RegID *TaggedURI `cbor:"1,keyasint,omitempty" json:"regid,omitempty"` + Roles Roles `cbor:"2,keyasint" json:"roles"` Extensions } @@ -41,7 +240,7 @@ func (o *Entity) SetEntityName(name string) *Entity { if name == "" { return nil } - o.EntityName = name + o.EntityName = MustNewStringEntityName(name) } return o } @@ -68,10 +267,14 @@ func (o *Entity) SetRoles(roles ...Role) *Entity { // Valid checks for validity of the fields within each Entity func (o Entity) Valid() error { - if o.EntityName == "" { + if o.EntityName == nil { return fmt.Errorf("invalid entity: empty entity-name") } + if err := o.EntityName.Valid(); err != nil { + return fmt.Errorf("invalid entity: %w", err) + } + if o.RegID != nil && o.RegID.Empty() { return fmt.Errorf("invalid entity: empty reg-id") } diff --git a/comid/entity_test.go b/comid/entity_test.go index 61c0d1b3..4d3eea45 100644 --- a/comid/entity_test.go +++ b/comid/entity_test.go @@ -4,6 +4,8 @@ package comid import ( + "errors" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -83,3 +85,176 @@ func TestEntity_SetRegID_empty(t *testing.T) { assert.Nil(t, e.SetRegID("")) } + +type testEntityName uint64 + +func newTestEntityName(val any) (*EntityName, error) { + if val == nil { + v := testEntityName(0) + return &EntityName{&v}, nil + } + + u, ok := val.(uint64) + if !ok { + return nil, errors.New("must be uint64") + } + + v := testEntityName(u) + return &EntityName{&v}, nil +} + +func (o testEntityName) Type() string { + return "test" +} + +func (o testEntityName) String() string { + return fmt.Sprint(uint64(o)) +} + +func (o testEntityName) Valid() error { + return nil +} + +type testEntityNameBadType struct { + testEntityName +} + +func newTestEntityNameBadType(val any) (*EntityName, error) { + v := testEntityNameBadType{testEntityName(7)} + return &EntityName{&v}, nil +} + +func (o testEntityNameBadType) Type() string { + return "string" +} + +func Test_RegisterEntityNameType(t *testing.T) { + err := RegisterEntityNameType(32, newTestEntityName) + assert.EqualError(t, err, "tag 32 is already registered") + + err = RegisterEntityNameType(99994, newTestEntityNameBadType) + assert.EqualError(t, err, `entity name type with name "string" already exists`) + + registerTestEntityNameType(t) +} + +// Since there only one, untagged, entity name type in the core spec, we use +// the test type define above in order to test the marshalling code works +// properly. Since global enviroment is not reset when running multiple tests, +// we cannot simply call RegisterEntityNameType() inside each test that relies +// on the test type, as that will cause the "tag already registred" error. On +// the other hand, we do not want to create inter-test dependencies by relying +// that the test registering the type is run before the others that rely on it. +// To get around this, use this global flag to only register the test type if a +// previous test hasn't already done so. +var testEntityNameTypeRegistered = false + +func registerTestEntityNameType(t *testing.T) { + if !testEntityNameTypeRegistered { + err := RegisterEntityNameType(99994, newTestEntityName) + require.NoError(t, err) + + testEntityNameTypeRegistered = true + } +} + +func TestEntityName_CBOR(t *testing.T) { + registerTestEntityNameType(t) + + for _, tv := range []struct { + Value any + Type string + ExpectedBytes []byte + ExpectedString string + }{ + { + Value: "test", + Type: "string", + ExpectedBytes: []byte{ + 0x64, // tstr(4) + 0x74, 0x65, 0x73, 0x74, // "test" + }, + ExpectedString: "test", + }, + { + Value: uint64(7), + Type: "test", + ExpectedBytes: []byte{ + 0xda, 0x0, 0x1, 0x86, 0x9a, // tag 99994 + 0x07, // unsigned int(7) + }, + ExpectedString: "7", + }, + } { + t.Run(tv.Type, func(t *testing.T) { + en, err := NewEntityName(tv.Value, tv.Type) + require.NoError(t, err) + + data, err := en.MarshalCBOR() + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedBytes, data) + + var out EntityName + + err = out.UnmarshalCBOR(data) + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedString, out.String()) + }) + } +} + +func TestEntityName_JSON(t *testing.T) { + registerTestEntityNameType(t) + + for _, tv := range []struct { + Value any + Type string + ExpectedBytes []byte + ExpectedString string + }{ + { + Value: "test", + Type: "string", + ExpectedBytes: []byte(`"test"`), + ExpectedString: "test", + }, + { + Value: uint64(7), + Type: "test", + ExpectedBytes: []byte(`{"type":"test","value":7}`), + ExpectedString: "7", + }, + } { + t.Run(tv.Type, func(t *testing.T) { + en, err := NewEntityName(tv.Value, tv.Type) + require.NoError(t, err) + + data, err := en.MarshalJSON() + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedBytes, data) + + var out EntityName + + err = out.UnmarshalJSON(data) + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedString, out.String()) + }) + } +} + +func Test_NewStringEntityName(t *testing.T) { + out, err := NewStringEntityName(nil) + require.NoError(t, err) + assert.EqualError(t, out.Valid(), "empty entity-name") + + out, err = NewStringEntityName([]byte("test")) + require.NoError(t, err) + assert.Equal(t, "test", out.String()) + + _, err = NewStringEntityName(7) + assert.EqualError(t, err, "unexpected type for string entity name: int") +} diff --git a/comid/environment_test.go b/comid/environment_test.go index c078ce90..99761acf 100644 --- a/comid/environment_test.go +++ b/comid/environment_test.go @@ -78,7 +78,7 @@ func TestEnvironment_ToCBOR_class_only(t *testing.T) { func TestEnvironment_ToCBOR_class_and_instance(t *testing.T) { tv := Environment{ Class: NewClassUUID(TestUUID), - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewInstanceUEID(TestUEID), } require.NotNil(t, tv.Class) require.NotNil(t, tv.Instance) @@ -96,7 +96,7 @@ func TestEnvironment_ToCBOR_class_and_instance(t *testing.T) { func TestEnvironment_ToCBOR_instance_only(t *testing.T) { tv := Environment{ - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewInstanceUEID(TestUEID), } require.NotNil(t, tv.Instance) @@ -180,7 +180,7 @@ func TestEnvironment_FromCBOR_class_and_instance(t *testing.T) { assert.NotNil(t, actual.Class) assert.Equal(t, TestUUIDString, actual.Class.ClassID.String()) assert.NotNil(t, actual.Instance) - assert.Equal(t, TestUEIDString, actual.Instance.String()) + assert.Equal(t, []byte(TestUEID), actual.Instance.Bytes()) assert.Nil(t, actual.Group) } @@ -197,3 +197,25 @@ func TestEnvironment_FromCBOR_group_only(t *testing.T) { assert.NotNil(t, actual.Group) assert.Equal(t, TestUUIDString, actual.Group.String()) } + +func TestEnviroment_JSON(t *testing.T) { + testEnv := Environment{ + Class: NewClassUUID(TestUUID), + } + + out, err := testEnv.ToJSON() + require.NoError(t, err) + assert.Equal(t, `{"class":{"id":{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}}}`, string(out)) + + var outEnv Environment + + err = outEnv.FromJSON(out) + require.NoError(t, err) + assert.Equal(t, testEnv, outEnv) + + _, err = Environment{}.ToJSON() + assert.EqualError(t, err, "environment must not be empty") + + err = outEnv.FromJSON([]byte(`{"class": 7}`)) + assert.EqualError(t, err, "json: cannot unmarshal number into Go struct field Environment.class of type comid.Class") +} diff --git a/comid/example_cca_refval_test.go b/comid/example_cca_refval_test.go index a9d1102b..9b52304c 100644 --- a/comid/example_cca_refval_test.go +++ b/comid/example_cca_refval_test.go @@ -66,21 +66,23 @@ func extractCCARefVal(rv ReferenceValue) error { if !m.Key.IsSet() { return fmt.Errorf("mKey not set at index %d", i) } - if m.Key.IsPSARefValID() { + + switch t := m.Key.Value.(type) { + case *TaggedPSARefValID: if err := extractSwMeasurement(m); err != nil { return fmt.Errorf("extracting measurement at index %d: %w", i, err) } - } - if m.Key.IsCCAPlatformConfigID() { + case *TaggedCCAPlatformConfigID: if err := extractCCARefValID(m.Key); err != nil { return fmt.Errorf("extracting cca-refval-id: %w", err) } if err := extractRawValue(m.Val.RawValue); err != nil { return fmt.Errorf("extracting raw vlue: %w", err) } - - return nil + default: + return fmt.Errorf("unexpected Mkey type: %T", t) } + } return nil @@ -105,9 +107,9 @@ func extractCCARefValID(k *Mkey) error { return fmt.Errorf("no measurement key") } - id, err := k.GetCCAPlatformConfigID() - if err != nil { - return fmt.Errorf("getting CCA platform config id: %w", err) + id, ok := k.Value.(*TaggedCCAPlatformConfigID) + if !ok { + return fmt.Errorf("expected CCA platform config id, found: %T", k.Value) } fmt.Printf("Label: %s\n", id) return nil diff --git a/comid/example_psa_keys_test.go b/comid/example_psa_keys_test.go index edf95989..3780cbbd 100644 --- a/comid/example_psa_keys_test.go +++ b/comid/example_psa_keys_test.go @@ -70,12 +70,7 @@ func extractInstanceID(i *Instance) error { return fmt.Errorf("no instance") } - instID, err := i.GetUEID() - if err != nil { - return fmt.Errorf("extracting implemenetation-id: %w", err) - } - - fmt.Printf("InstanceID: %x\n", instID) + fmt.Printf("InstanceID: %x\n", i.Bytes()) return nil } diff --git a/comid/example_psa_refval_test.go b/comid/example_psa_refval_test.go index 230209d0..18817d91 100644 --- a/comid/example_psa_refval_test.go +++ b/comid/example_psa_refval_test.go @@ -111,9 +111,10 @@ func extractPSARefValID(k *Mkey) error { return fmt.Errorf("no measurement key") } - id, err := k.GetPSARefValID() - if err != nil { - return fmt.Errorf("getting PSA refval id: %w", err) + id, ok := k.Value.(*TaggedPSARefValID) + + if !ok { + return fmt.Errorf("expected PSA refval id, found: %T", k.Value) } fmt.Printf("SignerID: %x\n", id.SignerID) @@ -142,12 +143,11 @@ func extractImplementationID(c *Class) error { return fmt.Errorf("no class-id") } - implID, err := classID.GetImplID() - if err != nil { - return fmt.Errorf("extracting implemenetation-id: %w", err) + if classID.Type() != ImplIDType { + return fmt.Errorf("class id is not a psa.impl-id") } - fmt.Printf("ImplementationID: %x\n", implID) + fmt.Printf("ImplementationID: %x\n", classID.Bytes()) return nil } diff --git a/comid/example_test.go b/comid/example_test.go index 756829d3..5bab7d06 100644 --- a/comid/example_test.go +++ b/comid/example_test.go @@ -26,13 +26,12 @@ func Example_encode() { SetModel("RoadRunner"). SetLayer(0). SetIndex(1), - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewInstanceUEID(TestUEID), Group: NewGroupUUID(TestUUID), }, Measurements: *NewMeasurements(). AddMeasurement( - NewMeasurement(). - SetKeyUUID(TestUUID). + MustNewUUIDMeasurement(TestUUID). SetRawValueBytes([]byte{0x01, 0x02, 0x03, 0x04}, []byte{0xff, 0xff, 0xff, 0xff}). SetSVN(2). AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}). @@ -55,13 +54,12 @@ func Example_encode() { SetModel("RoadRunner"). SetLayer(0). SetIndex(1), - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewInstanceUEID(TestUEID), Group: NewGroupUUID(TestUUID), }, Measurements: *NewMeasurements(). AddMeasurement( - NewMeasurement(). - SetKeyUUID(TestUUID). + MustNewUUIDMeasurement(TestUUID). SetRawValueBytes([]byte{0x01, 0x02, 0x03, 0x04}, []byte{0xff, 0xff, 0xff, 0xff}). SetMinSVN(2). AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}). @@ -79,7 +77,7 @@ func Example_encode() { AddAttestVerifKey( AttestVerifKey{ Environment: Environment{ - Instance: NewInstanceUUID(uuid.UUID(TestUUID)), + Instance: MustNewInstanceUUID(uuid.UUID(TestUUID)), }, VerifKeys: *NewCryptoKeys(). Add( @@ -89,7 +87,7 @@ func Example_encode() { ).AddDevIdentityKey( DevIdentityKey{ Environment: Environment{ - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewInstanceUEID(TestUEID), }, VerifKeys: *NewCryptoKeys(). Add( @@ -126,25 +124,23 @@ func Example_encode_PSA() { }, Measurements: *NewMeasurements(). AddMeasurement( - NewPSAMeasurement( - *NewPSARefValID(TestSignerID). - SetLabel("BL"). - SetVersion("5.0.5"), - ).AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}), + MustNewPSAMeasurement( + MustCreatePSARefValID( + TestSignerID, "BL", "5.0.5", + )).AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}), ). AddMeasurement( - NewPSAMeasurement( - *NewPSARefValID(TestSignerID). - SetLabel("PRoT"). - SetVersion("1.3.5"), - ).AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}), + MustNewPSAMeasurement( + MustCreatePSARefValID( + TestSignerID, "PRoT", "1.3.5", + )).AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}), ), }, ). AddAttestVerifKey( AttestVerifKey{ Environment: Environment{ - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewInstanceUEID(TestUEID), }, VerifKeys: *NewCryptoKeys(). Add( @@ -175,7 +171,7 @@ func Example_encode_PSA_attestation_verification() { AddAttestVerifKey( AttestVerifKey{ Environment: Environment{ - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewInstanceUEID(TestUEID), }, VerifKeys: *NewCryptoKeys(). Add( diff --git a/comid/group.go b/comid/group.go index adb9cdb4..9adc9e4c 100644 --- a/comid/group.go +++ b/comid/group.go @@ -82,7 +82,7 @@ func (o *Group) UnmarshalJSON(data []byte) error { } switch v.Type { - case "uuid": + case UUIDType: var x UUID if err := x.UnmarshalJSON(v.Value); err != nil { return err diff --git a/comid/instance.go b/comid/instance.go index 8201145d..2bacc686 100644 --- a/comid/instance.go +++ b/comid/instance.go @@ -1,179 +1,188 @@ package comid import ( - "encoding/hex" "encoding/json" + "errors" "fmt" - "github.com/google/uuid" - "github.com/veraison/eat" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" ) +type IInstanceValue interface { + extensions.ITypeChoiceValue + + Bytes() []byte +} + // Instance stores an instance identity. The supported formats are UUID and UEID. type Instance struct { - val interface{} + Value IInstanceValue } -// NewInstance instantiates an empty instance -func NewInstance() *Instance { - return &Instance{} -} +// NewInstanceUEID instantiates a new instance with the supplied UEID identity +func NewInstanceUEID(val any) (*Instance, error) { + if val == nil { + return &Instance{&TaggedUEID{}}, nil + } -// SetUEID sets the identity of the target instance to the supplied UEID -func (o *Instance) SetUEID(val eat.UEID) *Instance { - if o != nil { - if val.Validate() != nil { - return nil - } - o.val = TaggedUEID(val) + ret, err := NewTaggedUEID(val) + if err != nil { + return nil, err } - return o + return &Instance{ret}, nil } -// SetUUID sets the identity of the target instance to the supplied UUID -func (o *Instance) SetUUID(val uuid.UUID) *Instance { - if o != nil { - o.val = TaggedUUID(val) +func MustNewInstanceUEID(val any) *Instance { + ret, err := NewInstanceUEID(val) + if err != nil { + panic(err) } - return o -} -// NewInstanceUEID instantiates a new instance with the supplied UEID identity -func NewInstanceUEID(val eat.UEID) *Instance { - return NewInstance().SetUEID(val) + return ret } // NewInstanceUUID instantiates a new instance with the supplied UUID identity -func NewInstanceUUID(val uuid.UUID) *Instance { - return NewInstance().SetUUID(val) -} +func NewInstanceUUID(val any) (*Instance, error) { + if val == nil { + return &Instance{&TaggedUUID{}}, nil + } -// Valid checks for the validity of given instance -func (o Instance) Valid() error { - if o.String() == "" { - return fmt.Errorf("invalid instance id") + ret, err := NewTaggedUUID(val) + if err != nil { + return nil, err } - return nil + + return &Instance{ret}, nil } -func (o Instance) GetUEID() (eat.UEID, error) { - switch t := o.val.(type) { - case TaggedUEID: - return eat.UEID(t), nil - default: - return eat.UEID{}, fmt.Errorf("instance-id type is: %T", t) +func MustNewInstanceUUID(val any) *Instance { + ret, err := NewInstanceUUID(val) + if err != nil { + panic(err) } + + return ret } -func (o Instance) GetUUID() (UUID, error) { - switch t := o.val.(type) { - case TaggedUUID: - return UUID(t), nil - default: - return UUID{}, fmt.Errorf("instance-id type is: %T", t) +// Valid checks for the validity of given instance +func (o Instance) Valid() error { + if o.String() == "" { + return fmt.Errorf("invalid instance id") } + return nil } // String returns a printable string of the Instance value. UUIDs use the // canonical 8-4-4-4-12 format, UEIDs are hex encoded. func (o Instance) String() string { - switch t := o.val.(type) { - case TaggedUUID: - return UUID(t).String() - case TaggedUEID: - return hex.EncodeToString(t) - default: + if o.Value == nil { return "" } + + return o.Value.String() +} + +func (o Instance) Type() string { + return o.Value.Type() +} + +func (o Instance) Bytes() []byte { + return o.Value.Bytes() } // MarshalCBOR serializes the target instance to CBOR func (o Instance) MarshalCBOR() ([]byte, error) { - return em.Marshal(o.val) + return em.Marshal(o.Value) } func (o *Instance) UnmarshalCBOR(data []byte) error { - var ueid TaggedUEID - - if dm.Unmarshal(data, &ueid) == nil { - o.val = ueid - return nil - } - - var u TaggedUUID - - if dm.Unmarshal(data, &u) == nil { - o.val = u - return nil - } - - return fmt.Errorf("unknown instance type (CBOR: %x)", data) + return dm.Unmarshal(data, &o.Value) } -// UnmarshalJSON deserializes the supplied JSON type/value object into the Group -// target. The supported formats are UUID, e.g.: +// UnmarshalJSON deserializes the supplied JSON object into the target Instance +// The instance object must have the following shape: // // { -// "type": "uuid", -// "value": "69E027B2-7157-4758-BCB4-D9F167FE49EA" +// "type": "", +// "value": "" // } // -// and UEID: +// where must be one of the known IInstanceValue implementation +// type names (available in the base implementation: "ueid" and "uuid"), and +// is the instance value encoded as a string. The exact +// encoding is depenent. For the base implmentation types it is // -// { -// "type": "ueid", -// "value": "Ad6tvu/erb7v3q2+796tvu8=" -// } +// ueid: base64-encoded bytes, e.g. "YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE=" +// uuid: standard UUID string representation, e.g. "550e8400-e29b-41d4-a716-446655440000" func (o *Instance) UnmarshalJSON(data []byte) error { - var v tnv + var value encoding.TypeAndValue - if err := json.Unmarshal(data, &v); err != nil { + if err := json.Unmarshal(data, &value); err != nil { return err } - switch v.Type { - case "uuid": - var x UUID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedUUID(x) - case "ueid": - var x UEID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedUEID(x) - default: - return fmt.Errorf("unknown type %s for instance", v.Type) + if value.Type == "" { + return errors.New("key type not set") } - return nil + factory, ok := instanceValueRegister[value.Type] + if !ok { + return fmt.Errorf("unknown class id type: %q", value.Type) + } + + var valueString string + if err := json.Unmarshal(value.Value, &valueString); err != nil { + return err + } + + v, err := factory(valueString) + if err != nil { + return err + } + + o.Value = v.Value + + return o.Valid() } func (o Instance) MarshalJSON() ([]byte, error) { - var ( - v tnv - b []byte - err error - ) - - switch t := o.val.(type) { - case TaggedUUID: - b, err = UUID(t).MarshalJSON() - if err != nil { - return nil, err - } - v = tnv{Type: "uuid", Value: b} - case TaggedUEID: - b, err = UEID(t).MarshalJSON() - if err != nil { - return nil, err - } - v = tnv{Type: "ueid", Value: b} - default: - return nil, fmt.Errorf("unknown type %T for instance", t) - } - - return json.Marshal(v) + valueBytes, err := json.Marshal(o.Value) + if err != nil { + return nil, err + } + + value := encoding.TypeAndValue{ + Type: o.Value.Type(), + Value: valueBytes, + } + + return json.Marshal(value) +} + +type IInstanceFactory func(any) (*Instance, error) + +var instanceValueRegister = map[string]IInstanceFactory{ + UEIDType: NewInstanceUEID, + UUIDType: NewInstanceUUID, +} + +func RegisterInstanceType(tag uint64, factory IInstanceFactory) error { + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Type() + if _, exists := instanceValueRegister[typ]; exists { + return fmt.Errorf("class ID type with name %q already exists", typ) + } + + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + instanceValueRegister[typ] = factory + + return nil } diff --git a/comid/instance_test.go b/comid/instance_test.go index 34e671e7..e65c99a3 100644 --- a/comid/instance_test.go +++ b/comid/instance_test.go @@ -1,24 +1,69 @@ package comid import ( + "encoding/json" "testing" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestInstance_GetUUID_OK(t *testing.T) { - inst := NewInstanceUUID(uuid.UUID(TestUUID)) - require.NotNil(t, inst) - u, err := inst.GetUUID() - assert.Nil(t, err) - assert.Equal(t, u, TestUUID) + inst := MustNewInstanceUUID(TestUUID) + u, ok := inst.Value.(*TaggedUUID) + assert.True(t, ok) + assert.EqualValues(t, TestUUID, *u) } -func TestInstance_GetUUID_NOK(t *testing.T) { - inst := &Instance{} - expectedErr := "instance-id type is: " - _, err := inst.GetUUID() - assert.EqualError(t, err, expectedErr) +type testInstance string + +func newTestInstance(val any) (*Instance, error) { + ret := testInstance("test") + return &Instance{&ret}, nil +} + +func (o testInstance) Bytes() []byte { + return []byte(o) +} + +func (o testInstance) Type() string { + return "test-instance" +} + +func (o testInstance) String() string { + return string(o) +} + +func (o testInstance) Valid() error { + return nil +} + +func Test_RegisterInstanceType(t *testing.T) { + err := RegisterInstanceType(99997, newTestInstance) + require.NoError(t, err) + + instance, err := newTestInstance(nil) + require.NoError(t, err) + + data, err := json.Marshal(instance) + require.NoError(t, err) + assert.Equal(t, string(data), `{"type":"test-instance","value":"test"}`) + + var out Instance + err = json.Unmarshal(data, &out) + require.NoError(t, err) + assert.Equal(t, instance.Bytes(), out.Bytes()) + + data, err = em.Marshal(instance) + require.NoError(t, err) + assert.Equal(t, data, []byte{ + 0xda, 0x0, 0x1, 0x86, 0x9d, // tag 99997 + 0x64, // tstr(4) + 0x74, 0x65, 0x73, 0x74, // "test" + }) + + var out2 Instance + err = dm.Unmarshal(data, &out2) + require.NoError(t, err) + assert.Equal(t, instance.Bytes(), out2.Bytes()) } diff --git a/comid/measurement.go b/comid/measurement.go index b1909d87..eca0858d 100644 --- a/comid/measurement.go +++ b/comid/measurement.go @@ -5,8 +5,10 @@ package comid import ( "encoding/json" + "errors" "fmt" "net" + "strconv" "github.com/veraison/corim/encoding" "github.com/veraison/corim/extensions" @@ -16,198 +18,267 @@ import ( const MaxUint64 = ^uint64(0) -// Measurement stores a measurement-map with CBOR and JSON serializations. -type Measurement struct { - Key *Mkey `cbor:"0,keyasint,omitempty" json:"key,omitempty"` - Val Mval `cbor:"1,keyasint" json:"value"` - AuthorizedBy *CryptoKey `cbor:"2,keyasint,omitempty" json:"authorized-by,omitempty"` +type IMKeyValue interface { + extensions.ITypeChoiceValue } // Mkey stores a $measured-element-type-choice. // The supported types are UUID, PSA refval-id, CCA platform-config-id and unsigned integer // TO DO Add tagged OID: see https://github.com/veraison/corim/issues/35 type Mkey struct { - val interface{} + Value IMKeyValue +} + +func NewMkey(val any, typ string) (*Mkey, error) { + factory, ok := mkeyValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unexpected measurement key type: %q", typ) + } + + return factory(nil) +} + +func MustNewMkey(val any, typ string) *Mkey { + ret, err := NewMkey(val, typ) + if err != nil { + panic(err) + } + + return ret } func (o Mkey) IsSet() bool { - return o.val != nil + return o.Value != nil } func (o Mkey) Valid() error { - switch t := o.val.(type) { - case TaggedUUID: - if UUID(t).Empty() { - return fmt.Errorf("empty UUID") - } - return nil - case TaggedPSARefValID: - return PSARefValID(t).Valid() - case TaggedCCAPlatformConfigID: - if CCAPlatformConfigID(t).Empty() { - return fmt.Errorf("empty CCAPlatformConfigID") - } - case uint64: - if o.val == nil { - return fmt.Errorf("empty uint Mkey") - } - return nil - default: - return fmt.Errorf("unknown measurement key type: %T", t) + if o.Value == nil { + return errors.New("Mkey value not set") } + + if err := o.Value.Valid(); err != nil { + return fmt.Errorf("invalid %s: %w", o.Value.Type(), err) + } + return nil } -func (o Mkey) IsPSARefValID() bool { - _, ok := o.val.(TaggedPSARefValID) - return ok -} +// UnmarshalJSON deserializes the supplied JSON object into the target MKey +// The key object must have the following shape: +// +// { +// "type": "", +// "value": +// } +// +// where must be one of the known IMKeyValue implementation +// type names (available in the base implementation: "uuid", "oid", +// "psa.impl-id"), and is the class id value serialized to +// JSON. The exact serialization is depenent. For the base +// implementation types it is +// +// oid: dot-seprated integers, e.g. "1.2.3.4" +// psa.refval-id: JSON representation of the PSA refval-id +// uuid: standard UUID string representation, e.g. "550e8400-e29b-41d4-a716-446655440000" +func (o *Mkey) UnmarshalJSON(data []byte) error { + var tnv encoding.TypeAndValue -func (o Mkey) IsCCAPlatformConfigID() bool { - _, ok := o.val.(TaggedCCAPlatformConfigID) - return ok + if err := json.Unmarshal(data, &tnv); err != nil { + return err + } + + decoded, err := NewMkey(nil, tnv.Type) + if err != nil { + return err + } + + if err := json.Unmarshal(tnv.Value, decoded.Value); err != nil { + return fmt.Errorf("invalid %s: %w", tnv.Type, err) + } + + if err := decoded.Value.Valid(); err != nil { + return fmt.Errorf("invalid %s: %w", tnv.Type, err) + } + + o.Value = decoded.Value + + return nil } -func (o Mkey) GetPSARefValID() (PSARefValID, error) { - switch t := o.val.(type) { - case TaggedPSARefValID: - return PSARefValID(t), nil - default: - return PSARefValID{}, fmt.Errorf("measurement-key type is: %T", t) +// MarshalJSON serializes the target Mkey into the type'n'value JSON object +func (o Mkey) MarshalJSON() ([]byte, error) { + valueBytes, err := json.Marshal(o.Value) + if err != nil { + return nil, err + } + + value := encoding.TypeAndValue{ + Type: o.Value.Type(), + Value: valueBytes, } + + return json.Marshal(value) } -func (o Mkey) GetCCAPlatformConfigID() (CCAPlatformConfigID, error) { - switch t := o.val.(type) { - case TaggedCCAPlatformConfigID: - return CCAPlatformConfigID(t), nil - default: - return CCAPlatformConfigID(""), fmt.Errorf("measurement-key type is: %T", t) +func (o Mkey) MarshalCBOR() ([]byte, error) { + return em.Marshal(o.Value) +} + +func (o *Mkey) UnmarshalCBOR(data []byte) error { + if len(data) == 0 { + return errors.New("empty input") } + + majorType := (data[0] & 0xe0) >> 5 + if majorType == 6 { // tag + return dm.Unmarshal(data, &o.Value) + } + + // untagged value must be a uint + + var val UintMkey + if err := dm.Unmarshal(data, &val); err != nil { + return err + } + + o.Value = &val + return nil } -func (o Mkey) GetKeyUint() (uint64, error) { - switch t := o.val.(type) { +var UintType = "uint" + +type UintMkey uint64 + +func NewUintMkey(val any) (*UintMkey, error) { + var ret UintMkey + + if val == nil { + return &ret, nil + } + + switch t := val.(type) { + case UintMkey: + ret = t + case *UintMkey: + ret = *t + case string: + u, err := strconv.ParseUint(t, 10, 64) + if err != nil { + return nil, err + } + ret = UintMkey(u) case uint64: - return t, nil + ret = UintMkey(t) + case uint: + ret = UintMkey(t) default: - return MaxUint64, fmt.Errorf("measurement-key type is: %T", t) + return nil, fmt.Errorf("unexpected type for UintMkey: %T", t) } + + return &ret, nil } -// UnmarshalJSON deserializes the type'n'value JSON object into the target Mkey -func (o *Mkey) UnmarshalJSON(data []byte) error { - var v tnv +func (o UintMkey) Valid() error { + return nil +} + +func (o UintMkey) String() string { + return fmt.Sprint(uint64(o)) +} + +func (o UintMkey) Type() string { + return UintType +} - if err := json.Unmarshal(data, &v); err != nil { +func (o *UintMkey) UnmarshalJSON(data []byte) error { + var tmp uint64 + + if err := json.Unmarshal(data, &tmp); err != nil { return err } - switch v.Type { - case "uuid": - var x UUID - if err := x.UnmarshalJSON(v.Value); err != nil { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type UUID: %w", - err, - ) - } - o.val = TaggedUUID(x) - case "psa.refval-id": - var x PSARefValID - if err := json.Unmarshal(v.Value, &x); err != nil { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type PSARefValID: %w", - err, - ) - } - if err := x.Valid(); err != nil { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type PSARefValID: %w", - err, - ) - } - o.val = TaggedPSARefValID(x) - case "cca.platform-config-id": - var x CCAPlatformConfigID - if err := json.Unmarshal(v.Value, &x); err != nil { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type CCAPlatformConfigID: %w", - err, - ) - } - if x.Empty() { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type CCAPlatformConfigID: empty label", - ) - } - o.val = TaggedCCAPlatformConfigID(x) - case "uint": - var x uint64 - if err := json.Unmarshal(v.Value, &x); err != nil { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type uint: %w", - err, - ) - } - o.val = x - default: - return fmt.Errorf("unknown type %s for $measured-element-type-choice", v.Type) - } + *o = UintMkey(tmp) return nil } -// MarshalJSON serializes the target Mkey into the type'n'value JSON object -// Supported types are: uuid, psa.refval-id and unsigned integer -func (o Mkey) MarshalJSON() ([]byte, error) { - var ( - v tnv - b []byte - err error - ) - - switch t := o.val.(type) { - case TaggedUUID: - uuidString := UUID(t).String() - b, err = json.Marshal(uuidString) - if err != nil { - return nil, err - } - v = tnv{Type: "uuid", Value: b} - case TaggedPSARefValID: - b, err = json.Marshal(t) - if err != nil { - return nil, err - } - v = tnv{Type: "psa.refval-id", Value: b} - case TaggedCCAPlatformConfigID: - b, err = json.Marshal(t) - if err != nil { - return nil, err - } - v = tnv{Type: "cca.platform-config-id", Value: b} +func NewMkeyOID(val any) (*Mkey, error) { + ret, err := NewTaggedOID(val) + if err != nil { + return nil, err + } - case uint64: - b, err = json.Marshal(t) - if err != nil { - return nil, err - } - v = tnv{Type: "uint", Value: b} + return &Mkey{ret}, nil +} - default: - return nil, fmt.Errorf("unknown type %T for mkey", t) +func NewMkeyUUID(val any) (*Mkey, error) { + ret, err := NewTaggedUUID(val) + if err != nil { + return nil, err } - return json.Marshal(v) + return &Mkey{ret}, nil } -func (o Mkey) MarshalCBOR() ([]byte, error) { - return em.Marshal(o.val) +func NewMkeyUint(val any) (*Mkey, error) { + ret, err := NewUintMkey(val) + if err != nil { + return nil, err + } + + return &Mkey{ret}, nil } -func (o *Mkey) UnmarshalCBOR(data []byte) error { - return dm.Unmarshal(data, &o.val) +func NewMkeyPSARefvalID(val any) (*Mkey, error) { + ret, err := NewTaggedPSARefValID(val) + if err != nil { + return nil, err + } + + return &Mkey{ret}, nil +} + +func NewMkeyCCAPlatformConfigID(val any) (*Mkey, error) { + ret, err := NewTaggedCCAPlatformConfigID(val) + if err != nil { + return nil, err + } + + return &Mkey{ret}, nil +} + +type IMkeyFactory = func(val any) (*Mkey, error) + +var mkeyValueRegister = map[string]IMkeyFactory{ + OIDType: NewMkeyOID, + UUIDType: NewMkeyUUID, + UintType: NewMkeyUint, + PSARefValIDType: NewMkeyPSARefvalID, + CCAPlatformConfigIDType: NewMkeyCCAPlatformConfigID, +} + +// RegisterMkeyType registers a new IMKeyValue implementation +// (created by the provided IMKeyFactory) under the specified type name +// and CBOR tag. +func RegisterMkeyType(tag uint64, factory IMkeyFactory) error { + + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Value.Type() + if _, exists := mkeyValueRegister[typ]; exists { + return fmt.Errorf("mesurement key type with name %q already exists", typ) + } + + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + mkeyValueRegister[typ] = factory + + return nil } // Mval stores a measurement-values-map with JSON and CBOR serializations. @@ -320,95 +391,112 @@ func (o Version) Valid() error { return nil } -// NewMeasurement instantiates an empty measurement -func NewMeasurement() *Measurement { - return &Measurement{} +// Measurement stores a measurement-map with CBOR and JSON serializations. +type Measurement struct { + Key *Mkey `cbor:"0,keyasint,omitempty" json:"key,omitempty"` + Val Mval `cbor:"1,keyasint" json:"value"` + AuthorizedBy *CryptoKey `cbor:"2,keyasint,omitempty" json:"authorized-by,omitempty"` } -// SetKeyPSARefValID sets the key of the target measurement-map to the supplied -// PSA refval-id -func (o *Measurement) SetKeyPSARefValID(psaRefValID PSARefValID) *Measurement { - if o != nil { - if psaRefValID.Valid() != nil { - return nil - } - o.Key = &Mkey{ - val: TaggedPSARefValID(psaRefValID), - } +func NewMeasurement(val any, typ string) (*Measurement, error) { + keyFactory, ok := mkeyValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unknown Mkey type: %s", typ) } - return o -} -// SetKeyCCAPlatformConfigID sets the key of the target measurement-map to the supplied -// CCA platform-config-id -func (o *Measurement) SetKeyCCAPlatformConfigID(ccaPlatformConfigID CCAPlatformConfigID) *Measurement { - if o != nil { - if ccaPlatformConfigID.Empty() { - return nil - } - o.Key = &Mkey{ - val: TaggedCCAPlatformConfigID(ccaPlatformConfigID), - } + key, err := keyFactory(val) + if err != nil { + return nil, fmt.Errorf("invalid key: %w", err) } - return o -} -// SetKeyKeyUUID sets the key of the target measurement-map to the supplied -// UUID -func (o *Measurement) SetKeyUUID(u UUID) *Measurement { - if o != nil { - if u.Empty() { - return nil - } + if err = key.Valid(); err != nil { + return nil, fmt.Errorf("invalid key: %w", err) + } - if u.Valid() != nil { - return nil - } + var ret Measurement + ret.Key = key - o.Key = &Mkey{ - val: TaggedUUID(u), - } - } - return o + return &ret, nil } -// SetKeyUint sets the key of the target measurement-map to the supplied -// unsigned integer -func (o *Measurement) SetKeyUint(u uint64) *Measurement { - if o != nil { - o.Key = &Mkey{ - val: u, - } +func MustNewMeasurement(val any, typ string) *Measurement { + ret, err := NewMeasurement(val, typ) + + if err != nil { + panic(err) } - return o + + return ret } // NewPSAMeasurement instantiates a new measurement-map with the key set to the // supplied PSA refval-id -func NewPSAMeasurement(psaRefValID PSARefValID) *Measurement { - m := &Measurement{} - return m.SetKeyPSARefValID(psaRefValID) +func NewPSAMeasurement(key any) (*Measurement, error) { + return NewMeasurement(key, PSARefValIDType) +} + +func MustNewPSAMeasurement(key any) *Measurement { + ret, err := NewPSAMeasurement(key) + + if err != nil { + panic(err) + } + + return ret } // NewCCAPlatCfgMeasurement instantiates a new measurement-map with the key set to the // supplied CCA platform-config-id -func NewCCAPlatCfgMeasurement(ccaPlatformConfigID CCAPlatformConfigID) *Measurement { - m := &Measurement{} - return m.SetKeyCCAPlatformConfigID(ccaPlatformConfigID) +func NewCCAPlatCfgMeasurement(key any) (*Measurement, error) { + return NewMeasurement(key, CCAPlatformConfigIDType) +} + +func MustNewCCAPlatCfgMeasurement(key any) *Measurement { + ret, err := NewCCAPlatCfgMeasurement(key) + + if err != nil { + panic(err) + } + + return ret } // NewUUIDMeasurement instantiates a new measurement-map with the key set to the // supplied UUID -func NewUUIDMeasurement(uuid UUID) *Measurement { - m := &Measurement{} - return m.SetKeyUUID(uuid) +func NewUUIDMeasurement(key any) (*Measurement, error) { + return NewMeasurement(key, UUIDType) +} + +func MustNewUUIDMeasurement(key any) *Measurement { + ret, err := NewUUIDMeasurement(key) + + if err != nil { + panic(err) + } + + return ret } // NewUintMeasurement instantiates a new measurement-map with the key set to the // supplied Uint -func NewUintMeasurement(mkey uint64) *Measurement { - m := &Measurement{} - return m.SetKeyUint(mkey) +func NewUintMeasurement(key any) (*Measurement, error) { + return NewMeasurement(key, UintType) +} + +func MustNewUintMeasurement(key any) *Measurement { + ret, err := NewUintMeasurement(key) + + if err != nil { + panic(err) + } + + return ret +} + +// NewOIDMeasurement instantiates a new measurement-map with the key set to the +// supplied OID +func NewOIDMeasurement(key any) (*Measurement, error) { + return NewMeasurement(key, OIDType) } func (o *Measurement) SetVersion(ver string, scheme int64) *Measurement { @@ -438,26 +526,14 @@ func (o *Measurement) SetRawValueBytes(rawValue, rawValueMask []byte) *Measureme // SetSVN sets the supplied svn in the measurement-values-map of the target // measurement func (o *Measurement) SetSVN(svn uint64) *Measurement { - if o != nil { - s := SVN{} - if s.SetSVN(svn) == nil { - return nil - } - o.Val.SVN = &s - } + o.Val.SVN = MustNewTaggedSVN(svn) return o } // SetMinSVN sets the supplied min-svn in the measurement-values-map of the // target measurement func (o *Measurement) SetMinSVN(svn uint64) *Measurement { - if o != nil { - s := SVN{} - if s.SetMinSVN(svn) == nil { - return nil - } - o.Val.SVN = &s - } + o.Val.SVN = MustNewTaggedMinSVN(svn) return o } diff --git a/comid/measurement_test.go b/comid/measurement_test.go index d5d58bff..c4f95e11 100644 --- a/comid/measurement_test.go +++ b/comid/measurement_test.go @@ -4,6 +4,7 @@ package comid import ( + "crypto" "fmt" "testing" @@ -14,86 +15,86 @@ import ( ) func TestMeasurement_NewUUIDMeasurement_good_uuid(t *testing.T) { - tv := NewUUIDMeasurement(TestUUID) - assert.NotNil(t, tv) + _, err := NewUUIDMeasurement(TestUUID) + assert.NoError(t, err) } func TestMeasurement_NewUUIDMeasurement_empty_uuid(t *testing.T) { emptyUUID := UUID{} - tv := NewUUIDMeasurement(emptyUUID) + _, err := NewUUIDMeasurement(emptyUUID) - assert.Nil(t, tv) + assert.EqualError(t, err, + "invalid key: expecting RFC4122 UUID, got Reserved instead") } func TestMeasurement_NewUIntMeasurement(t *testing.T) { var TestUint uint64 = 35 - tv := NewUintMeasurement(TestUint) + _, err := NewUintMeasurement(TestUint) - assert.NotNil(t, tv) + assert.NoError(t, err) } func TestMeasurement_NewPSAMeasurement_empty(t *testing.T) { emptyPSARefValID := PSARefValID{} - tv := NewPSAMeasurement(emptyPSARefValID) - assert.Nil(t, tv) + _, err := NewPSAMeasurement(emptyPSARefValID) + + assert.EqualError(t, err, "invalid key: invalid psa.refval-id: missing mandatory signer ID") } func TestMeasurement_NewPSAMeasurement_no_values(t *testing.T) { - psaRefValID := - NewPSARefValID(TestSignerID). - SetLabel("PRoT"). - SetVersion("1.2.3") + psaRefValID, err := NewPSARefValID(TestSignerID) + require.NoError(t, err) + psaRefValID.SetLabel("PRoT") + psaRefValID.SetVersion("1.2.3") require.NotNil(t, psaRefValID) - tv := NewPSAMeasurement(*psaRefValID) - assert.NotNil(t, tv) + tv, err := NewPSAMeasurement(*psaRefValID) + assert.NoError(t, err) - err := tv.Valid() + err = tv.Valid() assert.EqualError(t, err, "no measurement value set") } func TestMeasurement_NewCCAPlatCfgMeasurement_no_values(t *testing.T) { ccaplatID := CCAPlatformConfigID(TestCCALabel) - tv := NewCCAPlatCfgMeasurement(ccaplatID) - assert.NotNil(t, tv) + tv, err := NewCCAPlatCfgMeasurement(ccaplatID) + assert.NoError(t, err) - err := tv.Valid() + err = tv.Valid() assert.EqualError(t, err, "no measurement value set") } func TestMeasurement_NewCCAPlatCfgMeasurement_valid_meas(t *testing.T) { ccaplatID := CCAPlatformConfigID(TestCCALabel) - tv := NewCCAPlatCfgMeasurement(ccaplatID).SetRawValueBytes([]byte{0x01, 0x02, 0x03, 0x04}, []byte{}) - assert.NotNil(t, tv) + tv, err := NewCCAPlatCfgMeasurement(ccaplatID) + assert.NoError(t, err) - err := tv.Valid() - assert.Nil(t, err) + tv.SetRawValueBytes([]byte{0x01, 0x02, 0x03, 0x04}, []byte{}) + + err = tv.Valid() + assert.NoError(t, err) } func TestMeasurement_NewPSAMeasurement_one_value(t *testing.T) { - psaRefValID := - NewPSARefValID(TestSignerID). - SetLabel("PRoT"). - SetVersion("1.2.3") - require.NotNil(t, psaRefValID) + tv, err := NewPSAMeasurement(MustCreatePSARefValID(TestSignerID, "PRoT", "1.2.3")) + require.NoError(t, err) - tv := NewPSAMeasurement(*psaRefValID).SetIPaddr(TestIPaddr) - assert.NotNil(t, tv) + tv.SetIPaddr(TestIPaddr) - err := tv.Valid() + err = tv.Valid() assert.Nil(t, err) } func TestMeasurement_NewUUIDMeasurement_no_values(t *testing.T) { - tv := NewUUIDMeasurement(TestUUID) - require.NotNil(t, tv) + tv, err := NewUUIDMeasurement(TestUUID) + require.NoError(t, err) - err := tv.Valid() + err = tv.Valid() assert.EqualError(t, err, "no measurement value set") } @@ -101,26 +102,27 @@ func TestMeasurement_NewUUIDMeasurement_some_value(t *testing.T) { var vs swid.VersionScheme require.NoError(t, vs.SetCode(swid.VersionSchemeSemVer)) - tv := NewUUIDMeasurement(TestUUID). - SetMinSVN(2). + tv, err := NewUUIDMeasurement(TestUUID) + require.NoError(t, err) + + tv.SetMinSVN(2). SetFlagsTrue(FlagIsDebug). SetVersion("1.2.3", swid.VersionSchemeSemVer) - require.NotNil(t, tv) - err := tv.Valid() + err = tv.Valid() assert.Nil(t, err) } func TestMeasurement_NewUUIDMeasurement_bad_digest(t *testing.T) { - tv := NewUUIDMeasurement(TestUUID) - require.NotNil(t, tv) + tv, err := NewUUIDMeasurement(TestUUID) + require.NoError(t, err) assert.Nil(t, tv.AddDigest(swid.Sha256, []byte{0xff})) } func TestMeasurement_NewUUIDMeasurement_bad_ueid(t *testing.T) { - tv := NewUUIDMeasurement(TestUUID) - require.NotNil(t, tv) + tv, err := NewUUIDMeasurement(TestUUID) + require.NoError(t, err) badUEID := eat.UEID{ 0xFF, // Invalid @@ -131,8 +133,8 @@ func TestMeasurement_NewUUIDMeasurement_bad_ueid(t *testing.T) { } func TestMeasurement_NewUUIDMeasurement_bad_uuid(t *testing.T) { - tv := NewUUIDMeasurement(TestUUID) - require.NotNil(t, tv) + tv, err := NewUUIDMeasurement(TestUUID) + require.NoError(t, err) nonRFC4122UUID, err := ParseUUID("f47ac10b-58cc-4372-c567-0e02b2c3d479") require.Nil(t, err) @@ -147,7 +149,7 @@ var ( func TestMkey_Valid_no_value(t *testing.T) { mkey := &Mkey{} - expectedErr := "unknown measurement key type: " + expectedErr := "Mkey value not set" err := mkey.Valid() assert.EqualError(t, err, expectedErr) } @@ -183,19 +185,19 @@ func TestMKey_MarshalCBOR_CCAPlatformConfigID_ok(t *testing.T) { func TestMKey_UnmarshalCBOR_CCAPlatformConfigID_ok(t *testing.T) { tvs := []struct { input []byte - expected CCAPlatformConfigID + expected TaggedCCAPlatformConfigID }{ { input: MustHexDecode(t, "d9025a736363612d706c6174666f726d2d636f6e666967"), - expected: CCAPlatformConfigID(TestCCALabel), + expected: TaggedCCAPlatformConfigID(TestCCALabel), }, { input: MustHexDecode(t, "d9025a716d7974657374706c6174666f726d666967"), - expected: CCAPlatformConfigID("mytestplatformfig"), + expected: TaggedCCAPlatformConfigID("mytestplatformfig"), }, { input: MustHexDecode(t, "d9025a6c6d79746573746c6162656c32"), - expected: CCAPlatformConfigID("mytestlabel2"), + expected: TaggedCCAPlatformConfigID("mytestlabel2"), }, } @@ -203,9 +205,9 @@ func TestMKey_UnmarshalCBOR_CCAPlatformConfigID_ok(t *testing.T) { mkey := &Mkey{} err := mkey.UnmarshalCBOR(tv.input) assert.Nil(t, err) - actual, err := mkey.GetCCAPlatformConfigID() - assert.Nil(t, err) - assert.Equal(t, tv.expected, actual) + actual, ok := mkey.Value.(*TaggedCCAPlatformConfigID) + assert.True(t, ok) + assert.Equal(t, tv.expected, *actual) fmt.Printf("CBOR: %x\n", actual) } } @@ -230,36 +232,13 @@ func TestMKey_MarshalCBOR_uint_ok(t *testing.T) { } for _, tv := range tvs { - mkey := &Mkey{tv.mkey} + mkey := &Mkey{UintMkey(tv.mkey)} actual, err := mkey.MarshalCBOR() assert.Nil(t, err) assert.Equal(t, tv.expected, actual) fmt.Printf("CBOR: %x\n", actual) } } -func TestMkey_MarshalCBOR_uint_not_ok(t *testing.T) { - tvs := []struct { - input interface{} - expected string - }{ - { - input: 123.456, - expected: "unknown measurement key type: float64", - }, - { - input: "sample", - expected: "unknown measurement key type: string", - }, - } - - for _, tv := range tvs { - mkey := &Mkey{tv.input} - _, err := mkey.MarshalCBOR() - assert.Nil(t, err) - err = mkey.Valid() - assert.EqualError(t, err, tv.expected) - } -} func TestMkey_UnmarshalCBOR_uint_ok(t *testing.T) { tvs := []struct { @@ -284,10 +263,10 @@ func TestMkey_UnmarshalCBOR_uint_ok(t *testing.T) { mKey := &Mkey{} err := mKey.UnmarshalCBOR(tv.mkey) - assert.Nil(t, err) - actual, err := mKey.GetKeyUint() - assert.Nil(t, err) - assert.Equal(t, tv.expected, actual) + require.NoError(t, err) + actual, ok := mKey.Value.(*UintMkey) + require.True(t, ok) + assert.Equal(t, tv.expected, uint64(*actual)) } } @@ -315,38 +294,9 @@ func TestMkey_UnmarshalCBOR_not_ok(t *testing.T) { } } -func TestMkey_UnmarshalCBOR_uint_not_ok(t *testing.T) { - tvs := []struct { - input []byte - expected string - }{ - { - input: []byte{0xd8, 0x25, 0x50, 0x31, 0xfb, 0x5a, 0xbf, 0x02, - 0x3e, 0x49, 0x92, 0xaa, 0x4e, 0x95, 0xf9, 0xc1, - 0x50, 0x3b, 0xfa}, - expected: "measurement-key type is: comid.TaggedUUID", - }, - { - input: []byte{0xd8, 0x21, 0x50, 0x31, 0xfb, 0x5a, 0xff, 0x12, - 0xFF, 0xFF, 0x92, 0xaa, 0x4e, 0x95, 0xf9, 0xc1, - 0x50, 0x3b, 0xfa}, - expected: "measurement-key type is: cbor.Tag", - }, - } - - for _, tv := range tvs { - mKey := &Mkey{} - - err := mKey.UnmarshalCBOR(tv.input) - assert.Nil(t, err) - _, err = mKey.GetKeyUint() - assert.EqualError(t, err, tv.expected) - } -} - func TestMKey_MarshalJSON_CCAPlatformConfigID_ok(t *testing.T) { refval := TestCCALabel - mkey := &Mkey{val: TaggedCCAPlatformConfigID(refval)} + mkey := &Mkey{Value: TaggedCCAPlatformConfigID(refval)} expected := `{"type":"cca.platform-config-id","value":"cca-platform-config"}` @@ -359,30 +309,29 @@ func TestMKey_MarshalJSON_CCAPlatformConfigID_ok(t *testing.T) { func TestMKey_UnMarshalJSON_CCAPlatformConfigID_ok(t *testing.T) { input := []byte(`{"type":"cca.platform-config-id","value":"cca-platform-config"}`) - expected := CCAPlatformConfigID(TestCCALabel) + expected := TaggedCCAPlatformConfigID(TestCCALabel) mKey := &Mkey{} err := mKey.UnmarshalJSON(input) assert.Nil(t, err) - actual, err := mKey.GetCCAPlatformConfigID() - assert.Nil(t, err) - assert.Equal(t, expected, actual) + actual, ok := mKey.Value.(*TaggedCCAPlatformConfigID) + assert.True(t, ok) + assert.Equal(t, expected, *actual) } func TestMKey_UnMarshalJSON_CCAPlatformConfigID_not_ok(t *testing.T) { input := []byte(`{"type":"cca.platform-config-id","value":""}`) - expected := "cannot unmarshal $measured-element-type-choice of type CCAPlatformConfigID: empty label" + expected := "invalid cca.platform-config-id: empty value" mKey := &Mkey{} err := mKey.UnmarshalJSON(input) - assert.NotNil(t, err) - assert.Equal(t, expected, err.Error()) - + assert.EqualError(t, err, expected) } + func TestMkey_MarshalJSON_uint_ok(t *testing.T) { tvs := []struct { mkey uint64 @@ -404,7 +353,7 @@ func TestMkey_MarshalJSON_uint_ok(t *testing.T) { for _, tv := range tvs { - mkey := &Mkey{tv.mkey} + mkey := &Mkey{UintMkey(tv.mkey)} actual, err := mkey.MarshalJSON() assert.Nil(t, err) @@ -414,31 +363,6 @@ func TestMkey_MarshalJSON_uint_ok(t *testing.T) { } } -func TestMkey_MarshalJSON_uint_not_ok(t *testing.T) { - tvs := []struct { - input interface{} - expected string - }{ - { - input: 123.456, - expected: "unknown type float64 for mkey", - }, - { - input: "sample", - expected: "unknown type string for mkey", - }, - } - - for _, tv := range tvs { - - mkey := &Mkey{tv.input} - - _, err := mkey.MarshalJSON() - - assert.EqualError(t, err, tv.expected) - } -} - func TestMkey_UnmarshalJSON_uint_ok(t *testing.T) { tvs := []struct { input []byte @@ -463,9 +387,9 @@ func TestMkey_UnmarshalJSON_uint_ok(t *testing.T) { err := mKey.UnmarshalJSON(tv.input) assert.Nil(t, err) - actual, err := mKey.GetKeyUint() - assert.Nil(t, err) - assert.Equal(t, tv.expected, actual) + actual, ok := mKey.Value.(*UintMkey) + assert.True(t, ok) + assert.Equal(t, tv.expected, uint64(*actual)) } } @@ -476,11 +400,11 @@ func TestMkey_UnmarshalJSON_notok(t *testing.T) { }{ { input: []byte(`{"type":"uint","value":"abcdefg"}`), - expected: "cannot unmarshal $measured-element-type-choice of type uint: json: cannot unmarshal string into Go value of type uint64", + expected: `invalid uint: json: cannot unmarshal string into Go value of type uint64`, }, { input: []byte(`{"type":"uint","value":123.456}`), - expected: "cannot unmarshal $measured-element-type-choice of type uint: json: cannot unmarshal number 123.456 into Go value of type uint64", + expected: "invalid uint: json: cannot unmarshal number 123.456 into Go value of type uint64", }, } @@ -493,27 +417,102 @@ func TestMkey_UnmarshalJSON_notok(t *testing.T) { } } -func TestMkey_UnmarshalJSON_uint_notok(t *testing.T) { +func TestNewUintMkey(t *testing.T) { + testVal := UintMkey(7) + tvs := []struct { - input []byte - expected string + input any + expected UintMkey + err string }{ { - input: []byte(`{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}`), - expected: "measurement-key type is: comid.TaggedUUID", + input: testVal, + expected: testVal, + }, + { + input: &testVal, + expected: testVal, }, { - input: []byte(`{"type":"psa.refval-id","value":{"label": "BL","version": "2.1.0","signer-id": "rLsRx+TaIXIFUjzkzhokWuGiOa48a/2eeHH35di66Gs="}}`), - expected: "measurement-key type is: comid.TaggedPSARefValID", + input: uint(7), + expected: testVal, + }, + { + input: uint64(7), + expected: testVal, + }, + { + input: "7", + expected: testVal, + }, + { + input: true, + err: "unexpected type for UintMkey: bool", }, } for _, tv := range tvs { - mKey := &Mkey{} - - err := mKey.UnmarshalJSON(tv.input) - assert.Nil(t, err) - _, err = mKey.GetKeyUint() - assert.EqualError(t, err, tv.expected) + out, err := NewUintMkey(tv.input) + if tv.err != "" { + assert.Nil(t, out) + assert.EqualError(t, err, tv.err) + } else { + assert.Equal(t, tv.expected, *out) + } } } + +func TestNewMkeyOID(t *testing.T) { + var expectedOID OID + require.NoError(t, expectedOID.FromString(TestOID)) + expected := TaggedOID(expectedOID) + + out, err := NewMkeyOID(TestOID) + require.NoError(t, err) + assert.Equal(t, &expected, out.Value) +} + +type testMkey [4]byte + +func newTestMkey(val any) (*Mkey, error) { + return &Mkey{&testMkey{0x74, 0x64, 0x73, 0x74}}, nil +} + +func (o testMkey) PublicKey() (crypto.PublicKey, error) { + return crypto.PublicKey(o[:]), nil +} + +func (o testMkey) Type() string { + return "test-mkey" +} + +func (o testMkey) String() string { + return "test" +} + +func (o testMkey) Valid() error { + return nil +} + +type badMkey struct { + testMkey +} + +func (o badMkey) Type() string { + return "uuid" +} + +func newBadMkey(val any) (*Mkey, error) { + return &Mkey{&badMkey{testMkey{0x74, 0x64, 0x73, 0x74}}}, nil +} + +func TestRegisterMkeyType(t *testing.T) { + err := RegisterMkeyType(32, newTestMkey) + assert.EqualError(t, err, "tag 32 is already registered") + + err = RegisterMkeyType(99996, newBadMkey) + assert.EqualError(t, err, `mesurement key type with name "uuid" already exists`) + + err = RegisterMkeyType(99996, newTestMkey) + assert.NoError(t, err) +} diff --git a/comid/oid.go b/comid/oid.go index 822d7deb..bbd5dbdd 100644 --- a/comid/oid.go +++ b/comid/oid.go @@ -11,6 +11,8 @@ import ( "strings" ) +const OIDType = "oid" + // BER-encoded absolute OID type OID []byte @@ -152,3 +154,56 @@ func (o *OID) UnmarshalJSON(data []byte) error { func (o OID) MarshalJSON() ([]byte, error) { return json.Marshal(o.String()) } + +type TaggedOID OID + +func NewTaggedOID(val any) (*TaggedOID, error) { + ret := TaggedOID{} + + if val == nil { + return &ret, nil + } + + switch t := val.(type) { + case string: + var berOID OID + if err := berOID.FromString(t); err != nil { + return nil, err + } + + ret = TaggedOID(berOID) + case []byte: + ret = make([]byte, len(t)) + copy(ret, t) + case TaggedOID: + ret = make([]byte, len(t)) + copy(ret, t) + case OID: + ret = make([]byte, len(t)) + copy(ret, t) + case *TaggedOID: + ret = make([]byte, len(*t)) + copy(ret, (*t)) + case *OID: + ret = make([]byte, len(*t)) + copy(ret, (*t)) + } + + return &ret, nil +} + +func (o TaggedOID) Type() string { + return OIDType +} + +func (o TaggedOID) String() string { + return OID(o).String() +} + +func (o TaggedOID) Valid() error { + return nil +} + +func (o TaggedOID) Bytes() []byte { + return o +} diff --git a/comid/psareferencevalue.go b/comid/psareferencevalue.go index cde3304b..0ffacf97 100644 --- a/comid/psareferencevalue.go +++ b/comid/psareferencevalue.go @@ -4,9 +4,12 @@ package comid import ( + "encoding/json" "fmt" ) +var PSARefValIDType = "psa.refval-id" + // PSARefValID stores a PSA refval-id with CBOR and JSON serializations // (See https://datatracker.ietf.org/doc/html/draft-xyz-rats-psa-endorsements) type PSARefValID struct { @@ -30,18 +33,56 @@ func (o PSARefValID) Valid() error { return nil } -type TaggedPSARefValID PSARefValID +func CreatePSARefValID(signerID []byte, label, version string) (*PSARefValID, error) { + ret, err := NewPSARefValID(signerID) + if err != nil { + return nil, err + } -func NewPSARefValID(signerID []byte) *PSARefValID { - switch len(signerID) { - case 32, 48, 64: - default: - return nil + ret.SetLabel(label) + ret.SetVersion(version) + + return ret, nil +} + +func MustCreatePSARefValID(signerID []byte, label, version string) *PSARefValID { + ret, err := CreatePSARefValID(signerID, label, version) + + if err != nil { + panic(err) } - return &PSARefValID{ - SignerID: signerID, + return ret +} + +func NewPSARefValID(val any) (*PSARefValID, error) { + var ret PSARefValID + + if val == nil { + return &ret, nil } + + switch t := val.(type) { + case PSARefValID: + ret = t + case *PSARefValID: + ret = *t + case string: + if err := json.Unmarshal([]byte(t), &ret); err != nil { + return nil, err + } + case []byte: + switch len(t) { + case 32, 48, 64: + ret.SignerID = t + default: + return nil, fmt.Errorf("invalid PSA RefVal ID length: %d", len(t)) + } + default: + return nil, fmt.Errorf("unexpected type for PSA RefVal ID: %T", t) + } + + return &ret, nil } func (o *PSARefValID) SetLabel(label string) *PSARefValID { @@ -57,3 +98,46 @@ func (o *PSARefValID) SetVersion(version string) *PSARefValID { } return o } + +type TaggedPSARefValID PSARefValID + +func NewTaggedPSARefValID(val any) (*TaggedPSARefValID, error) { + var ret TaggedPSARefValID + + switch t := val.(type) { + case TaggedPSARefValID: + ret = t + case *TaggedPSARefValID: + ret = *t + default: + refvalID, err := NewPSARefValID(val) + if err != nil { + return nil, err + } + ret = TaggedPSARefValID(*refvalID) + + } + + return &ret, nil +} + +func (o TaggedPSARefValID) Valid() error { + return PSARefValID(o).Valid() +} + +func (o TaggedPSARefValID) String() string { + ret, err := json.Marshal(o) + if err != nil { + return "" + } + + return string(ret) +} + +func (o TaggedPSARefValID) Type() string { + return PSARefValIDType +} + +func (o TaggedPSARefValID) IsZero() bool { + return len(o.SignerID) == 0 +} diff --git a/comid/psareferencevalue_test.go b/comid/psareferencevalue_test.go index 069c432e..72116cc2 100644 --- a/comid/psareferencevalue_test.go +++ b/comid/psareferencevalue_test.go @@ -4,9 +4,11 @@ package comid import ( + "fmt" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestPSARefValID_Valid_SignerID_range(t *testing.T) { @@ -15,13 +17,27 @@ func TestPSARefValID_Valid_SignerID_range(t *testing.T) { for i := 1; i <= 100; i++ { signerID = append(signerID, byte(0xff)) - tv := NewPSARefValID(signerID) + tv, err := NewPSARefValID(signerID) + switch i { case 32, 48, 64: assert.NotNil(t, tv) assert.Nil(t, tv.Valid()) default: assert.Nil(t, tv) + assert.EqualError( + t, + err, + fmt.Sprintf("invalid PSA RefVal ID length: %d", i), + ) } } } + +func TestPSARefValID_Streing(t *testing.T) { + signerID := MustHexDecode(t, "deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef") + refvalID, err := NewTaggedPSARefValID(signerID) + require.NoError(t, err) + + assert.Equal(t, `{"signer-id":"3q2+796tvu/erb7v3q2+796tvu/erb7v3q2+796tvu8="}`, refvalID.String()) +} diff --git a/comid/referencevalue_test.go b/comid/referencevalue_test.go index 00b89d18..89fb18ef 100644 --- a/comid/referencevalue_test.go +++ b/comid/referencevalue_test.go @@ -13,10 +13,9 @@ func Test_ReferenceValue(t *testing.T) { err := rv.Valid() assert.EqualError(t, err, "environment validation failed: environment must not be empty") - rv.Environment.Instance = NewInstance() id, err := uuid.NewUUID() require.NoError(t, err) - rv.Environment.Instance.SetUUID(id) + rv.Environment.Instance = MustNewInstanceUUID(id) err = rv.Valid() assert.EqualError(t, err, "measurements validation failed: no measurement entries") } diff --git a/comid/role.go b/comid/role.go index 1b08322b..14fc6e3b 100644 --- a/comid/role.go +++ b/comid/role.go @@ -40,6 +40,23 @@ var ( } ) +func RegisterRole(val int64, name string) error { + role := Role(val) + + if _, ok := roleToString[role]; ok { + return fmt.Errorf("role with value %d already exists", val) + } + + if _, ok := stringToRole[name]; ok { + return fmt.Errorf("role with name %q already exists", name) + } + + roleToString[role] = name + stringToRole[name] = role + + return nil +} + type Roles []Role func NewRoles() *Roles { diff --git a/comid/role_test.go b/comid/role_test.go index 4479960c..0b9d585c 100644 --- a/comid/role_test.go +++ b/comid/role_test.go @@ -210,3 +210,20 @@ func TestRoles_UnmarshalJSON_fail(t *testing.T) { assert.EqualError(t, err, tv.expectedErr) } } + +func Test_RegisterRole(t *testing.T) { + err := RegisterRole(1, "owner") + assert.EqualError(t, err, "role with value 1 already exists") + + err = RegisterRole(3, "maintainer") + assert.EqualError(t, err, `role with name "maintainer" already exists`) + + err = RegisterRole(3, "owner") + assert.NoError(t, err) + + roles := NewRoles().Add(Role(3)) + + out, err := roles.MarshalJSON() + require.NoError(t, err) + assert.Equal(t, `["owner"]`, string(out)) +} diff --git a/comid/svn.go b/comid/svn.go index ed7a26bb..26a13f5f 100644 --- a/comid/svn.go +++ b/comid/svn.go @@ -6,102 +6,245 @@ package comid import ( "encoding/json" "fmt" -) + "strconv" -type TaggedSVN uint64 -type TaggedMinSVN uint64 + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" +) type SVN struct { - val interface{} + Value ISVNValue } -func (o *SVN) SetSVN(val uint64) *SVN { - if o != nil { - o.val = TaggedSVN(val) +func NewSVN(val any, typ string) (*SVN, error) { + factory, ok := svnValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unexpected SVN type: %s", typ) } - return o + + return factory(val) } -func (o *SVN) SetMinSVN(val uint64) *SVN { - if o != nil { - o.val = TaggedMinSVN(val) +func MustNewSVN(val any, typ string) *SVN { + ret, err := NewSVN(val, typ) + if err != nil { + panic(err) } - return o + + return ret } func (o SVN) MarshalCBOR() ([]byte, error) { - return em.Marshal(o.val) + return em.Marshal(o.Value) } func (o *SVN) UnmarshalCBOR(data []byte) error { - var svn TaggedSVN - - if dm.Unmarshal(data, &svn) == nil { - o.val = svn - return nil - } - - var minsvn TaggedMinSVN - - if dm.Unmarshal(data, &minsvn) == nil { - o.val = svn - return nil - } - - return fmt.Errorf("unknown SVN (CBOR: %x)", data) + return dm.Unmarshal(data, &o.Value) } -type svnJSONRepr tnv - // Supported formats: // { "type": "exact-value", "value": 123 } -> SVN // { "type": "min-value", "value": 123 } -> MinSVN func (o *SVN) UnmarshalJSON(data []byte) error { - var s svnJSONRepr + var tnv encoding.TypeAndValue - if err := json.Unmarshal(data, &s); err != nil { + if err := json.Unmarshal(data, &tnv); err != nil { return fmt.Errorf("SVN decoding failure: %w", err) } - var x uint64 - if err := json.Unmarshal(s.Value, &x); err != nil { + decoded, err := NewSVN(nil, tnv.Type) + if err != nil { + return err + } + + if err := json.Unmarshal(tnv.Value, &decoded.Value); err != nil { return fmt.Errorf( - "cannot unmarshal svn or min-svn: %w", + "cannot unmarshal svn: %w", err, ) } - switch s.Type { - case "exact-value": - o.val = TaggedSVN(x) - case "min-value": - o.val = TaggedMinSVN(x) - default: - return fmt.Errorf("unknown comparison operator %s", s.Type) + if err := decoded.Value.Valid(); err != nil { + return fmt.Errorf("invalid %s: %w", tnv.Type, err) } + o.Value = decoded.Value + return nil } func (o SVN) MarshalJSON() ([]byte, error) { - var ( - v svnJSONRepr - b []byte - err error - ) - - b, err = json.Marshal(o.val) + valueBytes, err := json.Marshal(o.Value) if err != nil { return nil, err } - switch t := o.val.(type) { + + value := encoding.TypeAndValue{ + Type: o.Value.Type(), + Value: valueBytes, + } + + return json.Marshal(value) +} + +type ISVNValue interface { + extensions.ITypeChoiceValue +} + +var TaggedSVNType = "exact-value" +var TaggedMinSVNType = "min-value" + +type TaggedSVN uint64 + +func NewTaggedSVN(val any) (*SVN, error) { + var ret TaggedSVN + + if val == nil { + return &SVN{&ret}, nil + } + + switch t := val.(type) { + case string: + u, err := strconv.ParseUint(t, 10, 64) + if err != nil { + return nil, err + } + ret = TaggedSVN(u) case TaggedSVN: - v = svnJSONRepr{Type: "exact-value", Value: b} + ret = t + case *TaggedSVN: + ret = *t + case uint64: + ret = TaggedSVN(t) + case uint: + ret = TaggedSVN(t) + case int: + if t < 0 { + return nil, fmt.Errorf("SVN cannot be negative: %d", t) + } + ret = TaggedSVN(t) + case int64: + if t < 0 { + return nil, fmt.Errorf("SVN cannot be negative: %d", t) + } + ret = TaggedSVN(t) + default: + return nil, fmt.Errorf("unexpected type for SVN exact value: %T", t) + } + + return &SVN{&ret}, nil +} + +func MustNewTaggedSVN(val any) *SVN { + ret, err := NewTaggedSVN(val) + if err != nil { + panic(err) + } + + return ret +} + +func (o TaggedSVN) String() string { + return fmt.Sprint(uint64(o)) +} + +func (o TaggedSVN) Type() string { + return TaggedSVNType +} + +func (o TaggedSVN) Valid() error { + return nil +} + +type TaggedMinSVN uint64 + +func NewTaggedMinSVN(val any) (*SVN, error) { + var ret TaggedMinSVN + + if val == nil { + return &SVN{&ret}, nil + } + + switch t := val.(type) { + case string: + u, err := strconv.ParseUint(t, 10, 64) + if err != nil { + return nil, err + } + ret = TaggedMinSVN(u) case TaggedMinSVN: - v = svnJSONRepr{Type: "min-value", Value: b} + ret = t + case *TaggedMinSVN: + ret = *t + case uint64: + ret = TaggedMinSVN(t) + case uint: + ret = TaggedMinSVN(t) + case int: + if t < 0 { + return nil, fmt.Errorf("SVN cannot be negative: %d", t) + } + ret = TaggedMinSVN(t) + case int64: + if t < 0 { + return nil, fmt.Errorf("SVN cannot be negative: %d", t) + } + ret = TaggedMinSVN(t) default: - return nil, fmt.Errorf("unknown SVN type: %T", t) + return nil, fmt.Errorf("unexpected type for SVN min value: %T", t) + } + + return &SVN{&ret}, nil +} + +func MustNewTaggedMinSVN(val any) *SVN { + ret, err := NewTaggedMinSVN(val) + if err != nil { + panic(err) } - return json.Marshal(v) + return ret +} + +func (o TaggedMinSVN) String() string { + return fmt.Sprint(uint64(o)) +} + +func (o TaggedMinSVN) Type() string { + return TaggedMinSVNType +} + +func (o TaggedMinSVN) Valid() error { + return nil +} + +type ISVNFactory func(any) (*SVN, error) + +var svnValueRegister = map[string]ISVNFactory{ + TaggedSVNType: NewTaggedSVN, + TaggedMinSVNType: NewTaggedMinSVN, +} + +// RegisterSVNType registers a new ISVNValue implementation +// (created by the provided ISVNFactory) under the specified type name +// and CBOR tag. +func RegisterSVNType(tag uint64, factory ISVNFactory) error { + + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Value.Type() + if _, exists := svnValueRegister[typ]; exists { + return fmt.Errorf("SVN type with name %q already exists", typ) + } + + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + svnValueRegister[typ] = factory + + return nil } diff --git a/comid/svn_test.go b/comid/svn_test.go new file mode 100644 index 00000000..d3fe5f49 --- /dev/null +++ b/comid/svn_test.go @@ -0,0 +1,182 @@ +package comid + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_NewSVN(t *testing.T) { + for _, tv := range []struct { + Name string + Input any + Expected uint64 + Err string + }{ + { + Name: "string ok", + Input: "7", + Expected: 7, + Err: "", + }, + { + Name: "string err", + Input: "test", + Expected: 0, + Err: `strconv.ParseUint: parsing "test": invalid syntax`, + }, + { + Name: "uint", + Input: uint(7), + Expected: 7, + Err: "", + }, + { + Name: "uint64", + Input: uint64(7), + Expected: 7, + Err: "", + }, + { + Name: "int ok", + Input: 7, + Expected: 7, + Err: "", + }, + { + Name: "int not ok", + Input: -7, + Expected: 0, + Err: "SVN cannot be negative: -7", + }, + { + Name: "int64 ok", + Input: int64(7), + Expected: 7, + Err: "", + }, + { + Name: "int64 not ok", + Input: int64(-7), + Expected: 0, + Err: "SVN cannot be negative: -7", + }, + { + Name: "nil", + Input: nil, + Expected: 0, + Err: "", + }, + } { + t.Run(tv.Name, func(t *testing.T) { + ret, err := NewSVN(tv.Input, "exact-value") + exact := TaggedSVN(tv.Expected) + expected := SVN{&exact} + + if tv.Err != "" { + assert.EqualError(t, err, tv.Err) + } else { + assert.Equal(t, &expected, ret) + } + + retMin, err := NewSVN(tv.Input, "min-value") + min := TaggedMinSVN(tv.Expected) + expected = SVN{&min} + + if tv.Err != "" { + assert.EqualError(t, err, tv.Err) + } else { + assert.Equal(t, &expected, retMin) + } + }) + } + + in := TaggedSVN(7) + + _, err := NewSVN(in, "exact-value") + assert.NoError(t, err) + + _, err = NewSVN(&in, "exact-value") + assert.NoError(t, err) + + _, err = NewSVN(true, "exact-value") + assert.EqualError(t, err, "unexpected type for SVN exact value: bool") + + inMin := TaggedMinSVN(7) + + _, err = NewSVN(inMin, "min-value") + assert.NoError(t, err) + + _, err = NewSVN(&inMin, "min-value") + assert.NoError(t, err) + + _, err = NewSVN(true, "min-value") + assert.EqualError(t, err, "unexpected type for SVN min value: bool") + + _, err = NewSVN(true, "test") + assert.EqualError(t, err, "unexpected SVN type: test") + + ret := MustNewSVN(7, "exact-value") + assert.NotNil(t, ret) + + assert.Panics(t, func() { MustNewSVN(true, "exact-value") }) +} + +func TestSVN_JSON(t *testing.T) { + var v SVN + + err := v.UnmarshalJSON([]byte(`{"type":"exact-value","value":2.3}`)) + assert.EqualError(t, err, "cannot unmarshal svn: json: cannot unmarshal number 2.3 into Go value of type comid.TaggedSVN") + + err = v.UnmarshalJSON([]byte(`{"type":"test","value":7}`)) + assert.EqualError(t, err, "unexpected SVN type: test") + + err = v.UnmarshalJSON([]byte(`@@@`)) + assert.EqualError(t, err, "SVN decoding failure: invalid character '@' looking for beginning of value") + +} + +type testSVN uint64 + +func newTestSVN(val any) (*SVN, error) { + v := testSVN(7) + return &SVN{&v}, nil +} + +func (o testSVN) Type() string { + return "test-value" +} + +func (o testSVN) String() string { + return "test" +} + +func (o testSVN) Valid() error { + return nil +} + +type testSVNBadType struct { + testSVN +} + +func newTestSVNBadType(val any) (*SVN, error) { + v := testSVNBadType{testSVN(7)} + return &SVN{&v}, nil +} + +func (o testSVNBadType) Type() string { + return "min-value" +} + +func Test_RegisterSVNType(t *testing.T) { + err := RegisterSVNType(32, newTestSVN) + assert.EqualError(t, err, "tag 32 is already registered") + + err = RegisterSVNType(99995, newTestSVNBadType) + assert.EqualError(t, err, `SVN type with name "min-value" already exists`) + + err = RegisterSVNType(99995, newTestSVN) + require.NoError(t, err) + +} diff --git a/comid/ueid.go b/comid/ueid.go index 765d1b86..cf64ffa2 100644 --- a/comid/ueid.go +++ b/comid/ueid.go @@ -4,18 +4,17 @@ package comid import ( - "encoding/json" + "encoding/base64" "fmt" "github.com/veraison/eat" ) +const UEIDType = "ueid" + // UEID is an Unique Entity Identifier type UEID eat.UEID -// TaggedUEID is an alias to allow automatic tagging of an UEID type -type TaggedUEID UEID - func (o UEID) Empty() bool { return len(o) == 0 } @@ -28,25 +27,65 @@ func (o UEID) Valid() error { return nil } -// UnmarshalJSON deserializes the supplied string into the UEID target -func (o *UEID) UnmarshalJSON(data []byte) error { - var b []byte +func (o UEID) String() string { + return base64.StdEncoding.EncodeToString(o) +} + +// TaggedUEID is an alias to allow automatic tagging of an UEID type +type TaggedUEID UEID + +func NewTaggedUEID(val any) (*TaggedUEID, error) { + var ret TaggedUEID - if err := json.Unmarshal(data, &b); err != nil { - return err + if val == nil { + return &ret, nil } - u := UEID(b) + switch t := val.(type) { + case string: + b, err := base64.StdEncoding.DecodeString(t) + if err != nil { + return nil, fmt.Errorf("bad UEID: %w", err) + } - if err := u.Valid(); err != nil { - return err + ret = TaggedUEID(b) + case []byte: + ret = TaggedUEID(t) + case TaggedUEID: + ret = append(ret, t...) + case *TaggedUEID: + ret = append(ret, *t...) + case UEID: + ret = append(ret, t...) + case *UEID: + ret = append(ret, *t...) + case eat.UEID: + ret = append(ret, t...) + case *eat.UEID: + ret = append(ret, *t...) + default: + return nil, fmt.Errorf("unexpeted type for UEID: %T", t) } - *o = u + if err := ret.Valid(); err != nil { + return nil, err + } - return nil + return &ret, nil +} + +func (o TaggedUEID) Valid() error { + return UEID(o).Valid() +} + +func (o TaggedUEID) String() string { + return UEID(o).String() +} + +func (o TaggedUEID) Type() string { + return "ueid" } -func (o UEID) MarshalJSON() ([]byte, error) { - return json.Marshal([]byte(o)) +func (o TaggedUEID) Bytes() []byte { + return []byte(o) } diff --git a/comid/ueid_test.go b/comid/ueid_test.go new file mode 100644 index 00000000..a119ad44 --- /dev/null +++ b/comid/ueid_test.go @@ -0,0 +1,30 @@ +package comid + +import ( + "encoding/base64" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_NewTaggedUEID(t *testing.T) { + ueid := UEID(TestUEID) + tagged := TaggedUEID(TestUEID) + bytes := MustHexDecode(t, TestUEIDString) + + for _, v := range []any{ + TestUEID, + &TestUEID, + ueid, + &ueid, + tagged, + &tagged, + bytes, + base64.StdEncoding.EncodeToString(bytes), + } { + ret, err := NewTaggedUEID(v) + require.NoError(t, err) + assert.Equal(t, []byte(TestUEID), ret.Bytes()) + } +} diff --git a/comid/uuid.go b/comid/uuid.go index c5c78d61..b09171dd 100644 --- a/comid/uuid.go +++ b/comid/uuid.go @@ -10,12 +10,11 @@ import ( "github.com/google/uuid" ) +const UUIDType = "uuid" + // UUID represents an Universally Unique Identifier (UUID, see RFC4122) type UUID uuid.UUID -// TaggedUUID is an alias to allow automatic tagging of a UUID type -type TaggedUUID UUID - // ParseUUID parses the supplied string into a UUID func ParseUUID(s string) (UUID, error) { v, err := uuid.Parse(s) @@ -64,3 +63,89 @@ func (o *UUID) UnmarshalJSON(data []byte) error { func (o UUID) MarshalJSON() ([]byte, error) { return json.Marshal(o.String()) } + +// TaggedUUID is an alias to allow automatic tagging of a UUID type +type TaggedUUID UUID + +func NewTaggedUUID(val any) (*TaggedUUID, error) { + var ret TaggedUUID + + switch t := val.(type) { + case string: + u, err := ParseUUID(t) + if err != nil { + return nil, fmt.Errorf("bad UUID: %w", err) + } + ret = TaggedUUID(u) + case []byte: + if len(t) != 16 { + return nil, fmt.Errorf( + "unexpected size for UUID: expected 16 bytes, found %d", + len(t), + ) + } + + copy(ret[:], t) + case TaggedUUID: + copy(ret[:], t[:]) + case *TaggedUUID: + copy(ret[:], (*t)[:]) + case UUID: + copy(ret[:], t[:]) + case *UUID: + copy(ret[:], (*t)[:]) + case uuid.UUID: + copy(ret[:], t[:]) + case *uuid.UUID: + copy(ret[:], (*t)[:]) + default: + return nil, fmt.Errorf("unexpected type for UUID: %T", t) + } + + if err := ret.Valid(); err != nil { + return nil, err + } + + return &ret, nil +} + +// String returns a string representation of the binary UUID +func (o TaggedUUID) String() string { + return UUID(o).String() +} + +func (o TaggedUUID) Valid() error { + return UUID(o).Valid() +} + +// Type returns a string containing type name. This is part of the +// ITypeChoiceValue implementation. +func (o TaggedUUID) Type() string { + return "uuid" +} + +// Bytes returns a []byte containing the raw UUID bytes +func (o TaggedUUID) Bytes() []byte { + return o[:] +} + +func (o TaggedUUID) MarshalJSON() ([]byte, error) { + temp := o.String() + return json.Marshal(temp) +} + +func (o *TaggedUUID) UnmarshalJSON(data []byte) error { + var temp string + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + u, err := ParseUUID(temp) + if err != nil { + return fmt.Errorf("bad UUID: %w", err) + } + + *o = TaggedUUID(u) + + return nil +} diff --git a/comid/uuid_test.go b/comid/uuid_test.go new file mode 100644 index 00000000..5308854a --- /dev/null +++ b/comid/uuid_test.go @@ -0,0 +1,24 @@ +package comid + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUUID_JSON(t *testing.T) { + val := TaggedUUID(TestUUID) + expected := fmt.Sprintf(`"%s"`, val.String()) + + out, err := val.MarshalJSON() + require.NoError(t, err) + assert.Equal(t, expected, string(out)) + + var outUUID TaggedUUID + + err = outUUID.UnmarshalJSON(out) + require.NoError(t, err) + assert.Equal(t, val, outUUID) +} diff --git a/corim/cbor.go b/corim/cbor.go index ec15e95f..17464d1d 100644 --- a/corim/cbor.go +++ b/corim/cbor.go @@ -4,6 +4,7 @@ package corim import ( + "fmt" "reflect" cbor "github.com/fxamacker/cbor/v2" @@ -18,13 +19,13 @@ var ( var ( CoswidTag = []byte{0xd9, 0x01, 0xf9} // 505() ComidTag = []byte{0xd9, 0x01, 0xfa} // 506() -) -func corimTags() cbor.TagSet { - corimTagsMap := map[uint64]interface{}{ + corimTagsMap = map[uint64]interface{}{ 32: comid.TaggedURI(""), } +) +func corimTags() cbor.TagSet { opts := cbor.TagOptions{ EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired, @@ -57,6 +58,28 @@ func initCBORDecMode() (dm cbor.DecMode, err error) { return decOpt.DecModeWithTags(corimTags()) } +func registerCORIMTag(tag uint64, t interface{}) error { + if _, exists := corimTagsMap[tag]; exists { + return fmt.Errorf("tag %d is already registered", tag) + } + + corimTagsMap[tag] = t + + var err error + + em, err = initCBOREncMode() + if err != nil { + return err + } + + dm, err = initCBORDecMode() + if err != nil { + return err + } + + return nil +} + func init() { if emError != nil { panic(emError) diff --git a/corim/entity.go b/corim/entity.go index b305114b..453068cf 100644 --- a/corim/entity.go +++ b/corim/entity.go @@ -4,16 +4,215 @@ package corim import ( + "encoding/json" + "errors" "fmt" + "unicode/utf8" "github.com/veraison/corim/comid" "github.com/veraison/corim/encoding" "github.com/veraison/corim/extensions" ) +type EntityName struct { + Value IEntityName +} + +func NewEntityName(val any, typ string) (*EntityName, error) { + factory, ok := entityNameValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unexpected entity name type: %s", typ) + } + + return factory(val) +} + +func (o EntityName) String() string { + return o.Value.String() +} + +func (o EntityName) Valid() error { + if o.Value == nil { + return errors.New("empty entity name") + } + + return o.Value.Valid() +} + +func (o EntityName) MarshalCBOR() ([]byte, error) { + if err := o.Valid(); err != nil { + return nil, err + } + + return em.Marshal(o.Value) +} + +func (o *EntityName) UnmarshalCBOR(data []byte) error { + if len(data) == 0 { + return errors.New("empty") + } + + majorType := (data[0] & 0xe0) >> 5 + if majorType == 3 { // text string + var text string + + if err := dm.Unmarshal(data, &text); err != nil { + return err + } + + name := StringEntityName(text) + o.Value = &name + + return nil + } + + return dm.Unmarshal(data, &o.Value) +} + +func (o EntityName) MarshalJSON() ([]byte, error) { + if err := o.Valid(); err != nil { + return nil, err + } + + if o.Value.Type() == extensions.StringType { + return json.Marshal(o.Value.String()) + } + + valueBytes, err := json.Marshal(o.Value) + if err != nil { + return nil, err + } + + value := encoding.TypeAndValue{ + Type: o.Value.Type(), + Value: valueBytes, + } + + return json.Marshal(value) +} + +func (o *EntityName) UnmarshalJSON(data []byte) error { + var text string + if err := json.Unmarshal(data, &text); err == nil { + *o = *MustNewStringEntityName(text) + return nil + } + + var tnv encoding.TypeAndValue + + if err := json.Unmarshal(data, &tnv); err != nil { + return fmt.Errorf("entity name decoding failure: %w", err) + } + + decoded, err := NewEntityName(nil, tnv.Type) + if err != nil { + return err + } + + if err := json.Unmarshal(tnv.Value, &decoded.Value); err != nil { + return fmt.Errorf( + "cannot unmarshal entity name: %w", + err, + ) + } + + if err := decoded.Value.Valid(); err != nil { + return fmt.Errorf("invalid %s: %w", tnv.Type, err) + } + + o.Value = decoded.Value + + return nil +} + +type IEntityName interface { + extensions.ITypeChoiceValue +} + +type IEntityNameFactory func(any) (*EntityName, error) + +var entityNameValueRegister = map[string]IEntityNameFactory{ + extensions.StringType: NewStringEntityName, +} + +// RegisterEntityNameType registers a new IEntityNameValue implementation +// (created by the provided IEntityNameFactory) under the specified type name +// and CBOR tag. +func RegisterEntityNameType(tag uint64, factory IEntityNameFactory) error { + + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Value.Type() + if _, exists := entityNameValueRegister[typ]; exists { + return fmt.Errorf("entity name type with name %q already exists", typ) + } + + if err := registerCORIMTag(tag, nilVal.Value); err != nil { + return err + } + + entityNameValueRegister[typ] = factory + + return nil +} + +type StringEntityName string + +func NewStringEntityName(val any) (*EntityName, error) { + var ret StringEntityName + + if val == nil { + ret = StringEntityName("") + return &EntityName{&ret}, nil + } + + switch t := val.(type) { + case string: + ret = StringEntityName(t) + case []byte: + if !utf8.Valid(t) { + return nil, errors.New("bytes do not form a valid UTF-8 string") + } + + ret = StringEntityName(t) + default: + return nil, fmt.Errorf("unexpected type for string entity name: %T", t) + } + + return &EntityName{&ret}, nil +} + +func MustNewStringEntityName(val any) *EntityName { + ret, err := NewStringEntityName(val) + if err != nil { + panic(err) + } + + return ret +} + +func (o StringEntityName) String() string { + return string(o) +} + +func (o StringEntityName) Type() string { + return extensions.StringType +} + +func (o StringEntityName) Valid() error { + if o == "" { + return errors.New("empty entity-name") + } + + return nil +} + // Entity stores an entity-map capable of CBOR and JSON serializations. type Entity struct { - EntityName string `cbor:"0,keyasint" json:"name"` + EntityName *EntityName `cbor:"0,keyasint" json:"name"` RegID *comid.TaggedURI `cbor:"1,keyasint,omitempty" json:"regid,omitempty"` Roles Roles `cbor:"2,keyasint" json:"roles"` @@ -35,12 +234,13 @@ func (o *Entity) GetExtensions() extensions.IExtensionsValue { } // SetEntityName is used to set the EntityName field of Entity using supplied name -func (o *Entity) SetEntityName(name string) *Entity { +func (o *Entity) SetEntityName(name any) *Entity { + if o != nil { if name == "" { return nil } - o.EntityName = name + o.EntityName = MustNewStringEntityName(name) } return o } @@ -74,10 +274,14 @@ func (o *Entity) SetRoles(roles ...Role) *Entity { // Valid checks for validity of the fields within each Entity func (o Entity) Valid() error { - if o.EntityName == "" { + if o.EntityName == nil { return fmt.Errorf("invalid entity: empty entity-name") } + if err := o.EntityName.Valid(); err != nil { + return fmt.Errorf("invalid entity: %w", err) + } + if o.RegID != nil && o.RegID.Empty() { return fmt.Errorf("invalid entity: empty reg-id") } diff --git a/corim/entity_test.go b/corim/entity_test.go index 457b3770..07fe6849 100644 --- a/corim/entity_test.go +++ b/corim/entity_test.go @@ -4,6 +4,8 @@ package corim import ( + "errors" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -21,7 +23,7 @@ func TestEntity_Valid_uninitialized(t *testing.T) { func TestEntity_Valid_empty_name(t *testing.T) { tv := Entity{ - EntityName: "", + EntityName: MustNewStringEntityName(""), } err := tv.Valid() @@ -33,7 +35,7 @@ func TestEntity_Valid_non_nil_empty_URI(t *testing.T) { emptyRegID := comid.TaggedURI("") tv := Entity{ - EntityName: "ACME Ltd.", + EntityName: MustNewStringEntityName("ACME Ltd."), RegID: &emptyRegID, } @@ -46,7 +48,7 @@ func TestEntity_Valid_missing_roles(t *testing.T) { regID := comid.TaggedURI("http://acme.example") tv := Entity{ - EntityName: "ACME Ltd.", + EntityName: MustNewStringEntityName("ACME Ltd."), RegID: ®ID, } @@ -59,7 +61,7 @@ func TestEntity_Valid_unknown_role(t *testing.T) { regID := comid.TaggedURI("http://acme.example") tv := Entity{ - EntityName: "ACME Ltd.", + EntityName: MustNewStringEntityName("ACME Ltd."), RegID: ®ID, Roles: Roles{Role(666)}, } @@ -92,3 +94,176 @@ func TestEntities_Valid_empty(t *testing.T) { err := es.Valid() assert.EqualError(t, err, "entity at index 0: invalid entity: empty entity-name") } + +type testEntityName uint64 + +func newTestEntityName(val any) (*EntityName, error) { + if val == nil { + v := testEntityName(0) + return &EntityName{&v}, nil + } + + u, ok := val.(uint64) + if !ok { + return nil, errors.New("must be uint64") + } + + v := testEntityName(u) + return &EntityName{&v}, nil +} + +func (o testEntityName) Type() string { + return "test" +} + +func (o testEntityName) String() string { + return fmt.Sprint(uint64(o)) +} + +func (o testEntityName) Valid() error { + return nil +} + +type testEntityNameBadType struct { + testEntityName +} + +func newTestEntityNameBadType(val any) (*EntityName, error) { + v := testEntityNameBadType{testEntityName(7)} + return &EntityName{&v}, nil +} + +func (o testEntityNameBadType) Type() string { + return "string" +} + +func Test_RegisterEntityNameType(t *testing.T) { + err := RegisterEntityNameType(32, newTestEntityName) + assert.EqualError(t, err, "tag 32 is already registered") + + err = RegisterEntityNameType(99994, newTestEntityNameBadType) + assert.EqualError(t, err, `entity name type with name "string" already exists`) + + registerTestEntityNameType(t) +} + +// Since there only one, untagged, entity name type in the core spec, we use +// the test type define above in order to test the marshalling code works +// properly. Since global enviroment is not reset when running multiple tests, +// we cannot simply call RegisterEntityNameType() inside each test that relies +// on the test type, as that will cause the "tag already registred" error. On +// the other hand, we do not want to create inter-test dependencies by relying +// that the test registering the type is run before the others that rely on it. +// To get around this, use this global flag to only register the test type if a +// previous test hasn't already done so. +var testEntityNameTypeRegistered = false + +func registerTestEntityNameType(t *testing.T) { + if !testEntityNameTypeRegistered { + err := RegisterEntityNameType(99994, newTestEntityName) + require.NoError(t, err) + + testEntityNameTypeRegistered = true + } +} + +func TestEntityName_CBOR(t *testing.T) { + registerTestEntityNameType(t) + + for _, tv := range []struct { + Value any + Type string + ExpectedBytes []byte + ExpectedString string + }{ + { + Value: "test", + Type: "string", + ExpectedBytes: []byte{ + 0x64, // tstr(4) + 0x74, 0x65, 0x73, 0x74, // "test" + }, + ExpectedString: "test", + }, + { + Value: uint64(7), + Type: "test", + ExpectedBytes: []byte{ + 0xda, 0x0, 0x1, 0x86, 0x9a, // tag 99994 + 0x07, // unsigned int(7) + }, + ExpectedString: "7", + }, + } { + t.Run(tv.Type, func(t *testing.T) { + en, err := NewEntityName(tv.Value, tv.Type) + require.NoError(t, err) + + data, err := en.MarshalCBOR() + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedBytes, data) + + var out EntityName + + err = out.UnmarshalCBOR(data) + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedString, out.String()) + }) + } +} + +func TestEntityName_JSON(t *testing.T) { + registerTestEntityNameType(t) + + for _, tv := range []struct { + Value any + Type string + ExpectedBytes []byte + ExpectedString string + }{ + { + Value: "test", + Type: "string", + ExpectedBytes: []byte(`"test"`), + ExpectedString: "test", + }, + { + Value: uint64(7), + Type: "test", + ExpectedBytes: []byte(`{"type":"test","value":7}`), + ExpectedString: "7", + }, + } { + t.Run(tv.Type, func(t *testing.T) { + en, err := NewEntityName(tv.Value, tv.Type) + require.NoError(t, err) + + data, err := en.MarshalJSON() + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedBytes, data) + + var out EntityName + + err = out.UnmarshalJSON(data) + require.NoError(t, err) + + assert.Equal(t, tv.ExpectedString, out.String()) + }) + } +} + +func Test_NewStringEntityName(t *testing.T) { + out, err := NewStringEntityName(nil) + require.NoError(t, err) + assert.EqualError(t, out.Valid(), "empty entity-name") + + out, err = NewStringEntityName([]byte("test")) + require.NoError(t, err) + assert.Equal(t, "test", out.String()) + + _, err = NewStringEntityName(7) + assert.EqualError(t, err, "unexpected type for string entity name: int") +} diff --git a/corim/extensions_test.go b/corim/extensions_test.go index 653a13e1..811d8e4d 100644 --- a/corim/extensions_test.go +++ b/corim/extensions_test.go @@ -17,7 +17,7 @@ type TestExtensions struct { } func (o TestExtensions) ValidEntity(ent *Entity) error { - if ent.EntityName != "Futurama" { + if ent.EntityName.String() != "Futurama" { return errors.New(`EntityName must be "Futurama"`) // nolint:golint } @@ -78,7 +78,7 @@ func TestEntityExtensions_CBOR(t *testing.T) { err := cbor.Unmarshal(data, &ent) assert.NoError(t, err) - assert.Equal(t, ent.EntityName, "acme") + assert.Equal(t, ent.EntityName.String(), "acme") address, err := ent.Get("address") require.NoError(t, err) diff --git a/corim/role.go b/corim/role.go index 9bfc1276..d81d3cf7 100644 --- a/corim/role.go +++ b/corim/role.go @@ -24,6 +24,23 @@ var ( } ) +func RegisterRole(val int64, name string) error { + role := Role(val) + + if _, ok := roleToString[role]; ok { + return fmt.Errorf("role with value %d already exists", val) + } + + if _, ok := stringToRole[name]; ok { + return fmt.Errorf("role with name %q already exists", name) + } + + roleToString[role] = name + stringToRole[name] = role + + return nil +} + type Roles []Role func NewRoles() *Roles { @@ -44,7 +61,8 @@ func (o *Roles) Add(roles ...Role) *Roles { } func isRole(r Role) bool { - return r == RoleManifestCreator + _, ok := roleToString[r] + return ok } // Valid iterates over the range of individual roles to check for validity diff --git a/corim/role_test.go b/corim/role_test.go index 66df92b0..d2d13cd3 100644 --- a/corim/role_test.go +++ b/corim/role_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestRoles_ToJSON_ok(t *testing.T) { @@ -77,3 +78,20 @@ func TestRoles_FromJSON_fail(t *testing.T) { assert.EqualError(t, err, tv.expectedErr) } } + +func Test_RegisterRole(t *testing.T) { + err := RegisterRole(1, "owner") + assert.EqualError(t, err, "role with value 1 already exists") + + err = RegisterRole(3, "manifestCreator") + assert.EqualError(t, err, `role with name "manifestCreator" already exists`) + + err = RegisterRole(3, "owner") + assert.NoError(t, err) + + roles := NewRoles().Add(Role(3)) + + out, err := roles.MarshalJSON() + require.NoError(t, err) + assert.Equal(t, `["owner"]`, string(out)) +} diff --git a/corim/unsignedcorim_test.go b/corim/unsignedcorim_test.go index d9b4d105..fb75c192 100644 --- a/corim/unsignedcorim_test.go +++ b/corim/unsignedcorim_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/veraison/corim/comid" @@ -182,7 +181,7 @@ func TestUnsignedCorim_Valid_ok(t *testing.T) { AddAttestVerifKey( comid.AttestVerifKey{ Environment: comid.Environment{ - Instance: comid.NewInstanceUUID(uuid.UUID(comid.TestUUID)), + Instance: comid.MustNewInstanceUUID(comid.TestUUID), }, VerifKeys: *comid.NewCryptoKeys(). Add( @@ -264,7 +263,7 @@ func TestUnsignedCorim_AddEntity_full(t *testing.T) { expected := UnsignedCorim{ Entities: &Entities{ Entity{ - EntityName: name, + EntityName: MustNewStringEntityName(name), Roles: Roles{role}, RegID: &taggedRegID, }, diff --git a/encoding/json.go b/encoding/json.go new file mode 100644 index 00000000..3bd756f6 --- /dev/null +++ b/encoding/json.go @@ -0,0 +1,39 @@ +package encoding + +import ( + "encoding/json" + "errors" + "fmt" +) + +// TypeAndValue stores a JSON object with two attributes: a string "type" +// and a generic "value" (string) defined by type. This type is used in +// a few places to implement the choice types that CBOR handles using tags. +type TypeAndValue struct { + Type string `json:"type"` + Value json.RawMessage `json:"value"` +} + +func (o *TypeAndValue) UnmarshalJSON(data []byte) error { + var temp struct { + Type string `json:"type"` + Value json.RawMessage `json:"value"` + } + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + if temp.Type == "" { + return errors.New("type not set") + } + + if len(temp.Value) == 0 { + return fmt.Errorf("no value provided for %s", temp.Type) + } + + o.Type = temp.Type + o.Value = temp.Value + + return nil +} diff --git a/encoding/json_test.go b/encoding/json_test.go new file mode 100644 index 00000000..82ad6790 --- /dev/null +++ b/encoding/json_test.go @@ -0,0 +1,38 @@ +package encoding + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTypeAndValue_UnmarshalJSON(t *testing.T) { + for _, tv := range []struct { + Input string + Expected TypeAndValue + Err string + }{ + { + Input: `{"type": "test", "value": "test"}`, + Expected: TypeAndValue{Type: "test", Value: []byte(`"test"`)}, + }, + { + Input: `{"type": "test"}`, + Err: "no value provided for test", + }, + { + Input: `{"value": "test"}`, + Err: "type not set", + }, + } { + var out TypeAndValue + err := out.UnmarshalJSON([]byte(tv.Input)) + + if tv.Err != "" { + assert.EqualError(t, err, tv.Err) + } else { + assert.NoError(t, err) + assert.Equal(t, tv.Expected, out) + } + } +} diff --git a/extensions/typechoice.go b/extensions/typechoice.go new file mode 100644 index 00000000..a942b25f --- /dev/null +++ b/extensions/typechoice.go @@ -0,0 +1,13 @@ +package extensions + +var StringType = "string" + +type ITypeChoiceValue interface { + // String returns the string representation of the ITypeChoiceValue. + String() string + // Valid returns an error if validation of the ITypeChoiceValue fails, + // or nil if it succeeds. + Valid() error + // Type returns the type name of this ITypeChoiceValue implementation. + Type() string +}