From 786cb00ff2da5d5d173238361f935912132e60b5 Mon Sep 17 00:00:00 2001 From: Sergei Trofimov Date: Thu, 12 Oct 2023 17:02:11 +0100 Subject: [PATCH] WIP BORKED --- comid/ccaplatformconfigid.go | 57 ++++- comid/example_cca_refval_test.go | 18 +- comid/example_psa_refval_test.go | 6 +- comid/example_test.go | 24 +- comid/measurement.go | 414 +++++++++++++++---------------- comid/measurement_test.go | 225 +++++------------ comid/psareferencevalue.go | 100 +++++++- comid/psareferencevalue_test.go | 5 +- 8 files changed, 442 insertions(+), 407 deletions(-) diff --git a/comid/ccaplatformconfigid.go b/comid/ccaplatformconfigid.go index e485083f..b32e21cd 100644 --- a/comid/ccaplatformconfigid.go +++ b/comid/ccaplatformconfigid.go @@ -3,11 +3,15 @@ package comid -import "fmt" +import ( + "errors" + "fmt" + "unicode/utf8" +) -type CCAPlatformConfigID string +var CCAPlatformConfigIDType = "cca.platform-config-id" -type TaggedCCAPlatformConfigID CCAPlatformConfigID +type CCAPlatformConfigID string func (o CCAPlatformConfigID) Empty() bool { return o == "" @@ -27,3 +31,50 @@ func (o CCAPlatformConfigID) Get() (CCAPlatformConfigID, error) { } return o, nil } + +type TaggedCCAPlatformConfigID CCAPlatformConfigID + +func NewTaggedCCAPlatormConfigID(val any) (*TaggedCCAPlatformConfigID, error) { + var ret TaggedCCAPlatformConfigID + switch t := val.(type) { + case TaggedCCAPlatformConfigID: + ret = t + case *TaggedCCAPlatformConfigID: + ret = *t + case CCAPlatformConfigID: + ret = TaggedCCAPlatformConfigID(t) + case *CCAPlatformConfigID: + ret = TaggedCCAPlatformConfigID(*t) + case string: + ret = TaggedCCAPlatformConfigID(t) + case []byte: + if !utf8.Valid(t) { + return nil, errors.New("bytes do not form a valid UTF-8 string") + } + ret = TaggedCCAPlatformConfigID(t) + default: + return nil, fmt.Errorf("unexpected type for CCA platform-config-id: %T", t) + } + + return &ret, nil +} + +func (o TaggedCCAPlatformConfigID) Valid() error { + if o == "" { + return errors.New("empty") + } + + return nil +} + +func (o TaggedCCAPlatformConfigID) String() string { + return string(o) +} + +func (o TaggedCCAPlatformConfigID) Type() string { + return CCAPlatformConfigIDType +} + +func (o TaggedCCAPlatformConfigID) IsZero() bool { + return len(o) == 0 +} diff --git a/comid/example_cca_refval_test.go b/comid/example_cca_refval_test.go index a9d1102b..6fe06775 100644 --- a/comid/example_cca_refval_test.go +++ b/comid/example_cca_refval_test.go @@ -66,21 +66,23 @@ func extractCCARefVal(rv ReferenceValue) error { if !m.Key.IsSet() { return fmt.Errorf("mKey not set at index %d", i) } - if m.Key.IsPSARefValID() { + + switch t := m.Key.Value.(type) { + case TaggedPSARefValID: if err := extractSwMeasurement(m); err != nil { return fmt.Errorf("extracting measurement at index %d: %w", i, err) } - } - if m.Key.IsCCAPlatformConfigID() { + case TaggedCCAPlatformConfigID: if err := extractCCARefValID(m.Key); err != nil { return fmt.Errorf("extracting cca-refval-id: %w", err) } if err := extractRawValue(m.Val.RawValue); err != nil { return fmt.Errorf("extracting raw vlue: %w", err) } - - return nil + default: + return fmt.Errorf("unexpected Mkey type: %T", t) } + } return nil @@ -105,9 +107,9 @@ func extractCCARefValID(k *Mkey) error { return fmt.Errorf("no measurement key") } - id, err := k.GetCCAPlatformConfigID() - if err != nil { - return fmt.Errorf("getting CCA platform config id: %w", err) + id, ok := k.Value.(TaggedCCAPlatformConfigID) + if !ok { + return fmt.Errorf("expected CCA platform config id, found: %T", k.Value) } fmt.Printf("Label: %s\n", id) return nil diff --git a/comid/example_psa_refval_test.go b/comid/example_psa_refval_test.go index 8848e90f..5dbef5c5 100644 --- a/comid/example_psa_refval_test.go +++ b/comid/example_psa_refval_test.go @@ -111,9 +111,9 @@ func extractPSARefValID(k *Mkey) error { return fmt.Errorf("no measurement key") } - id, err := k.GetPSARefValID() - if err != nil { - return fmt.Errorf("getting PSA refval id: %w", err) + id, ok := k.Value.(TaggedPSARefValID) + if !ok { + return fmt.Errorf("expected PSA refval id, found: %T", k.Value) } fmt.Printf("SignerID: %x\n", id.SignerID) diff --git a/comid/example_test.go b/comid/example_test.go index 2e8a5ec9..5bab7d06 100644 --- a/comid/example_test.go +++ b/comid/example_test.go @@ -31,8 +31,7 @@ func Example_encode() { }, Measurements: *NewMeasurements(). AddMeasurement( - NewMeasurement(). - SetKeyUUID(TestUUID). + MustNewUUIDMeasurement(TestUUID). SetRawValueBytes([]byte{0x01, 0x02, 0x03, 0x04}, []byte{0xff, 0xff, 0xff, 0xff}). SetSVN(2). AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}). @@ -60,8 +59,7 @@ func Example_encode() { }, Measurements: *NewMeasurements(). AddMeasurement( - NewMeasurement(). - SetKeyUUID(TestUUID). + MustNewUUIDMeasurement(TestUUID). SetRawValueBytes([]byte{0x01, 0x02, 0x03, 0x04}, []byte{0xff, 0xff, 0xff, 0xff}). SetMinSVN(2). AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}). @@ -126,18 +124,16 @@ func Example_encode_PSA() { }, Measurements: *NewMeasurements(). AddMeasurement( - NewPSAMeasurement( - *NewPSARefValID(TestSignerID). - SetLabel("BL"). - SetVersion("5.0.5"), - ).AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}), + MustNewPSAMeasurement( + MustCreatePSARefValID( + TestSignerID, "BL", "5.0.5", + )).AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}), ). AddMeasurement( - NewPSAMeasurement( - *NewPSARefValID(TestSignerID). - SetLabel("PRoT"). - SetVersion("1.3.5"), - ).AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}), + MustNewPSAMeasurement( + MustCreatePSARefValID( + TestSignerID, "PRoT", "1.3.5", + )).AddDigest(swid.Sha256_32, []byte{0xab, 0xcd, 0xef, 0x00}), ), }, ). diff --git a/comid/measurement.go b/comid/measurement.go index b1909d87..9754ebda 100644 --- a/comid/measurement.go +++ b/comid/measurement.go @@ -5,6 +5,7 @@ package comid import ( "encoding/json" + "errors" "fmt" "net" @@ -23,191 +24,180 @@ type Measurement struct { AuthorizedBy *CryptoKey `cbor:"2,keyasint,omitempty" json:"authorized-by,omitempty"` } +type IMKeyValue interface { + extensions.ITypeChoiceValue +} + // Mkey stores a $measured-element-type-choice. // The supported types are UUID, PSA refval-id, CCA platform-config-id and unsigned integer // TO DO Add tagged OID: see https://github.com/veraison/corim/issues/35 type Mkey struct { - val interface{} + Value IMKeyValue } func (o Mkey) IsSet() bool { - return o.val != nil + return o.Value != nil } func (o Mkey) Valid() error { - switch t := o.val.(type) { - case TaggedUUID: - if UUID(t).Empty() { - return fmt.Errorf("empty UUID") - } - return nil - case TaggedPSARefValID: - return PSARefValID(t).Valid() - case TaggedCCAPlatformConfigID: - if CCAPlatformConfigID(t).Empty() { - return fmt.Errorf("empty CCAPlatformConfigID") - } - case uint64: - if o.val == nil { - return fmt.Errorf("empty uint Mkey") - } - return nil - default: - return fmt.Errorf("unknown measurement key type: %T", t) + if o.Value == nil { + return errors.New("Mkey value not set") + } + + return o.Value.Valid() +} + +// UnmarshalJSON deserializes the supplied JSON object into the target MKey +// The key object must have the following shape: +// +// { +// "type": "", +// "value": "" +// } +// +// where must be one of the known IMKeyValue implementation +// type names (available in the base implementation: "uuid", "oid", +// "psa.impl-id"), and is the class id value encoded as +// a string. The exact encoding is depenent. For the base +// implmentation types it is +// +// oid: dot-seprated integers, e.g. "1.2.3.4" +// psa.refval-id: JSON representation of the PSA refval-id +// uuid: standard UUID string representation, e.g. "550e8400-e29b-41d4-a716-446655440000" +func (o *Mkey) UnmarshalJSON(data []byte) error { + var value encoding.TypeAndValue + + if err := json.Unmarshal(data, &value); err != nil { + return err + } + + if value.Type == "" { + return errors.New("measurement type not set") + } + + factory, ok := mkeyValueRegister[value.Type] + if !ok { + return fmt.Errorf("unknown measurement type: %q", value.Type) + } + + v, err := factory(value.Value) + if err != nil { + return err } + + o.Value = v.Value + + return o.Valid() + return nil } -func (o Mkey) IsPSARefValID() bool { - _, ok := o.val.(TaggedPSARefValID) - return ok +// MarshalJSON serializes the target Mkey into the type'n'value JSON object +func (o Mkey) MarshalJSON() ([]byte, error) { + value := encoding.TypeAndValue{ + Type: o.Value.Type(), + Value: o.Value.String(), + } + + return json.Marshal(value) } -func (o Mkey) IsCCAPlatformConfigID() bool { - _, ok := o.val.(TaggedCCAPlatformConfigID) - return ok +func (o Mkey) MarshalCBOR() ([]byte, error) { + return em.Marshal(o.Value) } -func (o Mkey) GetPSARefValID() (PSARefValID, error) { - switch t := o.val.(type) { - case TaggedPSARefValID: - return PSARefValID(t), nil - default: - return PSARefValID{}, fmt.Errorf("measurement-key type is: %T", t) - } +func (o *Mkey) UnmarshalCBOR(data []byte) error { + return dm.Unmarshal(data, &o.Value) } -func (o Mkey) GetCCAPlatformConfigID() (CCAPlatformConfigID, error) { - switch t := o.val.(type) { - case TaggedCCAPlatformConfigID: - return CCAPlatformConfigID(t), nil - default: - return CCAPlatformConfigID(""), fmt.Errorf("measurement-key type is: %T", t) +var UintType = "uint" + +type UintMkey uint64 + +func NewUintMkey(val any) (*UintMkey, error) { + var ret UintMkey + + if val == nil { + return &ret, nil } -} -func (o Mkey) GetKeyUint() (uint64, error) { - switch t := o.val.(type) { + switch t := val.(type) { case uint64: - return t, nil + ret = UintMkey(t) + case uint: + ret = UintMkey(t) default: - return MaxUint64, fmt.Errorf("measurement-key type is: %T", t) + return nil, fmt.Errorf("unexpected type for MkeyUint: %T", t) } + + return &ret, nil } -// UnmarshalJSON deserializes the type'n'value JSON object into the target Mkey -func (o *Mkey) UnmarshalJSON(data []byte) error { - var v tnv +func (o UintMkey) Valid() error { + return nil +} - if err := json.Unmarshal(data, &v); err != nil { - return err +func (o UintMkey) String() string { + return fmt.Sprint(uint64(o)) +} + +func (o UintMkey) Type() string { + return UintType +} + +func NewMkeyOID(val any) (*Mkey, error) { + ret, err := NewTaggedOID(val) + if err != nil { + return nil, err } - switch v.Type { - case "uuid": - var x UUID - if err := x.UnmarshalJSON(v.Value); err != nil { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type UUID: %w", - err, - ) - } - o.val = TaggedUUID(x) - case "psa.refval-id": - var x PSARefValID - if err := json.Unmarshal(v.Value, &x); err != nil { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type PSARefValID: %w", - err, - ) - } - if err := x.Valid(); err != nil { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type PSARefValID: %w", - err, - ) - } - o.val = TaggedPSARefValID(x) - case "cca.platform-config-id": - var x CCAPlatformConfigID - if err := json.Unmarshal(v.Value, &x); err != nil { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type CCAPlatformConfigID: %w", - err, - ) - } - if x.Empty() { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type CCAPlatformConfigID: empty label", - ) - } - o.val = TaggedCCAPlatformConfigID(x) - case "uint": - var x uint64 - if err := json.Unmarshal(v.Value, &x); err != nil { - return fmt.Errorf( - "cannot unmarshal $measured-element-type-choice of type uint: %w", - err, - ) - } - o.val = x - default: - return fmt.Errorf("unknown type %s for $measured-element-type-choice", v.Type) + return &Mkey{ret}, nil +} + +func NewMkeyUUID(val any) (*Mkey, error) { + ret, err := NewTaggedUUID(val) + if err != nil { + return nil, err } - return nil + return &Mkey{ret}, nil } -// MarshalJSON serializes the target Mkey into the type'n'value JSON object -// Supported types are: uuid, psa.refval-id and unsigned integer -func (o Mkey) MarshalJSON() ([]byte, error) { - var ( - v tnv - b []byte - err error - ) - - switch t := o.val.(type) { - case TaggedUUID: - uuidString := UUID(t).String() - b, err = json.Marshal(uuidString) - if err != nil { - return nil, err - } - v = tnv{Type: "uuid", Value: b} - case TaggedPSARefValID: - b, err = json.Marshal(t) - if err != nil { - return nil, err - } - v = tnv{Type: "psa.refval-id", Value: b} - case TaggedCCAPlatformConfigID: - b, err = json.Marshal(t) - if err != nil { - return nil, err - } - v = tnv{Type: "cca.platform-config-id", Value: b} +func NewMkeyUint(val any) (*Mkey, error) { + ret, err := NewUintMkey(val) + if err != nil { + return nil, err + } - case uint64: - b, err = json.Marshal(t) - if err != nil { - return nil, err - } - v = tnv{Type: "uint", Value: b} + return &Mkey{ret}, nil +} - default: - return nil, fmt.Errorf("unknown type %T for mkey", t) +func NewMkeyPSARefvalID(val any) (*Mkey, error) { + ret, err := NewTaggedPSARefValID(val) + if err != nil { + return nil, err } - return json.Marshal(v) + return &Mkey{ret}, nil } -func (o Mkey) MarshalCBOR() ([]byte, error) { - return em.Marshal(o.val) +func NewMkeyCCAPlatformConfigID(val any) (*Mkey, error) { + ret, err := NewTaggedCCAPlatormConfigID(val) + if err != nil { + return nil, err + } + + return &Mkey{ret}, nil } -func (o *Mkey) UnmarshalCBOR(data []byte) error { - return dm.Unmarshal(data, &o.val) +type IMkeyFactory = func(val any) (*Mkey, error) + +var mkeyValueRegister = map[string]IMkeyFactory{ + OIDType: NewMkeyOID, + UUIDType: NewMkeyUUID, + UintType: NewMkeyUint, + PSARefValIDType: NewMkeyPSARefvalID, + CCAPlatformConfigIDType: NewMkeyCCAPlatformConfigID, } // Mval stores a measurement-values-map with JSON and CBOR serializations. @@ -320,95 +310,105 @@ func (o Version) Valid() error { return nil } -// NewMeasurement instantiates an empty measurement -func NewMeasurement() *Measurement { - return &Measurement{} -} - -// SetKeyPSARefValID sets the key of the target measurement-map to the supplied -// PSA refval-id -func (o *Measurement) SetKeyPSARefValID(psaRefValID PSARefValID) *Measurement { - if o != nil { - if psaRefValID.Valid() != nil { - return nil - } - o.Key = &Mkey{ - val: TaggedPSARefValID(psaRefValID), - } +func NewMeasurement(val any, typ string) (*Measurement, error) { + keyFactory, ok := mkeyValueRegister[typ] + if !ok { + return nil, fmt.Errorf("unknown Mkey type: %s", typ) } - return o -} -// SetKeyCCAPlatformConfigID sets the key of the target measurement-map to the supplied -// CCA platform-config-id -func (o *Measurement) SetKeyCCAPlatformConfigID(ccaPlatformConfigID CCAPlatformConfigID) *Measurement { - if o != nil { - if ccaPlatformConfigID.Empty() { - return nil - } - o.Key = &Mkey{ - val: TaggedCCAPlatformConfigID(ccaPlatformConfigID), - } + key, err := keyFactory(val) + if err != nil { + return nil, fmt.Errorf("invalid key: %w", err) } - return o -} -// SetKeyKeyUUID sets the key of the target measurement-map to the supplied -// UUID -func (o *Measurement) SetKeyUUID(u UUID) *Measurement { - if o != nil { - if u.Empty() { - return nil - } + if err = key.Valid(); err != nil { + return nil, fmt.Errorf("invalid key: %w", err) + } - if u.Valid() != nil { - return nil - } + var ret Measurement + ret.Key = key - o.Key = &Mkey{ - val: TaggedUUID(u), - } - } - return o + return &ret, nil } -// SetKeyUint sets the key of the target measurement-map to the supplied -// unsigned integer -func (o *Measurement) SetKeyUint(u uint64) *Measurement { - if o != nil { - o.Key = &Mkey{ - val: u, - } +func MustNewMeasurement(val any, typ string) *Measurement { + ret, err := NewMeasurement(val, typ) + + if err != nil { + panic(err) } - return o + + return ret } // NewPSAMeasurement instantiates a new measurement-map with the key set to the // supplied PSA refval-id -func NewPSAMeasurement(psaRefValID PSARefValID) *Measurement { - m := &Measurement{} - return m.SetKeyPSARefValID(psaRefValID) +func NewPSAMeasurement(key any) (*Measurement, error) { + return NewMeasurement(key, PSARefValIDType) +} + +func MustNewPSAMeasurement(key any) *Measurement { + ret, err := NewPSAMeasurement(key) + + if err != nil { + panic(err) + } + + return ret } // NewCCAPlatCfgMeasurement instantiates a new measurement-map with the key set to the // supplied CCA platform-config-id -func NewCCAPlatCfgMeasurement(ccaPlatformConfigID CCAPlatformConfigID) *Measurement { - m := &Measurement{} - return m.SetKeyCCAPlatformConfigID(ccaPlatformConfigID) +func NewCCAPlatCfgMeasurement(key any) (*Measurement, error) { + return NewMeasurement(key, CCAPlatformConfigIDType) +} + +func MustNewCCAPlatCfgMeasurement(key any) *Measurement { + ret, err := NewCCAPlatCfgMeasurement(key) + + if err != nil { + panic(err) + } + + return ret } // NewUUIDMeasurement instantiates a new measurement-map with the key set to the // supplied UUID -func NewUUIDMeasurement(uuid UUID) *Measurement { - m := &Measurement{} - return m.SetKeyUUID(uuid) +func NewUUIDMeasurement(key any) (*Measurement, error) { + return NewMeasurement(key, UUIDType) +} + +func MustNewUUIDMeasurement(key any) *Measurement { + ret, err := NewUUIDMeasurement(key) + + if err != nil { + panic(err) + } + + return ret } // NewUintMeasurement instantiates a new measurement-map with the key set to the // supplied Uint -func NewUintMeasurement(mkey uint64) *Measurement { - m := &Measurement{} - return m.SetKeyUint(mkey) +func NewUintMeasurement(key any) (*Measurement, error) { + return NewMeasurement(key, UintType) +} + +func MustNewUintMeasurement(key any) *Measurement { + ret, err := NewUintMeasurement(key) + + if err != nil { + panic(err) + } + + return ret +} + +// NewOIDMeasurement instantiates a new measurement-map with the key set to the +// supplied OID +func NewOIDMeasurement(key any) (*Measurement, error) { + return NewMeasurement(key, OIDType) } func (o *Measurement) SetVersion(ver string, scheme int64) *Measurement { diff --git a/comid/measurement_test.go b/comid/measurement_test.go index d5d58bff..c74d79f3 100644 --- a/comid/measurement_test.go +++ b/comid/measurement_test.go @@ -14,86 +14,86 @@ import ( ) func TestMeasurement_NewUUIDMeasurement_good_uuid(t *testing.T) { - tv := NewUUIDMeasurement(TestUUID) - assert.NotNil(t, tv) + _, err := NewUUIDMeasurement(TestUUID) + assert.NoError(t, err) } func TestMeasurement_NewUUIDMeasurement_empty_uuid(t *testing.T) { emptyUUID := UUID{} - tv := NewUUIDMeasurement(emptyUUID) + _, err := NewUUIDMeasurement(emptyUUID) - assert.Nil(t, tv) + assert.EqualError(t, err, + "invalid key: expecting RFC4122 UUID, got Reserved instead") } func TestMeasurement_NewUIntMeasurement(t *testing.T) { var TestUint uint64 = 35 - tv := NewUintMeasurement(TestUint) + _, err := NewUintMeasurement(TestUint) - assert.NotNil(t, tv) + assert.NoError(t, err) } func TestMeasurement_NewPSAMeasurement_empty(t *testing.T) { emptyPSARefValID := PSARefValID{} - tv := NewPSAMeasurement(emptyPSARefValID) - assert.Nil(t, tv) + _, err := NewPSAMeasurement(emptyPSARefValID) + + assert.EqualError(t, err, "invalid key: missing mandatory signer ID") } func TestMeasurement_NewPSAMeasurement_no_values(t *testing.T) { - psaRefValID := - NewPSARefValID(TestSignerID). - SetLabel("PRoT"). - SetVersion("1.2.3") + psaRefValID, err := NewPSARefValID(TestSignerID) + require.NoError(t, err) + psaRefValID.SetLabel("PRoT") + psaRefValID.SetVersion("1.2.3") require.NotNil(t, psaRefValID) - tv := NewPSAMeasurement(*psaRefValID) - assert.NotNil(t, tv) + tv, err := NewPSAMeasurement(*psaRefValID) + assert.NoError(t, err) - err := tv.Valid() + err = tv.Valid() assert.EqualError(t, err, "no measurement value set") } func TestMeasurement_NewCCAPlatCfgMeasurement_no_values(t *testing.T) { ccaplatID := CCAPlatformConfigID(TestCCALabel) - tv := NewCCAPlatCfgMeasurement(ccaplatID) - assert.NotNil(t, tv) + tv, err := NewCCAPlatCfgMeasurement(ccaplatID) + assert.NoError(t, err) - err := tv.Valid() + err = tv.Valid() assert.EqualError(t, err, "no measurement value set") } func TestMeasurement_NewCCAPlatCfgMeasurement_valid_meas(t *testing.T) { ccaplatID := CCAPlatformConfigID(TestCCALabel) - tv := NewCCAPlatCfgMeasurement(ccaplatID).SetRawValueBytes([]byte{0x01, 0x02, 0x03, 0x04}, []byte{}) - assert.NotNil(t, tv) + tv, err := NewCCAPlatCfgMeasurement(ccaplatID) + assert.NoError(t, err) - err := tv.Valid() - assert.Nil(t, err) + tv.SetRawValueBytes([]byte{0x01, 0x02, 0x03, 0x04}, []byte{}) + + err = tv.Valid() + assert.NoError(t, err) } func TestMeasurement_NewPSAMeasurement_one_value(t *testing.T) { - psaRefValID := - NewPSARefValID(TestSignerID). - SetLabel("PRoT"). - SetVersion("1.2.3") - require.NotNil(t, psaRefValID) + tv, err := NewPSAMeasurement(MustCreatePSARefValID(TestSignerID, "PRoT", "1.2.3")) + require.NoError(t, err) - tv := NewPSAMeasurement(*psaRefValID).SetIPaddr(TestIPaddr) - assert.NotNil(t, tv) + tv.SetIPaddr(TestIPaddr) - err := tv.Valid() + err = tv.Valid() assert.Nil(t, err) } func TestMeasurement_NewUUIDMeasurement_no_values(t *testing.T) { - tv := NewUUIDMeasurement(TestUUID) - require.NotNil(t, tv) + tv, err := NewUUIDMeasurement(TestUUID) + require.NoError(t, err) - err := tv.Valid() + err = tv.Valid() assert.EqualError(t, err, "no measurement value set") } @@ -101,26 +101,27 @@ func TestMeasurement_NewUUIDMeasurement_some_value(t *testing.T) { var vs swid.VersionScheme require.NoError(t, vs.SetCode(swid.VersionSchemeSemVer)) - tv := NewUUIDMeasurement(TestUUID). - SetMinSVN(2). + tv, err := NewUUIDMeasurement(TestUUID) + require.NoError(t, err) + + tv.SetMinSVN(2). SetFlagsTrue(FlagIsDebug). SetVersion("1.2.3", swid.VersionSchemeSemVer) - require.NotNil(t, tv) - err := tv.Valid() + err = tv.Valid() assert.Nil(t, err) } func TestMeasurement_NewUUIDMeasurement_bad_digest(t *testing.T) { - tv := NewUUIDMeasurement(TestUUID) - require.NotNil(t, tv) + tv, err := NewUUIDMeasurement(TestUUID) + require.NoError(t, err) assert.Nil(t, tv.AddDigest(swid.Sha256, []byte{0xff})) } func TestMeasurement_NewUUIDMeasurement_bad_ueid(t *testing.T) { - tv := NewUUIDMeasurement(TestUUID) - require.NotNil(t, tv) + tv, err := NewUUIDMeasurement(TestUUID) + require.NoError(t, err) badUEID := eat.UEID{ 0xFF, // Invalid @@ -131,8 +132,8 @@ func TestMeasurement_NewUUIDMeasurement_bad_ueid(t *testing.T) { } func TestMeasurement_NewUUIDMeasurement_bad_uuid(t *testing.T) { - tv := NewUUIDMeasurement(TestUUID) - require.NotNil(t, tv) + tv, err := NewUUIDMeasurement(TestUUID) + require.NoError(t, err) nonRFC4122UUID, err := ParseUUID("f47ac10b-58cc-4372-c567-0e02b2c3d479") require.Nil(t, err) @@ -147,7 +148,7 @@ var ( func TestMkey_Valid_no_value(t *testing.T) { mkey := &Mkey{} - expectedErr := "unknown measurement key type: " + expectedErr := "Mkey value not set" err := mkey.Valid() assert.EqualError(t, err, expectedErr) } @@ -203,9 +204,9 @@ func TestMKey_UnmarshalCBOR_CCAPlatformConfigID_ok(t *testing.T) { mkey := &Mkey{} err := mkey.UnmarshalCBOR(tv.input) assert.Nil(t, err) - actual, err := mkey.GetCCAPlatformConfigID() - assert.Nil(t, err) - assert.Equal(t, tv.expected, actual) + actual, ok := mkey.Value.(*TaggedCCAPlatformConfigID) + assert.True(t, ok) + assert.Equal(t, tv.expected, *actual) fmt.Printf("CBOR: %x\n", actual) } } @@ -230,36 +231,13 @@ func TestMKey_MarshalCBOR_uint_ok(t *testing.T) { } for _, tv := range tvs { - mkey := &Mkey{tv.mkey} + mkey := &Mkey{UintMkey(tv.mkey)} actual, err := mkey.MarshalCBOR() assert.Nil(t, err) assert.Equal(t, tv.expected, actual) fmt.Printf("CBOR: %x\n", actual) } } -func TestMkey_MarshalCBOR_uint_not_ok(t *testing.T) { - tvs := []struct { - input interface{} - expected string - }{ - { - input: 123.456, - expected: "unknown measurement key type: float64", - }, - { - input: "sample", - expected: "unknown measurement key type: string", - }, - } - - for _, tv := range tvs { - mkey := &Mkey{tv.input} - _, err := mkey.MarshalCBOR() - assert.Nil(t, err) - err = mkey.Valid() - assert.EqualError(t, err, tv.expected) - } -} func TestMkey_UnmarshalCBOR_uint_ok(t *testing.T) { tvs := []struct { @@ -284,10 +262,10 @@ func TestMkey_UnmarshalCBOR_uint_ok(t *testing.T) { mKey := &Mkey{} err := mKey.UnmarshalCBOR(tv.mkey) - assert.Nil(t, err) - actual, err := mKey.GetKeyUint() - assert.Nil(t, err) - assert.Equal(t, tv.expected, actual) + require.NoError(t, err) + actual, ok := mKey.Value.(*UintMkey) + assert.True(t, ok) + assert.Equal(t, tv.expected, uint64(*actual)) } } @@ -315,38 +293,9 @@ func TestMkey_UnmarshalCBOR_not_ok(t *testing.T) { } } -func TestMkey_UnmarshalCBOR_uint_not_ok(t *testing.T) { - tvs := []struct { - input []byte - expected string - }{ - { - input: []byte{0xd8, 0x25, 0x50, 0x31, 0xfb, 0x5a, 0xbf, 0x02, - 0x3e, 0x49, 0x92, 0xaa, 0x4e, 0x95, 0xf9, 0xc1, - 0x50, 0x3b, 0xfa}, - expected: "measurement-key type is: comid.TaggedUUID", - }, - { - input: []byte{0xd8, 0x21, 0x50, 0x31, 0xfb, 0x5a, 0xff, 0x12, - 0xFF, 0xFF, 0x92, 0xaa, 0x4e, 0x95, 0xf9, 0xc1, - 0x50, 0x3b, 0xfa}, - expected: "measurement-key type is: cbor.Tag", - }, - } - - for _, tv := range tvs { - mKey := &Mkey{} - - err := mKey.UnmarshalCBOR(tv.input) - assert.Nil(t, err) - _, err = mKey.GetKeyUint() - assert.EqualError(t, err, tv.expected) - } -} - func TestMKey_MarshalJSON_CCAPlatformConfigID_ok(t *testing.T) { refval := TestCCALabel - mkey := &Mkey{val: TaggedCCAPlatformConfigID(refval)} + mkey := &Mkey{Value: TaggedCCAPlatformConfigID(refval)} expected := `{"type":"cca.platform-config-id","value":"cca-platform-config"}` @@ -359,15 +308,15 @@ func TestMKey_MarshalJSON_CCAPlatformConfigID_ok(t *testing.T) { func TestMKey_UnMarshalJSON_CCAPlatformConfigID_ok(t *testing.T) { input := []byte(`{"type":"cca.platform-config-id","value":"cca-platform-config"}`) - expected := CCAPlatformConfigID(TestCCALabel) + expected := TaggedCCAPlatformConfigID(TestCCALabel) mKey := &Mkey{} err := mKey.UnmarshalJSON(input) assert.Nil(t, err) - actual, err := mKey.GetCCAPlatformConfigID() - assert.Nil(t, err) - assert.Equal(t, expected, actual) + actual, ok := mKey.Value.(*TaggedCCAPlatformConfigID) + assert.True(t, ok) + assert.Equal(t, expected, *actual) } @@ -404,7 +353,7 @@ func TestMkey_MarshalJSON_uint_ok(t *testing.T) { for _, tv := range tvs { - mkey := &Mkey{tv.mkey} + mkey := &Mkey{UintMkey(tv.mkey)} actual, err := mkey.MarshalJSON() assert.Nil(t, err) @@ -414,31 +363,6 @@ func TestMkey_MarshalJSON_uint_ok(t *testing.T) { } } -func TestMkey_MarshalJSON_uint_not_ok(t *testing.T) { - tvs := []struct { - input interface{} - expected string - }{ - { - input: 123.456, - expected: "unknown type float64 for mkey", - }, - { - input: "sample", - expected: "unknown type string for mkey", - }, - } - - for _, tv := range tvs { - - mkey := &Mkey{tv.input} - - _, err := mkey.MarshalJSON() - - assert.EqualError(t, err, tv.expected) - } -} - func TestMkey_UnmarshalJSON_uint_ok(t *testing.T) { tvs := []struct { input []byte @@ -463,9 +387,9 @@ func TestMkey_UnmarshalJSON_uint_ok(t *testing.T) { err := mKey.UnmarshalJSON(tv.input) assert.Nil(t, err) - actual, err := mKey.GetKeyUint() - assert.Nil(t, err) - assert.Equal(t, tv.expected, actual) + actual, ok := mKey.Value.(*UintMkey) + assert.True(t, ok) + assert.Equal(t, tv.expected, uint64(*actual)) } } @@ -492,28 +416,3 @@ func TestMkey_UnmarshalJSON_notok(t *testing.T) { assert.EqualError(t, err, tv.expected) } } - -func TestMkey_UnmarshalJSON_uint_notok(t *testing.T) { - tvs := []struct { - input []byte - expected string - }{ - { - input: []byte(`{"type":"uuid","value":"31fb5abf-023e-4992-aa4e-95f9c1503bfa"}`), - expected: "measurement-key type is: comid.TaggedUUID", - }, - { - input: []byte(`{"type":"psa.refval-id","value":{"label": "BL","version": "2.1.0","signer-id": "rLsRx+TaIXIFUjzkzhokWuGiOa48a/2eeHH35di66Gs="}}`), - expected: "measurement-key type is: comid.TaggedPSARefValID", - }, - } - - for _, tv := range tvs { - mKey := &Mkey{} - - err := mKey.UnmarshalJSON(tv.input) - assert.Nil(t, err) - _, err = mKey.GetKeyUint() - assert.EqualError(t, err, tv.expected) - } -} diff --git a/comid/psareferencevalue.go b/comid/psareferencevalue.go index cde3304b..1ea91303 100644 --- a/comid/psareferencevalue.go +++ b/comid/psareferencevalue.go @@ -4,9 +4,12 @@ package comid import ( + "encoding/json" "fmt" ) +var PSARefValIDType = "psa.refval-id" + // PSARefValID stores a PSA refval-id with CBOR and JSON serializations // (See https://datatracker.ietf.org/doc/html/draft-xyz-rats-psa-endorsements) type PSARefValID struct { @@ -30,18 +33,56 @@ func (o PSARefValID) Valid() error { return nil } -type TaggedPSARefValID PSARefValID +func CreatePSARefValID(signerID []byte, label, version string) (*PSARefValID, error) { + ret, err := NewPSARefValID(signerID) + if err != nil { + return nil, err + } -func NewPSARefValID(signerID []byte) *PSARefValID { - switch len(signerID) { - case 32, 48, 64: - default: - return nil + ret.SetLabel(label) + ret.SetVersion(version) + + return ret, nil +} + +func MustCreatePSARefValID(signerID []byte, label, version string) *PSARefValID { + ret, err := CreatePSARefValID(signerID, label, version) + + if err != nil { + panic(err) } - return &PSARefValID{ - SignerID: signerID, + return ret +} + +func NewPSARefValID(val any) (*PSARefValID, error) { + var ret PSARefValID + + if val == nil { + return &ret, nil } + + switch t := val.(type) { + case PSARefValID: + ret = t + case *PSARefValID: + ret = *t + case string: + if err := json.Unmarshal([]byte(t), &ret); err != nil { + return nil, err + } + case []byte: + switch len(t) { + case 32, 48, 64: + ret.SignerID = t + default: + return nil, fmt.Errorf("unexpected PSA RefVal ID length: %d", len(t)) + } + default: + return nil, fmt.Errorf("unexpected type for PSA RefVal ID: %T", t) + } + + return &ret, nil } func (o *PSARefValID) SetLabel(label string) *PSARefValID { @@ -57,3 +98,46 @@ func (o *PSARefValID) SetVersion(version string) *PSARefValID { } return o } + +type TaggedPSARefValID PSARefValID + +func NewTaggedPSARefValID(val any) (*TaggedPSARefValID, error) { + var ret TaggedPSARefValID + + switch t := val.(type) { + case TaggedPSARefValID: + ret = t + case *TaggedPSARefValID: + ret = *t + default: + refvalID, err := NewPSARefValID(val) + if err != nil { + return nil, err + } + ret = TaggedPSARefValID(*refvalID) + + } + + return &ret, nil +} + +func (o TaggedPSARefValID) Valid() error { + return PSARefValID(o).Valid() +} + +func (o TaggedPSARefValID) String() string { + ret, err := json.Marshal(o) + if err != nil { + return "" + } + + return string(ret) +} + +func (o TaggedPSARefValID) Type() string { + return PSARefValIDType +} + +func (o TaggedPSARefValID) IsZero() bool { + return len(o.SignerID) == 0 +} diff --git a/comid/psareferencevalue_test.go b/comid/psareferencevalue_test.go index 069c432e..73ed1ed0 100644 --- a/comid/psareferencevalue_test.go +++ b/comid/psareferencevalue_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestPSARefValID_Valid_SignerID_range(t *testing.T) { @@ -15,7 +16,9 @@ func TestPSARefValID_Valid_SignerID_range(t *testing.T) { for i := 1; i <= 100; i++ { signerID = append(signerID, byte(0xff)) - tv := NewPSARefValID(signerID) + tv, err := NewPSARefValID(signerID) + require.NoError(t, err) + switch i { case 32, 48, 64: assert.NotNil(t, tv)