Skip to content

Commit

Permalink
net, backend: fix wrong format in error packet (#357)
Browse files Browse the repository at this point in the history
  • Loading branch information
djshow832 authored Sep 4, 2023
1 parent cc62e05 commit 53a957f
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 52 deletions.
24 changes: 12 additions & 12 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"net"
"time"

"github.com/pingcap/tidb/parser/mysql"
"github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tiproxy/lib/util/errors"
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
Expand Down Expand Up @@ -73,7 +73,7 @@ func (auth *Authenticator) writeProxyProtocol(clientIO, backendIO *pnet.PacketIO
}

func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapability pnet.Capability) error {
requiredBackendCaps := defRequiredBackendCaps & pnet.Capability(auth.capability)
requiredBackendCaps := defRequiredBackendCaps & auth.capability
if auth.requireBackendTLS {
requiredBackendCaps |= pnet.ClientSSL
}
Expand All @@ -100,7 +100,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
}

cid, _ := cctx.Value(ConnContextKeyConnID).(uint64)
if err := clientIO.WriteInitialHandshake(proxyCapability, auth.salt, mysql.AuthNativePassword, handshakeHandler.GetServerVersion(), cid); err != nil {
if err := clientIO.WriteInitialHandshake(proxyCapability, auth.salt, pnet.AuthNativePassword, handshakeHandler.GetServerVersion(), cid); err != nil {
return err
}
pkt, isSSL, err := clientIO.ReadSSLRequestOrHandshakeResp()
Expand All @@ -126,7 +126,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
}
if commonCaps := frontendCapability & requiredFrontendCaps; commonCaps != requiredFrontendCaps {
logger.Error("require frontend capabilities", zap.Stringer("common", commonCaps), zap.Stringer("required", requiredFrontendCaps))
if writeErr := clientIO.WriteErrPacket(mysql.ErrNotSupportedAuthMode); writeErr != nil {
if writeErr := clientIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_NOT_SUPPORTED_AUTH_MODE)); writeErr != nil {
return writeErr
}
return errors.Wrapf(ErrCapabilityNegotiation, "require %s from frontend", requiredFrontendCaps&^commonCaps)
Expand Down Expand Up @@ -220,14 +220,14 @@ loop:
return err
}
switch serverPkt[0] {
case mysql.OKHeader:
case pnet.OKHeader.Byte():
return nil
case mysql.ErrHeader:
case pnet.ErrHeader.Byte():
return pnet.ParseErrorPacket(serverPkt)
default: // mysql.AuthSwitchRequest, ShaCommand
if serverPkt[0] == mysql.AuthSwitchRequest {
if serverPkt[0] == pnet.AuthSwitchHeader.Byte() {
pluginName = string(serverPkt[1 : bytes.IndexByte(serverPkt[1:], 0)+1])
} else if serverPkt[0] == 1 && pluginName == mysql.AuthCachingSha2Password && len(serverPkt) == 2 && serverPkt[1] == 3 {
} else if serverPkt[0] == 1 && pluginName == pnet.AuthCachingSha2Password && len(serverPkt) == 2 && serverPkt[1] == 3 {
// caching_sha2_password fast path
continue loop
}
Expand Down Expand Up @@ -262,7 +262,7 @@ func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, bac
return err
}

if err := auth.verifyBackendCaps(logger, pnet.Capability(backendCapability)); err != nil {
if err := auth.verifyBackendCaps(logger, backendCapability); err != nil {
return err
}

Expand Down Expand Up @@ -320,7 +320,7 @@ func (auth *Authenticator) writeAuthHandshake(
enableTLS = true
} else {
// When client TLS is disabled, also disables proxy TLS.
enableTLS = pnet.Capability(auth.capability)&pnet.ClientSSL != 0 && backendCapability&pnet.ClientSSL != 0 && backendTLSConfig != nil
enableTLS = auth.capability&pnet.ClientSSL != 0 && backendCapability&pnet.ClientSSL != 0 && backendTLSConfig != nil
}
if enableTLS {
resp.Capability |= pnet.ClientSSL
Expand Down Expand Up @@ -355,9 +355,9 @@ func (auth *Authenticator) handleSecondAuthResult(backendIO *pnet.PacketIO) erro
}

switch data[0] {
case mysql.OKHeader:
case pnet.OKHeader.Byte():
return nil
case mysql.ErrHeader:
case pnet.ErrHeader.Byte():
return pnet.ParseErrorPacket(data)
default: // mysql.AuthSwitchRequest, ShaCommand:
return errors.Errorf("read unexpected command: %#x", data[0])
Expand Down
35 changes: 17 additions & 18 deletions pkg/proxy/backend/mock_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ import (
"crypto/tls"
"encoding/binary"

gomysql "github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/tidb/parser/mysql"
"github.com/go-mysql-org/go-mysql/mysql"
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
)

Expand Down Expand Up @@ -74,7 +73,7 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error {
}
// upgrade to TLS
capability := binary.LittleEndian.Uint16(clientPkt[:2])
sslEnabled := uint32(capability)&mysql.ClientSSL > 0 && mb.capability&pnet.ClientSSL > 0
sslEnabled := pnet.Capability(capability)&pnet.ClientSSL > 0 && mb.capability&pnet.ClientSSL > 0
if sslEnabled {
if _, err = packetIO.ServerTLSHandshake(mb.tlsConfig); err != nil {
return err
Expand All @@ -98,7 +97,7 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error {
}

func (mb *mockBackend) verifyPassword(packetIO *pnet.PacketIO, resp *pnet.HandshakeResp) error {
if resp.AuthPlugin != mysql.AuthTiDBSessionToken {
if resp.AuthPlugin != pnet.AuthTiDBSessionToken {
var err error
if err = packetIO.WriteSwitchRequest(mb.authPlugin, mb.salt); err != nil {
return err
Expand All @@ -107,7 +106,7 @@ func (mb *mockBackend) verifyPassword(packetIO *pnet.PacketIO, resp *pnet.Handsh
return err
}
switch mb.authPlugin {
case mysql.AuthCachingSha2Password:
case pnet.AuthCachingSha2Password:
if err = packetIO.WriteShaCommand(); err != nil {
return err
}
Expand All @@ -121,7 +120,7 @@ func (mb *mockBackend) verifyPassword(packetIO *pnet.PacketIO, resp *pnet.Handsh
return err
}
} else {
if err := packetIO.WriteErrPacket(mysql.ErrAccessDenied); err != nil {
if err := packetIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_ACCESS_DENIED_ERROR)); err != nil {
return err
}
}
Expand Down Expand Up @@ -150,7 +149,7 @@ func (mb *mockBackend) respondOnce(packetIO *pnet.PacketIO) error {
case responseTypeOK:
return mb.respondOK(packetIO)
case responseTypeErr:
return packetIO.WriteErrPacket(mysql.ErrUnknown)
return packetIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_UNKNOWN_ERROR))
case responseTypeResultSet:
if pnet.Command(pkt[0]) == pnet.ComQuery && string(pkt[1:]) == sqlQueryState {
return mb.respondSessionStates(packetIO)
Expand Down Expand Up @@ -179,16 +178,16 @@ func (mb *mockBackend) respondOnce(packetIO *pnet.PacketIO) error {
case responseTypeNone:
return nil
}
return packetIO.WriteErrPacket(mysql.ErrUnknown)
return packetIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_UNKNOWN_ERROR))
}

func (mb *mockBackend) respondOK(packetIO *pnet.PacketIO) error {
for i := 0; i < mb.stmtNum; i++ {
status := mb.status
if i < mb.stmtNum-1 {
status |= mysql.ServerMoreResultsExists
status |= mysql.SERVER_MORE_RESULTS_EXISTS
} else {
status &= ^mysql.ServerMoreResultsExists
status &= ^mysql.SERVER_MORE_RESULTS_EXISTS
}
if err := packetIO.WriteOKPacket(status, pnet.OKHeader); err != nil {
return err
Expand Down Expand Up @@ -242,16 +241,16 @@ func (mb *mockBackend) respondResultSet(packetIO *pnet.PacketIO) error {
}

func (mb *mockBackend) writeResultSet(packetIO *pnet.PacketIO, names []string, values [][]any) error {
rs, err := gomysql.BuildSimpleTextResultset(names, values)
rs, err := mysql.BuildSimpleTextResultset(names, values)
if err != nil {
return err
}
for i := 0; i < mb.stmtNum; i++ {
status := mb.status
if i < mb.stmtNum-1 {
status |= mysql.ServerMoreResultsExists
status |= mysql.SERVER_MORE_RESULTS_EXISTS
} else {
status &= ^mysql.ServerMoreResultsExists
status &= ^mysql.SERVER_MORE_RESULTS_EXISTS
}
data := pnet.DumpLengthEncodedInt(nil, uint64(len(names)))
if err := packetIO.WritePacket(data, false); err != nil {
Expand All @@ -263,7 +262,7 @@ func (mb *mockBackend) writeResultSet(packetIO *pnet.PacketIO, names []string, v
}
}

if status&mysql.ServerStatusCursorExists == 0 {
if status&mysql.SERVER_STATUS_CURSOR_EXISTS == 0 {
if mb.capability&pnet.ClientDeprecateEOF == 0 {
if err := packetIO.WriteEOFPacket(status); err != nil {
return err
Expand Down Expand Up @@ -291,12 +290,12 @@ func (mb *mockBackend) respondLoadFile(packetIO *pnet.PacketIO) error {
for i := 0; i < mb.stmtNum; i++ {
status := mb.status
if i < mb.stmtNum-1 {
status |= mysql.ServerMoreResultsExists
status |= mysql.SERVER_MORE_RESULTS_EXISTS
} else {
status &= ^mysql.ServerMoreResultsExists
status &= ^mysql.SERVER_MORE_RESULTS_EXISTS
}
data := make([]byte, 0, 1+len(mockCmdStr))
data = append(data, mysql.LocalInFileHeader)
data = append(data, pnet.LocalInFileHeader.Byte())
data = append(data, []byte(mockCmdStr)...)
if err := packetIO.WritePacket(data, true); err != nil {
return err
Expand All @@ -321,7 +320,7 @@ func (mb *mockBackend) respondLoadFile(packetIO *pnet.PacketIO) error {

// respond to Prepare
func (mb *mockBackend) respondPrepare(packetIO *pnet.PacketIO) error {
data := []byte{mysql.OKHeader}
data := []byte{pnet.OKHeader.Byte()}
data = pnet.DumpUint32(data, uint32(mockCmdInt))
data = pnet.DumpUint16(data, uint16(mb.columns))
data = pnet.DumpUint16(data, uint16(mb.params))
Expand Down
32 changes: 10 additions & 22 deletions pkg/proxy/net/packetio_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package net

import (
"encoding/binary"
"fmt"

"github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/tiproxy/lib/util/errors"
Expand Down Expand Up @@ -96,26 +95,14 @@ func (p *PacketIO) ReadSSLRequestOrHandshakeResp() (pkt []byte, isSSL bool, err
}

// WriteErrPacket writes an Error packet.
func (p *PacketIO) WriteErrPacket(code uint16, message ...any) error {
data := make([]byte, 0, 9+len(message))
func (p *PacketIO) WriteErrPacket(merr *mysql.MyError) error {
data := make([]byte, 0, 9+len(merr.Message))
data = append(data, ErrHeader.Byte())
data = append(data, byte(code), byte(code>>8))

// TODO: ClientProtocol41 must be enabled for state
data = append(data, byte(merr.Code), byte(merr.Code>>8))
// ClientProtocol41 is always enabled.
data = append(data, '#')
s, ok := mysql.MySQLState[code]
if !ok {
s = mysql.DEFAULT_MYSQL_STATE
}
data = append(data, s...)

var msg string
if format, ok := mysql.MySQLErrName[code]; ok {
msg = fmt.Sprintf(format, message...)
} else {
msg = fmt.Sprint(message...)
}
data = append(data, msg...)
data = append(data, merr.State...)
data = append(data, merr.Message...)
return p.WritePacket(data, true)
}

Expand All @@ -124,7 +111,7 @@ func (p *PacketIO) WriteOKPacket(status uint16, header Header) error {
data := make([]byte, 0, 7)
data = append(data, header.Byte())
data = append(data, 0, 0)
// ClientProtocol41 must be enabled.
// ClientProtocol41 is always enabled.
data = DumpUint16(data, status)
data = append(data, 0, 0)
return p.WritePacket(data, true)
Expand All @@ -135,7 +122,7 @@ func (p *PacketIO) WriteEOFPacket(status uint16) error {
data := make([]byte, 0, 5)
data = append(data, EOFHeader.Byte())
data = append(data, 0, 0)
// ClientProtocol41 must be enabled.
// ClientProtocol41 is always enabled.
data = DumpUint16(data, status)
return p.WritePacket(data, true)
}
Expand All @@ -149,7 +136,8 @@ func (p *PacketIO) WriteUserError(err error) {
if !errors.As(err, &ue) {
return
}
if writeErr := p.WriteErrPacket(mysql.ER_UNKNOWN_ERROR, ue.UserMsg()); writeErr != nil {
myErr := mysql.NewError(mysql.ER_UNKNOWN_ERROR, ue.UserMsg())
if writeErr := p.WriteErrPacket(myErr); writeErr != nil {
p.logger.Error("writing error to client failed", zap.NamedError("mysql_err", err), zap.NamedError("write_err", writeErr))
}
}
30 changes: 30 additions & 0 deletions pkg/proxy/net/packetio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"testing"
"time"

"github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/tiproxy/lib/config"
"github.com/pingcap/tiproxy/lib/util/logger"
"github.com/pingcap/tiproxy/lib/util/security"
Expand Down Expand Up @@ -257,3 +258,32 @@ func TestKeepAlive(t *testing.T) {
1,
)
}

func TestPredefinedPacket(t *testing.T) {
testTCPConn(t,
func(t *testing.T, cli *PacketIO) {
data, err := cli.ReadPacket()
require.NoError(t, err)
merr := ParseErrorPacket(data).(*mysql.MyError)
require.Equal(t, uint16(mysql.ER_UNKNOWN_ERROR), merr.Code)
require.Equal(t, "Unknown error", merr.Message)

data, err = cli.ReadPacket()
require.NoError(t, err)
merr = ParseErrorPacket(data).(*mysql.MyError)
require.Equal(t, uint16(mysql.ER_UNKNOWN_ERROR), merr.Code)
require.Equal(t, "test error", merr.Message)

data, err = cli.ReadPacket()
require.NoError(t, err)
res := ParseOKPacket(data)
require.Equal(t, uint16(100), res.Status)
},
func(t *testing.T, srv *PacketIO) {
require.NoError(t, srv.WriteErrPacket(mysql.NewDefaultError(mysql.ER_UNKNOWN_ERROR)))
require.NoError(t, srv.WriteErrPacket(mysql.NewError(mysql.ER_UNKNOWN_ERROR, "test error")))
require.NoError(t, srv.WriteOKPacket(100, OKHeader))
},
1,
)
}

0 comments on commit 53a957f

Please sign in to comment.