diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 98996160..c753cd87 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -9,6 +9,7 @@ import ( "encoding/binary" "fmt" "net" + "strings" "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/tidb/util/hack" @@ -21,6 +22,7 @@ import ( const unknownAuthPlugin = "auth_unknown_plugin" const requiredFrontendCaps = pnet.ClientProtocol41 const defRequiredBackendCaps = pnet.ClientDeprecateEOF +const ER_INVALID_SEQUENCE = 8052 // SupportedServerCapabilities is the default supported capabilities. Other server capabilities are not supported. // TiDB supports ClientDeprecateEOF since v6.3.0. @@ -212,11 +214,6 @@ loop: for { serverPkt, err := backendIO.ReadPacket() if err != nil { - // tiproxy pp enabled, tidb pp disabled, tls disabled => invalid sequence - // tiproxy pp disabled, tidb pp enabled, tls disabled => invalid sequence - if pktIdx == 0 && errors.Is(err, pnet.ErrInvalidSequence) { - return errors.Wrap(ErrBackendPPV2, err) - } return err } var packetErr *mysql.MyError @@ -235,7 +232,7 @@ loop: return err } if packetErr != nil { - return errors.Wrap(ErrClientAuthFail, packetErr) + return handleHandshakeError(pktIdx, packetErr) } pktIdx++ @@ -428,3 +425,23 @@ func setCompress(packetIO *pnet.PacketIO, capability pnet.Capability, zstdLevel } return packetIO.SetCompressionAlgorithm(algorithm, zstdLevel) } + +// handleHandshakeError tries to recognize the error and report more friendly messages. +func handleHandshakeError(pktIdx int, packetErr *mysql.MyError) error { + if pktIdx == 0 { + // PPV2 errors only appear in the first packet + + // mysql ERROR 1156: Got packets out of order (proxy ppv2 = true) + if packetErr.Code == mysql.ER_NET_PACKETS_OUT_OF_ORDER || + // tidb ERROR 8052: invalid sequence, received 10 while expecting 1 (proxy ppv2 = true, db ppv2 = false) + packetErr.Code == ER_INVALID_SEQUENCE { + return errors.Wrap(ErrBackendPPV2, packetErr) + } + // tidb ERROR 1105: invalid PROXY Protocol Header (proxy ppv2 = false, db ppv2 = true, db ppv2 fallback = false) + // 1105 is UNKNOWN_ERR, so we judge the error by messages + if strings.Contains(packetErr.Message, "PROXY Protocol") { + return ErrBackendPPV2 + } + } + return errors.Wrap(ErrClientAuthFail, packetErr) +} diff --git a/pkg/proxy/backend/authenticator_test.go b/pkg/proxy/backend/authenticator_test.go index 2a6ccea2..abe31290 100644 --- a/pkg/proxy/backend/authenticator_test.go +++ b/pkg/proxy/backend/authenticator_test.go @@ -402,6 +402,10 @@ func TestProxyProtocol(t *testing.T) { cfgOverriders := getCfgCombinations(cfgs) for _, cfgs := range cfgOverriders { ts, clean := newTestSuite(t, tc, cfgs...) + // invalid sequence detection removed, backend will stuck if clients insists to send proxy header. + if !ts.mb.proxyProtocol && ts.mp.bcConfig.ProxyProtocol { + continue + } ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) { // TiDB proxy-protocol can be set unfallbackable, but TiProxy proxy-protocol is always fallbackable. // So when backend enables proxy-protocol and proxy disables it, it still works well. diff --git a/pkg/proxy/backend/error.go b/pkg/proxy/backend/error.go index ed53d86c..1342852a 100644 --- a/pkg/proxy/backend/error.go +++ b/pkg/proxy/backend/error.go @@ -79,7 +79,7 @@ const ( SrcClientSQLErr // SrcProxyQuit includes: proxy graceful shutdown SrcProxyQuit - // SrcProxyMalformed includes: malformed packet; invalid sequence + // SrcProxyMalformed includes: malformed packet SrcProxyMalformed // SrcProxyNoBackend includes: no backends SrcProxyNoBackend diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index 201a6b33..1e5fe865 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -246,11 +246,7 @@ func (p *PacketIO) readOnePacket() ([]byte, bool, error) { if err := ReadFull(p.readWriter, p.header); err != nil { return nil, false, errors.Wrap(ErrReadConn, err) } - sequence, pktSequence := p.header[3], p.readWriter.Sequence() - if sequence != pktSequence { - return nil, false, ErrInvalidSequence.GenWithStack("invalid sequence, expected %d, actual %d", pktSequence, sequence) - } - p.readWriter.SetSequence(sequence + 1) + p.readWriter.SetSequence(p.header[3] + 1) length := int(p.header[0]) | int(p.header[1])<<8 | int(p.header[2])<<16 data := make([]byte, length) @@ -350,11 +346,7 @@ func (p *PacketIO) ForwardUntil(dest *PacketIO, isEnd func(firstByte byte, first } } else { for { - sequence, pktSequence := header[3], p.readWriter.Sequence() - if sequence != pktSequence { - return p.wrapErr(ErrInvalidSequence.GenWithStack("invalid sequence, expected %d, actual %d", pktSequence, sequence)) - } - p.readWriter.SetSequence(sequence + 1) + p.readWriter.SetSequence(header[3] + 1) // Sequence may be different (e.g. with compression) so we can't just copy the data to the destination. dest.readWriter.SetSequence(dest.readWriter.Sequence() + 1) p.limitReader.N = int64(length + 4)