Skip to content

Commit

Permalink
Add tests for KillWithContext logic
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Joos <[email protected]>
  • Loading branch information
danieljoos committed Feb 21, 2024
1 parent c4af426 commit b7e6d3d
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 10 deletions.
9 changes: 8 additions & 1 deletion go/mysql/fakesqldb/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,11 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R
}
key := strings.ToLower(query)
db.mu.Lock()
defer db.mu.Unlock()
db.queryCalled[key]++
db.querylog = append(db.querylog, key)
// Check if we should close the connection and provoke errno 2013.
if db.shouldClose.Load() {
defer db.mu.Unlock()
c.Close()

// log error
Expand All @@ -394,6 +394,8 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R
// The driver may send this at connection time, and we don't want it to
// interfere.
if key == "set names utf8" || strings.HasPrefix(key, "set collation_connection = ") {
defer db.mu.Unlock()

// log error
if err := callback(&sqltypes.Result{}); err != nil {
log.Errorf("callback failed : %v", err)
Expand All @@ -403,12 +405,14 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R

// check if we should reject it.
if err, ok := db.rejectedData[key]; ok {
db.mu.Unlock()
return err
}

// Check explicit queries from AddQuery().
result, ok := db.data[key]
if ok {
db.mu.Unlock()
if f := result.BeforeFunc; f != nil {
f()
}
Expand All @@ -419,6 +423,7 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R
for _, pat := range db.patternData {
if pat.expr.MatchString(query) {
userCallback, ok := db.queryPatternUserCallback[pat.expr]
db.mu.Unlock()
if ok {
userCallback(query)
}
Expand All @@ -429,6 +434,8 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R
}
}

defer db.mu.Unlock()

if db.neverFail.Load() {
return callback(&sqltypes.Result{})
}
Expand Down
24 changes: 15 additions & 9 deletions go/vt/vttablet/tabletserver/connpool/dbconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ import (
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
)

const defaultKillTimeout = 5 * time.Second

// Conn is a db connection for tabletserver.
// It performs automatic reconnects as needed.
// Its Execute function has a timeout that can kill
Expand All @@ -57,6 +59,8 @@ type Conn struct {
// err will be set if a query is killed through a Kill.
errmu sync.Mutex
err error

killTimeout time.Duration
}

// NewConnection creates a new DBConn. It triggers a CheckMySQL if creation fails.
Expand All @@ -71,10 +75,11 @@ func newPooledConn(ctx context.Context, pool *Pool, appParams dbconfigs.Connecto
return nil, err
}
db := &Conn{
conn: c,
env: pool.env,
stats: pool.env.Stats(),
dbaPool: pool.dbaPool,
conn: c,
env: pool.env,
stats: pool.env.Stats(),
dbaPool: pool.dbaPool,
killTimeout: defaultKillTimeout,
}
return db, nil
}
Expand All @@ -86,9 +91,10 @@ func NewConn(ctx context.Context, params dbconfigs.Connector, dbaPool *dbconnpoo
return nil, err
}
dbconn := &Conn{
conn: c,
dbaPool: dbaPool,
stats: tabletenv.NewStats(servenv.NewExporter("Temp", "Tablet")),
conn: c,
dbaPool: dbaPool,
stats: tabletenv.NewStats(servenv.NewExporter("Temp", "Tablet")),
killTimeout: defaultKillTimeout,
}
if setting == nil {
return dbconn, nil
Expand Down Expand Up @@ -175,7 +181,7 @@ func (dbc *Conn) execOnce(ctx context.Context, query string, maxrows int, wantfi

select {
case <-ctx.Done():
killCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
killCtx, cancel := context.WithTimeout(context.Background(), dbc.killTimeout)
defer cancel()

_ = dbc.KillWithContext(killCtx, ctx.Err().Error(), time.Since(now))
Expand Down Expand Up @@ -274,7 +280,7 @@ func (dbc *Conn) streamOnce(ctx context.Context, query string, callback func(*sq

select {
case <-ctx.Done():
killCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
killCtx, cancel := context.WithTimeout(context.Background(), dbc.killTimeout)
defer cancel()

_ = dbc.KillWithContext(killCtx, ctx.Err().Error(), time.Since(now))
Expand Down
98 changes: 98 additions & 0 deletions go/vt/vttablet/tabletserver/connpool/dbconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ import (
"vitess.io/vitess/go/mysql/fakesqldb"
"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
)

func compareTimingCounts(t *testing.T, op string, delta int64, before, after map[string]int64) {
Expand Down Expand Up @@ -291,6 +293,57 @@ func TestDBConnKill(t *testing.T) {
}
}

func TestDBKillWithContext(t *testing.T) {
db := fakesqldb.New(t)
defer db.Close()
connPool := newPool()
connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams())
defer connPool.Close()
dbConn, err := newPooledConn(context.Background(), connPool, db.ConnParams())
if dbConn != nil {
defer dbConn.Close()
}
require.NoError(t, err)

query := fmt.Sprintf("kill %d", dbConn.ID())
db.AddQuery(query, &sqltypes.Result{})
db.SetBeforeFunc(query, func() {
// should take longer than our context deadline below.
time.Sleep(200 * time.Millisecond)
})

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

// KillWithContext should return context.DeadlineExceeded
err = dbConn.KillWithContext(ctx, "test kill", 0)
require.ErrorIs(t, err, context.DeadlineExceeded)
}

func TestDBKillWithContextDoneContext(t *testing.T) {
db := fakesqldb.New(t)
defer db.Close()
connPool := newPool()
connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams())
defer connPool.Close()
dbConn, err := newPooledConn(context.Background(), connPool, db.ConnParams())
if dbConn != nil {
defer dbConn.Close()
}
require.NoError(t, err)

query := fmt.Sprintf("kill %d", dbConn.ID())
db.AddRejectedQuery(query, errors.New("rejected"))

contextErr := errors.New("context error")
ctx, cancel := context.WithCancelCause(context.Background())
cancel(contextErr) // cancel the context immediately

// KillWithContext should return the cancellation cause
err = dbConn.KillWithContext(ctx, "test kill", 0)
require.ErrorIs(t, err, contextErr)
}

// TestDBConnClose tests that an Exec returns immediately if a connection
// is asynchronously killed (and closed) in the middle of an execution.
func TestDBConnClose(t *testing.T) {
Expand Down Expand Up @@ -519,3 +572,48 @@ func TestDBConnReApplySetting(t *testing.T) {

db.VerifyAllExecutedOrFail()
}

func TestDBExecOnceKillTimeout(t *testing.T) {
db := fakesqldb.New(t)
defer db.Close()
connPool := newPool()
connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams())
defer connPool.Close()
dbConn, err := newPooledConn(context.Background(), connPool, db.ConnParams())
if dbConn != nil {
defer dbConn.Close()
}
require.NoError(t, err)

// A very long running query that will be killed.
expectedQuery := "select 1"
var timestampQuery time.Time
db.AddQuery(expectedQuery, &sqltypes.Result{})
db.SetBeforeFunc(expectedQuery, func() {
timestampQuery = time.Now()
// should take longer than our context deadline below.
time.Sleep(1000 * time.Millisecond)
})

// We expect a kill-query to be fired, too.
// It should also run into a timeout.
var timestampKill time.Time
dbConn.killTimeout = 100 * time.Millisecond
db.AddQueryPatternWithCallback(`kill \d+`, &sqltypes.Result{}, func(string) {
timestampKill = time.Now()
// should take longer than the configured kill timeout above.
time.Sleep(200 * time.Millisecond)
})

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

result, err := dbConn.ExecOnce(ctx, "select 1", 1, false)
timestampDone := time.Now()

require.Error(t, err)
require.Equal(t, vtrpcpb.Code_CANCELED, vterrors.Code(err))
require.Nil(t, result)
require.WithinDuration(t, timestampQuery, timestampKill, 150*time.Millisecond)
require.WithinDuration(t, timestampKill, timestampDone, 150*time.Millisecond)
}

0 comments on commit b7e6d3d

Please sign in to comment.