Skip to content

Commit

Permalink
add context timeout for lib
Browse files Browse the repository at this point in the history
  • Loading branch information
mazchew committed Sep 22, 2024
1 parent d276277 commit ec5ee7d
Show file tree
Hide file tree
Showing 11 changed files with 508 additions and 74 deletions.
88 changes: 78 additions & 10 deletions client/gosqldriver/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,29 @@ type heraConnection struct {
clientinfo *netstring.Netstring

// Context support
watching bool
watcher chan<- context.Context
finished chan<- struct{}
closech chan struct{}
watching bool
watcher chan<- context.Context
finished chan<- struct{}
closech chan struct{}
cancelled atomicError // set non-nil if conn is canceled
closed atomicBool // set when conn is closed, before closech is closed
}

// NewHeraConnection creates a structure implementing a driver.Con interface
func NewHeraConnection(conn net.Conn) driver.Conn {
hera := &heraConnection{conn: conn, id: conn.RemoteAddr().String(), reader: netstring.NewNetstringReader(conn), corrID: corrIDUnsetCmd}
if logger.GetLogger().V(logger.Info) {
logger.GetLogger().Log(logger.Info, hera.id, "create driver connection")
hera := &heraConnection{conn: conn,
id: conn.RemoteAddr().String(),
reader: netstring.NewNetstringReader(conn),
corrID: corrIDUnsetCmd,
closech: make(chan struct{}),
}

hera.startWatcher()

if logger.GetLogger().V(logger.Info) {
logger.GetLogger().Log(logger.Info, hera.id, "create driver connection")
}

return hera
}

Expand Down Expand Up @@ -112,6 +120,8 @@ func (c *heraConnection) startWatcher() {
}()
}

// finish is called when the query has succeeded.

