Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend, net: Support compression protocol #373

Merged
merged 16 commits into from
Oct 9, 2023
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ require (
github.com/gin-gonic/gin v1.8.1
github.com/go-mysql-org/go-mysql v1.6.0
github.com/go-sql-driver/mysql v1.7.0
github.com/klauspost/compress v1.16.6
github.com/pingcap/tidb v1.1.0-beta.0.20230103132820-3ccff46aa3bc
github.com/pingcap/tidb/parser v0.0.0-20230103132820-3ccff46aa3bc
github.com/pingcap/tiproxy/lib v0.0.0-00010101000000-000000000000
Expand Down
3 changes: 2 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,8 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.8.2/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A=
github.com/klauspost/compress v1.9.0/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A=
github.com/klauspost/compress v1.15.13 h1:NFn1Wr8cfnenSJSA46lLq4wHCcBzKTSjnBIexDMMOV0=
github.com/klauspost/compress v1.16.6 h1:91SKEy4K37vkp255cJ8QesJhjyRO0hn9i9G0GoUwLsk=
github.com/klauspost/compress v1.16.6/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek=
github.com/klauspost/cpuid v1.3.1 h1:5JNjFYYQrZeKRJ0734q51WCEEn2huer72Dc7K+R/b6s=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
Expand Down
33 changes: 27 additions & 6 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ const defRequiredBackendCaps = pnet.ClientDeprecateEOF

// SupportedServerCapabilities is the default supported capabilities. Other server capabilities are not supported.
// TiDB supports ClientDeprecateEOF since v6.3.0.
// TiDB supports ClientCompress and ClientZstdCompressionAlgorithm since v7.2.0.
const SupportedServerCapabilities = pnet.ClientLongPassword | pnet.ClientFoundRows | pnet.ClientConnectWithDB |
pnet.ClientODBC | pnet.ClientLocalFiles | pnet.ClientInteractive | pnet.ClientLongFlag | pnet.ClientSSL |
pnet.ClientTransactions | pnet.ClientReserved | pnet.ClientSecureConnection | pnet.ClientMultiStatements |
pnet.ClientMultiResults | pnet.ClientPluginAuth | pnet.ClientConnectAttrs | pnet.ClientPluginAuthLenencClientData |
requiredFrontendCaps | defRequiredBackendCaps
pnet.ClientCompress | pnet.ClientZstdCompressionAlgorithm | requiredFrontendCaps | defRequiredBackendCaps

// Authenticator handshakes with the client and the backend.
type Authenticator struct {
Expand All @@ -42,6 +43,7 @@ type Authenticator struct {
attrs map[string]string
salt []byte
capability pnet.Capability
zstdLevel int
collation uint8
proxyProtocol bool
requireBackendTLS bool
Expand All @@ -64,9 +66,7 @@ func (auth *Authenticator) writeProxyProtocol(clientIO, backendIO *pnet.PacketIO
}
// either from another proxy or directly from clients, we are acting as a proxy
proxy.Command = proxyprotocol.ProxyCommandProxy
if err := backendIO.WriteProxyV2(proxy); err != nil {
return err
}
backendIO.EnableProxyClient(proxy)
}
return nil
}
Expand Down Expand Up @@ -157,6 +157,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
auth.dbname = clientResp.DB
auth.collation = clientResp.Collation
auth.attrs = clientResp.Attrs
auth.zstdLevel = clientResp.ZstdLevel

// In case of testing, backendIO is passed manually that we don't want to bother with the routing logic.
backendIO, err := getBackendIO(cctx, auth, clientResp, 15*time.Second)
Expand Down Expand Up @@ -225,6 +226,12 @@ loop:
pktIdx++
switch serverPkt[0] {
case pnet.OKHeader.Byte():
if err := setCompress(clientIO, auth.capability, auth.zstdLevel); err != nil {
return err
}
if err := setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel); err != nil {
return err
}
return nil
case pnet.ErrHeader.Byte():
return pnet.ParseErrorPacket(serverPkt)
Expand Down Expand Up @@ -277,7 +284,10 @@ func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, bac
return err
}

return auth.handleSecondAuthResult(backendIO)
if err = auth.handleSecondAuthResult(backendIO); err == nil {
return setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel)
}
return err
}

