Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed unnecessary DB transactions use #99

Merged
merged 8 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 50 additions & 54 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package db
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"

"github.com/0xPolygon/cdk-data-availability/types"
"github.com/ethereum/go-ethereum/common"
Expand All @@ -18,30 +18,21 @@ var (

// DB defines functions that a DB instance should implement
type DB interface {
BeginStateTransaction(ctx context.Context) (Tx, error)

StoreLastProcessedBlock(ctx context.Context, task string, block uint64, dbTx sqlx.ExecerContext) error
StoreLastProcessedBlock(ctx context.Context, task string, block uint64) error
GetLastProcessedBlock(ctx context.Context, task string) (uint64, error)

StoreUnresolvedBatchKeys(ctx context.Context, bks []types.BatchKey, dbTx sqlx.ExecerContext) error
StoreUnresolvedBatchKeys(ctx context.Context, bks []types.BatchKey) error
GetUnresolvedBatchKeys(ctx context.Context, limit uint) ([]types.BatchKey, error)
DeleteUnresolvedBatchKeys(ctx context.Context, bks []types.BatchKey, dbTx sqlx.ExecerContext) error
DeleteUnresolvedBatchKeys(ctx context.Context, bks []types.BatchKey) error

Exists(ctx context.Context, key common.Hash) bool
GetOffChainData(ctx context.Context, key common.Hash, dbTx sqlx.QueryerContext) (types.ArgBytes, error)
ListOffChainData(ctx context.Context, keys []common.Hash, dbTx sqlx.QueryerContext) (map[common.Hash]types.ArgBytes, error)
StoreOffChainData(ctx context.Context, od []types.OffChainData, dbTx sqlx.ExecerContext) error
GetOffChainData(ctx context.Context, key common.Hash) (types.ArgBytes, error)
ListOffChainData(ctx context.Context, keys []common.Hash) (map[common.Hash]types.ArgBytes, error)
StoreOffChainData(ctx context.Context, od []types.OffChainData) error

CountOffchainData(ctx context.Context) (uint64, error)
}

// Tx is the interface that defines functions a db tx has to implement
type Tx interface {
sqlx.ExecerContext
sqlx.QueryerContext
driver.Tx
}

// DB is the database layer of the data node
type pgDB struct {
pg *sqlx.DB
Expand All @@ -54,21 +45,16 @@ func New(pg *sqlx.DB) DB {
}
}

// BeginStateTransaction begins a DB transaction. The caller is responsible for committing or rolling back the transaction
func (db *pgDB) BeginStateTransaction(ctx context.Context) (Tx, error) {
return db.pg.BeginTxx(ctx, nil)
}

// StoreLastProcessedBlock stores a record of a block processed by the synchronizer for named task
func (db *pgDB) StoreLastProcessedBlock(ctx context.Context, task string, block uint64, dbTx sqlx.ExecerContext) error {
func (db *pgDB) StoreLastProcessedBlock(ctx context.Context, task string, block uint64) error {
const storeLastProcessedBlockSQL = `
INSERT INTO data_node.sync_tasks (task, block)
VALUES ($1, $2)
ON CONFLICT (task) DO UPDATE
SET block = EXCLUDED.block, processed = NOW();
`

if _, err := db.execer(dbTx).ExecContext(ctx, storeLastProcessedBlockSQL, task, block); err != nil {
if _, err := db.pg.ExecContext(ctx, storeLastProcessedBlockSQL, task, block); err != nil {
return err
}

Expand All @@ -91,25 +77,33 @@ func (db *pgDB) GetLastProcessedBlock(ctx context.Context, task string) (uint64,
}

// StoreUnresolvedBatchKeys stores unresolved batch keys in the database
func (db *pgDB) StoreUnresolvedBatchKeys(ctx context.Context, bks []types.BatchKey, dbTx sqlx.ExecerContext) error {
func (db *pgDB) StoreUnresolvedBatchKeys(ctx context.Context, bks []types.BatchKey) error {
const storeUnresolvedBatchesSQL = `
INSERT INTO data_node.unresolved_batches (num, hash)
VALUES ($1, $2)
ON CONFLICT (num, hash) DO NOTHING;
`

execer := db.execer(dbTx)
tx, err := db.pg.BeginTxx(ctx, nil)
if err != nil {
return err
}

for _, bk := range bks {
if _, err := execer.ExecContext(
if _, err = tx.ExecContext(
ctx, storeUnresolvedBatchesSQL,
bk.Number,
bk.Hash.Hex(),
); err != nil {
if txErr := tx.Rollback(); txErr != nil {
return fmt.Errorf("%v: rollback caused by %v", txErr, err)
}

return err
}
}

return nil
return tx.Commit()
}

// GetUnresolvedBatchKeys returns the unresolved batch keys from the database
Expand Down Expand Up @@ -143,23 +137,32 @@ func (db *pgDB) GetUnresolvedBatchKeys(ctx context.Context, limit uint) ([]types
}

// DeleteUnresolvedBatchKeys deletes the unresolved batch keys from the database
func (db *pgDB) DeleteUnresolvedBatchKeys(ctx context.Context, bks []types.BatchKey, dbTx sqlx.ExecerContext) error {
func (db *pgDB) DeleteUnresolvedBatchKeys(ctx context.Context, bks []types.BatchKey) error {
const deleteUnresolvedBatchKeysSQL = `
DELETE FROM data_node.unresolved_batches
WHERE num = $1 AND hash = $2;
`

tx, err := db.pg.BeginTxx(ctx, nil)
if err != nil {
return err
}

for _, bk := range bks {
if _, err := db.execer(dbTx).ExecContext(
if _, err = tx.ExecContext(
ctx, deleteUnresolvedBatchKeysSQL,
bk.Number,
bk.Hash.Hex(),
); err != nil {
if txErr := tx.Rollback(); txErr != nil {
return fmt.Errorf("%v: rollback caused by %v", txErr, err)
}

return err
}
}

return nil
return tx.Commit()
}

// Exists checks if a key exists in offchain data table
Expand All @@ -178,29 +181,37 @@ func (db *pgDB) Exists(ctx context.Context, key common.Hash) bool {
}

// StoreOffChainData stores and array of key values in the Db
func (db *pgDB) StoreOffChainData(ctx context.Context, od []types.OffChainData, dbTx sqlx.ExecerContext) error {
func (db *pgDB) StoreOffChainData(ctx context.Context, od []types.OffChainData) error {
const storeOffChainDataSQL = `
INSERT INTO data_node.offchain_data (key, value)
VALUES ($1, $2)
ON CONFLICT (key) DO NOTHING;
`

execer := db.execer(dbTx)
tx, err := db.pg.BeginTxx(ctx, nil)
if err != nil {
return err
}

for _, d := range od {
if _, err := execer.ExecContext(
if _, err = tx.ExecContext(
ctx, storeOffChainDataSQL,
d.Key.Hex(),
common.Bytes2Hex(d.Value),
); err != nil {
if txErr := tx.Rollback(); txErr != nil {
return fmt.Errorf("%v: rollback caused by %v", txErr, err)
}

return err
}
}

return nil
return tx.Commit()
}

// GetOffChainData returns the value identified by the key
func (db *pgDB) GetOffChainData(ctx context.Context, key common.Hash, dbTx sqlx.QueryerContext) (types.ArgBytes, error) {
func (db *pgDB) GetOffChainData(ctx context.Context, key common.Hash) (types.ArgBytes, error) {
const getOffchainDataSQL = `
SELECT value
FROM data_node.offchain_data
Expand All @@ -211,18 +222,19 @@ func (db *pgDB) GetOffChainData(ctx context.Context, key common.Hash, dbTx sqlx.
hexValue string
)

if err := db.querier(dbTx).QueryRowxContext(ctx, getOffchainDataSQL, key.Hex()).Scan(&hexValue); err != nil {
if err := db.pg.QueryRowxContext(ctx, getOffchainDataSQL, key.Hex()).Scan(&hexValue); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrStateNotSynchronized
}

return nil, err
}

return common.FromHex(hexValue), nil
}

// ListOffChainData returns values identified by the given keys
func (db *pgDB) ListOffChainData(ctx context.Context, keys []common.Hash, dbTx sqlx.QueryerContext) (map[common.Hash]types.ArgBytes, error) {
func (db *pgDB) ListOffChainData(ctx context.Context, keys []common.Hash) (map[common.Hash]types.ArgBytes, error) {
if len(keys) == 0 {
return nil, nil
}
Expand All @@ -246,7 +258,7 @@ func (db *pgDB) ListOffChainData(ctx context.Context, keys []common.Hash, dbTx s
// sqlx.In returns queries with the `?` bindvar, we can rebind it for our backend
query = db.pg.Rebind(query)

rows, err := db.querier(dbTx).QueryxContext(ctx, query, args...)
rows, err := db.pg.QueryxContext(ctx, query, args...)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -280,19 +292,3 @@ func (db *pgDB) CountOffchainData(ctx context.Context) (uint64, error) {

return count, nil
}

func (db *pgDB) execer(dbTx sqlx.ExecerContext) sqlx.ExecerContext {
if dbTx != nil {
return dbTx
}

return db.pg
}

func (db *pgDB) querier(dbTx sqlx.QueryerContext) sqlx.QueryerContext {
if dbTx != nil {
return dbTx
}

return db.pg
}
48 changes: 27 additions & 21 deletions db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func Test_DB_StoreLastProcessedBlock(t *testing.T) {

dbPG := New(wdb)

err = dbPG.StoreLastProcessedBlock(context.Background(), tt.task, tt.block, wdb)
err = dbPG.StoreLastProcessedBlock(context.Background(), tt.task, tt.block)
if tt.returnErr != nil {
require.ErrorIs(t, err, tt.returnErr)
} else {
Expand Down Expand Up @@ -116,7 +116,7 @@ func Test_DB_GetLastProcessedBlock(t *testing.T) {

dbPG := New(wdb)

err = dbPG.StoreLastProcessedBlock(context.Background(), tt.task, tt.block, wdb)
err = dbPG.StoreLastProcessedBlock(context.Background(), tt.task, tt.block)
require.NoError(t, err)

actual, err := dbPG.GetLastProcessedBlock(context.Background(), tt.task)
Expand Down Expand Up @@ -179,6 +179,7 @@ func Test_DB_StoreUnresolvedBatchKeys(t *testing.T) {

defer db.Close()

mock.ExpectBegin()
for _, o := range tt.bk {
expected := mock.ExpectExec(`INSERT INTO data_node\.unresolved_batches \(num, hash\) VALUES \(\$1, \$2\) ON CONFLICT \(num, hash\) DO NOTHING`).
WithArgs(o.Number, o.Hash.Hex())
Expand All @@ -188,12 +189,17 @@ func Test_DB_StoreUnresolvedBatchKeys(t *testing.T) {
expected.WillReturnResult(sqlmock.NewResult(int64(len(tt.bk)), int64(len(tt.bk))))
}
}
if tt.returnErr == nil {
mock.ExpectCommit()
} else {
mock.ExpectRollback()
}

wdb := sqlx.NewDb(db, "postgres")

dbPG := New(wdb)

err = dbPG.StoreUnresolvedBatchKeys(context.Background(), tt.bk, wdb)
err = dbPG.StoreUnresolvedBatchKeys(context.Background(), tt.bk)
if tt.returnErr != nil {
require.ErrorIs(t, err, tt.returnErr)
} else {
Expand Down Expand Up @@ -304,6 +310,7 @@ func Test_DB_DeleteUnresolvedBatchKeys(t *testing.T) {

defer db.Close()

mock.ExpectBegin()
for _, bk := range tt.bks {
expected := mock.ExpectExec(`DELETE FROM data_node\.unresolved_batches WHERE num = \$1 AND hash = \$2`).
WithArgs(bk.Number, bk.Hash.Hex())
Expand All @@ -313,12 +320,17 @@ func Test_DB_DeleteUnresolvedBatchKeys(t *testing.T) {
expected.WillReturnResult(sqlmock.NewResult(int64(len(tt.bks)), int64(len(tt.bks))))
}
}
if tt.returnErr != nil {
mock.ExpectRollback()
} else {
mock.ExpectCommit()
}

wdb := sqlx.NewDb(db, "postgres")

dbPG := New(wdb)

err = dbPG.DeleteUnresolvedBatchKeys(context.Background(), tt.bks, wdb)
err = dbPG.DeleteUnresolvedBatchKeys(context.Background(), tt.bks)
if tt.returnErr != nil {
require.ErrorIs(t, err, tt.returnErr)
} else {
Expand Down Expand Up @@ -377,6 +389,7 @@ func Test_DB_StoreOffChainData(t *testing.T) {

defer db.Close()

mock.ExpectBegin()
for _, o := range tt.od {
expected := mock.ExpectExec(`INSERT INTO data_node\.offchain_data \(key, value\) VALUES \(\$1, \$2\) ON CONFLICT \(key\) DO NOTHING`).
WithArgs(o.Key.Hex(), common.Bytes2Hex(o.Value))
Expand All @@ -386,12 +399,17 @@ func Test_DB_StoreOffChainData(t *testing.T) {
expected.WillReturnResult(sqlmock.NewResult(int64(len(tt.od)), int64(len(tt.od))))
}
}
if tt.returnErr == nil {
mock.ExpectCommit()
} else {
mock.ExpectRollback()
}

wdb := sqlx.NewDb(db, "postgres")

dbPG := New(wdb)

err = dbPG.StoreOffChainData(context.Background(), tt.od, wdb)
err = dbPG.StoreOffChainData(context.Background(), tt.od)
if tt.returnErr != nil {
require.ErrorIs(t, err, tt.returnErr)
} else {
Expand Down Expand Up @@ -467,7 +485,7 @@ func Test_DB_GetOffChainData(t *testing.T) {

dbPG := New(wdb)

data, err := dbPG.GetOffChainData(context.Background(), tt.key, wdb)
data, err := dbPG.GetOffChainData(context.Background(), tt.key)
if tt.returnErr != nil {
require.ErrorIs(t, err, tt.returnErr)
} else {
Expand Down Expand Up @@ -586,7 +604,7 @@ func Test_DB_ListOffChainData(t *testing.T) {

dbPG := New(wdb)

data, err := dbPG.ListOffChainData(context.Background(), tt.keys, wdb)
data, err := dbPG.ListOffChainData(context.Background(), tt.keys)
if tt.returnErr != nil {
require.ErrorIs(t, err, tt.returnErr)
} else {
Expand Down Expand Up @@ -760,13 +778,7 @@ func seedOffchainData(t *testing.T, db *sqlx.DB, mock sqlmock.Sqlmock, od []type
}
mock.ExpectCommit()

tx, err := db.BeginTxx(context.Background(), nil)
require.NoError(t, err)

err = New(db).StoreOffChainData(context.Background(), od, tx)
require.NoError(t, err)

err = tx.Commit()
err := New(db).StoreOffChainData(context.Background(), od)
require.NoError(t, err)
}

Expand All @@ -781,12 +793,6 @@ func seedUnresolvedBatchKeys(t *testing.T, db *sqlx.DB, mock sqlmock.Sqlmock, bk
}
mock.ExpectCommit()

tx, err := db.BeginTxx(context.Background(), nil)
require.NoError(t, err)

err = New(db).StoreUnresolvedBatchKeys(context.Background(), bk, tx)
require.NoError(t, err)

err = tx.Commit()
err := New(db).StoreUnresolvedBatchKeys(context.Background(), bk)
require.NoError(t, err)
}
Loading
Loading