Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend: refine the error message when require TLS #359

Merged
merged 1 commit into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (

var (
ErrCapabilityNegotiation = errors.New("capability negotiation failed")
ErrTLSConfigRequired = errors.New("require TLS config on TiProxy when require-backend-tls=true")
)

const unknownAuthPlugin = "auth_unknown_plugin"
Expand Down Expand Up @@ -74,17 +73,15 @@ func (auth *Authenticator) writeProxyProtocol(clientIO, backendIO *pnet.PacketIO

func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapability pnet.Capability) error {
requiredBackendCaps := defRequiredBackendCaps & auth.capability
if auth.requireBackendTLS {
requiredBackendCaps |= pnet.ClientSSL
}

if commonCaps := backendCapability & requiredBackendCaps; commonCaps != requiredBackendCaps {
// The error cannot be sent to the client because the client only expects an initial handshake packet.
// The only way is to log it and disconnect.
logger.Error("require backend capabilities", zap.Stringer("common", commonCaps), zap.Stringer("required", requiredBackendCaps^commonCaps))
return errors.Wrapf(ErrCapabilityNegotiation, "require %s from backend", requiredBackendCaps^commonCaps)
}

if auth.requireBackendTLS && (backendCapability&pnet.ClientSSL == 0) {
return pnet.WrapUserError(errors.New("backend doesn't enable TLS"), requireTiDBTLSErrMsg)
}
return nil
}

Expand Down Expand Up @@ -315,7 +312,7 @@ func (auth *Authenticator) writeAuthHandshake(
var enableTLS bool
if auth.requireBackendTLS {
if backendTLSConfig == nil {
return ErrTLSConfigRequired
return pnet.WrapUserError(errors.New("tiproxy doesn't enable TLS"), requireProxyTLSErrMsg)
}
enableTLS = true
} else {
Expand Down
50 changes: 49 additions & 1 deletion pkg/proxy/backend/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"

"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tiproxy/lib/util/errors"
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -220,7 +221,7 @@ func TestCustomAuth(t *testing.T) {
require.Equal(t, ts.mc.username, inUser)
require.Equal(t, reUser, ts.mb.username)
require.Equal(t, reAttrs, ts.mb.attrs)
require.Equal(t, reCap&pnet.ClientDeprecateEOF, pnet.Capability(ts.mb.capability)&pnet.ClientDeprecateEOF)
require.Equal(t, reCap&pnet.ClientDeprecateEOF, ts.mb.capability&pnet.ClientDeprecateEOF)
}
ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) {})
checker()
Expand Down Expand Up @@ -290,3 +291,50 @@ func TestAuthFail(t *testing.T) {
clean()
}
}

