Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests for connection pooling #395

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 234 additions & 49 deletions client/gosqldriver/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package gosqldriver

import (
"context"
"database/sql/driver"
"errors"
"fmt"
Expand All @@ -32,26 +33,168 @@ import (

var corrIDUnsetCmd = netstring.NewNetstringFrom(common.CmdClientCalCorrelationID, []byte("CorrId=NotSet"))

type heraConnectionInterface interface {
Prepare(query string) (driver.Stmt, error)
Close() error
Begin() (driver.Tx, error)
exec(cmd int, payload []byte) error
execNs(ns *netstring.Netstring) error
getResponse() (*netstring.Netstring, error)
SetShardID(shard int) error
ResetShardID() error
GetNumShards() (int, error)
SetShardKeyPayload(payload string)
ResetShardKeyPayload()
SetCalCorrID(corrID string)
SetClientInfo(poolName string, host string) error
SetClientInfoWithPoolStack(poolName string, host string, poolStack string) error
getID() string
getCorrID() *netstring.Netstring
getShardKeyPayload() []byte
setCorrID(*netstring.Netstring)
startWatcher()
finish()
cancel(err error)
watchCancel(ctx context.Context) error
}

type heraConnection struct {
id string // used for logging
conn net.Conn
reader *netstring.Reader
// for the sharding extension
shardKeyPayload []byte
// correlation id
corrID *netstring.Netstring
corrID *netstring.Netstring
clientinfo *netstring.Netstring

// Context support
watching bool
watcher chan<- context.Context
finished chan<- struct{}
closech chan struct{}
cancelled atomicError // set non-nil if conn is canceled
closed atomicBool // set when conn is closed, before closech is closed
}

// NewHeraConnection creates a structure implementing a driver.Con interface
func NewHeraConnection(conn net.Conn) driver.Conn {
hera := &heraConnection{conn: conn, id: conn.RemoteAddr().String(), reader: netstring.NewNetstringReader(conn), corrID: corrIDUnsetCmd}
hera := &heraConnection{conn: conn,
id: conn.RemoteAddr().String(),
reader: netstring.NewNetstringReader(conn),
corrID: corrIDUnsetCmd,
closech: make(chan struct{}),
}

hera.startWatcher()

if logger.GetLogger().V(logger.Info) {
logger.GetLogger().Log(logger.Info, hera.id, "create driver connection")
}

return hera
}

func (c *heraConnection) startWatcher() {
watcher := make(chan context.Context, 1)
c.watcher = watcher
finished := make(chan struct{})
c.finished = finished
go func() {
for {
var ctx context.Context
select {
case ctx = <-watcher:
case <-c.closech:
return
}

select {
case <-ctx.Done():
c.cancel(ctx.Err())
case <-finished:
case <-c.closech:
return
}
}
}()
}

// finish is called when the query has succeeded.

func (c *heraConnection) finish() {
if !c.watching || c.finished == nil {
return
}
select {
case c.finished <- struct{}{}:
c.watching = false
case <-c.closech:
}
}

func (c *heraConnection) cancel(err error) {
if logger.GetLogger().V(logger.Debug) {
logger.GetLogger().Log(logger.Debug, c.id, "ctx error:", err)
}
c.cancelled.Set(err)
c.cleanup()
}

func (c *heraConnection) watchCancel(ctx context.Context) error {
if c.watching {
// Reach here if cancelled, the connection is already invalid
c.cleanup()
return nil
}
// When ctx is already cancelled, don't watch it.
if err := ctx.Err(); err != nil {
return err
}
// When ctx is not cancellable, don't watch it.
if ctx.Done() == nil {
return nil
}

if c.watcher == nil {
return nil
}

c.watching = true
c.watcher <- ctx
return nil
}

// Closes the network connection and unsets internal variables. Do not call this
// function after successfully authentication, call Close instead. This function
// is called before auth or on auth failure because HERA will have already
// closed the network connection.
func (c *heraConnection) cleanup() {
if c.closed.Swap(true) {
return
}

// Makes cleanup idempotent
close(c.closech)
if c.conn == nil {
return
}
c.finish()
if err := c.conn.Close(); err != nil {
logger.GetLogger().Log(logger.Alert, err)
}
}

// error
func (c *heraConnection) error() error {
if c.closed.Load() {
if err := c.cancelled.Value(); err != nil {
return err
}
return ErrInvalidConn
}
return nil
}

// Prepare returns a prepared statement, bound to this connection.
func (c *heraConnection) Prepare(query string) (driver.Stmt, error) {
Expand Down Expand Up @@ -82,16 +225,28 @@ func (c *heraConnection) Begin() (driver.Tx, error) {
if logger.GetLogger().V(logger.Debug) {
logger.GetLogger().Log(logger.Debug, c.id, "begin txn")
}
if c.closed.Load() {
logger.GetLogger().Log(logger.Alert, ErrInvalidConn)
return nil, driver.ErrBadConn
}
return &tx{hera: c}, nil
}

// internal function to execute commands
func (c *heraConnection) exec(cmd int, payload []byte) error {
if c.closed.Load() {
logger.GetLogger().Log(logger.Alert, ErrInvalidConn)
return driver.ErrBadConn
}
return c.execNs(netstring.NewNetstringFrom(cmd, payload))
}

// internal function to execute commands
func (c *heraConnection) execNs(ns *netstring.Netstring) error {
if c.closed.Load() {
logger.GetLogger().Log(logger.Alert, ErrInvalidConn)
return driver.ErrBadConn
}
if logger.GetLogger().V(logger.Verbose) {
payload := string(ns.Payload)
if len(payload) > 1000 {
Expand All @@ -105,6 +260,10 @@ func (c *heraConnection) execNs(ns *netstring.Netstring) error {

// returns the next message from the connection
func (c *heraConnection) getResponse() (*netstring.Netstring, error) {
if c.closed.Load() {
logger.GetLogger().Log(logger.Alert, ErrInvalidConn)
return nil, driver.ErrBadConn
}
ns, err := c.reader.ReadNext()
if err != nil {
if logger.GetLogger().V(logger.Warning) {
Expand Down Expand Up @@ -177,66 +336,92 @@ func (c *heraConnection) SetCalCorrID(corrID string) {
}

// SetClientInfo actually sends it over to Hera server
func (c *heraConnection) SetClientInfo(poolName string, host string)(error){
func (c *heraConnection) SetClientInfo(poolName string, host string) error {
if len(poolName) <= 0 && len(host) <= 0 {
return nil
}

if c.closed.Load() {
logger.GetLogger().Log(logger.Alert, ErrInvalidConn)
return driver.ErrBadConn
}

pid := os.Getpid()
data := fmt.Sprintf("PID: %d, HOST: %s, Poolname: %s, Command: SetClientInfo,", pid, host, poolName)
c.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data)))
if logger.GetLogger().V(logger.Verbose) {
logger.GetLogger().Log(logger.Verbose, "SetClientInfo", c.clientinfo.Serialized)
}

_, err := c.conn.Write(c.clientinfo.Serialized)
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to send client info")
}
return errors.New("Failed custom auth, failed to send client info")
}
ns, err := c.reader.ReadNext()
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to read server info")
}
return errors.New("Failed to read server info")
}
if logger.GetLogger().V(logger.Debug) {
logger.GetLogger().Log(logger.Debug, "Server info:", string(ns.Payload))
}
c.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data)))
if logger.GetLogger().V(logger.Verbose) {
logger.GetLogger().Log(logger.Verbose, "SetClientInfo", c.clientinfo.Serialized)
}

_, err := c.conn.Write(c.clientinfo.Serialized)
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to send client info")
}
return errors.New("Failed custom auth, failed to send client info")
}
ns, err := c.reader.ReadNext()
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to read server info")
}
return errors.New("Failed to read server info")
}
if logger.GetLogger().V(logger.Debug) {
logger.GetLogger().Log(logger.Debug, "Server info:", string(ns.Payload))
}
return nil
}

