From 9739f0d4d973392e92470ec9bcbfb70fb89632f2 Mon Sep 17 00:00:00 2001 From: begmaroman Date: Wed, 17 Jul 2024 08:53:57 +0100 Subject: [PATCH] Added Preparex to the DB statements --- cmd/main.go | 5 +- db/db.go | 126 ++++++++++++++++++++++++++++++--------------- db/db_test.go | 139 ++++++++++++++++++++++++++++++-------------------- 3 files changed, 173 insertions(+), 97 deletions(-) diff --git a/cmd/main.go b/cmd/main.go index 0923e21..b18bf6c 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -87,7 +87,10 @@ func start(cliCtx *cli.Context) error { log.Fatal(err) } - storage := db.New(pg) + storage, err := db.New(cliCtx.Context, pg) + if err != nil { + log.Fatal(err) + } // Load private key pk, err := config.NewKeyFromKeystore(c.PrivateKey) diff --git a/db/db.go b/db/db.go index 683792f..2d3139e 100644 --- a/db/db.go +++ b/db/db.go @@ -104,31 +104,69 @@ type DB interface { // DB is the database layer of the data node type pgDB struct { pg *sqlx.DB + + storeLastProcessedBlockStmt *sqlx.Stmt + getLastProcessedBlockStmt *sqlx.Stmt + getUnresolvedBatchKeysStmt *sqlx.Stmt + getOffChainDataStmt *sqlx.Stmt + countOffChainDataStmt *sqlx.Stmt + detectOffChainDataGapsStmt *sqlx.Stmt } // New instantiates a DB -func New(pg *sqlx.DB) DB { - return &pgDB{ - pg: pg, +func New(ctx context.Context, pg *sqlx.DB) (DB, error) { + storeLastProcessedBlockStmt, err := pg.PreparexContext(ctx, storeLastProcessedBlockSQL) + if err != nil { + return nil, fmt.Errorf("failed to prepare the store last processed block statement: %w", err) + } + + getLastProcessedBlockStmt, err := pg.PreparexContext(ctx, getLastProcessedBlockSQL) + if err != nil { + return nil, fmt.Errorf("failed to prepare the get last processed block statement: %w", err) + } + + getUnresolvedBatchKeysStmt, err := pg.PreparexContext(ctx, getUnresolvedBatchKeysSQL) + if err != nil { + return nil, fmt.Errorf("failed to prepare the get unresolved batch keys statement: %w", err) + } + + getOffChainDataStmt, err := pg.PreparexContext(ctx, getOffchainDataSQL) + if err != nil { + return nil, fmt.Errorf("failed to prepare the get offchain data statement: %w", err) } + + countOffChainDataStmt, err := pg.PreparexContext(ctx, countOffchainDataSQL) + if err != nil { + return nil, fmt.Errorf("failed to prepare the count offchain data statement: %w", err) + } + + detectOffChainDataGapsStmt, err := pg.PreparexContext(ctx, selectOffchainDataGapsSQL) + if err != nil { + return nil, fmt.Errorf("failed to prepare the detect offchain data gaps statement: %w", err) + } + + return &pgDB{ + pg: pg, + storeLastProcessedBlockStmt: storeLastProcessedBlockStmt, + getLastProcessedBlockStmt: getLastProcessedBlockStmt, + getUnresolvedBatchKeysStmt: getUnresolvedBatchKeysStmt, + getOffChainDataStmt: getOffChainDataStmt, + countOffChainDataStmt: countOffChainDataStmt, + detectOffChainDataGapsStmt: detectOffChainDataGapsStmt, + }, nil } // StoreLastProcessedBlock stores a record of a block processed by the synchronizer for named task func (db *pgDB) StoreLastProcessedBlock(ctx context.Context, block uint64, task string) error { - if _, err := db.pg.ExecContext(ctx, storeLastProcessedBlockSQL, task, block); err != nil { - return err - } - - return nil + _, err := db.storeLastProcessedBlockStmt.ExecContext(ctx, task, block) + return err } // GetLastProcessedBlock returns the latest block successfully processed by the synchronizer for named task func (db *pgDB) GetLastProcessedBlock(ctx context.Context, task string) (uint64, error) { - var ( - lastBlock uint64 - ) + var lastBlock uint64 - if err := db.pg.QueryRowContext(ctx, getLastProcessedBlockSQL, task).Scan(&lastBlock); err != nil { + if err := db.getLastProcessedBlockStmt.QueryRowContext(ctx, task).Scan(&lastBlock); err != nil { return 0, err } @@ -142,12 +180,13 @@ func (db *pgDB) StoreUnresolvedBatchKeys(ctx context.Context, bks []types.BatchK return err } + stmt, err := tx.PreparexContext(ctx, storeUnresolvedBatchesSQL) + if err != nil { + return err + } + for _, bk := range bks { - if _, err = tx.ExecContext( - ctx, storeUnresolvedBatchesSQL, - bk.Number, - bk.Hash.Hex(), - ); err != nil { + if _, err = stmt.ExecContext(ctx, bk.Number, bk.Hash.Hex()); err != nil { if txErr := tx.Rollback(); txErr != nil { return fmt.Errorf("%v: rollback caused by %v", txErr, err) } @@ -161,19 +200,21 @@ func (db *pgDB) StoreUnresolvedBatchKeys(ctx context.Context, bks []types.BatchK // GetUnresolvedBatchKeys returns the unresolved batch keys from the database func (db *pgDB) GetUnresolvedBatchKeys(ctx context.Context, limit uint) ([]types.BatchKey, error) { - rows, err := db.pg.QueryxContext(ctx, getUnresolvedBatchKeysSQL, limit) + rows, err := db.getUnresolvedBatchKeysStmt.QueryxContext(ctx, limit) if err != nil { return nil, err } defer rows.Close() + type row struct { + Number uint64 `db:"num"` + Hash string `db:"hash"` + } + var bks []types.BatchKey for rows.Next() { - bk := struct { - Number uint64 `db:"num"` - Hash string `db:"hash"` - }{} + bk := row{} if err = rows.StructScan(&bk); err != nil { return nil, err } @@ -194,12 +235,13 @@ func (db *pgDB) DeleteUnresolvedBatchKeys(ctx context.Context, bks []types.Batch return err } + stmt, err := tx.PreparexContext(ctx, deleteUnresolvedBatchKeysSQL) + if err != nil { + return err + } + for _, bk := range bks { - if _, err = tx.ExecContext( - ctx, deleteUnresolvedBatchKeysSQL, - bk.Number, - bk.Hash.Hex(), - ); err != nil { + if _, err = stmt.ExecContext(ctx, bk.Number, bk.Hash.Hex()); err != nil { if txErr := tx.Rollback(); txErr != nil { return fmt.Errorf("%v: rollback caused by %v", txErr, err) } @@ -218,13 +260,13 @@ func (db *pgDB) StoreOffChainData(ctx context.Context, od []types.OffChainData) return err } + stmt, err := tx.PreparexContext(ctx, storeOffChainDataSQL) + if err != nil { + return err + } + for _, d := range od { - if _, err = tx.ExecContext( - ctx, storeOffChainDataSQL, - d.Key.Hex(), - common.Bytes2Hex(d.Value), - d.BatchNum, - ); err != nil { + if _, err = stmt.ExecContext(ctx, d.Key.Hex(), common.Bytes2Hex(d.Value), d.BatchNum); err != nil { if txErr := tx.Rollback(); txErr != nil { return fmt.Errorf("%v: rollback caused by %v", txErr, err) } @@ -244,7 +286,7 @@ func (db *pgDB) GetOffChainData(ctx context.Context, key common.Hash) (*types.Of BatchNum uint64 `db:"batch_num"` }{} - if err := db.pg.QueryRowxContext(ctx, getOffchainDataSQL, key.Hex()).StructScan(&data); err != nil { + if err := db.getOffChainDataStmt.QueryRowxContext(ctx, key.Hex()).StructScan(&data); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrStateNotSynchronized } @@ -285,13 +327,15 @@ func (db *pgDB) ListOffChainData(ctx context.Context, keys []common.Hash) ([]typ defer rows.Close() + type row struct { + Key string `db:"key"` + Value string `db:"value"` + BatchNum uint64 `db:"batch_num"` + } + list := make([]types.OffChainData, 0, len(keys)) for rows.Next() { - data := struct { - Key string `db:"key"` - Value string `db:"value"` - BatchNum uint64 `db:"batch_num"` - }{} + data := row{} if err = rows.StructScan(&data); err != nil { return nil, err } @@ -309,7 +353,7 @@ func (db *pgDB) ListOffChainData(ctx context.Context, keys []common.Hash) ([]typ // CountOffchainData returns the count of rows in the offchain_data table func (db *pgDB) CountOffchainData(ctx context.Context) (uint64, error) { var count uint64 - if err := db.pg.QueryRowContext(ctx, countOffchainDataSQL).Scan(&count); err != nil { + if err := db.countOffChainDataStmt.QueryRowContext(ctx).Scan(&count); err != nil { return 0, err } @@ -318,7 +362,7 @@ func (db *pgDB) CountOffchainData(ctx context.Context) (uint64, error) { // DetectOffchainDataGaps returns the number of gaps in the offchain_data table func (db *pgDB) DetectOffchainDataGaps(ctx context.Context) (map[uint64]uint64, error) { - rows, err := db.pg.QueryxContext(ctx, selectOffchainDataGapsSQL) + rows, err := db.detectOffChainDataGapsStmt.QueryxContext(ctx) if err != nil { return nil, err } diff --git a/db/db_test.go b/db/db_test.go index 6a18698..3881041 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql/driver" "errors" + "regexp" "testing" "github.com/0xPolygon/cdk-data-availability/types" @@ -46,6 +47,8 @@ func Test_DB_StoreLastProcessedBlock(t *testing.T) { defer db.Close() + constructorExpect(mock) + expected := mock.ExpectExec(`INSERT INTO data_node\.sync_tasks \(task, block\) VALUES \(\$1, \$2\) ON CONFLICT \(task\) DO UPDATE SET block = EXCLUDED\.block, processed = NOW\(\)`). WithArgs(tt.task, tt.block) if tt.returnErr != nil { @@ -56,7 +59,8 @@ func Test_DB_StoreLastProcessedBlock(t *testing.T) { wdb := sqlx.NewDb(db, "postgres") - dbPG := New(wdb) + dbPG, err := New(context.Background(), wdb) + require.NoError(t, err) err = dbPG.StoreLastProcessedBlock(context.Background(), tt.block, tt.task) if tt.returnErr != nil { @@ -103,6 +107,8 @@ func Test_DB_GetLastProcessedBlock(t *testing.T) { defer db.Close() + constructorExpect(mock) + mock.ExpectExec(`INSERT INTO data_node\.sync_tasks \(task, block\) VALUES \(\$1, \$2\) ON CONFLICT \(task\) DO UPDATE SET block = EXCLUDED\.block, processed = NOW\(\)`). WithArgs(tt.task, tt.block). WillReturnResult(sqlmock.NewResult(1, 1)) @@ -118,7 +124,8 @@ func Test_DB_GetLastProcessedBlock(t *testing.T) { wdb := sqlx.NewDb(db, "postgres") - dbPG := New(wdb) + dbPG, err := New(context.Background(), wdb) + require.NoError(t, err) err = dbPG.StoreLastProcessedBlock(context.Background(), tt.block, tt.task) require.NoError(t, err) @@ -183,11 +190,24 @@ func Test_DB_StoreUnresolvedBatchKeys(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) + wdb := sqlx.NewDb(db, "postgres") + + mock.ExpectPrepare(regexp.QuoteMeta(storeLastProcessedBlockSQL)) + mock.ExpectPrepare(regexp.QuoteMeta(getLastProcessedBlockSQL)) + mock.ExpectPrepare(regexp.QuoteMeta(getUnresolvedBatchKeysSQL)) + mock.ExpectPrepare(regexp.QuoteMeta(getOffchainDataSQL)) + mock.ExpectPrepare(regexp.QuoteMeta(countOffchainDataSQL)) + mock.ExpectPrepare(regexp.QuoteMeta(selectOffchainDataGapsSQL)) + + dbPG, err := New(context.Background(), wdb) + require.NoError(t, err) + defer db.Close() mock.ExpectBegin() + mock.ExpectPrepare(regexp.QuoteMeta(storeUnresolvedBatchesSQL)) for _, o := range tt.bk { - expected := mock.ExpectExec(`INSERT INTO data_node\.unresolved_batches \(num, hash\) VALUES \(\$1, \$2\) ON CONFLICT \(num, hash\) DO NOTHING`). + expected := mock.ExpectExec(regexp.QuoteMeta(storeUnresolvedBatchesSQL)). WithArgs(o.Number, o.Hash.Hex()) if tt.returnErr != nil { expected.WillReturnError(tt.returnErr) @@ -201,10 +221,6 @@ func Test_DB_StoreUnresolvedBatchKeys(t *testing.T) { mock.ExpectRollback() } - wdb := sqlx.NewDb(db, "postgres") - - dbPG := New(wdb) - err = dbPG.StoreUnresolvedBatchKeys(context.Background(), tt.bk) if tt.returnErr != nil { require.ErrorIs(t, err, tt.returnErr) @@ -253,10 +269,14 @@ func Test_DB_GetUnresolvedBatchKeys(t *testing.T) { defer db.Close() + constructorExpect(mock) + wdb := sqlx.NewDb(db, "postgres") + dbPG, err := New(context.Background(), wdb) + require.NoError(t, err) // Seed data - seedUnresolvedBatchKeys(t, wdb, mock, tt.bks) + seedUnresolvedBatchKeys(t, dbPG, mock, tt.bks) var limit = uint(10) expected := mock.ExpectQuery(`SELECT num, hash FROM data_node\.unresolved_batches LIMIT \$1\;`).WithArgs(limit) @@ -269,8 +289,6 @@ func Test_DB_GetUnresolvedBatchKeys(t *testing.T) { } } - dbPG := New(wdb) - data, err := dbPG.GetUnresolvedBatchKeys(context.Background(), limit) if tt.returnErr != nil { require.ErrorIs(t, err, tt.returnErr) @@ -320,9 +338,12 @@ func Test_DB_DeleteUnresolvedBatchKeys(t *testing.T) { defer db.Close() + constructorExpect(mock) + mock.ExpectBegin() + mock.ExpectPrepare(regexp.QuoteMeta(deleteUnresolvedBatchKeysSQL)) for _, bk := range tt.bks { - expected := mock.ExpectExec(`DELETE FROM data_node\.unresolved_batches WHERE num = \$1 AND hash = \$2`). + expected := mock.ExpectExec(regexp.QuoteMeta(deleteUnresolvedBatchKeysSQL)). WithArgs(bk.Number, bk.Hash.Hex()) if tt.returnErr != nil { expected.WillReturnError(tt.returnErr) @@ -338,7 +359,8 @@ func Test_DB_DeleteUnresolvedBatchKeys(t *testing.T) { wdb := sqlx.NewDb(db, "postgres") - dbPG := New(wdb) + dbPG, err := New(context.Background(), wdb) + require.NoError(t, err) err = dbPG.DeleteUnresolvedBatchKeys(context.Background(), tt.bks) if tt.returnErr != nil { @@ -399,11 +421,18 @@ func Test_DB_StoreOffChainData(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) + constructorExpect(mock) + + wdb := sqlx.NewDb(db, "postgres") + dbPG, err := New(context.Background(), wdb) + require.NoError(t, err) + defer db.Close() mock.ExpectBegin() + mock.ExpectPrepare(regexp.QuoteMeta(storeOffChainDataSQL)) for _, o := range tt.od { - expected := mock.ExpectExec(`INSERT INTO data_node\.offchain_data \(key, value, batch_num\) VALUES \(\$1, \$2, \$3\) ON CONFLICT \(key\) DO UPDATE SET value = EXCLUDED\.value, batch_num = EXCLUDED\.batch_num`). + expected := mock.ExpectExec(regexp.QuoteMeta(storeOffChainDataSQL)). WithArgs(o.Key.Hex(), common.Bytes2Hex(o.Value), o.BatchNum) if tt.returnErr != nil { expected.WillReturnError(tt.returnErr) @@ -417,10 +446,6 @@ func Test_DB_StoreOffChainData(t *testing.T) { mock.ExpectRollback() } - wdb := sqlx.NewDb(db, "postgres") - - dbPG := New(wdb) - err = dbPG.StoreOffChainData(context.Background(), tt.od) if tt.returnErr != nil { require.ErrorIs(t, err, tt.returnErr) @@ -486,14 +511,18 @@ func Test_DB_GetOffChainData(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) - defer db.Close() + constructorExpect(mock) wdb := sqlx.NewDb(db, "postgres") + dbPG, err := New(context.Background(), wdb) + require.NoError(t, err) + + defer db.Close() // Seed data - seedOffchainData(t, wdb, mock, tt.od) + seedOffchainData(t, dbPG, mock, tt.od) - expected := mock.ExpectQuery(`SELECT key, value, batch_num FROM data_node\.offchain_data WHERE key = \$1 LIMIT 1`). + expected := mock.ExpectQuery(regexp.QuoteMeta(getOffchainDataSQL)). WithArgs(tt.key.Hex()) if tt.returnErr != nil { @@ -503,8 +532,6 @@ func Test_DB_GetOffChainData(t *testing.T) { AddRow(tt.expected.Key.Hex(), common.Bytes2Hex(tt.expected.Value), tt.expected.BatchNum)) } - dbPG := New(wdb) - data, err := dbPG.GetOffChainData(context.Background(), tt.key) if tt.returnErr != nil { require.ErrorIs(t, err, tt.returnErr) @@ -613,10 +640,14 @@ func Test_DB_ListOffChainData(t *testing.T) { defer db.Close() + constructorExpect(mock) + wdb := sqlx.NewDb(db, "postgres") + dbPG, err := New(context.Background(), wdb) + require.NoError(t, err) // Seed data - seedOffchainData(t, wdb, mock, tt.od) + seedOffchainData(t, dbPG, mock, tt.od) preparedKeys := make([]driver.Value, len(tt.keys)) for i, key := range tt.keys { @@ -638,8 +669,6 @@ func Test_DB_ListOffChainData(t *testing.T) { expected.WillReturnRows(returnData) } - dbPG := New(wdb) - data, err := dbPG.ListOffChainData(context.Background(), tt.keys) if tt.returnErr != nil { require.ErrorIs(t, err, tt.returnErr) @@ -698,12 +727,16 @@ func Test_DB_CountOffchainData(t *testing.T) { defer db.Close() + constructorExpect(mock) + wdb := sqlx.NewDb(db, "postgres") + dbPG, err := New(context.Background(), wdb) + require.NoError(t, err) // Seed data - seedOffchainData(t, wdb, mock, tt.od) + seedOffchainData(t, dbPG, mock, tt.od) - expected := mock.ExpectQuery(`SELECT COUNT\(\*\) FROM data_node\.offchain_data`) + expected := mock.ExpectQuery(regexp.QuoteMeta(countOffchainDataSQL)) if tt.returnErr != nil { expected.WillReturnError(tt.returnErr) @@ -711,8 +744,6 @@ func Test_DB_CountOffchainData(t *testing.T) { expected.WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(tt.count)) } - dbPG := New(wdb) - actual, err := dbPG.CountOffchainData(context.Background()) if tt.returnErr != nil { require.ErrorIs(t, err, tt.returnErr) @@ -775,27 +806,16 @@ func Test_DB_DetectOffchainDataGaps(t *testing.T) { defer db.Close() + constructorExpect(mock) + wdb := sqlx.NewDb(db, "postgres") + dbPG, err := New(context.Background(), wdb) + require.NoError(t, err) // Seed data - seedOffchainData(t, wdb, mock, tt.seed) - - expected := mock.ExpectQuery(`WITH numbered_batches AS \( - SELECT - batch_num, - ROW_NUMBER\(\) OVER \(ORDER BY batch_num\) AS row_number - FROM data_node\.offchain_data - \) - SELECT - nb1\.batch_num AS current_batch_num, - nb2\.batch_num AS next_batch_num - FROM - numbered_batches nb1 - LEFT JOIN numbered_batches nb2 ON nb1\.row_number = nb2\.row_number - 1 - WHERE - nb1\.batch_num IS NOT NULL - AND nb2\.batch_num IS NOT NULL - AND nb1\.batch_num \+ 1 <> nb2\.batch_num`) + seedOffchainData(t, dbPG, mock, tt.seed) + + expected := mock.ExpectQuery(regexp.QuoteMeta(selectOffchainDataGapsSQL)) if tt.returnErr != nil { expected.WillReturnError(tt.returnErr) @@ -807,8 +827,6 @@ func Test_DB_DetectOffchainDataGaps(t *testing.T) { expected.WillReturnRows(rows) } - dbPG := New(wdb) - actual, err := dbPG.DetectOffchainDataGaps(context.Background()) if tt.returnErr != nil { require.ErrorIs(t, err, tt.returnErr) @@ -822,32 +840,43 @@ func Test_DB_DetectOffchainDataGaps(t *testing.T) { } } -func seedOffchainData(t *testing.T, db *sqlx.DB, mock sqlmock.Sqlmock, od []types.OffChainData) { +func constructorExpect(mock sqlmock.Sqlmock) { + mock.ExpectPrepare(regexp.QuoteMeta(storeLastProcessedBlockSQL)) + mock.ExpectPrepare(regexp.QuoteMeta(getLastProcessedBlockSQL)) + mock.ExpectPrepare(regexp.QuoteMeta(getUnresolvedBatchKeysSQL)) + mock.ExpectPrepare(regexp.QuoteMeta(getOffchainDataSQL)) + mock.ExpectPrepare(regexp.QuoteMeta(countOffchainDataSQL)) + mock.ExpectPrepare(regexp.QuoteMeta(selectOffchainDataGapsSQL)) +} + +func seedOffchainData(t *testing.T, db DB, mock sqlmock.Sqlmock, od []types.OffChainData) { t.Helper() mock.ExpectBegin() + mock.ExpectPrepare(regexp.QuoteMeta(storeOffChainDataSQL)) for i, o := range od { - mock.ExpectExec(`INSERT INTO data_node\.offchain_data \(key, value, batch_num\) VALUES \(\$1, \$2, \$3\) ON CONFLICT \(key\) DO UPDATE SET value = EXCLUDED\.value, batch_num = EXCLUDED\.batch_num`). + mock.ExpectExec(regexp.QuoteMeta(storeOffChainDataSQL)). WithArgs(o.Key.Hex(), common.Bytes2Hex(o.Value), o.BatchNum). WillReturnResult(sqlmock.NewResult(int64(i+1), int64(i+1))) } mock.ExpectCommit() - err := New(db).StoreOffChainData(context.Background(), od) + err := db.StoreOffChainData(context.Background(), od) require.NoError(t, err) } -func seedUnresolvedBatchKeys(t *testing.T, db *sqlx.DB, mock sqlmock.Sqlmock, bk []types.BatchKey) { +func seedUnresolvedBatchKeys(t *testing.T, db DB, mock sqlmock.Sqlmock, bk []types.BatchKey) { t.Helper() mock.ExpectBegin() + mock.ExpectPrepare(regexp.QuoteMeta(storeUnresolvedBatchesSQL)) for i, o := range bk { - mock.ExpectExec(`INSERT INTO data_node\.unresolved_batches \(num, hash\) VALUES \(\$1, \$2\) ON CONFLICT \(num, hash\) DO NOTHING`). + mock.ExpectExec(regexp.QuoteMeta(storeUnresolvedBatchesSQL)). WithArgs(o.Number, o.Hash.Hex()). WillReturnResult(sqlmock.NewResult(int64(i+1), int64(i+1))) } mock.ExpectCommit() - err := New(db).StoreUnresolvedBatchKeys(context.Background(), bk) + err := db.StoreUnresolvedBatchKeys(context.Background(), bk) require.NoError(t, err) }