From b8f1850b510cf18f57627a3514d82008992d9d24 Mon Sep 17 00:00:00 2001 From: Sergei Trofimov Date: Wed, 13 Sep 2023 17:48:30 +0100 Subject: [PATCH] WIP: dynamic ICryptoKeyValue registration prototype Signed-off-by: Sergei Trofimov --- comid/cbor.go | 29 ++- comid/class.go | 32 +-- comid/classid.go | 433 +++++++++++++++++++------------ comid/classid_test.go | 45 ++-- comid/cryptokey.go | 303 ++++++++++++--------- comid/cryptokey_test.go | 12 +- comid/example_psa_refval_test.go | 7 +- comid/tmp | 15 ++ comid/uuid.go | 20 ++ encoding/json.go | 9 + extensions/typechoice.go | 11 + 11 files changed, 572 insertions(+), 344 deletions(-) create mode 100644 comid/tmp create mode 100644 encoding/json.go create mode 100644 extensions/typechoice.go diff --git a/comid/cbor.go b/comid/cbor.go index c44d518d..2de1a2b8 100644 --- a/comid/cbor.go +++ b/comid/cbor.go @@ -4,6 +4,7 @@ package comid import ( + "fmt" "reflect" cbor "github.com/fxamacker/cbor/v2" @@ -12,10 +13,8 @@ import ( var ( em, emError = initCBOREncMode() dm, dmError = initCBORDecMode() -) -func comidTags() cbor.TagSet { - comidTagsMap := map[uint64]interface{}{ + comidTagsMap = map[uint64]interface{}{ 32: TaggedURI(""), 37: TaggedUUID{}, 111: TaggedOID{}, @@ -37,7 +36,9 @@ func comidTags() cbor.TagSet { 601: TaggedPSARefValID{}, 602: TaggedCCAPlatformConfigID(""), } +) +func comidTags() cbor.TagSet { opts := cbor.TagOptions{ EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired, @@ -69,6 +70,28 @@ func initCBORDecMode() (dm cbor.DecMode, err error) { return decOpt.DecModeWithTags(comidTags()) } +func registerCOMIDTag(tag uint64, t interface{}) error { + if _, exists := comidTagsMap[tag]; exists { + return fmt.Errorf("tag %d is already registered", tag) + } + + comidTagsMap[tag] = t + + var err error + + em, err = initCBOREncMode() + if err != nil { + return err + } + + dm, err = initCBORDecMode() + if err != nil { + return err + } + + return nil +} + func init() { if emError != nil { panic(emError) diff --git a/comid/class.go b/comid/class.go index 085e703b..ab414740 100644 --- a/comid/class.go +++ b/comid/class.go @@ -23,39 +23,33 @@ type Class struct { // NewClassUUID instantiates a new Class object with the specified UUID as // identifier func NewClassUUID(uuid UUID) *Class { - c := Class{ - ClassID: &ClassID{}, - } - - if c.ClassID.SetUUID(uuid) == nil { + classID, err := NewUUIDClassID(uuid) + if err != nil { return nil } - return &c + + return &Class{ClassID: classID} } // NewClassImplID instantiates a new Class object that identifies the specified PSA // Implementation ID func NewClassImplID(implID ImplID) *Class { - c := Class{ - ClassID: &ClassID{}, - } - - if c.ClassID.SetImplID(implID) == nil { + classID, err := NewImplIDClassID(implID) + if err != nil { return nil } - return &c + + return &Class{ClassID: classID} } // NewClassOID instantiates a new Class object that identifies the OID func NewClassOID(oid string) *Class { - c := Class{ - ClassID: &ClassID{}, - } - - if c.ClassID.SetOID(oid) == nil { + classID, err := NewOIDClassID(oid) + if err != nil { return nil } - return &c + + return &Class{ClassID: classID} } // SetVendor sets the vendor metadata to the supplied string @@ -131,7 +125,7 @@ func (o *Class) SetIndex(index uint64) *Class { // Valid checks the non-empty<> constraint on the map func (o Class) Valid() error { // check non-empty<{ ... }> - if (o.ClassID == nil || o.ClassID.Unset()) && + if (o.ClassID == nil || !o.ClassID.IsSet()) && o.Vendor == nil && o.Model == nil && o.Layer == nil && o.Index == nil { return fmt.Errorf("class must not be empty") } diff --git a/comid/classid.go b/comid/classid.go index 760f8e5d..4d13a6a3 100644 --- a/comid/classid.go +++ b/comid/classid.go @@ -6,38 +6,143 @@ package comid import ( "encoding/base64" "encoding/json" + "errors" "fmt" + + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" +) + +const ( + UUIDType = "uuid" + IntType = "int" + OIDType = "oid" + + ImplIDType = "psa.impl-id" ) // ClassID represents a $class-id-type-choice, which can be one of TaggedUUID, // TaggedOID, or TaggedImplID (PSA-specific extension) type ClassID struct { - val interface{} + Value IClassIDValue } -type ClassIDType uint16 +func (o ClassID) Valid() error { + if o.Value == nil { + return errors.New("nil value") + } -const ( - ClassIDTypeUUID = ClassIDType(iota) - ClassIDTypeImplID - ClassIDTypeOID + return o.Value.Valid() +} - ClassIDTypeUnknown = ^ClassIDType(0) -) +// Type returns the type of the target ClassID, i.e., one of UUID, OID or PSA +// Implementation ID +func (o ClassID) Type() string { + if o.Value == nil { + return "" + } + + return o.Value.Type() +} -// SetUUID sets the value of the targed ClassID to the supplied UUID -func (o *ClassID) SetUUID(uuid UUID) *ClassID { - if o != nil { - o.val = TaggedUUID(uuid) +// Bytes returns a []byte containing the raw bytes of the class id value +func (o ClassID) Bytes() []byte { + if o.Value == nil { + return []byte{} } - return o + return o.Value.Bytes() +} + +// IsSet returns true iff the underlying class id value has been set (is not nil) +func (o ClassID) IsSet() bool { + return o.Value != nil +} + +// MarshalCBOR serializes the target ClassID to CBOR +func (o ClassID) MarshalCBOR() ([]byte, error) { + return em.Marshal(o.Value) +} + +// UnmarshalCBOR deserializes the supplied CBOR buffer into the target ClassID. +// It is undefined behavior to try and inspect the target ClassID in case this +// method returns an error. +func (o *ClassID) UnmarshalCBOR(data []byte) error { + return dm.Unmarshal(data, &o.Value) +} + +// UnmarshalJSON deserializes the supplied JSON object into the target ClassID +// The class id object must following shape: +// +// { +// "type": "", +// "value": "" +// } +// +// where must be one of the known IClassIDValue implementation +// type names (available in the base implementation: "uuid", "oid", +// "psa.impl-id"), and is the class id value encoded as +// a string. The exact encoding is depenent. For the base +// implmentation types it is +// uuid: standard UUID string reprsentation, e.g. "550e8400-e29b-41d4-a716-446655440000" +// oid: dot-seprated integers, e.g. "1.2.3.4" +// psa.impl-id: base64-encoded bytes, e.g. "YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE=" +func (o *ClassID) UnmarshalJSON(data []byte) error { + var value encoding.TypeAndValue + + if err := json.Unmarshal(data, &value); err != nil { + return err + } + + if value.Type == "" { + return errors.New("key type not set") + } + + factory, ok := classIDValueRegister[value.Type] + if !ok { + return fmt.Errorf("unknown class id type: %q", value.Type) + } + + v, err := factory(value.Value) + if err != nil { + return err + } + + o.Value = v.Value + + return o.Valid() +} + +// MarshalJSON serializes the target ClassID to JSON +func (o ClassID) MarshalJSON() ([]byte, error) { + value := encoding.TypeAndValue{ + Type: o.Value.Type(), + Value: o.Value.String(), + } + + return json.Marshal(value) +} + +// String returns a printable string of the ClassID value. UUIDs use the +// canonical 8-4-4-4-12 format, PSA Implementation IDs are base64 encoded. +// OIDs are output in dotted-decimal notation. +func (o ClassID) String() string { + return o.Value.String() +} + +type IClassIDValue interface { + extensions.ITypeChoiceValue + + Bytes() []byte } type ImplID [32]byte -type TaggedImplID ImplID -func (o ImplID) MarshalJSON() ([]byte, error) { - return json.Marshal(o[:]) +func (o ImplID) String() string { + return base64.StdEncoding.EncodeToString(o[:]) +} + +func (o ImplID) Valid() error { + return nil } func (o *ImplID) UnmarshalJSON(data []byte) error { @@ -56,194 +161,194 @@ func (o *ImplID) UnmarshalJSON(data []byte) error { return nil } -type TaggedOID OID +type TaggedImplID ImplID + +func NewImplIDClassID(val any) (*ClassID, error) { + var ret TaggedImplID -// SetImplID sets the value of the targed ClassID to the supplied PSA -// Implementation ID (see Section 3.2.2 of draft-tschofenig-rats-psa-token) -func (o *ClassID) SetImplID(implID ImplID) *ClassID { - if o != nil { - o.val = TaggedImplID(implID) + if val == nil { + return &ClassID{(*TaggedImplID)(new([32]byte))}, nil } - return o -} -func (o ClassID) GetImplID() (ImplID, error) { - switch t := o.val.(type) { + switch t := val.(type) { + case []byte: + if nb := len(t); nb != 32 { + return nil, fmt.Errorf("bad ImplID: got %d bytes, want 32", nb) + } + + copy(ret[:], t) + case string: + v, err := base64.StdEncoding.DecodeString(t) + if err != nil { + return nil, fmt.Errorf("bad ImplID: %w", err) + } + + if nb := len(v); nb != 32 { + return nil, fmt.Errorf("bad ImplID: decoded %d bytes, want 32", nb) + } + + copy(ret[:], v) case TaggedImplID: - return ImplID(t), nil + copy(ret[:], t[:]) + case *TaggedImplID: + copy(ret[:], (*t)[:]) + case ImplID: + copy(ret[:], t[:]) + case *ImplID: + copy(ret[:], (*t)[:]) default: - return ImplID{}, fmt.Errorf("class-id type is: %T", t) + return nil, fmt.Errorf("unexpected type of ImplID: %T", t) } + + return &ClassID{&ret}, nil } -// SetOID sets the value of the targed ClassID to the supplied OID. -// The OID is a string in dotted-decimal notation -func (o *ClassID) SetOID(s string) *ClassID { - if o != nil { - var berOID OID - if berOID.FromString(s) != nil { - return nil - } - o.val = TaggedOID(berOID) +func MustNewImplIDClassID(val any) *ClassID { + ret, err := NewImplIDClassID(val) + if err != nil { + panic(err) } - return o + + return ret } -// MarshalCBOR serializes the target ClassID to CBOR -func (o ClassID) MarshalCBOR() ([]byte, error) { - return em.Marshal(o.val) +func (o TaggedImplID) Valid() error { + return ImplID(o).Valid() } -// UnmarshalCBOR deserializes the supplied CBOR buffer into the target ClassID. -// It is undefined behavior to try and inspect the target ClassID in case this -// method returns an error. -func (o *ClassID) UnmarshalCBOR(data []byte) error { - var implID TaggedImplID +func (o TaggedImplID) String() string { + return ImplID(o).String() +} - if dm.Unmarshal(data, &implID) == nil { - o.val = implID - return nil - } +func (o TaggedImplID) Type() string { + return ImplIDType +} - var uuid TaggedUUID +func (o TaggedImplID) Bytes() []byte { + return o[:] +} - if dm.Unmarshal(data, &uuid) == nil { - o.val = uuid - return nil - } +type TaggedOID OID + +func NewOIDClassID(val any) (*ClassID, error) { + var ret TaggedOID - var oid TaggedOID + switch t := val.(type) { + case string: + var berOID OID + if err := berOID.FromString(t); err != nil { + return nil, err + } - if dm.Unmarshal(data, &oid) == nil { - o.val = oid - return nil + ret = TaggedOID(berOID) + case TaggedOID: + copy(ret, t) + case OID: + copy(ret, t) + case *TaggedOID: + copy(ret, (*t)) + case *OID: + copy(ret, (*t)) } - return fmt.Errorf("unknown class id (CBOR: %x)", data) + return &ClassID{&ret}, nil } -// UnmarshalJSON deserializes the supplied JSON object into the target ClassID -// The class id object must have one of the following shapes: -// -// UUID: -// -// { -// "type": "uuid", -// "value": "69E027B2-7157-4758-BCB4-D9F167FE49EA" -// } -// -// OID: -// -// { -// "type": "oid", -// "value": "2.16.840.1.113741.1.15.4.2" -// } -// -// PSA Implementation ID: -// -// { -// "type": "psa.impl-id", -// "value": "YWNtZS1pbXBsZW1lbnRhdGlvbi1pZC0wMDAwMDAwMDE=" -// } -func (o *ClassID) UnmarshalJSON(data []byte) error { - var v tnv - - if err := json.Unmarshal(data, &v); err != nil { - return err +func MustNewOIDClassID(val any) *ClassID { + ret, err := NewOIDClassID(val) + if err != nil { + panic(err) } - switch v.Type { - case "uuid": // nolint: goconst - var x UUID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedUUID(x) - case "oid": - var x OID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedOID(x) - case "psa.impl-id": - var x ImplID - if err := x.UnmarshalJSON(v.Value); err != nil { - return err - } - o.val = TaggedImplID(x) - default: - return fmt.Errorf("unknown type '%s' for class id", v.Type) - } + return ret +} + +func (o TaggedOID) Type() string { + return OIDType +} + +func (o TaggedOID) String() string { + return OID(o).String() +} +func (o TaggedOID) Valid() error { return nil } -// MarshalJSON serializes the target ClassID to JSON -func (o ClassID) MarshalJSON() ([]byte, error) { - var ( - v tnv - b []byte - err error - ) +func (o TaggedOID) Bytes() []byte { + return o +} - switch t := o.val.(type) { - case TaggedUUID: - b, err = UUID(t).MarshalJSON() - if err != nil { - return nil, err - } - v = tnv{Type: "uuid", Value: b} - case TaggedOID: - b, err = OID(t).MarshalJSON() +func NewUUIDClassID(val any) (*ClassID, error) { + ret := TaggedUUID{} + + if val == nil { + return &ClassID{&ret}, nil + } + + switch t := val.(type) { + case string: + uuid, err := ParseUUID(t) if err != nil { - return nil, err + return nil, fmt.Errorf("bad UUID: %w", err) } - v = tnv{Type: "oid", Value: b} - case TaggedImplID: - b, err = ImplID(t).MarshalJSON() - if err != nil { - return nil, err + ret = TaggedUUID(uuid) + case []byte: + if len(t) != 16 { + return nil, fmt.Errorf( + "unexpected size for UUID: expected 16 bytes, found %d", + len(t), + ) } - v = tnv{Type: "psa.impl-id", Value: b} + + copy(ret[:], t) + case TaggedUUID: + copy(ret[:], t[:]) + case *TaggedUUID: + copy(ret[:], (*t)[:]) + case UUID: + copy(ret[:], t[:]) + case *UUID: + copy(ret[:], (*t)[:]) default: - return nil, fmt.Errorf("unknown type %T for class-id", t) + return nil, fmt.Errorf("unexpected type for UUID: %T", t) } - return json.Marshal(v) + return &ClassID{&ret}, nil } -// Type returns the type of the target ClassID, i.e., one of UUID, OID or PSA -// Implementation ID -func (o ClassID) Type() ClassIDType { - switch o.val.(type) { - case TaggedUUID: - return ClassIDTypeUUID - case TaggedImplID: - return ClassIDTypeImplID - case TaggedOID: - return ClassIDTypeOID +func MustNewUUIDClassID(val any) *ClassID { + ret, err := NewUUIDClassID(val) + if err != nil { + panic(err) } - return ClassIDTypeUnknown + + return ret } -// String returns a printable string of the ClassID value. UUIDs use the -// canonical 8-4-4-4-12 format, PSA Implementation IDs are base64 encoded. -// OIDs are output in dotted-decimal notation. -func (o ClassID) String() string { - switch t := o.val.(type) { - case TaggedUUID: - return UUID(t).String() - case TaggedImplID: - b := [32]byte(t) - return base64.StdEncoding.EncodeToString(b[:]) - case TaggedOID: - return OID(t).String() - default: - return "" - } +type IClassIDFactory func(any) (*ClassID, error) + +var classIDValueRegister = map[string]IClassIDFactory{ + OIDType: NewOIDClassID, + ImplIDType: NewImplIDClassID, + UUIDType: NewUUIDClassID, } -// Unset tests whether the target ClassID has been initialized -func (o ClassID) Unset() bool { - return o.val == nil || o.Type() == ClassIDTypeUnknown +func RegisterClassIDType(typ string, tag uint64, factory IClassIDFactory) error { + if _, exists := classIDValueRegister[typ]; exists { + return fmt.Errorf("class ID type with name %q already exists", typ) + } + + nilVal, err := factory(nil) + if err != nil { + return err + } + + if err := registerCOMIDTag(tag, nilVal); err != nil { + return err + } + + classIDValueRegister[typ] = factory + + return nil } diff --git a/comid/classid_test.go b/comid/classid_test.go index 51b07391..7cf23847 100644 --- a/comid/classid_test.go +++ b/comid/classid_test.go @@ -8,13 +8,10 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestClassID_MarshalCBOR_UUID(t *testing.T) { - var tv ClassID - - require.NotNil(t, tv.SetUUID(TestUUID)) + tv := MustNewUUIDClassID(TestUUID) // 37(h'31FB5ABF023E4992AA4E95F9C1503BFA') // tag(37): d8 25 @@ -29,9 +26,7 @@ func TestClassID_MarshalCBOR_UUID(t *testing.T) { } func TestClassID_MarshalCBOR_ImplID(t *testing.T) { - var tv ClassID - - require.NotNil(t, tv.SetImplID(TestImplID)) + tv := MustNewImplIDClassID(TestImplID) // 600 (h'61636D652D696D706C656D656E746174696F6E2D69642D303030303030303031') // tag(600): d9 0258 @@ -66,7 +61,7 @@ func TestClassID_UnmarshalCBOR_UUID_OK(t *testing.T) { err := actual.UnmarshalCBOR(tv) assert.Nil(t, err) - assert.Equal(t, ClassIDTypeUUID, actual.Type()) + assert.Equal(t, "uuid", actual.Type()) assert.Equal(t, TestUUIDString, actual.String()) } @@ -79,7 +74,7 @@ func TestClassID_UnmarshalCBOR_ImplID_OK(t *testing.T) { err := actual.UnmarshalCBOR(tv) assert.Nil(t, err) - assert.Equal(t, ClassIDTypeImplID, actual.Type()) + assert.Equal(t, "psa.impl-id", actual.Type()) assert.Equal(t, expected, actual.String()) } @@ -88,12 +83,10 @@ func TestClassID_UnmarshalCBOR_badInput(t *testing.T) { hex := "582061636d652d696d706c656d656e746174696f6e2d69642d303030303030303031" tv := MustHexDecode(t, hex) - expectedError := fmt.Sprintf("unknown class id (CBOR: %s)", hex) - var actual ClassID err := actual.UnmarshalCBOR(tv) - assert.EqualError(t, err, expectedError) + assert.EqualError(t, err, "cbor: cannot unmarshal byte string into Go value of type comid.IClassIDValue") } func TestClassID_UnmarshalJSON_UUID(t *testing.T) { @@ -105,7 +98,7 @@ func TestClassID_UnmarshalJSON_UUID(t *testing.T) { err := actual.UnmarshalJSON([]byte(tv)) assert.Nil(t, err) - assert.Equal(t, ClassIDTypeUUID, actual.Type()) + assert.Equal(t, "uuid", actual.Type()) assert.Equal(t, TestUUIDString, actual.String()) } @@ -122,7 +115,7 @@ func TestClassID_UnmarshalJSON_ImplID(t *testing.T) { err := actual.UnmarshalJSON([]byte(tv)) assert.Nil(t, err) - assert.Equal(t, ClassIDTypeImplID, actual.Type()) + assert.Equal(t, "psa.impl-id", actual.Type()) // the returned string is the base64 encoding of the stored binary assert.Equal(t, expected, actual.String()) } @@ -133,8 +126,8 @@ func TestClassID_UnmarshalJSON_badInput_unknown_type(t *testing.T) { var actual ClassID err := actual.UnmarshalJSON([]byte(tv)) - assert.EqualError(t, err, "unknown type 'FOOBAR' for class id") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.EqualError(t, err, `unknown class id type: "FOOBAR"`) + assert.Equal(t, "", actual.Type()) } func TestClassID_UnmarshalJSON_badInput_missing_value(t *testing.T) { @@ -143,8 +136,8 @@ func TestClassID_UnmarshalJSON_badInput_missing_value(t *testing.T) { var actual ClassID err := actual.UnmarshalJSON([]byte(tv)) - assert.EqualError(t, err, "bad ImplID: unexpected end of JSON input") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.EqualError(t, err, "bad ImplID: decoded 0 bytes, want 32") + assert.Equal(t, "", actual.Type()) } func TestClassID_UnmarshalJSON_badInput_empty_value(t *testing.T) { @@ -153,8 +146,8 @@ func TestClassID_UnmarshalJSON_badInput_empty_value(t *testing.T) { var actual ClassID err := actual.UnmarshalJSON([]byte(tv)) - assert.EqualError(t, err, "bad ImplID format: got 0 bytes, want 32") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.EqualError(t, err, "bad ImplID: decoded 0 bytes, want 32") + assert.Equal(t, "", actual.Type()) } func TestClassID_UnmarshalJSON_badInput_badly_encoded_ImplID_value(t *testing.T) { @@ -164,7 +157,7 @@ func TestClassID_UnmarshalJSON_badInput_badly_encoded_ImplID_value(t *testing.T) err := actual.UnmarshalJSON([]byte(tv)) assert.EqualError(t, err, "bad ImplID: illegal base64 data at input byte 0") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.Equal(t, "", actual.Type()) } func TestClassID_UnmarshalJSON_badInput_badly_encoded_UUID_value(t *testing.T) { @@ -174,7 +167,7 @@ func TestClassID_UnmarshalJSON_badInput_badly_encoded_UUID_value(t *testing.T) { err := actual.UnmarshalJSON([]byte(tv)) assert.EqualError(t, err, "bad UUID: invalid UUID length: 9") - assert.Equal(t, ClassIDTypeUnknown, actual.Type()) + assert.Equal(t, "", actual.Type()) } func TestClassID_SetOID_ok(t *testing.T) { @@ -200,8 +193,7 @@ func TestClassID_SetOID_ok(t *testing.T) { } for _, tv := range tvs { - c := ClassID{} - assert.NotNil(t, c.SetOID(tv)) + c := MustNewOIDClassID(tv) assert.Equal(t, tv, c.String()) } } @@ -219,7 +211,8 @@ func TestClassID_SetOID_bad(t *testing.T) { } for _, tv := range tvs { - c := ClassID{} - assert.Nil(t, c.SetOID(tv)) + c, err := NewOIDClassID(tv) + assert.NotNil(t, err) + assert.Nil(t, c) } } diff --git a/comid/cryptokey.go b/comid/cryptokey.go index 53d99df8..7f9186bb 100644 --- a/comid/cryptokey.go +++ b/comid/cryptokey.go @@ -14,6 +14,8 @@ import ( "fmt" "github.com/fxamacker/cbor/v2" + "github.com/veraison/corim/encoding" + "github.com/veraison/corim/extensions" "github.com/veraison/go-cose" "github.com/veraison/swid" ) @@ -58,50 +60,12 @@ type CryptoKey struct { // specified crypto key type. For PKIX types, k must be a string. For COSE_Key, // k must be a []byte. For thumbprint types, k must be a swid.HashEntry. func NewCryptoKey(k any, typ string) (*CryptoKey, error) { - switch typ { - case PKIXBase64KeyType: - v, ok := k.(string) - if !ok { - return nil, fmt.Errorf("value must be a string; found %T", k) - } - return NewPKIXBase64Key(v) - case PKIXBase64CertType: - v, ok := k.(string) - if !ok { - return nil, fmt.Errorf("value must be a string; found %T", k) - } - return NewPKIXBase64Cert(v) - case PKIXBase64CertPathType: - v, ok := k.(string) - if !ok { - return nil, fmt.Errorf("value must be a string; found %T", k) - } - return NewPKIXBase64CertPath(v) - case COSEKeyType: - v, ok := k.([]byte) - if !ok { - return nil, fmt.Errorf("value must be a []byte; found %T", k) - } - return NewCOSEKey(v) - case ThumbprintType, CertThumbprintType, CertPathThumbprintType: - v, ok := k.(swid.HashEntry) - if !ok { - return nil, fmt.Errorf("value must be a swid.HashEntry; found %T", k) - } - switch typ { - case ThumbprintType: - return NewThumbprint(v) - case CertThumbprintType: - return NewCertThumbprint(v) - case CertPathThumbprintType: - return NewCertPathThumbprint(v) - default: - // Should never here because of the the outer case clause - panic(fmt.Sprintf("unexpected thumbprint type: %s", typ)) - } - default: + factory, ok := cryptoKeyValueRegister[typ] + if !ok { return nil, fmt.Errorf("unexpected CryptoKey type: %s", typ) } + + return factory(k) } // MustNewCryptoKey is the same as NewCryptoKey, but does not return an error, @@ -136,42 +100,18 @@ func (o CryptoKey) PublicKey() (crypto.PublicKey, error) { // MarshalJSON returns a []byte containing the JSON representation of the // CryptoKey. func (o CryptoKey) MarshalJSON() ([]byte, error) { - value := struct { - Type string `json:"type"` - Value string `json:"value"` - }{ + value := encoding.TypeAndValue{ + Type: o.Value.Type(), Value: o.Value.String(), } - switch o.Value.(type) { - case TaggedPKIXBase64Key: - value.Type = PKIXBase64KeyType - case TaggedPKIXBase64Cert: - value.Type = PKIXBase64CertType - case TaggedPKIXBase64CertPath: - value.Type = PKIXBase64CertPathType - case TaggedCOSEKey: - value.Type = COSEKeyType - case TaggedThumbprint: - value.Type = ThumbprintType - case TaggedCertThumbprint: - value.Type = CertThumbprintType - case TaggedCertPathThumbprint: - value.Type = CertPathThumbprintType - default: - return nil, fmt.Errorf("unexpected ICryptoKeyValue type: %T", o.Value) - } - return json.Marshal(value) } // UnmarshalJSON populates the CryptoKey from the JSON representation inside // the provided []byte. func (o *CryptoKey) UnmarshalJSON(b []byte) error { - var value struct { - Type string `json:"type"` - Value string `json:"value"` - } + var value encoding.TypeAndValue if err := json.Unmarshal(b, &value); err != nil { return err @@ -181,36 +121,18 @@ func (o *CryptoKey) UnmarshalJSON(b []byte) error { return errors.New("key type not set") } - switch value.Type { - case PKIXBase64KeyType: - o.Value = TaggedPKIXBase64Key(value.Value) - case PKIXBase64CertType: - o.Value = TaggedPKIXBase64Cert(value.Value) - case PKIXBase64CertPathType: - o.Value = TaggedPKIXBase64CertPath(value.Value) - case COSEKeyType: - data, err := base64.StdEncoding.DecodeString(value.Value) - if err != nil { - return fmt.Errorf("base64 decode error: %w", err) - } - o.Value = TaggedCOSEKey(data) - case ThumbprintType, CertThumbprintType, CertPathThumbprintType: - he, err := swid.ParseHashEntry(value.Value) - if err != nil { - return fmt.Errorf("swid.HashEntry decode error: %w", err) - } - switch value.Type { - case ThumbprintType: - o.Value = TaggedThumbprint{digest{he}} - case CertThumbprintType: - o.Value = TaggedCertThumbprint{digest{he}} - case CertPathThumbprintType: - o.Value = TaggedCertPathThumbprint{digest{he}} - } - default: + factory, ok := cryptoKeyValueRegister[value.Type] + if !ok { return fmt.Errorf("unexpected ICryptoKeyValue type: %q", value.Type) } + k, err := factory(value.Value) + if err != nil { + return err + } + + o.Value = k.Value + return o.Valid() } @@ -229,11 +151,8 @@ func (o *CryptoKey) UnmarshalCBOR(b []byte) error { // ICryptoKeyValue is the interface implemented by the concrete CryptoKey value // types. type ICryptoKeyValue interface { - // String returns the string representation of the ICryptoKeyValue. - String() string - // Valid returns an error if validation of the ICryptoKeyValue fails, - // or nil if it succeeds. - Valid() error + extensions.ITypeChoiceValue + // PublicKey returns a crypto.PublicKey constructed from the // ICryptoKeyValue's underlying value. This returns an error if the // ICryptoKeyValue is one of the thumbprint types. @@ -244,7 +163,12 @@ type ICryptoKeyValue interface { // https://www.rfc-editor.org/rfc/rfc7468#section-13 type TaggedPKIXBase64Key string -func NewPKIXBase64Key(s string) (*CryptoKey, error) { +func NewPKIXBase64Key(k any) (*CryptoKey, error) { + s, ok := k.(string) + if !ok { + return nil, fmt.Errorf("value must be a string; found %T", k) + } + key := TaggedPKIXBase64Key(s) if err := key.Valid(); err != nil { return nil, err @@ -252,8 +176,8 @@ func NewPKIXBase64Key(s string) (*CryptoKey, error) { return &CryptoKey{key}, nil } -func MustNewPKIXBase64Key(s string) *CryptoKey { - key, err := NewPKIXBase64Key(s) +func MustNewPKIXBase64Key(k any) *CryptoKey { + key, err := NewPKIXBase64Key(k) if err != nil { panic(err) } @@ -269,6 +193,10 @@ func (o TaggedPKIXBase64Key) Valid() error { return err } +func (o TaggedPKIXBase64Key) Type() string { + return PKIXBase64KeyType +} + func (o TaggedPKIXBase64Key) PublicKey() (crypto.PublicKey, error) { if string(o) == "" { return nil, errors.New("key value not set") @@ -302,7 +230,12 @@ func (o TaggedPKIXBase64Key) PublicKey() (crypto.PublicKey, error) { // certificate. See https://www.rfc-editor.org/rfc/rfc7468#section-5 type TaggedPKIXBase64Cert string -func NewPKIXBase64Cert(s string) (*CryptoKey, error) { +func NewPKIXBase64Cert(k any) (*CryptoKey, error) { + s, ok := k.(string) + if !ok { + return nil, fmt.Errorf("value must be a string; found %T", k) + } + cert := TaggedPKIXBase64Cert(s) if err := cert.Valid(); err != nil { return nil, err @@ -310,8 +243,8 @@ func NewPKIXBase64Cert(s string) (*CryptoKey, error) { return &CryptoKey{cert}, nil } -func MustNewPKIXBase64Cert(s string) *CryptoKey { - cert, err := NewPKIXBase64Cert(s) +func MustNewPKIXBase64Cert(k any) *CryptoKey { + cert, err := NewPKIXBase64Cert(k) if err != nil { panic(err) } @@ -327,6 +260,10 @@ func (o TaggedPKIXBase64Cert) Valid() error { return err } +func (o TaggedPKIXBase64Cert) Type() string { + return PKIXBase64CertType +} + func (o TaggedPKIXBase64Cert) PublicKey() (crypto.PublicKey, error) { cert, err := o.cert() if err != nil { @@ -375,7 +312,11 @@ func (o TaggedPKIXBase64Cert) cert() (*x509.Certificate, error) { // directly certifies the one preceding. type TaggedPKIXBase64CertPath string -func NewPKIXBase64CertPath(s string) (*CryptoKey, error) { +func NewPKIXBase64CertPath(k any) (*CryptoKey, error) { + s, ok := k.(string) + if !ok { + return nil, fmt.Errorf("value must be a string; found %T", k) + } cert := TaggedPKIXBase64CertPath(s) if err := cert.Valid(); err != nil { @@ -385,8 +326,8 @@ func NewPKIXBase64CertPath(s string) (*CryptoKey, error) { return &CryptoKey{cert}, nil } -func MustNewPKIXBase64CertPath(s string) *CryptoKey { - cert, err := NewPKIXBase64CertPath(s) +func MustNewPKIXBase64CertPath(k any) *CryptoKey { + cert, err := NewPKIXBase64CertPath(k) if err != nil { panic(err) @@ -404,6 +345,10 @@ func (o TaggedPKIXBase64CertPath) Valid() error { return err } +func (o TaggedPKIXBase64CertPath) Type() string { + return PKIXBase64CertPathType +} + func (o TaggedPKIXBase64CertPath) PublicKey() (crypto.PublicKey, error) { certs, err := o.certPath() if err != nil { @@ -468,7 +413,22 @@ func (o TaggedPKIXBase64CertPath) certPath() ([]*x509.Certificate, error) { // https://www.rfc-editor.org/rfc/rfc9052#section-7 type TaggedCOSEKey []byte -func NewCOSEKey(b []byte) (*CryptoKey, error) { +func NewCOSEKey(k any) (*CryptoKey, error) { + var b []byte + var err error + + switch t := k.(type) { + case []byte: + b = t + case string: + b, err = base64.StdEncoding.DecodeString(t) + if err != nil { + return nil, fmt.Errorf("base64 decode error: %w", err) + } + default: + return nil, fmt.Errorf("value must be a []byte or a string; found %T", k) + } + key := TaggedCOSEKey(b) if err := key.Valid(); err != nil { @@ -478,8 +438,8 @@ func NewCOSEKey(b []byte) (*CryptoKey, error) { return &CryptoKey{key}, nil } -func MustNewCOSEKey(b []byte) *CryptoKey { - key, err := NewCOSEKey(b) +func MustNewCOSEKey(k any) *CryptoKey { + key, err := NewCOSEKey(k) if err != nil { panic(err) @@ -509,6 +469,10 @@ func (o TaggedCOSEKey) Valid() error { return err } +func (o TaggedCOSEKey) Type() string { + return COSEKeyType +} + func (o TaggedCOSEKey) PublicKey() (crypto.PublicKey, error) { if len(o) == 0 { return nil, errors.New("empty COSE_Key value") @@ -608,7 +572,22 @@ type TaggedThumbprint struct { digest } -func NewThumbprint(he swid.HashEntry) (*CryptoKey, error) { +func NewThumbprint(k any) (*CryptoKey, error) { + var he swid.HashEntry + var err error + + switch t := k.(type) { + case string: + he, err = swid.ParseHashEntry(t) + if err != nil { + return nil, fmt.Errorf("swid.HashEntry decode error: %w", err) + } + case swid.HashEntry: + he = t + default: + return nil, fmt.Errorf("value must be a swid.HashEntry or a string; found %T", k) + } + key := &CryptoKey{TaggedThumbprint{digest{he}}} if err := key.Valid(); err != nil { @@ -618,8 +597,8 @@ func NewThumbprint(he swid.HashEntry) (*CryptoKey, error) { return key, nil } -func MustNewThumbprint(he swid.HashEntry) *CryptoKey { - key, err := NewThumbprint(he) +func MustNewThumbprint(k any) *CryptoKey { + key, err := NewThumbprint(k) if err != nil { panic(err) @@ -628,13 +607,32 @@ func MustNewThumbprint(he swid.HashEntry) *CryptoKey { return key } +func (o TaggedThumbprint) Type() string { + return ThumbprintType +} + // TaggedCertThumbprint represents a digest of a certificate. The digest value // may be used to find the certificate if contained in a lookup table. type TaggedCertThumbprint struct { digest } -func NewCertThumbprint(he swid.HashEntry) (*CryptoKey, error) { +func NewCertThumbprint(k any) (*CryptoKey, error) { + var he swid.HashEntry + var err error + + switch t := k.(type) { + case string: + he, err = swid.ParseHashEntry(t) + if err != nil { + return nil, fmt.Errorf("swid.HashEntry decode error: %w", err) + } + case swid.HashEntry: + he = t + default: + return nil, fmt.Errorf("value must be a swid.HashEntry or a string; found %T", k) + } + key := &CryptoKey{TaggedCertThumbprint{digest{he}}} if err := key.Valid(); err != nil { @@ -644,8 +642,8 @@ func NewCertThumbprint(he swid.HashEntry) (*CryptoKey, error) { return key, nil } -func MustNewCertThumbprint(he swid.HashEntry) *CryptoKey { - key, err := NewCertThumbprint(he) +func MustNewCertThumbprint(k any) *CryptoKey { + key, err := NewCertThumbprint(k) if err != nil { panic(err) @@ -654,6 +652,10 @@ func MustNewCertThumbprint(he swid.HashEntry) *CryptoKey { return key } +func (o TaggedCertThumbprint) Type() string { + return CertThumbprintType +} + // TaggedCertPathThumbprint represents a digest of a certification path. The // digest value may be used to find the certificate path if contained in a // lookup table. @@ -661,7 +663,22 @@ type TaggedCertPathThumbprint struct { digest } -func NewCertPathThumbprint(he swid.HashEntry) (*CryptoKey, error) { +func NewCertPathThumbprint(k any) (*CryptoKey, error) { + var he swid.HashEntry + var err error + + switch t := k.(type) { + case string: + he, err = swid.ParseHashEntry(t) + if err != nil { + return nil, fmt.Errorf("swid.HashEntry decode error: %w", err) + } + case swid.HashEntry: + he = t + default: + return nil, fmt.Errorf("value must be a swid.HashEntry or a string; found %T", k) + } + key := &CryptoKey{TaggedCertPathThumbprint{digest{he}}} if err := key.Valid(); err != nil { @@ -671,8 +688,8 @@ func NewCertPathThumbprint(he swid.HashEntry) (*CryptoKey, error) { return key, nil } -func MustNewCertPathThumbprint(he swid.HashEntry) *CryptoKey { - key, err := NewCertPathThumbprint(he) +func MustNewCertPathThumbprint(k any) *CryptoKey { + key, err := NewCertPathThumbprint(k) if err != nil { panic(err) @@ -680,3 +697,45 @@ func MustNewCertPathThumbprint(he swid.HashEntry) *CryptoKey { return key } + +func (o TaggedCertPathThumbprint) Type() string { + return CertPathThumbprintType +} + +// ICryptoKeyFactory creates a *CryptoKey from the specified any value. When +// passed nil as input, this must return the equivaluent nil-value for the +// associated ICryptoKeyValue implementation. +type ICryptoKeyFactory func(any) (*CryptoKey, error) + +var cryptoKeyValueRegister = map[string]ICryptoKeyFactory{ + // types defined by the core spec + PKIXBase64KeyType: NewPKIXBase64Key, + PKIXBase64CertType: NewPKIXBase64Cert, + PKIXBase64CertPathType: NewPKIXBase64CertPath, + COSEKeyType: NewCOSEKey, + ThumbprintType: NewThumbprint, + CertThumbprintType: NewCertThumbprint, + CertPathThumbprintType: NewCertPathThumbprint, +} + +// RegisterCryptoKeyType registeres a new ICryptoKeyValue implementation +// (created by the provided ICryptoKeyFactory) under the specified type name +// and CBOR tag. +func RegisterCryptoKeyType(typ string, tag uint64, factory ICryptoKeyFactory) error { + if _, exists := cryptoKeyValueRegister[typ]; exists { + return fmt.Errorf("crypto key type with name %q already exists", typ) + } + + nilVal, err := factory(nil) + if err != nil { + return err + } + + if err := registerCOMIDTag(tag, nilVal.Value); err != nil { + return err + } + + cryptoKeyValueRegister[typ] = factory + + return nil +} diff --git a/comid/cryptokey_test.go b/comid/cryptokey_test.go index 1c0812fc..c53f3125 100644 --- a/comid/cryptokey_test.go +++ b/comid/cryptokey_test.go @@ -149,7 +149,7 @@ func Test_CryptoKey_NewCOSEKey(t *testing.T) { } func Test_CryptoKey_NewThumbprint(t *testing.T) { - type newKeyFunc func(swid.HashEntry) (*CryptoKey, error) + type newKeyFunc func(any) (*CryptoKey, error) for _, newFunc := range []newKeyFunc{ NewThumbprint, @@ -177,7 +177,7 @@ func Test_CryptoKey_NewThumbprint(t *testing.T) { assert.Contains(t, err.Error(), "length mismatch for hash algorithm") } - type mustNewKeyFunc func(swid.HashEntry) *CryptoKey + type mustNewKeyFunc func(any) *CryptoKey for _, mustNewFunc := range []mustNewKeyFunc{ MustNewThumbprint, @@ -371,22 +371,22 @@ func Test_NewCryptoKey_negative(t *testing.T) { { Type: COSEKeyType, In: 7, - ErrMsg: "value must be a []byte; found int", + ErrMsg: "value must be a []byte or a string; found int", }, { Type: ThumbprintType, In: 7, - ErrMsg: "value must be a swid.HashEntry; found int", + ErrMsg: "value must be a swid.HashEntry or a string; found int", }, { Type: CertThumbprintType, In: 7, - ErrMsg: "value must be a swid.HashEntry; found int", + ErrMsg: "value must be a swid.HashEntry or a string; found int", }, { Type: CertPathThumbprintType, In: 7, - ErrMsg: "value must be a swid.HashEntry; found int", + ErrMsg: "value must be a swid.HashEntry or a string; found int", }, { Type: "random-key", diff --git a/comid/example_psa_refval_test.go b/comid/example_psa_refval_test.go index 230209d0..8848e90f 100644 --- a/comid/example_psa_refval_test.go +++ b/comid/example_psa_refval_test.go @@ -142,12 +142,11 @@ func extractImplementationID(c *Class) error { return fmt.Errorf("no class-id") } - implID, err := classID.GetImplID() - if err != nil { - return fmt.Errorf("extracting implemenetation-id: %w", err) + if classID.Type() != ImplIDType { + return fmt.Errorf("class id is not a psa.impl-id") } - fmt.Printf("ImplementationID: %x\n", implID) + fmt.Printf("ImplementationID: %x\n", classID.Bytes()) return nil } diff --git a/comid/tmp b/comid/tmp new file mode 100644 index 00000000..58681bc6 --- /dev/null +++ b/comid/tmp @@ -0,0 +1,15 @@ +ImplementationID: 61636d652d696d706c656d656e746174696f6e2d69642d303030303030303031 +SignerID: acbb11c7e4da217205523ce4ce1a245ae1a239ae3c6bfd9e7871f7e5d8bae86b +Label: BL +Version: 2.1.0 +Digest: 87428fc522803d31065e7bce3cf03fe475096631e5e07bbd7a0fde60c4cf25c7 +SignerID: acbb11c7e4da217205523ce4ce1a245ae1a239ae3c6bfd9e7871f7e5d8bae86b +Label: PRoT +Version: 1.3.5 +Digest: 0263829989b6fd954f72baaf2fc64bc2e2f01d692d4de72986ea808f6e99813f +SignerID: acbb11c7e4da217205523ce4ce1a245ae1a239ae3c6bfd9e7871f7e5d8bae86b +Label: ARoT +Version: 0.1.4 +Digest: a3a5e715f0cc574a73c3f9bebb6bc24f32ffd5b67b387244c2c909da779a1478 +Label: a non-empty (unique) label +Raw value: 72617776616c75650a72617776616c75650a diff --git a/comid/uuid.go b/comid/uuid.go index c5c78d61..9842a617 100644 --- a/comid/uuid.go +++ b/comid/uuid.go @@ -16,6 +16,26 @@ type UUID uuid.UUID // TaggedUUID is an alias to allow automatic tagging of a UUID type type TaggedUUID UUID +// String returns a string representation of the binary UUID +func (o TaggedUUID) String() string { + return UUID(o).String() +} + +func (o TaggedUUID) Valid() error { + return UUID(o).Valid() +} + +// Type returns a string contianing type name. This is part of the +// ITypeChoiceValue implementation. +func (o TaggedUUID) Type() string { + return "uuid" +} + +// Bytes returns a []byte containing the raw UUID bytes +func (o TaggedUUID) Bytes() []byte { + return o[:] +} + // ParseUUID parses the supplied string into a UUID func ParseUUID(s string) (UUID, error) { v, err := uuid.Parse(s) diff --git a/encoding/json.go b/encoding/json.go new file mode 100644 index 00000000..74dc6b08 --- /dev/null +++ b/encoding/json.go @@ -0,0 +1,9 @@ +package encoding + +// TypeAndValue stores a JSON object with two attributes: a string "type" +// and a generic "value" (string) defined by type. This type is used in +// a few places to implement the choice types that CBOR handles using tags. +type TypeAndValue struct { + Type string `json:"type"` + Value string `json:"value"` +} diff --git a/extensions/typechoice.go b/extensions/typechoice.go new file mode 100644 index 00000000..263013df --- /dev/null +++ b/extensions/typechoice.go @@ -0,0 +1,11 @@ +package extensions + +type ITypeChoiceValue interface { + // String returns the string representation of the ITypeChoiceValue. + String() string + // Valid returns an error if validation of the ITypeChoiceValue fails, + // or nil if it succeeds. + Valid() error + // Type returns the type name of this ITypeChoiceValue implementation. + Type() string +}