func (c *heraConnection) finish() {
if !c.watching || c.finished == nil {
return
Expand All @@ -127,13 +137,14 @@ func (c *heraConnection) cancel(err error) {
if logger.GetLogger().V(logger.Debug) {
logger.GetLogger().Log(logger.Debug, c.id, "ctx error:", err)
}
c.Close()
c.cancelled.Set(err)
c.cleanup()
}

func (c *heraConnection) watchCancel(ctx context.Context) error {
if c.watching {
// Reach here if canceled, the connection is already invalid
c.Close()
// Reach here if cancelled, the connection is already invalid
c.cleanup()
return nil
}
// When ctx is already cancelled, don't watch it.
Expand All @@ -154,6 +165,37 @@ func (c *heraConnection) watchCancel(ctx context.Context) error {
return nil
}

// Closes the network connection and unsets internal variables. Do not call this
// function after successfully authentication, call Close instead. This function
// is called before auth or on auth failure because HERA will have already
// closed the network connection.
func (c *heraConnection) cleanup() {
if c.closed.Swap(true) {
return
}

// Makes cleanup idempotent
close(c.closech)
if c.conn == nil {
return
}
c.finish()
if err := c.conn.Close(); err != nil {
logger.GetLogger().Log(logger.Alert, err)
}
}

// error
func (c *heraConnection) error() error {
if c.closed.Load() {
if err := c.cancelled.Value(); err != nil {
return err
}
return ErrInvalidConn
}
return nil
}

// Prepare returns a prepared statement, bound to this connection.
func (c *heraConnection) Prepare(query string) (driver.Stmt, error) {
if logger.GetLogger().V(logger.Debug) {
Expand Down Expand Up @@ -183,16 +225,28 @@ func (c *heraConnection) Begin() (driver.Tx, error) {
if logger.GetLogger().V(logger.Debug) {
logger.GetLogger().Log(logger.Debug, c.id, "begin txn")
}
if c.closed.Load() {
logger.GetLogger().Log(logger.Alert, ErrInvalidConn)
return nil, driver.ErrBadConn
}
return &tx{hera: c}, nil
}

// internal function to execute commands
func (c *heraConnection) exec(cmd int, payload []byte) error {
if c.closed.Load() {
logger.GetLogger().Log(logger.Alert, ErrInvalidConn)
return driver.ErrBadConn
}
return c.execNs(netstring.NewNetstringFrom(cmd, payload))
}

// internal function to execute commands
func (c *heraConnection) execNs(ns *netstring.Netstring) error {
if c.closed.Load() {
logger.GetLogger().Log(logger.Alert, ErrInvalidConn)
return driver.ErrBadConn
}
if logger.GetLogger().V(logger.Verbose) {
payload := string(ns.Payload)
if len(payload) > 1000 {
Expand All @@ -206,6 +260,10 @@ func (c *heraConnection) execNs(ns *netstring.Netstring) error {

// returns the next message from the connection
func (c *heraConnection) getResponse() (*netstring.Netstring, error) {
if c.closed.Load() {
logger.GetLogger().Log(logger.Alert, ErrInvalidConn)
return nil, driver.ErrBadConn
}
ns, err := c.reader.ReadNext()
if err != nil {
if logger.GetLogger().V(logger.Warning) {
Expand Down Expand Up @@ -283,6 +341,11 @@ func (c *heraConnection) SetClientInfo(poolName string, host string) error {
return nil
}

if c.closed.Load() {
logger.GetLogger().Log(logger.Alert, ErrInvalidConn)
return driver.ErrBadConn
}

pid := os.Getpid()
data := fmt.Sprintf("PID: %d, HOST: %s, Poolname: %s, Command: SetClientInfo,", pid, host, poolName)
c.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data)))
Expand Down Expand Up @@ -315,6 +378,11 @@ func (c *heraConnection) SetClientInfoWithPoolStack(poolName string, host string
return nil
}

if c.closed.Load() {
logger.GetLogger().Log(logger.Alert, ErrInvalidConn)
return driver.ErrBadConn
}

pid := os.Getpid()
data := fmt.Sprintf("PID: %d, HOST: %s, Poolname: %s, PoolStack: %s, Command: SetClientInfo,", pid, host, poolName, poolStack)
c.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data)))
Expand Down
8 changes: 4 additions & 4 deletions client/gosqldriver/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ func (st *stmt) NumInput() int {
// Implements driver.Stmt.
// Exec executes a query that doesn't return rows, such as an INSERT or UPDATE.
func (st *stmt) Exec(args []driver.Value) (driver.Result, error) {
defer st.hera.finish()
sk := 0
if len(st.hera.getShardKeyPayload()) > 0 {
sk = 1
Expand Down Expand Up @@ -113,9 +114,9 @@ func (st *stmt) Exec(args []driver.Value) (driver.Result, error) {
func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
//TODO: honor the context timeout and return when it is canceled
if err := st.hera.watchCancel(ctx); err != nil {
fmt.Println("=== reach here 1 ====")
return nil, err
}
defer st.hera.finish()

sk := 0
if len(st.hera.getShardKeyPayload()) > 0 {
Expand All @@ -132,8 +133,6 @@ func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driv
cmd := netstring.NewNetstringEmbedded(nss)
err = st.hera.execNs(cmd)
if err != nil {
fmt.Println("=== reach here 2 ====")
st.hera.finish()
return nil, err
}

Expand All @@ -151,6 +150,7 @@ func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driv
// Implements driver.Stmt.
// Query executes a query that may return rows, such as a SELECT.
func (st *stmt) Query(args []driver.Value) (driver.Rows, error) {
defer st.hera.finish()
sk := 0
if len(st.hera.getShardKeyPayload()) > 0 {
sk = 1
Expand Down Expand Up @@ -188,6 +188,7 @@ func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (dri
if err := st.hera.watchCancel(ctx); err != nil {
return nil, err
}
defer st.hera.finish()

sk := 0
if len(st.hera.getShardKeyPayload()) > 0 {
Expand All @@ -204,7 +205,6 @@ func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (dri
cmd := netstring.NewNetstringEmbedded(nss)
err = st.hera.execNs(cmd)
if err != nil {
st.hera.finish()
return nil, err
}

Expand Down
67 changes: 67 additions & 0 deletions client/gosqldriver/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package gosqldriver

import (
"errors"
"sync"
"sync/atomic"
)

var ErrInvalidConn = errors.New("invalid connection")

// atomicError provides thread-safe error handling
type atomicError struct {
value atomic.Value
mu sync.Mutex
}

// Set sets the error value atomically. The value must not be nil.
func (ae *atomicError) Set(err error) {
if err == nil {
panic("atomicError: nil error value")
}
ae.mu.Lock()
defer ae.mu.Unlock()
ae.value.Store(err)
}

// Value returns the current error value, or nil if none is set.
func (ae *atomicError) Value() error {
v := ae.value.Load()
if v == nil {
return nil
}
return v.(error)
}

type atomicBool struct {
value uint32
mu sync.Mutex
}

// Store sets the value of the bool regardless of the previous value
func (ab *atomicBool) Store(value bool) {
ab.mu.Lock()
defer ab.mu.Unlock()
if value {
atomic.StoreUint32(&ab.value, 1)
} else {
atomic.StoreUint32(&ab.value, 0)
}
}

// Load returns whether the current boolean value is true
func (ab *atomicBool) Load() bool {
ab.mu.Lock()
defer ab.mu.Unlock()
return atomic.LoadUint32(&ab.value) > 0
}

// Swap sets the value of the bool and returns the old value.
func (ab *atomicBool) Swap(value bool) bool {
ab.mu.Lock()
defer ab.mu.Unlock()
if value {
return atomic.SwapUint32(&ab.value, 1) > 0
}
return atomic.SwapUint32(&ab.value, 0) > 0
}
14 changes: 10 additions & 4 deletions lib/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@ package lib
import (
"errors"
"fmt"
"github.com/paypal/hera/cal"
"github.com/paypal/hera/config"
"github.com/paypal/hera/utility/logger"
"os"
"path/filepath"
"strings"
"sync/atomic"

"github.com/paypal/hera/cal"
"github.com/paypal/hera/config"
"github.com/paypal/hera/utility/logger"
)

//The Config contains all the static configuration
// The Config contains all the static configuration
type Config struct {
CertChainFile string
KeyFile string // leave blank for no SSL
Expand Down Expand Up @@ -179,6 +180,9 @@ type Config struct {

// Max desired percentage of healthy workers for the worker pool
MaxDesiredHealthyWorkerPct int

//Timeout for management queries.
ManagementQueriesTimeoutInUs int
}

// The OpsConfig contains the configuration that can be modified during run time
Expand Down Expand Up @@ -465,6 +469,8 @@ func InitConfig() error {
gAppConfig.MaxDesiredHealthyWorkerPct = 90
}

gAppConfig.ManagementQueriesTimeoutInUs = cdb.GetOrDefaultInt("management_queries_timeout_us", 200000) //200 milliseconds

return nil
}

Expand Down
2 changes: 1 addition & 1 deletion lib/querybindblocker.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func InitQueryBindBlocker(modName string) {
}

func loadBlockQueryBind(db *sql.DB) {
ctx, cancel := context.WithTimeout(context.Background(), 5000*time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(GetConfig().ManagementQueriesTimeoutInUs)*time.Microsecond)
defer cancel()
conn, err := db.Conn(ctx)
if err != nil {
Expand Down
Loading

0 comments on commit ec5ee7d

Please sign in to comment.