From 1bd34db54d9e8af3c8ed5667236bc1da7d0bd785 Mon Sep 17 00:00:00 2001 From: Arthur Schreiber Date: Wed, 8 Nov 2023 14:52:08 +0000 Subject: [PATCH] Add new `KillWithContext` function. Signed-off-by: Arthur Schreiber --- .../vttablet/tabletserver/connpool/dbconn.go | 42 +++++++++++++++---- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/go/vt/vttablet/tabletserver/connpool/dbconn.go b/go/vt/vttablet/tabletserver/connpool/dbconn.go index 63f4c73520e..3d3f866b180 100644 --- a/go/vt/vttablet/tabletserver/connpool/dbconn.go +++ b/go/vt/vttablet/tabletserver/connpool/dbconn.go @@ -362,10 +362,19 @@ func (dbc *Conn) IsClosed() bool { return dbc.conn.IsClosed() } +// Kill wraps KillWithContext using context.Background. +func (dbc *Conn) Kill(reason string, elapsed time.Duration) error { + return dbc.KillWithContext(context.Background(), reason, elapsed) +} + // Kill kills the currently executing query both on MySQL side // and on the connection side. If no query is executing, it's a no-op. // Kill will also not kill a query more than once. -func (dbc *Conn) Kill(reason string, elapsed time.Duration) error { +func (dbc *Conn) KillWithContext(ctx context.Context, reason string, elapsed time.Duration) error { + if cause := context.Cause(ctx); cause != nil { + return cause + } + dbc.stats.KillCounters.Add("Queries", 1) log.Infof("Due to %s, elapsed time: %v, killing query ID %v %s", reason, elapsed, dbc.conn.ID(), dbc.CurrentForLogging()) @@ -376,20 +385,37 @@ func (dbc *Conn) Kill(reason string, elapsed time.Duration) error { dbc.conn.Close() // Server side action. Kill the session. - killConn, err := dbc.dbaPool.Get(context.TODO()) + killConn, err := dbc.dbaPool.Get(ctx) if err != nil { log.Warningf("Failed to get conn from dba pool: %v", err) return err } defer killConn.Recycle() - sql := fmt.Sprintf("kill %d", dbc.conn.ID()) - _, err = killConn.Conn.ExecuteFetch(sql, 10000, false) - if err != nil { - log.Errorf("Could not kill query ID %v %s: %v", dbc.conn.ID(), - dbc.CurrentForLogging(), err) + + errChan := make(chan error, 1) + resultChan := make(chan *sqltypes.Result, 1) + + go func() { + sql := fmt.Sprintf("kill %d", dbc.conn.ID()) + // TODO: Allow pushing ctx down to ExecuteFetch. + result, err := killConn.Conn.ExecuteFetch(sql, 10000, false) + if err != nil { + errChan <- err + } else { + resultChan <- result + } + }() + + select { + case <-ctx.Done(): + killConn.Close() + return context.Cause(ctx) + case err := <-errChan: + log.Errorf("Could not kill query ID %v %s: %v", dbc.conn.ID(), dbc.CurrentForLogging(), err) return err + case <-resultChan: + return nil } - return nil } // Current returns the currently executing query.