diff --git a/mq/producer.go b/mq/producer.go index 32555e3..ea13e5b 100644 --- a/mq/producer.go +++ b/mq/producer.go @@ -3,15 +3,23 @@ package mq import ( "context" "crypto/md5" + "errors" "fmt" + "math" + "time" + amqp "github.com/rabbitmq/amqp091-go" "go.opentelemetry.io/otel" + "golang.org/x/sync/singleflight" ) type Producer struct { Conn *amqp.Connection Channel *amqp.Channel + amqpURI string // AMQP URI for RabbitMQ reconnection + sf *singleflight.Group + appId string } @@ -21,7 +29,9 @@ func NewProducer(appId string, amqpURI string) (*Producer, error) { } p := &Producer{ - appId: appId, + appId: appId, + amqpURI: amqpURI, + sf: new(singleflight.Group), } var err error @@ -34,6 +44,53 @@ func NewProducer(appId string, amqpURI string) (*Producer, error) { return p, nil } +func (p *Producer) isConnected() bool { + return !p.Conn.IsClosed() && !p.Channel.IsClosed() +} + +func (p *Producer) connectFn() error { + if p.isConnected() { + return nil + } + + _, err, _ := p.sf.Do("reconnect", func() (interface{}, error) { + var lastErr error + for i := 0; i < 3; i++ { + if p.isConnected() { + return nil, nil + } + + if i > 0 { + time.Sleep(time.Second * time.Duration(math.Pow(2, float64(i-1)))) + } + + conn, channel, err := initConnection(p.amqpURI) + if err != nil { + lastErr = fmt.Errorf("reconnect attempt %d failed: %s", i+1, err) + continue + } + + oldConn := p.Conn + oldChannel := p.Channel + p.Conn = conn + p.Channel = channel + + if oldChannel != nil { + _ = oldChannel.Close() + } + if oldConn != nil { + _ = oldConn.Close() + } + + return nil, nil + } + + return nil, lastErr + }) + + return err +} + func (p *Producer) PublishNotice(ctx context.Context, data *NoticeTemplate, options ...string) error { if data == nil { @@ -107,6 +164,28 @@ func (p *Producer) publish(ctx context.Context, key string, msg []byte, opts map Headers: headers, }) + if err != nil && errors.Is(err, amqp.ErrClosed) { + if err = p.connectFn(); err != nil { + return err + } + + err = p.Channel.PublishWithContext( + ctx, + exchangeName, + key, + false, + false, + amqp.Publishing{ + ContentType: "application/json", + DeliveryMode: amqp.Persistent, + Body: msg, + AppId: p.appId, + UserId: opts[UserIdKey], + MessageId: fmt.Sprintf("%x", md5.Sum(msg)), + Headers: headers, + }) + } + return err }