-
Notifications
You must be signed in to change notification settings - Fork 720
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add basic tests for rewritten tsoStream
Signed-off-by: MyonKeminta <[email protected]>
- Loading branch information
1 parent
140a7c2
commit 88d9f5b
Showing
2 changed files
with
307 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,263 @@ | ||
// Copyright 2024 TiKV Project Authors. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package pd | ||
|
||
import ( | ||
"context" | ||
"io" | ||
"testing" | ||
"time" | ||
|
||
"github.com/pingcap/errors" | ||
"github.com/stretchr/testify/suite" | ||
"github.com/tikv/pd/client/errs" | ||
) | ||
|
||
type resultMsg struct { | ||
r tsoRequestResult | ||
err error | ||
breakStream bool | ||
} | ||
|
||
type mockTSOStreamImpl struct { | ||
requestCh chan struct{} | ||
resultCh chan resultMsg | ||
keyspaceID uint32 | ||
} | ||
|
||
func newMockTSOStreamImpl() *mockTSOStreamImpl { | ||
return &mockTSOStreamImpl{ | ||
requestCh: make(chan struct{}, 64), | ||
resultCh: make(chan resultMsg, 64), | ||
keyspaceID: 0, | ||
} | ||
} | ||
|
||
func (s *mockTSOStreamImpl) Send(clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string, count int64) error { | ||
s.requestCh <- struct{}{} | ||
return nil | ||
} | ||
|
||
func (s *mockTSOStreamImpl) Recv() (tsoRequestResult, error) { | ||
res := <-s.resultCh | ||
if !res.breakStream { | ||
<-s.requestCh | ||
} | ||
return res.r, res.err | ||
} | ||
|
||
func (s *mockTSOStreamImpl) returnResult(physical int64, logical int64, count uint32) { | ||
s.resultCh <- resultMsg{ | ||
r: tsoRequestResult{ | ||
physical: physical, | ||
logical: logical, | ||
count: count, | ||
suffixBits: 0, | ||
respKeyspaceGroupID: s.keyspaceID, | ||
}, | ||
} | ||
} | ||
|
||
func (s *mockTSOStreamImpl) returnError(err error) { | ||
s.resultCh <- resultMsg{ | ||
err: err, | ||
} | ||
} | ||
|
||
func (s *mockTSOStreamImpl) breakStream(err error) { | ||
s.resultCh <- resultMsg{ | ||
err: err, | ||
breakStream: true, | ||
} | ||
} | ||
|
||
func (s *mockTSOStreamImpl) stop() { | ||
s.breakStream(io.EOF) | ||
} | ||
|
||
type callbackInvocation struct { | ||
result tsoRequestResult | ||
streamURL string | ||
err error | ||
} | ||
|
||
type testTSOStreamSuite struct { | ||
suite.Suite | ||
|
||
inner *mockTSOStreamImpl | ||
stream *tsoStream | ||
} | ||
|
||
func (s *testTSOStreamSuite) SetupTest() { | ||
s.inner = newMockTSOStreamImpl() | ||
s.stream = newTSOStream("mock:///", s.inner) | ||
} | ||
|
||
func (s *testTSOStreamSuite) TearDownTest() { | ||
s.inner.stop() | ||
s.stream.WaitForClosed() | ||
s.inner = nil | ||
s.stream = nil | ||
} | ||
|
||
func TestTSOStreamTestSuite(t *testing.T) { | ||
suite.Run(t, new(testTSOStreamSuite)) | ||
} | ||
|
||
func (s *testTSOStreamSuite) noResult(ch <-chan callbackInvocation) { | ||
select { | ||
case res := <-ch: | ||
s.FailNowf("result received unexpectedly", "received result: %+v", res) | ||
case <-time.After(time.Millisecond * 20): | ||
} | ||
} | ||
|
||
func (s *testTSOStreamSuite) getResult(ch <-chan callbackInvocation) callbackInvocation { | ||
select { | ||
case res := <-ch: | ||
return res | ||
case <-time.After(time.Second * 10000): | ||
s.FailNow("result not ready in time") | ||
panic("result not ready in time") | ||
} | ||
} | ||
|
||
func (s *testTSOStreamSuite) processRequestWithResultCh(count int64) <-chan callbackInvocation { | ||
ch := make(chan callbackInvocation, 1) | ||
err := s.stream.processRequests(1, 2, 3, globalDCLocation, count, time.Now(), func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) { | ||
if err == nil { | ||
s.Equal(uint32(3), reqKeyspaceGroupID) | ||
s.Equal(uint32(0), result.suffixBits) | ||
} | ||
ch <- callbackInvocation{ | ||
result: result, | ||
streamURL: streamURL, | ||
err: err, | ||
} | ||
}) | ||
s.NoError(err) | ||
return ch | ||
} | ||
|
||
func (s *testTSOStreamSuite) TestTSOStreamBasic() { | ||
ch := s.processRequestWithResultCh(1) | ||
s.noResult(ch) | ||
s.inner.returnResult(10, 1, 1) | ||
res := s.getResult(ch) | ||
|
||
s.NoError(res.err) | ||
s.Equal("mock:///", res.streamURL) | ||
s.Equal(int64(10), res.result.physical) | ||
s.Equal(int64(1), res.result.logical) | ||
s.Equal(uint32(1), res.result.count) | ||
|
||
ch = s.processRequestWithResultCh(2) | ||
s.noResult(ch) | ||
s.inner.returnResult(20, 3, 2) | ||
res = s.getResult(ch) | ||
|
||
s.NoError(res.err) | ||
s.Equal("mock:///", res.streamURL) | ||
s.Equal(int64(20), res.result.physical) | ||
s.Equal(int64(3), res.result.logical) | ||
s.Equal(uint32(2), res.result.count) | ||
|
||
ch = s.processRequestWithResultCh(3) | ||
s.noResult(ch) | ||
s.inner.returnError(errors.New("mock rpc error")) | ||
res = s.getResult(ch) | ||
s.Error(res.err) | ||
s.Equal("mock rpc error", res.err.Error()) | ||
|
||
// After an error from the (simulated) RPC stream, the tsoStream should be in a broken status and can't accept | ||
// new request anymore. | ||
err := s.stream.processRequests(1, 2, 3, globalDCLocation, 1, time.Now(), func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) { | ||
panic("unreachable") | ||
}) | ||
s.Error(err) | ||
} | ||
|
||
func (s *testTSOStreamSuite) testTSOStreamBrokenImpl(err error, pendingRequests int) { | ||
var resultCh []<-chan callbackInvocation | ||
|
||
for i := 0; i < pendingRequests; i++ { | ||
ch := s.processRequestWithResultCh(1) | ||
resultCh = append(resultCh, ch) | ||
s.noResult(ch) | ||
} | ||
|
||
s.inner.breakStream(err) | ||
closedCh := make(chan struct{}) | ||
go func() { | ||
s.stream.WaitForClosed() | ||
closedCh <- struct{}{} | ||
}() | ||
select { | ||
case <-closedCh: | ||
case <-time.After(time.Second): | ||
s.FailNow("stream receiver loop didn't exit") | ||
} | ||
|
||
for _, ch := range resultCh { | ||
res := s.getResult(ch) | ||
s.Error(res.err) | ||
if err == io.EOF { | ||
s.ErrorIs(res.err, errs.ErrClientTSOStreamClosed) | ||
} else { | ||
s.ErrorIs(res.err, err) | ||
} | ||
} | ||
} | ||
|
||
func (s *testTSOStreamSuite) TestTSOStreamBrokenWithEOFNoPendingReq() { | ||
s.testTSOStreamBrokenImpl(io.EOF, 0) | ||
} | ||
|
||
func (s *testTSOStreamSuite) TestTSOStreamCanceledNoPendingReq() { | ||
s.testTSOStreamBrokenImpl(context.Canceled, 0) | ||
} | ||
|
||
func (s *testTSOStreamSuite) TestTSOStreamBrokenWithEOFWithPendingReq() { | ||
s.testTSOStreamBrokenImpl(io.EOF, 5) | ||
} | ||
|
||
func (s *testTSOStreamSuite) TestTSOStreamCanceledWithPendingReq() { | ||
s.testTSOStreamBrokenImpl(context.Canceled, 5) | ||
} | ||
|
||
func (s *testTSOStreamSuite) TestTSOStreamFIFO() { | ||
var resultChs []<-chan callbackInvocation | ||
const COUNT = 5 | ||
for i := 0; i < COUNT; i++ { | ||
ch := s.processRequestWithResultCh(int64(i + 1)) | ||
resultChs = append(resultChs, ch) | ||
} | ||
|
||
for _, ch := range resultChs { | ||
s.noResult(ch) | ||
} | ||
|
||
for i := 0; i < COUNT; i++ { | ||
s.inner.returnResult(int64((i+1)*10), int64(i), uint32(i+1)) | ||
} | ||
|
||
for i, ch := range resultChs { | ||
res := s.getResult(ch) | ||
s.NoError(res.err) | ||
s.Equal(int64((i+1)*10), res.result.physical) | ||
s.Equal(int64(i), res.result.logical) | ||
s.Equal(uint32(i+1), res.result.count) | ||
} | ||
} |