From 81d2faed0cbf1fe3bfdbc662f59a45985e469ac1 Mon Sep 17 00:00:00 2001 From: Piotr Piotrowski Date: Tue, 4 Feb 2025 17:16:42 +0100 Subject: [PATCH] [ADDED] PullMaxMessagesWithFetchSizeLimit option for Consume and Messages (#1789) Signed-off-by: Piotr Piotrowski --- jetstream/README.md | 8 + jetstream/errors.go | 5 + jetstream/jetstream_options.go | 54 +++++++ jetstream/message.go | 3 + jetstream/pull.go | 84 ++++++----- jetstream/test/pull_test.go | 261 +++++++++++++++++++++++++++++++++ 6 files changed, 379 insertions(+), 36 deletions(-) diff --git a/jetstream/README.md b/jetstream/README.md index 28926a842..ac11f328f 100644 --- a/jetstream/README.md +++ b/jetstream/README.md @@ -494,6 +494,10 @@ request. An error will be triggered if at least 2 heartbeats are missed - `WithConsumeErrHandler(func (ConsumeContext, error))` - when used, sets a custom error handler on `Consume()`, allowing e.g. tracking missing heartbeats. +- `PullMaxMessagesWithBytesLimit` - up to the provided number of messages + will be buffered and a single fetch size will be limited to the provided + value. This is an advanced option and should be used with caution. Most of the + time, `PullMaxMessages` or `PullMaxBytes` should be used instead. > __NOTE__: `Stop()` should always be called on `ConsumeContext` to avoid > leaking goroutines. @@ -536,6 +540,10 @@ type PullThresholdMessages int - `PullHeartbeat(time.Duration)` - idle heartbeat duration for a single pull request. An error will be triggered if at least 2 heartbeats are missed (unless `WithMessagesErrOnMissingHeartbeat(false)` is used) +- `PullMaxMessagesWithBytesLimit` - up to the provided number of messages + will be buffered and a single fetch size will be limited to the provided + value. This is an advanced option and should be used with caution. Most of the + time, `PullMaxMessages` or `PullMaxBytes` should be used instead. ##### Using `Messages()` to fetch single messages one by one diff --git a/jetstream/errors.go b/jetstream/errors.go index 8d2fec642..4b8697b1f 100644 --- a/jetstream/errors.go +++ b/jetstream/errors.go @@ -199,6 +199,11 @@ var ( // on a pull request. ErrMaxBytesExceeded JetStreamError = &jsError{message: "message size exceeds max bytes"} + // ErrBatchCompleted is returned when a fetch request sent the whole batch, + // but there are still bytes left. This is applicable only when MaxBytes is + // set on a pull request. + ErrBatchCompleted JetStreamError = &jsError{message: "batch completed"} + // ErrConsumerDeleted is returned when attempting to send pull request to a // consumer which does not exist. ErrConsumerDeleted JetStreamError = &jsError{message: "consumer deleted"} diff --git a/jetstream/jetstream_options.go b/jetstream/jetstream_options.go index 78cd36c77..e23596be2 100644 --- a/jetstream/jetstream_options.go +++ b/jetstream/jetstream_options.go @@ -125,6 +125,60 @@ func (max PullMaxMessages) configureMessages(opts *consumeOpts) error { return nil } +type pullMaxMessagesWithBytesLimit struct { + maxMessages int + maxBytes int +} + +// PullMaxMessagesWithBytesLimit limits the number of messages to be buffered +// in the client. Additionally, it sets the maximum size a single fetch request +// can have. Note that this will not limit the total size of messages buffered +// in the client, but rather can serve as a way to limit what nats server will +// have to internally buffer for a single fetch request. +// +// This is an advanced option and should be used with caution. Most users should +// use [PullMaxMessages] or [PullMaxBytes] instead. +// +// PullMaxMessagesWithBytesLimit implements both PullConsumeOpt and +// PullMessagesOpt, allowing it to configure Consumer.Consume and Consumer.Messages. +func PullMaxMessagesWithBytesLimit(maxMessages, byteLimit int) pullMaxMessagesWithBytesLimit { + return pullMaxMessagesWithBytesLimit{maxMessages, byteLimit} +} + +func (m pullMaxMessagesWithBytesLimit) configureConsume(opts *consumeOpts) error { + if m.maxMessages <= 0 { + return fmt.Errorf("%w: maxMessages size must be at least 1", ErrInvalidOption) + } + if m.maxBytes <= 0 { + return fmt.Errorf("%w: maxBytes size must be at least 1", ErrInvalidOption) + } + if opts.MaxMessages > 0 { + return fmt.Errorf("%w: maxMessages already set", ErrInvalidOption) + } + opts.MaxMessages = m.maxMessages + opts.MaxBytes = m.maxBytes + opts.LimitSize = true + + return nil +} + +func (m pullMaxMessagesWithBytesLimit) configureMessages(opts *consumeOpts) error { + if m.maxMessages <= 0 { + return fmt.Errorf("%w: maxMessages size must be at least 1", ErrInvalidOption) + } + if m.maxBytes <= 0 { + return fmt.Errorf("%w: maxBytes size must be at least 1", ErrInvalidOption) + } + if opts.MaxMessages > 0 { + return fmt.Errorf("%w: maxMessages already set", ErrInvalidOption) + } + opts.MaxMessages = m.maxMessages + opts.MaxBytes = m.maxBytes + opts.LimitSize = true + + return nil +} + // PullExpiry sets timeout on a single pull request, waiting until at least one // message is available. // If not provided, a default of 30 seconds will be used. diff --git a/jetstream/message.go b/jetstream/message.go index 217d2a483..720a4a577 100644 --- a/jetstream/message.go +++ b/jetstream/message.go @@ -418,6 +418,9 @@ func checkMsg(msg *nats.Msg) (bool, error) { if strings.Contains(strings.ToLower(descr), "message size exceeds maxbytes") { return false, ErrMaxBytesExceeded } + if strings.Contains(strings.ToLower(descr), "batch completed") { + return false, ErrBatchCompleted + } if strings.Contains(strings.ToLower(descr), "consumer deleted") { return false, ErrConsumerDeleted } diff --git a/jetstream/pull.go b/jetstream/pull.go index 8ee17f2ea..2bc3bf5c1 100644 --- a/jetstream/pull.go +++ b/jetstream/pull.go @@ -101,6 +101,7 @@ type ( Expires time.Duration MaxMessages int MaxBytes int + LimitSize bool Heartbeat time.Duration ErrHandler ConsumeErrHandlerFunc ReportMissingHeartbeats bool @@ -160,9 +161,10 @@ type ( ) const ( - DefaultMaxMessages = 500 - DefaultExpires = 30 * time.Second - unset = -1 + DefaultMaxMessages = 500 + DefaultExpires = 30 * time.Second + defaultBatchMaxBytesOnly = 1_000_000 + unset = -1 ) func min(x, y int) int { @@ -193,7 +195,7 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) ( sub := &pullSubscription{ id: consumeID, consumer: p, - errs: make(chan error, 1), + errs: make(chan error, 10), done: make(chan struct{}, 1), fetchNext: make(chan *pullRequest, 1), consumeOpts: consumeOpts, @@ -373,7 +375,7 @@ func (s *pullSubscription) resetPendingMsgs() { // lock should be held before calling this method func (s *pullSubscription) decrementPendingMsgs(msg *nats.Msg) { s.pending.msgCount-- - if s.consumeOpts.MaxBytes != 0 { + if s.consumeOpts.MaxBytes != 0 && !s.consumeOpts.LimitSize { s.pending.byteCount -= msg.Size() } } @@ -388,18 +390,23 @@ func (s *pullSubscription) incrementDeliveredMsgs() { // the buffer to trigger a new pull request. // lock should be held before calling this method func (s *pullSubscription) checkPending() { + // check if we went below any threshold + // we don't want to track bytes threshold if either it's not set or we used + // PullMaxMessagesWithBytesLimit if (s.pending.msgCount < s.consumeOpts.ThresholdMessages || - (s.pending.byteCount < s.consumeOpts.ThresholdBytes && s.consumeOpts.MaxBytes != 0)) && + (s.pending.byteCount < s.consumeOpts.ThresholdBytes && s.consumeOpts.MaxBytes != 0 && !s.consumeOpts.LimitSize)) && s.fetchInProgress.Load() == 0 { var batchSize, maxBytes int - if s.consumeOpts.MaxBytes == 0 { - // if using messages, calculate appropriate batch size - batchSize = s.consumeOpts.MaxMessages - s.pending.msgCount - } else { - // if using bytes, use the max value - batchSize = s.consumeOpts.MaxMessages - maxBytes = s.consumeOpts.MaxBytes - s.pending.byteCount + batchSize = s.consumeOpts.MaxMessages - s.pending.msgCount + if s.consumeOpts.MaxBytes != 0 { + if s.consumeOpts.LimitSize { + maxBytes = s.consumeOpts.MaxBytes + } else { + maxBytes = s.consumeOpts.MaxBytes - s.pending.byteCount + // when working with max bytes only, always ask for full batch + batchSize = s.consumeOpts.MaxMessages + } } if s.consumeOpts.StopAfter > 0 { batchSize = min(batchSize, s.consumeOpts.StopAfter-s.delivered-s.pending.msgCount) @@ -440,7 +447,7 @@ func (p *pullConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, error consumer: p, done: make(chan struct{}, 1), msgs: msgs, - errs: make(chan error, 1), + errs: make(chan error, 10), fetchNext: make(chan *pullRequest, 1), consumeOpts: consumeOpts, } @@ -584,7 +591,7 @@ func (s *pullSubscription) Next() (Msg, error) { } func (s *pullSubscription) handleStatusMsg(msg *nats.Msg, msgErr error) error { - if !errors.Is(msgErr, nats.ErrTimeout) && !errors.Is(msgErr, ErrMaxBytesExceeded) { + if !errors.Is(msgErr, nats.ErrTimeout) && !errors.Is(msgErr, ErrMaxBytesExceeded) && !errors.Is(msgErr, ErrBatchCompleted) { if errors.Is(msgErr, ErrConsumerDeleted) || errors.Is(msgErr, ErrBadRequest) { return msgErr } @@ -605,7 +612,7 @@ func (s *pullSubscription) handleStatusMsg(msg *nats.Msg, msgErr error) error { if s.pending.msgCount < 0 { s.pending.msgCount = 0 } - if s.consumeOpts.MaxBytes > 0 { + if s.consumeOpts.MaxBytes > 0 && !s.consumeOpts.LimitSize { s.pending.byteCount -= bytesLeft if s.pending.byteCount < 0 { s.pending.byteCount = 0 @@ -712,7 +719,7 @@ func (p *pullConsumer) Fetch(batch int, opts ...FetchOpt) (MessageBatch, error) // FetchBytes is used to retrieve up to a provided bytes from the stream. func (p *pullConsumer) FetchBytes(maxBytes int, opts ...FetchOpt) (MessageBatch, error) { req := &pullRequest{ - Batch: 1000000, + Batch: defaultBatchMaxBytesOnly, MaxBytes: maxBytes, Expires: DefaultExpires, Heartbeat: unset, @@ -761,7 +768,7 @@ func (p *pullConsumer) fetch(req *pullRequest) (MessageBatch, error) { consumer: p, done: make(chan struct{}, 1), msgs: msgs, - errs: make(chan error, 1), + errs: make(chan error, 10), } inbox := p.js.conn.NewInbox() var err error @@ -985,40 +992,45 @@ func parseMessagesOpts(ordered bool, opts ...PullMessagesOpt) (*consumeOpts, err } func (consumeOpts *consumeOpts) setDefaults(ordered bool) error { - if consumeOpts.MaxBytes != unset && consumeOpts.MaxMessages != unset { + // we cannot use both max messages and max bytes unless we're using max bytes as fetch size limiter + if consumeOpts.MaxBytes != unset && consumeOpts.MaxMessages != unset && !consumeOpts.LimitSize { return errors.New("only one of MaxMessages and MaxBytes can be specified") } - if consumeOpts.MaxBytes != unset { - // when max_bytes is used, set batch size to a very large number - consumeOpts.MaxMessages = 1000000 - } else if consumeOpts.MaxMessages != unset { + if consumeOpts.MaxBytes != unset && !consumeOpts.LimitSize { + // we used PullMaxBytes setting, set MaxMessages to a high value + consumeOpts.MaxMessages = defaultBatchMaxBytesOnly + } else if consumeOpts.MaxMessages == unset { + // otherwise, if max messages is not set, set it to default value + consumeOpts.MaxMessages = DefaultMaxMessages + } + // if user did not set max bytes, set it to 0 + if consumeOpts.MaxBytes == unset { consumeOpts.MaxBytes = 0 - } else { - if consumeOpts.MaxBytes == unset { - consumeOpts.MaxBytes = 0 - } - if consumeOpts.MaxMessages == unset { - consumeOpts.MaxMessages = DefaultMaxMessages - } } if consumeOpts.ThresholdMessages == 0 { + // half of the max messages, rounded up consumeOpts.ThresholdMessages = int(math.Ceil(float64(consumeOpts.MaxMessages) / 2)) } if consumeOpts.ThresholdBytes == 0 { + // half of the max bytes, rounded up consumeOpts.ThresholdBytes = int(math.Ceil(float64(consumeOpts.MaxBytes) / 2)) } + + // set default heartbeats if consumeOpts.Heartbeat == unset { + // by default, use 50% of expiry time + consumeOpts.Heartbeat = consumeOpts.Expires / 2 if ordered { - consumeOpts.Heartbeat = 5 * time.Second + // for ordered consumers, the default heartbeat is 5 seconds if consumeOpts.Expires < 10*time.Second { consumeOpts.Heartbeat = consumeOpts.Expires / 2 + } else { + consumeOpts.Heartbeat = 5 * time.Second } - } else { - consumeOpts.Heartbeat = consumeOpts.Expires / 2 - if consumeOpts.Heartbeat > 30*time.Second { - consumeOpts.Heartbeat = 30 * time.Second - } + } else if consumeOpts.Heartbeat > 30*time.Second { + // cap the heartbeat to 30 seconds + consumeOpts.Heartbeat = 30 * time.Second } } if consumeOpts.Heartbeat > consumeOpts.Expires/2 { diff --git a/jetstream/test/pull_test.go b/jetstream/test/pull_test.go index 657e4db06..c1ffb523c 100644 --- a/jetstream/test/pull_test.go +++ b/jetstream/test/pull_test.go @@ -14,6 +14,7 @@ package test import ( + "bytes" "context" "errors" "fmt" @@ -1945,6 +1946,115 @@ func TestPullConsumerMessages(t *testing.T) { t.Fatalf("Unexpected received message count after drain; want %d; got %d", len(testMsgs), len(msgs)) } }) + + t.Run("with max messages and per fetch size limit", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // subscribe to next request subject to verify how many next requests were sent + // and whether both thresholds work as expected + sub, err := nc.SubscribeSync(fmt.Sprintf("$JS.API.CONSUMER.MSG.NEXT.foo.%s", c.CachedInfo().Name)) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + defer sub.Unsubscribe() + + it, err := c.Messages(jetstream.PullMaxMessagesWithBytesLimit(10, 1024)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + smallMsg := nats.Msg{ + Subject: "FOO.A", + Data: []byte("msg"), + } + // publish 10 small messages + for i := 0; i < 10; i++ { + if _, err := js.PublishMsg(ctx, &smallMsg); err != nil { + t.Fatalf("Unexpected error during publish: %s", err) + } + } + + for i := 0; i < 10; i++ { + msg, err := it.Next() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + msg.Ack() + } + + // we should get 2 pull requests + for range 2 { + fetchReq, err := sub.NextMsg(100 * time.Millisecond) + if err != nil { + t.Fatalf("Error on next msg: %v", err) + } + if !bytes.Contains(fetchReq.Data, []byte(`"max_bytes":1024`)) { + t.Fatalf("Unexpected fetch request: %s", fetchReq.Data) + } + } + // make sure no more requests were sent + _, err = sub.NextMsg(100 * time.Millisecond) + if !errors.Is(err, nats.ErrTimeout) { + t.Fatalf("Expected timeout error; got: %v", err) + } + + // now publish 10 large messages, almost hitting the limit + // we need to account for the total message size (which includes js ack reply subject) + largeMsg := nats.Msg{ + Subject: "FOO.B", + Data: make([]byte, 950), + } + for range 10 { + if _, err := js.PublishMsg(ctx, &largeMsg); err != nil { + t.Fatalf("Unexpected error during publish: %s", err) + } + } + + for i := 0; i < 10; i++ { + msg, err := it.Next() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + msg.Ack() + } + // we expect 10 pull requests + for range 9 { + fetchReq, err := sub.NextMsg(100 * time.Millisecond) + if err != nil { + t.Fatalf("Error on next msg: %v", err) + } + if !bytes.Contains(fetchReq.Data, []byte(`"max_bytes":1024`)) { + t.Fatalf("Unexpected fetch request: %s", fetchReq.Data) + } + } + _, err = sub.NextMsg(100 * time.Millisecond) + if !errors.Is(err, nats.ErrTimeout) { + t.Fatalf("Expected timeout error; got: %v", err) + } + + it.Stop() + }) } func TestPullConsumerConsume(t *testing.T) { @@ -2871,6 +2981,157 @@ func TestPullConsumerConsume(t *testing.T) { t.Fatalf("Timeout waiting for consume to be closed") } }) + + t.Run("with max messages and per fetch size limit", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // subscribe to next request subject to verify how many next requests were sent + // and whether both thresholds work as expected + sub, err := nc.SubscribeSync(fmt.Sprintf("$JS.API.CONSUMER.MSG.NEXT.foo.%s", c.CachedInfo().Name)) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + defer sub.Unsubscribe() + + wg := &sync.WaitGroup{} + msgs := make([]jetstream.Msg, 0) + cc, err := c.Consume(func(msg jetstream.Msg) { + msg.Ack() + msgs = append(msgs, msg) + wg.Done() + }, jetstream.PullMaxMessagesWithBytesLimit(10, 1024)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + smallMsg := nats.Msg{ + Subject: "FOO.A", + Data: []byte("msg"), + } + wg.Add(10) + // publish 10 small messages + for i := 0; i < 10; i++ { + if _, err := js.PublishMsg(ctx, &smallMsg); err != nil { + t.Fatalf("Unexpected error during publish: %s", err) + } + } + wg.Wait() + + // we should get 2 pull requests + for range 2 { + fetchReq, err := sub.NextMsg(100 * time.Millisecond) + if err != nil { + t.Fatalf("Error on next msg: %v", err) + } + if !bytes.Contains(fetchReq.Data, []byte(`"max_bytes":1024`)) { + t.Fatalf("Unexpected fetch request: %s", fetchReq.Data) + } + } + // make sure no more requests were sent + _, err = sub.NextMsg(100 * time.Millisecond) + if !errors.Is(err, nats.ErrTimeout) { + t.Fatalf("Expected timeout error; got: %v", err) + } + + // now publish 10 large messages, almost hitting the limit + // we need to account for the total message size (which includes js ack reply subject) + largeMsg := nats.Msg{ + Subject: "FOO.B", + Data: make([]byte, 950), + } + wg.Add(10) + for range 10 { + if _, err := js.PublishMsg(ctx, &largeMsg); err != nil { + t.Fatalf("Unexpected error during publish: %s", err) + } + } + wg.Wait() + + // we expect 10 pull requests + for range 10 { + fetchReq, err := sub.NextMsg(100 * time.Millisecond) + if err != nil { + t.Fatalf("Error on next msg: %v", err) + } + if !bytes.Contains(fetchReq.Data, []byte(`"max_bytes":1024`)) { + t.Fatalf("Unexpected fetch request: %s", fetchReq.Data) + } + } + _, err = sub.NextMsg(100 * time.Millisecond) + if !errors.Is(err, nats.ErrTimeout) { + t.Fatalf("Expected timeout error; got: %v", err) + } + + cc.Stop() + }) + + t.Run("avoid stall on batch completed status", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + wg := &sync.WaitGroup{} + msgs := make([]jetstream.Msg, 0) + // use consume with small max messages and large max bytes + // to make sure we don't stall on batch completed status + cc, err := c.Consume(func(msg jetstream.Msg) { + msg.Ack() + msgs = append(msgs, msg) + wg.Done() + }, jetstream.PullMaxMessagesWithBytesLimit(2, 1024)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + wg.Add(10) + for i := 0; i < 10; i++ { + if _, err := js.Publish(ctx, "FOO.A", []byte("msg")); err != nil { + t.Fatalf("Unexpected error during publish: %s", err) + } + } + wg.Wait() + cc.Stop() + }) } func TestPullConsumerConsume_WithCluster(t *testing.T) {