diff --git a/client/http/client.go b/client/http/client.go index 389c5ee5bef..5ac00a8a43b 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -305,7 +305,8 @@ func NewClient( } sd := pd.NewDefaultPDServiceDiscovery(ctx, cancel, pdAddrs, c.inner.tlsConf) if err := sd.Init(); err != nil { - log.Error("[pd] init service discovery failed", zap.String("source", source), zap.Strings("pd-addrs", pdAddrs), zap.Error(err)) + log.Error("[pd] init service discovery failed", + zap.String("source", source), zap.Strings("pd-addrs", pdAddrs), zap.Error(err)) return nil } c.inner.init(sd) @@ -382,9 +383,8 @@ func NewHTTPClientWithRequestChecker(checker requestChecker) *http.Client { } } -// newClientWithoutInitServiceDiscovery creates a PD HTTP client -// with the given PD addresses and TLS config without init service discovery. -func newClientWithoutInitServiceDiscovery( +// newClientWithMockServiceDiscovery creates a new PD HTTP client with a mock PD service discovery. +func newClientWithMockServiceDiscovery( source string, pdAddrs []string, opts ...ClientOption, @@ -395,7 +395,12 @@ func newClientWithoutInitServiceDiscovery( for _, opt := range opts { opt(c) } - sd := pd.NewDefaultPDServiceDiscovery(ctx, cancel, pdAddrs, c.inner.tlsConf) + sd := pd.NewMockPDServiceDiscovery(pdAddrs, c.inner.tlsConf) + if err := sd.Init(); err != nil { + log.Error("[pd] init mock service discovery failed", + zap.String("source", source), zap.Strings("pd-addrs", pdAddrs), zap.Error(err)) + return nil + } c.inner.init(sd) return c } diff --git a/client/http/client_test.go b/client/http/client_test.go index 02fce93838e..49faefefaec 100644 --- a/client/http/client_test.go +++ b/client/http/client_test.go @@ -28,6 +28,7 @@ import ( func TestPDAllowFollowerHandleHeader(t *testing.T) { re := require.New(t) + checked := 0 httpClient := NewHTTPClientWithRequestChecker(func(req *http.Request) error { var expectedVal string if req.URL.Path == HotHistory { @@ -38,16 +39,19 @@ func TestPDAllowFollowerHandleHeader(t *testing.T) { re.Failf("PD allow follower handler header check failed", "should be %s, but got %s", expectedVal, val) } + checked++ return nil }) - c := newClientWithoutInitServiceDiscovery("test-header", []string{"http://127.0.0.1"}, WithHTTPClient(httpClient)) + c := newClientWithMockServiceDiscovery("test-header", []string{"http://127.0.0.1"}, WithHTTPClient(httpClient)) + defer c.Close() c.GetRegions(context.Background()) c.GetHistoryHotRegions(context.Background(), &HistoryHotRegionsRequest{}) - c.Close() + re.Equal(2, checked) } func TestCallerID(t *testing.T) { re := require.New(t) + checked := 0 expectedVal := atomic.NewString(defaultCallerID) httpClient := NewHTTPClientWithRequestChecker(func(req *http.Request) error { val := req.Header.Get(xCallerIDKey) @@ -56,20 +60,23 @@ func TestCallerID(t *testing.T) { re.Failf("Caller ID header check failed", "should be %s, but got %s", expectedVal, val) } + checked++ return nil }) - c := newClientWithoutInitServiceDiscovery("test-caller-id", []string{"http://127.0.0.1"}, WithHTTPClient(httpClient)) + c := newClientWithMockServiceDiscovery("test-caller-id", []string{"http://127.0.0.1"}, WithHTTPClient(httpClient)) + defer c.Close() c.GetRegions(context.Background()) expectedVal.Store("test") c.WithCallerID(expectedVal.Load()).GetRegions(context.Background()) - c.Close() + re.Equal(2, checked) } func TestWithBackoffer(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - c := newClientWithoutInitServiceDiscovery("test-with-backoffer", []string{"http://127.0.0.1"}) + c := newClientWithMockServiceDiscovery("test-with-backoffer", []string{"http://127.0.0.1"}) + defer c.Close() base := 100 * time.Millisecond max := 500 * time.Millisecond @@ -88,5 +95,4 @@ func TestWithBackoffer(t *testing.T) { _, err = c.WithBackoffer(bo).GetPDVersion(timeoutCtx) re.InDelta(3*time.Second, time.Since(start), float64(250*time.Millisecond)) re.ErrorIs(err, context.DeadlineExceeded) - c.Close() } diff --git a/client/mock_pd_service_discovery.go b/client/mock_pd_service_discovery.go new file mode 100644 index 00000000000..10f7f080106 --- /dev/null +++ b/client/mock_pd_service_discovery.go @@ -0,0 +1,74 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pd + +import ( + "crypto/tls" + "sync" + + "google.golang.org/grpc" +) + +var _ ServiceDiscovery = (*mockPDServiceDiscovery)(nil) + +type mockPDServiceDiscovery struct { + urls []string + tlsCfg *tls.Config + clients []ServiceClient +} + +// NewMockPDServiceDiscovery creates a mock PD service discovery. +func NewMockPDServiceDiscovery(urls []string, tlsCfg *tls.Config) *mockPDServiceDiscovery { + return &mockPDServiceDiscovery{ + urls: urls, + tlsCfg: tlsCfg, + } +} + +// Init directly creates the service clients with the given URLs. +func (m *mockPDServiceDiscovery) Init() error { + m.clients = make([]ServiceClient, 0, len(m.urls)) + for _, url := range m.urls { + m.clients = append(m.clients, newPDServiceClient(url, url, m.tlsCfg, nil, false)) + } + return nil +} + +// Close clears the service clients. +func (m *mockPDServiceDiscovery) Close() { + clear(m.clients) +} + +// GetAllServiceClients returns all service clients init in the mock PD service discovery. +func (m *mockPDServiceDiscovery) GetAllServiceClients() []ServiceClient { + return m.clients +} + +func (m *mockPDServiceDiscovery) GetClusterID() uint64 { return 0 } +func (m *mockPDServiceDiscovery) GetKeyspaceID() uint32 { return 0 } +func (m *mockPDServiceDiscovery) GetKeyspaceGroupID() uint32 { return 0 } +func (m *mockPDServiceDiscovery) GetServiceURLs() []string { return nil } +func (m *mockPDServiceDiscovery) GetServingEndpointClientConn() *grpc.ClientConn { return nil } +func (m *mockPDServiceDiscovery) GetClientConns() *sync.Map { return nil } +func (m *mockPDServiceDiscovery) GetServingAddr() string { return "" } +func (m *mockPDServiceDiscovery) GetBackupAddrs() []string { return nil } +func (m *mockPDServiceDiscovery) GetServiceClient() ServiceClient { return nil } +func (m *mockPDServiceDiscovery) GetOrCreateGRPCConn(addr string) (*grpc.ClientConn, error) { + return nil, nil +} +func (m *mockPDServiceDiscovery) ScheduleCheckMemberChanged() {} +func (m *mockPDServiceDiscovery) CheckMemberChanged() error { return nil } +func (m *mockPDServiceDiscovery) AddServingAddrSwitchedCallback(callbacks ...func()) {} +func (m *mockPDServiceDiscovery) AddServiceAddrsSwitchedCallback(callbacks ...func()) {} diff --git a/client/pd_service_discovery.go b/client/pd_service_discovery.go index 137a5c640c0..5d9105e7681 100644 --- a/client/pd_service_discovery.go +++ b/client/pd_service_discovery.go @@ -424,8 +424,10 @@ type tsoAllocatorEventSource interface { SetTSOGlobalServAddrUpdatedCallback(callback tsoGlobalServAddrUpdatedFunc) } -var _ ServiceDiscovery = (*pdServiceDiscovery)(nil) -var _ tsoAllocatorEventSource = (*pdServiceDiscovery)(nil) +var ( + _ ServiceDiscovery = (*pdServiceDiscovery)(nil) + _ tsoAllocatorEventSource = (*pdServiceDiscovery)(nil) +) // pdServiceDiscovery is the service discovery client of PD/API service which is quorum based type pdServiceDiscovery struct {