diff --git a/go/vt/vttablet/tabletserver/connpool/dbconn.go b/go/vt/vttablet/tabletserver/connpool/dbconn.go index ab5b69535d4..5632c651e84 100644 --- a/go/vt/vttablet/tabletserver/connpool/dbconn.go +++ b/go/vt/vttablet/tabletserver/connpool/dbconn.go @@ -52,7 +52,7 @@ type Conn struct { env tabletenv.Env dbaPool *dbconnpool.ConnectionPool stats *tabletenv.Stats - current atomic.Value + current atomic.Pointer[string] // err will be set if a query is killed through a Kill. errmu sync.Mutex @@ -76,7 +76,6 @@ func newPooledConn(ctx context.Context, pool *Pool, appParams dbconfigs.Connecto stats: pool.env.Stats(), dbaPool: pool.dbaPool, } - db.current.Store("") return db, nil } @@ -91,7 +90,6 @@ func NewConn(ctx context.Context, params dbconfigs.Connector, dbaPool *dbconnpoo dbaPool: dbaPool, stats: tabletenv.NewStats(servenv.NewExporter("Temp", "Tablet")), } - dbconn.current.Store("") if setting == nil { return dbconn, nil } @@ -152,8 +150,8 @@ func (dbc *Conn) Exec(ctx context.Context, query string, maxrows int, wantfields } func (dbc *Conn) execOnce(ctx context.Context, query string, maxrows int, wantfields bool) (*sqltypes.Result, error) { - dbc.current.Store(query) - defer dbc.current.Store("") + dbc.current.Store(&query) + defer dbc.current.Store(nil) // Check if the context is already past its deadline before // trying to execute the query. @@ -161,40 +159,33 @@ func (dbc *Conn) execOnce(ctx context.Context, query string, maxrows int, wantfi return nil, fmt.Errorf("%v before execution started", err) } - defer dbc.stats.MySQLTimings.Record("Exec", time.Now()) + now := time.Now() + defer dbc.stats.MySQLTimings.Record("Exec", now) - resultChan := make(chan *sqltypes.Result, 1) - errChan := make(chan error, 1) + type execResult struct { + result *sqltypes.Result + err error + } - startTime := time.Now() + ch := make(chan execResult) go func() { result, err := dbc.conn.ExecuteFetch(query, maxrows, wantfields) - if err != nil { - errChan <- err - } else { - resultChan <- result - } + ch <- execResult{result, err} }() - var err error - var result *sqltypes.Result - select { case <-ctx.Done(): killCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - dbc.KillWithContext(killCtx, ctx.Err().Error(), time.Since(startTime)) + _ = dbc.KillWithContext(killCtx, ctx.Err().Error(), time.Since(now)) return nil, dbc.Err() - case err = <-errChan: - case result = <-resultChan: - } - - if dbcErr := dbc.Err(); dbcErr != nil { - return nil, dbcErr + case r := <-ch: + if dbcErr := dbc.Err(); dbcErr != nil { + return nil, dbcErr + } + return r.result, r.err } - - return result, err } // ExecOnce executes the specified query, but does not retry on connection errors. @@ -270,35 +261,30 @@ func (dbc *Conn) Stream(ctx context.Context, query string, callback func(*sqltyp } func (dbc *Conn) streamOnce(ctx context.Context, query string, callback func(*sqltypes.Result) error, alloc func() *sqltypes.Result, streamBufferSize int) error { - defer dbc.stats.MySQLTimings.Record("ExecStream", time.Now()) - - dbc.current.Store(query) - defer dbc.current.Store("") + dbc.current.Store(&query) + defer dbc.current.Store(nil) - errChan := make(chan error, 1) - startTime := time.Now() + now := time.Now() + defer dbc.stats.MySQLTimings.Record("ExecStream", now) + ch := make(chan error) go func() { - errChan <- dbc.conn.ExecuteStreamFetch(query, callback, alloc, streamBufferSize) + ch <- dbc.conn.ExecuteStreamFetch(query, callback, alloc, streamBufferSize) }() - var err error - select { case <-ctx.Done(): killCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - dbc.KillWithContext(killCtx, ctx.Err().Error(), time.Since(startTime)) + _ = dbc.KillWithContext(killCtx, ctx.Err().Error(), time.Since(now)) return dbc.Err() - case err = <-errChan: - } - - if dbcErr := dbc.Err(); dbcErr != nil { - return dbcErr + case err := <-ch: + if dbcErr := dbc.Err(); dbcErr != nil { + return dbcErr + } + return err } - - return err } // StreamOnce executes the query and streams the results. But, does not retry on connection errors. @@ -401,7 +387,7 @@ func (dbc *Conn) Kill(reason string, elapsed time.Duration) error { return dbc.KillWithContext(context.Background(), reason, elapsed) } -// Kill kills the currently executing query both on MySQL side +// KillWithContext kills the currently executing query both on MySQL side // and on the connection side. If no query is executing, it's a no-op. // Kill will also not kill a query more than once. func (dbc *Conn) KillWithContext(ctx context.Context, reason string, elapsed time.Duration) error { @@ -426,18 +412,11 @@ func (dbc *Conn) KillWithContext(ctx context.Context, reason string, elapsed tim } defer killConn.Recycle() - errChan := make(chan error, 1) - resultChan := make(chan *sqltypes.Result, 1) - + ch := make(chan error) + sql := fmt.Sprintf("kill %d", dbc.conn.ID()) go func() { - sql := fmt.Sprintf("kill %d", dbc.conn.ID()) - // TODO: Allow pushing ctx down to ExecuteFetch. - result, err := killConn.Conn.ExecuteFetch(sql, 10000, false) - if err != nil { - errChan <- err - } else { - resultChan <- result - } + _, err := killConn.Conn.ExecuteFetch(sql, -1, false) + ch <- err }() select { @@ -448,17 +427,21 @@ func (dbc *Conn) KillWithContext(ctx context.Context, reason string, elapsed tim log.Warningf("Query may be hung: %s", dbc.CurrentForLogging()) return context.Cause(ctx) - case err := <-errChan: - log.Errorf("Could not kill query ID %v %s: %v", dbc.conn.ID(), dbc.CurrentForLogging(), err) - return err - case <-resultChan: + case err := <-ch: + if err != nil { + log.Errorf("Could not kill query ID %v %s: %v", dbc.conn.ID(), dbc.CurrentForLogging(), err) + return err + } return nil } } // Current returns the currently executing query. func (dbc *Conn) Current() string { - return dbc.current.Load().(string) + if q := dbc.current.Load(); q != nil { + return *q + } + return "" } // ID returns the connection id.