diff --git a/corim/entity.go b/corim/entity.go index 49aa1468..97bec2b8 100644 --- a/corim/entity.go +++ b/corim/entity.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/veraison/corim/comid" + "github.com/veraison/corim/encoding" ) // Entity stores an entity-map capable of CBOR and JSON serializations. @@ -14,12 +15,22 @@ type Entity struct { EntityName string `cbor:"0,keyasint" json:"name"` RegID *comid.TaggedURI `cbor:"1,keyasint,omitempty" json:"regid,omitempty"` Roles Roles `cbor:"2,keyasint" json:"roles"` + + Extensions } func NewEntity() *Entity { return &Entity{} } +func (o *Entity) RegisterExtensions(exts IExtensionsValue) { + o.Extensions.IExtensionsValue = exts +} + +func (o *Entity) GetExtensions() IExtensionsValue { + return o.Extensions.IExtensionsValue +} + // SetEntityName is used to set the EntityName field of Entity using supplied name func (o *Entity) SetEntityName(name string) *Entity { if o != nil { @@ -72,7 +83,15 @@ func (o Entity) Valid() error { return fmt.Errorf("invalid entity: %w", err) } - return nil + return o.Extensions.ValidEntity(&o) +} + +func (o *Entity) UnmarshalCBOR(data []byte) error { + return encoding.PopulateStructFromCBOR(dm, data, o) +} + +func (o *Entity) MarshalCBOR() ([]byte, error) { + return encoding.SerializeStructToCBOR(em, o) } // Entities is an array of entity-map's diff --git a/corim/extensions.go b/corim/extensions.go new file mode 100644 index 00000000..c7dc5ced --- /dev/null +++ b/corim/extensions.go @@ -0,0 +1,150 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package corim + +import ( + "errors" + "fmt" + "reflect" + "strings" +) + +var ErrExtensionNotFound = errors.New("extension not found") + +type IExtensionsValue interface{} + +type IEntityValidator interface { + ValidEntity(*Entity) error +} + +type Extensions struct { + IExtensionsValue +} + +func (o *Extensions) ValidEntity(entity *Entity) error { + if !o.HaveExtensions() { + return nil + } + + ev, ok := o.IExtensionsValue.(IEntityValidator) + if ok { + if err := ev.ValidEntity(entity); err != nil { + return err + } + } + + return nil +} + +func (o *Extensions) HaveExtensions() bool { + return o.IExtensionsValue != nil +} + +func (o *Extensions) Get(name string) (any, error) { + if o.IExtensionsValue == nil { + return nil, fmt.Errorf("%w: %s", ErrExtensionNotFound, name) + } + + extType := reflect.TypeOf(o.IExtensionsValue) + extVal := reflect.ValueOf(o.IExtensionsValue) + if extType.Kind() == reflect.Pointer { + extType = extType.Elem() + extVal = extVal.Elem() + } + + var fieldName, fieldJSONTag, fieldCBORTag string + for i := 0; i < extVal.NumField(); i++ { + typeField := extType.Field(i) + fieldName = typeField.Name + + tag, ok := typeField.Tag.Lookup("json") + if ok { + fieldJSONTag = strings.Split(tag, ",")[0] + } + + tag, ok = typeField.Tag.Lookup("cbor") + if ok { + fieldCBORTag = strings.Split(tag, ",")[0] + } + + if fieldName == name || fieldJSONTag == name || fieldCBORTag == name { + return extVal.Field(i).Interface(), nil + } + } + + return nil, fmt.Errorf("%w: %s", ErrExtensionNotFound, name) +} + +func (o *Extensions) GetString(name string) (string, error) { + v, err := o.Get(name) + if err != nil { + return "", err + } + + switch t := v.(type) { + case string: + return t, nil + default: + return fmt.Sprintf("%v", t), nil + } +} + +func (o *Extensions) GetInt(name string) (int64, error) { + v, err := o.Get(name) + if err != nil { + return 0, err + } + + val := reflect.ValueOf(v) + if val.CanInt() { + return val.Int(), nil + } + + return 0, fmt.Errorf("%s is not an integer: %v (%T)", name, v, v) +} + +func (o *Extensions) Set(name string, value any) error { + if o.IExtensionsValue == nil { + return fmt.Errorf("%w: %s", ErrExtensionNotFound, name) + } + + extType := reflect.TypeOf(o.IExtensionsValue) + extVal := reflect.ValueOf(o.IExtensionsValue) + if extType.Kind() == reflect.Pointer { + extType = extType.Elem() + extVal = extVal.Elem() + } + + var fieldName, fieldJSONTag, fieldCBORTag string + for i := 0; i < extVal.NumField(); i++ { + typeField := extType.Field(i) + valField := extVal.Field(i) + fieldName = typeField.Name + + tag, ok := typeField.Tag.Lookup("json") + if ok { + fieldJSONTag = strings.Split(tag, ",")[0] + } + + tag, ok = typeField.Tag.Lookup("cbor") + if ok { + fieldCBORTag = strings.Split(tag, ",")[0] + } + + if fieldName == name || fieldJSONTag == name || fieldCBORTag == name { + newVal := reflect.ValueOf(value) + if newVal.CanConvert(valField.Type()) { + valField.Set(newVal.Convert(valField.Type())) + return nil + } + + return fmt.Errorf( + "cannot set field %q (of type %s) to %v (%T)", + name, typeField.Type.Name(), + value, value, + ) + } + } + + return fmt.Errorf("%w: %s", ErrExtensionNotFound, name) +} diff --git a/corim/extensions_test.go b/corim/extensions_test.go new file mode 100644 index 00000000..b2c9173b --- /dev/null +++ b/corim/extensions_test.go @@ -0,0 +1,113 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package corim + +import ( + "errors" + "testing" + + "github.com/fxamacker/cbor/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type TestExtensions struct { + Address string `cbor:"-1,keyasint,omitempty" json:"address,omitempty"` + Size int `cbor:"-2,keyasint,omitempty" json:"size,omitempty"` +} + +func (o TestExtensions) ValidEntity(ent *Entity) error { + if ent.EntityName != "Futurama" { + return errors.New(`EntityName must be "Futurama"`) // nolint:golint + } + + return nil +} + +func TestEntityExtensions_GetSet(t *testing.T) { + extsVal := TestExtensions{ + Address: "742 Evergreen Terrace", + Size: 6, + } + exts := &Extensions{&extsVal} + + v, err := exts.GetInt("size") + assert.NoError(t, err) + assert.Equal(t, int64(6), v) + + s, err := exts.GetString("address") + assert.NoError(t, err) + assert.Equal(t, "742 Evergreen Terrace", s) + + _, err = exts.GetInt("address") + assert.EqualError(t, err, "address is not an integer: 742 Evergreen Terrace (string)") + + _, err = exts.GetInt("foo") + assert.EqualError(t, err, "extension not found: foo") + + err = exts.Set("-1", "123 Fake Street") + assert.NoError(t, err) + + s, err = exts.GetString("address") + assert.NoError(t, err) + assert.Equal(t, "123 Fake Street", s) + + err = exts.Set("Size", "foo") + assert.EqualError(t, err, `cannot set field "Size" (of type int) to foo (string)`) + + ent := NewEntity() + ent.RegisterExtensions(&extsVal) + + obtainedVal := ent.GetExtensions().(*TestExtensions) + assert.EqualValues(t, extsVal, *obtainedVal) +} + +func TestEntityExtensions_Valid(t *testing.T) { + ent := NewEntity() + ent.SetEntityName("The Simpsons") + ent.SetRoles(RoleManifestCreator) + + err := ent.Valid() + assert.NoError(t, err) + + ent.RegisterExtensions(&TestExtensions{}) + err = ent.Valid() + assert.EqualError(t, err, `EntityName must be "Futurama"`) + + ent.SetEntityName("Futurama") + err = ent.Valid() + assert.NoError(t, err) +} + +func TestEntityExtensions_CBOR(t *testing.T) { + data := []byte{ + 0xa4, // map(4) + + 0x00, // key 0 + 0x64, // val tstr(4) + 0x61, 0x63, 0x6d, 0x65, // "acme" + + 0x02, // key 2 + 0x81, // array(1) + 0x01, // 1 + + 0x20, // key -1 + 0x63, // val tstr(3) + 0x66, 0x6f, 0x6f, // "foo" + + 0x21, // key -2 + 0x06, // val 6 + } + + ent := NewEntity() + ent.RegisterExtensions(&TestExtensions{}) + + err := cbor.Unmarshal(data, &ent) + assert.NoError(t, err) + + assert.Equal(t, ent.EntityName, "acme") + + address, err := ent.Get("address") + require.NoError(t, err) + assert.Equal(t, address, "foo") +} diff --git a/encoding/cbor.go b/encoding/cbor.go new file mode 100644 index 00000000..c2c5428e --- /dev/null +++ b/encoding/cbor.go @@ -0,0 +1,226 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 + +package encoding + +import ( + "fmt" + "reflect" + "strconv" + "strings" + + cbor "github.com/fxamacker/cbor/v2" +) + +type embedded struct { + Type reflect.Type + Value reflect.Value +} + +func SerializeStructToCBOR(em cbor.EncMode, source any) ([]byte, error) { + rawMap := make(map[int]cbor.RawMessage) + + structType := reflect.TypeOf(source) + structVal := reflect.ValueOf(source) + + if err := doSerializeStructToCBOR(em, rawMap, structType, structVal); err != nil { + return nil, err + } + + return em.Marshal(rawMap) +} + +func doSerializeStructToCBOR( + em cbor.EncMode, + rawMap map[int]cbor.RawMessage, + structType reflect.Type, + structVal reflect.Value, +) error { + if structType.Kind() == reflect.Pointer { + structType = structType.Elem() + structVal = structVal.Elem() + } + + var embeds []embedded + + for i := 0; i < structVal.NumField(); i++ { + typeField := structType.Field(i) + valField := structVal.Field(i) + + if collectEmbedded(&typeField, valField, &embeds) { + continue + } + + tag, ok := typeField.Tag.Lookup("cbor") + if !ok { + continue + } + + parts := strings.Split(tag, ",") + keyString := parts[0] + + isOmitEmpty := false + if len(parts) > 1 { + for _, option := range parts[1:] { + if option == "omitempty" { + isOmitEmpty = true + break + } + } + } + + // do not serialize zero values if the corresponding field is + // omitempty + if isOmitEmpty && valField.IsZero() { + continue + } + + keyInt, err := strconv.Atoi(keyString) + if err != nil { + return fmt.Errorf("non-integer cbor key: %s", keyString) + } + + if _, ok := rawMap[keyInt]; ok { + return fmt.Errorf("duplicate cbor key: %d", keyInt) + } + + data, err := em.Marshal(valField.Interface()) + if err != nil { + return fmt.Errorf("error marshaling field %q: %w", + typeField.Name, + err, + ) + } + + rawMap[keyInt] = cbor.RawMessage(data) + } + + for _, emb := range embeds { + if err := doSerializeStructToCBOR(em, rawMap, emb.Type, emb.Value); err != nil { + return err + } + } + + return nil +} + +func PopulateStructFromCBOR(dm cbor.DecMode, data []byte, dest any) error { + var rawMap map[int]cbor.RawMessage + + if err := dm.Unmarshal(data, &rawMap); err != nil { + return err + } + + structType := reflect.TypeOf(dest) + structVal := reflect.ValueOf(dest) + + return doPopulateStructFromCBOR(dm, rawMap, structType, structVal) +} + +func doPopulateStructFromCBOR( + dm cbor.DecMode, + rawMap map[int]cbor.RawMessage, + structType reflect.Type, + structVal reflect.Value, +) error { + if structType.Kind() == reflect.Pointer { + structType = structType.Elem() + structVal = structVal.Elem() + } + + var embeds []embedded + + for i := 0; i < structVal.NumField(); i++ { + typeField := structType.Field(i) + valField := structVal.Field(i) + + if collectEmbedded(&typeField, valField, &embeds) { + continue + } + + tag, ok := typeField.Tag.Lookup("cbor") + if !ok { + continue + } + + parts := strings.Split(tag, ",") + keyString := parts[0] + + isOmitEmpty := false + if len(parts) > 1 { + for _, option := range parts[1:] { + if option == "omitempty" { + isOmitEmpty = true + break + } + } + } + + keyInt, err := strconv.Atoi(keyString) + if err != nil { + return fmt.Errorf("non-integer cbor key %s", keyString) + } + + rawVal, ok := rawMap[keyInt] + if !ok { + if isOmitEmpty { + continue + } + + return fmt.Errorf("missing mandatory field %q (%d)", + typeField.Name, keyInt) + } + + fieldPtr := valField.Addr().Interface() + if err := dm.Unmarshal(rawVal, fieldPtr); err != nil { + return fmt.Errorf("error unmarshalling field %q: %w", + typeField.Name, + err, + ) + } + + delete(rawMap, keyInt) + } + + for _, emb := range embeds { + if err := doPopulateStructFromCBOR(dm, rawMap, emb.Type, emb.Value); err != nil { + return err + } + } + + return nil +} + +// collectEmbedded returns true if the Field is embedded (regardless of +// whether or not it was collected). +func collectEmbedded( + typeField *reflect.StructField, + valField reflect.Value, + embeds *[]embedded, +) bool { + if typeField.Name == typeField.Type.Name() && + (typeField.Type.Kind() == reflect.Struct || + typeField.Type.Kind() == reflect.Interface) { + + var fieldType reflect.Type + var fieldValue reflect.Value + + if typeField.Type.Kind() == reflect.Interface { + fieldValue = valField.Elem() + if fieldValue.Kind() == reflect.Invalid { + // no value underlying the interface + return true + } + // use the interface's underlying value's real type + fieldType = valField.Elem().Type() + } else { + fieldType = typeField.Type + fieldValue = valField + } + + *embeds = append(*embeds, embedded{Type: fieldType, Value: fieldValue}) + return true + } + + return false +} diff --git a/encoding/cbor_test.go b/encoding/cbor_test.go new file mode 100644 index 00000000..471c992c --- /dev/null +++ b/encoding/cbor_test.go @@ -0,0 +1,107 @@ +// Copyright 2021 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package encoding + +import ( + "testing" + + "github.com/fxamacker/cbor/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_PopulateStructFromCBOR_simple(t *testing.T) { + type SimpleStruct struct { + FieldOne string `cbor:"0,keyasint,omitempty"` + FieldTwo int `cbor:"1,keyasint"` + } + + var v SimpleStruct + + data := []byte{ + 0xa2, // map(2) + + 0x00, // key 0 + 0x64, // val tstr(4) + 0x61, 0x63, 0x6d, 0x65, // "acme" + + 0x01, // key 1 + 0x06, // val 6 + } + + dm, err := cbor.DecOptions{}.DecMode() + require.NoError(t, err) + + err = PopulateStructFromCBOR(dm, data, &v) + require.NoError(t, err) + assert.Equal(t, "acme", v.FieldOne) + assert.Equal(t, 6, v.FieldTwo) + + data = []byte{ + 0xa1, // map(1) + + 0x01, // key 1 + 0x06, // val 6 + } + v = SimpleStruct{} + + err = PopulateStructFromCBOR(dm, data, &v) + require.NoError(t, err) + assert.Equal(t, "", v.FieldOne) + assert.Equal(t, 6, v.FieldTwo) + + data = []byte{ + 0xa1, // map(1) + + 0x02, // key 2 + 0x64, // val tstr(4) + 0x61, 0x63, 0x6d, 0x65, // "acme" + } + v = SimpleStruct{} + + err = PopulateStructFromCBOR(dm, data, &v) + assert.EqualError(t, err, `missing mandatory field "FieldTwo" (1)`) + + err = PopulateStructFromCBOR(dm, []byte{0x01}, &v) + assert.EqualError(t, err, `cbor: cannot unmarshal positive integer into Go value of type map[int]cbor.RawMessage`) + + type CompositeStruct struct { + FieldThree string `cbor:"2,keyasint"` + SimpleStruct + } + + var c CompositeStruct + + data = []byte{ + 0xa3, // map(3) + + 0x00, // key 0 + 0x64, // val tstr(4) + 0x61, 0x63, 0x6d, 0x65, // "acme" + + 0x01, // key 1 + 0x06, // val 6 + + 0x02, // key 2 + 0x63, // val tstr(3) + 0x66, 0x6f, 0x6f, // "foo" + } + + err = PopulateStructFromCBOR(dm, data, &c) + require.NoError(t, err) + assert.Equal(t, "acme", c.FieldOne) + assert.Equal(t, 6, c.FieldTwo) + assert.Equal(t, "foo", c.FieldThree) + + em, err := cbor.EncOptions{}.EncMode() + require.NoError(t, err) + + res, err := SerializeStructToCBOR(em, &c) + require.NoError(t, err) + + var c2 CompositeStruct + err = PopulateStructFromCBOR(dm, res, &c2) + require.NoError(t, err) + assert.EqualValues(t, c, c2) + +}