Skip to content

Commit

Permalink
Switch to DeferredConfirmations and implement Batch Publishing jxsl13#36
Browse files Browse the repository at this point in the history
  • Loading branch information
escb005 committed Apr 24, 2024
1 parent 28526e2 commit 2401e98
Show file tree
Hide file tree
Showing 7 changed files with 594 additions and 85 deletions.
18 changes: 18 additions & 0 deletions amqpx.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,18 @@ func (a *AMQPX) Publish(ctx context.Context, exchange string, routingKey string,
return a.pub.Publish(ctx, exchange, routingKey, msg)
}

// Publishes a batch of messages.
// Each messages can be published to a different exchange and routing key.
func (a *AMQPX) PublishBatch(ctx context.Context, msgs []pool.BatchPublishing) error {
a.mu.RLock()
defer a.mu.RUnlock()
if a.pub == nil {
panic("amqpx package was not started")
}

return a.pub.PublishBatch(ctx, msgs)
}

// Get is only supposed to be used for testing, do not use get for polling any broker queues.
func (a *AMQPX) Get(ctx context.Context, queue string, autoAck bool) (msg pool.Delivery, ok bool, err error) {
a.mu.RLock()
Expand Down Expand Up @@ -379,6 +391,12 @@ func Publish(ctx context.Context, exchange string, routingKey string, msg pool.P
return amqpx.Publish(ctx, exchange, routingKey, msg)
}

// Publishes a batch of messages.
// Each messages can be published to a different exchange and routing key.
func PublishBatch(ctx context.Context, msgs []pool.BatchPublishing) error {
return amqpx.PublishBatch(ctx, msgs)
}

// Get is only supposed to be used for testing, do not use get for polling any broker queues.
func Get(ctx context.Context, queue string, autoAck bool) (msg pool.Delivery, ok bool, err error) {
return amqpx.Get(ctx, queue, autoAck)
Expand Down
120 changes: 120 additions & 0 deletions pool/confirmation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package pool

import (
"context"
"fmt"

"github.com/rabbitmq/amqp091-go"
)

// Confirmation is a deferred confirmation of a single publish.
// We use DeferredConfirmation instead of Transactions due to the performance implications of transactions.
// See: https://www.rabbitmq.com/docs/confirms#publisher-confirms
type Confirmation struct {
dc *amqp091.DeferredConfirmation
s *Session
}

// BatchConfirmation is a deferred confirmation of a batch of publishes.
// We use DeferredConfirmation instead of Transactions due to the performance implications of transactions.
// See: https://www.rabbitmq.com/docs/confirms#publisher-confirms
type BatchConfirmation struct {
dc []*amqp091.DeferredConfirmation
s *Session
}

func (c *Confirmation) Session() *Session {
return c.s
}

func (c *Confirmation) DeferredConfirmation() *amqp091.DeferredConfirmation {
return c.dc
}

func (bc *BatchConfirmation) Session() *Session {
return bc.s
}

func (bc *BatchConfirmation) DeferredConfirmations() []*amqp091.DeferredConfirmation {
return bc.dc
}

// Wait blocks until the server confirms the publish, the channel/connection is closed, the context is cancelled, or an error occurs.
func (c *Confirmation) Wait(ctx context.Context) error {
select {
case <-c.dc.Done():
err := c.s.error()
if err != nil {
return fmt.Errorf("await confirm failed: confirms channel closed: %w", err)
}
if !c.dc.Acked() {
return fmt.Errorf("await confirm failed: %w", ErrNack)
}
return nil
case returned, ok := <-c.s.returned:
if !ok {
err := c.s.error()
if err != nil {
return fmt.Errorf("await confirm failed: returned channel closed: %w", err)
}
return fmt.Errorf("await confirm failed: %w", errReturnedClosed)
}
return fmt.Errorf("await confirm failed: %w: %s", ErrReturned, returned.ReplyText)
case blocking, ok := <-c.s.conn.BlockingFlowControl():
if !ok {
err := c.s.error()
if err != nil {
return fmt.Errorf("await confirm failed: blocking channel closed: %w", err)
}
return fmt.Errorf("await confirm failed: %w", errBlockingFlowControlClosed)
}
return fmt.Errorf("await confirm failed: %w: %s", ErrBlockingFlowControl, blocking.Reason)
case <-ctx.Done():
return fmt.Errorf("await confirm: failed context %w: %w", ErrClosed, ctx.Err())
case <-c.s.ctx.Done():
return fmt.Errorf("await confirm failed: session %w: %w", ErrClosed, c.s.ctx.Err())
}
}

// Wait blocks until the server confirms all publishes, the channel/connection is closed, the context is cancelled, or an error occurs.
func (bc *BatchConfirmation) Wait(ctx context.Context) error {
for {
select {
case returned, ok := <-bc.s.returned:
if !ok {
err := bc.s.error()
if err != nil {
return fmt.Errorf("await confirm failed: returned channel closed: %w", err)
}
return fmt.Errorf("await confirm failed: %w", errReturnedClosed)
}
return fmt.Errorf("await confirm failed: %w: %s", ErrReturned, returned.ReplyText)
case blocking, ok := <-bc.s.conn.BlockingFlowControl():
if !ok {
err := bc.s.error()
if err != nil {
return fmt.Errorf("await confirm failed: blocking channel closed: %w", err)
}
return fmt.Errorf("await confirm failed: %w", errBlockingFlowControlClosed)
}
return fmt.Errorf("await confirm failed: %w: %s", ErrBlockingFlowControl, blocking.Reason)
case <-ctx.Done():
return fmt.Errorf("await confirm: failed context %w: %w", ErrClosed, ctx.Err())
case <-bc.s.ctx.Done():
return fmt.Errorf("await confirm failed: session %w: %w", ErrClosed, bc.s.ctx.Err())
default:
err := bc.s.error()
if err != nil {
return fmt.Errorf("await confirm failed: confirms channel closed: %w", err)
}

for _, dc := range bc.dc {
if !dc.Acked() {
continue // not all acked yet, keep waiting
}
}
}

return nil
}
}
62 changes: 60 additions & 2 deletions pool/publisher.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,31 @@ func NewPublisher(p *Pool, options ...PublisherOption) *Publisher {
return pub
}

// Publishes a batch of messages.
// Each messages can be published to a different exchange and routing key.
func (p *Publisher) PublishBatch(ctx context.Context, msgs []BatchPublishing) error {
for {
err := p.publishBatch(ctx, msgs)
switch {
case err == nil:
return nil
case errors.Is(err, ErrNack):
return err
case errors.Is(err, ErrDeliveryTagMismatch):
return err
default:
if recoverable(err) {
for _, msg := range msgs {
p.warn(msg.Exchange, msg.RoutingKey, err, "publish failed due to recoverable error, retrying")
}
// retry
} else {
return err
}
}
}
}

// Publish a message to a specific exchange with a given routingKey.
// You may set exchange to "" and routingKey to your queue name in order to publish directly to a queue.
func (p *Publisher) Publish(ctx context.Context, exchange string, routingKey string, msg Publishing) error {
Expand All @@ -86,6 +111,39 @@ func (p *Publisher) Publish(ctx context.Context, exchange string, routingKey str
}
}

func (p *Publisher) publishBatch(ctx context.Context, msgs []BatchPublishing) (err error) {
defer func() {
if err != nil {
for _, msg := range msgs {
p.warn(msg.Exchange, msg.RoutingKey, err, "failed to publish message")
}
} else {
for _, msg := range msgs {
p.info(msg.Exchange, msg.RoutingKey, "published a message")
}
}
}()

s, err := p.pool.GetSession(ctx)
if err != nil {
return err
}
defer func() {
p.pool.ReturnSession(s, err)
}()

confirm, err := s.PublishBatch(ctx, msgs)
if err != nil {
return err
}

if !s.IsConfirmable() {
return nil
}

return confirm.Wait(ctx)
}

func (p *Publisher) publish(ctx context.Context, exchange string, routingKey string, msg Publishing) (err error) {
defer func() {
if err != nil {
Expand All @@ -103,7 +161,7 @@ func (p *Publisher) publish(ctx context.Context, exchange string, routingKey str
p.pool.ReturnSession(s, err)
}()

tag, err := s.Publish(ctx, exchange, routingKey, msg)
confirm, err := s.Publish(ctx, exchange, routingKey, msg)
if err != nil {
return err
}
Expand All @@ -112,7 +170,7 @@ func (p *Publisher) publish(ctx context.Context, exchange string, routingKey str
return nil
}

return s.AwaitConfirm(ctx, tag)
return confirm.Wait(ctx)
}

// Get is only supposed to be used for testing, do not use get for polling any broker queues.
Expand Down
80 changes: 80 additions & 0 deletions pool/publisher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,86 @@ func TestSinglePublisher(t *testing.T) {
wg.Wait()
}

func TestSingleBatchPublisher(t *testing.T) {
t.Parallel()

var (
_, connectURL, _ = testutils.NextConnectURL()
ctx = context.TODO()
log = logging.NewTestLogger(t)
nextConnName = testutils.ConnectionNameGenerator()
numMsgs = 5
)

hs, hsclose := NewSession(
t,
ctx,
testutils.HealthyConnectURL,
nextConnName(),
)
defer hsclose()

p, err := pool.New(
ctx,
connectURL,
1,
1,
pool.WithLogger(logging.NewTestLogger(t)),
pool.WithConfirms(true),
pool.WithConnectionRecoverCallback(func(name string, retry int, err error) {
log.Warnf("connection %s is broken, retry %d, error: %s", name, retry, err)
}),
)
if err != nil {
assert.NoError(t, err)
return
}
defer p.Close()

var (
nextExchangeName = testutils.ExchangeNameGenerator(hs.Name())
nextQueueName = testutils.QueueNameGenerator(hs.Name())
exchangeName = nextExchangeName()
queueName = nextQueueName()
)
cleanup := DeclareExchangeQueue(t, ctx, hs, exchangeName, queueName)
defer cleanup()

var (
nextConsumerName = testutils.ConsumerNameGenerator(queueName)
publisherMsgGen = testutils.MessageGenerator(queueName)
consumerMsgGen = testutils.MessageGenerator(queueName)
wg sync.WaitGroup
)

pub := pool.NewPublisher(p)
defer pub.Close()

// TODO: currently this test allows duplication of messages
ConsumeAsyncN(t, ctx, &wg, hs, queueName, nextConsumerName(), consumerMsgGen, numMsgs, true)

msgs := []pool.BatchPublishing{}
for i := 0; i < numMsgs; i++ {
msg := publisherMsgGen()
msgs = append(msgs, pool.BatchPublishing{
Exchange: exchangeName,
RoutingKey: "",
Publishing: pool.Publishing{
Mandatory: true,
ContentType: "text/plain",
Body: []byte(msg),
},
})
}
err = pub.PublishBatch(ctx, msgs)
if err != nil {
assert.NoError(t, err, "when publishing batch message")
return

}
wg.Wait()
}

/*
// TODO: out of memory rabbitmq tests are disabled until https://github.com/rabbitmq/amqp091-go/issues/253 is resolved
func TestPublishAwaitFlowControl(t *testing.T) {
Expand Down
Loading

0 comments on commit 2401e98

Please sign in to comment.