func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serverPkt []byte, capability pnet.Capability, err error) {
Expand Down Expand Up @@ -307,8 +317,9 @@ func (auth *Authenticator) writeAuthHandshake(
Attrs: auth.attrs,
Collation: auth.collation,
AuthData: authData,
Capability: auth.capability | authCap,
Capability: auth.capability&backendCapability | authCap,
AuthPlugin: authPlugin,
ZstdLevel: auth.zstdLevel,
}

if len(resp.Attrs) > 0 {
Expand Down Expand Up @@ -382,3 +393,13 @@ func (auth *Authenticator) changeUser(req *pnet.ChangeUserReq) {
func (auth *Authenticator) updateCurrentDB(db string) {
auth.dbname = db
}

func setCompress(packetIO *pnet.PacketIO, capability pnet.Capability, zstdLevel int) error {
algorithm := pnet.CompressionNone
if capability&pnet.ClientCompress > 0 {
algorithm = pnet.CompressionZlib
} else if capability&pnet.ClientZstdCompressionAlgorithm > 0 {
algorithm = pnet.CompressionZstd
}
return packetIO.SetCompressionAlgorithm(algorithm, zstdLevel)
}
159 changes: 159 additions & 0 deletions pkg/proxy/backend/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,30 @@ func TestCapability(t *testing.T) {
cfg.clientConfig.capability |= pnet.ClientSecureConnection
},
},
{
func(cfg *testConfig) {
cfg.backendConfig.capability &= ^pnet.ClientCompress
cfg.backendConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
},
func(cfg *testConfig) {
cfg.backendConfig.capability |= pnet.ClientCompress
cfg.backendConfig.capability |= pnet.ClientZstdCompressionAlgorithm
},
},
{
func(cfg *testConfig) {
cfg.clientConfig.capability &= ^pnet.ClientCompress
cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
},
func(cfg *testConfig) {
cfg.clientConfig.capability |= pnet.ClientCompress
cfg.clientConfig.capability |= pnet.ClientZstdCompressionAlgorithm
},
func(cfg *testConfig) {
cfg.clientConfig.capability |= pnet.ClientCompress
cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
},
},
}

tc := newTCPConnSuite(t)
Expand Down Expand Up @@ -387,3 +411,138 @@ func TestProxyProtocol(t *testing.T) {
clean()
}
}

func TestCompressProtocol(t *testing.T) {
cfgs := [][]cfgOverrider{
{
func(cfg *testConfig) {
cfg.backendConfig.capability &= ^pnet.ClientCompress
cfg.backendConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
},
func(cfg *testConfig) {
cfg.backendConfig.capability |= pnet.ClientCompress
cfg.backendConfig.capability |= pnet.ClientZstdCompressionAlgorithm
},
},
{
func(cfg *testConfig) {
cfg.clientConfig.capability &= ^pnet.ClientCompress
cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
},
func(cfg *testConfig) {
cfg.clientConfig.capability |= pnet.ClientCompress
cfg.clientConfig.capability |= pnet.ClientZstdCompressionAlgorithm
cfg.clientConfig.zstdLevel = 3
},
func(cfg *testConfig) {
cfg.clientConfig.capability |= pnet.ClientCompress
cfg.clientConfig.capability |= pnet.ClientZstdCompressionAlgorithm
cfg.clientConfig.zstdLevel = 9
},
func(cfg *testConfig) {
cfg.clientConfig.capability |= pnet.ClientCompress
cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
},
},
}

checker := func(t *testing.T, ts *testSuite, referCfg *testConfig) {
// If the client enables compression, client <-> proxy enables compression.
if referCfg.clientConfig.capability&pnet.ClientCompress > 0 {
require.Greater(t, ts.mp.authenticator.capability&pnet.ClientCompress, pnet.Capability(0))
require.Greater(t, ts.mc.capability&pnet.ClientCompress, pnet.Capability(0))
} else {
require.Equal(t, pnet.Capability(0), ts.mp.authenticator.capability&pnet.ClientCompress)
require.Equal(t, pnet.Capability(0), ts.mc.capability&pnet.ClientCompress)
}
// If both the client and the backend enables compression, proxy <-> backend enables compression.
if referCfg.clientConfig.capability&referCfg.backendConfig.capability&pnet.ClientCompress > 0 {
require.Greater(t, ts.mb.capability&pnet.ClientCompress, pnet.Capability(0))
} else {
require.Equal(t, pnet.Capability(0), ts.mb.capability&pnet.ClientCompress)
}
// If the client enables zstd compression, client <-> proxy enables zstd compression.
zstdCap := pnet.ClientCompress | pnet.ClientZstdCompressionAlgorithm
if referCfg.clientConfig.capability&zstdCap == zstdCap {
require.Greater(t, ts.mp.authenticator.capability&pnet.ClientZstdCompressionAlgorithm, pnet.Capability(0))
require.Greater(t, ts.mc.capability&pnet.ClientZstdCompressionAlgorithm, pnet.Capability(0))
require.Equal(t, referCfg.clientConfig.zstdLevel, ts.mp.authenticator.zstdLevel)
} else {
require.Equal(t, pnet.Capability(0), ts.mp.authenticator.capability&pnet.ClientZstdCompressionAlgorithm)
require.Equal(t, pnet.Capability(0), ts.mc.capability&pnet.ClientZstdCompressionAlgorithm)
}
// If both the client and the backend enables zstd compression, proxy <-> backend enables zstd compression.
if referCfg.clientConfig.capability&referCfg.backendConfig.capability&zstdCap == zstdCap {
require.Greater(t, ts.mb.capability&pnet.ClientZstdCompressionAlgorithm, pnet.Capability(0))
require.Equal(t, referCfg.clientConfig.zstdLevel, ts.mb.zstdLevel)
} else {
require.Equal(t, pnet.Capability(0), ts.mb.capability&pnet.ClientZstdCompressionAlgorithm)
}
}

tc := newTCPConnSuite(t)
cfgOverriders := getCfgCombinations(cfgs)
for _, cfgs := range cfgOverriders {
referCfg := newTestConfig(cfgs...)
ts, clean := newTestSuite(t, tc, cfgs...)
ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) {
checker(t, ts, referCfg)
})
ts.authenticateSecondTime(t, func(t *testing.T, ts *testSuite) {
checker(t, ts, referCfg)
})
clean()
}
}

