diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 64de7ae3..90249107 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -11,7 +11,7 @@ import ( "net" "time" - "github.com/pingcap/tidb/parser/mysql" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tiproxy/lib/util/errors" pnet "github.com/pingcap/tiproxy/pkg/proxy/net" @@ -73,7 +73,7 @@ func (auth *Authenticator) writeProxyProtocol(clientIO, backendIO *pnet.PacketIO } func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapability pnet.Capability) error { - requiredBackendCaps := defRequiredBackendCaps & pnet.Capability(auth.capability) + requiredBackendCaps := defRequiredBackendCaps & auth.capability if auth.requireBackendTLS { requiredBackendCaps |= pnet.ClientSSL } @@ -100,7 +100,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte } cid, _ := cctx.Value(ConnContextKeyConnID).(uint64) - if err := clientIO.WriteInitialHandshake(proxyCapability, auth.salt, mysql.AuthNativePassword, handshakeHandler.GetServerVersion(), cid); err != nil { + if err := clientIO.WriteInitialHandshake(proxyCapability, auth.salt, pnet.AuthNativePassword, handshakeHandler.GetServerVersion(), cid); err != nil { return err } pkt, isSSL, err := clientIO.ReadSSLRequestOrHandshakeResp() @@ -126,7 +126,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte } if commonCaps := frontendCapability & requiredFrontendCaps; commonCaps != requiredFrontendCaps { logger.Error("require frontend capabilities", zap.Stringer("common", commonCaps), zap.Stringer("required", requiredFrontendCaps)) - if writeErr := clientIO.WriteErrPacket(mysql.ErrNotSupportedAuthMode); writeErr != nil { + if writeErr := clientIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_NOT_SUPPORTED_AUTH_MODE)); writeErr != nil { return writeErr } return errors.Wrapf(ErrCapabilityNegotiation, "require %s from frontend", requiredFrontendCaps&^commonCaps) @@ -220,14 +220,14 @@ loop: return err } switch serverPkt[0] { - case mysql.OKHeader: + case pnet.OKHeader.Byte(): return nil - case mysql.ErrHeader: + case pnet.ErrHeader.Byte(): return pnet.ParseErrorPacket(serverPkt) default: // mysql.AuthSwitchRequest, ShaCommand - if serverPkt[0] == mysql.AuthSwitchRequest { + if serverPkt[0] == pnet.AuthSwitchHeader.Byte() { pluginName = string(serverPkt[1 : bytes.IndexByte(serverPkt[1:], 0)+1]) - } else if serverPkt[0] == 1 && pluginName == mysql.AuthCachingSha2Password && len(serverPkt) == 2 && serverPkt[1] == 3 { + } else if serverPkt[0] == 1 && pluginName == pnet.AuthCachingSha2Password && len(serverPkt) == 2 && serverPkt[1] == 3 { // caching_sha2_password fast path continue loop } @@ -262,7 +262,7 @@ func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, bac return err } - if err := auth.verifyBackendCaps(logger, pnet.Capability(backendCapability)); err != nil { + if err := auth.verifyBackendCaps(logger, backendCapability); err != nil { return err } @@ -320,7 +320,7 @@ func (auth *Authenticator) writeAuthHandshake( enableTLS = true } else { // When client TLS is disabled, also disables proxy TLS. - enableTLS = pnet.Capability(auth.capability)&pnet.ClientSSL != 0 && backendCapability&pnet.ClientSSL != 0 && backendTLSConfig != nil + enableTLS = auth.capability&pnet.ClientSSL != 0 && backendCapability&pnet.ClientSSL != 0 && backendTLSConfig != nil } if enableTLS { resp.Capability |= pnet.ClientSSL @@ -355,9 +355,9 @@ func (auth *Authenticator) handleSecondAuthResult(backendIO *pnet.PacketIO) erro } switch data[0] { - case mysql.OKHeader: + case pnet.OKHeader.Byte(): return nil - case mysql.ErrHeader: + case pnet.ErrHeader.Byte(): return pnet.ParseErrorPacket(data) default: // mysql.AuthSwitchRequest, ShaCommand: return errors.Errorf("read unexpected command: %#x", data[0]) diff --git a/pkg/proxy/backend/mock_backend_test.go b/pkg/proxy/backend/mock_backend_test.go index 97d2ffe7..5b15fc79 100644 --- a/pkg/proxy/backend/mock_backend_test.go +++ b/pkg/proxy/backend/mock_backend_test.go @@ -7,8 +7,7 @@ import ( "crypto/tls" "encoding/binary" - gomysql "github.com/go-mysql-org/go-mysql/mysql" - "github.com/pingcap/tidb/parser/mysql" + "github.com/go-mysql-org/go-mysql/mysql" pnet "github.com/pingcap/tiproxy/pkg/proxy/net" ) @@ -74,7 +73,7 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error { } // upgrade to TLS capability := binary.LittleEndian.Uint16(clientPkt[:2]) - sslEnabled := uint32(capability)&mysql.ClientSSL > 0 && mb.capability&pnet.ClientSSL > 0 + sslEnabled := pnet.Capability(capability)&pnet.ClientSSL > 0 && mb.capability&pnet.ClientSSL > 0 if sslEnabled { if _, err = packetIO.ServerTLSHandshake(mb.tlsConfig); err != nil { return err @@ -98,7 +97,7 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error { } func (mb *mockBackend) verifyPassword(packetIO *pnet.PacketIO, resp *pnet.HandshakeResp) error { - if resp.AuthPlugin != mysql.AuthTiDBSessionToken { + if resp.AuthPlugin != pnet.AuthTiDBSessionToken { var err error if err = packetIO.WriteSwitchRequest(mb.authPlugin, mb.salt); err != nil { return err @@ -107,7 +106,7 @@ func (mb *mockBackend) verifyPassword(packetIO *pnet.PacketIO, resp *pnet.Handsh return err } switch mb.authPlugin { - case mysql.AuthCachingSha2Password: + case pnet.AuthCachingSha2Password: if err = packetIO.WriteShaCommand(); err != nil { return err } @@ -121,7 +120,7 @@ func (mb *mockBackend) verifyPassword(packetIO *pnet.PacketIO, resp *pnet.Handsh return err } } else { - if err := packetIO.WriteErrPacket(mysql.ErrAccessDenied); err != nil { + if err := packetIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_ACCESS_DENIED_ERROR)); err != nil { return err } } @@ -150,7 +149,7 @@ func (mb *mockBackend) respondOnce(packetIO *pnet.PacketIO) error { case responseTypeOK: return mb.respondOK(packetIO) case responseTypeErr: - return packetIO.WriteErrPacket(mysql.ErrUnknown) + return packetIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_UNKNOWN_ERROR)) case responseTypeResultSet: if pnet.Command(pkt[0]) == pnet.ComQuery && string(pkt[1:]) == sqlQueryState { return mb.respondSessionStates(packetIO) @@ -179,16 +178,16 @@ func (mb *mockBackend) respondOnce(packetIO *pnet.PacketIO) error { case responseTypeNone: return nil } - return packetIO.WriteErrPacket(mysql.ErrUnknown) + return packetIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_UNKNOWN_ERROR)) } func (mb *mockBackend) respondOK(packetIO *pnet.PacketIO) error { for i := 0; i < mb.stmtNum; i++ { status := mb.status if i < mb.stmtNum-1 { - status |= mysql.ServerMoreResultsExists + status |= mysql.SERVER_MORE_RESULTS_EXISTS } else { - status &= ^mysql.ServerMoreResultsExists + status &= ^mysql.SERVER_MORE_RESULTS_EXISTS } if err := packetIO.WriteOKPacket(status, pnet.OKHeader); err != nil { return err @@ -242,16 +241,16 @@ func (mb *mockBackend) respondResultSet(packetIO *pnet.PacketIO) error { } func (mb *mockBackend) writeResultSet(packetIO *pnet.PacketIO, names []string, values [][]any) error { - rs, err := gomysql.BuildSimpleTextResultset(names, values) + rs, err := mysql.BuildSimpleTextResultset(names, values) if err != nil { return err } for i := 0; i < mb.stmtNum; i++ { status := mb.status if i < mb.stmtNum-1 { - status |= mysql.ServerMoreResultsExists + status |= mysql.SERVER_MORE_RESULTS_EXISTS } else { - status &= ^mysql.ServerMoreResultsExists + status &= ^mysql.SERVER_MORE_RESULTS_EXISTS } data := pnet.DumpLengthEncodedInt(nil, uint64(len(names))) if err := packetIO.WritePacket(data, false); err != nil { @@ -263,7 +262,7 @@ func (mb *mockBackend) writeResultSet(packetIO *pnet.PacketIO, names []string, v } } - if status&mysql.ServerStatusCursorExists == 0 { + if status&mysql.SERVER_STATUS_CURSOR_EXISTS == 0 { if mb.capability&pnet.ClientDeprecateEOF == 0 { if err := packetIO.WriteEOFPacket(status); err != nil { return err @@ -291,12 +290,12 @@ func (mb *mockBackend) respondLoadFile(packetIO *pnet.PacketIO) error { for i := 0; i < mb.stmtNum; i++ { status := mb.status if i < mb.stmtNum-1 { - status |= mysql.ServerMoreResultsExists + status |= mysql.SERVER_MORE_RESULTS_EXISTS } else { - status &= ^mysql.ServerMoreResultsExists + status &= ^mysql.SERVER_MORE_RESULTS_EXISTS } data := make([]byte, 0, 1+len(mockCmdStr)) - data = append(data, mysql.LocalInFileHeader) + data = append(data, pnet.LocalInFileHeader.Byte()) data = append(data, []byte(mockCmdStr)...) if err := packetIO.WritePacket(data, true); err != nil { return err @@ -321,7 +320,7 @@ func (mb *mockBackend) respondLoadFile(packetIO *pnet.PacketIO) error { // respond to Prepare func (mb *mockBackend) respondPrepare(packetIO *pnet.PacketIO) error { - data := []byte{mysql.OKHeader} + data := []byte{pnet.OKHeader.Byte()} data = pnet.DumpUint32(data, uint32(mockCmdInt)) data = pnet.DumpUint16(data, uint16(mb.columns)) data = pnet.DumpUint16(data, uint16(mb.params)) diff --git a/pkg/proxy/net/packetio_mysql.go b/pkg/proxy/net/packetio_mysql.go index 313864bd..f17ba3e3 100644 --- a/pkg/proxy/net/packetio_mysql.go +++ b/pkg/proxy/net/packetio_mysql.go @@ -5,7 +5,6 @@ package net import ( "encoding/binary" - "fmt" "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/tiproxy/lib/util/errors" @@ -96,26 +95,14 @@ func (p *PacketIO) ReadSSLRequestOrHandshakeResp() (pkt []byte, isSSL bool, err } // WriteErrPacket writes an Error packet. -func (p *PacketIO) WriteErrPacket(code uint16, message ...any) error { - data := make([]byte, 0, 9+len(message)) +func (p *PacketIO) WriteErrPacket(merr *mysql.MyError) error { + data := make([]byte, 0, 9+len(merr.Message)) data = append(data, ErrHeader.Byte()) - data = append(data, byte(code), byte(code>>8)) - - // TODO: ClientProtocol41 must be enabled for state + data = append(data, byte(merr.Code), byte(merr.Code>>8)) + // ClientProtocol41 is always enabled. data = append(data, '#') - s, ok := mysql.MySQLState[code] - if !ok { - s = mysql.DEFAULT_MYSQL_STATE - } - data = append(data, s...) - - var msg string - if format, ok := mysql.MySQLErrName[code]; ok { - msg = fmt.Sprintf(format, message...) - } else { - msg = fmt.Sprint(message...) - } - data = append(data, msg...) + data = append(data, merr.State...) + data = append(data, merr.Message...) return p.WritePacket(data, true) } @@ -124,7 +111,7 @@ func (p *PacketIO) WriteOKPacket(status uint16, header Header) error { data := make([]byte, 0, 7) data = append(data, header.Byte()) data = append(data, 0, 0) - // ClientProtocol41 must be enabled. + // ClientProtocol41 is always enabled. data = DumpUint16(data, status) data = append(data, 0, 0) return p.WritePacket(data, true) @@ -135,7 +122,7 @@ func (p *PacketIO) WriteEOFPacket(status uint16) error { data := make([]byte, 0, 5) data = append(data, EOFHeader.Byte()) data = append(data, 0, 0) - // ClientProtocol41 must be enabled. + // ClientProtocol41 is always enabled. data = DumpUint16(data, status) return p.WritePacket(data, true) } @@ -149,7 +136,8 @@ func (p *PacketIO) WriteUserError(err error) { if !errors.As(err, &ue) { return } - if writeErr := p.WriteErrPacket(mysql.ER_UNKNOWN_ERROR, ue.UserMsg()); writeErr != nil { + myErr := mysql.NewError(mysql.ER_UNKNOWN_ERROR, ue.UserMsg()) + if writeErr := p.WriteErrPacket(myErr); writeErr != nil { p.logger.Error("writing error to client failed", zap.NamedError("mysql_err", err), zap.NamedError("write_err", writeErr)) } } diff --git a/pkg/proxy/net/packetio_test.go b/pkg/proxy/net/packetio_test.go index b3701ddd..14e0bdfc 100644 --- a/pkg/proxy/net/packetio_test.go +++ b/pkg/proxy/net/packetio_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/tiproxy/lib/config" "github.com/pingcap/tiproxy/lib/util/logger" "github.com/pingcap/tiproxy/lib/util/security" @@ -257,3 +258,32 @@ func TestKeepAlive(t *testing.T) { 1, ) } + +func TestPredefinedPacket(t *testing.T) { + testTCPConn(t, + func(t *testing.T, cli *PacketIO) { + data, err := cli.ReadPacket() + require.NoError(t, err) + merr := ParseErrorPacket(data).(*mysql.MyError) + require.Equal(t, uint16(mysql.ER_UNKNOWN_ERROR), merr.Code) + require.Equal(t, "Unknown error", merr.Message) + + data, err = cli.ReadPacket() + require.NoError(t, err) + merr = ParseErrorPacket(data).(*mysql.MyError) + require.Equal(t, uint16(mysql.ER_UNKNOWN_ERROR), merr.Code) + require.Equal(t, "test error", merr.Message) + + data, err = cli.ReadPacket() + require.NoError(t, err) + res := ParseOKPacket(data) + require.Equal(t, uint16(100), res.Status) + }, + func(t *testing.T, srv *PacketIO) { + require.NoError(t, srv.WriteErrPacket(mysql.NewDefaultError(mysql.ER_UNKNOWN_ERROR))) + require.NoError(t, srv.WriteErrPacket(mysql.NewError(mysql.ER_UNKNOWN_ERROR, "test error"))) + require.NoError(t, srv.WriteOKPacket(100, OKHeader)) + }, + 1, + ) +}