Skip to content

Commit

Permalink
Example extension mechanism
Browse files Browse the repository at this point in the history
This is a PoC for an extension mechanism that would allow calling code
to supply extension implementations for sockets defined by the CoRIM
spec.

This is done by embedding an Extensions object at socket locations. This
object wraps an Interface implementations of which are user-supplied.
Extensions provides uniform accessor mechanism for the user extensions
without them needing to implement themselves. A custom serialisation
mechanism is used to ensure correct CBOR is generated for the embedded
structs.

Signed-off-by: Sergei Trofimov <[email protected]>
  • Loading branch information
setrofim committed Sep 19, 2023
1 parent 5a0612a commit c77a95c
Show file tree
Hide file tree
Showing 5 changed files with 616 additions and 1 deletion.
21 changes: 20 additions & 1 deletion corim/entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,30 @@ import (
"fmt"

"github.com/veraison/corim/comid"
"github.com/veraison/corim/encoding"
)

// Entity stores an entity-map capable of CBOR and JSON serializations.
type Entity struct {
EntityName string `cbor:"0,keyasint" json:"name"`
RegID *comid.TaggedURI `cbor:"1,keyasint,omitempty" json:"regid,omitempty"`
Roles Roles `cbor:"2,keyasint" json:"roles"`

Extensions
}

func NewEntity() *Entity {
return &Entity{}
}

func (o *Entity) RegisterExtensions(exts IExtensionsValue) {
o.Extensions.IExtensionsValue = exts
}

func (o *Entity) GetExtensions() IExtensionsValue {
return o.Extensions.IExtensionsValue
}

