diff --git a/client/v3/client.go b/client/v3/client.go index 60cbdf21454..efa44e8902d 100644 --- a/client/v3/client.go +++ b/client/v3/client.go @@ -502,7 +502,7 @@ func (c *Client) checkVersion() (err error) { return } } - if maj < 3 || (maj == 3 && min < 2) { + if maj < 3 || (maj == 3 && min < 4) { rerr = ErrOldCluster } errc <- rerr @@ -510,7 +510,7 @@ func (c *Client) checkVersion() (err error) { } // wait for success for range eps { - if err = <-errc; err == nil { + if err = <-errc; err != nil { break } } diff --git a/client/v3/client_test.go b/client/v3/client_test.go index e441476374b..c6e7c7a977d 100644 --- a/client/v3/client_test.go +++ b/client/v3/client_test.go @@ -17,7 +17,9 @@ package clientv3 import ( "context" "fmt" + "io" "net" + "sync" "testing" "time" @@ -294,3 +296,99 @@ func (mc *mockCluster) MemberUpdate(ctx context.Context, id uint64, peerAddrs [] func (mc *mockCluster) MemberPromote(ctx context.Context, id uint64) (*MemberPromoteResponse, error) { return nil, nil } + +func TestClientRejectOldCluster(t *testing.T) { + testutil.RegisterLeakDetection(t) + var tests = []struct { + name string + endpoints []string + versions []string + expectedError error + }{ + { + name: "all new versions with the same value", + endpoints: []string{"192.168.3.41:22379", "192.168.3.41:22479", "192.168.3.41:22579"}, + versions: []string{"3.5.4", "3.5.4", "3.5.4"}, + expectedError: nil, + }, + { + name: "all new versions with different values", + endpoints: []string{"192.168.3.41:22379", "192.168.3.41:22479", "192.168.3.41:22579"}, + versions: []string{"3.5.4", "3.5.4", "3.4.0"}, + expectedError: nil, + }, + { + name: "all old versions with different values", + endpoints: []string{"192.168.3.41:22379", "192.168.3.41:22479", "192.168.3.41:22579"}, + versions: []string{"3.3.0", "3.3.0", "3.4.0"}, + expectedError: ErrOldCluster, + }, + { + name: "all old versions with the same value", + endpoints: []string{"192.168.3.41:22379", "192.168.3.41:22479", "192.168.3.41:22579"}, + versions: []string{"3.3.0", "3.3.0", "3.3.0"}, + expectedError: ErrOldCluster, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if len(tt.endpoints) != len(tt.versions) || len(tt.endpoints) == 0 { + t.Errorf("Unexpected endpoints and versions length, len(endpoints):%d, len(versions):%d", len(tt.endpoints), len(tt.versions)) + return + } + endpointToVersion := make(map[string]string) + for j := range tt.endpoints { + endpointToVersion[tt.endpoints[j]] = tt.versions[j] + } + c := &Client{ + ctx: context.Background(), + cfg: Config{ + Endpoints: tt.endpoints, + }, + mu: new(sync.RWMutex), + Maintenance: &mockMaintenance{ + Version: endpointToVersion, + }, + } + + if err := c.checkVersion(); err != tt.expectedError { + t.Errorf("heckVersion err:%v", err) + } + }) + + } + +} + +type mockMaintenance struct { + Version map[string]string +} + +func (mm mockMaintenance) Status(ctx context.Context, endpoint string) (*StatusResponse, error) { + return &StatusResponse{Version: mm.Version[endpoint]}, nil +} + +func (mm mockMaintenance) AlarmList(ctx context.Context) (*AlarmResponse, error) { + return nil, nil +} + +func (mm mockMaintenance) AlarmDisarm(ctx context.Context, m *AlarmMember) (*AlarmResponse, error) { + return nil, nil +} + +func (mm mockMaintenance) Defragment(ctx context.Context, endpoint string) (*DefragmentResponse, error) { + return nil, nil +} + +func (mm mockMaintenance) HashKV(ctx context.Context, endpoint string, rev int64) (*HashKVResponse, error) { + return nil, nil +} + +func (mm mockMaintenance) Snapshot(ctx context.Context) (io.ReadCloser, error) { + return nil, nil +} + +func (mm mockMaintenance) MoveLeader(ctx context.Context, transfereeID uint64) (*MoveLeaderResponse, error) { + return nil, nil +}