Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a separate unit to manage the cache of prepared statements #2937

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
different version can lead to presubmits failing due to unexpected
diffs.

* Use a separate unit (StmtCache) to manage the cache of prepared statements, which wraps the sql.Stmt struct to handle and monitor the execution errors of the prepared statement. When an error occurs during statement execution, it closes the statement and clears the cache, as well as increments the error monitoring indicator.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what 'unit' means?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The word 'separate unit' comes from the current TODO comment:

// TODO(al,martin): consider pulling this all out as a separate unit for reuse


### Misc

* Bump Go version from 1.17 to 1.19.
Expand Down
17 changes: 9 additions & 8 deletions storage/mysql/log_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/google/trillian/monitoring"
"github.com/google/trillian/storage"
"github.com/google/trillian/storage/cache"
"github.com/google/trillian/storage/stmtcache"
"github.com/google/trillian/storage/tree"
"github.com/google/trillian/types"
"github.com/transparency-dev/merkle/compact"
Expand Down Expand Up @@ -135,7 +136,7 @@ func NewLogStorage(db *sql.DB, mf monitoring.MetricFactory) storage.LogStorage {
}
return &mySQLLogStorage{
admin: NewAdminStorage(db),
mySQLTreeStorage: newTreeStorage(db),
mySQLTreeStorage: newTreeStorage(db, mf),
metricFactory: mf,
}
}
Expand All @@ -144,16 +145,16 @@ func (m *mySQLLogStorage) CheckDatabaseAccessible(ctx context.Context) error {
return m.db.PingContext(ctx)
}