// SetEntityName is used to set the EntityName field of Entity using supplied name
func (o *Entity) SetEntityName(name string) *Entity {
if o != nil {
Expand Down Expand Up @@ -72,7 +83,15 @@ func (o Entity) Valid() error {
return fmt.Errorf("invalid entity: %w", err)
}

return nil
return o.Extensions.ValidEntity(&o)
}

func (o *Entity) UnmarshalCBOR(data []byte) error {
return encoding.PopulateStructFromCBOR(dm, data, o)
}

func (o *Entity) MarshalCBOR() ([]byte, error) {
return encoding.SerializeStructToCBOR(em, o)
}

// Entities is an array of entity-map's
Expand Down
150 changes: 150 additions & 0 deletions corim/extensions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// Copyright 2023 Contributors to the Veraison project.
// SPDX-License-Identifier: Apache-2.0
package corim

import (
"errors"
"fmt"
"reflect"
"strings"
)

var ErrExtensionNotFound = errors.New("extension not found")

type IExtensionsValue interface{}

type IEntityValidator interface {
ValidEntity(*Entity) error
}

type Extensions struct {
IExtensionsValue
}

func (o *Extensions) ValidEntity(entity *Entity) error {
if !o.HaveExtensions() {
return nil
}

ev, ok := o.IExtensionsValue.(IEntityValidator)
if ok {
if err := ev.ValidEntity(entity); err != nil {
return err
}
}

return nil
}

func (o *Extensions) HaveExtensions() bool {
return o.IExtensionsValue != nil
}

func (o *Extensions) Get(name string) (any, error) {
if o.IExtensionsValue == nil {
return nil, fmt.Errorf("%w: %s", ErrExtensionNotFound, name)
}

extType := reflect.TypeOf(o.IExtensionsValue)
extVal := reflect.ValueOf(o.IExtensionsValue)
if extType.Kind() == reflect.Pointer {
extType = extType.Elem()
extVal = extVal.Elem()
}

var fieldName, fieldJSONTag, fieldCBORTag string
for i := 0; i < extVal.NumField(); i++ {
typeField := extType.Field(i)
fieldName = typeField.Name

tag, ok := typeField.Tag.Lookup("json")
if ok {
fieldJSONTag = strings.Split(tag, ",")[0]
}

tag, ok = typeField.Tag.Lookup("cbor")
if ok {
fieldCBORTag = strings.Split(tag, ",")[0]
}

if fieldName == name || fieldJSONTag == name || fieldCBORTag == name {
return extVal.Field(i).Interface(), nil
}
}

return nil, fmt.Errorf("%w: %s", ErrExtensionNotFound, name)
}

func (o *Extensions) GetString(name string) (string, error) {
v, err := o.Get(name)
if err != nil {
return "", err
}

switch t := v.(type) {
case string:
return t, nil
default:
return fmt.Sprintf("%v", t), nil
}
}

func (o *Extensions) GetInt(name string) (int64, error) {
v, err := o.Get(name)
if err != nil {
return 0, err
}

val := reflect.ValueOf(v)
if val.CanInt() {
return val.Int(), nil
}

return 0, fmt.Errorf("%s is not an integer: %v (%T)", name, v, v)
}

func (o *Extensions) Set(name string, value any) error {
if o.IExtensionsValue == nil {
return fmt.Errorf("%w: %s", ErrExtensionNotFound, name)
}

extType := reflect.TypeOf(o.IExtensionsValue)
extVal := reflect.ValueOf(o.IExtensionsValue)
if extType.Kind() == reflect.Pointer {
extType = extType.Elem()
extVal = extVal.Elem()
}

var fieldName, fieldJSONTag, fieldCBORTag string
for i := 0; i < extVal.NumField(); i++ {
typeField := extType.Field(i)
valField := extVal.Field(i)
fieldName = typeField.Name

tag, ok := typeField.Tag.Lookup("json")
if ok {
fieldJSONTag = strings.Split(tag, ",")[0]
}

tag, ok = typeField.Tag.Lookup("cbor")
if ok {
fieldCBORTag = strings.Split(tag, ",")[0]
}

if fieldName == name || fieldJSONTag == name || fieldCBORTag == name {
newVal := reflect.ValueOf(value)
if newVal.CanConvert(valField.Type()) {
valField.Set(newVal.Convert(valField.Type()))
return nil
}

return fmt.Errorf(
"cannot set field %q (of type %s) to %v (%T)",
name, typeField.Type.Name(),
value, value,
)
}
}

return fmt.Errorf("%w: %s", ErrExtensionNotFound, name)
}
113 changes: 113 additions & 0 deletions corim/extensions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright 2023 Contributors to the Veraison project.
// SPDX-License-Identifier: Apache-2.0
package corim

import (
"errors"
"testing"

"github.com/fxamacker/cbor/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type TestExtensions struct {
Address string `cbor:"-1,keyasint,omitempty" json:"address,omitempty"`
Size int `cbor:"-2,keyasint,omitempty" json:"size,omitempty"`
}

func (o TestExtensions) ValidEntity(ent *Entity) error {
if ent.EntityName != "Futurama" {
return errors.New(`EntityName must be "Futurama"`) // nolint:golint
}

return nil
}

func TestEntityExtensions_GetSet(t *testing.T) {
extsVal := TestExtensions{
Address: "742 Evergreen Terrace",
Size: 6,
}
exts := &Extensions{&extsVal}

v, err := exts.GetInt("size")
assert.NoError(t, err)
assert.Equal(t, int64(6), v)

s, err := exts.GetString("address")
assert.NoError(t, err)
assert.Equal(t, "742 Evergreen Terrace", s)

_, err = exts.GetInt("address")
assert.EqualError(t, err, "address is not an integer: 742 Evergreen Terrace (string)")

_, err = exts.GetInt("foo")
assert.EqualError(t, err, "extension not found: foo")

err = exts.Set("-1", "123 Fake Street")
assert.NoError(t, err)

s, err = exts.GetString("address")
assert.NoError(t, err)
assert.Equal(t, "123 Fake Street", s)

err = exts.Set("Size", "foo")
assert.EqualError(t, err, `cannot set field "Size" (of type int) to foo (string)`)

ent := NewEntity()
ent.RegisterExtensions(&extsVal)

obtainedVal := ent.GetExtensions().(*TestExtensions)
assert.EqualValues(t, extsVal, *obtainedVal)
}

func TestEntityExtensions_Valid(t *testing.T) {
ent := NewEntity()
ent.SetEntityName("The Simpsons")
ent.SetRoles(RoleManifestCreator)

err := ent.Valid()
assert.NoError(t, err)

ent.RegisterExtensions(&TestExtensions{})
err = ent.Valid()
assert.EqualError(t, err, `EntityName must be "Futurama"`)

ent.SetEntityName("Futurama")
err = ent.Valid()
assert.NoError(t, err)
}

func TestEntityExtensions_CBOR(t *testing.T) {
data := []byte{
0xa4, // map(4)

0x00, // key 0
0x64, // val tstr(4)
0x61, 0x63, 0x6d, 0x65, // "acme"

0x02, // key 2
0x81, // array(1)
0x01, // 1

0x20, // key -1
0x63, // val tstr(3)
0x66, 0x6f, 0x6f, // "foo"

0x21, // key -2
0x06, // val 6
}

ent := NewEntity()
ent.RegisterExtensions(&TestExtensions{})

err := cbor.Unmarshal(data, &ent)
assert.NoError(t, err)

assert.Equal(t, ent.EntityName, "acme")

address, err := ent.Get("address")
require.NoError(t, err)
assert.Equal(t, address, "foo")
}
Loading

0 comments on commit c77a95c

Please sign in to comment.