diff --git a/pubsub/rabbitpubsub/amqp.go b/pubsub/rabbitpubsub/amqp.go index d28831ce6b..5a5a174885 100644 --- a/pubsub/rabbitpubsub/amqp.go +++ b/pubsub/rabbitpubsub/amqp.go @@ -62,6 +62,7 @@ type amqpChannel interface { QueueDeclareAndBind(qname, ename string) error ExchangeDelete(string) error QueueDelete(qname string) error + Qos(prefetchCount, prefetchSize int, global bool) error } // connection adapts an *amqp.Connection to the amqpConnection interface. @@ -79,6 +80,7 @@ func (c *connection) Channel() (amqpChannel, error) { if err := ch.Confirm(wait); err != nil { return nil, err } + return &channel{ch}, nil } @@ -168,3 +170,7 @@ func (ch *channel) QueueDelete(qname string) error { _, err := ch.ch.QueueDelete(qname, false, false, false) return err } + +func (ch *channel) Qos(prefetchCount, prefetchSize int, global bool) error { + return ch.ch.Qos(prefetchCount, prefetchSize, global) +} diff --git a/pubsub/rabbitpubsub/fake_test.go b/pubsub/rabbitpubsub/fake_test.go index 9cd11f3b32..97c3badd61 100644 --- a/pubsub/rabbitpubsub/fake_test.go +++ b/pubsub/rabbitpubsub/fake_test.go @@ -389,6 +389,14 @@ func (ch *fakeChannel) QueueDelete(name string) error { return nil } +func (ch *fakeChannel) Qos(_, _ int, _ bool) error { + if ch.isClosed() { + return amqp.ErrClosed + } + + return nil +} + // Assumes nothing is ever written to the channel. func chanIsClosed(ch chan struct{}) bool { select { diff --git a/pubsub/rabbitpubsub/rabbit.go b/pubsub/rabbitpubsub/rabbit.go index 7bfed94554..3142f37c27 100644 --- a/pubsub/rabbitpubsub/rabbit.go +++ b/pubsub/rabbitpubsub/rabbit.go @@ -21,6 +21,7 @@ import ( "net/url" "os" "path" + "strconv" "strings" "sync" "sync/atomic" @@ -96,7 +97,9 @@ const Scheme = "rabbit" // // For subscriptions, the URL's host+path is used as the queue name. // -// No query parameters are supported. +// An optional query string can be used to set the Qos consumer prefetch on subscriptions +// like "rabbit://myqueue?prefetch_count=1000" to set the consumer prefetch count to 1000 +// see also https://www.rabbitmq.com/docs/consumer-prefetch type URLOpener struct { // Connection to use for communication with the server. Connection *amqp.Connection @@ -118,11 +121,27 @@ func (o *URLOpener) OpenTopicURL(ctx context.Context, u *url.URL) (*pubsub.Topic // OpenSubscriptionURL opens a pubsub.Subscription based on u. func (o *URLOpener) OpenSubscriptionURL(ctx context.Context, u *url.URL) (*pubsub.Subscription, error) { - for param := range u.Query() { - return nil, fmt.Errorf("open subscription %v: invalid query parameter %q", u, param) + opts := o.SubscriptionOptions + for param, value := range u.Query() { + switch param { + case "prefetch_count": + if len(value) != 1 || len(value[0]) == 0 { + return nil, fmt.Errorf("open subscription %v: invalid query parameter %q", u, param) + } + + prefetchCount, err := strconv.Atoi(value[0]) + if err != nil { + return nil, fmt.Errorf("open subscription %v: invalid query parameter %q: %w", u, param, err) + } + + opts.PrefetchCount = &prefetchCount + default: + return nil, fmt.Errorf("open subscription %v: invalid query parameter %q", u, param) + } } + queueName := path.Join(u.Host, u.Path) - return OpenSubscription(o.Connection, queueName, &o.SubscriptionOptions), nil + return OpenSubscription(o.Connection, queueName, &opts), nil } type topic struct { @@ -142,7 +161,10 @@ type TopicOptions struct{} // SubscriptionOptions sets options for constructing a *pubsub.Subscription // backed by RabbitMQ. -type SubscriptionOptions struct{} +type SubscriptionOptions struct { + // Qos property prefetch count. Optional. + PrefetchCount *int +} // OpenTopic returns a *pubsub.Topic corresponding to the named exchange. // See the package documentation for an example. @@ -515,7 +537,7 @@ func (*topic) Close() error { return nil } // The documentation of the amqp package recommends using separate connections for // publishing and subscribing. func OpenSubscription(conn *amqp.Connection, name string, opts *SubscriptionOptions) *pubsub.Subscription { - return pubsub.NewSubscription(newSubscription(&connection{conn}, name), nil, nil) + return pubsub.NewSubscription(newSubscription(&connection{conn}, name, opts), nil, nil) } type subscription struct { @@ -523,6 +545,8 @@ type subscription struct { queue string // the AMQP queue name consumer string // the client-generated name for this particular subscriber + opts *SubscriptionOptions + mu sync.Mutex ch amqpChannel // AMQP channel used for all communication. delc <-chan amqp.Delivery @@ -533,11 +557,16 @@ type subscription struct { var nextConsumer int64 // atomic -func newSubscription(conn amqpConnection, name string) *subscription { +func newSubscription(conn amqpConnection, name string, opts *SubscriptionOptions) *subscription { + if opts == nil { + opts = &SubscriptionOptions{} + } + return &subscription{ conn: conn, queue: name, consumer: fmt.Sprintf("c%d", atomic.AddInt64(&nextConsumer, 1)), + opts: opts, receiveBatchHook: func() {}, } } @@ -564,6 +593,11 @@ func (s *subscription) establishChannel(ctx context.Context) error { if err != nil { return err } + // Apply subscription options to channel. + err = applyOptionsToChannel(s.opts, ch) + if err != nil { + return err + } // Subscribe to messages from the queue. s.delc, err = ch.Consume(s.queue, s.consumer) return err @@ -571,8 +605,22 @@ func (s *subscription) establishChannel(ctx context.Context) error { if err != nil { return err } + s.ch = ch s.closec = ch.NotifyClose(make(chan *amqp.Error, 1)) // closec will get at most one element + + return nil +} + +func applyOptionsToChannel(opts *SubscriptionOptions, ch amqpChannel) error { + if opts.PrefetchCount == nil { + return nil + } + + if err := ch.Qos(*opts.PrefetchCount, 0, false); err != nil { + return fmt.Errorf("unable to set channel Qos: %w", err) + } + return nil } diff --git a/pubsub/rabbitpubsub/rabbit_test.go b/pubsub/rabbitpubsub/rabbit_test.go index 8f246a843b..ec25be339a 100644 --- a/pubsub/rabbitpubsub/rabbit_test.go +++ b/pubsub/rabbitpubsub/rabbit_test.go @@ -43,6 +43,8 @@ const rabbitURL = "amqp://guest:guest@localhost:5672/" var logOnce sync.Once func mustDialRabbit(t testing.TB) amqpConnection { + t.Helper() + if !setup.HasDockerTestEnvironment() { logOnce.Do(func() { t.Log("using the fake because the RabbitMQ server is not available") @@ -61,6 +63,8 @@ func mustDialRabbit(t testing.TB) amqpConnection { func TestConformance(t *testing.T) { harnessMaker := func(_ context.Context, t *testing.T) (drivertest.Harness, error) { + t.Helper() + return &harness{conn: mustDialRabbit(t)}, nil } _, isFake := mustDialRabbit(t).(*fakeConnection) @@ -73,6 +77,8 @@ func TestConformance(t *testing.T) { } t.Logf("now running tests with the fake") harnessMaker = func(_ context.Context, t *testing.T) (drivertest.Harness, error) { + t.Helper() + return &harness{conn: newFakeConnection()}, nil } asTests = []drivertest.AsTest{rabbitAsTest{true}} @@ -138,12 +144,12 @@ func (h *harness) CreateSubscription(_ context.Context, dt driver.Topic, testNam } ch.QueueDelete(queue) } - ds = newSubscription(h.conn, queue) + ds = newSubscription(h.conn, queue, nil) return ds, cleanup, nil } func (h *harness) MakeNonexistentSubscription(_ context.Context) (driver.Subscription, func(), error) { - return newSubscription(h.conn, "nonexistent-subscription"), func() {}, nil + return newSubscription(h.conn, "nonexistent-subscription", nil), func() {}, nil } func (h *harness) Close() { @@ -379,62 +385,103 @@ func (rabbitAsTest) AfterSend(as func(interface{}) bool) error { return nil } -func fakeConnectionStringInEnv() func() { - oldEnvVal := os.Getenv("RABBIT_SERVER_URL") - os.Setenv("RABBIT_SERVER_URL", "amqp://localhost:10000/vhost") - return func() { - os.Setenv("RABBIT_SERVER_URL", oldEnvVal) - } -} - func TestOpenTopicFromURL(t *testing.T) { - cleanup := fakeConnectionStringInEnv() - defer cleanup() + t.Setenv("RABBIT_SERVER_URL", rabbitURL) tests := []struct { - URL string - WantErr bool + label string + URLTemplate string + WantErr bool }{ - // OK, but still error because Dial fails. - {"rabbit://myexchange", true}, - // Invalid parameter. - {"rabbit://myexchange?param=value", true}, + {"valid url", "rabbit://%s", false}, + {"invalid url with parameters", "rabbit://%s?param=value", true}, } - ctx := context.Background() for _, test := range tests { - topic, err := pubsub.OpenTopic(ctx, test.URL) - if (err != nil) != test.WantErr { - t.Errorf("%s: got error %v, want error %v", test.URL, err, test.WantErr) - } - if topic != nil { - topic.Shutdown(ctx) - } + t.Run(test.label, func(t *testing.T) { + conn := mustDialRabbit(t) + _, isFake := conn.(*fakeConnection) + if isFake { + t.Skip("test requires real rabbitmq") + } + + h := &harness{conn: conn} + + ctx := context.Background() + + dt, cleanupTopic, err := h.CreateTopic(ctx, t.Name()) + if err != nil { + t.Fatalf("unable to create topic: %v", err) + } + + t.Cleanup(cleanupTopic) + + exchange := dt.(*topic).exchange + url := fmt.Sprintf(test.URLTemplate, exchange) + + topic, err := pubsub.OpenTopic(ctx, url) + if (err != nil) != test.WantErr { + t.Errorf("%s: got error %v, want error %v", test.URLTemplate, err, test.WantErr) + } + if topic != nil { + topic.Shutdown(ctx) + } + }) } } func TestOpenSubscriptionFromURL(t *testing.T) { - cleanup := fakeConnectionStringInEnv() - defer cleanup() + t.Setenv("RABBIT_SERVER_URL", rabbitURL) tests := []struct { - URL string - WantErr bool + label string + URLTemplate string + WantErr bool }{ - // OK, but error because Dial fails. - {"rabbit://myqueue", true}, - // Invalid parameter. - {"rabbit://myqueue?param=value", true}, + + {"url with no QoS prefetch count", "rabbit://%s", false}, + {"invalid parameters", "rabbit://%s?param=value", true}, + {"valid url with QoS prefetch count", "rabbit://%s?prefetch_count=1024", false}, + {"invalid url with QoS prefetch count", "rabbit://%s?prefetch_count=value", true}, } - ctx := context.Background() for _, test := range tests { - sub, err := pubsub.OpenSubscription(ctx, test.URL) - if (err != nil) != test.WantErr { - t.Errorf("%s: got error %v, want error %v", test.URL, err, test.WantErr) - } - if sub != nil { - sub.Shutdown(ctx) - } + t.Run(test.label, func(t *testing.T) { + conn := mustDialRabbit(t) + _, isFake := conn.(*fakeConnection) + if isFake { + t.Skip("test requires real rabbitmq") + } + + h := &harness{conn: conn} + + ctx := context.Background() + + dt, cleanupTopic, err := h.CreateTopic(ctx, t.Name()) + if err != nil { + t.Fatalf("unable to create topic: %v", err) + } + + t.Cleanup(cleanupTopic) + + ds, cleanupSubscription, err := h.CreateSubscription(ctx, dt, t.Name()) + if err != nil { + t.Fatalf("unable to create subscription: %v", err) + } + + t.Cleanup(cleanupSubscription) + + queue := ds.(*subscription).queue + url := fmt.Sprintf(test.URLTemplate, queue) + + sub, err := pubsub.OpenSubscription(ctx, url) + if (err != nil) != test.WantErr { + t.Errorf("%s: got error %v, want error %v", test.URLTemplate, err, test.WantErr) + } + + if sub != nil { + sub.Shutdown(ctx) + } + }) } }