Skip to content

Commit

Permalink
core/services/relay/evm: switch RequestRound DB & Tracker to use sqlu…
Browse files Browse the repository at this point in the history
…til.DataSource (#12706)
  • Loading branch information
jmank88 authored Apr 5, 2024
1 parent ee52be7 commit 1efb525
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 61 deletions.
1 change: 1 addition & 0 deletions core/services/chainlink/relayer_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func (r *RelayerFactory) NewEVM(ctx context.Context, config EVMFactoryConfig) (m

relayerOpts := evmrelay.RelayerOpts{
DB: ccOpts.SqlxDB,
DS: ccOpts.DB,
QConfig: ccOpts.AppConfig.Database(),
CSAETHKeystore: config.CSAETHKeystore,
MercuryPool: r.MercuryPool,
Expand Down
1 change: 1 addition & 0 deletions core/services/job/spawner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) {

evmRelayer, err := evmrelayer.NewRelayer(lggr, chain, evmrelayer.RelayerOpts{
DB: db,
DS: db,
QConfig: testopts.GeneralConfig.Database(),
CSAETHKeystore: keyStore,
})
Expand Down
13 changes: 10 additions & 3 deletions core/services/relay/evm/evm.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types"

"github.com/smartcontractkit/chainlink-common/pkg/services"
"github.com/smartcontractkit/chainlink-common/pkg/sqlutil"
commontypes "github.com/smartcontractkit/chainlink-common/pkg/types"

txmgrcommon "github.com/smartcontractkit/chainlink/v2/common/txmgr"
Expand Down Expand Up @@ -70,7 +71,8 @@ func init() {
var _ commontypes.Relayer = &Relayer{} //nolint:staticcheck

type Relayer struct {
db *sqlx.DB
db *sqlx.DB // legacy: prefer to use ds instead
ds sqlutil.DataSource
chain legacyevm.Chain
lggr logger.Logger
ks CSAETHKeystore
Expand All @@ -93,7 +95,8 @@ type CSAETHKeystore interface {
}

type RelayerOpts struct {
*sqlx.DB
*sqlx.DB // legacy: prefer to use ds instead
DS sqlutil.DataSource
pg.QConfig
CSAETHKeystore
MercuryPool wsrpc.Pool
Expand All @@ -104,6 +107,9 @@ func (c RelayerOpts) Validate() error {
if c.DB == nil {
err = errors.Join(err, errors.New("nil DB"))
}
if c.DS == nil {
err = errors.Join(err, errors.New("nil DataSource"))
}
if c.QConfig == nil {
err = errors.Join(err, errors.New("nil QConfig"))
}
Expand All @@ -129,6 +135,7 @@ func NewRelayer(lggr logger.Logger, chain legacyevm.Chain, opts RelayerOpts) (*R
cdcFactory := llo.NewChannelDefinitionCacheFactory(lggr, lloORM, chain.LogPoller())
return &Relayer{
db: opts.DB,
ds: opts.DS,
chain: chain,
lggr: lggr,
ks: opts.CSAETHKeystore,
Expand Down Expand Up @@ -588,7 +595,7 @@ func (r *Relayer) NewMedianProvider(rargs commontypes.RelayArgs, pargs commontyp
return nil, err
}

medianContract, err := newMedianContract(configWatcher.ContractConfigTracker(), configWatcher.contractAddress, configWatcher.chain, rargs.JobID, r.db, lggr)
medianContract, err := newMedianContract(configWatcher.ContractConfigTracker(), configWatcher.contractAddress, configWatcher.chain, rargs.JobID, r.ds, lggr)
if err != nil {
return nil, err
}
Expand Down
8 changes: 7 additions & 1 deletion core/services/relay/evm/evm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/jmoiron/sqlx"

"github.com/smartcontractkit/chainlink-common/pkg/sqlutil"
"github.com/smartcontractkit/chainlink/v2/core/internal/testutils/configtest"
"github.com/smartcontractkit/chainlink/v2/core/services/pg"
"github.com/smartcontractkit/chainlink/v2/core/services/relay/evm"
Expand All @@ -16,6 +17,7 @@ func TestRelayerOpts_Validate(t *testing.T) {
cfg := configtest.NewTestGeneralConfig(t)
type fields struct {
DB *sqlx.DB
DS sqlutil.DataSource
QConfig pg.QConfig
CSAETHKeystore evm.CSAETHKeystore
}
Expand All @@ -28,27 +30,31 @@ func TestRelayerOpts_Validate(t *testing.T) {
name: "all invalid",
fields: fields{
DB: nil,
DS: nil,
QConfig: nil,
CSAETHKeystore: nil,
},
wantErrContains: `nil DB
nil DataSource
nil QConfig
nil Keystore`,
},
{
name: "missing db, keystore",
name: "missing db, ds, keystore",
fields: fields{
DB: nil,
QConfig: cfg.Database(),
},
wantErrContains: `nil DB
nil DataSource
nil Keystore`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := evm.RelayerOpts{
DB: tt.fields.DB,
DS: tt.fields.DS,
QConfig: tt.fields.QConfig,
CSAETHKeystore: tt.fields.CSAETHKeystore,
}
Expand Down
13 changes: 6 additions & 7 deletions core/services/relay/evm/median.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ import (

"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
"github.com/smartcontractkit/libocr/gethwrappers2/ocr2aggregator"
"github.com/smartcontractkit/libocr/offchainreporting2/reportingplugin/median"
"github.com/smartcontractkit/libocr/offchainreporting2plus/types"
ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types"

"github.com/smartcontractkit/chainlink-common/pkg/services"
"github.com/smartcontractkit/chainlink-common/pkg/sqlutil"
"github.com/smartcontractkit/chainlink/v2/core/chains/legacyevm"
offchain_aggregator_wrapper "github.com/smartcontractkit/chainlink/v2/core/internal/gethwrappers2/generated/offchainaggregator"
"github.com/smartcontractkit/chainlink/v2/core/logger"
Expand All @@ -30,7 +30,7 @@ type medianContract struct {
requestRoundTracker *RequestRoundTracker
}

func newMedianContract(configTracker types.ContractConfigTracker, contractAddress common.Address, chain legacyevm.Chain, specID int32, db *sqlx.DB, lggr logger.Logger) (*medianContract, error) {
func newMedianContract(configTracker types.ContractConfigTracker, contractAddress common.Address, chain legacyevm.Chain, specID int32, ds sqlutil.DataSource, lggr logger.Logger) (*medianContract, error) {
lggr = lggr.Named("MedianContract")
contract, err := offchain_aggregator_wrapper.NewOffchainAggregator(contractAddress, chain.Client())
if err != nil {
Expand Down Expand Up @@ -58,16 +58,15 @@ func newMedianContract(configTracker types.ContractConfigTracker, contractAddres
chain.LogBroadcaster(),
specID,
lggr,
db,
NewRoundRequestedDB(db.DB, specID, lggr),
ds,
NewRoundRequestedDB(ds, specID, lggr),
chain.Config().EVM(),
chain.Config().Database(),
),
}, nil
}
func (oc *medianContract) Start(context.Context) error {
func (oc *medianContract) Start(ctx context.Context) error {
return oc.StartOnce("MedianContract", func() error {
return oc.requestRoundTracker.Start()
return oc.requestRoundTracker.Start(ctx)
})
}

Expand Down
53 changes: 37 additions & 16 deletions core/services/relay/evm/mocks/request_round_db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 18 additions & 11 deletions core/services/relay/evm/request_round_db.go
Original file line number Diff line number Diff line change
@@ -1,43 +1,50 @@
package evm

import (
"database/sql"
"context"
"encoding/json"

"github.com/pkg/errors"
"github.com/smartcontractkit/libocr/gethwrappers2/ocr2aggregator"
ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types"

"github.com/smartcontractkit/chainlink-common/pkg/sqlutil"
"github.com/smartcontractkit/chainlink/v2/core/logger"
"github.com/smartcontractkit/chainlink/v2/core/services/pg"
)

// RequestRoundDB stores requested rounds for querying by the median plugin.
type RequestRoundDB interface {
SaveLatestRoundRequested(tx pg.Queryer, rr ocr2aggregator.OCR2AggregatorRoundRequested) error
LoadLatestRoundRequested() (rr ocr2aggregator.OCR2AggregatorRoundRequested, err error)
SaveLatestRoundRequested(ctx context.Context, rr ocr2aggregator.OCR2AggregatorRoundRequested) error
LoadLatestRoundRequested(context.Context) (rr ocr2aggregator.OCR2AggregatorRoundRequested, err error)
Transact(context.Context, func(db RequestRoundDB) error) error
}

var _ RequestRoundDB = &requestRoundDB{}

//go:generate mockery --quiet --name RequestRoundDB --output ./mocks/ --case=underscore
type requestRoundDB struct {
*sql.DB
ds sqlutil.DataSource
oracleSpecID int32
lggr logger.Logger
}

// NewDB returns a new DB scoped to this oracleSpecID
func NewRoundRequestedDB(sqldb *sql.DB, oracleSpecID int32, lggr logger.Logger) *requestRoundDB {
return &requestRoundDB{sqldb, oracleSpecID, lggr}
func NewRoundRequestedDB(ds sqlutil.DataSource, oracleSpecID int32, lggr logger.Logger) *requestRoundDB {
return &requestRoundDB{ds, oracleSpecID, lggr}
}

func (d *requestRoundDB) SaveLatestRoundRequested(tx pg.Queryer, rr ocr2aggregator.OCR2AggregatorRoundRequested) error {
func (d *requestRoundDB) Transact(ctx context.Context, fn func(db RequestRoundDB) error) error {
return sqlutil.Transact(ctx, func(ds sqlutil.DataSource) RequestRoundDB {
return NewRoundRequestedDB(ds, d.oracleSpecID, d.lggr)
}, d.ds, nil, fn)
}

func (d *requestRoundDB) SaveLatestRoundRequested(ctx context.Context, rr ocr2aggregator.OCR2AggregatorRoundRequested) error {
rawLog, err := json.Marshal(rr.Raw)
if err != nil {
return errors.Wrap(err, "could not marshal log as JSON")
}
_, err = tx.Exec(`
_, err = d.ds.ExecContext(ctx, `
INSERT INTO ocr2_latest_round_requested (ocr2_oracle_spec_id, requester, config_digest, epoch, round, raw)
VALUES ($1,$2,$3,$4,$5,$6) ON CONFLICT (ocr2_oracle_spec_id) DO UPDATE SET
requester = EXCLUDED.requester,
Expand All @@ -50,9 +57,9 @@ VALUES ($1,$2,$3,$4,$5,$6) ON CONFLICT (ocr2_oracle_spec_id) DO UPDATE SET
return errors.Wrap(err, "could not save latest round requested")
}

func (d *requestRoundDB) LoadLatestRoundRequested() (ocr2aggregator.OCR2AggregatorRoundRequested, error) {
func (d *requestRoundDB) LoadLatestRoundRequested(ctx context.Context) (ocr2aggregator.OCR2AggregatorRoundRequested, error) {
rr := ocr2aggregator.OCR2AggregatorRoundRequested{}
rows, err := d.Query(`
rows, err := d.ds.QueryContext(ctx, `
SELECT requester, config_digest, epoch, round, raw
FROM ocr2_latest_round_requested
WHERE ocr2_oracle_spec_id = $1
Expand Down
18 changes: 9 additions & 9 deletions core/services/relay/evm/request_round_db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest"
"github.com/smartcontractkit/chainlink/v2/core/logger"
"github.com/smartcontractkit/chainlink/v2/core/services/ocr2/testhelpers"
"github.com/smartcontractkit/chainlink/v2/core/services/pg"
"github.com/smartcontractkit/chainlink/v2/core/services/relay/evm"
)

Expand All @@ -23,8 +22,8 @@ func Test_DB_LatestRoundRequested(t *testing.T) {
require.NoError(t, err)

lggr := logger.TestLogger(t)
db := evm.NewRoundRequestedDB(sqlDB.DB, 1, lggr)
db2 := evm.NewRoundRequestedDB(sqlDB.DB, 2, lggr)
db := evm.NewRoundRequestedDB(sqlDB, 1, lggr)
db2 := evm.NewRoundRequestedDB(sqlDB, 2, lggr)

rawLog := cltest.LogFromFixture(t, "../../../testdata/jsonrpc/round_requested_log_1_1.json")

Expand All @@ -38,8 +37,8 @@ func Test_DB_LatestRoundRequested(t *testing.T) {

t.Run("saves latest round requested", func(t *testing.T) {
ctx := testutils.Context(t)
err := pg.SqlxTransaction(ctx, sqlDB, logger.TestLogger(t), func(q pg.Queryer) error {
return db.SaveLatestRoundRequested(q, rr)
err := db.Transact(ctx, func(tx evm.RequestRoundDB) error {
return tx.SaveLatestRoundRequested(ctx, rr)
})
require.NoError(t, err)

Expand All @@ -54,19 +53,20 @@ func Test_DB_LatestRoundRequested(t *testing.T) {
Raw: rawLog,
}

err = pg.SqlxTransaction(ctx, sqlDB, logger.TestLogger(t), func(q pg.Queryer) error {
return db.SaveLatestRoundRequested(q, rr)
err = db.Transact(ctx, func(tx evm.RequestRoundDB) error {
return tx.SaveLatestRoundRequested(ctx, rr)
})
require.NoError(t, err)
})

t.Run("loads latest round requested", func(t *testing.T) {
ctx := testutils.Context(t)
// There is no round for db2
lrr, err := db2.LoadLatestRoundRequested()
lrr, err := db2.LoadLatestRoundRequested(ctx)
require.NoError(t, err)
require.Equal(t, 0, int(lrr.Epoch))

lrr, err = db.LoadLatestRoundRequested()
lrr, err = db.LoadLatestRoundRequested(ctx)
require.NoError(t, err)

assert.Equal(t, rr, lrr)
Expand Down
Loading

0 comments on commit 1efb525

Please sign in to comment.