Skip to content

Commit

Permalink
Add discriminator to the account and allow encoding/decoding of non-a…
Browse files Browse the repository at this point in the history
…ccount types
  • Loading branch information
nolag committed Apr 23, 2024
1 parent fc974f2 commit 0f5ebb4
Show file tree
Hide file tree
Showing 8 changed files with 302 additions and 23 deletions.
2 changes: 1 addition & 1 deletion pkg/solana/chainreader/chain_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func (s *SolanaChainReaderService) init(namespaces map[string]config.ChainReader
return err
}

idlCodec, err := codec.NewIDLCodec(idl, config.BuilderForEncoding(method.Encoding))
idlCodec, err := codec.NewIDLAccountCodec(idl, config.BuilderForEncoding(method.Encoding))
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/solana/chainreader/chain_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ func newTestIDLAndCodec(t *testing.T) (string, codec.IDL, types.RemoteCodec) {
t.FailNow()
}

entry, err := codec.NewIDLCodec(idl, binary.LittleEndian())
entry, err := codec.NewIDLAccountCodec(idl, binary.LittleEndian())
if err != nil {
t.Logf("failed to create new codec from test IDL: %s", err.Error())
t.FailNow()
Expand Down Expand Up @@ -722,7 +722,7 @@ func makeTestCodec(t *testing.T, rawIDL string, encoding config.EncodingType) ty
t.FailNow()
}

testCodec, err := codec.NewIDLCodec(idl, config.BuilderForEncoding(encoding))
testCodec, err := codec.NewIDLAccountCodec(idl, config.BuilderForEncoding(encoding))
if err != nil {
t.Logf("failed to create new codec from test IDL: %s", err.Error())
t.FailNow()
Expand Down
71 changes: 71 additions & 0 deletions pkg/solana/codec/discriminator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package codec

import (
"bytes"
"crypto/sha256"
"fmt"
"reflect"

"github.com/smartcontractkit/chainlink-common/pkg/codec/encodings"
"github.com/smartcontractkit/chainlink-common/pkg/types"
)

const discriminatorLength = 8

func NewDiscriminator(name string) encodings.TypeCodec {
sum := sha256.Sum256([]byte("account:" + name))
return &discriminator{hashPrefix: sum[:discriminatorLength]}
}

type discriminator struct {
hashPrefix []byte
}

func (d discriminator) Encode(value any, into []byte) ([]byte, error) {
if value == nil {
return append(into, d.hashPrefix...), nil
}

raw, ok := value.(*[]byte)
if !ok {
return nil, fmt.Errorf("%w: value must be a byte slice got %T", types.ErrInvalidType, value)
}

// inject if not specified
if raw == nil {
return append(into, d.hashPrefix...), nil
}

// Not sure if we should really be encoding accounts...
if !bytes.Equal(*raw, d.hashPrefix) {
return nil, fmt.Errorf("%w: invalid discriminator expected %x got %x", types.ErrInvalidType, d.hashPrefix, raw)
}

return append(into, *raw...), nil
}

func (d discriminator) Decode(encoded []byte) (any, []byte, error) {
raw, remaining, err := encodings.SafeDecode(encoded, discriminatorLength, func(raw []byte) []byte { return raw })
if err != nil {
return nil, nil, err
}

if !bytes.Equal(raw, d.hashPrefix) {
return nil, nil, fmt.Errorf("%w: invalid discriminator expected %x got %x", types.ErrInvalidEncoding, d.hashPrefix, raw)
}

return &raw, remaining, nil
}

func (d discriminator) GetType() reflect.Type {
// Pointer type so that nil can inject values and so that the NamedCodec won't wrap with no-nil pointer.
return reflect.TypeOf(&[]byte{})
}

func (d discriminator) Size(_ int) (int, error) {
return discriminatorLength, nil
}

func (d discriminator) FixedSize() (int, error) {
return discriminatorLength, nil
}
83 changes: 83 additions & 0 deletions pkg/solana/codec/discriminator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package codec_test

import (
"crypto/sha256"
"errors"
"reflect"
"testing"

"github.com/smartcontractkit/chainlink-common/pkg/types"
"github.com/stretchr/testify/require"

"github.com/smartcontractkit/chainlink-solana/pkg/solana/codec"
)

func TestDiscriminator(t *testing.T) {
t.Run("encode and decode return the discriminator", func(t *testing.T) {
tmp := sha256.Sum256([]byte("account:Foo"))
expected := tmp[:8]
c := codec.NewDiscriminator("Foo")
encoded, err := c.Encode(&expected, nil)
require.NoError(t, err)
require.Equal(t, expected, encoded)
actual, remaining, err := c.Decode(encoded)
require.NoError(t, err)
require.Equal(t, &expected, actual)
require.Len(t, remaining, 0)
})

t.Run("encode returns an error if the discriminator is invalid", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
_, err := c.Encode(&[]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, nil)
require.True(t, errors.Is(err, types.ErrInvalidType))
})

t.Run("encode injects the discriminator if it's not provided", func(t *testing.T) {
tmp := sha256.Sum256([]byte("account:Foo"))
expected := tmp[:8]
c := codec.NewDiscriminator("Foo")
encoded, err := c.Encode(nil, nil)
require.NoError(t, err)
require.Equal(t, expected, encoded)
encoded, err = c.Encode((*[]byte)(nil), nil)
require.NoError(t, err)
require.Equal(t, expected, encoded)
})

t.Run("decode returns an error if the encoded value is too short", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
_, _, err := c.Decode([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06})
require.True(t, errors.Is(err, types.ErrInvalidEncoding))
})