func TestRequireBackendTLS(t *testing.T) {
tests := []struct {
cfg cfgOverrider
errMsg string
}{
{
cfg: func(cfg *testConfig) {
cfg.proxyConfig.bcConfig.RequireBackendTLS = true
cfg.proxyConfig.backendTLSConfig = nil
cfg.backendConfig.capability |= pnet.ClientSSL
},
errMsg: requireProxyTLSErrMsg,
},
{
cfg: func(cfg *testConfig) {
cfg.proxyConfig.bcConfig.RequireBackendTLS = true
cfg.backendConfig.tlsConfig = nil
cfg.backendConfig.capability &= ^pnet.ClientSSL
},
errMsg: requireTiDBTLSErrMsg,
},
{
cfg: func(cfg *testConfig) {
cfg.proxyConfig.bcConfig.RequireBackendTLS = false
cfg.proxyConfig.backendTLSConfig = nil
cfg.backendConfig.tlsConfig = nil
cfg.backendConfig.capability &= ^pnet.ClientSSL
},
},
}

tc := newTCPConnSuite(t)
for _, tt := range tests {
ts, clean := newTestSuite(t, tc, tt.cfg)
ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) {
if len(tt.errMsg) > 0 {
var userError *pnet.UserError
require.True(t, errors.As(ts.mp.err, &userError))
require.Equal(t, tt.errMsg, userError.UserMsg())
} else {
require.NoError(t, ts.mp.err)
}
})
clean()
}
}
6 changes: 3 additions & 3 deletions pkg/proxy/backend/backend_conn_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ func TestOnTraffic(t *testing.T) {
0xce,
}
ts := newBackendMgrTester(t, func(config *testConfig) {
config.proxyConfig.checkBackendInterval = 10 * time.Millisecond
config.proxyConfig.bcConfig.CheckBackendInterval = 10 * time.Millisecond
config.proxyConfig.handler.onTraffic = func(cc ConnContext) {
require.Equal(t, uint64(inbytes[i]), cc.ClientInBytes())
require.Equal(t, uint64(outbytes[i]), cc.ClientOutBytes())
Expand Down Expand Up @@ -873,7 +873,7 @@ func TestGetBackendIO(t *testing.T) {

func TestBackendInactive(t *testing.T) {
ts := newBackendMgrTester(t, func(config *testConfig) {
config.proxyConfig.checkBackendInterval = 10 * time.Millisecond
config.proxyConfig.bcConfig.CheckBackendInterval = 10 * time.Millisecond
})
runners := []runner{
// 1st handshake
Expand Down Expand Up @@ -957,7 +957,7 @@ func TestBackendInactive(t *testing.T) {

func TestKeepAlive(t *testing.T) {
ts := newBackendMgrTester(t, func(config *testConfig) {
config.proxyConfig.checkBackendInterval = 10 * time.Millisecond
config.proxyConfig.bcConfig.CheckBackendInterval = 10 * time.Millisecond
})
runners := []runner{
{
Expand Down
10 changes: 6 additions & 4 deletions pkg/proxy/backend/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ import (
)

const (
connectErrMsg = "No available TiDB instances, please check TiDB cluster"
parsePktErrMsg = "TiProxy fails to parse the packet, please contact PingCAP"
handshakeErrMsg = "TiProxy fails to connect to TiDB, please check network"
capabilityErrMsg = "Verify TiDB capability failed, please upgrade TiDB"
connectErrMsg = "No available TiDB instances, please check TiDB cluster"
parsePktErrMsg = "TiProxy fails to parse the packet, please contact PingCAP"
handshakeErrMsg = "TiProxy fails to connect to TiDB, please check network"
capabilityErrMsg = "Verify TiDB capability failed, please upgrade TiDB"
requireProxyTLSErrMsg = "Require TLS config on TiProxy when require-backend-tls=true"
requireTiDBTLSErrMsg = "Require TLS config on TiDB when require-backend-tls=true"
)

var (
Expand Down
2 changes: 1 addition & 1 deletion pkg/proxy/backend/mock_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error {
mb.db = resp.DB
mb.authData = resp.AuthData
mb.attrs = resp.Attrs
mb.capability = pnet.Capability(resp.Capability)
mb.capability = resp.Capability
// verify password
return mb.verifyPassword(packetIO, resp)
}
Expand Down
32 changes: 15 additions & 17 deletions pkg/proxy/backend/mock_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,22 @@ import (
)

type proxyConfig struct {
frontendTLSConfig *tls.Config
backendTLSConfig *tls.Config
handler *CustomHandshakeHandler
checkBackendInterval time.Duration
sessionToken string
capability pnet.Capability
waitRedirect bool
connectionID uint64
frontendTLSConfig *tls.Config
backendTLSConfig *tls.Config
handler *CustomHandshakeHandler
bcConfig *BCConfig
sessionToken string
capability pnet.Capability
waitRedirect bool
connectionID uint64
}

func newProxyConfig() *proxyConfig {
return &proxyConfig{
handler: &CustomHandshakeHandler{},
capability: defaultTestBackendCapability,
sessionToken: mockToken,
checkBackendInterval: CheckBackendInterval,
handler: &CustomHandshakeHandler{},
capability: defaultTestBackendCapability,
sessionToken: mockToken,
bcConfig: &BCConfig{},
}
}

Expand All @@ -49,11 +49,9 @@ type mockProxy struct {
func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy {
lg, _ := logger.CreateLoggerForTest(t)
mp := &mockProxy{
proxyConfig: cfg,
logger: lg.Named("mockProxy"),
BackendConnManager: NewBackendConnManager(lg, cfg.handler, cfg.connectionID, &BCConfig{
CheckBackendInterval: cfg.checkBackendInterval,
}),
proxyConfig: cfg,
logger: lg.Named("mockProxy"),
BackendConnManager: NewBackendConnManager(lg, cfg.handler, cfg.connectionID, cfg.bcConfig),
}
mp.cmdProcessor.capability = cfg.capability
return mp
Expand Down
Loading