diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 90249107..2689e649 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -21,7 +21,6 @@ import ( var ( ErrCapabilityNegotiation = errors.New("capability negotiation failed") - ErrTLSConfigRequired = errors.New("require TLS config on TiProxy when require-backend-tls=true") ) const unknownAuthPlugin = "auth_unknown_plugin" @@ -74,17 +73,15 @@ func (auth *Authenticator) writeProxyProtocol(clientIO, backendIO *pnet.PacketIO func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapability pnet.Capability) error { requiredBackendCaps := defRequiredBackendCaps & auth.capability - if auth.requireBackendTLS { - requiredBackendCaps |= pnet.ClientSSL - } - if commonCaps := backendCapability & requiredBackendCaps; commonCaps != requiredBackendCaps { // The error cannot be sent to the client because the client only expects an initial handshake packet. // The only way is to log it and disconnect. logger.Error("require backend capabilities", zap.Stringer("common", commonCaps), zap.Stringer("required", requiredBackendCaps^commonCaps)) return errors.Wrapf(ErrCapabilityNegotiation, "require %s from backend", requiredBackendCaps^commonCaps) } - + if auth.requireBackendTLS && (backendCapability&pnet.ClientSSL == 0) { + return pnet.WrapUserError(errors.New("backend doesn't enable TLS"), requireTiDBTLSErrMsg) + } return nil } @@ -315,7 +312,7 @@ func (auth *Authenticator) writeAuthHandshake( var enableTLS bool if auth.requireBackendTLS { if backendTLSConfig == nil { - return ErrTLSConfigRequired + return pnet.WrapUserError(errors.New("tiproxy doesn't enable TLS"), requireProxyTLSErrMsg) } enableTLS = true } else { diff --git a/pkg/proxy/backend/authenticator_test.go b/pkg/proxy/backend/authenticator_test.go index b0323484..c5da07c0 100644 --- a/pkg/proxy/backend/authenticator_test.go +++ b/pkg/proxy/backend/authenticator_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tiproxy/lib/util/errors" pnet "github.com/pingcap/tiproxy/pkg/proxy/net" "github.com/stretchr/testify/require" ) @@ -220,7 +221,7 @@ func TestCustomAuth(t *testing.T) { require.Equal(t, ts.mc.username, inUser) require.Equal(t, reUser, ts.mb.username) require.Equal(t, reAttrs, ts.mb.attrs) - require.Equal(t, reCap&pnet.ClientDeprecateEOF, pnet.Capability(ts.mb.capability)&pnet.ClientDeprecateEOF) + require.Equal(t, reCap&pnet.ClientDeprecateEOF, ts.mb.capability&pnet.ClientDeprecateEOF) } ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) {}) checker() @@ -290,3 +291,50 @@ func TestAuthFail(t *testing.T) { clean() } } + +func TestRequireBackendTLS(t *testing.T) { + tests := []struct { + cfg cfgOverrider + errMsg string + }{ + { + cfg: func(cfg *testConfig) { + cfg.proxyConfig.bcConfig.RequireBackendTLS = true + cfg.proxyConfig.backendTLSConfig = nil + cfg.backendConfig.capability |= pnet.ClientSSL + }, + errMsg: requireProxyTLSErrMsg, + }, + { + cfg: func(cfg *testConfig) { + cfg.proxyConfig.bcConfig.RequireBackendTLS = true + cfg.backendConfig.tlsConfig = nil + cfg.backendConfig.capability &= ^pnet.ClientSSL + }, + errMsg: requireTiDBTLSErrMsg, + }, + { + cfg: func(cfg *testConfig) { + cfg.proxyConfig.bcConfig.RequireBackendTLS = false + cfg.proxyConfig.backendTLSConfig = nil + cfg.backendConfig.tlsConfig = nil + cfg.backendConfig.capability &= ^pnet.ClientSSL + }, + }, + } + + tc := newTCPConnSuite(t) + for _, tt := range tests { + ts, clean := newTestSuite(t, tc, tt.cfg) + ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) { + if len(tt.errMsg) > 0 { + var userError *pnet.UserError + require.True(t, errors.As(ts.mp.err, &userError)) + require.Equal(t, tt.errMsg, userError.UserMsg()) + } else { + require.NoError(t, ts.mp.err) + } + }) + clean() + } +} diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index 169aa7f6..69dc8e01 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -784,7 +784,7 @@ func TestOnTraffic(t *testing.T) { 0xce, } ts := newBackendMgrTester(t, func(config *testConfig) { - config.proxyConfig.checkBackendInterval = 10 * time.Millisecond + config.proxyConfig.bcConfig.CheckBackendInterval = 10 * time.Millisecond config.proxyConfig.handler.onTraffic = func(cc ConnContext) { require.Equal(t, uint64(inbytes[i]), cc.ClientInBytes()) require.Equal(t, uint64(outbytes[i]), cc.ClientOutBytes()) @@ -873,7 +873,7 @@ func TestGetBackendIO(t *testing.T) { func TestBackendInactive(t *testing.T) { ts := newBackendMgrTester(t, func(config *testConfig) { - config.proxyConfig.checkBackendInterval = 10 * time.Millisecond + config.proxyConfig.bcConfig.CheckBackendInterval = 10 * time.Millisecond }) runners := []runner{ // 1st handshake @@ -957,7 +957,7 @@ func TestBackendInactive(t *testing.T) { func TestKeepAlive(t *testing.T) { ts := newBackendMgrTester(t, func(config *testConfig) { - config.proxyConfig.checkBackendInterval = 10 * time.Millisecond + config.proxyConfig.bcConfig.CheckBackendInterval = 10 * time.Millisecond }) runners := []runner{ { diff --git a/pkg/proxy/backend/error.go b/pkg/proxy/backend/error.go index aa85d7ac..b2eb3295 100644 --- a/pkg/proxy/backend/error.go +++ b/pkg/proxy/backend/error.go @@ -8,10 +8,12 @@ import ( ) const ( - connectErrMsg = "No available TiDB instances, please check TiDB cluster" - parsePktErrMsg = "TiProxy fails to parse the packet, please contact PingCAP" - handshakeErrMsg = "TiProxy fails to connect to TiDB, please check network" - capabilityErrMsg = "Verify TiDB capability failed, please upgrade TiDB" + connectErrMsg = "No available TiDB instances, please check TiDB cluster" + parsePktErrMsg = "TiProxy fails to parse the packet, please contact PingCAP" + handshakeErrMsg = "TiProxy fails to connect to TiDB, please check network" + capabilityErrMsg = "Verify TiDB capability failed, please upgrade TiDB" + requireProxyTLSErrMsg = "Require TLS config on TiProxy when require-backend-tls=true" + requireTiDBTLSErrMsg = "Require TLS config on TiDB when require-backend-tls=true" ) var ( diff --git a/pkg/proxy/backend/mock_backend_test.go b/pkg/proxy/backend/mock_backend_test.go index 5b15fc79..9a8ec1bc 100644 --- a/pkg/proxy/backend/mock_backend_test.go +++ b/pkg/proxy/backend/mock_backend_test.go @@ -91,7 +91,7 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error { mb.db = resp.DB mb.authData = resp.AuthData mb.attrs = resp.Attrs - mb.capability = pnet.Capability(resp.Capability) + mb.capability = resp.Capability // verify password return mb.verifyPassword(packetIO, resp) } diff --git a/pkg/proxy/backend/mock_proxy_test.go b/pkg/proxy/backend/mock_proxy_test.go index d212d063..95d81bf7 100644 --- a/pkg/proxy/backend/mock_proxy_test.go +++ b/pkg/proxy/backend/mock_proxy_test.go @@ -15,22 +15,22 @@ import ( ) type proxyConfig struct { - frontendTLSConfig *tls.Config - backendTLSConfig *tls.Config - handler *CustomHandshakeHandler - checkBackendInterval time.Duration - sessionToken string - capability pnet.Capability - waitRedirect bool - connectionID uint64 + frontendTLSConfig *tls.Config + backendTLSConfig *tls.Config + handler *CustomHandshakeHandler + bcConfig *BCConfig + sessionToken string + capability pnet.Capability + waitRedirect bool + connectionID uint64 } func newProxyConfig() *proxyConfig { return &proxyConfig{ - handler: &CustomHandshakeHandler{}, - capability: defaultTestBackendCapability, - sessionToken: mockToken, - checkBackendInterval: CheckBackendInterval, + handler: &CustomHandshakeHandler{}, + capability: defaultTestBackendCapability, + sessionToken: mockToken, + bcConfig: &BCConfig{}, } } @@ -49,11 +49,9 @@ type mockProxy struct { func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy { lg, _ := logger.CreateLoggerForTest(t) mp := &mockProxy{ - proxyConfig: cfg, - logger: lg.Named("mockProxy"), - BackendConnManager: NewBackendConnManager(lg, cfg.handler, cfg.connectionID, &BCConfig{ - CheckBackendInterval: cfg.checkBackendInterval, - }), + proxyConfig: cfg, + logger: lg.Named("mockProxy"), + BackendConnManager: NewBackendConnManager(lg, cfg.handler, cfg.connectionID, cfg.bcConfig), } mp.cmdProcessor.capability = cfg.capability return mp