t.Run("decode returns an error if the discriminator is invalid", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
_, _, err := c.Decode([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07})
require.True(t, errors.Is(err, types.ErrInvalidEncoding))
})

t.Run("encode returns an error if the value is not a byte slice", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
_, err := c.Encode(42, nil)
require.True(t, errors.Is(err, types.ErrInvalidType))
})

t.Run("GetType returns the type of the discriminator", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
require.Equal(t, reflect.TypeOf(&[]byte{}), c.GetType())
})

t.Run("Size returns the length of the discriminator", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
size, err := c.Size(0)
require.NoError(t, err)
require.Equal(t, 8, size)
})

t.Run("FixedSize returns the length of the discriminator", func(t *testing.T) {
c := codec.NewDiscriminator("Foo")
size, err := c.FixedSize()
require.NoError(t, err)
require.Equal(t, 8, size)
})
}
41 changes: 30 additions & 11 deletions pkg/solana/codec/solana.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,18 @@ func NewNamedModifierCodec(original types.RemoteCodec, itemType string, modifier
return modCodec, err
}

// NewIDLCodec is for Anchor custom types
func NewIDLCodec(idl IDL, builder encodings.Builder) (types.RemoteCodec, error) {
accounts := make(encodings.LenientCodecFromTypeCodec)
// NewIDLAccountCodec is for Anchor custom types
func NewIDLAccountCodec(idl IDL, builder encodings.Builder) (types.RemoteCodec, error) {
return newIdlCoded(idl, builder, idl.Accounts, true)
}

func NewIDLDefinedTypesCodec(idl IDL, builder encodings.Builder) (types.RemoteCodec, error) {
return newIdlCoded(idl, builder, idl.Types, false)
}

func newIdlCoded(
idl IDL, builder encodings.Builder, from IdlTypeDefSlice, includeDiscriminator bool) (types.RemoteCodec, error) {
typeCodecs := make(encodings.LenientCodecFromTypeCodec)

refs := &codecRefs{
builder: builder,
Expand All @@ -71,22 +80,22 @@ func NewIDLCodec(idl IDL, builder encodings.Builder) (types.RemoteCodec, error)
dependencies: make(map[string][]string),
}

for _, account := range idl.Accounts {
for _, def := range from {
var (
name string
accCodec encodings.TypeCodec
err error
)

name, accCodec, err = createNamedCodec(account, refs)
name, accCodec, err = createNamedCodec(def, refs, includeDiscriminator)
if err != nil {
return nil, err
}

accounts[name] = accCodec
typeCodecs[name] = accCodec
}

return accounts, nil
return typeCodecs, nil
}

type codecRefs struct {
Expand All @@ -99,13 +108,14 @@ type codecRefs struct {
func createNamedCodec(
def IdlTypeDef,
refs *codecRefs,
includeDiscriminator bool,
) (string, encodings.TypeCodec, error) {
caser := cases.Title(language.English)
name := def.Name

switch def.Type.Kind {
case IdlTypeDefTyKindStruct:
return asStruct(def, refs, name, caser)
return asStruct(def, refs, name, caser, includeDiscriminator)
case IdlTypeDefTyKindEnum:
variants := def.Type.Variants
if !variants.IsAllUint8() {
Expand All @@ -123,8 +133,17 @@ func asStruct(
refs *codecRefs,
name string, // name is the struct name and can be used in dependency checks
caser cases.Caser,
includeDiscriminator bool,
) (string, encodings.TypeCodec, error) {
named := make([]encodings.NamedTypeCodec, len(*def.Type.Fields))
desLen := 0
if includeDiscriminator {
desLen = 1
}
named := make([]encodings.NamedTypeCodec, len(*def.Type.Fields)+desLen)

if includeDiscriminator {
named[0] = encodings.NamedTypeCodec{Name: "Discriminator" + name, Codec: NewDiscriminator(name)}
}

for idx, field := range *def.Type.Fields {
fieldName := field.Name
Expand All @@ -134,7 +153,7 @@ func asStruct(
return name, nil, err
}

named[idx] = encodings.NamedTypeCodec{Name: caser.String(fieldName), Codec: typedCodec}
named[idx+desLen] = encodings.NamedTypeCodec{Name: caser.String(fieldName), Codec: typedCodec}
}

structCodec, err := encodings.NewStructCodec(named)
Expand Down Expand Up @@ -188,7 +207,7 @@ func asDefined(parentTypeName string, definedName *IdlTypeDefined, refs *codecRe

saveDependency(refs, parentTypeName, definedName.Defined)

newTypeName, newTypeCodec, err := createNamedCodec(*nextDef, refs)
newTypeName, newTypeCodec, err := createNamedCodec(*nextDef, refs, false)
if err != nil {
return nil, err
}
Expand Down
44 changes: 38 additions & 6 deletions pkg/solana/codec/solana_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@ import (
"github.com/smartcontractkit/chainlink-solana/pkg/solana/codec/testutils"
)

func TestNewIDLCodec(t *testing.T) {
func TestNewIDLAccountCodec(t *testing.T) {
/// TODO this should run the codec interface tests
t.Parallel()

ctx := tests.Context(t)
_, _, entry := newTestIDLAndCodec(t)
_, _, entry := newTestIDLAndCodec(t, true)

expected := testutils.DefaultTestStruct
bts, err := entry.Encode(ctx, expected, testutils.TestStructWithNestedStruct)

// length of fields + discriminator
require.Equal(t, 262, len(bts))

require.NoError(t, err)

var decoded testutils.StructWithNestedStruct
Expand All @@ -35,11 +39,32 @@ func TestNewIDLCodec(t *testing.T) {
require.Equal(t, expected, decoded)
}

func TestNewIDLDefinedTypesCodecCodec(t *testing.T) {
/// TODO this should run the codec interface tests
t.Parallel()

ctx := tests.Context(t)
_, _, entry := newTestIDLAndCodec(t, false)

expected := testutils.DefaultTestStruct
bts, err := entry.Encode(ctx, expected, testutils.TestStructWithNestedStructType)

// length of fields without a discriminator
require.Equal(t, 254, len(bts))

require.NoError(t, err)

var decoded testutils.StructWithNestedStruct

require.NoError(t, entry.Decode(ctx, bts, &decoded, testutils.TestStructWithNestedStructType))
require.Equal(t, expected, decoded)
}

func TestNewIDLCodec_WithModifiers(t *testing.T) {
t.Parallel()

ctx := tests.Context(t)
_, _, idlCodec := newTestIDLAndCodec(t)
_, _, idlCodec := newTestIDLAndCodec(t, true)
modConfig := codeccommon.ModifiersConfig{
&codeccommon.RenameModifierConfig{Fields: map[string]string{"Value": "V"}},
}
Expand Down Expand Up @@ -113,12 +138,12 @@ func TestNewIDLCodec_CircularDependency(t *testing.T) {
t.FailNow()
}

_, err := codec.NewIDLCodec(idl, binary.LittleEndian())
_, err := codec.NewIDLAccountCodec(idl, binary.LittleEndian())

assert.ErrorIs(t, err, types.ErrInvalidConfig)
}

func newTestIDLAndCodec(t *testing.T) (string, codec.IDL, types.RemoteCodec) {
func newTestIDLAndCodec(t *testing.T, account bool) (string, codec.IDL, types.RemoteCodec) {
t.Helper()

var idl codec.IDL
Expand All @@ -127,7 +152,14 @@ func newTestIDLAndCodec(t *testing.T) (string, codec.IDL, types.RemoteCodec) {
t.FailNow()
}

entry, err := codec.NewIDLCodec(idl, binary.LittleEndian())
var entry types.RemoteCodec
var err error
if account {
entry, err = codec.NewIDLAccountCodec(idl, binary.LittleEndian())
} else {
entry, err = codec.NewIDLDefinedTypesCodec(idl, binary.LittleEndian())
}

if err != nil {
t.Logf("failed to create new codec from test IDL: %s", err.Error())
t.FailNow()
Expand Down
Loading

0 comments on commit 0f5ebb4

Please sign in to comment.