Skip to content

Commit

Permalink
[Audit remediation] Suggestion 1: Adopt a Verifiable Database / Batch…
Browse files Browse the repository at this point in the history
… Format (#100)
  • Loading branch information
begmaroman authored Jul 3, 2024
1 parent 97f3683 commit 48558e6
Show file tree
Hide file tree
Showing 18 changed files with 361 additions and 377 deletions.
69 changes: 33 additions & 36 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@ var (

// DB defines functions that a DB instance should implement
type DB interface {
StoreLastProcessedBlock(ctx context.Context, task string, block uint64) error
StoreLastProcessedBlock(ctx context.Context, block uint64, task string) error
GetLastProcessedBlock(ctx context.Context, task string) (uint64, error)

StoreUnresolvedBatchKeys(ctx context.Context, bks []types.BatchKey) error
GetUnresolvedBatchKeys(ctx context.Context, limit uint) ([]types.BatchKey, error)
DeleteUnresolvedBatchKeys(ctx context.Context, bks []types.BatchKey) error

Exists(ctx context.Context, key common.Hash) bool
GetOffChainData(ctx context.Context, key common.Hash) (types.ArgBytes, error)
ListOffChainData(ctx context.Context, keys []common.Hash) (map[common.Hash]types.ArgBytes, error)
GetOffChainData(ctx context.Context, key common.Hash) (*types.OffChainData, error)
ListOffChainData(ctx context.Context, keys []common.Hash) ([]types.OffChainData, error)
StoreOffChainData(ctx context.Context, od []types.OffChainData) error

CountOffchainData(ctx context.Context) (uint64, error)
Expand All @@ -46,7 +45,7 @@ func New(pg *sqlx.DB) DB {
}

// StoreLastProcessedBlock stores a record of a block processed by the synchronizer for named task
func (db *pgDB) StoreLastProcessedBlock(ctx context.Context, task string, block uint64) error {
func (db *pgDB) StoreLastProcessedBlock(ctx context.Context, block uint64, task string) error {
const storeLastProcessedBlockSQL = `
INSERT INTO data_node.sync_tasks (task, block)
VALUES ($1, $2)
Expand Down Expand Up @@ -165,27 +164,13 @@ func (db *pgDB) DeleteUnresolvedBatchKeys(ctx context.Context, bks []types.Batch
return tx.Commit()
}

// Exists checks if a key exists in offchain data table
func (db *pgDB) Exists(ctx context.Context, key common.Hash) bool {
const keyExists = "SELECT COUNT(*) FROM data_node.offchain_data WHERE key = $1;"

var (
count uint
)

if err := db.pg.QueryRowContext(ctx, keyExists, key.Hex()).Scan(&count); err != nil {
return false
}

return count > 0
}

// StoreOffChainData stores and array of key values in the Db
func (db *pgDB) StoreOffChainData(ctx context.Context, od []types.OffChainData) error {
const storeOffChainDataSQL = `
INSERT INTO data_node.offchain_data (key, value)
VALUES ($1, $2)
ON CONFLICT (key) DO NOTHING;
INSERT INTO data_node.offchain_data (key, value, batch_num)
VALUES ($1, $2, $3)
ON CONFLICT (key) DO UPDATE
SET value = EXCLUDED.value, batch_num = EXCLUDED.batch_num;
`

tx, err := db.pg.BeginTxx(ctx, nil)
Expand All @@ -198,6 +183,7 @@ func (db *pgDB) StoreOffChainData(ctx context.Context, od []types.OffChainData)
ctx, storeOffChainDataSQL,
d.Key.Hex(),
common.Bytes2Hex(d.Value),
d.BatchNum,
); err != nil {
if txErr := tx.Rollback(); txErr != nil {
return fmt.Errorf("%v: rollback caused by %v", txErr, err)
Expand All @@ -211,36 +197,42 @@ func (db *pgDB) StoreOffChainData(ctx context.Context, od []types.OffChainData)
}

// GetOffChainData returns the value identified by the key
func (db *pgDB) GetOffChainData(ctx context.Context, key common.Hash) (types.ArgBytes, error) {
func (db *pgDB) GetOffChainData(ctx context.Context, key common.Hash) (*types.OffChainData, error) {
const getOffchainDataSQL = `
SELECT value
SELECT key, value, batch_num
FROM data_node.offchain_data
WHERE key = $1 LIMIT 1;
`

var (
hexValue string
)
data := struct {
Key string `db:"key"`
Value string `db:"value"`
BatchNum uint64 `db:"batch_num"`
}{}

if err := db.pg.QueryRowxContext(ctx, getOffchainDataSQL, key.Hex()).Scan(&hexValue); err != nil {
if err := db.pg.QueryRowxContext(ctx, getOffchainDataSQL, key.Hex()).StructScan(&data); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrStateNotSynchronized
}

return nil, err
}

return common.FromHex(hexValue), nil
return &types.OffChainData{
Key: common.HexToHash(data.Key),
Value: common.FromHex(data.Value),
BatchNum: data.BatchNum,
}, nil
}

// ListOffChainData returns values identified by the given keys
func (db *pgDB) ListOffChainData(ctx context.Context, keys []common.Hash) (map[common.Hash]types.ArgBytes, error) {
func (db *pgDB) ListOffChainData(ctx context.Context, keys []common.Hash) ([]types.OffChainData, error) {
if len(keys) == 0 {
return nil, nil
}

const listOffchainDataSQL = `
SELECT key, value
SELECT key, value, batch_num
FROM data_node.offchain_data
WHERE key IN (?);
`
Expand All @@ -265,17 +257,22 @@ func (db *pgDB) ListOffChainData(ctx context.Context, keys []common.Hash) (map[c

defer rows.Close()

list := make(map[common.Hash]types.ArgBytes)
list := make([]types.OffChainData, 0, len(keys))
for rows.Next() {
data := struct {
Key string `db:"key"`
Value string `db:"value"`
Key string `db:"key"`
Value string `db:"value"`
BatchNum uint64 `db:"batch_num"`
}{}
if err = rows.StructScan(&data); err != nil {
return nil, err
}

list[common.HexToHash(data.Key)] = common.FromHex(data.Value)
list = append(list, types.OffChainData{
Key: common.HexToHash(data.Key),
Value: common.FromHex(data.Value),
BatchNum: data.BatchNum,
})
}

return list, nil
Expand Down
Loading

0 comments on commit 48558e6

Please sign in to comment.