-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This is a PoC for an extension mechanism that would allow calling code to supply extension implementations for sockets defined by the CoRIM spec. This is done by embedding an Extensions object at socket locations. This object wraps an Interface implementations of which are user-supplied. Extensions provides uniform accessor mechanism for the user extensions without them needing to implement themselves. A custom serialisation mechanism is used to ensure correct CBOR is generated for the embedded structs. Signed-off-by: Sergei Trofimov <[email protected]>
- Loading branch information
Showing
5 changed files
with
616 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
// Copyright 2023 Contributors to the Veraison project. | ||
// SPDX-License-Identifier: Apache-2.0 | ||
package corim | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"reflect" | ||
"strings" | ||
) | ||
|
||
var ErrExtensionNotFound = errors.New("extension not found") | ||
|
||
type IExtensionsValue interface{} | ||
|
||
type IEntityValidator interface { | ||
ValidEntity(*Entity) error | ||
} | ||
|
||
type Extensions struct { | ||
IExtensionsValue | ||
} | ||
|
||
func (o *Extensions) ValidEntity(entity *Entity) error { | ||
if !o.HaveExtensions() { | ||
return nil | ||
} | ||
|
||
ev, ok := o.IExtensionsValue.(IEntityValidator) | ||
if ok { | ||
if err := ev.ValidEntity(entity); err != nil { | ||
return err | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func (o *Extensions) HaveExtensions() bool { | ||
return o.IExtensionsValue != nil | ||
} | ||
|
||
func (o *Extensions) Get(name string) (any, error) { | ||
if o.IExtensionsValue == nil { | ||
return nil, fmt.Errorf("%w: %s", ErrExtensionNotFound, name) | ||
} | ||
|
||
extType := reflect.TypeOf(o.IExtensionsValue) | ||
extVal := reflect.ValueOf(o.IExtensionsValue) | ||
if extType.Kind() == reflect.Pointer { | ||
extType = extType.Elem() | ||
extVal = extVal.Elem() | ||
} | ||
|
||
var fieldName, fieldJSONTag, fieldCBORTag string | ||
for i := 0; i < extVal.NumField(); i++ { | ||
typeField := extType.Field(i) | ||
fieldName = typeField.Name | ||
|
||
tag, ok := typeField.Tag.Lookup("json") | ||
if ok { | ||
fieldJSONTag = strings.Split(tag, ",")[0] | ||
} | ||
|
||
tag, ok = typeField.Tag.Lookup("cbor") | ||
if ok { | ||
fieldCBORTag = strings.Split(tag, ",")[0] | ||
} | ||
|
||
if fieldName == name || fieldJSONTag == name || fieldCBORTag == name { | ||
return extVal.Field(i).Interface(), nil | ||
} | ||
} | ||
|
||
return nil, fmt.Errorf("%w: %s", ErrExtensionNotFound, name) | ||
} | ||
|
||
func (o *Extensions) GetString(name string) (string, error) { | ||
v, err := o.Get(name) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
switch t := v.(type) { | ||
case string: | ||
return t, nil | ||
default: | ||
return fmt.Sprintf("%v", t), nil | ||
} | ||
} | ||
|
||
func (o *Extensions) GetInt(name string) (int64, error) { | ||
v, err := o.Get(name) | ||
if err != nil { | ||
return 0, err | ||
} | ||
|
||
val := reflect.ValueOf(v) | ||
if val.CanInt() { | ||
return val.Int(), nil | ||
} | ||
|
||
return 0, fmt.Errorf("%s is not an integer: %v (%T)", name, v, v) | ||
} | ||
|
||
func (o *Extensions) Set(name string, value any) error { | ||
if o.IExtensionsValue == nil { | ||
return fmt.Errorf("%w: %s", ErrExtensionNotFound, name) | ||
} | ||
|
||
extType := reflect.TypeOf(o.IExtensionsValue) | ||
extVal := reflect.ValueOf(o.IExtensionsValue) | ||
if extType.Kind() == reflect.Pointer { | ||
extType = extType.Elem() | ||
extVal = extVal.Elem() | ||
} | ||
|
||
var fieldName, fieldJSONTag, fieldCBORTag string | ||
for i := 0; i < extVal.NumField(); i++ { | ||
typeField := extType.Field(i) | ||
valField := extVal.Field(i) | ||
fieldName = typeField.Name | ||
|
||
tag, ok := typeField.Tag.Lookup("json") | ||
if ok { | ||
fieldJSONTag = strings.Split(tag, ",")[0] | ||
} | ||
|
||
tag, ok = typeField.Tag.Lookup("cbor") | ||
if ok { | ||
fieldCBORTag = strings.Split(tag, ",")[0] | ||
} | ||
|
||
if fieldName == name || fieldJSONTag == name || fieldCBORTag == name { | ||
newVal := reflect.ValueOf(value) | ||
if newVal.CanConvert(valField.Type()) { | ||
valField.Set(newVal.Convert(valField.Type())) | ||
return nil | ||
} | ||
|
||
return fmt.Errorf( | ||
"cannot set field %q (of type %s) to %v (%T)", | ||
name, typeField.Type.Name(), | ||
value, value, | ||
) | ||
} | ||
} | ||
|
||
return fmt.Errorf("%w: %s", ErrExtensionNotFound, name) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
// Copyright 2023 Contributors to the Veraison project. | ||
// SPDX-License-Identifier: Apache-2.0 | ||
package corim | ||
|
||
import ( | ||
"errors" | ||
"testing" | ||
|
||
"github.com/fxamacker/cbor/v2" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
type TestExtensions struct { | ||
Address string `cbor:"-1,keyasint,omitempty" json:"address,omitempty"` | ||
Size int `cbor:"-2,keyasint,omitempty" json:"size,omitempty"` | ||
} | ||
|
||
func (o TestExtensions) ValidEntity(ent *Entity) error { | ||
if ent.EntityName != "Futurama" { | ||
return errors.New(`EntityName must be "Futurama"`) // nolint:golint | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func TestEntityExtensions_GetSet(t *testing.T) { | ||
extsVal := TestExtensions{ | ||
Address: "742 Evergreen Terrace", | ||
Size: 6, | ||
} | ||
exts := &Extensions{&extsVal} | ||
|
||
v, err := exts.GetInt("size") | ||
assert.NoError(t, err) | ||
assert.Equal(t, int64(6), v) | ||
|
||
s, err := exts.GetString("address") | ||
assert.NoError(t, err) | ||
assert.Equal(t, "742 Evergreen Terrace", s) | ||
|
||
_, err = exts.GetInt("address") | ||
assert.EqualError(t, err, "address is not an integer: 742 Evergreen Terrace (string)") | ||
|
||
_, err = exts.GetInt("foo") | ||
assert.EqualError(t, err, "extension not found: foo") | ||
|
||
err = exts.Set("-1", "123 Fake Street") | ||
assert.NoError(t, err) | ||
|
||
s, err = exts.GetString("address") | ||
assert.NoError(t, err) | ||
assert.Equal(t, "123 Fake Street", s) | ||
|
||
err = exts.Set("Size", "foo") | ||
assert.EqualError(t, err, `cannot set field "Size" (of type int) to foo (string)`) | ||
|
||
ent := NewEntity() | ||
ent.RegisterExtensions(&extsVal) | ||
|
||
obtainedVal := ent.GetExtensions().(*TestExtensions) | ||
assert.EqualValues(t, extsVal, *obtainedVal) | ||
} | ||
|
||
func TestEntityExtensions_Valid(t *testing.T) { | ||
ent := NewEntity() | ||
ent.SetEntityName("The Simpsons") | ||
ent.SetRoles(RoleManifestCreator) | ||
|
||
err := ent.Valid() | ||
assert.NoError(t, err) | ||
|
||
ent.RegisterExtensions(&TestExtensions{}) | ||
err = ent.Valid() | ||
assert.EqualError(t, err, `EntityName must be "Futurama"`) | ||
|
||
ent.SetEntityName("Futurama") | ||
err = ent.Valid() | ||
assert.NoError(t, err) | ||
} | ||
|
||
func TestEntityExtensions_CBOR(t *testing.T) { | ||
data := []byte{ | ||
0xa4, // map(4) | ||
|
||
0x00, // key 0 | ||
0x64, // val tstr(4) | ||
0x61, 0x63, 0x6d, 0x65, // "acme" | ||
|
||
0x02, // key 2 | ||
0x81, // array(1) | ||
0x01, // 1 | ||
|
||
0x20, // key -1 | ||
0x63, // val tstr(3) | ||
0x66, 0x6f, 0x6f, // "foo" | ||
|
||
0x21, // key -2 | ||
0x06, // val 6 | ||
} | ||
|
||
ent := NewEntity() | ||
ent.RegisterExtensions(&TestExtensions{}) | ||
|
||
err := cbor.Unmarshal(data, &ent) | ||
assert.NoError(t, err) | ||
|
||
assert.Equal(t, ent.EntityName, "acme") | ||
|
||
address, err := ent.Get("address") | ||
require.NoError(t, err) | ||
assert.Equal(t, address, "foo") | ||
} |
Oops, something went wrong.