Skip to content

Commit

Permalink
[BCF-3143] Multiple Address Bindings (#849)
Browse files Browse the repository at this point in the history
* support for multiple address bindings

* fix errors
  • Loading branch information
EasterTheBunny authored Sep 11, 2024
1 parent 48b813a commit 5b8d287
Show file tree
Hide file tree
Showing 8 changed files with 294 additions and 102 deletions.
32 changes: 16 additions & 16 deletions pkg/solana/chainreader/account_read_binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ type BinaryDataReader interface {
// `idlAccount` refers to the account name in the IDL for which the codec has a type mapping.
type accountReadBinding struct {
idlAccount string
account solana.PublicKey
codec types.RemoteCodec
reader BinaryDataReader
opts *rpc.GetAccountInfoOpts
Expand All @@ -37,12 +36,19 @@ func newAccountReadBinding(acct string, codec types.RemoteCodec, reader BinaryDa

var _ readBinding = &accountReadBinding{}

func (b *accountReadBinding) PreLoad(ctx context.Context, result *loadedResult) {
func (b *accountReadBinding) PreLoad(ctx context.Context, address string, result *loadedResult) {
if result == nil {
return
}

bts, err := b.reader.ReadAll(ctx, b.account, b.opts)
account, err := solana.PublicKeyFromBase58(address)
if err != nil {
result.err <- err

return
}

bts, err := b.reader.ReadAll(ctx, account, b.opts)
if err != nil {
result.err <- fmt.Errorf("%w: failed to get binary data", err)

Expand All @@ -57,7 +63,7 @@ func (b *accountReadBinding) PreLoad(ctx context.Context, result *loadedResult)
}
}

func (b *accountReadBinding) GetLatestValue(ctx context.Context, _ any, outVal any, result *loadedResult) error {
func (b *accountReadBinding) GetLatestValue(ctx context.Context, address string, _ any, outVal any, result *loadedResult) error {
var (
bts []byte
err error
Expand All @@ -79,25 +85,19 @@ func (b *accountReadBinding) GetLatestValue(ctx context.Context, _ any, outVal a
return err
}
} else {
if bts, err = b.reader.ReadAll(ctx, b.account, b.opts); err != nil {
account, err := solana.PublicKeyFromBase58(address)
if err != nil {
return err
}

if bts, err = b.reader.ReadAll(ctx, account, b.opts); err != nil {
return fmt.Errorf("%w: failed to get binary data", err)
}
}

return b.codec.Decode(ctx, bts, outVal, b.idlAccount)
}

func (b *accountReadBinding) Bind(contract types.BoundContract) error {
account, err := solana.PublicKeyFromBase58(contract.Address)
if err != nil {
return err
}

b.account = account

return nil
}

func (b *accountReadBinding) CreateType(_ bool) (any, error) {
return b.codec.CreateType(b.idlAccount, false)
}
16 changes: 10 additions & 6 deletions pkg/solana/chainreader/account_read_binding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ func TestPreload(t *testing.T) {
err: make(chan error, 1),
}

binding.PreLoad(ctx, loaded)
pubKey := solana.NewWallet().PublicKey()

binding.PreLoad(ctx, pubKey.String(), loaded)

var result testStruct

err = binding.GetLatestValue(ctx, nil, &result, loaded)
err = binding.GetLatestValue(ctx, pubKey.String(), nil, &result, loaded)
elapsed := time.Since(start)

require.NoError(t, err)
Expand Down Expand Up @@ -74,15 +76,16 @@ func TestPreload(t *testing.T) {
cancel(expectedErr)
}()

pubKey := solana.NewWallet().PublicKey()
loaded := &loadedResult{
value: make(chan []byte, 1),
err: make(chan error, 1),
}
start := time.Now()
binding.PreLoad(ctx, loaded)
binding.PreLoad(ctx, pubKey.String(), loaded)

var result testStruct
err := binding.GetLatestValue(ctx, nil, &result, loaded)
err := binding.GetLatestValue(ctx, pubKey.String(), nil, &result, loaded)
elapsed := time.Since(start)

assert.ErrorIs(t, err, ctx.Err())
Expand All @@ -102,14 +105,15 @@ func TestPreload(t *testing.T) {
reader.On("ReadAll", mock.Anything, mock.Anything, mock.Anything).
Return([]byte{}, expectedErr)

pubKey := solana.NewWallet().PublicKey()
loaded := &loadedResult{
value: make(chan []byte, 1),
err: make(chan error, 1),
}
binding.PreLoad(ctx, loaded)
binding.PreLoad(ctx, pubKey.String(), loaded)

var result testStruct
err := binding.GetLatestValue(ctx, nil, &result, loaded)
err := binding.GetLatestValue(ctx, pubKey.String(), nil, &result, loaded)

assert.ErrorIs(t, err, expectedErr)
})
Expand Down
48 changes: 17 additions & 31 deletions pkg/solana/chainreader/bindings.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@ import (
"context"
"fmt"
"reflect"
"strconv"
"strings"

"github.com/gagliardetto/solana-go"
"github.com/smartcontractkit/chainlink-common/pkg/types"
)

type readBinding interface {
PreLoad(context.Context, *loadedResult)
GetLatestValue(ctx context.Context, params, returnVal any, preload *loadedResult) error
Bind(types.BoundContract) error
PreLoad(context.Context, string, *loadedResult)
GetLatestValue(ctx context.Context, address string, params, returnVal any, preload *loadedResult) error
CreateType(bool) (any, error)
}

Expand Down Expand Up @@ -111,34 +109,22 @@ func (b namespaceBindings) CreateType(namespace, methodName string, forEncoding
return reflect.New(reflect.StructOf(fields)).Interface(), nil
}

func (b namespaceBindings) Bind(boundContracts []types.BoundContract) error {
for _, bc := range boundContracts {
parts := strings.Split(bc.Name, ".")
if len(parts) != 3 {
return fmt.Errorf("%w: BoundContract.Name must follow pattern of [namespace.method.procedure_idx]", types.ErrInvalidConfig)
}

nbs, nbsExist := b[parts[0]]
if !nbsExist {
return fmt.Errorf("%w: no namespace named %s for %s", types.ErrInvalidConfig, parts[0], bc.Name)
}

mbs, mbsExists := nbs[parts[1]]
if !mbsExists {
return fmt.Errorf("%w: no method named %s for %s", types.ErrInvalidConfig, parts[1], bc.Name)
}

val, err := strconv.Atoi(parts[2])
if err != nil {
return fmt.Errorf("%w: procedure index not parsable for %s", types.ErrInvalidConfig, bc.Name)
}
func (b namespaceBindings) Bind(binding types.BoundContract) error {
_, nbsExist := b[binding.Name]
if !nbsExist {
return fmt.Errorf("%w: no namespace named %s", types.ErrInvalidConfig, binding.Name)
}

if len(mbs) <= val {
return fmt.Errorf("%w: no procedure for index %d for %s", types.ErrInvalidConfig, val, bc.Name)
}
readAddresses, err := decodeAddressMappings(binding.Address)
if err != nil {
return err
}

if err := mbs[val].Bind(bc); err != nil {
return err
for readName, addresses := range readAddresses {
for idx, address := range addresses {
if _, err := solana.PublicKeyFromBase58(address); err != nil {
return fmt.Errorf("%w: invalid address binding for %s at index %d: %s", types.ErrInvalidConfig, readName, idx, err.Error())
}
}
}

Expand Down
8 changes: 2 additions & 6 deletions pkg/solana/chainreader/bindings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,9 @@ type mockBinding struct {
mock.Mock
}

func (_m *mockBinding) PreLoad(context.Context, *loadedResult) {}
func (_m *mockBinding) PreLoad(context.Context, string, *loadedResult) {}

func (_m *mockBinding) GetLatestValue(ctx context.Context, params, returnVal any, _ *loadedResult) error {
return nil
}

func (_m *mockBinding) Bind(types.BoundContract) error {
func (_m *mockBinding) GetLatestValue(ctx context.Context, address string, params, returnVal any, _ *loadedResult) error {
return nil
}

Expand Down
81 changes: 68 additions & 13 deletions pkg/solana/chainreader/chain_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package chainreader

import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"reflect"
"strings"
"sync"

ag_solana "github.com/gagliardetto/solana-go"
Expand All @@ -31,6 +31,7 @@ type SolanaChainReaderService struct {

// internal values
bindings namespaceBindings
lookup *lookup

// service state management
wg sync.WaitGroup
Expand All @@ -48,6 +49,7 @@ func NewChainReaderService(lggr logger.Logger, dataReader BinaryDataReader, cfg
lggr: logger.Named(lggr, ServiceName),
client: dataReader,
bindings: namespaceBindings{},
lookup: newLookup(),
}

if err := svc.init(cfg.Namespaces); err != nil {
Expand Down Expand Up @@ -103,14 +105,30 @@ func (s *SolanaChainReaderService) GetLatestValue(ctx context.Context, readIdent
s.wg.Add(1)
defer s.wg.Done()

split := strings.Split(readIdentifier, ".")
contractName, method := split[0], split[1]
values, ok := s.lookup.getContractForReadIdentifiers(readIdentifier)
if !ok {
return fmt.Errorf("%w: no contract for read identifier %s", types.ErrInvalidType, readIdentifier)
}

addressMappings, err := decodeAddressMappings(values.address)
if err != nil {
return fmt.Errorf("%w: %s", types.ErrInvalidConfig, err)
}

bindings, err := s.bindings.GetReadBindings(contractName, method)
addresses, ok := addressMappings[values.readName]
if !ok {
return fmt.Errorf("%w: no addresses for readName %s", types.ErrInvalidConfig, values.readName)
}

bindings, err := s.bindings.GetReadBindings(values.contract, values.readName)
if err != nil {
return err
}

if len(addresses) != len(bindings) {
return fmt.Errorf("%w: addresses and bindings lengths do not match", types.ErrInvalidConfig)
}

localCtx, localCancel := context.WithCancel(ctx)

// the wait group ensures GetLatestValue returns only after all go-routines have completed
Expand Down Expand Up @@ -145,11 +163,11 @@ func (s *SolanaChainReaderService) GetLatestValue(ctx context.Context, readIdent
}

wg.Add(1)
go func(ctx context.Context, rb readBinding, res *loadedResult) {
go func(ctx context.Context, rb readBinding, res *loadedResult, address string) {
defer wg.Done()

rb.PreLoad(ctx, res)
}(localCtx, binding, results[idx])
rb.PreLoad(ctx, address, res)
}(localCtx, binding, results[idx], addresses[idx])
}
}

Expand All @@ -158,7 +176,7 @@ func (s *SolanaChainReaderService) GetLatestValue(ctx context.Context, readIdent
// in the case of no preloading, GetLatestValue will load and decode in
// sequence.
for idx, binding := range bindings {
if err := binding.GetLatestValue(ctx, params, returnVal, results[idx]); err != nil {
if err := binding.GetLatestValue(ctx, addresses[idx], params, returnVal, results[idx]); err != nil {
localCancel()

wg.Wait()
Expand Down Expand Up @@ -187,17 +205,36 @@ func (s *SolanaChainReaderService) QueryKey(_ context.Context, _ types.BoundCont
// Bind implements the types.ContractReader interface and allows new contract bindings to be added
// to the service.
func (s *SolanaChainReaderService) Bind(_ context.Context, bindings []types.BoundContract) error {
return s.bindings.Bind(bindings)
for _, binding := range bindings {
if err := s.bindings.Bind(binding); err != nil {
return err
}

s.lookup.bindAddressForContract(binding.Name, binding.Address)
}

return nil
}

func (s *SolanaChainReaderService) Unbind(_ context.Context, _ []types.BoundContract) error {
return errors.New("unimplemented")
// Unbind implements the types.ContractReader interface and allows existing contract bindings to be removed
// from the service.
func (s *SolanaChainReaderService) Unbind(_ context.Context, bindings []types.BoundContract) error {
for _, binding := range bindings {
s.lookup.unbindAddressForContract(binding.Name, binding.Address)
}

return nil
}

// CreateContractType implements the ContractTypeProvider interface and allows the chain reader
// service to explicitly define the expected type for a grpc server to provide.
func (s *SolanaChainReaderService) CreateContractType(contractName, itemType string, forEncoding bool) (any, error) {
return s.bindings.CreateType(contractName, itemType, forEncoding)
func (s *SolanaChainReaderService) CreateContractType(readIdentifier string, forEncoding bool) (any, error) {
values, ok := s.lookup.getContractForReadIdentifiers(readIdentifier)
if !ok {
return nil, fmt.Errorf("%w: no contract for read identifier", types.ErrInvalidConfig)
}

return s.bindings.CreateType(values.contract, values.readName, forEncoding)
}

func (s *SolanaChainReaderService) init(namespaces map[string]config.ChainReaderMethods) error {
Expand All @@ -213,6 +250,8 @@ func (s *SolanaChainReaderService) init(namespaces map[string]config.ChainReader
return err
}

s.lookup.addReadNameForContract(namespace, methodName)

for _, procedure := range method.Procedures {
mod, err := procedure.OutputModifications.ToModifier(codec.DecoderHooks...)
if err != nil {
Expand Down Expand Up @@ -275,3 +314,19 @@ func (r *accountDataReader) ReadAll(ctx context.Context, pk ag_solana.PublicKey,

return bts, nil
}

func decodeAddressMappings(encoded string) (map[string][]string, error) {
decoded, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return nil, err
}

var readAddresses map[string][]string

err = json.Unmarshal(decoded, &readAddresses)
if err != nil {
return nil, err
}

return readAddresses, nil
}
Loading

0 comments on commit 5b8d287

Please sign in to comment.