func (c *heraConnection) SetClientInfoWithPoolStack(poolName string, host string, poolStack string)(error){
func (c *heraConnection) SetClientInfoWithPoolStack(poolName string, host string, poolStack string) error {
if len(poolName) <= 0 && len(host) <= 0 && len(poolStack) <= 0 {
return nil
}

if c.closed.Load() {
logger.GetLogger().Log(logger.Alert, ErrInvalidConn)
return driver.ErrBadConn
}

pid := os.Getpid()
data := fmt.Sprintf("PID: %d, HOST: %s, Poolname: %s, PoolStack: %s, Command: SetClientInfo,", pid, host, poolName, poolStack)
c.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data)))
if logger.GetLogger().V(logger.Verbose) {
logger.GetLogger().Log(logger.Verbose, "SetClientInfo", c.clientinfo.Serialized)
}

_, err := c.conn.Write(c.clientinfo.Serialized)
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to send client info")
}
return errors.New("Failed custom auth, failed to send client info")
}
ns, err := c.reader.ReadNext()
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to read server info")
}
return errors.New("Failed to read server info")
}
if logger.GetLogger().V(logger.Debug) {
logger.GetLogger().Log(logger.Debug, "Server info:", string(ns.Payload))
}
c.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data)))
if logger.GetLogger().V(logger.Verbose) {
logger.GetLogger().Log(logger.Verbose, "SetClientInfo", c.clientinfo.Serialized)
}

_, err := c.conn.Write(c.clientinfo.Serialized)
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to send client info")
}
return errors.New("Failed custom auth, failed to send client info")
}
ns, err := c.reader.ReadNext()
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to read server info")
}
return errors.New("Failed to read server info")
}
if logger.GetLogger().V(logger.Debug) {
logger.GetLogger().Log(logger.Debug, "Server info:", string(ns.Payload))
}
return nil
}
}

func (c *heraConnection) getID() string {
return c.id
}

func (c *heraConnection) getCorrID() *netstring.Netstring {
return c.corrID
}

func (c *heraConnection) getShardKeyPayload() []byte {
return c.shardKeyPayload
}

func (c *heraConnection) setCorrID(newCorrID *netstring.Netstring) {
c.corrID = newCorrID
}
Loading