diff --git a/comid/attestverifkey_test.go b/comid/attestverifkey_test.go index 659a0abc..dd321e2d 100644 --- a/comid/attestverifkey_test.go +++ b/comid/attestverifkey_test.go @@ -23,16 +23,17 @@ func TestAttestVerifKey_Valid_empty(t *testing.T) { testerr: "environment validation failed: environment must not be empty", }, { - env: Environment{Instance: NewInstanceUEID(TestUEID)}, + env: Environment{Instance: MustNewInstanceUEID(TestUEID)}, verifkey: CryptoKeys{}, testerr: "verification keys validation failed: no keys to validate", }, { - env: Environment{Instance: NewInstanceUEID(TestUEID)}, + env: Environment{Instance: MustNewInstanceUEID(TestUEID)}, verifkey: CryptoKeys{&invalidKey}, testerr: "verification keys validation failed: invalid key at index 0: key value not set", }, } + for _, tv := range tvs { av := AttestVerifKey{Environment: tv.env, VerifKeys: tv.verifkey} err := av.Valid() diff --git a/comid/cbor.go b/comid/cbor.go index c44d518d..2de1a2b8 100644 --- a/comid/cbor.go +++ b/comid/cbor.go @@ -4,6 +4,7 @@ package comid import ( + "fmt" "reflect" cbor "github.com/fxamacker/cbor/v2" @@ -12,10 +13,8 @@ import ( var ( em, emError = initCBOREncMode() dm, dmError = initCBORDecMode() -) -func comidTags() cbor.TagSet { - comidTagsMap := map[uint64]interface{}{ + comidTagsMap = map[uint64]interface{}{ 32: TaggedURI(""), 37: TaggedUUID{}, 111: TaggedOID{}, @@ -37,7 +36,9 @@ func comidTags() cbor.TagSet { 601: TaggedPSARefValID{}, 602: TaggedCCAPlatformConfigID(""), } +) +func comidTags() cbor.TagSet { opts := cbor.TagOptions{ EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired, @@ -69,6 +70,28 @@ func initCBORDecMode() (dm cbor.DecMode, err error) { return decOpt.DecModeWithTags(comidTags()) } +func registerCOMIDTag(tag uint64, t interface{}) error { + if _, exists := comidTagsMap[tag]; exists { + return fmt.Errorf("tag %d is already registered", tag) + } + + comidTagsMap[tag] = t + + var err error + + em, err = initCBOREncMode() + if err != nil { + return err + } + + dm, err = initCBORDecMode() + if err != nil { + return err + } + + return nil +} + func init() { if emError != nil { panic(emError) diff --git a/comid/class.go b/comid/class.go index 085e703b..ab414740 100644 --- a/comid/class.go +++ b/comid/class.go @@ -23,39 +23,33 @@ type Class struct { // NewClassUUID instantiates a new Class object with the specified UUID as // identifier func NewClassUUID(uuid UUID) *Class { - c := Class{ - ClassID: &ClassID{}, - } - - if c.ClassID.SetUUID(uuid) == nil { + classID, err := NewUUIDClassID(uuid) + if err != nil { return nil } - return &c + + return &Class{ClassID: classID} } // NewClassImplID instantiates a new Class object that identifies the specified PSA // Implementation ID func NewClassImplID(implID ImplID) *Class { - c := Class{ - ClassID: &ClassID{}, - } - - if c.ClassID.SetImplID(implID) == nil { + classID, err := NewImplIDClassID(implID) + if err != nil { return nil } - return &c + + return &Class{ClassID: classID} } // NewClassOID instantiates a new Class object that identifies the OID func NewClassOID(oid string) *Class { - c := Class{ - ClassID: &ClassID{}, - } - - if c.ClassID.SetOID(oid) == nil { + classID, err := NewOIDClassID(oid) + if err != nil { return nil } - return &c + + return &Class{ClassID: classID} } // SetVendor sets the vendor metadata to the supplied string @@ -131,7 +125,7 @@ func (o *Class) SetIndex(index uint64) *Class { // Valid checks the non-empty<> constraint on the map func (o Class) Valid() error { // check non-empty<{ ... }> - if (o.ClassID == nil || o.ClassID.Unset()) && + if (o.ClassID == nil || !o.ClassID.IsSet()) && o.Vendor == nil && o.Model == nil && o.Layer == nil && o.Index == nil { return fmt.Errorf("class must not be empty") } diff --git a/comid/classid.go b/comid/classid.go index 760f8e5d..47b7d396 100644 --- a/comid/classid.go +++ b/comid/classid.go @@ -6,244 +6,319 @@ package comid import ( "encoding/base64" "encoding/json" + "errors" "fmt" -) - -// ClassID represents a $class-id-type-choice, which can be one of TaggedUUID, -// TaggedOID, or TaggedImplID (PSA-specific extension) -type ClassID struct { - val interface{} -} -type ClassIDType uint16 + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" +) const ( - ClassIDTypeUUID = ClassIDType(iota) - ClassIDTypeImplID - ClassIDTypeOID + IntType = "int" + OIDType = "oid" - ClassIDTypeUnknown = ^ClassIDType(0) + ImplIDType = "psa.impl-id" ) -// SetUUID sets the value of the targed ClassID to the supplied UUID -func (o *ClassID) SetUUID(uuid UUID) *ClassID { - if o != nil { - o.val = TaggedUUID(uuid) - } - return o -} - -type ImplID [32]byte -type TaggedImplID ImplID +type IClassIDValue interface { + extensions.ITypeChoiceValue -func (o ImplID) MarshalJSON() ([]byte, error) { - return json.Marshal(o[:]) + Bytes() []byte } -func (o *ImplID) UnmarshalJSON(data []byte) error { - var b []byte - - if err := json.Unmarshal(data, &b); err != nil { - return fmt.Errorf("bad ImplID: %w", err) - } +// ClassID represents a $class-id-type-choice, which can be one of TaggedUUID, +// TaggedOID, or TaggedImplID (PSA-specific extension) +type ClassID struct { + Value IClassIDValue +} - if nb := len(b); nb != 32 { - return fmt.Errorf("bad ImplID format: got %d bytes, want 32", nb) +func (o ClassID) Valid() error { + if o.Value == nil { + return errors.New("nil value") } - copy(o[:], b) - - return nil + return o.Value.Valid() } -type TaggedOID OID - -// SetImplID sets the value of the targed ClassID to the supplied PSA -// Implementation ID (see Section 3.2.2 of draft-tschofenig-rats-psa-token) -func (o *ClassID) SetImplID(implID ImplID) *ClassID { - if o != nil { - o.val = TaggedImplID(implID) +// Type returns the type of the ClassID +func (o ClassID) Type() string { + if o.Value == nil { + return "" } - return o + + return o.Value.Type() } -func (o ClassID) GetImplID() (ImplID, error) { - switch t := o.val.(type) { - case TaggedImplID: - return ImplID(t), nil - default: - return ImplID{}, fmt.Errorf("class-id type is: %T", t) +// Bytes returns a []byte containing the raw bytes of the class id value +func (o ClassID) Bytes() []byte { + if o.Value == nil { + return []byte{} } + return o.Value.Bytes() } -// SetOID sets the value of the targed ClassID to the supplied OID. -// The OID is a string in dotted-decimal notation -func (o *ClassID) SetOID(s string) *ClassID { - if o != nil { - var berOID OID - if berOID.FromString(s) != nil { - return nil - } - o.val = TaggedOID(berOID) - } - return o +// IsSet returns true iff the underlying class id value has been set (is not nil) +func (o ClassID) IsSet() bool { + return o.Value != nil } // MarshalCBOR serializes the target ClassID to CBOR func (o ClassID) MarshalCBOR() ([]byte, error) { - return em.Marshal(o.val) + return em.Marshal(o.Value) } // UnmarshalCBOR deserializes the supplied CBOR buffer into the target ClassID. // It is undefined behavior to try and inspect the target ClassID in case this // method returns an error. func (o *ClassID) UnmarshalCBOR(data []byte) error { - var implID TaggedImplID - - if dm.Unmarshal(data, &implID) == nil { - o.val = implID - return nil - } - - var uuid TaggedUUID - - if dm.Unmarshal(data, &uuid) == nil { - o.val = uuid - return nil - } - - var oid TaggedOID - - if dm.Unmarshal(data, &oid) == nil { - o.val = oid - return nil - } - - return fmt.Errorf("unknown class id (CBOR: %x)", data) + return dm.Unmarshal(data, &o.Value) } // UnmarshalJSON deserializes the supplied JSON object into the target ClassID -// The class id object must have one of the following shapes: -// -// UUID: -// -// { -// "type": "uuid", -// "value": "69E027B2-7157-4758-BCB4-D9F167FE49EA" -// } -// -// OID: +// The class id object must have the following shape: // // { -// "type": "oid", -// "value": "2.16.840.1.113741.1.15.4.2" +// "type": "", +// "value": "" // } // -// PSA Implementation ID: +// where must be one of the known IClassIDValue implementation +// type names (available in the base implementation: "uuid", "oid", +// "psa.impl-id"), and is the class id value encoded as +// a string. The exact encoding is depenent. For the base +// implmentation types it is // -// { -// "type": "psa.impl-id", -// "value": "YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE=" -// } +// oid: dot-seprated integers, e.g. "1.2.3.4" +// psa.impl-id: base64-encoded bytes, e.g. "YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE=" +// uuid: standard UUID string representation, e.g. "550e8400-e29b-41d4-a716-446655440000" func (o *ClassID) UnmarshalJSON(data []byte) error { - var v tnv + var value encoding.TypeAndValue - if err := json.Unmarshal(data, &v); err != nil { + if err := json.Unmarshal(data, &value); err != nil { return err } - switch v.Type { - case "uuid": // nolint: goconst - var x UUID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedUUID(x) - case "oid": - var x OID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedOID(x) - case "psa.impl-id": - var x ImplID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedImplID(x) - default: - return fmt.Errorf("unknown type '%s' for class id", v.Type) + if value.Type == "" { + return errors.New("key type not set") } - return nil + factory, ok := classIDValueRegister[value.Type] + if !ok { + return fmt.Errorf("unknown class id type: %q", value.Type) + } + + v, err := factory(value.Value) + if err != nil { + return err + } + + o.Value = v.Value + + return o.Valid() } // MarshalJSON serializes the target ClassID to JSON func (o ClassID) MarshalJSON() ([]byte, error) { - var ( - v tnv - b []byte - err error - ) - - switch t := o.val.(type) { - case TaggedUUID: - b, err = UUID(t).MarshalJSON() - if err != nil { - return nil, err + value := encoding.TypeAndValue{ + Type: o.Value.Type(), + Value: o.Value.String(), + } + + return json.Marshal(value) +} + +// String returns a printable string of the ClassID value. UUIDs use the +// canonical 8-4-4-4-12 format, PSA Implementation IDs are base64 encoded. +// OIDs are output in dotted-decimal notation. +func (o ClassID) String() string { + return o.Value.String() +} + +type ImplID [32]byte + +func (o ImplID) String() string { + return base64.StdEncoding.EncodeToString(o[:]) +} + +func (o ImplID) Valid() error { + return nil +} + +type TaggedImplID ImplID + +func NewImplIDClassID(val any) (*ClassID, error) { + var ret TaggedImplID + + if val == nil { + return &ClassID{&TaggedImplID{}}, nil + } + + switch t := val.(type) { + case []byte: + if nb := len(t); nb != 32 { + return nil, fmt.Errorf("bad ImplID: got %d bytes, want 32", nb) } - v = tnv{Type: "uuid", Value: b} - case TaggedOID: - b, err = OID(t).MarshalJSON() + + copy(ret[:], t) + case string: + v, err := base64.StdEncoding.DecodeString(t) if err != nil { - return nil, err + return nil, fmt.Errorf("bad ImplID: %w", err) } - v = tnv{Type: "oid", Value: b} - case TaggedImplID: - b, err = ImplID(t).MarshalJSON() - if err != nil { - return nil, err + + if nb := len(v); nb != 32 { + return nil, fmt.Errorf("bad ImplID: decoded %d bytes, want 32", nb) } - v = tnv{Type: "psa.impl-id", Value: b} + + copy(ret[:], v) + case TaggedImplID: + copy(ret[:], t[:]) + case *TaggedImplID: + copy(ret[:], (*t)[:]) + case ImplID: + copy(ret[:], t[:]) + case *ImplID: + copy(ret[:], (*t)[:]) default: - return nil, fmt.Errorf("unknown type %T for class-id", t) + return nil, fmt.Errorf("unexpected type for ImplID: %T", t) } - return json.Marshal(v) + return &ClassID{&ret}, nil } -// Type returns the type of the target ClassID, i.e., one of UUID, OID or PSA -// Implementation ID -func (o ClassID) Type() ClassIDType { - switch o.val.(type) { - case TaggedUUID: - return ClassIDTypeUUID - case TaggedImplID: - return ClassIDTypeImplID - case TaggedOID: - return ClassIDTypeOID +func MustNewImplIDClassID(val any) *ClassID { + ret, err := NewImplIDClassID(val) + if err != nil { + panic(err) } - return ClassIDTypeUnknown + + return ret } -// String returns a printable string of the ClassID value. UUIDs use the -// canonical 8-4-4-4-12 format, PSA Implementation IDs are base64 encoded. -// OIDs are output in dotted-decimal notation. -func (o ClassID) String() string { - switch t := o.val.(type) { - case TaggedUUID: - return UUID(t).String() - case TaggedImplID: - b := [32]byte(t) - return base64.StdEncoding.EncodeToString(b[:]) +func (o TaggedImplID) Valid() error { + return ImplID(o).Valid() +} + +func (o TaggedImplID) String() string { + return ImplID(o).String() +} + +func (o TaggedImplID) Type() string { + return ImplIDType +} + +func (o TaggedImplID) Bytes() []byte { + return o[:] +} + +type TaggedOID OID + +func NewOIDClassID(val any) (*ClassID, error) { + ret := TaggedOID{} + + if val == nil { + return &ClassID{&ret}, nil + } + + switch t := val.(type) { + case string: + var berOID OID + if err := berOID.FromString(t); err != nil { + return nil, err + } + + ret = TaggedOID(berOID) + case []byte: + ret = make([]byte, len(t)) + copy(ret, t) case TaggedOID: - return OID(t).String() - default: - return "" + ret = make([]byte, len(t)) + copy(ret, t) + case OID: + ret = make([]byte, len(t)) + copy(ret, t) + case *TaggedOID: + ret = make([]byte, len(*t)) + copy(ret, (*t)) + case *OID: + ret = make([]byte, len(*t)) + copy(ret, (*t)) + } + + return &ClassID{&ret}, nil +} + +func MustNewOIDClassID(val any) *ClassID { + ret, err := NewOIDClassID(val) + if err != nil { + panic(err) } + + return ret +} + +func (o TaggedOID) Type() string { + return OIDType +} + +func (o TaggedOID) String() string { + return OID(o).String() } -// Unset tests whether the target ClassID has been initialized -func (o ClassID) Unset() bool { - return o.val == nil || o.Type() == ClassIDTypeUnknown +func (o TaggedOID) Valid() error { + return nil +} + +func (o TaggedOID) Bytes() []byte { + return o +} + +func NewUUIDClassID(val any) (*ClassID, error) { + if val == nil { + return &ClassID{&TaggedUUID{}}, nil + } + + ret, err := NewTaggedUUID(val) + if err != nil { + return nil, err + } + + return &ClassID{ret}, nil +} + +func MustNewUUIDClassID(val any) *ClassID { + ret, err := NewUUIDClassID(val) + if err != nil { + panic(err) + } + + return ret +} + +type IClassIDFactory func(any) (*ClassID, error) + +var classIDValueRegister = map[string]IClassIDFactory{ + OIDType: NewOIDClassID, + ImplIDType: NewImplIDClassID, + UUIDType: NewUUIDClassID, +} + +func RegisterClassIDType(tag uint64, factory IClassIDFactory) error { + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Type() + if _, exists := classIDValueRegister[typ]; exists { + return fmt.Errorf("class ID type with name %q already exists", typ) + } + + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + classIDValueRegister[typ] = factory + + return nil } diff --git a/comid/classid_test.go b/comid/classid_test.go index 51b07391..0c6a59fb 100644 --- a/comid/classid_test.go +++ b/comid/classid_test.go @@ -4,6 +4,7 @@ package comid import ( + "encoding/json" "fmt" "testing" @@ -12,9 +13,7 @@ import ( ) func TestClassID_MarshalCBOR_UUID(t *testing.T) { - var tv ClassID - - require.NotNil(t, tv.SetUUID(TestUUID)) + tv := MustNewUUIDClassID(TestUUID) // 37(h'31FB5ABF023E4992AA4E95F9C1503BFA') // tag(37): d8 25 @@ -29,9 +28,7 @@ func TestClassID_MarshalCBOR_UUID(t *testing.T) { } func TestClassID_MarshalCBOR_ImplID(t *testing.T) { - var tv ClassID - - require.NotNil(t, tv.SetImplID(TestImplID)) + tv := MustNewImplIDClassID(TestImplID) // 600 (h'61636D652D696D706C656D656E746174696F6E2D69642D303030303030303031') // tag(600): d9 0258 @@ -66,7 +63,7 @@ func TestClassID_UnmarshalCBOR_UUID_OK(t *testing.T) { err := actual.UnmarshalCBOR(tv) assert.Nil(t, err) - assert.Equal(t, ClassIDTypeUUID, actual.Type()) + assert.Equal(t, "uuid", actual.Type()) assert.Equal(t, TestUUIDString, actual.String()) } @@ -79,7 +76,7 @@ func TestClassID_UnmarshalCBOR_ImplID_OK(t *testing.T) { err := actual.UnmarshalCBOR(tv) assert.Nil(t, err) - assert.Equal(t, ClassIDTypeImplID, actual.Type()) + assert.Equal(t, "psa.impl-id", actual.Type()) assert.Equal(t, expected, actual.String()) } @@ -88,12 +85,10 @@ func TestClassID_UnmarshalCBOR_badInput(t *testing.T) { hex := "582061636d652d696d706c656d656e746174696f6e2d69642d303030303030303031" tv := MustHexDecode(t, hex) - expectedError := fmt.Sprintf("unknown class id (CBOR: %s)", hex) - var actual ClassID err := actual.UnmarshalCBOR(tv) - assert.EqualError(t, err, expectedError) + assert.EqualError(t, err, "cbor: cannot unmarshal byte string into Go value of type comid.IClassIDValue") } func TestClassID_UnmarshalJSON_UUID(t *testing.T) { @@ -105,7 +100,7 @@ func TestClassID_UnmarshalJSON_UUID(t *testing.T) { err := actual.UnmarshalJSON([]byte(tv)) assert.Nil(t, err) - assert.Equal(t, ClassIDTypeUUID, actual.Type()) + assert.Equal(t, "uuid", actual.Type()) assert.Equal(t, TestUUIDString, actual.String()) } @@ -122,7 +117,7 @@ func TestClassID_UnmarshalJSON_ImplID(t *testing.T) { err := actual.UnmarshalJSON([]byte(tv)) assert.Nil(t, err) - assert.Equal(t, ClassIDTypeImplID, actual.Type()) + assert.Equal(t, "psa.impl-id", actual.Type()) // the returned string is the base64 encoding of the stored binary assert.Equal(t, expected, actual.String()) } @@ -133,8 +128,8 @@ func TestClassID_UnmarshalJSON_badInput_unknown_type(t *testing.T) { var actual ClassID err := actual.UnmarshalJSON([]byte(tv)) - assert.EqualError(t, err, "unknown type 'FOOBAR' for class id") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.EqualError(t, err, `unknown class id type: "FOOBAR"`) + assert.Equal(t, "", actual.Type()) } func TestClassID_UnmarshalJSON_badInput_missing_value(t *testing.T) { @@ -143,8 +138,8 @@ func TestClassID_UnmarshalJSON_badInput_missing_value(t *testing.T) { var actual ClassID err := actual.UnmarshalJSON([]byte(tv)) - assert.EqualError(t, err, "bad ImplID: unexpected end of JSON input") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.EqualError(t, err, "bad ImplID: decoded 0 bytes, want 32") + assert.Equal(t, "", actual.Type()) } func TestClassID_UnmarshalJSON_badInput_empty_value(t *testing.T) { @@ -153,8 +148,8 @@ func TestClassID_UnmarshalJSON_badInput_empty_value(t *testing.T) { var actual ClassID err := actual.UnmarshalJSON([]byte(tv)) - assert.EqualError(t, err, "bad ImplID format: got 0 bytes, want 32") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.EqualError(t, err, "bad ImplID: decoded 0 bytes, want 32") + assert.Equal(t, "", actual.Type()) } func TestClassID_UnmarshalJSON_badInput_badly_encoded_ImplID_value(t *testing.T) { @@ -164,7 +159,7 @@ func TestClassID_UnmarshalJSON_badInput_badly_encoded_ImplID_value(t *testing.T) err := actual.UnmarshalJSON([]byte(tv)) assert.EqualError(t, err, "bad ImplID: illegal base64 data at input byte 0") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.Equal(t, "", actual.Type()) } func TestClassID_UnmarshalJSON_badInput_badly_encoded_UUID_value(t *testing.T) { @@ -174,7 +169,7 @@ func TestClassID_UnmarshalJSON_badInput_badly_encoded_UUID_value(t *testing.T) { err := actual.UnmarshalJSON([]byte(tv)) assert.EqualError(t, err, "bad UUID: invalid UUID length: 9") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.Equal(t, "", actual.Type()) } func TestClassID_SetOID_ok(t *testing.T) { @@ -200,8 +195,7 @@ func TestClassID_SetOID_ok(t *testing.T) { } for _, tv := range tvs { - c := ClassID{} - assert.NotNil(t, c.SetOID(tv)) + c := MustNewOIDClassID(tv) assert.Equal(t, tv, c.String()) } } @@ -219,7 +213,151 @@ func TestClassID_SetOID_bad(t *testing.T) { } for _, tv := range tvs { - c := ClassID{} - assert.Nil(t, c.SetOID(tv)) + c, err := NewOIDClassID(tv) + assert.NotNil(t, err) + assert.Nil(t, c) + } +} + +func Test_NewImplIDClassID(t *testing.T) { + classID, err := NewImplIDClassID(nil) + expected := [32]byte{} + require.NoError(t, err) + assert.Equal(t, expected[:], classID.Bytes()) + + taggedImplID := TaggedImplID(TestImplID) + + for _, v := range []any{ + TestImplID, + &TestImplID, + taggedImplID, + &taggedImplID, + taggedImplID.Bytes(), + } { + classID, err = NewImplIDClassID(v) + require.NoError(t, err) + assert.Equal(t, taggedImplID.Bytes(), classID.Bytes()) + } + + expected = [32]byte{ + 0x61, 0x63, 0x6d, 0x65, 0x2d, 0x69, 0x6d, 0x70, + 0x6c, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x2d, 0x69, 0x64, 0x2d, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x31, + } + classID, err = NewImplIDClassID("YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE=") + require.NoError(t, err) + assert.Equal(t, expected[:], classID.Bytes()) + + _, err = NewImplIDClassID(7) + assert.EqualError(t, err, "unexpected type for ImplID: int") +} + +func Test_NewUUIDClassID(t *testing.T) { + classID, err := NewUUIDClassID(nil) + + expected := [16]byte{} + require.NoError(t, err) + assert.Equal(t, expected[:], classID.Bytes()) + + taggedUUID := TaggedUUID(TestUUID) + + for _, v := range []any{ + TestUUID, + &TestUUID, + taggedUUID, + &taggedUUID, + taggedUUID.Bytes(), + } { + classID, err = NewUUIDClassID(v) + require.NoError(t, err) + assert.Equal(t, taggedUUID.Bytes(), classID.Bytes()) + } + + classID, err = NewUUIDClassID(taggedUUID.String()) + require.NoError(t, err) + assert.Equal(t, taggedUUID.Bytes(), classID.Bytes()) +} + +func Test_NewOIDClassID(t *testing.T) { + classID, err := NewOIDClassID(nil) + + expected := []byte{} + require.NoError(t, err) + assert.Equal(t, expected, classID.Bytes()) + + var oid OID + require.NoError(t, oid.FromString(TestOID)) + taggedOID := TaggedOID(oid) + + for _, v := range []any{ + TestOID, + oid, + &oid, + taggedOID, + &taggedOID, + taggedOID.Bytes(), + } { + classID, err = NewOIDClassID(v) + require.NoError(t, err) + expected := taggedOID.Bytes() + got := classID.Bytes() + assert.Equal(t, expected, got) } + + classID, err = NewOIDClassID(taggedOID.String()) + require.NoError(t, err) + assert.Equal(t, taggedOID.Bytes(), classID.Bytes()) +} + +type testClassID [4]byte + +func newTestClassID(val any) (*ClassID, error) { + return &ClassID{&testClassID{0x74, 0x65, 0x73, 0x74}}, nil +} + +func (o testClassID) Bytes() []byte { + return o[:] +} + +func (o testClassID) Type() string { + return "test-class-id" +} + +func (o testClassID) String() string { + return "test" +} + +func (o testClassID) Valid() error { + return nil +} + +func Test_RegisterClassIDType(t *testing.T) { + err := RegisterClassIDType(99999, newTestClassID) + require.NoError(t, err) + + classID, err := newTestClassID(nil) + require.NoError(t, err) + + data, err := json.Marshal(classID) + require.NoError(t, err) + assert.Equal(t, string(data), `{"type":"test-class-id","value":"test"}`) + + var out ClassID + err = json.Unmarshal(data, &out) + require.NoError(t, err) + assert.Equal(t, classID.Bytes(), out.Bytes()) + + data, err = em.Marshal(classID) + require.NoError(t, err) + assert.Equal(t, data, []byte{ + 0xda, 0x0, 0x1, 0x86, 0x9f, // tag 99999 + 0x44, // bstr(4) + 0x74, 0x65, 0x73, 0x74, // "test" + }) + + var out2 ClassID + err = dm.Unmarshal(data, &out2) + require.NoError(t, err) + assert.Equal(t, classID.Bytes(), out2.Bytes()) } diff --git a/comid/cryptokey.go b/comid/cryptokey.go index 53d99df8..c6cadac6 100644 --- a/comid/cryptokey.go +++ b/comid/cryptokey.go @@ -14,6 +14,8 @@ import ( "fmt" "github.com/fxamacker/cbor/v2" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" "github.com/veraison/go-cose" "github.com/veraison/swid" ) @@ -58,50 +60,12 @@ type CryptoKey struct { // specified crypto key type. For PKIX types, k must be a string. For COSE_Key, // k must be a []byte. For thumbprint types, k must be a swid.HashEntry. func NewCryptoKey(k any, typ string) (*CryptoKey, error) { - switch typ { - case PKIXBase64KeyType: - v, ok := k.(string) - if !ok { - return nil, fmt.Errorf("value must be a string; found %T", k) - } - return NewPKIXBase64Key(v) - case PKIXBase64CertType: - v, ok := k.(string) - if !ok { - return nil, fmt.Errorf("value must be a string; found %T", k) - } - return NewPKIXBase64Cert(v) - case PKIXBase64CertPathType: - v, ok := k.(string) - if !ok { - return nil, fmt.Errorf("value must be a string; found %T", k) - } - return NewPKIXBase64CertPath(v) - case COSEKeyType: - v, ok := k.([]byte) - if !ok { - return nil, fmt.Errorf("value must be a []byte; found %T", k) - } - return NewCOSEKey(v) - case ThumbprintType, CertThumbprintType, CertPathThumbprintType: - v, ok := k.(swid.HashEntry) - if !ok { - return nil, fmt.Errorf("value must be a swid.HashEntry; found %T", k) - } - switch typ { - case ThumbprintType: - return NewThumbprint(v) - case CertThumbprintType: - return NewCertThumbprint(v) - case CertPathThumbprintType: - return NewCertPathThumbprint(v) - default: - // Should never here because of the the outer case clause - panic(fmt.Sprintf("unexpected thumbprint type: %s", typ)) - } - default: + factory, ok := cryptoKeyValueRegister[typ] + if !ok { return nil, fmt.Errorf("unexpected CryptoKey type: %s", typ) } + + return factory(k) } // MustNewCryptoKey is the same as NewCryptoKey, but does not return an error, @@ -126,6 +90,11 @@ func (o CryptoKey) Valid() error { return o.Value.Valid() } +// Type returns the type of the CryptoKey value +func (o CryptoKey) Type() string { + return o.Value.Type() +} + // PublicKey returns a crypto.PublicKey constructed from the CryptoKey's // underlying value. This returns an error if the CryptoKey is one of the // thumbprint types. @@ -136,42 +105,18 @@ func (o CryptoKey) PublicKey() (crypto.PublicKey, error) { // MarshalJSON returns a []byte containing the JSON representation of the // CryptoKey. func (o CryptoKey) MarshalJSON() ([]byte, error) { - value := struct { - Type string `json:"type"` - Value string `json:"value"` - }{ + value := encoding.TypeAndValue{ + Type: o.Value.Type(), Value: o.Value.String(), } - switch o.Value.(type) { - case TaggedPKIXBase64Key: - value.Type = PKIXBase64KeyType - case TaggedPKIXBase64Cert: - value.Type = PKIXBase64CertType - case TaggedPKIXBase64CertPath: - value.Type = PKIXBase64CertPathType - case TaggedCOSEKey: - value.Type = COSEKeyType - case TaggedThumbprint: - value.Type = ThumbprintType - case TaggedCertThumbprint: - value.Type = CertThumbprintType - case TaggedCertPathThumbprint: - value.Type = CertPathThumbprintType - default: - return nil, fmt.Errorf("unexpected ICryptoKeyValue type: %T", o.Value) - } - return json.Marshal(value) } // UnmarshalJSON populates the CryptoKey from the JSON representation inside // the provided []byte. func (o *CryptoKey) UnmarshalJSON(b []byte) error { - var value struct { - Type string `json:"type"` - Value string `json:"value"` - } + var value encoding.TypeAndValue if err := json.Unmarshal(b, &value); err != nil { return err @@ -181,36 +126,18 @@ func (o *CryptoKey) UnmarshalJSON(b []byte) error { return errors.New("key type not set") } - switch value.Type { - case PKIXBase64KeyType: - o.Value = TaggedPKIXBase64Key(value.Value) - case PKIXBase64CertType: - o.Value = TaggedPKIXBase64Cert(value.Value) - case PKIXBase64CertPathType: - o.Value = TaggedPKIXBase64CertPath(value.Value) - case COSEKeyType: - data, err := base64.StdEncoding.DecodeString(value.Value) - if err != nil { - return fmt.Errorf("base64 decode error: %w", err) - } - o.Value = TaggedCOSEKey(data) - case ThumbprintType, CertThumbprintType, CertPathThumbprintType: - he, err := swid.ParseHashEntry(value.Value) - if err != nil { - return fmt.Errorf("swid.HashEntry decode error: %w", err) - } - switch value.Type { - case ThumbprintType: - o.Value = TaggedThumbprint{digest{he}} - case CertThumbprintType: - o.Value = TaggedCertThumbprint{digest{he}} - case CertPathThumbprintType: - o.Value = TaggedCertPathThumbprint{digest{he}} - } - default: + factory, ok := cryptoKeyValueRegister[value.Type] + if !ok { return fmt.Errorf("unexpected ICryptoKeyValue type: %q", value.Type) } + k, err := factory(value.Value) + if err != nil { + return err + } + + o.Value = k.Value + return o.Valid() } @@ -229,11 +156,8 @@ func (o *CryptoKey) UnmarshalCBOR(b []byte) error { // ICryptoKeyValue is the interface implemented by the concrete CryptoKey value // types. type ICryptoKeyValue interface { - // String returns the string representation of the ICryptoKeyValue. - String() string - // Valid returns an error if validation of the ICryptoKeyValue fails, - // or nil if it succeeds. - Valid() error + extensions.ITypeChoiceValue + // PublicKey returns a crypto.PublicKey constructed from the // ICryptoKeyValue's underlying value. This returns an error if the // ICryptoKeyValue is one of the thumbprint types. @@ -244,7 +168,12 @@ type ICryptoKeyValue interface { // https://www.rfc-editor.org/rfc/rfc7468#section-13 type TaggedPKIXBase64Key string -func NewPKIXBase64Key(s string) (*CryptoKey, error) { +func NewPKIXBase64Key(k any) (*CryptoKey, error) { + s, ok := k.(string) + if !ok { + return nil, fmt.Errorf("value must be a string; found %T", k) + } + key := TaggedPKIXBase64Key(s) if err := key.Valid(); err != nil { return nil, err @@ -252,8 +181,8 @@ func NewPKIXBase64Key(s string) (*CryptoKey, error) { return &CryptoKey{key}, nil } -func MustNewPKIXBase64Key(s string) *CryptoKey { - key, err := NewPKIXBase64Key(s) +func MustNewPKIXBase64Key(k any) *CryptoKey { + key, err := NewPKIXBase64Key(k) if err != nil { panic(err) } @@ -269,6 +198,10 @@ func (o TaggedPKIXBase64Key) Valid() error { return err } +func (o TaggedPKIXBase64Key) Type() string { + return PKIXBase64KeyType +} + func (o TaggedPKIXBase64Key) PublicKey() (crypto.PublicKey, error) { if string(o) == "" { return nil, errors.New("key value not set") @@ -302,7 +235,12 @@ func (o TaggedPKIXBase64Key) PublicKey() (crypto.PublicKey, error) { // certificate. See https://www.rfc-editor.org/rfc/rfc7468#section-5 type TaggedPKIXBase64Cert string -func NewPKIXBase64Cert(s string) (*CryptoKey, error) { +func NewPKIXBase64Cert(k any) (*CryptoKey, error) { + s, ok := k.(string) + if !ok { + return nil, fmt.Errorf("value must be a string; found %T", k) + } + cert := TaggedPKIXBase64Cert(s) if err := cert.Valid(); err != nil { return nil, err @@ -310,8 +248,8 @@ func NewPKIXBase64Cert(s string) (*CryptoKey, error) { return &CryptoKey{cert}, nil } -func MustNewPKIXBase64Cert(s string) *CryptoKey { - cert, err := NewPKIXBase64Cert(s) +func MustNewPKIXBase64Cert(k any) *CryptoKey { + cert, err := NewPKIXBase64Cert(k) if err != nil { panic(err) } @@ -327,6 +265,10 @@ func (o TaggedPKIXBase64Cert) Valid() error { return err } +func (o TaggedPKIXBase64Cert) Type() string { + return PKIXBase64CertType +} + func (o TaggedPKIXBase64Cert) PublicKey() (crypto.PublicKey, error) { cert, err := o.cert() if err != nil { @@ -375,7 +317,11 @@ func (o TaggedPKIXBase64Cert) cert() (*x509.Certificate, error) { // directly certifies the one preceding. type TaggedPKIXBase64CertPath string -func NewPKIXBase64CertPath(s string) (*CryptoKey, error) { +func NewPKIXBase64CertPath(k any) (*CryptoKey, error) { + s, ok := k.(string) + if !ok { + return nil, fmt.Errorf("value must be a string; found %T", k) + } cert := TaggedPKIXBase64CertPath(s) if err := cert.Valid(); err != nil { @@ -385,8 +331,8 @@ func NewPKIXBase64CertPath(s string) (*CryptoKey, error) { return &CryptoKey{cert}, nil } -func MustNewPKIXBase64CertPath(s string) *CryptoKey { - cert, err := NewPKIXBase64CertPath(s) +func MustNewPKIXBase64CertPath(k any) *CryptoKey { + cert, err := NewPKIXBase64CertPath(k) if err != nil { panic(err) @@ -404,6 +350,10 @@ func (o TaggedPKIXBase64CertPath) Valid() error { return err } +func (o TaggedPKIXBase64CertPath) Type() string { + return PKIXBase64CertPathType +} + func (o TaggedPKIXBase64CertPath) PublicKey() (crypto.PublicKey, error) { certs, err := o.certPath() if err != nil { @@ -468,7 +418,22 @@ func (o TaggedPKIXBase64CertPath) certPath() ([]*x509.Certificate, error) { // https://www.rfc-editor.org/rfc/rfc9052#section-7 type TaggedCOSEKey []byte -func NewCOSEKey(b []byte) (*CryptoKey, error) { +func NewCOSEKey(k any) (*CryptoKey, error) { + var b []byte + var err error + + switch t := k.(type) { + case []byte: + b = t + case string: + b, err = base64.StdEncoding.DecodeString(t) + if err != nil { + return nil, fmt.Errorf("base64 decode error: %w", err) + } + default: + return nil, fmt.Errorf("value must be a []byte or a string; found %T", k) + } + key := TaggedCOSEKey(b) if err := key.Valid(); err != nil { @@ -478,8 +443,8 @@ func NewCOSEKey(b []byte) (*CryptoKey, error) { return &CryptoKey{key}, nil } -func MustNewCOSEKey(b []byte) *CryptoKey { - key, err := NewCOSEKey(b) +func MustNewCOSEKey(k any) *CryptoKey { + key, err := NewCOSEKey(k) if err != nil { panic(err) @@ -509,6 +474,10 @@ func (o TaggedCOSEKey) Valid() error { return err } +func (o TaggedCOSEKey) Type() string { + return COSEKeyType +} + func (o TaggedCOSEKey) PublicKey() (crypto.PublicKey, error) { if len(o) == 0 { return nil, errors.New("empty COSE_Key value") @@ -608,7 +577,22 @@ type TaggedThumbprint struct { digest } -func NewThumbprint(he swid.HashEntry) (*CryptoKey, error) { +func NewThumbprint(k any) (*CryptoKey, error) { + var he swid.HashEntry + var err error + + switch t := k.(type) { + case string: + he, err = swid.ParseHashEntry(t) + if err != nil { + return nil, fmt.Errorf("swid.HashEntry decode error: %w", err) + } + case swid.HashEntry: + he = t + default: + return nil, fmt.Errorf("value must be a swid.HashEntry or a string; found %T", k) + } + key := &CryptoKey{TaggedThumbprint{digest{he}}} if err := key.Valid(); err != nil { @@ -618,8 +602,8 @@ func NewThumbprint(he swid.HashEntry) (*CryptoKey, error) { return key, nil } -func MustNewThumbprint(he swid.HashEntry) *CryptoKey { - key, err := NewThumbprint(he) +func MustNewThumbprint(k any) *CryptoKey { + key, err := NewThumbprint(k) if err != nil { panic(err) @@ -628,13 +612,32 @@ func MustNewThumbprint(he swid.HashEntry) *CryptoKey { return key } +func (o TaggedThumbprint) Type() string { + return ThumbprintType +} + // TaggedCertThumbprint represents a digest of a certificate. The digest value // may be used to find the certificate if contained in a lookup table. type TaggedCertThumbprint struct { digest } -func NewCertThumbprint(he swid.HashEntry) (*CryptoKey, error) { +func NewCertThumbprint(k any) (*CryptoKey, error) { + var he swid.HashEntry + var err error + + switch t := k.(type) { + case string: + he, err = swid.ParseHashEntry(t) + if err != nil { + return nil, fmt.Errorf("swid.HashEntry decode error: %w", err) + } + case swid.HashEntry: + he = t + default: + return nil, fmt.Errorf("value must be a swid.HashEntry or a string; found %T", k) + } + key := &CryptoKey{TaggedCertThumbprint{digest{he}}} if err := key.Valid(); err != nil { @@ -644,8 +647,8 @@ func NewCertThumbprint(he swid.HashEntry) (*CryptoKey, error) { return key, nil } -func MustNewCertThumbprint(he swid.HashEntry) *CryptoKey { - key, err := NewCertThumbprint(he) +func MustNewCertThumbprint(k any) *CryptoKey { + key, err := NewCertThumbprint(k) if err != nil { panic(err) @@ -654,6 +657,10 @@ func MustNewCertThumbprint(he swid.HashEntry) *CryptoKey { return key } +func (o TaggedCertThumbprint) Type() string { + return CertThumbprintType +} + // TaggedCertPathThumbprint represents a digest of a certification path. The // digest value may be used to find the certificate path if contained in a // lookup table. @@ -661,7 +668,22 @@ type TaggedCertPathThumbprint struct { digest } -func NewCertPathThumbprint(he swid.HashEntry) (*CryptoKey, error) { +func NewCertPathThumbprint(k any) (*CryptoKey, error) { + var he swid.HashEntry + var err error + + switch t := k.(type) { + case string: + he, err = swid.ParseHashEntry(t) + if err != nil { + return nil, fmt.Errorf("swid.HashEntry decode error: %w", err) + } + case swid.HashEntry: + he = t + default: + return nil, fmt.Errorf("value must be a swid.HashEntry or a string; found %T", k) + } + key := &CryptoKey{TaggedCertPathThumbprint{digest{he}}} if err := key.Valid(); err != nil { @@ -671,8 +693,8 @@ func NewCertPathThumbprint(he swid.HashEntry) (*CryptoKey, error) { return key, nil } -func MustNewCertPathThumbprint(he swid.HashEntry) *CryptoKey { - key, err := NewCertPathThumbprint(he) +func MustNewCertPathThumbprint(k any) *CryptoKey { + key, err := NewCertPathThumbprint(k) if err != nil { panic(err) @@ -680,3 +702,47 @@ func MustNewCertPathThumbprint(he swid.HashEntry) *CryptoKey { return key } + +func (o TaggedCertPathThumbprint) Type() string { + return CertPathThumbprintType +} + +// ICryptoKeyFactory creates a *CryptoKey from the specified any value. When +// passed nil as input, this must return the equivaluent nil-value for the +// associated ICryptoKeyValue implementation. +type ICryptoKeyFactory func(any) (*CryptoKey, error) + +var cryptoKeyValueRegister = map[string]ICryptoKeyFactory{ + // types defined by the core spec + PKIXBase64KeyType: NewPKIXBase64Key, + PKIXBase64CertType: NewPKIXBase64Cert, + PKIXBase64CertPathType: NewPKIXBase64CertPath, + COSEKeyType: NewCOSEKey, + ThumbprintType: NewThumbprint, + CertThumbprintType: NewCertThumbprint, + CertPathThumbprintType: NewCertPathThumbprint, +} + +// RegisterCryptoKeyType registeres a new ICryptoKeyValue implementation +// (created by the provided ICryptoKeyFactory) under the specified type name +// and CBOR tag. +func RegisterCryptoKeyType(tag uint64, factory ICryptoKeyFactory) error { + + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Type() + if _, exists := cryptoKeyValueRegister[typ]; exists { + return fmt.Errorf("crypto key type with name %q already exists", typ) + } + + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + cryptoKeyValueRegister[typ] = factory + + return nil +} diff --git a/comid/cryptokey_test.go b/comid/cryptokey_test.go index 1c0812fc..b16ed637 100644 --- a/comid/cryptokey_test.go +++ b/comid/cryptokey_test.go @@ -4,6 +4,7 @@ package comid import ( + "crypto" "encoding/base64" "encoding/json" "fmt" @@ -149,7 +150,7 @@ func Test_CryptoKey_NewCOSEKey(t *testing.T) { } func Test_CryptoKey_NewThumbprint(t *testing.T) { - type newKeyFunc func(swid.HashEntry) (*CryptoKey, error) + type newKeyFunc func(any) (*CryptoKey, error) for _, newFunc := range []newKeyFunc{ NewThumbprint, @@ -177,7 +178,7 @@ func Test_CryptoKey_NewThumbprint(t *testing.T) { assert.Contains(t, err.Error(), "length mismatch for hash algorithm") } - type mustNewKeyFunc func(swid.HashEntry) *CryptoKey + type mustNewKeyFunc func(any) *CryptoKey for _, mustNewFunc := range []mustNewKeyFunc{ MustNewThumbprint, @@ -371,22 +372,22 @@ func Test_NewCryptoKey_negative(t *testing.T) { { Type: COSEKeyType, In: 7, - ErrMsg: "value must be a []byte; found int", + ErrMsg: "value must be a []byte or a string; found int", }, { Type: ThumbprintType, In: 7, - ErrMsg: "value must be a swid.HashEntry; found int", + ErrMsg: "value must be a swid.HashEntry or a string; found int", }, { Type: CertThumbprintType, In: 7, - ErrMsg: "value must be a swid.HashEntry; found int", + ErrMsg: "value must be a swid.HashEntry or a string; found int", }, { Type: CertPathThumbprintType, In: 7, - ErrMsg: "value must be a swid.HashEntry; found int", + ErrMsg: "value must be a swid.HashEntry or a string; found int", }, { Type: "random-key", @@ -399,3 +400,55 @@ func Test_NewCryptoKey_negative(t *testing.T) { assert.ErrorContains(t, err, tv.ErrMsg) } } + +type testCryptoKey [4]byte + +func newTestCryptoKey(val any) (*CryptoKey, error) { + return &CryptoKey{&testCryptoKey{0x74, 0x64, 0x73, 0x74}}, nil +} + +func (o testCryptoKey) PublicKey() (crypto.PublicKey, error) { + return crypto.PublicKey(o[:]), nil +} + +func (o testCryptoKey) Type() string { + return "test-crypto-key" +} + +func (o testCryptoKey) String() string { + return "test" +} + +func (o testCryptoKey) Valid() error { + return nil +} + +func Test_RegisterCryptoKeyTypeType(t *testing.T) { + err := RegisterCryptoKeyType(99998, newTestCryptoKey) + require.NoError(t, err) + + key, err := newTestCryptoKey(nil) + require.NoError(t, err) + + data, err := json.Marshal(key) + require.NoError(t, err) + assert.Equal(t, string(data), `{"type":"test-crypto-key","value":"test"}`) + + var out CryptoKey + err = json.Unmarshal(data, &out) + require.NoError(t, err) + assert.EqualValues(t, key, &out) + + data, err = em.Marshal(key) + require.NoError(t, err) + assert.Equal(t, data, []byte{ + 0xda, 0x0, 0x1, 0x86, 0x9e, // tag 99998 + 0x44, // bstr(4) + 0x74, 0x64, 0x73, 0x74, // "test" + }) + + var out2 CryptoKey + err = dm.Unmarshal(data, &out2) + require.NoError(t, err) + assert.Equal(t, key, &out2) +} diff --git a/comid/devidentitykey_test.go b/comid/devidentitykey_test.go index 14f2f6e1..be618c8c 100644 --- a/comid/devidentitykey_test.go +++ b/comid/devidentitykey_test.go @@ -23,12 +23,12 @@ func TestDevIdentityKey_Valid_empty(t *testing.T) { testerr: "environment validation failed: environment must not be empty", }, { - env: Environment{Instance: NewInstanceUEID(TestUEID)}, + env: Environment{Instance: MustNewInstanceUEID(TestUEID)}, verifkey: CryptoKeys{}, testerr: "verification keys validation failed: no keys to validate", }, { - env: Environment{Instance: NewInstanceUEID(TestUEID)}, + env: Environment{Instance: MustNewInstanceUEID(TestUEID)}, verifkey: CryptoKeys{&invalidKey}, testerr: "verification keys validation failed: invalid key at index 0: key value not set", }, diff --git a/comid/environment_test.go b/comid/environment_test.go index c078ce90..54cc58be 100644 --- a/comid/environment_test.go +++ b/comid/environment_test.go @@ -78,7 +78,7 @@ func TestEnvironment_ToCBOR_class_only(t *testing.T) { func TestEnvironment_ToCBOR_class_and_instance(t *testing.T) { tv := Environment{ Class: NewClassUUID(TestUUID), - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewInstanceUEID(TestUEID), } require.NotNil(t, tv.Class) require.NotNil(t, tv.Instance) @@ -96,7 +96,7 @@ func TestEnvironment_ToCBOR_class_and_instance(t *testing.T) { func TestEnvironment_ToCBOR_instance_only(t *testing.T) { tv := Environment{ - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewInstanceUEID(TestUEID), } require.NotNil(t, tv.Instance) @@ -180,7 +180,7 @@ func TestEnvironment_FromCBOR_class_and_instance(t *testing.T) { assert.NotNil(t, actual.Class) assert.Equal(t, TestUUIDString, actual.Class.ClassID.String()) assert.NotNil(t, actual.Instance) - assert.Equal(t, TestUEIDString, actual.Instance.String()) + assert.Equal(t, []byte(TestUEID), actual.Instance.Bytes()) assert.Nil(t, actual.Group) } diff --git a/comid/example_psa_keys_test.go b/comid/example_psa_keys_test.go index edf95989..3780cbbd 100644 --- a/comid/example_psa_keys_test.go +++ b/comid/example_psa_keys_test.go @@ -70,12 +70,7 @@ func extractInstanceID(i *Instance) error { return fmt.Errorf("no instance") } - instID, err := i.GetUEID() - if err != nil { - return fmt.Errorf("extracting implemenetation-id: %w", err) - } - - fmt.Printf("InstanceID: %x\n", instID) + fmt.Printf("InstanceID: %x\n", i.Bytes()) return nil } diff --git a/comid/example_psa_refval_test.go b/comid/example_psa_refval_test.go index 230209d0..8848e90f 100644 --- a/comid/example_psa_refval_test.go +++ b/comid/example_psa_refval_test.go @@ -142,12 +142,11 @@ func extractImplementationID(c *Class) error { return fmt.Errorf("no class-id") } - implID, err := classID.GetImplID() - if err != nil { - return fmt.Errorf("extracting implemenetation-id: %w", err) + if classID.Type() != ImplIDType { + return fmt.Errorf("class id is not a psa.impl-id") } - fmt.Printf("ImplementationID: %x\n", implID) + fmt.Printf("ImplementationID: %x\n", classID.Bytes()) return nil } diff --git a/comid/example_test.go b/comid/example_test.go index 756829d3..2e8a5ec9 100644 --- a/comid/example_test.go +++ b/comid/example_test.go @@ -26,7 +26,7 @@ func Example_encode() { SetModel("RoadRunner"). SetLayer(0). SetIndex(1), - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewInstanceUEID(TestUEID), Group: NewGroupUUID(TestUUID), }, Measurements: *NewMeasurements(). @@ -55,7 +55,7 @@ func Example_encode() { SetModel("RoadRunner"). SetLayer(0). SetIndex(1), - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewInstanceUEID(TestUEID), Group: NewGroupUUID(TestUUID), }, Measurements: *NewMeasurements(). @@ -79,7 +79,7 @@ func Example_encode() { AddAttestVerifKey( AttestVerifKey{ Environment: Environment{ - Instance: NewInstanceUUID(uuid.UUID(TestUUID)), + Instance: MustNewInstanceUUID(uuid.UUID(TestUUID)), }, VerifKeys: *NewCryptoKeys(). Add( @@ -89,7 +89,7 @@ func Example_encode() { ).AddDevIdentityKey( DevIdentityKey{ Environment: Environment{ - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewInstanceUEID(TestUEID), }, VerifKeys: *NewCryptoKeys(). Add( @@ -144,7 +144,7 @@ func Example_encode_PSA() { AddAttestVerifKey( AttestVerifKey{ Environment: Environment{ - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewInstanceUEID(TestUEID), }, VerifKeys: *NewCryptoKeys(). Add( @@ -175,7 +175,7 @@ func Example_encode_PSA_attestation_verification() { AddAttestVerifKey( AttestVerifKey{ Environment: Environment{ - Instance: NewInstanceUEID(TestUEID), + Instance: MustNewInstanceUEID(TestUEID), }, VerifKeys: *NewCryptoKeys(). Add( diff --git a/comid/instance.go b/comid/instance.go index 8201145d..94286374 100644 --- a/comid/instance.go +++ b/comid/instance.go @@ -1,179 +1,178 @@ package comid import ( - "encoding/hex" "encoding/json" + "errors" "fmt" - "github.com/google/uuid" - "github.com/veraison/eat" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" ) +type IInstanceValue interface { + extensions.ITypeChoiceValue + + Bytes() []byte +} + // Instance stores an instance identity. The supported formats are UUID and UEID. type Instance struct { - val interface{} + Value IInstanceValue } -// NewInstance instantiates an empty instance -func NewInstance() *Instance { - return &Instance{} -} +// NewInstanceUEID instantiates a new instance with the supplied UEID identity +func NewInstanceUEID(val any) (*Instance, error) { + if val == nil { + return &Instance{&TaggedUEID{}}, nil + } -// SetUEID sets the identity of the target instance to the supplied UEID -func (o *Instance) SetUEID(val eat.UEID) *Instance { - if o != nil { - if val.Validate() != nil { - return nil - } - o.val = TaggedUEID(val) + ret, err := NewTaggedUEID(val) + if err != nil { + return nil, err } - return o + return &Instance{ret}, nil } -// SetUUID sets the identity of the target instance to the supplied UUID -func (o *Instance) SetUUID(val uuid.UUID) *Instance { - if o != nil { - o.val = TaggedUUID(val) +func MustNewInstanceUEID(val any) *Instance { + ret, err := NewInstanceUEID(val) + if err != nil { + panic(err) } - return o -} -// NewInstanceUEID instantiates a new instance with the supplied UEID identity -func NewInstanceUEID(val eat.UEID) *Instance { - return NewInstance().SetUEID(val) + return ret } // NewInstanceUUID instantiates a new instance with the supplied UUID identity -func NewInstanceUUID(val uuid.UUID) *Instance { - return NewInstance().SetUUID(val) -} +func NewInstanceUUID(val any) (*Instance, error) { + if val == nil { + return &Instance{&TaggedUUID{}}, nil + } -// Valid checks for the validity of given instance -func (o Instance) Valid() error { - if o.String() == "" { - return fmt.Errorf("invalid instance id") + ret, err := NewTaggedUUID(val) + if err != nil { + return nil, err } - return nil + + return &Instance{ret}, nil } -func (o Instance) GetUEID() (eat.UEID, error) { - switch t := o.val.(type) { - case TaggedUEID: - return eat.UEID(t), nil - default: - return eat.UEID{}, fmt.Errorf("instance-id type is: %T", t) +func MustNewInstanceUUID(val any) *Instance { + ret, err := NewInstanceUUID(val) + if err != nil { + panic(err) } + + return ret } -func (o Instance) GetUUID() (UUID, error) { - switch t := o.val.(type) { - case TaggedUUID: - return UUID(t), nil - default: - return UUID{}, fmt.Errorf("instance-id type is: %T", t) +// Valid checks for the validity of given instance +func (o Instance) Valid() error { + if o.String() == "" { + return fmt.Errorf("invalid instance id") } + return nil } // String returns a printable string of the Instance value. UUIDs use the // canonical 8-4-4-4-12 format, UEIDs are hex encoded. func (o Instance) String() string { - switch t := o.val.(type) { - case TaggedUUID: - return UUID(t).String() - case TaggedUEID: - return hex.EncodeToString(t) - default: + if o.Value == nil { return "" } + + return o.Value.String() +} + +func (o Instance) Type() string { + return o.Value.Type() +} + +func (o Instance) Bytes() []byte { + return o.Value.Bytes() } // MarshalCBOR serializes the target instance to CBOR func (o Instance) MarshalCBOR() ([]byte, error) { - return em.Marshal(o.val) + return em.Marshal(o.Value) } func (o *Instance) UnmarshalCBOR(data []byte) error { - var ueid TaggedUEID - - if dm.Unmarshal(data, &ueid) == nil { - o.val = ueid - return nil - } - - var u TaggedUUID - - if dm.Unmarshal(data, &u) == nil { - o.val = u - return nil - } - - return fmt.Errorf("unknown instance type (CBOR: %x)", data) + return dm.Unmarshal(data, &o.Value) } -// UnmarshalJSON deserializes the supplied JSON type/value object into the Group -// target. The supported formats are UUID, e.g.: +// UnmarshalJSON deserializes the supplied JSON object into the target Instance +// The instance object must have the following shape: // // { -// "type": "uuid", -// "value": "69E027B2-7157-4758-BCB4-D9F167FE49EA" +// "type": "", +// "value": "" // } // -// and UEID: +// where must be one of the known IInstanceValue implementation +// type names (available in the base implementation: "ueid" and "uuid"), and +// is the instance value encoded as a string. The exact +// encoding is depenent. For the base implmentation types it is // -// { -// "type": "ueid", -// "value": "Ad6tvu/erb7v3q2+796tvu8=" -// } +// ueid: base64-encoded bytes, e.g. "YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE=" +// uuid: standard UUID string representation, e.g. "550e8400-e29b-41d4-a716-446655440000" func (o *Instance) UnmarshalJSON(data []byte) error { - var v tnv + var value encoding.TypeAndValue - if err := json.Unmarshal(data, &v); err != nil { + if err := json.Unmarshal(data, &value); err != nil { return err } - switch v.Type { - case "uuid": - var x UUID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedUUID(x) - case "ueid": - var x UEID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedUEID(x) - default: - return fmt.Errorf("unknown type %s for instance", v.Type) + if value.Type == "" { + return errors.New("key type not set") } - return nil + factory, ok := instanceValueRegister[value.Type] + if !ok { + return fmt.Errorf("unknown class id type: %q", value.Type) + } + + v, err := factory(value.Value) + if err != nil { + return err + } + + o.Value = v.Value + + return o.Valid() } func (o Instance) MarshalJSON() ([]byte, error) { - var ( - v tnv - b []byte - err error - ) - - switch t := o.val.(type) { - case TaggedUUID: - b, err = UUID(t).MarshalJSON() - if err != nil { - return nil, err - } - v = tnv{Type: "uuid", Value: b} - case TaggedUEID: - b, err = UEID(t).MarshalJSON() - if err != nil { - return nil, err - } - v = tnv{Type: "ueid", Value: b} - default: - return nil, fmt.Errorf("unknown type %T for instance", t) - } - - return json.Marshal(v) + value := encoding.TypeAndValue{ + Type: o.Value.Type(), + Value: o.Value.String(), + } + + return json.Marshal(value) +} + +type IInstanceFactory func(any) (*Instance, error) + +var instanceValueRegister = map[string]IInstanceFactory{ + UEIDType: NewInstanceUEID, + UUIDType: NewInstanceUUID, +} + +func RegisterInstanceType(tag uint64, factory IInstanceFactory) error { + nilVal, err := factory(nil) + if err != nil { + return err + } + + typ := nilVal.Type() + if _, exists := instanceValueRegister[typ]; exists { + return fmt.Errorf("class ID type with name %q already exists", typ) + } + + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + instanceValueRegister[typ] = factory + + return nil } diff --git a/comid/instance_test.go b/comid/instance_test.go index 34e671e7..e65c99a3 100644 --- a/comid/instance_test.go +++ b/comid/instance_test.go @@ -1,24 +1,69 @@ package comid import ( + "encoding/json" "testing" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestInstance_GetUUID_OK(t *testing.T) { - inst := NewInstanceUUID(uuid.UUID(TestUUID)) - require.NotNil(t, inst) - u, err := inst.GetUUID() - assert.Nil(t, err) - assert.Equal(t, u, TestUUID) + inst := MustNewInstanceUUID(TestUUID) + u, ok := inst.Value.(*TaggedUUID) + assert.True(t, ok) + assert.EqualValues(t, TestUUID, *u) } -func TestInstance_GetUUID_NOK(t *testing.T) { - inst := &Instance{} - expectedErr := "instance-id type is: " - _, err := inst.GetUUID() - assert.EqualError(t, err, expectedErr) +type testInstance string + +func newTestInstance(val any) (*Instance, error) { + ret := testInstance("test") + return &Instance{&ret}, nil +} + +func (o testInstance) Bytes() []byte { + return []byte(o) +} + +func (o testInstance) Type() string { + return "test-instance" +} + +func (o testInstance) String() string { + return string(o) +} + +func (o testInstance) Valid() error { + return nil +} + +func Test_RegisterInstanceType(t *testing.T) { + err := RegisterInstanceType(99997, newTestInstance) + require.NoError(t, err) + + instance, err := newTestInstance(nil) + require.NoError(t, err) + + data, err := json.Marshal(instance) + require.NoError(t, err) + assert.Equal(t, string(data), `{"type":"test-instance","value":"test"}`) + + var out Instance + err = json.Unmarshal(data, &out) + require.NoError(t, err) + assert.Equal(t, instance.Bytes(), out.Bytes()) + + data, err = em.Marshal(instance) + require.NoError(t, err) + assert.Equal(t, data, []byte{ + 0xda, 0x0, 0x1, 0x86, 0x9d, // tag 99997 + 0x64, // tstr(4) + 0x74, 0x65, 0x73, 0x74, // "test" + }) + + var out2 Instance + err = dm.Unmarshal(data, &out2) + require.NoError(t, err) + assert.Equal(t, instance.Bytes(), out2.Bytes()) } diff --git a/comid/referencevalue_test.go b/comid/referencevalue_test.go index 00b89d18..89fb18ef 100644 --- a/comid/referencevalue_test.go +++ b/comid/referencevalue_test.go @@ -13,10 +13,9 @@ func Test_ReferenceValue(t *testing.T) { err := rv.Valid() assert.EqualError(t, err, "environment validation failed: environment must not be empty") - rv.Environment.Instance = NewInstance() id, err := uuid.NewUUID() require.NoError(t, err) - rv.Environment.Instance.SetUUID(id) + rv.Environment.Instance = MustNewInstanceUUID(id) err = rv.Valid() assert.EqualError(t, err, "measurements validation failed: no measurement entries") } diff --git a/comid/tmp b/comid/tmp new file mode 100644 index 00000000..58681bc6 --- /dev/null +++ b/comid/tmp @@ -0,0 +1,15 @@ +ImplementationID: 61636d652d696d706c656d656e746174696f6e2d69642d303030303030303031 +SignerID: acbb11c7e4da217205523ce4ce1a245ae1a239ae3c6bfd9e7871f7e5d8bae86b +Label: BL +Version: 2.1.0 +Digest: 87428fc522803d31065e7bce3cf03fe475096631e5e07bbd7a0fde60c4cf25c7 +SignerID: acbb11c7e4da217205523ce4ce1a245ae1a239ae3c6bfd9e7871f7e5d8bae86b +Label: PRoT +Version: 1.3.5 +Digest: 0263829989b6fd954f72baaf2fc64bc2e2f01d692d4de72986ea808f6e99813f +SignerID: acbb11c7e4da217205523ce4ce1a245ae1a239ae3c6bfd9e7871f7e5d8bae86b +Label: ARoT +Version: 0.1.4 +Digest: a3a5e715f0cc574a73c3f9bebb6bc24f32ffd5b67b387244c2c909da779a1478 +Label: a non-empty (unique) label +Raw value: 72617776616c75650a72617776616c75650a diff --git a/comid/ueid.go b/comid/ueid.go index 765d1b86..cf64ffa2 100644 --- a/comid/ueid.go +++ b/comid/ueid.go @@ -4,18 +4,17 @@ package comid import ( - "encoding/json" + "encoding/base64" "fmt" "github.com/veraison/eat" ) +const UEIDType = "ueid" + // UEID is an Unique Entity Identifier type UEID eat.UEID -// TaggedUEID is an alias to allow automatic tagging of an UEID type -type TaggedUEID UEID - func (o UEID) Empty() bool { return len(o) == 0 } @@ -28,25 +27,65 @@ func (o UEID) Valid() error { return nil } -// UnmarshalJSON deserializes the supplied string into the UEID target -func (o *UEID) UnmarshalJSON(data []byte) error { - var b []byte +func (o UEID) String() string { + return base64.StdEncoding.EncodeToString(o) +} + +// TaggedUEID is an alias to allow automatic tagging of an UEID type +type TaggedUEID UEID + +func NewTaggedUEID(val any) (*TaggedUEID, error) { + var ret TaggedUEID - if err := json.Unmarshal(data, &b); err != nil { - return err + if val == nil { + return &ret, nil } - u := UEID(b) + switch t := val.(type) { + case string: + b, err := base64.StdEncoding.DecodeString(t) + if err != nil { + return nil, fmt.Errorf("bad UEID: %w", err) + } - if err := u.Valid(); err != nil { - return err + ret = TaggedUEID(b) + case []byte: + ret = TaggedUEID(t) + case TaggedUEID: + ret = append(ret, t...) + case *TaggedUEID: + ret = append(ret, *t...) + case UEID: + ret = append(ret, t...) + case *UEID: + ret = append(ret, *t...) + case eat.UEID: + ret = append(ret, t...) + case *eat.UEID: + ret = append(ret, *t...) + default: + return nil, fmt.Errorf("unexpeted type for UEID: %T", t) } - *o = u + if err := ret.Valid(); err != nil { + return nil, err + } - return nil + return &ret, nil +} + +func (o TaggedUEID) Valid() error { + return UEID(o).Valid() +} + +func (o TaggedUEID) String() string { + return UEID(o).String() +} + +func (o TaggedUEID) Type() string { + return "ueid" } -func (o UEID) MarshalJSON() ([]byte, error) { - return json.Marshal([]byte(o)) +func (o TaggedUEID) Bytes() []byte { + return []byte(o) } diff --git a/comid/ueid_test.go b/comid/ueid_test.go new file mode 100644 index 00000000..a119ad44 --- /dev/null +++ b/comid/ueid_test.go @@ -0,0 +1,30 @@ +package comid + +import ( + "encoding/base64" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_NewTaggedUEID(t *testing.T) { + ueid := UEID(TestUEID) + tagged := TaggedUEID(TestUEID) + bytes := MustHexDecode(t, TestUEIDString) + + for _, v := range []any{ + TestUEID, + &TestUEID, + ueid, + &ueid, + tagged, + &tagged, + bytes, + base64.StdEncoding.EncodeToString(bytes), + } { + ret, err := NewTaggedUEID(v) + require.NoError(t, err) + assert.Equal(t, []byte(TestUEID), ret.Bytes()) + } +} diff --git a/comid/uuid.go b/comid/uuid.go index c5c78d61..15a16eab 100644 --- a/comid/uuid.go +++ b/comid/uuid.go @@ -10,12 +10,11 @@ import ( "github.com/google/uuid" ) +const UUIDType = "uuid" + // UUID represents an Universally Unique Identifier (UUID, see RFC4122) type UUID uuid.UUID -// TaggedUUID is an alias to allow automatic tagging of a UUID type -type TaggedUUID UUID - // ParseUUID parses the supplied string into a UUID func ParseUUID(s string) (UUID, error) { v, err := uuid.Parse(s) @@ -64,3 +63,68 @@ func (o *UUID) UnmarshalJSON(data []byte) error { func (o UUID) MarshalJSON() ([]byte, error) { return json.Marshal(o.String()) } + +// TaggedUUID is an alias to allow automatic tagging of a UUID type +type TaggedUUID UUID + +func NewTaggedUUID(val any) (*TaggedUUID, error) { + var ret TaggedUUID + + switch t := val.(type) { + case string: + u, err := ParseUUID(t) + if err != nil { + return nil, fmt.Errorf("bad UUID: %w", err) + } + ret = TaggedUUID(u) + case []byte: + if len(t) != 16 { + return nil, fmt.Errorf( + "unexpected size for UUID: expected 16 bytes, found %d", + len(t), + ) + } + + copy(ret[:], t) + case TaggedUUID: + copy(ret[:], t[:]) + case *TaggedUUID: + copy(ret[:], (*t)[:]) + case UUID: + copy(ret[:], t[:]) + case *UUID: + copy(ret[:], (*t)[:]) + case uuid.UUID: + copy(ret[:], t[:]) + case *uuid.UUID: + copy(ret[:], (*t)[:]) + default: + return nil, fmt.Errorf("unexpected type for UUID: %T", t) + } + + if err := ret.Valid(); err != nil { + return nil, err + } + + return &ret, nil +} + +// String returns a string representation of the binary UUID +func (o TaggedUUID) String() string { + return UUID(o).String() +} + +func (o TaggedUUID) Valid() error { + return UUID(o).Valid() +} + +// Type returns a string contianing type name. This is part of the +// ITypeChoiceValue implementation. +func (o TaggedUUID) Type() string { + return "uuid" +} + +// Bytes returns a []byte containing the raw UUID bytes +func (o TaggedUUID) Bytes() []byte { + return o[:] +} diff --git a/corim/unsignedcorim_test.go b/corim/unsignedcorim_test.go index d9b4d105..f24bd90a 100644 --- a/corim/unsignedcorim_test.go +++ b/corim/unsignedcorim_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/veraison/corim/comid" @@ -182,7 +181,7 @@ func TestUnsignedCorim_Valid_ok(t *testing.T) { AddAttestVerifKey( comid.AttestVerifKey{ Environment: comid.Environment{ - Instance: comid.NewInstanceUUID(uuid.UUID(comid.TestUUID)), + Instance: comid.MustNewInstanceUUID(comid.TestUUID), }, VerifKeys: *comid.NewCryptoKeys(). Add( diff --git a/encoding/json.go b/encoding/json.go new file mode 100644 index 00000000..74dc6b08 --- /dev/null +++ b/encoding/json.go @@ -0,0 +1,9 @@ +package encoding + +// TypeAndValue stores a JSON object with two attributes: a string "type" +// and a generic "value" (string) defined by type. This type is used in +// a few places to implement the choice types that CBOR handles using tags. +type TypeAndValue struct { + Type string `json:"type"` + Value string `json:"value"` +} diff --git a/extensions/typechoice.go b/extensions/typechoice.go new file mode 100644 index 00000000..263013df --- /dev/null +++ b/extensions/typechoice.go @@ -0,0 +1,11 @@ +package extensions + +type ITypeChoiceValue interface { + // String returns the string representation of the ITypeChoiceValue. + String() string + // Valid returns an error if validation of the ITypeChoiceValue fails, + // or nil if it succeeds. + Valid() error + // Type returns the type name of this ITypeChoiceValue implementation. + Type() string +}