Skip to content

Commit

Permalink
Added Preparex to the DB statements
Browse files Browse the repository at this point in the history
  • Loading branch information
begmaroman committed Jul 17, 2024
1 parent d1534fc commit 9739f0d
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 97 deletions.
5 changes: 4 additions & 1 deletion cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ func start(cliCtx *cli.Context) error {
log.Fatal(err)
}

storage := db.New(pg)
storage, err := db.New(cliCtx.Context, pg)
if err != nil {
log.Fatal(err)
}

// Load private key
pk, err := config.NewKeyFromKeystore(c.PrivateKey)
Expand Down
126 changes: 85 additions & 41 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,31 +104,69 @@ type DB interface {
// DB is the database layer of the data node
type pgDB struct {
pg *sqlx.DB

storeLastProcessedBlockStmt *sqlx.Stmt
getLastProcessedBlockStmt *sqlx.Stmt
getUnresolvedBatchKeysStmt *sqlx.Stmt
getOffChainDataStmt *sqlx.Stmt
countOffChainDataStmt *sqlx.Stmt
detectOffChainDataGapsStmt *sqlx.Stmt
}

// New instantiates a DB
func New(pg *sqlx.DB) DB {
return &pgDB{
pg: pg,
func New(ctx context.Context, pg *sqlx.DB) (DB, error) {
storeLastProcessedBlockStmt, err := pg.PreparexContext(ctx, storeLastProcessedBlockSQL)
if err != nil {
return nil, fmt.Errorf("failed to prepare the store last processed block statement: %w", err)
}

getLastProcessedBlockStmt, err := pg.PreparexContext(ctx, getLastProcessedBlockSQL)
if err != nil {
return nil, fmt.Errorf("failed to prepare the get last processed block statement: %w", err)
}

getUnresolvedBatchKeysStmt, err := pg.PreparexContext(ctx, getUnresolvedBatchKeysSQL)
if err != nil {
return nil, fmt.Errorf("failed to prepare the get unresolved batch keys statement: %w", err)
}

getOffChainDataStmt, err := pg.PreparexContext(ctx, getOffchainDataSQL)
if err != nil {
return nil, fmt.Errorf("failed to prepare the get offchain data statement: %w", err)
}

countOffChainDataStmt, err := pg.PreparexContext(ctx, countOffchainDataSQL)
if err != nil {
return nil, fmt.Errorf("failed to prepare the count offchain data statement: %w", err)
}

detectOffChainDataGapsStmt, err := pg.PreparexContext(ctx, selectOffchainDataGapsSQL)
if err != nil {
return nil, fmt.Errorf("failed to prepare the detect offchain data gaps statement: %w", err)
}

return &pgDB{
pg: pg,
storeLastProcessedBlockStmt: storeLastProcessedBlockStmt,
getLastProcessedBlockStmt: getLastProcessedBlockStmt,
getUnresolvedBatchKeysStmt: getUnresolvedBatchKeysStmt,
getOffChainDataStmt: getOffChainDataStmt,
countOffChainDataStmt: countOffChainDataStmt,
detectOffChainDataGapsStmt: detectOffChainDataGapsStmt,
}, nil
}

// StoreLastProcessedBlock stores a record of a block processed by the synchronizer for named task
func (db *pgDB) StoreLastProcessedBlock(ctx context.Context, block uint64, task string) error {
if _, err := db.pg.ExecContext(ctx, storeLastProcessedBlockSQL, task, block); err != nil {
return err
}

return nil
_, err := db.storeLastProcessedBlockStmt.ExecContext(ctx, task, block)
return err
}

// GetLastProcessedBlock returns the latest block successfully processed by the synchronizer for named task
func (db *pgDB) GetLastProcessedBlock(ctx context.Context, task string) (uint64, error) {
var (
lastBlock uint64
)
var lastBlock uint64

if err := db.pg.QueryRowContext(ctx, getLastProcessedBlockSQL, task).Scan(&lastBlock); err != nil {
if err := db.getLastProcessedBlockStmt.QueryRowContext(ctx, task).Scan(&lastBlock); err != nil {
return 0, err
}

Expand All @@ -142,12 +180,13 @@ func (db *pgDB) StoreUnresolvedBatchKeys(ctx context.Context, bks []types.BatchK
return err
}

stmt, err := tx.PreparexContext(ctx, storeUnresolvedBatchesSQL)
if err != nil {
return err
}

for _, bk := range bks {
if _, err = tx.ExecContext(
ctx, storeUnresolvedBatchesSQL,
bk.Number,
bk.Hash.Hex(),
); err != nil {
if _, err = stmt.ExecContext(ctx, bk.Number, bk.Hash.Hex()); err != nil {
if txErr := tx.Rollback(); txErr != nil {
return fmt.Errorf("%v: rollback caused by %v", txErr, err)
}
Expand All @@ -161,19 +200,21 @@ func (db *pgDB) StoreUnresolvedBatchKeys(ctx context.Context, bks []types.BatchK

// GetUnresolvedBatchKeys returns the unresolved batch keys from the database
func (db *pgDB) GetUnresolvedBatchKeys(ctx context.Context, limit uint) ([]types.BatchKey, error) {
rows, err := db.pg.QueryxContext(ctx, getUnresolvedBatchKeysSQL, limit)
rows, err := db.getUnresolvedBatchKeysStmt.QueryxContext(ctx, limit)
if err != nil {
return nil, err
}

defer rows.Close()

type row struct {
Number uint64 `db:"num"`
Hash string `db:"hash"`
}

var bks []types.BatchKey
for rows.Next() {
bk := struct {
Number uint64 `db:"num"`
Hash string `db:"hash"`
}{}
bk := row{}
if err = rows.StructScan(&bk); err != nil {
return nil, err
}
Expand All @@ -194,12 +235,13 @@ func (db *pgDB) DeleteUnresolvedBatchKeys(ctx context.Context, bks []types.Batch
return err
}

stmt, err := tx.PreparexContext(ctx, deleteUnresolvedBatchKeysSQL)
if err != nil {
return err
}

for _, bk := range bks {
if _, err = tx.ExecContext(
ctx, deleteUnresolvedBatchKeysSQL,
bk.Number,
bk.Hash.Hex(),
); err != nil {
if _, err = stmt.ExecContext(ctx, bk.Number, bk.Hash.Hex()); err != nil {
if txErr := tx.Rollback(); txErr != nil {
return fmt.Errorf("%v: rollback caused by %v", txErr, err)
}
Expand All @@ -218,13 +260,13 @@ func (db *pgDB) StoreOffChainData(ctx context.Context, od []types.OffChainData)
return err
}

stmt, err := tx.PreparexContext(ctx, storeOffChainDataSQL)
if err != nil {
return err
}

for _, d := range od {
if _, err = tx.ExecContext(
ctx, storeOffChainDataSQL,
d.Key.Hex(),
common.Bytes2Hex(d.Value),
d.BatchNum,
); err != nil {
if _, err = stmt.ExecContext(ctx, d.Key.Hex(), common.Bytes2Hex(d.Value), d.BatchNum); err != nil {
if txErr := tx.Rollback(); txErr != nil {
return fmt.Errorf("%v: rollback caused by %v", txErr, err)
}
Expand All @@ -244,7 +286,7 @@ func (db *pgDB) GetOffChainData(ctx context.Context, key common.Hash) (*types.Of
BatchNum uint64 `db:"batch_num"`
}{}

if err := db.pg.QueryRowxContext(ctx, getOffchainDataSQL, key.Hex()).StructScan(&data); err != nil {
if err := db.getOffChainDataStmt.QueryRowxContext(ctx, key.Hex()).StructScan(&data); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrStateNotSynchronized
}
Expand Down Expand Up @@ -285,13 +327,15 @@ func (db *pgDB) ListOffChainData(ctx context.Context, keys []common.Hash) ([]typ

defer rows.Close()

type row struct {
Key string `db:"key"`
Value string `db:"value"`
BatchNum uint64 `db:"batch_num"`
}

list := make([]types.OffChainData, 0, len(keys))
for rows.Next() {
data := struct {
Key string `db:"key"`
Value string `db:"value"`
BatchNum uint64 `db:"batch_num"`
}{}
data := row{}
if err = rows.StructScan(&data); err != nil {
return nil, err
}
Expand All @@ -309,7 +353,7 @@ func (db *pgDB) ListOffChainData(ctx context.Context, keys []common.Hash) ([]typ
// CountOffchainData returns the count of rows in the offchain_data table
func (db *pgDB) CountOffchainData(ctx context.Context) (uint64, error) {
var count uint64
if err := db.pg.QueryRowContext(ctx, countOffchainDataSQL).Scan(&count); err != nil {
if err := db.countOffChainDataStmt.QueryRowContext(ctx).Scan(&count); err != nil {
return 0, err
}

Expand All @@ -318,7 +362,7 @@ func (db *pgDB) CountOffchainData(ctx context.Context) (uint64, error) {

// DetectOffchainDataGaps returns the number of gaps in the offchain_data table
func (db *pgDB) DetectOffchainDataGaps(ctx context.Context) (map[uint64]uint64, error) {
rows, err := db.pg.QueryxContext(ctx, selectOffchainDataGapsSQL)
rows, err := db.detectOffChainDataGapsStmt.QueryxContext(ctx)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 9739f0d

Please sign in to comment.