Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
reductionista committed Jan 7, 2025
1 parent 1c74a3b commit c208014
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 18 deletions.
6 changes: 4 additions & 2 deletions integration-tests/smoke/event_loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

contract "github.com/smartcontractkit/chainlink-solana/contracts/generated/log_read_test"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/client"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/config"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/logpoller"

"github.com/smartcontractkit/chainlink-solana/integration-tests/solclient"
Expand All @@ -49,7 +50,8 @@ func TestEventLoader(t *testing.T) {
require.NoError(t, err)

rpcURL, wsURL := setupTestValidator(t, privateKey.PublicKey().String())
rpcClient := rpc.New(rpcURL)
cl, rpcClient, err := client.NewTestClient(rpcURL, config.NewDefault(), 1*time.Second, logger.Nop())
require.NoError(t, err)
wsClient, err := ws.Connect(ctx, wsURL)
require.NoError(t, err)

Expand All @@ -62,7 +64,7 @@ func TestEventLoader(t *testing.T) {
parser := &printParser{t: t}
sender := newLogSender(t, rpcClient, wsClient)
collector := logpoller.NewEncodedLogCollector(
rpcClient,
cl,
parser,
logger.Nop(),
)
Expand Down
15 changes: 11 additions & 4 deletions pkg/solana/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,25 @@ type Client struct {
requestGroup *singleflight.Group
}

func NewClient(endpoint string, cfg config.Config, requestTimeout time.Duration, log logger.Logger) (*Client, error) {
return &Client{
// Return both the client and the underlying rpc client for testing
func NewTestClient(endpoint string, cfg config.Config, requestTimeout time.Duration, log logger.Logger) (*Client, *rpc.Client, error) {
rpcClient := Client{
url: endpoint,
rpc: rpc.New(endpoint),
skipPreflight: cfg.SkipPreflight(),
commitment: cfg.Commitment(),
maxRetries: cfg.MaxRetries(),
txTimeout: cfg.TxTimeout(),
contextDuration: requestTimeout,
log: log,
requestGroup: &singleflight.Group{},
}, nil
}
rpcClient.rpc = rpc.New(endpoint)
return &rpcClient, rpcClient.rpc, nil
}

func NewClient(endpoint string, cfg config.Config, requestTimeout time.Duration, log logger.Logger) (*Client, error) {
rpcClient, _, err := NewTestClient(endpoint, cfg, requestTimeout, log)
return rpcClient, err
}

func (c *Client) latency(name string) func() {
Expand Down
22 changes: 10 additions & 12 deletions pkg/solana/logpoller/filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ import (

"github.com/gagliardetto/solana-go"
"github.com/smartcontractkit/chainlink-common/pkg/logger"

"github.com/smartcontractkit/chainlink-solana/pkg/solana/logpoller/utils"
)

type filters struct {
Expand Down Expand Up @@ -88,8 +86,6 @@ func (fl *filters) RegisterFilter(ctx context.Context, filter Filter) error {
return fmt.Errorf("failed to load filters: %w", err)
}

filter.EventSig = utils.Discriminator("event", filter.EventName)

fl.filtersMutex.Lock()
defer fl.filtersMutex.Unlock()

Expand Down Expand Up @@ -134,17 +130,17 @@ func (fl *filters) RegisterFilter(ctx context.Context, filter Filter) error {
}

programID := filter.Address.ToSolana().String()
if _, ok := fl.knownPrograms[programID]; !ok {
if _, ok = fl.knownPrograms[programID]; !ok {
fl.knownPrograms[programID] = 1
} else {
fl.knownPrograms[programID]++
}

discriminator := base64.StdEncoding.EncodeToString(filter.EventSig[:])[:10]
discriminatorHead := filter.Discriminator()[:10]
if _, ok := fl.knownPrograms[programID]; !ok {
fl.knownDiscriminators[discriminator] = 1
fl.knownDiscriminators[discriminatorHead] = 1
} else {
fl.knownDiscriminators[discriminator]++
fl.knownDiscriminators[discriminatorHead]++
}

return nil
Expand Down Expand Up @@ -220,13 +216,13 @@ func (fl *filters) removeFilterFromIndexes(filter Filter) {
}
}

discriminator := base64.StdEncoding.EncodeToString(filter.EventSig[:])[:10]
if refcount, ok := fl.knownDiscriminators[discriminator]; ok {
discriminatorHead := filter.Discriminator()[:10]
if refcount, ok := fl.knownDiscriminators[discriminatorHead]; ok {
refcount--
if refcount > 0 {
fl.knownDiscriminators[discriminator] = refcount
fl.knownDiscriminators[discriminatorHead] = refcount
} else {
delete(fl.knownDiscriminators, discriminator)
delete(fl.knownDiscriminators, discriminatorHead)
}
}
}
Expand Down Expand Up @@ -345,6 +341,8 @@ func (fl *filters) LoadFilters(ctx context.Context) error {
fl.filtersByAddress = make(map[PublicKey]map[EventSignature]map[int64]struct{})
fl.filtersToBackfill = make(map[int64]struct{})
fl.filtersToDelete = make(map[int64]Filter)
fl.knownPrograms = make(map[string]uint)
fl.knownDiscriminators = make(map[string]uint)

filters, err := fl.orm.SelectFilters(ctx)
if err != nil {
Expand Down
29 changes: 29 additions & 0 deletions pkg/solana/logpoller/filters_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ func TestFilters_LoadFilters(t *testing.T) {
happyPath2,
}, nil).Once()

orm.On("SelectSeqNums", mock.Anything).Return(map[int64]int64{
1: 18,
2: 25,
3: 0,
}, nil)

err := fs.LoadFilters(ctx)
require.EqualError(t, err, "failed to select filters from db: db failed")
err = fs.LoadFilters(ctx)
Expand Down Expand Up @@ -110,6 +116,7 @@ func TestFilters_RegisterFilter(t *testing.T) {
const filterName = "Filter"
dbFilter := Filter{Name: filterName}
orm.On("SelectFilters", mock.Anything).Return([]Filter{dbFilter}, nil).Once()
orm.On("SelectSeqNums", mock.Anything).Return(map[int64]int64{}, nil)
newFilter := dbFilter
tc.ModifyField(&newFilter)
err := fs.RegisterFilter(tests.Context(t), newFilter)
Expand All @@ -122,6 +129,7 @@ func TestFilters_RegisterFilter(t *testing.T) {
fs := newFilters(lggr, orm)
const filterName = "Filter"
orm.On("SelectFilters", mock.Anything).Return(nil, nil).Once()
orm.On("SelectSeqNums", mock.Anything).Return(map[int64]int64{}, nil).Once()
orm.On("InsertFilter", mock.Anything, mock.Anything).Return(int64(0), errors.New("failed to insert")).Once()
filter := Filter{Name: filterName}
err := fs.RegisterFilter(tests.Context(t), filter)
Expand Down Expand Up @@ -149,6 +157,7 @@ func TestFilters_RegisterFilter(t *testing.T) {
fs := newFilters(lggr, orm)
const filterName = "Filter"
orm.On("SelectFilters", mock.Anything).Return(nil, nil).Once()
orm.On("SelectSeqNums", mock.Anything).Return(map[int64]int64{}, nil).Once()
const filterID = int64(10)
orm.On("InsertFilter", mock.Anything, mock.Anything).Return(filterID, nil).Once()
err := fs.RegisterFilter(tests.Context(t), Filter{Name: filterName})
Expand Down Expand Up @@ -180,6 +189,7 @@ func TestFilters_UnregisterFilter(t *testing.T) {
fs := newFilters(lggr, orm)
const filterName = "Filter"
orm.On("SelectFilters", mock.Anything).Return(nil, nil).Once()
orm.On("SelectSeqNums", mock.Anything).Return(map[int64]int64{}, nil).Once()
err := fs.UnregisterFilter(tests.Context(t), filterName)
require.NoError(t, err)
})
Expand All @@ -189,6 +199,7 @@ func TestFilters_UnregisterFilter(t *testing.T) {
const filterName = "Filter"
const id int64 = 10
orm.On("SelectFilters", mock.Anything).Return([]Filter{{ID: id, Name: filterName}}, nil).Once()
orm.On("SelectSeqNums", mock.Anything).Return(map[int64]int64{}, nil).Once()
orm.On("MarkFilterDeleted", mock.Anything, id).Return(errors.New("db query failed")).Once()
err := fs.UnregisterFilter(tests.Context(t), filterName)
require.EqualError(t, err, "failed to mark filter deleted: db query failed")
Expand All @@ -199,6 +210,7 @@ func TestFilters_UnregisterFilter(t *testing.T) {
const filterName = "Filter"
const id int64 = 10
orm.On("SelectFilters", mock.Anything).Return([]Filter{{ID: id, Name: filterName}}, nil).Once()
orm.On("SelectSeqNums", mock.Anything).Return(map[int64]int64{}, nil).Once()
orm.On("MarkFilterDeleted", mock.Anything, id).Return(nil).Once()
err := fs.UnregisterFilter(tests.Context(t), filterName)
require.NoError(t, err)
Expand Down Expand Up @@ -226,6 +238,9 @@ func TestFilters_PruneFilters(t *testing.T) {
Name: "To keep",
},
}, nil).Once()
orm.On("SelectSeqNums", mock.Anything).Return(map[int64]int64{
2: 25,
}, nil).Once()
orm.On("DeleteFilters", mock.Anything, map[int64]Filter{toDelete.ID: toDelete}).Return(nil).Once()
err := fs.PruneFilters(tests.Context(t))
require.NoError(t, err)
Expand All @@ -246,6 +261,10 @@ func TestFilters_PruneFilters(t *testing.T) {
Name: "To keep",
},
}, nil).Once()
orm.EXPECT().SelectSeqNums(mock.Anything).Return(map[int64]int64{
1: 18,
2: 25,
}, nil).Once()
newToDelete := Filter{
ID: 3,
Name: "To delete 2",
Expand Down Expand Up @@ -291,6 +310,12 @@ func TestFilters_MatchingFilters(t *testing.T) {
EventSig: expectedFilter1.EventSig,
}
orm.On("SelectFilters", mock.Anything).Return([]Filter{expectedFilter1, expectedFilter2, sameAddress, sameEventSig}, nil).Once()
orm.On("SelectSeqNums", mock.Anything).Return(map[int64]int64{
1: 18,
2: 25,
3: 14,
4: 0,
}, nil)
filters := newFilters(lggr, orm)
err := filters.LoadFilters(tests.Context(t))
require.NoError(t, err)
Expand Down Expand Up @@ -319,6 +344,10 @@ func TestFilters_GetFiltersToBackfill(t *testing.T) {
Name: "notBackfilled",
}
orm.EXPECT().SelectFilters(mock.Anything).Return([]Filter{backfilledFilter, notBackfilled}, nil).Once()
orm.EXPECT().SelectSeqNums(mock.Anything).Return(map[int64]int64{
1: 18,
2: 25,
}, nil)
filters := newFilters(lggr, orm)
err := filters.LoadFilters(tests.Context(t))
require.NoError(t, err)
Expand Down
13 changes: 13 additions & 0 deletions pkg/solana/logpoller/loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,13 @@ func TestEncodedLogCollector_ParseSingleEvent(t *testing.T) {
GetBlockWithOpts(mock.Anything, mock.Anything, mock.Anything).
RunAndReturn(func(_ context.Context, slot uint64, _ *rpc.GetBlockOpts) (*rpc.GetBlockResult, error) {
height := slot - 1
timeStamp := solana.UnixTimeSeconds(time.Now().Unix())

result := rpc.GetBlockResult{
Transactions: []rpc.TransactionWithMeta{},
Signatures: []solana.Signature{},
BlockHeight: &height,
BlockTime: &timeStamp,
}

_, _ = rand.Read(result.Blockhash[:])
Expand Down Expand Up @@ -132,6 +134,8 @@ func TestEncodedLogCollector_MultipleEventOrdered(t *testing.T) {
hashes := make([]solana.Hash, len(slots))
scrambler := &slotUnsync{ch: make(chan struct{})}

timeStamp := solana.UnixTimeSeconds(time.Now().Unix())

for idx := range len(sigs) {
_, _ = rand.Read(sigs[idx][:])
_, _ = rand.Read(hashes[idx][:])
Expand Down Expand Up @@ -176,6 +180,7 @@ func TestEncodedLogCollector_MultipleEventOrdered(t *testing.T) {
Transactions: []rpc.TransactionWithMeta{},
Signatures: []solana.Signature{},
BlockHeight: &height,
BlockTime: &timeStamp,
}, nil
}

Expand All @@ -190,6 +195,7 @@ func TestEncodedLogCollector_MultipleEventOrdered(t *testing.T) {
},
Signatures: []solana.Signature{sigs[slotIdx]},
BlockHeight: &height,
BlockTime: &timeStamp,
}, nil
})

Expand All @@ -199,6 +205,7 @@ func TestEncodedLogCollector_MultipleEventOrdered(t *testing.T) {
BlockData: logpoller.BlockData{
SlotNumber: 41,
BlockHeight: 40,
BlockTime: timeStamp,
BlockHash: hashes[3],
TransactionHash: sigs[3],
TransactionIndex: 0,
Expand All @@ -211,6 +218,7 @@ func TestEncodedLogCollector_MultipleEventOrdered(t *testing.T) {
BlockData: logpoller.BlockData{
SlotNumber: 42,
BlockHeight: 41,
BlockTime: timeStamp,
BlockHash: hashes[2],
TransactionHash: sigs[2],
TransactionIndex: 0,
Expand All @@ -223,6 +231,7 @@ func TestEncodedLogCollector_MultipleEventOrdered(t *testing.T) {
BlockData: logpoller.BlockData{
SlotNumber: 43,
BlockHeight: 42,
BlockTime: timeStamp,
BlockHash: hashes[1],
TransactionHash: sigs[1],
TransactionIndex: 0,
Expand All @@ -235,6 +244,7 @@ func TestEncodedLogCollector_MultipleEventOrdered(t *testing.T) {
BlockData: logpoller.BlockData{
SlotNumber: 44,
BlockHeight: 43,
BlockTime: timeStamp,
BlockHash: hashes[0],
TransactionHash: sigs[0],
TransactionIndex: 0,
Expand Down Expand Up @@ -337,12 +347,14 @@ func TestEncodedLogCollector_BackfillForAddress(t *testing.T) {
}

height := slot - 1
timeStamp := solana.UnixTimeSeconds(time.Now().Unix())

if idx == -1 {
return &rpc.GetBlockResult{
Transactions: []rpc.TransactionWithMeta{},
Signatures: []solana.Signature{},
BlockHeight: &height,
BlockTime: &timeStamp,
}, nil
}

Expand All @@ -361,6 +373,7 @@ func TestEncodedLogCollector_BackfillForAddress(t *testing.T) {
},
Signatures: []solana.Signature{sigs[idx*2], sigs[(idx*2)+1]},
BlockHeight: &height,
BlockTime: &timeStamp,
}, nil
})

Expand Down
8 changes: 8 additions & 0 deletions pkg/solana/logpoller/models.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package logpoller

import (
"encoding/base64"
"time"

"github.com/lib/pq"

"github.com/smartcontractkit/chainlink-solana/pkg/solana/logpoller/utils"
)

type Filter struct {
Expand All @@ -26,6 +29,11 @@ func (f Filter) MatchSameLogs(other Filter) bool {
f.EventIdl.Equal(other.EventIdl) && f.SubkeyPaths.Equal(other.SubkeyPaths)
}

func (f Filter) Discriminator() string {
d := utils.Discriminator("event", f.Name)
return base64.StdEncoding.EncodeToString(d[:])
}

type Log struct {
ID int64
FilterID int64
Expand Down

0 comments on commit c208014

Please sign in to comment.