Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend, net: replace tidb/parser/mysql #415

Merged
merged 3 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions pkg/proxy/backend/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"strings"
"testing"

"github.com/pingcap/tidb/parser/mysql"
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -107,18 +106,18 @@ func TestAuthPlugin(t *testing.T) {
},
{
func(cfg *testConfig) {
cfg.clientConfig.authPlugin = mysql.AuthNativePassword
cfg.clientConfig.authPlugin = pnet.AuthNativePassword
},
func(cfg *testConfig) {
cfg.clientConfig.authPlugin = mysql.AuthCachingSha2Password
cfg.clientConfig.authPlugin = pnet.AuthCachingSha2Password
},
},
{
func(cfg *testConfig) {
cfg.backendConfig.authPlugin = mysql.AuthNativePassword
cfg.backendConfig.authPlugin = pnet.AuthNativePassword
},
func(cfg *testConfig) {
cfg.backendConfig.authPlugin = mysql.AuthCachingSha2Password
cfg.backendConfig.authPlugin = pnet.AuthCachingSha2Password
},
},
{
Expand Down
8 changes: 4 additions & 4 deletions pkg/proxy/backend/backend_conn_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
"unsafe"

"github.com/cenkalti/backoff/v4"
gomysql "github.com/go-mysql-org/go-mysql/mysql"
"github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/tiproxy/lib/config"
"github.com/pingcap/tiproxy/lib/util/errors"
"github.com/pingcap/tiproxy/lib/util/waitgroup"
Expand Down Expand Up @@ -264,7 +264,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) (
mgr.handshakeHandler.OnTraffic(mgr)
}()
if len(request) < 1 {
err = gomysql.ErrMalformPacket
err = mysql.ErrMalformPacket
return
}
cmd := pnet.Command(request[0])
Expand Down Expand Up @@ -304,7 +304,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) (
mgr.authenticator.capability &^= pnet.ClientMultiStatements
mgr.cmdProcessor.capability &^= pnet.ClientMultiStatements
default:
err = errors.Wrapf(gomysql.ErrMalformPacket, "unrecognized set_option value:%d", val)
err = errors.Wrapf(mysql.ErrMalformPacket, "unrecognized set_option value:%d", val)
return
}
case pnet.ComChangeUser:
Expand Down Expand Up @@ -359,7 +359,7 @@ func (mgr *BackendConnManager) initSessionStates(backendIO *pnet.PacketIO, sessi

func (mgr *BackendConnManager) querySessionStates(backendIO *pnet.PacketIO) (sessionStates, sessionToken string, err error) {
// Do not lock here because the caller already locks.
var result *gomysql.Resultset
var result *mysql.Resultset
if result, _, err = mgr.cmdProcessor.query(backendIO, sqlQueryState); err != nil {
return
}
Expand Down
3 changes: 1 addition & 2 deletions pkg/proxy/backend/backend_conn_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"testing"
"time"

"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tiproxy/lib/util/errors"
"github.com/pingcap/tiproxy/lib/util/logger"
"github.com/pingcap/tiproxy/lib/util/waitgroup"
Expand Down Expand Up @@ -202,7 +201,7 @@ func (ts *backendMgrTester) respondWithNoTxn4Backend(packetIO *pnet.PacketIO) er

func (ts *backendMgrTester) startTxn4Backend(packetIO *pnet.PacketIO) error {
ts.mb.respondType = responseTypeOK
ts.mb.status = mysql.ServerStatusInTrans
ts.mb.status = pnet.ServerStatusInTrans
return ts.mb.respond(packetIO)
}

Expand Down
7 changes: 3 additions & 4 deletions pkg/proxy/backend/cmd_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package backend
import (
"encoding/binary"

"github.com/pingcap/tidb/parser/mysql"
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -58,7 +57,7 @@ func (cp *CmdProcessor) updateServerStatus(request []byte, serverStatus uint16)
}

func (cp *CmdProcessor) updateTxnStatus(serverStatus uint16) {
if serverStatus&mysql.ServerStatusInTrans > 0 {
if serverStatus&pnet.ServerStatusInTrans > 0 {
cp.serverStatus |= StatusInTrans
} else {
cp.serverStatus &^= StatusInTrans
Expand All @@ -84,11 +83,11 @@ func (cp *CmdProcessor) updatePrepStmtStatus(request []byte, serverStatus uint16
case pnet.ComStmtSendLongData:
prepStmtStatus = StatusPrepareWaitExecute
case pnet.ComStmtExecute:
if serverStatus&mysql.ServerStatusCursorExists > 0 {
if serverStatus&pnet.ServerStatusCursorExists > 0 {
prepStmtStatus = StatusPrepareWaitFetch
}
case pnet.ComStmtFetch:
if serverStatus&mysql.ServerStatusLastRowSend == 0 {
if serverStatus&pnet.ServerStatusLastRowSend == 0 {
prepStmtStatus = StatusPrepareWaitFetch
}
}
Expand Down
35 changes: 17 additions & 18 deletions pkg/proxy/backend/cmd_processor_exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ import (
"encoding/binary"
"strings"

gomysql "github.com/go-mysql-org/go-mysql/mysql"
"github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/tidb/parser"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tiproxy/lib/util/errors"
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
"github.com/siddontang/go/hack"
Expand Down Expand Up @@ -73,12 +72,12 @@ func (cp *CmdProcessor) forwardCommand(clientIO, backendIO *pnet.PacketIO, reque
return err
}
switch response[0] {
case mysql.OKHeader:
case pnet.OKHeader.Byte():
cp.handleOKPacket(request, response)
return nil
case mysql.ErrHeader:
case pnet.ErrHeader.Byte():
return cp.handleErrorPacket(response)
case mysql.EOFHeader:
case pnet.EOFHeader.Byte():
if cp.capability&pnet.ClientDeprecateEOF == 0 {
cp.handleEOFPacket(request, response)
} else {
Expand Down Expand Up @@ -132,7 +131,7 @@ func (cp *CmdProcessor) forwardPrepareCmd(clientIO, backendIO *pnet.PacketIO) er
return err
}
switch response[0] {
case mysql.OKHeader:
case pnet.OKHeader.Byte():
// The OK packet doesn't contain a server status.
// See https://mariadb.com/kb/en/com_stmt_prepare/
numColumns := binary.LittleEndian.Uint16(response[5:])
Expand All @@ -158,7 +157,7 @@ func (cp *CmdProcessor) forwardPrepareCmd(clientIO, backendIO *pnet.PacketIO) er
}
}
return clientIO.Flush()
case mysql.ErrHeader:
case pnet.ErrHeader.Byte():
if err := clientIO.Flush(); err != nil {
return err
}
Expand All @@ -185,24 +184,24 @@ func (cp *CmdProcessor) forwardQueryCmd(clientIO, backendIO *pnet.PacketIO, requ
err := backendIO.ForwardUntil(clientIO, func(firstByte byte, _ int) (end, needData bool) {
first = firstByte
switch firstByte {
case mysql.OKHeader, mysql.ErrHeader:
case pnet.OKHeader.Byte(), pnet.ErrHeader.Byte():
return true, true
default:
return true, false
}
}, func(response []byte) error {
var err error
switch first {
case mysql.OKHeader:
case pnet.OKHeader.Byte():
status := cp.handleOKPacket(request, response)
serverStatus, err = status, clientIO.Flush()
case mysql.ErrHeader:
case pnet.ErrHeader.Byte():
if err = clientIO.Flush(); err != nil {
return err
}
// Subsequent statements won't be executed even if it's a multi-statement.
return cp.handleErrorPacket(response)
case mysql.LocalInFileHeader:
case pnet.LocalInFileHeader.Byte():
serverStatus, err = cp.forwardLoadInFile(clientIO, backendIO, request)
default:
serverStatus, err = cp.forwardResultSet(clientIO, backendIO, request)
Expand All @@ -213,7 +212,7 @@ func (cp *CmdProcessor) forwardQueryCmd(clientIO, backendIO *pnet.PacketIO, requ
return err
}
// If it's not the last statement in multi-statements, continue.
if serverStatus&mysql.ServerMoreResultsExists == 0 {
if serverStatus&pnet.ServerMoreResultsExists == 0 {
break
}
}
Expand Down Expand Up @@ -243,9 +242,9 @@ func (cp *CmdProcessor) forwardLoadInFile(clientIO, backendIO *pnet.PacketIO, re
return
}
switch response[0] {
case mysql.OKHeader:
case pnet.OKHeader.Byte():
return cp.handleOKPacket(request, response), nil
case mysql.ErrHeader:
case pnet.ErrHeader.Byte():
return serverStatus, cp.handleErrorPacket(response)
}
// impossible here
Expand All @@ -262,7 +261,7 @@ func (cp *CmdProcessor) forwardResultSet(clientIO, backendIO *pnet.PacketIO, req
serverStatus = binary.LittleEndian.Uint16(response[3:])
// If a cursor exists, only columns are sent this time. The client will then send COM_STMT_FETCH to fetch rows.
// Otherwise, columns and rows are both sent once.
if serverStatus&mysql.ServerStatusCursorExists > 0 {
if serverStatus&pnet.ServerStatusCursorExists > 0 {
serverStatus = cp.handleEOFPacket(request, response)
return clientIO.Flush()
}
Expand Down Expand Up @@ -294,7 +293,7 @@ func (cp *CmdProcessor) forwardChangeUserCmd(clientIO, backendIO *pnet.PacketIO,
cp.logger.Warn("parse COM_CHANGE_USER packet encounters error", zap.Error(err))
var warning *errors.Warning
if !errors.As(err, &warning) {
return gomysql.ErrMalformPacket
return mysql.ErrMalformPacket
}
}
// The client may use the TiProxy salt to generate the auth data instead of using the TiDB salt,
Expand All @@ -312,10 +311,10 @@ func (cp *CmdProcessor) forwardChangeUserCmd(clientIO, backendIO *pnet.PacketIO,
return err
}
switch response[0] {
case mysql.OKHeader:
case pnet.OKHeader.Byte():
cp.handleOKPacket(request, response)
return nil
case mysql.ErrHeader:
case pnet.ErrHeader.Byte():
return cp.handleErrorPacket(response)
default:
// If the server sends a switch-auth request, the proxy forwards the auth data to the server.
Expand Down
27 changes: 13 additions & 14 deletions pkg/proxy/backend/cmd_processor_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ package backend
import (
"encoding/binary"

gomysql "github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/tidb/parser/mysql"
"github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/tiproxy/lib/util/errors"
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
"github.com/siddontang/go/hack"
Expand All @@ -16,7 +15,7 @@ import (
// query is called when the proxy sends requests to the backend by itself,
// such as querying session states, committing the current transaction.
// It only supports limited cases, excluding loading file, cursor fetch, multi-statements, etc.
func (cp *CmdProcessor) query(packetIO *pnet.PacketIO, sql string) (result *gomysql.Resultset, response []byte, err error) {
func (cp *CmdProcessor) query(packetIO *pnet.PacketIO, sql string) (result *mysql.Resultset, response []byte, err error) {
// send request
packetIO.ResetSequence()
data := hack.Slice(sql)
Expand All @@ -32,29 +31,29 @@ func (cp *CmdProcessor) query(packetIO *pnet.PacketIO, sql string) (result *gomy
return
}
switch response[0] {
case mysql.OKHeader:
case pnet.OKHeader.Byte():
cp.handleOKPacket(request, response)
case mysql.ErrHeader:
case pnet.ErrHeader.Byte():
err = cp.handleErrorPacket(response)
case mysql.LocalInFileHeader:
case pnet.LocalInFileHeader.Byte():
err = errors.WithStack(mysql.ErrMalformPacket)
default:
var rs *gomysql.Result
var rs *mysql.Result
rs, err = cp.readResultSet(packetIO, response)
result = rs.Resultset
}
return
}

// readResultSet is only used for reading the results of `show session_states` currently.
func (cp *CmdProcessor) readResultSet(packetIO *pnet.PacketIO, data []byte) (*gomysql.Result, error) {
func (cp *CmdProcessor) readResultSet(packetIO *pnet.PacketIO, data []byte) (*mysql.Result, error) {
columnCount, _, n := pnet.ParseLengthEncodedInt(data)
if n-len(data) != 0 {
return nil, errors.WithStack(mysql.ErrMalformPacket)
}

result := &gomysql.Result{
Resultset: gomysql.NewResultset(int(columnCount)),
result := &mysql.Result{
Resultset: mysql.NewResultset(int(columnCount)),
}
if err := cp.readResultColumns(packetIO, result); err != nil {
return nil, err
Expand All @@ -65,7 +64,7 @@ func (cp *CmdProcessor) readResultSet(packetIO *pnet.PacketIO, data []byte) (*go
return result, nil
}

func (cp *CmdProcessor) readResultColumns(packetIO *pnet.PacketIO, result *gomysql.Result) (err error) {
func (cp *CmdProcessor) readResultColumns(packetIO *pnet.PacketIO, result *mysql.Result) (err error) {
var fieldIndex int
var data []byte

Expand All @@ -86,7 +85,7 @@ func (cp *CmdProcessor) readResultColumns(packetIO *pnet.PacketIO, result *gomys
return err
}
if result.Fields[fieldIndex] == nil {
result.Fields[fieldIndex] = &gomysql.Field{}
result.Fields[fieldIndex] = &mysql.Field{}
}
if err = result.Fields[fieldIndex].Parse(data); err != nil {
return errors.WithStack(err)
Expand All @@ -97,7 +96,7 @@ func (cp *CmdProcessor) readResultColumns(packetIO *pnet.PacketIO, result *gomys
}
}

func (cp *CmdProcessor) readResultRows(packetIO *pnet.PacketIO, result *gomysql.Result) (err error) {
func (cp *CmdProcessor) readResultRows(packetIO *pnet.PacketIO, result *mysql.Result) (err error) {
var data []byte

for {
Expand All @@ -123,7 +122,7 @@ func (cp *CmdProcessor) readResultRows(packetIO *pnet.PacketIO, result *gomysql.
}

if cap(result.Values) < len(result.RowDatas) {
result.Values = make([][]gomysql.FieldValue, len(result.RowDatas))
result.Values = make([][]mysql.FieldValue, len(result.RowDatas))
} else {
result.Values = result.Values[:len(result.RowDatas)]
}
Expand Down
Loading