diff --git a/internal/distributed/streaming/internal/consumer/consumer_impl.go b/internal/distributed/streaming/internal/consumer/consumer_impl.go index 082c61afd1758..14632e3187ab9 100644 --- a/internal/distributed/streaming/internal/consumer/consumer_impl.go +++ b/internal/distributed/streaming/internal/consumer/consumer_impl.go @@ -74,11 +74,12 @@ func (rc *resumableConsumerImpl) resumeLoop() { // consumer need to resume when error occur, so message handler shouldn't close if the internal consumer encounter failure. nopCloseMH := nopCloseHandler{ Handler: rc.mh, - HandleInterceptor: func(ctx context.Context, msg message.ImmutableMessage, handle handleFunc) (bool, error) { - g := rc.metrics.StartConsume(msg.EstimateSize()) - ok, err := handle(ctx, msg) - g.Finish() - return ok, err + HandleInterceptor: func(handleParam message.HandleParam, h message.Handler) message.HandleResult { + if handleParam.Message != nil { + g := rc.metrics.StartConsume(handleParam.Message.EstimateSize()) + defer func() { g.Finish() }() + } + return h.Handle(handleParam) }, } diff --git a/internal/distributed/streaming/internal/consumer/consumer_test.go b/internal/distributed/streaming/internal/consumer/consumer_test.go index c6a85792b320c..ffe4d01177de2 100644 --- a/internal/distributed/streaming/internal/consumer/consumer_test.go +++ b/internal/distributed/streaming/internal/consumer/consumer_test.go @@ -12,6 +12,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/client/handler" "github.com/milvus-io/milvus/internal/streamingnode/client/handler/consumer" "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/message/adaptor" "github.com/milvus-io/milvus/pkg/streaming/util/options" "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/walimplstest" ) @@ -22,22 +23,25 @@ func TestResumableConsumer(t *testing.T) { ch := make(chan struct{}) c.EXPECT().Done().Return(ch) c.EXPECT().Error().Return(errors.New("test")) - c.EXPECT().Close().Return() + c.EXPECT().Close().Return(nil) rc := NewResumableConsumer(func(ctx context.Context, opts *handler.ConsumerOptions) (consumer.Consumer, error) { if i == 0 { i++ - ok, err := opts.MessageHandler.Handle(context.Background(), message.NewImmutableMesasge( - walimplstest.NewTestMessageID(123), - []byte("payload"), - map[string]string{ - "key": "value", - "_t": "1", - "_tt": message.EncodeUint64(456), - "_v": "1", - "_lc": walimplstest.NewTestMessageID(123).Marshal(), - })) - assert.True(t, ok) - assert.NoError(t, err) + result := opts.MessageHandler.Handle(message.HandleParam{ + Ctx: context.Background(), + Message: message.NewImmutableMesasge( + walimplstest.NewTestMessageID(123), + []byte("payload"), + map[string]string{ + "key": "value", + "_t": "1", + "_tt": message.EncodeUint64(456), + "_v": "1", + "_lc": walimplstest.NewTestMessageID(123).Marshal(), + }), + }) + assert.True(t, result.MessageHandled) + assert.NoError(t, result.Error) return c, nil } else if i == 1 { i++ @@ -46,7 +50,7 @@ func TestResumableConsumer(t *testing.T) { newC := mock_consumer.NewMockConsumer(t) newC.EXPECT().Done().Return(make(<-chan struct{})) newC.EXPECT().Error().Return(errors.New("test")) - newC.EXPECT().Close().Return() + newC.EXPECT().Close().Return(nil) return newC, nil }, &ConsumerOptions{ PChannel: "test", @@ -54,7 +58,7 @@ func TestResumableConsumer(t *testing.T) { DeliverFilters: []options.DeliverFilter{ options.DeliverFilterTimeTickGT(1), }, - MessageHandler: message.ChanMessageHandler(make(chan message.ImmutableMessage, 2)), + MessageHandler: adaptor.ChanMessageHandler(make(chan message.ImmutableMessage, 2)), }) select { @@ -76,10 +80,13 @@ func TestResumableConsumer(t *testing.T) { func TestHandler(t *testing.T) { ch := make(chan message.ImmutableMessage, 100) hNop := nopCloseHandler{ - Handler: message.ChanMessageHandler(ch), + Handler: adaptor.ChanMessageHandler(ch), } - hNop.Handle(context.Background(), nil) - assert.Nil(t, <-ch) + hNop.Handle(message.HandleParam{ + Ctx: context.Background(), + Message: message.NewImmutableMesasge(walimplstest.NewTestMessageID(123), []byte("payload"), nil), + }) + assert.NotNil(t, <-ch) hNop.Close() select { case <-ch: diff --git a/internal/distributed/streaming/internal/consumer/handler.go b/internal/distributed/streaming/internal/consumer/handler.go index d106b9da4d5fe..c78bbd2de826d 100644 --- a/internal/distributed/streaming/internal/consumer/handler.go +++ b/internal/distributed/streaming/internal/consumer/handler.go @@ -1,25 +1,21 @@ package consumer import ( - "context" - "github.com/milvus-io/milvus/pkg/streaming/util/message" ) -type handleFunc func(ctx context.Context, msg message.ImmutableMessage) (bool, error) - // nopCloseHandler is a handler that do nothing when close. type nopCloseHandler struct { message.Handler - HandleInterceptor func(ctx context.Context, msg message.ImmutableMessage, handle handleFunc) (bool, error) + HandleInterceptor func(handleParam message.HandleParam, h message.Handler) message.HandleResult } // Handle is the callback for handling message. -func (nch nopCloseHandler) Handle(ctx context.Context, msg message.ImmutableMessage) (bool, error) { +func (nch nopCloseHandler) Handle(handleParam message.HandleParam) message.HandleResult { if nch.HandleInterceptor != nil { - return nch.HandleInterceptor(ctx, msg, nch.Handler.Handle) + return nch.HandleInterceptor(handleParam, nch.Handler) } - return nch.Handler.Handle(ctx, msg) + return nch.Handler.Handle(handleParam) } // Close is called after all messages are handled or handling is interrupted. diff --git a/internal/distributed/streaming/internal/consumer/message_handler.go b/internal/distributed/streaming/internal/consumer/message_handler.go index 538052ee174c0..e790ad2c6d773 100644 --- a/internal/distributed/streaming/internal/consumer/message_handler.go +++ b/internal/distributed/streaming/internal/consumer/message_handler.go @@ -1,8 +1,6 @@ package consumer import ( - "context" - "github.com/milvus-io/milvus/pkg/streaming/util/message" ) @@ -13,16 +11,20 @@ type timeTickOrderMessageHandler struct { lastTimeTick uint64 } -func (mh *timeTickOrderMessageHandler) Handle(ctx context.Context, msg message.ImmutableMessage) (bool, error) { - lastConfirmedMessageID := msg.LastConfirmedMessageID() - timetick := msg.TimeTick() +func (mh *timeTickOrderMessageHandler) Handle(handleParam message.HandleParam) message.HandleResult { + var lastConfirmedMessageID message.MessageID + var lastTimeTick uint64 + if handleParam.Message != nil { + lastConfirmedMessageID = handleParam.Message.LastConfirmedMessageID() + lastTimeTick = handleParam.Message.TimeTick() + } - ok, err := mh.inner.Handle(ctx, msg) - if ok { + result := mh.inner.Handle(handleParam) + if result.MessageHandled { mh.lastConfirmedMessageID = lastConfirmedMessageID - mh.lastTimeTick = timetick + mh.lastTimeTick = lastTimeTick } - return ok, err + return result } func (mh *timeTickOrderMessageHandler) Close() { diff --git a/internal/distributed/streaming/internal/producer/producer.go b/internal/distributed/streaming/internal/producer/producer.go index 1feaab90a6fdf..19eb0aaef8527 100644 --- a/internal/distributed/streaming/internal/producer/producer.go +++ b/internal/distributed/streaming/internal/producer/producer.go @@ -15,6 +15,7 @@ import ( "github.com/milvus-io/milvus/internal/util/streamingutil/status" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/util/syncutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -77,7 +78,7 @@ type ResumableProducer struct { } // Produce produce a new message to log service. -func (p *ResumableProducer) Produce(ctx context.Context, msg message.MutableMessage) (result *producer.ProduceResult, err error) { +func (p *ResumableProducer) Produce(ctx context.Context, msg message.MutableMessage) (result *types.AppendResult, err error) { if !p.lifetime.Add(typeutil.LifetimeStateWorking) { return nil, errors.Wrapf(errs.ErrClosed, "produce on closed producer") } @@ -94,7 +95,7 @@ func (p *ResumableProducer) Produce(ctx context.Context, msg message.MutableMess return nil, err } - produceResult, err := producerHandler.Produce(ctx, msg) + produceResult, err := producerHandler.Append(ctx, msg) if err == nil { return produceResult, nil } diff --git a/internal/distributed/streaming/internal/producer/producer_test.go b/internal/distributed/streaming/internal/producer/producer_test.go index d98be5dde3d32..7ef50b7fc2ec4 100644 --- a/internal/distributed/streaming/internal/producer/producer_test.go +++ b/internal/distributed/streaming/internal/producer/producer_test.go @@ -14,12 +14,13 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/client/handler/producer" "github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message" "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/types" ) func TestResumableProducer(t *testing.T) { p := mock_producer.NewMockProducer(t) msgID := mock_message.NewMockMessageID(t) - p.EXPECT().Produce(mock.Anything, mock.Anything).Return(&producer.ProduceResult{ + p.EXPECT().Append(mock.Anything, mock.Anything).Return(&types.AppendResult{ MessageID: msgID, TimeTick: 100, }, nil) @@ -47,11 +48,11 @@ func TestResumableProducer(t *testing.T) { } else if i == 2 { p := mock_producer.NewMockProducer(t) msgID := mock_message.NewMockMessageID(t) - p.EXPECT().Produce(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, mm message.MutableMessage) (*producer.ProduceResult, error) { + p.EXPECT().Append(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, mm message.MutableMessage) (*types.AppendResult, error) { if ctx.Err() != nil { return nil, ctx.Err() } - return &producer.ProduceResult{ + return &types.AppendResult{ MessageID: msgID, TimeTick: 100, }, nil diff --git a/internal/distributed/streaming/streaming_test.go b/internal/distributed/streaming/streaming_test.go index c24f65261636d..3f166e466d845 100644 --- a/internal/distributed/streaming/streaming_test.go +++ b/internal/distributed/streaming/streaming_test.go @@ -9,6 +9,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/distributed/streaming" "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/message/adaptor" "github.com/milvus-io/milvus/pkg/streaming/util/options" _ "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/pulsar" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -100,7 +101,7 @@ func TestStreamingConsume(t *testing.T) { t.Skip() streaming.Init() defer streaming.Release() - ch := make(message.ChanMessageHandler, 10) + ch := make(adaptor.ChanMessageHandler, 10) s := streaming.WAL().Read(context.Background(), streaming.ReadOption{ VChannel: vChannel, DeliverPolicy: options.DeliverPolicyAll(), diff --git a/internal/distributed/streaming/wal_test.go b/internal/distributed/streaming/wal_test.go index db527c044eddb..da8e08306e179 100644 --- a/internal/distributed/streaming/wal_test.go +++ b/internal/distributed/streaming/wal_test.go @@ -56,7 +56,7 @@ func TestWAL(t *testing.T) { return true } }) - p.EXPECT().Produce(mock.Anything, mock.Anything).Return(&types.AppendResult{ + p.EXPECT().Append(mock.Anything, mock.Anything).Return(&types.AppendResult{ MessageID: walimplstest.NewTestMessageID(1), TimeTick: 10, TxnCtx: &message.TxnContext{ diff --git a/internal/mocks/streamingnode/client/handler/mock_consumer/mock_Consumer.go b/internal/mocks/streamingnode/client/handler/mock_consumer/mock_Consumer.go index efa7eb0f7f894..e9328568f4a89 100644 --- a/internal/mocks/streamingnode/client/handler/mock_consumer/mock_Consumer.go +++ b/internal/mocks/streamingnode/client/handler/mock_consumer/mock_Consumer.go @@ -18,8 +18,21 @@ func (_m *MockConsumer) EXPECT() *MockConsumer_Expecter { } // Close provides a mock function with given fields: -func (_m *MockConsumer) Close() { - _m.Called() +func (_m *MockConsumer) Close() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 } // MockConsumer_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' @@ -39,12 +52,12 @@ func (_c *MockConsumer_Close_Call) Run(run func()) *MockConsumer_Close_Call { return _c } -func (_c *MockConsumer_Close_Call) Return() *MockConsumer_Close_Call { - _c.Call.Return() +func (_c *MockConsumer_Close_Call) Return(_a0 error) *MockConsumer_Close_Call { + _c.Call.Return(_a0) return _c } -func (_c *MockConsumer_Close_Call) RunAndReturn(run func()) *MockConsumer_Close_Call { +func (_c *MockConsumer_Close_Call) RunAndReturn(run func() error) *MockConsumer_Close_Call { _c.Call.Return(run) return _c } diff --git a/internal/mocks/streamingnode/client/handler/mock_producer/mock_Producer.go b/internal/mocks/streamingnode/client/handler/mock_producer/mock_Producer.go index b215ccd60cc31..36d0a3714d24f 100644 --- a/internal/mocks/streamingnode/client/handler/mock_producer/mock_Producer.go +++ b/internal/mocks/streamingnode/client/handler/mock_producer/mock_Producer.go @@ -24,47 +24,61 @@ func (_m *MockProducer) EXPECT() *MockProducer_Expecter { return &MockProducer_Expecter{mock: &_m.Mock} } -// Assignment provides a mock function with given fields: -func (_m *MockProducer) Assignment() types.PChannelInfoAssigned { - ret := _m.Called() +// Append provides a mock function with given fields: ctx, msg +func (_m *MockProducer) Append(ctx context.Context, msg message.MutableMessage) (*types.AppendResult, error) { + ret := _m.Called(ctx, msg) if len(ret) == 0 { - panic("no return value specified for Assignment") + panic("no return value specified for Append") } - var r0 types.PChannelInfoAssigned - if rf, ok := ret.Get(0).(func() types.PChannelInfoAssigned); ok { - r0 = rf() + var r0 *types.AppendResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage) (*types.AppendResult, error)); ok { + return rf(ctx, msg) + } + if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage) *types.AppendResult); ok { + r0 = rf(ctx, msg) } else { - r0 = ret.Get(0).(types.PChannelInfoAssigned) + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.AppendResult) + } } - return r0 + if rf, ok := ret.Get(1).(func(context.Context, message.MutableMessage) error); ok { + r1 = rf(ctx, msg) + } else { + r1 = ret.Error(1) + } + + return r0, r1 } -// MockProducer_Assignment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Assignment' -type MockProducer_Assignment_Call struct { +// MockProducer_Append_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Append' +type MockProducer_Append_Call struct { *mock.Call } -// Assignment is a helper method to define mock.On call -func (_e *MockProducer_Expecter) Assignment() *MockProducer_Assignment_Call { - return &MockProducer_Assignment_Call{Call: _e.mock.On("Assignment")} +// Append is a helper method to define mock.On call +// - ctx context.Context +// - msg message.MutableMessage +func (_e *MockProducer_Expecter) Append(ctx interface{}, msg interface{}) *MockProducer_Append_Call { + return &MockProducer_Append_Call{Call: _e.mock.On("Append", ctx, msg)} } -func (_c *MockProducer_Assignment_Call) Run(run func()) *MockProducer_Assignment_Call { +func (_c *MockProducer_Append_Call) Run(run func(ctx context.Context, msg message.MutableMessage)) *MockProducer_Append_Call { _c.Call.Run(func(args mock.Arguments) { - run() + run(args[0].(context.Context), args[1].(message.MutableMessage)) }) return _c } -func (_c *MockProducer_Assignment_Call) Return(_a0 types.PChannelInfoAssigned) *MockProducer_Assignment_Call { - _c.Call.Return(_a0) +func (_c *MockProducer_Append_Call) Return(_a0 *types.AppendResult, _a1 error) *MockProducer_Append_Call { + _c.Call.Return(_a0, _a1) return _c } -func (_c *MockProducer_Assignment_Call) RunAndReturn(run func() types.PChannelInfoAssigned) *MockProducer_Assignment_Call { +func (_c *MockProducer_Append_Call) RunAndReturn(run func(context.Context, message.MutableMessage) (*types.AppendResult, error)) *MockProducer_Append_Call { _c.Call.Return(run) return _c } @@ -193,65 +207,6 @@ func (_c *MockProducer_IsAvailable_Call) RunAndReturn(run func() bool) *MockProd return _c } -// Produce provides a mock function with given fields: ctx, msg -func (_m *MockProducer) Produce(ctx context.Context, msg message.MutableMessage) (*types.AppendResult, error) { - ret := _m.Called(ctx, msg) - - if len(ret) == 0 { - panic("no return value specified for Produce") - } - - var r0 *types.AppendResult - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage) (*types.AppendResult, error)); ok { - return rf(ctx, msg) - } - if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage) *types.AppendResult); ok { - r0 = rf(ctx, msg) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*types.AppendResult) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, message.MutableMessage) error); ok { - r1 = rf(ctx, msg) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockProducer_Produce_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Produce' -type MockProducer_Produce_Call struct { - *mock.Call -} - -// Produce is a helper method to define mock.On call -// - ctx context.Context -// - msg message.MutableMessage -func (_e *MockProducer_Expecter) Produce(ctx interface{}, msg interface{}) *MockProducer_Produce_Call { - return &MockProducer_Produce_Call{Call: _e.mock.On("Produce", ctx, msg)} -} - -func (_c *MockProducer_Produce_Call) Run(run func(ctx context.Context, msg message.MutableMessage)) *MockProducer_Produce_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(message.MutableMessage)) - }) - return _c -} - -func (_c *MockProducer_Produce_Call) Return(_a0 *types.AppendResult, _a1 error) *MockProducer_Produce_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockProducer_Produce_Call) RunAndReturn(run func(context.Context, message.MutableMessage) (*types.AppendResult, error)) *MockProducer_Produce_Call { - _c.Call.Return(run) - return _c -} - // NewMockProducer creates a new instance of MockProducer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockProducer(t interface { diff --git a/internal/streamingcoord/client/assignment/assignment_test.go b/internal/streamingcoord/client/assignment/assignment_test.go index 60f05873e8c54..19882628c2cb8 100644 --- a/internal/streamingcoord/client/assignment/assignment_test.go +++ b/internal/streamingcoord/client/assignment/assignment_test.go @@ -95,6 +95,10 @@ func TestAssignmentService(t *testing.T) { assignmentService.ReportAssignmentError(ctx, types.PChannelInfo{Name: "c1", Term: 1}, errors.New("test")) + // Repeated report error at the same term should be ignored. + assignmentService.ReportAssignmentError(ctx, types.PChannelInfo{Name: "c1", Term: 1}, errors.New("test")) + assignmentService.ReportAssignmentError(ctx, types.PChannelInfo{Name: "c1", Term: 1}, errors.New("test")) + // test close go close(closeCh) time.Sleep(10 * time.Millisecond) diff --git a/internal/streamingcoord/client/assignment/discoverer.go b/internal/streamingcoord/client/assignment/discoverer.go index b9f92f27e5e0b..74bb2ccc898c6 100644 --- a/internal/streamingcoord/client/assignment/discoverer.go +++ b/internal/streamingcoord/client/assignment/discoverer.go @@ -14,13 +14,14 @@ import ( // newAssignmentDiscoverClient creates a new assignment discover client. func newAssignmentDiscoverClient(w *watcher, streamClient streamingpb.StreamingCoordAssignmentService_AssignmentDiscoverClient) *assignmentDiscoverClient { c := &assignmentDiscoverClient{ - lifetime: typeutil.NewLifetime(), - w: w, - streamClient: streamClient, - logger: log.With(), - requestCh: make(chan *streamingpb.AssignmentDiscoverRequest, 16), - exitCh: make(chan struct{}), - wg: sync.WaitGroup{}, + lifetime: typeutil.NewLifetime(), + w: w, + streamClient: streamClient, + logger: log.With(), + requestCh: make(chan *streamingpb.AssignmentDiscoverRequest, 16), + exitCh: make(chan struct{}), + wg: sync.WaitGroup{}, + lastErrorReportedTerm: make(map[string]int64), } c.executeBackgroundTask() return c @@ -28,13 +29,14 @@ func newAssignmentDiscoverClient(w *watcher, streamClient streamingpb.StreamingC // assignmentDiscoverClient is the client for assignment discover. type assignmentDiscoverClient struct { - lifetime *typeutil.Lifetime - w *watcher - logger *log.MLogger - requestCh chan *streamingpb.AssignmentDiscoverRequest - exitCh chan struct{} - wg sync.WaitGroup - streamClient streamingpb.StreamingCoordAssignmentService_AssignmentDiscoverClient + lifetime *typeutil.Lifetime + w *watcher + logger *log.MLogger + requestCh chan *streamingpb.AssignmentDiscoverRequest + exitCh chan struct{} + wg sync.WaitGroup + streamClient streamingpb.StreamingCoordAssignmentService_AssignmentDiscoverClient + lastErrorReportedTerm map[string]int64 } // ReportAssignmentError reports the assignment error to server. @@ -101,12 +103,28 @@ func (c *assignmentDiscoverClient) sendLoop() (err error) { } return c.streamClient.CloseSend() } + if c.shouldIgnore(req) { + continue + } if err := c.streamClient.Send(req); err != nil { return err } } } +// shouldIgnore checks if the request should be ignored. +func (c *assignmentDiscoverClient) shouldIgnore(req *streamingpb.AssignmentDiscoverRequest) bool { + switch req := req.Command.(type) { + case *streamingpb.AssignmentDiscoverRequest_ReportError: + if term, ok := c.lastErrorReportedTerm[req.ReportError.Pchannel.Name]; ok && req.ReportError.Pchannel.Term <= term { + // If the error at newer term has been reported, ignore it right now. + return true + } + c.lastErrorReportedTerm[req.ReportError.Pchannel.Name] = req.ReportError.Pchannel.Term + } + return false +} + // recvLoop receives the message from server. // 1. FullAssignment // 2. Close diff --git a/internal/streamingcoord/server/balancer/balancer_impl.go b/internal/streamingcoord/server/balancer/balancer_impl.go index 1b8967653a820..ca1a6827cc410 100644 --- a/internal/streamingcoord/server/balancer/balancer_impl.go +++ b/internal/streamingcoord/server/balancer/balancer_impl.go @@ -226,7 +226,7 @@ func (b *balancerImpl) applyBalanceResultToStreamingNode(ctx context.Context, mo // assign the channel to the target node. if err := resource.Resource().StreamingNodeManagerClient().Assign(ctx, channel.CurrentAssignment()); err != nil { - b.logger.Warn("fail to assign channel", zap.Any("assignment", channel.CurrentAssignment())) + b.logger.Warn("fail to assign channel", zap.Any("assignment", channel.CurrentAssignment()), zap.Error(err)) return err } b.logger.Info("assign channel success", zap.Any("assignment", channel.CurrentAssignment())) diff --git a/internal/streamingnode/client/handler/consumer/consumer.go b/internal/streamingnode/client/handler/consumer/consumer.go index d9fd1deb1abd1..32b0b3e995abe 100644 --- a/internal/streamingnode/client/handler/consumer/consumer.go +++ b/internal/streamingnode/client/handler/consumer/consumer.go @@ -14,5 +14,5 @@ type Consumer interface { Done() <-chan struct{} // Close the consumer, release the underlying resources. - Close() + Close() error } diff --git a/internal/streamingnode/client/handler/consumer/consumer_impl.go b/internal/streamingnode/client/handler/consumer/consumer_impl.go index b880f7064a8d8..8c9ba9233aa16 100644 --- a/internal/streamingnode/client/handler/consumer/consumer_impl.go +++ b/internal/streamingnode/client/handler/consumer/consumer_impl.go @@ -107,7 +107,7 @@ type consumerImpl struct { } // Close close the consumer client. -func (c *consumerImpl) Close() { +func (c *consumerImpl) Close() error { // Send the close request to server. if err := c.grpcStreamClient.Send(&streamingpb.ConsumeRequest{ Request: &streamingpb.ConsumeRequest_Close{}, @@ -118,7 +118,7 @@ func (c *consumerImpl) Close() { if err := c.grpcStreamClient.CloseSend(); err != nil { c.logger.Warn("close grpc stream failed", zap.Error(err)) } - <-c.finishErr.Done() + return c.finishErr.Get() } // Error returns the error of the consumer client. @@ -189,9 +189,12 @@ func (c *consumerImpl) recvLoop() (err error) { if c.txnBuilder != nil { panic("unreachable code: txn builder should be nil if we receive a non-txn message") } - if _, err := c.msgHandler.Handle(c.ctx, newImmutableMsg); err != nil { + if result := c.msgHandler.Handle(message.HandleParam{ + Ctx: c.ctx, + Message: newImmutableMsg, + }); result.Error != nil { c.logger.Warn("message handle canceled", zap.Error(err)) - return errors.Wrapf(err, "At Handler") + return errors.Wrapf(result.Error, "At Handler") } } case *streamingpb.ConsumeResponse_Close: @@ -255,7 +258,10 @@ func (c *consumerImpl) handleTxnMessage(msg message.ImmutableMessage) error { c.logger.Warn("failed to build txn message", zap.Any("messageID", commitMsg.MessageID()), zap.Error(err)) return nil } - if _, err := c.msgHandler.Handle(c.ctx, msg); err != nil { + if result := c.msgHandler.Handle(message.HandleParam{ + Ctx: c.ctx, + Message: msg, + }); result.Error != nil { c.logger.Warn("message handle canceled at txn", zap.Error(err)) return errors.Wrap(err, "At Handler Of Txn") } diff --git a/internal/streamingnode/client/handler/consumer/consumer_test.go b/internal/streamingnode/client/handler/consumer/consumer_test.go index 8481beea30260..801b52c1479aa 100644 --- a/internal/streamingnode/client/handler/consumer/consumer_test.go +++ b/internal/streamingnode/client/handler/consumer/consumer_test.go @@ -14,6 +14,7 @@ import ( "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/message/adaptor" "github.com/milvus-io/milvus/pkg/streaming/util/options" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/walimplstest" @@ -21,7 +22,7 @@ import ( ) func TestConsumer(t *testing.T) { - resultCh := make(message.ChanMessageHandler, 1) + resultCh := make(adaptor.ChanMessageHandler, 1) c := newMockedConsumerImpl(t, context.Background(), resultCh) mmsg, _ := message.NewInsertMessageBuilderV1(). @@ -70,7 +71,7 @@ func TestConsumer(t *testing.T) { } func TestConsumerWithCancellation(t *testing.T) { - resultCh := make(message.ChanMessageHandler, 1) + resultCh := make(adaptor.ChanMessageHandler, 1) ctx, cancel := context.WithCancel(context.Background()) c := newMockedConsumerImpl(t, ctx, resultCh) diff --git a/internal/streamingnode/client/handler/handler_client_impl.go b/internal/streamingnode/client/handler/handler_client_impl.go index d2a52f66fad4f..d805914680c3f 100644 --- a/internal/streamingnode/client/handler/handler_client_impl.go +++ b/internal/streamingnode/client/handler/handler_client_impl.go @@ -11,6 +11,8 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/client/handler/assignment" "github.com/milvus-io/milvus/internal/streamingnode/client/handler/consumer" "github.com/milvus-io/milvus/internal/streamingnode/client/handler/producer" + "github.com/milvus-io/milvus/internal/streamingnode/client/handler/registry" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/util/streamingutil/service/balancer/picker" "github.com/milvus-io/milvus/internal/util/streamingutil/service/lazygrpc" "github.com/milvus-io/milvus/internal/util/streamingutil/service/resolver" @@ -21,7 +23,11 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) -var errWaitNextBackoff = errors.New("wait for next backoff") +var ( + errWaitNextBackoff = errors.New("wait for next backoff") + _ producer.Producer = wal.WAL(nil) + _ consumer.Consumer = wal.Scanner(nil) +) type handlerClientImpl struct { lifetime *typeutil.Lifetime @@ -40,15 +46,27 @@ func (hc *handlerClientImpl) CreateProducer(ctx context.Context, opts *ProducerO } defer hc.lifetime.Done() - p, err := hc.createHandlerAfterStreamingNodeReady(ctx, opts.PChannel, func(ctx context.Context, assign *types.PChannelInfoAssigned) (any, error) { + p, err := hc.createHandlerAfterStreamingNodeReady(ctx, opts.PChannel, func(ctx context.Context, assign *types.PChannelInfoAssigned) (*handlerCreateResult, error) { + // Check if the localWAL is assigned at local + localWAL, err := registry.GetAvailableWAL(assign.Channel) + if err == nil { + return localResult(localWAL), nil + } + if !shouldUseRemoteWAL(err) { + return nil, err + } // Wait for handler service is ready. handlerService, err := hc.service.GetService(ctx) if err != nil { return nil, err } - return hc.newProducer(ctx, &producer.ProducerOptions{ + remoteWAL, err := hc.newProducer(ctx, &producer.ProducerOptions{ Assignment: assign, }, handlerService) + if err != nil { + return nil, err + } + return remoteResult(remoteWAL), nil }) if err != nil { return nil, err @@ -63,19 +81,41 @@ func (hc *handlerClientImpl) CreateConsumer(ctx context.Context, opts *ConsumerO } defer hc.lifetime.Done() - c, err := hc.createHandlerAfterStreamingNodeReady(ctx, opts.PChannel, func(ctx context.Context, assign *types.PChannelInfoAssigned) (any, error) { + c, err := hc.createHandlerAfterStreamingNodeReady(ctx, opts.PChannel, func(ctx context.Context, assign *types.PChannelInfoAssigned) (*handlerCreateResult, error) { + // Check if the localWAL is assigned at local + localWAL, err := registry.GetAvailableWAL(assign.Channel) + if err == nil { + localScanner, err := localWAL.Read(ctx, wal.ReadOption{ + VChannel: opts.VChannel, + DeliverPolicy: opts.DeliverPolicy, + MessageFilter: opts.DeliverFilters, + MesasgeHandler: opts.MessageHandler, + }) + if err != nil { + return nil, err + } + return localResult(localScanner), nil + } + if !shouldUseRemoteWAL(err) { + return nil, err + } + // Wait for handler service is ready. handlerService, err := hc.service.GetService(ctx) if err != nil { return nil, err } - return hc.newConsumer(ctx, &consumer.ConsumerOptions{ + remoteScanner, err := hc.newConsumer(ctx, &consumer.ConsumerOptions{ Assignment: assign, VChannel: opts.VChannel, DeliverPolicy: opts.DeliverPolicy, DeliverFilters: opts.DeliverFilters, MessageHandler: opts.MessageHandler, }, handlerService) + if err != nil { + return nil, err + } + return remoteResult(remoteScanner), nil }) if err != nil { return nil, err @@ -83,9 +123,24 @@ func (hc *handlerClientImpl) CreateConsumer(ctx context.Context, opts *ConsumerO return c.(Consumer), nil } +func localResult(result any) *handlerCreateResult { + return &handlerCreateResult{result: result, isLocal: true} +} + +func remoteResult(result any) *handlerCreateResult { + return &handlerCreateResult{result: result, isLocal: false} +} + +type handlerCreateResult struct { + result any + isLocal bool +} + +type handlerCreateFunc func(ctx context.Context, assign *types.PChannelInfoAssigned) (*handlerCreateResult, error) + // createHandlerAfterStreamingNodeReady creates a handler until streaming node ready. // If streaming node is not ready, it will block until new assignment term is coming or context timeout. -func (hc *handlerClientImpl) createHandlerAfterStreamingNodeReady(ctx context.Context, pchannel string, create func(ctx context.Context, assign *types.PChannelInfoAssigned) (any, error)) (any, error) { +func (hc *handlerClientImpl) createHandlerAfterStreamingNodeReady(ctx context.Context, pchannel string, create handlerCreateFunc) (any, error) { logger := log.With(zap.String("pchannel", pchannel)) // TODO: backoff should be configurable. backoff := backoff.NewExponentialBackOff() @@ -93,9 +148,10 @@ func (hc *handlerClientImpl) createHandlerAfterStreamingNodeReady(ctx context.Co assign := hc.watcher.Get(ctx, pchannel) if assign != nil { // Find assignment, try to create producer on this assignment. - c, err := create(ctx, assign) + createResult, err := create(ctx, assign) if err == nil { - return c, nil + logger.Info("create handler success", zap.Any("assignment", assign), zap.Bool("isLocal", createResult.isLocal)) + return createResult.result, nil } logger.Warn("create handler failed", zap.Any("assignment", assign), zap.Error(err)) @@ -158,3 +214,18 @@ func isPermanentFailureUntilNewAssignment(err error) bool { streamingServiceErr := status.AsStreamingError(err) return streamingServiceErr.IsWrongStreamingNode() } + +// shouldUseRemoteWAL checks if use remote wal when given error happens. +func shouldUseRemoteWAL(err error) bool { + if err == nil { + panic("the incoming error should never be nil") + } + // When following error happens, we should try to make a remote wal fetch. + // 1. If current node didn't deploy any streaming node. + if errors.Is(err, registry.ErrNoStreamingNodeDeployed) { + return true + } + // 2. If the wal is not exist at current streaming node. + streamingServiceErr := status.AsStreamingError(err) + return streamingServiceErr.IsWrongStreamingNode() +} diff --git a/internal/streamingnode/client/handler/handler_client_test.go b/internal/streamingnode/client/handler/handler_client_test.go index 3aa571e142cfe..a8be47d7c01ff 100644 --- a/internal/streamingnode/client/handler/handler_client_test.go +++ b/internal/streamingnode/client/handler/handler_client_test.go @@ -19,7 +19,7 @@ import ( "github.com/milvus-io/milvus/pkg/mocks/streaming/proto/mock_streamingpb" "github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_types" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" - "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/message/adaptor" "github.com/milvus-io/milvus/pkg/streaming/util/options" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -41,9 +41,9 @@ func TestHandlerClient(t *testing.T) { w.EXPECT().Close().Run(func() {}) p := mock_producer.NewMockProducer(t) - p.EXPECT().Close().Run(func() {}) + p.EXPECT().Close().RunAndReturn(func() {}) c := mock_consumer.NewMockConsumer(t) - c.EXPECT().Close().Run(func() {}) + c.EXPECT().Close().RunAndReturn(func() error { return nil }) rebalanceTrigger := mock_types.NewMockAssignmentRebalanceTrigger(t) rebalanceTrigger.EXPECT().ReportAssignmentError(mock.Anything, mock.Anything, mock.Anything).Return(nil) @@ -104,7 +104,7 @@ func TestHandlerClient(t *testing.T) { options.DeliverFilterTimeTickGT(10), options.DeliverFilterTimeTickGTE(10), }, - MessageHandler: make(message.ChanMessageHandler), + MessageHandler: make(adaptor.ChanMessageHandler), }) assert.NoError(t, err) assert.NotNil(t, consumer) diff --git a/internal/streamingnode/client/handler/producer/producer.go b/internal/streamingnode/client/handler/producer/producer.go index 41dec673d9d6c..6e31d2b479010 100644 --- a/internal/streamingnode/client/handler/producer/producer.go +++ b/internal/streamingnode/client/handler/producer/producer.go @@ -9,18 +9,13 @@ import ( var _ Producer = (*producerImpl)(nil) -type ProduceResult = types.AppendResult - // Producer is the interface that wraps the basic produce method on grpc stream. // Producer is work on a single stream on grpc, // so Producer cannot recover from failure because of the stream is broken. type Producer interface { - // Assignment returns the assignment of the producer. - Assignment() types.PChannelInfoAssigned - - // Produce sends the produce message to server. + // Append sends the produce message to server. // TODO: Support Batch produce here. - Produce(ctx context.Context, msg message.MutableMessage) (*ProduceResult, error) + Append(ctx context.Context, msg message.MutableMessage) (*types.AppendResult, error) // Check if a producer is available. IsAvailable() bool diff --git a/internal/streamingnode/client/handler/producer/producer_impl.go b/internal/streamingnode/client/handler/producer/producer_impl.go index 54ec3224d0f02..91783b3a419fa 100644 --- a/internal/streamingnode/client/handler/producer/producer_impl.go +++ b/internal/streamingnode/client/handler/producer/producer_impl.go @@ -114,17 +114,12 @@ type produceRequest struct { } type produceResponse struct { - result *ProduceResult + result *types.AppendResult err error } -// Assignment returns the assignment of the producer. -func (p *producerImpl) Assignment() types.PChannelInfoAssigned { - return p.assignment -} - -// Produce sends the produce message to server. -func (p *producerImpl) Produce(ctx context.Context, msg message.MutableMessage) (*ProduceResult, error) { +// Append sends the produce message to server. +func (p *producerImpl) Append(ctx context.Context, msg message.MutableMessage) (*types.AppendResult, error) { if !p.lifetime.Add(typeutil.LifetimeStateWorking) { return nil, status.NewOnShutdownError("producer client is shutting down") } @@ -293,7 +288,7 @@ func (p *producerImpl) recvLoop() (err error) { return err } result = produceResponse{ - result: &ProduceResult{ + result: &types.AppendResult{ MessageID: msgID, TimeTick: produceResp.Result.GetTimetick(), TxnCtx: message.NewTxnContextFromProto(produceResp.Result.GetTxnContext()), diff --git a/internal/streamingnode/client/handler/producer/producer_test.go b/internal/streamingnode/client/handler/producer/producer_test.go index bea7eda13da53..8accd38e890a4 100644 --- a/internal/streamingnode/client/handler/producer/producer_test.go +++ b/internal/streamingnode/client/handler/producer/producer_test.go @@ -61,12 +61,12 @@ func TestProducer(t *testing.T) { ch := make(chan struct{}) go func() { msg := message.CreateTestEmptyInsertMesage(1, nil) - msgID, err := producer.Produce(ctx, msg) + msgID, err := producer.Append(ctx, msg) assert.Error(t, err) assert.Nil(t, msgID) msg = message.CreateTestEmptyInsertMesage(1, nil) - msgID, err = producer.Produce(ctx, msg) + msgID, err = producer.Append(ctx, msg) assert.NoError(t, err) assert.NotNil(t, msgID) close(ch) @@ -100,7 +100,7 @@ func TestProducer(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) defer cancel() msg := message.CreateTestEmptyInsertMesage(1, nil) - _, err = producer.Produce(ctx, msg) + _, err = producer.Append(ctx, msg) assert.ErrorIs(t, err, context.DeadlineExceeded) assert.True(t, producer.IsAvailable()) producer.Close() diff --git a/internal/streamingnode/client/handler/registry/wal_manager.go b/internal/streamingnode/client/handler/registry/wal_manager.go new file mode 100644 index 0000000000000..36249e1616902 --- /dev/null +++ b/internal/streamingnode/client/handler/registry/wal_manager.go @@ -0,0 +1,44 @@ +package registry + +import ( + "context" + + "github.com/cockroachdb/errors" + + "github.com/milvus-io/milvus/internal/streamingnode/server/wal" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/syncutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +var ( + registry = syncutil.NewFuture[WALManager]() + ErrNoStreamingNodeDeployed = errors.New("no streaming node deployed") +) + +// RegisterLocalWALManager registers the local wal manager. +// When the streaming node is started, it should call this function to register the wal manager. +func RegisterLocalWALManager(manager WALManager) { + if !paramtable.IsLocalComponentEnabled(typeutil.StreamingNodeRole) { + panic("unreachable: streaming node is not enabled but wal setup") + } + registry.Set(manager) + log.Ctx(context.Background()).Info("register local wal manager done") +} + +// GetAvailableWAL returns a available wal instance for the channel. +func GetAvailableWAL(channel types.PChannelInfo) (wal.WAL, error) { + if !paramtable.IsLocalComponentEnabled(typeutil.StreamingNodeRole) { + return nil, ErrNoStreamingNodeDeployed + } + return registry.Get().GetAvailableWAL(channel) +} + +// WALManager is a hint type for wal manager at streaming node. +type WALManager interface { + // GetAvailableWAL returns a available wal instance for the channel. + // Return nil if the wal instance is not found. + GetAvailableWAL(channel types.PChannelInfo) (wal.WAL, error) +} diff --git a/internal/streamingnode/server/flusher/flusherimpl/channel_lifetime.go b/internal/streamingnode/server/flusher/flusherimpl/channel_lifetime.go index 994a856daf214..0f7764b087088 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/channel_lifetime.go +++ b/internal/streamingnode/server/flusher/flusherimpl/channel_lifetime.go @@ -30,7 +30,6 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" - adaptor2 "github.com/milvus-io/milvus/internal/streamingnode/server/wal/adaptor" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/util/message/adaptor" @@ -112,7 +111,7 @@ func (c *channelLifetime) Run() error { // Create scanner. policy := options.DeliverPolicyStartFrom(messageID) - handler := adaptor2.NewMsgPackAdaptorHandler() + handler := adaptor.NewMsgPackAdaptorHandler() ro := wal.ReadOption{ VChannel: c.vchannel, DeliverPolicy: policy, diff --git a/internal/streamingnode/server/server.go b/internal/streamingnode/server/server.go index 8956d8d78eaac..52d349d2a5895 100644 --- a/internal/streamingnode/server/server.go +++ b/internal/streamingnode/server/server.go @@ -5,6 +5,7 @@ import ( "google.golang.org/grpc" + "github.com/milvus-io/milvus/internal/streamingnode/client/handler/registry" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/service" "github.com/milvus-io/milvus/internal/streamingnode/server/walmanager" @@ -66,6 +67,8 @@ func (s *Server) initBasicComponent(_ context.Context) { if err != nil { panic("open wal manager failed") } + // Register the wal manager to the local registry. + registry.RegisterLocalWALManager(s.walManager) } // initService initializes the grpc service. diff --git a/internal/streamingnode/server/wal/adaptor/message_handler.go b/internal/streamingnode/server/wal/adaptor/message_handler.go deleted file mode 100644 index 8ec28014a623b..0000000000000 --- a/internal/streamingnode/server/wal/adaptor/message_handler.go +++ /dev/null @@ -1,107 +0,0 @@ -package adaptor - -import ( - "github.com/milvus-io/milvus/internal/streamingnode/server/wal" - "github.com/milvus-io/milvus/pkg/mq/msgstream" - "github.com/milvus-io/milvus/pkg/streaming/util/message" - "github.com/milvus-io/milvus/pkg/streaming/util/message/adaptor" -) - -var ( - _ wal.MessageHandler = defaultMessageHandler(nil) - _ wal.MessageHandler = (*MsgPackAdaptorHandler)(nil) -) - -type defaultMessageHandler chan message.ImmutableMessage - -func (h defaultMessageHandler) Handle(param wal.HandleParam) wal.HandleResult { - var sendingCh chan message.ImmutableMessage - if param.Message != nil { - sendingCh = h - } - select { - case <-param.Ctx.Done(): - return wal.HandleResult{Error: param.Ctx.Err()} - case msg, ok := <-param.Upstream: - if !ok { - return wal.HandleResult{Error: wal.ErrUpstreamClosed} - } - return wal.HandleResult{Incoming: msg} - case sendingCh <- param.Message: - return wal.HandleResult{MessageHandled: true} - case <-param.TimeTickChan: - return wal.HandleResult{TimeTickUpdated: true} - } -} - -func (d defaultMessageHandler) Close() { - close(d) -} - -// NewMsgPackAdaptorHandler create a new message pack adaptor handler. -func NewMsgPackAdaptorHandler() *MsgPackAdaptorHandler { - return &MsgPackAdaptorHandler{ - base: adaptor.NewBaseMsgPackAdaptorHandler(), - } -} - -type MsgPackAdaptorHandler struct { - base *adaptor.BaseMsgPackAdaptorHandler -} - -// Chan is the channel for message. -func (m *MsgPackAdaptorHandler) Chan() <-chan *msgstream.MsgPack { - return m.base.Channel -} - -// Handle is the callback for handling message. -func (m *MsgPackAdaptorHandler) Handle(param wal.HandleParam) wal.HandleResult { - messageHandled := false - // not handle new message if there are pending msgPack. - if param.Message != nil && m.base.PendingMsgPack.Len() == 0 { - m.base.GenerateMsgPack(param.Message) - messageHandled = true - } - - for { - var sendCh chan<- *msgstream.MsgPack - if m.base.PendingMsgPack.Len() != 0 { - sendCh = m.base.Channel - } - - select { - case <-param.Ctx.Done(): - return wal.HandleResult{ - MessageHandled: messageHandled, - Error: param.Ctx.Err(), - } - case msg, notClose := <-param.Upstream: - if !notClose { - return wal.HandleResult{ - MessageHandled: messageHandled, - Error: wal.ErrUpstreamClosed, - } - } - return wal.HandleResult{ - Incoming: msg, - MessageHandled: messageHandled, - } - case sendCh <- m.base.PendingMsgPack.Next(): - m.base.PendingMsgPack.UnsafeAdvance() - if m.base.PendingMsgPack.Len() > 0 { - continue - } - return wal.HandleResult{MessageHandled: messageHandled} - case <-param.TimeTickChan: - return wal.HandleResult{ - MessageHandled: messageHandled, - TimeTickUpdated: true, - } - } - } -} - -// Close closes the handler. -func (m *MsgPackAdaptorHandler) Close() { - close(m.base.Channel) -} diff --git a/internal/streamingnode/server/wal/adaptor/message_handler_test.go b/internal/streamingnode/server/wal/adaptor/message_handler_test.go deleted file mode 100644 index b3c7dedafddda..0000000000000 --- a/internal/streamingnode/server/wal/adaptor/message_handler_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package adaptor - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus/internal/streamingnode/server/wal" - "github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message" - "github.com/milvus-io/milvus/pkg/streaming/util/message" - "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/rmq" -) - -func TestMsgPackAdaptorHandler(t *testing.T) { - messageID := rmq.NewRmqID(1) - tt := uint64(100) - msg := message.CreateTestInsertMessage( - t, - 1, - 1000, - tt, - messageID, - ) - immutableMsg := msg.IntoImmutableMessage(messageID) - - upstream := make(chan message.ImmutableMessage, 1) - - ctx := context.Background() - h := NewMsgPackAdaptorHandler() - done := make(chan struct{}) - go func() { - for range h.Chan() { - } - close(done) - }() - upstream <- immutableMsg - resp := h.Handle(wal.HandleParam{ - Ctx: ctx, - Upstream: upstream, - Message: nil, - }) - assert.Equal(t, resp.Incoming, immutableMsg) - assert.False(t, resp.MessageHandled) - assert.NoError(t, resp.Error) - - resp = h.Handle(wal.HandleParam{ - Ctx: ctx, - Upstream: upstream, - Message: resp.Incoming, - }) - assert.NoError(t, resp.Error) - assert.Nil(t, resp.Incoming) - assert.True(t, resp.MessageHandled) - h.Close() - - <-done -} - -func TestDefaultHandler(t *testing.T) { - h := make(defaultMessageHandler, 1) - done := make(chan struct{}) - go func() { - for range h { - } - close(done) - }() - - upstream := make(chan message.ImmutableMessage, 1) - msg := mock_message.NewMockImmutableMessage(t) - upstream <- msg - resp := h.Handle(wal.HandleParam{ - Ctx: context.Background(), - Upstream: upstream, - Message: nil, - }) - assert.NotNil(t, resp.Incoming) - assert.NoError(t, resp.Error) - assert.False(t, resp.MessageHandled) - assert.Equal(t, resp.Incoming, msg) - - resp = h.Handle(wal.HandleParam{ - Ctx: context.Background(), - Upstream: upstream, - Message: resp.Incoming, - }) - assert.NoError(t, resp.Error) - assert.Nil(t, resp.Incoming) - assert.True(t, resp.MessageHandled) - - h.Close() - <-done -} diff --git a/internal/streamingnode/server/wal/adaptor/scanner_adaptor.go b/internal/streamingnode/server/wal/adaptor/scanner_adaptor.go index 6e293353c40e4..0c2b232b9f770 100644 --- a/internal/streamingnode/server/wal/adaptor/scanner_adaptor.go +++ b/internal/streamingnode/server/wal/adaptor/scanner_adaptor.go @@ -11,6 +11,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/message/adaptor" "github.com/milvus-io/milvus/pkg/streaming/util/options" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/streaming/walimpls" @@ -32,7 +33,7 @@ func newScannerAdaptor( panic("vchannel of scanner must be set") } if readOption.MesasgeHandler == nil { - readOption.MesasgeHandler = defaultMessageHandler(make(chan message.ImmutableMessage)) + readOption.MesasgeHandler = adaptor.ChanMessageHandler(make(chan message.ImmutableMessage)) } options.GetFilterFunc(readOption.MessageFilter) logger := log.With(zap.String("name", name), zap.String("channel", l.Channel().Name)) @@ -75,7 +76,7 @@ func (s *scannerAdaptorImpl) Channel() types.PChannelInfo { // Chan returns the message channel of the scanner. func (s *scannerAdaptorImpl) Chan() <-chan message.ImmutableMessage { - return s.readOption.MesasgeHandler.(defaultMessageHandler) + return s.readOption.MesasgeHandler.(adaptor.ChanMessageHandler) } // Close the scanner, release the underlying resources. @@ -107,7 +108,7 @@ func (s *scannerAdaptorImpl) executeConsume() { for { // generate the event channel and do the event loop. // TODO: Consume from local cache. - handleResult := s.readOption.MesasgeHandler.Handle(wal.HandleParam{ + handleResult := s.readOption.MesasgeHandler.Handle(message.HandleParam{ Ctx: s.Context(), Upstream: s.getUpstream(innerScanner), TimeTickChan: s.getTimeTickUpdateChan(timeTickNotifier), diff --git a/internal/streamingnode/server/wal/interceptors/segment/inspector/impls.go b/internal/streamingnode/server/wal/interceptors/segment/inspector/impls.go index de8ed3e119b2f..8e138aa8d3a54 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/inspector/impls.go +++ b/internal/streamingnode/server/wal/interceptors/segment/inspector/impls.go @@ -4,8 +4,12 @@ import ( "context" "time" + "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/syncutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -118,9 +122,16 @@ func (s *sealOperationInspectorImpl) background() { return true }) case <-mustSealTicker.C: - segmentBelongs := resource.Resource().SegmentAssignStatsManager().SealByTotalGrowingSegmentsSize() + threshold := paramtable.Get().DataCoordCfg.GrowingSegmentsMemSizeInMB.GetAsUint64() * 1024 * 1024 + segmentBelongs := resource.Resource().SegmentAssignStatsManager().SealByTotalGrowingSegmentsSize(threshold) + if segmentBelongs == nil { + continue + } + log.Info("seal by total growing segments size", zap.String("vchannel", segmentBelongs.VChannel), + zap.Uint64("sealThreshold", threshold), + zap.Int64("sealSegment", segmentBelongs.SegmentID)) if pm, ok := s.managers.Get(segmentBelongs.PChannel); ok { - pm.MustSealSegments(s.taskNotifier.Context(), segmentBelongs) + pm.MustSealSegments(s.taskNotifier.Context(), *segmentBelongs) } } } diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go b/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go index bce92f57960d6..c7111c6a2890a 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go @@ -162,19 +162,14 @@ func (m *partitionSegmentManager) collectShouldBeSealedWithPolicy(predicates fun return shouldBeSealedSegments } -// CollectDirtySegmentsAndClear collects all segments in the manager and clear the maanger. -func (m *partitionSegmentManager) CollectDirtySegmentsAndClear() []*segmentAllocManager { +// CollectAllSegmentsAndClear collects all segments in the manager and clear the manager. +func (m *partitionSegmentManager) CollectAllSegmentsAndClear() []*segmentAllocManager { m.mu.Lock() defer m.mu.Unlock() - dirtySegments := make([]*segmentAllocManager, 0, len(m.segments)) - for _, segment := range m.segments { - if segment.IsDirtyEnough() { - dirtySegments = append(dirtySegments, segment) - } - } - m.segments = make([]*segmentAllocManager, 0) - return dirtySegments + segments := m.segments + m.segments = nil + return segments } // CollectAllCanBeSealedAndClear collects all segments that can be sealed and clear the manager. diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go index e942ffae35c55..ac441cf07d9b9 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go @@ -265,31 +265,35 @@ func (m *PChannelSegmentAllocManager) Close(ctx context.Context) { // Try to seal all wait m.helper.SealAllWait(ctx) - m.logger.Info("seal all waited segments done", zap.Int("waitCounter", m.helper.WaitCounter())) + m.logger.Info("seal all waited segments done, may be some not done here", zap.Int("waitCounter", m.helper.WaitCounter())) segments := make([]*segmentAllocManager, 0) m.managers.Range(func(pm *partitionSegmentManager) { - segments = append(segments, pm.CollectDirtySegmentsAndClear()...) + segments = append(segments, pm.CollectAllSegmentsAndClear()...) }) - // commitAllSegmentsOnSamePChannel commits all segments on the same pchannel. + // Try to seal the dirty segment to avoid generate too large segment. protoSegments := make([]*streamingpb.SegmentAssignmentMeta, 0, len(segments)) + growingCnt := 0 for _, segment := range segments { - protoSegments = append(protoSegments, segment.Snapshot()) + if segment.GetState() == streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING { + growingCnt++ + } + if segment.IsDirtyEnough() { + // Only persist the dirty segment. + protoSegments = append(protoSegments, segment.Snapshot()) + } } - - m.logger.Info("segment assignment manager save all dirty segment assignments info", zap.Int("segmentCount", len(protoSegments))) + m.logger.Info("segment assignment manager save all dirty segment assignments info", + zap.Int("dirtySegmentCount", len(protoSegments)), + zap.Int("growingSegmentCount", growingCnt), + zap.Int("segmentCount", len(segments))) if err := resource.Resource().StreamingNodeCatalog().SaveSegmentAssignments(ctx, m.pchannel.Name, protoSegments); err != nil { m.logger.Warn("commit segment assignment at pchannel failed", zap.Error(err)) } // remove the stats from stats manager. - m.logger.Info("segment assignment manager remove all segment stats from stats manager") - for _, segment := range segments { - if segment.GetState() == streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING { - resource.Resource().SegmentAssignStatsManager().UnregisterSealedSegment(segment.GetSegmentID()) - } - } - + removedStatsSegmentCnt := resource.Resource().SegmentAssignStatsManager().UnregisterAllStatsOnPChannel(m.pchannel.Name) + m.logger.Info("segment assignment manager remove all segment stats from stats manager", zap.Int("removedStatsSegmentCount", removedStatsSegmentCnt)) m.metrics.Close() } diff --git a/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager.go b/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager.go index 43676953bc873..ec310f3d363c7 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager.go +++ b/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager.go @@ -5,10 +5,6 @@ import ( "sync" "github.com/cockroachdb/errors" - "github.com/pingcap/log" - "go.uber.org/zap" - - "github.com/milvus-io/milvus/pkg/util/paramtable" ) var ( @@ -24,8 +20,9 @@ type StatsManager struct { totalStats InsertMetrics pchannelStats map[string]*InsertMetrics vchannelStats map[string]*InsertMetrics - segmentStats map[int64]*SegmentStats // map[SegmentID]SegmentStats - segmentIndex map[int64]SegmentBelongs // map[SegmentID]channels + segmentStats map[int64]*SegmentStats // map[SegmentID]SegmentStats + segmentIndex map[int64]SegmentBelongs // map[SegmentID]channels + pchannelIndex map[string]map[int64]struct{} // map[PChannel]SegmentID sealNotifier *SealSignalNotifier } @@ -46,6 +43,7 @@ func NewStatsManager() *StatsManager { vchannelStats: make(map[string]*InsertMetrics), segmentStats: make(map[int64]*SegmentStats), segmentIndex: make(map[int64]SegmentBelongs), + pchannelIndex: make(map[string]map[int64]struct{}), sealNotifier: NewSealSignalNotifier(), } } @@ -62,6 +60,10 @@ func (m *StatsManager) RegisterNewGrowingSegment(belongs SegmentBelongs, segment m.segmentStats[segmentID] = stats m.segmentIndex[segmentID] = belongs + if _, ok := m.pchannelIndex[belongs.PChannel]; !ok { + m.pchannelIndex[belongs.PChannel] = make(map[int64]struct{}) + } + m.pchannelIndex[belongs.PChannel][segmentID] = struct{}{} m.totalStats.Collect(stats.Insert) if _, ok := m.pchannelStats[belongs.PChannel]; !ok { m.pchannelStats[belongs.PChannel] = &InsertMetrics{} @@ -145,6 +147,10 @@ func (m *StatsManager) UnregisterSealedSegment(segmentID int64) *SegmentStats { m.mu.Lock() defer m.mu.Unlock() + return m.unregisterSealedSegment(segmentID) +} + +func (m *StatsManager) unregisterSealedSegment(segmentID int64) *SegmentStats { // Must be exist, otherwise it's a bug. info, ok := m.segmentIndex[segmentID] if !ok { @@ -156,6 +162,13 @@ func (m *StatsManager) UnregisterSealedSegment(segmentID int64) *SegmentStats { m.totalStats.Subtract(stats.Insert) delete(m.segmentStats, segmentID) delete(m.segmentIndex, segmentID) + if _, ok := m.pchannelIndex[info.PChannel]; ok { + delete(m.pchannelIndex[info.PChannel], segmentID) + if len(m.pchannelIndex[info.PChannel]) == 0 { + delete(m.pchannelIndex, info.PChannel) + } + } + if _, ok := m.pchannelStats[info.PChannel]; ok { m.pchannelStats[info.PChannel].Subtract(stats.Insert) if m.pchannelStats[info.PChannel].BinarySize == 0 { @@ -171,15 +184,29 @@ func (m *StatsManager) UnregisterSealedSegment(segmentID int64) *SegmentStats { return stats } +// UnregisterAllStatsOnPChannel unregisters all stats on pchannel. +func (m *StatsManager) UnregisterAllStatsOnPChannel(pchannel string) int { + m.mu.Lock() + defer m.mu.Unlock() + + segmentIDs, ok := m.pchannelIndex[pchannel] + if !ok { + return 0 + } + for segmentID := range segmentIDs { + m.unregisterSealedSegment(segmentID) + } + return len(segmentIDs) +} + // SealByTotalGrowingSegmentsSize seals the largest growing segment // if the total size of growing segments in ANY vchannel exceeds the threshold. -func (m *StatsManager) SealByTotalGrowingSegmentsSize() SegmentBelongs { +func (m *StatsManager) SealByTotalGrowingSegmentsSize(vchannelThreshold uint64) *SegmentBelongs { m.mu.Lock() defer m.mu.Unlock() - for vchannel, metrics := range m.vchannelStats { - threshold := paramtable.Get().DataCoordCfg.GrowingSegmentsMemSizeInMB.GetAsUint64() * 1024 * 1024 - if metrics.BinarySize >= threshold { + for _, metrics := range m.vchannelStats { + if metrics.BinarySize >= vchannelThreshold { var ( largestSegment int64 = 0 largestSegmentSize uint64 = 0 @@ -190,13 +217,14 @@ func (m *StatsManager) SealByTotalGrowingSegmentsSize() SegmentBelongs { largestSegment = segmentID } } - log.Info("seal by total growing segments size", zap.String("vchannel", vchannel), - zap.Uint64("vchannelGrowingSize", metrics.BinarySize), zap.Uint64("sealThreshold", threshold), - zap.Int64("sealSegment", largestSegment), zap.Uint64("sealSegmentSize", largestSegmentSize)) - return m.segmentIndex[largestSegment] + belongs, ok := m.segmentIndex[largestSegment] + if !ok { + panic("unrechable: the segmentID should always be found in segmentIndex") + } + return &belongs } } - return SegmentBelongs{} + return nil } // InsertOpeatationMetrics is the metrics of insert operation. diff --git a/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager_test.go b/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager_test.go index efee056b66448..d4bda9cf12985 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager_test.go +++ b/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager_test.go @@ -106,6 +106,25 @@ func TestStatsManager(t *testing.T) { assert.Panics(t, func() { m.UnregisterSealedSegment(1) }) + m.UnregisterAllStatsOnPChannel("pchannel") + m.UnregisterAllStatsOnPChannel("pchannel2") +} + +func TestSealByTotalGrowingSegmentsSize(t *testing.T) { + m := NewStatsManager() + m.RegisterNewGrowingSegment(SegmentBelongs{PChannel: "pchannel", VChannel: "vchannel", CollectionID: 1, PartitionID: 2, SegmentID: 3}, 3, createSegmentStats(100, 100, 300)) + m.RegisterNewGrowingSegment(SegmentBelongs{PChannel: "pchannel", VChannel: "vchannel", CollectionID: 1, PartitionID: 2, SegmentID: 4}, 4, createSegmentStats(100, 200, 300)) + m.RegisterNewGrowingSegment(SegmentBelongs{PChannel: "pchannel", VChannel: "vchannel", CollectionID: 1, PartitionID: 2, SegmentID: 5}, 5, createSegmentStats(100, 100, 300)) + belongs := m.SealByTotalGrowingSegmentsSize(401) + assert.Nil(t, belongs) + belongs = m.SealByTotalGrowingSegmentsSize(400) + assert.NotNil(t, belongs) + assert.Equal(t, int64(4), belongs.SegmentID) + m.UnregisterAllStatsOnPChannel("pchannel") + assert.Empty(t, m.pchannelStats) + assert.Empty(t, m.vchannelStats) + assert.Empty(t, m.segmentStats) + assert.Empty(t, m.segmentIndex) } func createSegmentStats(row uint64, binarySize uint64, maxBinarSize uint64) *SegmentStats { diff --git a/internal/streamingnode/server/wal/scanner.go b/internal/streamingnode/server/wal/scanner.go index 89ca04460f909..8396e00fd1d5e 100644 --- a/internal/streamingnode/server/wal/scanner.go +++ b/internal/streamingnode/server/wal/scanner.go @@ -1,8 +1,6 @@ package wal import ( - "context" - "github.com/cockroachdb/errors" "github.com/milvus-io/milvus/pkg/streaming/util/message" @@ -19,7 +17,7 @@ type ReadOption struct { VChannel string // vchannel name DeliverPolicy options.DeliverPolicy MessageFilter []options.DeliverFilter - MesasgeHandler MessageHandler // message handler for message processing. + MesasgeHandler message.Handler // message handler for message processing. // If the message handler is nil (no redundant operation need to apply), // the default message handler will be used, and the receiver will be returned from Chan. // Otherwise, Chan will panic. @@ -45,27 +43,3 @@ type Scanner interface { // Return the error same with `Error` Close() error } - -type HandleParam struct { - Ctx context.Context - Upstream <-chan message.ImmutableMessage - Message message.ImmutableMessage - TimeTickChan <-chan struct{} -} - -type HandleResult struct { - Incoming message.ImmutableMessage // Not nil if upstream return new message. - MessageHandled bool // True if Message is handled successfully. - TimeTickUpdated bool // True if TimeTickChan is triggered. - Error error // Error is context is canceled. -} - -// MessageHandler is used to handle message read from log. -// TODO: should be removed in future after msgstream is removed. -type MessageHandler interface { - // Handle is the callback for handling message. - Handle(param HandleParam) HandleResult - - // Close is called after all messages are handled or handling is interrupted. - Close() -} diff --git a/internal/streamingnode/server/walmanager/wal_lifetime.go b/internal/streamingnode/server/walmanager/wal_lifetime.go index 616c1bc7c4b07..524044cf0be90 100644 --- a/internal/streamingnode/server/walmanager/wal_lifetime.go +++ b/internal/streamingnode/server/walmanager/wal_lifetime.go @@ -3,6 +3,7 @@ package walmanager import ( "context" + "github.com/cockroachdb/errors" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" @@ -72,7 +73,14 @@ func (w *walLifetime) Remove(ctx context.Context, term int64) error { } // Wait until the WAL state is ready or term expired or error occurs. - return w.statePair.WaitCurrentStateReachExpected(ctx, expected) + err := w.statePair.WaitCurrentStateReachExpected(ctx, expected) + if errors.IsAny(err, context.Canceled, context.DeadlineExceeded) { + return err + } + if err != nil { + w.logger.Info("remove wal success because that previous open operation is failure", zap.NamedError("previousOpenError", err)) + } + return nil } // Close closes the wal lifetime. diff --git a/pkg/streaming/util/message/adaptor/handler.go b/pkg/streaming/util/message/adaptor/handler.go index 80fd72be0766d..2d9ae00c64910 100644 --- a/pkg/streaming/util/message/adaptor/handler.go +++ b/pkg/streaming/util/message/adaptor/handler.go @@ -1,8 +1,6 @@ package adaptor import ( - "context" - "go.uber.org/zap" "github.com/milvus-io/milvus/pkg/log" @@ -11,6 +9,32 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) +type ChanMessageHandler chan message.ImmutableMessage + +func (h ChanMessageHandler) Handle(param message.HandleParam) message.HandleResult { + var sendingCh chan message.ImmutableMessage + if param.Message != nil { + sendingCh = h + } + select { + case <-param.Ctx.Done(): + return message.HandleResult{Error: param.Ctx.Err()} + case msg, ok := <-param.Upstream: + if !ok { + return message.HandleResult{Error: message.ErrUpstreamClosed} + } + return message.HandleResult{Incoming: msg} + case sendingCh <- param.Message: + return message.HandleResult{MessageHandled: true} + case <-param.TimeTickChan: + return message.HandleResult{TimeTickUpdated: true} + } +} + +func (d ChanMessageHandler) Close() { + close(d) +} + // NewMsgPackAdaptorHandler create a new message pack adaptor handler. func NewMsgPackAdaptorHandler() *MsgPackAdaptorHandler { return &MsgPackAdaptorHandler{ @@ -18,7 +42,6 @@ func NewMsgPackAdaptorHandler() *MsgPackAdaptorHandler { } } -// MsgPackAdaptorHandler is the handler for message pack. type MsgPackAdaptorHandler struct { base *BaseMsgPackAdaptorHandler } @@ -29,20 +52,53 @@ func (m *MsgPackAdaptorHandler) Chan() <-chan *msgstream.MsgPack { } // Handle is the callback for handling message. -func (m *MsgPackAdaptorHandler) Handle(ctx context.Context, msg message.ImmutableMessage) (bool, error) { - m.base.GenerateMsgPack(msg) - for m.base.PendingMsgPack.Len() > 0 { +func (m *MsgPackAdaptorHandler) Handle(param message.HandleParam) message.HandleResult { + messageHandled := false + // not handle new message if there are pending msgPack. + if param.Message != nil && m.base.PendingMsgPack.Len() == 0 { + m.base.GenerateMsgPack(param.Message) + messageHandled = true + } + + for { + var sendCh chan<- *msgstream.MsgPack + if m.base.PendingMsgPack.Len() != 0 { + sendCh = m.base.Channel + } + select { - case <-ctx.Done(): - return true, ctx.Err() - case m.base.Channel <- m.base.PendingMsgPack.Next(): + case <-param.Ctx.Done(): + return message.HandleResult{ + MessageHandled: messageHandled, + Error: param.Ctx.Err(), + } + case msg, notClose := <-param.Upstream: + if !notClose { + return message.HandleResult{ + MessageHandled: messageHandled, + Error: message.ErrUpstreamClosed, + } + } + return message.HandleResult{ + Incoming: msg, + MessageHandled: messageHandled, + } + case sendCh <- m.base.PendingMsgPack.Next(): m.base.PendingMsgPack.UnsafeAdvance() + if m.base.PendingMsgPack.Len() > 0 { + continue + } + return message.HandleResult{MessageHandled: messageHandled} + case <-param.TimeTickChan: + return message.HandleResult{ + MessageHandled: messageHandled, + TimeTickUpdated: true, + } } } - return true, nil } -// Close is the callback for closing message. +// Close closes the handler. func (m *MsgPackAdaptorHandler) Close() { close(m.base.Channel) } diff --git a/pkg/streaming/util/message/adaptor/handler_test.go b/pkg/streaming/util/message/adaptor/handler_test.go index 1c5909a079739..cf34c37bd83c7 100644 --- a/pkg/streaming/util/message/adaptor/handler_test.go +++ b/pkg/streaming/util/message/adaptor/handler_test.go @@ -3,167 +3,244 @@ package adaptor import ( "context" "testing" - "time" "github.com/stretchr/testify/assert" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" - "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message" "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/rmq" ) func TestMsgPackAdaptorHandler(t *testing.T) { - id := rmq.NewRmqID(1) - + messageID := rmq.NewRmqID(1) + tt := uint64(100) + msg := message.CreateTestInsertMessage( + t, + 1, + 1000, + tt, + messageID, + ) + immutableMsg := msg.IntoImmutableMessage(messageID) + + upstream := make(chan message.ImmutableMessage, 1) + + ctx := context.Background() h := NewMsgPackAdaptorHandler() - insertMsg := message.CreateTestInsertMessage(t, 1, 100, 10, id) - insertImmutableMessage := insertMsg.IntoImmutableMessage(id) - ch := make(chan *msgstream.MsgPack, 1) + done := make(chan struct{}) go func() { - for msgPack := range h.Chan() { - ch <- msgPack + for range h.Chan() { } - close(ch) + close(done) }() - ok, err := h.Handle(context.Background(), insertImmutableMessage) - assert.True(t, ok) - assert.NoError(t, err) - msgPack := <-ch - - assert.Equal(t, uint64(10), msgPack.BeginTs) - assert.Equal(t, uint64(10), msgPack.EndTs) - for _, tsMsg := range msgPack.Msgs { - assert.Equal(t, uint64(10), tsMsg.BeginTs()) - assert.Equal(t, uint64(10), tsMsg.EndTs()) - for _, ts := range tsMsg.(*msgstream.InsertMsg).Timestamps { - assert.Equal(t, uint64(10), ts) - } - } - - deleteMsg, err := message.NewDeleteMessageBuilderV1(). - WithVChannel("vchan1"). - WithHeader(&message.DeleteMessageHeader{ - CollectionId: 1, - }). - WithBody(&msgpb.DeleteRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Delete, - }, - CollectionID: 1, - PartitionID: 1, - Timestamps: []uint64{10}, - }). - BuildMutable() - assert.NoError(t, err) + upstream <- immutableMsg + resp := h.Handle(message.HandleParam{ + Ctx: ctx, + Upstream: upstream, + Message: nil, + }) + assert.Equal(t, resp.Incoming, immutableMsg) + assert.False(t, resp.MessageHandled) + assert.NoError(t, resp.Error) + + resp = h.Handle(message.HandleParam{ + Ctx: ctx, + Upstream: upstream, + Message: resp.Incoming, + }) + assert.NoError(t, resp.Error) + assert.Nil(t, resp.Incoming) + assert.True(t, resp.MessageHandled) + h.Close() - deleteImmutableMsg := deleteMsg. - WithTimeTick(11). - WithLastConfirmedUseMessageID(). - IntoImmutableMessage(id) + <-done +} - ok, err = h.Handle(context.Background(), deleteImmutableMsg) - assert.True(t, ok) - assert.NoError(t, err) - msgPack = <-ch - assert.Equal(t, uint64(11), msgPack.BeginTs) - assert.Equal(t, uint64(11), msgPack.EndTs) - for _, tsMsg := range msgPack.Msgs { - assert.Equal(t, uint64(11), tsMsg.BeginTs()) - assert.Equal(t, uint64(11), tsMsg.EndTs()) - for _, ts := range tsMsg.(*msgstream.DeleteMsg).Timestamps { - assert.Equal(t, uint64(11), ts) +func TestDefaultHandler(t *testing.T) { + h := make(ChanMessageHandler, 1) + done := make(chan struct{}) + go func() { + for range h { } - } - - // Create a txn message - msg, err := message.NewBeginTxnMessageBuilderV2(). - WithVChannel("vchan1"). - WithHeader(&message.BeginTxnMessageHeader{ - KeepaliveMilliseconds: 1000, - }). - WithBody(&message.BeginTxnMessageBody{}). - BuildMutable() - assert.NoError(t, err) - assert.NotNil(t, msg) - - txnCtx := message.TxnContext{ - TxnID: 1, - Keepalive: time.Second, - } - - beginImmutableMsg, err := message.AsImmutableBeginTxnMessageV2(msg.WithTimeTick(9). - WithTxnContext(txnCtx). - WithLastConfirmedUseMessageID(). - IntoImmutableMessage(rmq.NewRmqID(2))) - assert.NoError(t, err) - - msg, err = message.NewCommitTxnMessageBuilderV2(). - WithVChannel("vchan1"). - WithHeader(&message.CommitTxnMessageHeader{}). - WithBody(&message.CommitTxnMessageBody{}). - BuildMutable() - assert.NoError(t, err) - - commitImmutableMsg, err := message.AsImmutableCommitTxnMessageV2(msg.WithTimeTick(12). - WithTxnContext(txnCtx). - WithTxnContext(message.TxnContext{}). - WithLastConfirmedUseMessageID(). - IntoImmutableMessage(rmq.NewRmqID(3))) - assert.NoError(t, err) - - txn, err := message.NewImmutableTxnMessageBuilder(beginImmutableMsg). - Add(insertMsg.WithTxnContext(txnCtx).IntoImmutableMessage(id)). - Add(deleteMsg.WithTxnContext(txnCtx).IntoImmutableMessage(id)). - Build(commitImmutableMsg) - assert.NoError(t, err) - - ok, err = h.Handle(context.Background(), txn) - assert.True(t, ok) - assert.NoError(t, err) - msgPack = <-ch - - assert.Equal(t, uint64(12), msgPack.BeginTs) - assert.Equal(t, uint64(12), msgPack.EndTs) - - // Create flush message - msg, err = message.NewFlushMessageBuilderV2(). - WithVChannel("vchan1"). - WithHeader(&message.FlushMessageHeader{}). - WithBody(&message.FlushMessageBody{}). - BuildMutable() - assert.NoError(t, err) - - flushMsg := msg. - WithTimeTick(13). - WithLastConfirmedUseMessageID(). - IntoImmutableMessage(rmq.NewRmqID(4)) - - ok, err = h.Handle(context.Background(), flushMsg) - assert.True(t, ok) - assert.NoError(t, err) - - msgPack = <-ch + close(done) + }() - assert.Equal(t, uint64(13), msgPack.BeginTs) - assert.Equal(t, uint64(13), msgPack.EndTs) + upstream := make(chan message.ImmutableMessage, 1) + msg := mock_message.NewMockImmutableMessage(t) + upstream <- msg + resp := h.Handle(message.HandleParam{ + Ctx: context.Background(), + Upstream: upstream, + Message: nil, + }) + assert.NotNil(t, resp.Incoming) + assert.NoError(t, resp.Error) + assert.False(t, resp.MessageHandled) + assert.Equal(t, resp.Incoming, msg) + + resp = h.Handle(message.HandleParam{ + Ctx: context.Background(), + Upstream: upstream, + Message: resp.Incoming, + }) + assert.NoError(t, resp.Error) + assert.Nil(t, resp.Incoming) + assert.True(t, resp.MessageHandled) h.Close() - <-ch + <-done } -func TestMsgPackAdaptorHandlerTimeout(t *testing.T) { - id := rmq.NewRmqID(1) - - insertMsg := message.CreateTestInsertMessage(t, 1, 100, 10, id) - insertImmutableMessage := insertMsg.IntoImmutableMessage(id) - - h := NewMsgPackAdaptorHandler() - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - ok, err := h.Handle(ctx, insertImmutableMessage) - assert.True(t, ok) - assert.ErrorIs(t, err, ctx.Err()) -} +// func TestMsgPackAdaptorHandler(t *testing.T) { +// id := rmq.NewRmqID(1) +// +// h := NewMsgPackAdaptorHandler() +// insertMsg := message.CreateTestInsertMessage(t, 1, 100, 10, id) +// insertImmutableMessage := insertMsg.IntoImmutableMessage(id) +// ch := make(chan *msgstream.MsgPack, 1) +// go func() { +// for msgPack := range h.Chan() { +// ch <- msgPack +// } +// close(ch) +// }() +// ok, err := h.Handle(context.Background(), insertImmutableMessage) +// assert.True(t, ok) +// assert.NoError(t, err) +// msgPack := <-ch +// +// assert.Equal(t, uint64(10), msgPack.BeginTs) +// assert.Equal(t, uint64(10), msgPack.EndTs) +// for _, tsMsg := range msgPack.Msgs { +// assert.Equal(t, uint64(10), tsMsg.BeginTs()) +// assert.Equal(t, uint64(10), tsMsg.EndTs()) +// for _, ts := range tsMsg.(*msgstream.InsertMsg).Timestamps { +// assert.Equal(t, uint64(10), ts) +// } +// } +// +// deleteMsg, err := message.NewDeleteMessageBuilderV1(). +// WithVChannel("vchan1"). +// WithHeader(&message.DeleteMessageHeader{ +// CollectionId: 1, +// }). +// WithBody(&msgpb.DeleteRequest{ +// Base: &commonpb.MsgBase{ +// MsgType: commonpb.MsgType_Delete, +// }, +// CollectionID: 1, +// PartitionID: 1, +// Timestamps: []uint64{10}, +// }). +// BuildMutable() +// assert.NoError(t, err) +// +// deleteImmutableMsg := deleteMsg. +// WithTimeTick(11). +// WithLastConfirmedUseMessageID(). +// IntoImmutableMessage(id) +// +// ok, err = h.Handle(context.Background(), deleteImmutableMsg) +// assert.True(t, ok) +// assert.NoError(t, err) +// msgPack = <-ch +// assert.Equal(t, uint64(11), msgPack.BeginTs) +// assert.Equal(t, uint64(11), msgPack.EndTs) +// for _, tsMsg := range msgPack.Msgs { +// assert.Equal(t, uint64(11), tsMsg.BeginTs()) +// assert.Equal(t, uint64(11), tsMsg.EndTs()) +// for _, ts := range tsMsg.(*msgstream.DeleteMsg).Timestamps { +// assert.Equal(t, uint64(11), ts) +// } +// } +// +// // Create a txn message +// msg, err := message.NewBeginTxnMessageBuilderV2(). +// WithVChannel("vchan1"). +// WithHeader(&message.BeginTxnMessageHeader{ +// KeepaliveMilliseconds: 1000, +// }). +// WithBody(&message.BeginTxnMessageBody{}). +// BuildMutable() +// assert.NoError(t, err) +// assert.NotNil(t, msg) +// +// txnCtx := message.TxnContext{ +// TxnID: 1, +// Keepalive: time.Second, +// } +// +// beginImmutableMsg, err := message.AsImmutableBeginTxnMessageV2(msg.WithTimeTick(9). +// WithTxnContext(txnCtx). +// WithLastConfirmedUseMessageID(). +// IntoImmutableMessage(rmq.NewRmqID(2))) +// assert.NoError(t, err) +// +// msg, err = message.NewCommitTxnMessageBuilderV2(). +// WithVChannel("vchan1"). +// WithHeader(&message.CommitTxnMessageHeader{}). +// WithBody(&message.CommitTxnMessageBody{}). +// BuildMutable() +// assert.NoError(t, err) +// +// commitImmutableMsg, err := message.AsImmutableCommitTxnMessageV2(msg.WithTimeTick(12). +// WithTxnContext(txnCtx). +// WithTxnContext(message.TxnContext{}). +// WithLastConfirmedUseMessageID(). +// IntoImmutableMessage(rmq.NewRmqID(3))) +// assert.NoError(t, err) +// +// txn, err := message.NewImmutableTxnMessageBuilder(beginImmutableMsg). +// Add(insertMsg.WithTxnContext(txnCtx).IntoImmutableMessage(id)). +// Add(deleteMsg.WithTxnContext(txnCtx).IntoImmutableMessage(id)). +// Build(commitImmutableMsg) +// assert.NoError(t, err) +// +// ok, err = h.Handle(context.Background(), txn) +// assert.True(t, ok) +// assert.NoError(t, err) +// msgPack = <-ch +// +// assert.Equal(t, uint64(12), msgPack.BeginTs) +// assert.Equal(t, uint64(12), msgPack.EndTs) +// +// // Create flush message +// msg, err = message.NewFlushMessageBuilderV2(). +// WithVChannel("vchan1"). +// WithHeader(&message.FlushMessageHeader{}). +// WithBody(&message.FlushMessageBody{}). +// BuildMutable() +// assert.NoError(t, err) +// +// flushMsg := msg. +// WithTimeTick(13). +// WithLastConfirmedUseMessageID(). +// IntoImmutableMessage(rmq.NewRmqID(4)) +// +// ok, err = h.Handle(context.Background(), flushMsg) +// assert.True(t, ok) +// assert.NoError(t, err) +// +// msgPack = <-ch +// +// assert.Equal(t, uint64(13), msgPack.BeginTs) +// assert.Equal(t, uint64(13), msgPack.EndTs) +// +// h.Close() +// <-ch +// } +// +// func TestMsgPackAdaptorHandlerTimeout(t *testing.T) { +// id := rmq.NewRmqID(1) +// +// insertMsg := message.CreateTestInsertMessage(t, 1, 100, 10, id) +// insertImmutableMessage := insertMsg.IntoImmutableMessage(id) +// +// h := NewMsgPackAdaptorHandler() +// ctx, cancel := context.WithCancel(context.Background()) +// cancel() +// +// ok, err := h.Handle(ctx, insertImmutableMessage) +// assert.True(t, ok) +// assert.ErrorIs(t, err, ctx.Err()) +// } diff --git a/pkg/streaming/util/message/message_handler.go b/pkg/streaming/util/message/message_handler.go index c6b6355c6a511..8af20f2598437 100644 --- a/pkg/streaming/util/message/message_handler.go +++ b/pkg/streaming/util/message/message_handler.go @@ -1,6 +1,28 @@ package message -import "context" +import ( + "context" + + "github.com/cockroachdb/errors" +) + +var ErrUpstreamClosed = errors.New("upstream closed") + +// HandleParam is the parameter for handler. +type HandleParam struct { + Ctx context.Context + Upstream <-chan ImmutableMessage + Message ImmutableMessage + TimeTickChan <-chan struct{} +} + +// HandleResult is the result of handler. +type HandleResult struct { + Incoming ImmutableMessage // Not nil if upstream return new message. + MessageHandled bool // True if Message is handled successfully. + TimeTickUpdated bool // True if TimeTickChan is triggered. + Error error // Error is context is canceled. +} // Handler is used to handle message read from log. type Handler interface { @@ -8,29 +30,9 @@ type Handler interface { // Return true if the message is consumed, false if the message is not consumed. // Should return error if and only if ctx is done. // !!! It's a bad implementation for compatibility for msgstream, - // should be removed in the future. - Handle(ctx context.Context, msg ImmutableMessage) (bool, error) + // will be removed in the future. + Handle(param HandleParam) HandleResult // Close is called after all messages are handled or handling is interrupted. Close() } - -var _ Handler = ChanMessageHandler(nil) - -// ChanMessageHandler is a handler just forward the message into a channel. -type ChanMessageHandler chan ImmutableMessage - -// Handle is the callback for handling message. -func (cmh ChanMessageHandler) Handle(ctx context.Context, msg ImmutableMessage) (bool, error) { - select { - case <-ctx.Done(): - return false, ctx.Err() - case cmh <- msg: - return true, nil - } -} - -// Close is called after all messages are handled or handling is interrupted. -func (cmh ChanMessageHandler) Close() { - close(cmh) -} diff --git a/pkg/streaming/util/message/message_handler_test.go b/pkg/streaming/util/message/message_handler_test.go deleted file mode 100644 index 12b02810227b9..0000000000000 --- a/pkg/streaming/util/message/message_handler_test.go +++ /dev/null @@ -1,27 +0,0 @@ -package message - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestMessageHandler(t *testing.T) { - ch := make(chan ImmutableMessage, 1) - h := ChanMessageHandler(ch) - ok, err := h.Handle(context.Background(), nil) - assert.NoError(t, err) - assert.True(t, ok) - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - ok, err = h.Handle(ctx, nil) - assert.ErrorIs(t, err, ctx.Err()) - assert.False(t, ok) - - assert.Nil(t, <-ch) - h.Close() - _, ok = <-ch - assert.False(t, ok) -}