func (m *mySQLLogStorage) getLeavesByMerkleHashStmt(ctx context.Context, num int, orderBySequence bool) (*sql.Stmt, error) {
func (m *mySQLLogStorage) getLeavesByMerkleHashStmt(ctx context.Context, num int, orderBySequence bool) (*stmtcache.Stmt, error) {
if orderBySequence {
return m.getStmt(ctx, selectLeavesByMerkleHashOrderedBySequenceSQL, num, "?", "?")
return m.stmtCache.GetStmt(ctx, selectLeavesByMerkleHashOrderedBySequenceSQL, num, "?", "?")
}

return m.getStmt(ctx, selectLeavesByMerkleHashSQL, num, "?", "?")
return m.stmtCache.GetStmt(ctx, selectLeavesByMerkleHashSQL, num, "?", "?")
}

func (m *mySQLLogStorage) getLeavesByLeafIdentityHashStmt(ctx context.Context, num int) (*sql.Stmt, error) {
return m.getStmt(ctx, selectLeavesByLeafIdentityHashSQL, num, "?", "?")
func (m *mySQLLogStorage) getLeavesByLeafIdentityHashStmt(ctx context.Context, num int) (*stmtcache.Stmt, error) {
return m.stmtCache.GetStmt(ctx, selectLeavesByLeafIdentityHashSQL, num, "?", "?")
}

func (m *mySQLLogStorage) GetActiveLogIDs(ctx context.Context) ([]int64, error) {
Expand Down Expand Up @@ -730,8 +731,8 @@ func (t *logTreeTX) StoreSignedLogRoot(ctx context.Context, root *trillian.Signe
return checkResultOkAndRowCountIs(res, err, 1)
}

func (t *logTreeTX) getLeavesByHashInternal(ctx context.Context, leafHashes [][]byte, tmpl *sql.Stmt, desc string) ([]*trillian.LogLeaf, error) {
stx := t.tx.StmtContext(ctx, tmpl)
func (t *logTreeTX) getLeavesByHashInternal(ctx context.Context, leafHashes [][]byte, tmpl *stmtcache.Stmt, desc string) ([]*trillian.LogLeaf, error) {
stx := tmpl.WithTx(ctx, t.tx)
defer stx.Close()

var args []interface{}
Expand Down
71 changes: 13 additions & 58 deletions storage/mysql/tree_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ import (
"encoding/base64"
"fmt"
"runtime/debug"
"strings"
"sync"

"github.com/google/trillian"
"github.com/google/trillian/monitoring"
"github.com/google/trillian/storage/cache"
"github.com/google/trillian/storage/stmtcache"
"github.com/google/trillian/storage/storagepb"
"github.com/google/trillian/storage/tree"
"google.golang.org/protobuf/proto"
Expand All @@ -52,20 +53,15 @@ const (
AND Subtree.SubtreeRevision = x.MaxRevision
AND Subtree.TreeId = x.TreeId
AND Subtree.TreeId = ?`
placeholderSQL = "<placeholder>"
placeholderSQL = stmtcache.PlaceholderSQL
)

// mySQLTreeStorage is shared between the mySQLLog- and (forthcoming) mySQLMap-
// Storage implementations, and contains functionality which is common to both,
type mySQLTreeStorage struct {
db *sql.DB

// Must hold the mutex before manipulating the statement map. Sharing a lock because
// it only needs to be held while the statements are built, not while they execute and
// this will be a short time. These maps are from the number of placeholder '?'
// in the query to the statement that should be used.
statementMutex sync.Mutex
statements map[string]map[int]*sql.Stmt
stmtCache *stmtcache.StmtCache
}

// OpenDB opens a database connection for all MySQL-based storage implementations.
Expand All @@ -85,60 +81,19 @@ func OpenDB(dbURL string) (*sql.DB, error) {
return db, nil
}

func newTreeStorage(db *sql.DB) *mySQLTreeStorage {
func newTreeStorage(db *sql.DB, mf monitoring.MetricFactory) *mySQLTreeStorage {
return &mySQLTreeStorage{
db: db,
statements: make(map[string]map[int]*sql.Stmt),
db: db,
stmtCache: stmtcache.New(db, mf),
}
}

// expandPlaceholderSQL expands an sql statement by adding a specified number of '?'
// placeholder slots. At most one placeholder will be expanded.
func expandPlaceholderSQL(sql string, num int, first, rest string) string {
if num <= 0 {
panic(fmt.Errorf("trying to expand SQL placeholder with <= 0 parameters: %s", sql))
}

parameters := first + strings.Repeat(","+rest, num-1)

return strings.Replace(sql, placeholderSQL, parameters, 1)
}

// getStmt creates and caches sql.Stmt structs based on the passed in statement
// and number of bound arguments.
// TODO(al,martin): consider pulling this all out as a separate unit for reuse
// elsewhere.
func (m *mySQLTreeStorage) getStmt(ctx context.Context, statement string, num int, first, rest string) (*sql.Stmt, error) {
m.statementMutex.Lock()
defer m.statementMutex.Unlock()

if m.statements[statement] != nil {
if m.statements[statement][num] != nil {
// TODO(al,martin): we'll possibly need to expire Stmts from the cache,
// e.g. when DB connections break etc.
return m.statements[statement][num], nil
}
} else {
m.statements[statement] = make(map[int]*sql.Stmt)
}

s, err := m.db.PrepareContext(ctx, expandPlaceholderSQL(statement, num, first, rest))
if err != nil {
klog.Warningf("Failed to prepare statement %d: %s", num, err)
return nil, err
}

m.statements[statement][num] = s

return s, nil
}

func (m *mySQLTreeStorage) getSubtreeStmt(ctx context.Context, num int) (*sql.Stmt, error) {
return m.getStmt(ctx, selectSubtreeSQL, num, "?", "?")
func (m *mySQLTreeStorage) getSubtreeStmt(ctx context.Context, num int) (*stmtcache.Stmt, error) {
return m.stmtCache.GetStmt(ctx, selectSubtreeSQL, num, "?", "?")
}

func (m *mySQLTreeStorage) setSubtreeStmt(ctx context.Context, num int) (*sql.Stmt, error) {
return m.getStmt(ctx, insertSubtreeMultiSQL, num, "VALUES(?, ?, ?, ?)", "(?, ?, ?, ?)")
func (m *mySQLTreeStorage) setSubtreeStmt(ctx context.Context, num int) (*stmtcache.Stmt, error) {
return m.stmtCache.GetStmt(ctx, insertSubtreeMultiSQL, num, "VALUES(?, ?, ?, ?)", "(?, ?, ?, ?)")
}

func (m *mySQLTreeStorage) beginTreeTx(ctx context.Context, tree *trillian.Tree, hashSizeBytes int, subtreeCache *cache.SubtreeCache) (treeTX, error) {
Expand Down Expand Up @@ -183,7 +138,7 @@ func (t *treeTX) getSubtrees(ctx context.Context, treeRevision int64, ids [][]by
if err != nil {
return nil, err
}
stx := t.tx.StmtContext(ctx, tmpl)
stx := tmpl.WithTx(ctx, t.tx)
defer stx.Close()

args := make([]interface{}, 0, len(ids)+3)
Expand Down Expand Up @@ -291,7 +246,7 @@ func (t *treeTX) storeSubtrees(ctx context.Context, subtrees []*storagepb.Subtre
if err != nil {
return err
}
stx := t.tx.StmtContext(ctx, tmpl)
stx := tmpl.WithTx(ctx, t.tx)
defer stx.Close()

r, err := stx.ExecContext(ctx, args...)
Expand Down
200 changes: 200 additions & 0 deletions storage/stmtcache/stmtcache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
// Package stmtcache contains tools for managing the prepared-statement cache.
package stmtcache

import (
"context"
"database/sql"
"fmt"
"strings"
"sync"

"github.com/google/trillian/monitoring"
"k8s.io/klog/v2"
)

var (
once sync.Once
errStmtCounter monitoring.Counter
)

// PlaceholderSQL SQL statement placeholder.
const PlaceholderSQL = "<placeholder>"

// Stmt is wraps the sql.Stmt struct for handling and monitoring SQL errors.
// If Stmt execution errors occur, it is automatically closed and the prepared statements in the cache are cleared.
type Stmt struct {
statement string
placeholderNum int
stmtCache *StmtCache
stmt *sql.Stmt
parentStmt *Stmt
}

// errHandler handling and monitoring SQL errors
// This err parameter is not currently used, but it may be necessary to perform more granular processing and monitoring of different errs in the future.
func (s *Stmt) errHandler(_ error) {
o := s
if s.parentStmt != nil {
o = s.parentStmt
}

if err := o.Close(); err != nil {
klog.Warningf("Failed to close stmt: %s", err)
}

if o.stmtCache != nil {
once.Do(func() {
errStmtCounter = o.stmtCache.mf.NewCounter("sql_stmt_errors", "Number of statement execution errors")
})

errStmtCounter.Inc()
}
}

// SQLStmt returns the referenced sql.Stmt struct.
func (s *Stmt) SQLStmt() *sql.Stmt {
return s.stmt
}

// Close closes the Stmt.
// Clear if Stmt belongs to cache
func (s *Stmt) Close() error {
if cache := s.stmtCache; cache != nil {
cache.clearOne(s)
}

return s.stmt.Close()
}

// WithTx returns a transaction-specific prepared statement from
// an existing statement.
// The transaction-specific Stmt is closed by the caller.
func (s *Stmt) WithTx(ctx context.Context, tx *sql.Tx) *Stmt {
parent := s
if s.parentStmt != nil {
parent = s.parentStmt
}
return &Stmt{
parentStmt: parent,
stmt: tx.StmtContext(ctx, parent.stmt),
}
}

// ExecContext executes a prepared statement with the given arguments and
// returns a Result summarizing the effect of the statement.
func (s *Stmt) ExecContext(ctx context.Context, args ...any) (sql.Result, error) {
res, err := s.stmt.ExecContext(ctx, args...)
if err != nil {
s.errHandler(err)
}
return res, err
}

// QueryContext executes a prepared query statement with the given arguments
// and returns the query results as a *Rows.
func (s *Stmt) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) {
res, err := s.stmt.QueryContext(ctx, args...)
if err != nil {
s.errHandler(err)
}
return res, err
}

// QueryRowContext executes a prepared query statement with the given arguments.
// If an error occurs during the execution of the statement, that error will
// be returned by a call to Scan on the returned *Row, which is always non-nil.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards
// the rest.
func (s *Stmt) QueryRowContext(ctx context.Context, args ...any) *sql.Row {
res := s.stmt.QueryRowContext(ctx, args...)
if err := res.Err(); err != nil {
s.errHandler(err)
}
return res
}

// StmtCache is a cache of the sql.Stmt structs.
type StmtCache struct {
db *sql.DB
statementMutex sync.Mutex
statements map[string]map[int]*sql.Stmt
mf monitoring.MetricFactory
}

// New creates a StmtCache instance.
func New(db *sql.DB, mf monitoring.MetricFactory) *StmtCache {
if mf == nil {
mf = monitoring.InertMetricFactory{}
}

return &StmtCache{
db: db,
statements: make(map[string]map[int]*sql.Stmt),
mf: mf,
}
}

// clearOne clear the cache of a sql.Stmt.
func (sc *StmtCache) clearOne(s *Stmt) {
if s == nil || s.stmt == nil || s.stmtCache != sc {
return
}

sc.statementMutex.Lock()
defer sc.statementMutex.Unlock()

if _s, ok := sc.statements[s.statement][s.placeholderNum]; ok && _s == s.stmt {
sc.statements[s.statement][s.placeholderNum] = nil
}
}

func (sc *StmtCache) getStmt(ctx context.Context, statement string, num int, first, rest string) (*sql.Stmt, error) {
sc.statementMutex.Lock()
defer sc.statementMutex.Unlock()

if sc.statements[statement] != nil {
if sc.statements[statement][num] != nil {
return sc.statements[statement][num], nil
}
} else {
sc.statements[statement] = make(map[int]*sql.Stmt)
}

s, err := sc.db.PrepareContext(ctx, expandPlaceholderSQL(statement, num, first, rest))
if err != nil {
klog.Warningf("Failed to prepare statement %d: %s", num, err)
return nil, err
}

sc.statements[statement][num] = s

return s, nil
}

// expandPlaceholderSQL expands an sql statement by adding a specified number of '?'
// placeholder slots. At most one placeholder will be expanded.
func expandPlaceholderSQL(sql string, num int, first, rest string) string {
if num <= 0 {
panic(fmt.Errorf("trying to expand SQL placeholder with <= 0 parameters: %s", sql))
}

parameters := first + strings.Repeat(","+rest, num-1)

return strings.Replace(sql, PlaceholderSQL, parameters, 1)
}

// GetStmt creates and caches sql.Stmt and returns their wrapper Stmt.
func (sc *StmtCache) GetStmt(ctx context.Context, statement string, num int, first, rest string) (*Stmt, error) {
stmt, err := sc.getStmt(ctx, statement, num, first, rest)
if err != nil {
return nil, err
}

return &Stmt{
statement: statement,
placeholderNum: num,
stmtCache: sc,
stmt: stmt,
}, nil
}
Loading