// After upgrading the backend, the backend capability may change.
func TestUpgradeBackendCap(t *testing.T) {
cfgs := [][]cfgOverrider{
{
func(cfg *testConfig) {
cfg.clientConfig.capability &= ^pnet.ClientCompress
cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
},
func(cfg *testConfig) {
cfg.clientConfig.capability |= pnet.ClientCompress
cfg.clientConfig.capability |= pnet.ClientZstdCompressionAlgorithm
cfg.clientConfig.zstdLevel = 3
},
func(cfg *testConfig) {
cfg.clientConfig.capability |= pnet.ClientCompress
cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
},
},
{
func(cfg *testConfig) {
cfg.backendConfig.capability &= ^pnet.ClientCompress
cfg.backendConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
},
},
}

tc := newTCPConnSuite(t)
cfgOverriders := getCfgCombinations(cfgs)
for _, cfgs := range cfgOverriders {
referCfg := newTestConfig(cfgs...)
ts, clean := newTestSuite(t, tc, cfgs...)
// Before upgrade, the backend doesn't support compression.
ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) {
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mp.authenticator.capability&pnet.ClientCompress)
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mc.capability&pnet.ClientCompress)
require.Equal(t, pnet.Capability(0), ts.mb.capability&pnet.ClientCompress)
})
// After upgrade, the backend also supports compression.
ts.mb.backendConfig.capability |= pnet.ClientCompress
ts.mb.backendConfig.capability |= pnet.ClientZstdCompressionAlgorithm
ts.authenticateSecondTime(t, func(t *testing.T, ts *testSuite) {
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mc.capability&pnet.ClientCompress)
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mp.authenticator.capability&pnet.ClientCompress)
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mb.capability&pnet.ClientCompress)
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientZstdCompressionAlgorithm, ts.mc.capability&pnet.ClientZstdCompressionAlgorithm)
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientZstdCompressionAlgorithm, ts.mp.authenticator.capability&pnet.ClientZstdCompressionAlgorithm)
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientZstdCompressionAlgorithm, ts.mb.capability&pnet.ClientZstdCompressionAlgorithm)
})
clean()
}
}
20 changes: 9 additions & 11 deletions pkg/proxy/backend/backend_conn_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,8 @@ func TestSpecialCmds(t *testing.T) {
require.NoError(t, ts.redirectSucceed4Backend(packetIO))
require.Equal(t, "another_user", ts.mb.username)
require.Equal(t, "session_db", ts.mb.db)
expectCap := pnet.Capability(ts.mp.handshakeHandler.GetCapability() &^ (pnet.ClientMultiStatements | pnet.ClientPluginAuthLenencClientData))
gotCap := pnet.Capability(ts.mb.capability &^ pnet.ClientPluginAuthLenencClientData)
expectCap := ts.mp.handshakeHandler.GetCapability() & defaultTestClientCapability &^ (pnet.ClientMultiStatements | pnet.ClientPluginAuthLenencClientData)
gotCap := ts.mb.capability &^ pnet.ClientPluginAuthLenencClientData
require.Equal(t, expectCap, gotCap, "expected=%s,got=%s", expectCap, gotCap)
return nil
},
Expand Down Expand Up @@ -793,18 +793,16 @@ func TestHandlerReturnError(t *testing.T) {
}

