Skip to content

Commit

Permalink
fix scheme
Browse files Browse the repository at this point in the history
Signed-off-by: Cabinfever_B <[email protected]>
  • Loading branch information
CabinfeverB committed Mar 8, 2024
1 parent e72d49b commit 6438a65
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 67 deletions.
2 changes: 1 addition & 1 deletion client/http/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func (ci *clientInner) requestWithRetry(
return errs.ErrClientNoAvailableMember
}
for _, cli := range clients {
addr := cli.GetHTTPAddress()
addr := cli.GetAddress()
statusCode, err = ci.doRequest(ctx, addr, reqInfo, headerOpts...)
if err == nil || noNeedRetry(statusCode) {
return err
Expand Down
79 changes: 37 additions & 42 deletions client/pd_service_discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ const (
updateMemberTimeout = time.Second // Use a shorter timeout to recover faster from network isolation.
updateMemberBackOffBaseTime = 100 * time.Millisecond

httpScheme = "http"
httpsScheme = "https"
httpScheme = "http://"
httpsScheme = "https://"
)

// MemberHealthCheckInterval might be changed in the unit to shorten the testing time.
Expand Down Expand Up @@ -124,10 +124,8 @@ type ServiceDiscovery interface {

// ServiceClient is an interface that defines a set of operations for a raw PD gRPC client to specific PD server.
type ServiceClient interface {
// GetAddress returns the address information of the PD server.
// GetAddress returns the address with HTTP scheme of the PD server.
GetAddress() string
// GetHTTPAddress returns the address with HTTP scheme of the PD server.
GetHTTPAddress() string
// GetClientConn returns the gRPC connection of the service client
GetClientConn() *grpc.ClientConn
// BuildGRPCTargetContext builds a context object with a gRPC context.
Expand Down Expand Up @@ -158,34 +156,15 @@ type pdServiceClient struct {
networkFailure atomic.Bool
}

func newPDServiceClient(addr, leaderAddr string, tlsCfg *tls.Config, conn *grpc.ClientConn, isLeader bool) ServiceClient {
var httpAddress string
if tlsCfg == nil {
if strings.HasPrefix(addr, httpsScheme) {
addr = strings.TrimPrefix(addr, httpsScheme)
httpAddress = fmt.Sprintf("%s%s", httpScheme, addr)
} else if strings.HasPrefix(addr, httpScheme) {
httpAddress = addr
} else {
httpAddress = fmt.Sprintf("%s://%s", httpScheme, addr)
}
} else {
if strings.HasPrefix(addr, httpsScheme) {
httpAddress = addr
} else if strings.HasPrefix(addr, httpScheme) {
addr = strings.TrimPrefix(addr, httpScheme)
httpAddress = fmt.Sprintf("%s%s", httpsScheme, addr)
} else {
httpAddress = fmt.Sprintf("%s://%s", httpsScheme, addr)
}
}

// NOTE: In the current implementation, the address passed in is bound to have an http scheme,
// because it is processed in `newPDServiceDiscovery`, and the url returned by etcd member is its own.
// When testing, the address is also bound to have an http scheme.
func newPDServiceClient(addr, leaderAddr string, conn *grpc.ClientConn, isLeader bool) ServiceClient {
cli := &pdServiceClient{
addr: addr,
httpAddress: httpAddress,
conn: conn,
isLeader: isLeader,
leaderAddr: leaderAddr,
addr: addr,
conn: conn,
isLeader: isLeader,
leaderAddr: leaderAddr,
}
if conn == nil {
cli.networkFailure.Store(true)
Expand Down Expand Up @@ -504,7 +483,7 @@ func newPDServiceDiscovery(
tlsCfg: tlsCfg,
option: option,
}
urls = addrsToUrls(urls)
urls = addrsToUrls(urls, tlsCfg)
pdsd.urls.Store(urls)
return pdsd
}
Expand Down Expand Up @@ -1030,7 +1009,7 @@ func (c *pdServiceDiscovery) switchLeader(addrs []string) (bool, error) {
// If gRPC connect is created successfully or leader is new, still saves.
if addr != oldLeader.GetAddress() || newConn != nil {
// Set PD leader and Global TSO Allocator (which is also the PD leader)
leaderClient := newPDServiceClient(addr, addr, c.tlsCfg, newConn, true)
leaderClient := newPDServiceClient(addr, addr, newConn, true)
c.leader.Store(leaderClient)
}
// Run callbacks
Expand Down Expand Up @@ -1067,15 +1046,15 @@ func (c *pdServiceDiscovery) updateFollowers(members []*pdpb.Member, leader *pdp
log.Warn("[pd] failed to connect follower", zap.String("follower", addr), errs.ZapError(err))
continue
}
follower := newPDServiceClient(addr, leader.GetClientUrls()[0], c.tlsCfg, conn, false)
follower := newPDServiceClient(addr, leader.GetClientUrls()[0], conn, false)
c.followers.Store(addr, follower)
changed = true
}
delete(followers, addr)
} else {
changed = true
conn, err := c.GetOrCreateGRPCConn(addr)
follower := newPDServiceClient(addr, leader.GetClientUrls()[0], c.tlsCfg, conn, false)
follower := newPDServiceClient(addr, leader.GetClientUrls()[0], conn, false)
if err != nil || conn == nil {
log.Warn("[pd] failed to connect follower", zap.String("follower", addr), errs.ZapError(err))
}
Expand Down Expand Up @@ -1148,15 +1127,31 @@ func (c *pdServiceDiscovery) GetOrCreateGRPCConn(addr string) (*grpc.ClientConn,
return grpcutil.GetOrCreateGRPCConn(c.ctx, &c.clientConns, addr, c.tlsCfg, c.option.gRPCDialOptions...)
}

func addrsToUrls(addrs []string) []string {
func addrsToUrls(addrs []string, tlsCfg *tls.Config) []string {
// Add default schema "http://" to addrs.
urls := make([]string, 0, len(addrs))
for _, addr := range addrs {
if strings.Contains(addr, "://") {
urls = append(urls, addr)
} else {
urls = append(urls, "http://"+addr)
}
addr = addrToUrl(addr, tlsCfg)
urls = append(urls, addr)
}
return urls
}

func addrToUrl(addr string, tlsCfg *tls.Config) string {
if tlsCfg == nil {
if strings.HasPrefix(addr, httpsScheme) {
addr = strings.TrimPrefix(addr, httpsScheme)
addr = fmt.Sprintf("%s%s", httpScheme, addr)
} else if !strings.HasPrefix(addr, httpScheme) {
addr = fmt.Sprintf("%s%s", httpScheme, addr)
}
} else {
if strings.HasPrefix(addr, httpScheme) {
addr = strings.TrimPrefix(addr, httpScheme)
addr = fmt.Sprintf("%s%s", httpsScheme, addr)
} else if !strings.HasPrefix(addr, httpsScheme) {
addr = fmt.Sprintf("%s%s", httpsScheme, addr)
}
}
return addr
}
40 changes: 22 additions & 18 deletions client/pd_service_discovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,14 @@ func (suite *serviceClientTestSuite) SetupSuite() {
leaderConn, err1 := grpc.Dial(suite.leaderServer.addr, grpc.WithInsecure()) //nolint
followerConn, err2 := grpc.Dial(suite.followerServer.addr, grpc.WithInsecure()) //nolint
if err1 == nil && err2 == nil {
suite.followerClient = newPDServiceClient(suite.followerServer.addr, suite.leaderServer.addr, nil, followerConn, false)
suite.leaderClient = newPDServiceClient(suite.leaderServer.addr, suite.leaderServer.addr, nil, leaderConn, true)
suite.followerClient = newPDServiceClient(
addrToUrl(suite.followerServer.addr, nil),
addrToUrl(suite.leaderServer.addr, nil),
followerConn, false)
suite.leaderClient = newPDServiceClient(
addrToUrl(suite.leaderServer.addr, nil),
addrToUrl(suite.leaderServer.addr, nil),
leaderConn, true)
suite.followerServer.server.leaderConn = suite.leaderClient.GetClientConn()
suite.followerServer.server.leaderAddr = suite.leaderClient.GetAddress()
return
Expand All @@ -165,16 +171,14 @@ func (suite *serviceClientTestSuite) TearDownSuite() {

func (suite *serviceClientTestSuite) TestServiceClient() {
re := suite.Require()
leaderAddress := suite.leaderServer.addr
followerAddress := suite.followerServer.addr
leaderAddress := addrToUrl(suite.leaderServer.addr, nil)
followerAddress := addrToUrl(suite.followerServer.addr, nil)

follower := suite.followerClient
leader := suite.leaderClient

re.Equal(follower.GetAddress(), followerAddress)
re.Equal(leader.GetAddress(), leaderAddress)
re.Equal(follower.GetHTTPAddress(), "http://"+followerAddress)
re.Equal(leader.GetHTTPAddress(), "http://"+leaderAddress)

re.True(follower.Available())
re.True(leader.Available())
Expand Down Expand Up @@ -302,16 +306,16 @@ func (suite *serviceClientTestSuite) TestServiceClientBalancer() {

func TestHTTPScheme(t *testing.T) {
re := require.New(t)
cli := newPDServiceClient("127.0.0.1:2379", "127.0.0.1:2379", nil, nil, false)
re.Equal("http://127.0.0.1:2379", cli.GetHTTPAddress())
cli = newPDServiceClient("https://127.0.0.1:2379", "127.0.0.1:2379", nil, nil, false)
re.Equal("http://127.0.0.1:2379", cli.GetHTTPAddress())
cli = newPDServiceClient("http://127.0.0.1:2379", "127.0.0.1:2379", nil, nil, false)
re.Equal("http://127.0.0.1:2379", cli.GetHTTPAddress())
cli = newPDServiceClient("127.0.0.1:2379", "127.0.0.1:2379", &tls.Config{}, nil, false)
re.Equal("https://127.0.0.1:2379", cli.GetHTTPAddress())
cli = newPDServiceClient("https://127.0.0.1:2379", "127.0.0.1:2379", &tls.Config{}, nil, false)
re.Equal("https://127.0.0.1:2379", cli.GetHTTPAddress())
cli = newPDServiceClient("http://127.0.0.1:2379", "127.0.0.1:2379", &tls.Config{}, nil, false)
re.Equal("https://127.0.0.1:2379", cli.GetHTTPAddress())
cli := newPDServiceClient(addrToUrl("127.0.0.1:2379", nil), addrToUrl("127.0.0.1:2379", nil), nil, false)
re.Equal("http://127.0.0.1:2379", cli.GetAddress())
cli = newPDServiceClient(addrToUrl("https://127.0.0.1:2379", nil), addrToUrl("127.0.0.1:2379", nil), nil, false)
re.Equal("http://127.0.0.1:2379", cli.GetAddress())
cli = newPDServiceClient(addrToUrl("http://127.0.0.1:2379", nil), addrToUrl("127.0.0.1:2379", nil), nil, false)
re.Equal("http://127.0.0.1:2379", cli.GetAddress())
cli = newPDServiceClient(addrToUrl("127.0.0.1:2379", &tls.Config{}), addrToUrl("127.0.0.1:2379", &tls.Config{}), nil, false)
re.Equal("https://127.0.0.1:2379", cli.GetAddress())
cli = newPDServiceClient(addrToUrl("https://127.0.0.1:2379", &tls.Config{}), addrToUrl("127.0.0.1:2379", &tls.Config{}), nil, false)
re.Equal("https://127.0.0.1:2379", cli.GetAddress())
cli = newPDServiceClient(addrToUrl("http://127.0.0.1:2379", &tls.Config{}), addrToUrl("127.0.0.1:2379", &tls.Config{}), nil, false)
re.Equal("https://127.0.0.1:2379", cli.GetAddress())
}
17 changes: 11 additions & 6 deletions tests/integrations/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,12 @@ func TestClientLeaderChange(t *testing.T) {
defer cluster.Destroy()

endpoints := runServer(re, cluster)
cli := setupCli(re, ctx, endpoints)
endpointsWithWrongURL := append([]string{}, endpoints...)
// inject wrong http scheme
for i := range endpointsWithWrongURL {
endpointsWithWrongURL[i] = "https://" + strings.TrimPrefix(endpointsWithWrongURL[i], "http://")
}
cli := setupCli(re, ctx, endpointsWithWrongURL)
defer cli.Close()
innerCli, ok := cli.(interface{ GetServiceDiscovery() pd.ServiceDiscovery })
re.True(ok)
Expand All @@ -127,14 +132,14 @@ func TestClientLeaderChange(t *testing.T) {
re.True(cluster.CheckTSOUnique(ts1))

leader := cluster.GetLeader()
waitLeader(re, innerCli.GetServiceDiscovery(), cluster.GetServer(leader).GetConfig().ClientUrls)
waitLeader(re, innerCli.GetServiceDiscovery(), cluster.GetServer(leader))

err = cluster.GetServer(leader).Stop()
re.NoError(err)
leader = cluster.WaitLeader()
re.NotEmpty(leader)

waitLeader(re, innerCli.GetServiceDiscovery(), cluster.GetServer(leader).GetConfig().ClientUrls)
waitLeader(re, innerCli.GetServiceDiscovery(), cluster.GetServer(leader))

// Check TS won't fall back after leader changed.
testutil.Eventually(re, func() bool {
Expand Down Expand Up @@ -955,10 +960,10 @@ func setupCli(re *require.Assertions, ctx context.Context, endpoints []string, o
return cli
}

func waitLeader(re *require.Assertions, cli pd.ServiceDiscovery, leader string) {
func waitLeader(re *require.Assertions, cli pd.ServiceDiscovery, leader *tests.TestServer) {
testutil.Eventually(re, func() bool {
cli.ScheduleCheckMemberChanged()
return cli.GetServingAddr() == leader
return cli.GetServingAddr() == leader.GetConfig().ClientUrls && leader.GetAddr() == cli.GetServingAddr()
})
}

Expand Down Expand Up @@ -1853,7 +1858,7 @@ func (suite *clientTestSuite) TestMemberUpdateBackOff() {
re.True(ok)

leader := cluster.GetLeader()
waitLeader(re, innerCli.GetServiceDiscovery(), cluster.GetServer(leader).GetConfig().ClientUrls)
waitLeader(re, innerCli.GetServiceDiscovery(), cluster.GetServer(leader))
memberID := cluster.GetServer(leader).GetLeader().GetMemberId()

re.NoError(failpoint.Enable("github.com/tikv/pd/server/leaderLoopCheckAgain", fmt.Sprintf("return(\"%d\")", memberID)))
Expand Down

0 comments on commit 6438a65

Please sign in to comment.