From ffd4555a0c45badc76308e434dce6171638ef150 Mon Sep 17 00:00:00 2001 From: David Piegza <697113+davidpiegza@users.noreply.github.com> Date: Mon, 25 Sep 2023 10:53:31 +0000 Subject: [PATCH] Add shutdown state in MySQL server plugin --- go/mysql/conn.go | 5 +++++ go/vt/vtgate/plugin_mysql_server.go | 11 ++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 6f3643ebc7f..87959a03302 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -209,6 +209,7 @@ type Conn struct { // cancel keep the cancel function for the current executing query. // this is used by `kill [query|connection] ID` command from other connection. cancel context.CancelFunc + // this is used to mark the connection to be closed so that the command phase for the connection can be stopped and // the connection gets closed. closing bool @@ -1711,3 +1712,7 @@ func (c *Conn) IsMarkedForClose() bool { func GetTestConn() *Conn { return newConn(testConn{}) } + +func (c *Conn) IsShuttingDown() bool { + return c.listener.shutdown.Load() +} diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index bfbb7b105f8..273592b5bf7 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -201,6 +201,12 @@ func startSpan(ctx context.Context, query, label string) (trace.Span, context.Co } func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { + session := vh.session(c) + if c.IsShuttingDown() && !session.InTransaction { + c.MarkForClose() + return sqlerror.NewSQLError(sqlerror.ERServerShutdown, sqlerror.SSNetError, "Server shutdown in progress") + } + ctx, cancel := context.WithCancel(context.Background()) c.UpdateCancelCtx(cancel) @@ -229,7 +235,6 @@ func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sq "VTGate MySQL Connector" /* subcomponent: part of the client */) ctx = callerid.NewContext(ctx, ef, im) - session := vh.session(c) if !session.InTransaction { vh.busyConnections.Add(1) } @@ -614,11 +619,11 @@ func newMysqlUnixSocket(address string, authServer mysql.AuthServer, handler mys func (srv *mysqlServer) shutdownMysqlProtocolAndDrain() { if srv.tcpListener != nil { - srv.tcpListener.Close() + srv.tcpListener.Shutdown() srv.tcpListener = nil } if srv.unixListener != nil { - srv.unixListener.Close() + srv.unixListener.Shutdown() srv.unixListener = nil } if srv.sigChan != nil {