func TestOnTraffic(t *testing.T) {
i := 0
inbytes, outbytes := []int{
0x99,
}, []int{
0xce,
}
var inBytes, outBytes uint64
ts := newBackendMgrTester(t, func(config *testConfig) {
config.proxyConfig.bcConfig.CheckBackendInterval = 10 * time.Millisecond
config.proxyConfig.handler.onTraffic = func(cc ConnContext) {
require.Equal(t, uint64(inbytes[i]), cc.ClientInBytes())
require.Equal(t, uint64(outbytes[i]), cc.ClientOutBytes())
i++
require.Greater(t, cc.ClientInBytes(), uint64(0))
require.GreaterOrEqual(t, cc.ClientInBytes(), inBytes)
inBytes = cc.ClientInBytes()
require.Greater(t, cc.ClientOutBytes(), uint64(0))
require.GreaterOrEqual(t, cc.ClientOutBytes(), outBytes)
outBytes = cc.ClientOutBytes()
}
})
runners := []runner{
Expand Down
18 changes: 18 additions & 0 deletions pkg/proxy/backend/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,24 @@ func (tc *tcpConnSuite) newConn(t *testing.T, enableRoute bool) func() {
}
}

func (tc *tcpConnSuite) reconnectBackend(t *testing.T) {
lg, _ := logger.CreateLoggerForTest(t)
var wg waitgroup.WaitGroup
wg.Run(func() {
_ = tc.backendIO.Close()
conn, err := tc.backendListener.Accept()
require.NoError(t, err)
tc.backendIO = pnet.NewPacketIO(conn, lg)
})
wg.Run(func() {
_ = tc.proxyBIO.Close()
backendConn, err := net.Dial("tcp", tc.backendListener.Addr().String())
require.NoError(t, err)
tc.proxyBIO = pnet.NewPacketIO(backendConn, lg)
})
wg.Wait()
}

func (tc *tcpConnSuite) run(clientRunner, backendRunner func(*pnet.PacketIO) error, proxyRunner func(*pnet.PacketIO, *pnet.PacketIO) error) (cerr, berr, perr error) {
var wg waitgroup.WaitGroup
if clientRunner != nil {
Expand Down
13 changes: 9 additions & 4 deletions pkg/proxy/backend/mock_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ type mockBackend struct {
// Inputs that assigned by the test and will be sent to the client.
*backendConfig
// Outputs that received from the client and will be checked by the test.
username string
db string
attrs map[string]string
authData []byte
username string
db string
attrs map[string]string
authData []byte
zstdLevel int
}

func newMockBackend(cfg *backendConfig) *mockBackend {
Expand Down Expand Up @@ -98,6 +99,7 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error {
mb.authData = resp.AuthData
mb.attrs = resp.Attrs
mb.capability = resp.Capability
mb.zstdLevel = resp.ZstdLevel
// verify password
return mb.verifyPassword(packetIO, resp)
}
Expand Down Expand Up @@ -125,6 +127,9 @@ func (mb *mockBackend) verifyPassword(packetIO *pnet.PacketIO, resp *pnet.Handsh
if err := packetIO.WriteOKPacket(mb.status, pnet.OKHeader); err != nil {
return err
}
if err := setCompress(packetIO, mb.capability, mb.zstdLevel); err != nil {
return err
}
} else {
if err := packetIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_ACCESS_DENIED_ERROR)); err != nil {
return err
Expand Down
2 changes: 2 additions & 0 deletions pkg/proxy/backend/mock_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type clientConfig struct {
capability pnet.Capability
collation uint8
cmd pnet.Command
zstdLevel int
// for both auth and cmd
abnormalExit bool
}
Expand Down Expand Up @@ -82,6 +83,7 @@ func (mc *mockClient) authenticate(packetIO *pnet.PacketIO) error {
AuthData: mc.authData,
Capability: mc.capability,
Collation: mc.collation,
ZstdLevel: mc.zstdLevel,
}
pkt = pnet.MakeHandshakeResponse(resp)
if mc.capability&pnet.ClientSSL > 0 {
Expand Down
Loading
Loading