Skip to content

Commit

Permalink
Fix the reconnection logic for parallel writes with HA Producer. (#335)
Browse files Browse the repository at this point in the history
* fix lock on concurrent writes while reconnecting
* refactor test
* example: update reliable client example
---------

Co-authored-by: Gabriele Santomaggio <[email protected]>
  • Loading branch information
hiimjako and Gsantomaggio authored Jul 22, 2024
1 parent 18547a0 commit cb4e0d2
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 34 deletions.
57 changes: 31 additions & 26 deletions examples/reliable/reliable_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,20 @@ import (
"bufio"
"errors"
"fmt"
"github.com/rabbitmq/rabbitmq-stream-go-client/pkg/amqp"
"github.com/rabbitmq/rabbitmq-stream-go-client/pkg/ha"
"github.com/rabbitmq/rabbitmq-stream-go-client/pkg/message"
"github.com/rabbitmq/rabbitmq-stream-go-client/pkg/stream"
"os"
"sync"
"sync/atomic"
"time"
)

// The ha producer and consumer provide a way to auto-reconnect in case of connection problems

import (
"github.com/rabbitmq/rabbitmq-stream-go-client/pkg/amqp"
"github.com/rabbitmq/rabbitmq-stream-go-client/pkg/ha"
"github.com/rabbitmq/rabbitmq-stream-go-client/pkg/logs"
"github.com/rabbitmq/rabbitmq-stream-go-client/pkg/message"
"github.com/rabbitmq/rabbitmq-stream-go-client/pkg/stream"
)

// The ha producer and consumer provide a way to auto-reconnect in case of connection problems

func CheckErr(err error) {
if err != nil {
fmt.Printf("%s ", err)
Expand All @@ -37,6 +35,7 @@ func main() {
// Tune the parameters to test the reliability
const messagesToSend = 5_000_000
const numberOfProducers = 2
const concurrentProducers = 2
const numberOfConsumers = 2
const sendDelay = 1 * time.Millisecond
const delayEachMessages = 200
Expand Down Expand Up @@ -81,8 +80,9 @@ func main() {
go func() {
for isRunning {
totalConfirmed := atomic.LoadInt32(&confirmed) + atomic.LoadInt32(&fail)
fmt.Printf("%s - ToSend: %d - nProducers: %d - nConsumers %d \n", time.Now().Format(time.RFC822),
messagesToSend*numberOfProducers, numberOfProducers, numberOfConsumers)
expectedMessages := messagesToSend * numberOfProducers * concurrentProducers
fmt.Printf("%s - ToSend: %d - nProducers: %d - concurrentProducers: %d - nConsumers %d \n", time.Now().Format(time.RFC822),
expectedMessages, numberOfProducers, concurrentProducers, numberOfConsumers)
fmt.Printf("Sent:%d - ReSent %d - Confirmed:%d - Not confirmed:%d - Fail+Confirmed :%d \n",
sent, atomic.LoadInt32(&reSent), atomic.LoadInt32(&confirmed), atomic.LoadInt32(&fail), totalConfirmed)
fmt.Printf("Total Consumed: %d - Per consumer: %d \n", atomic.LoadInt32(&consumed),
Expand Down Expand Up @@ -120,22 +120,27 @@ func main() {
CheckErr(err)
producers = append(producers, rProducer)
go func() {
for i := 0; i < messagesToSend; i++ {
msg := amqp.NewMessage([]byte("ha"))
mutex.Lock()
for _, confirmedMessage := range unConfirmedMessages {
err := rProducer.Send(confirmedMessage)
atomic.AddInt32(&reSent, 1)
CheckErr(err)
}
unConfirmedMessages = []message.StreamMessage{}
mutex.Unlock()
err := rProducer.Send(msg)
if i%delayEachMessages == 0 {
time.Sleep(sendDelay)
}
atomic.AddInt32(&sent, 1)
CheckErr(err)
for i := 0; i < concurrentProducers; i++ {
go func() {
for i := 0; i < messagesToSend; i++ {
msg := amqp.NewMessage([]byte("ha"))
mutex.Lock()
for _, confirmedMessage := range unConfirmedMessages {
err := rProducer.Send(confirmedMessage)
atomic.AddInt32(&reSent, 1)
CheckErr(err)
}
unConfirmedMessages = []message.StreamMessage{}
mutex.Unlock()
err := rProducer.Send(msg)
if i%delayEachMessages == 0 {
time.Sleep(sendDelay)
}
atomic.AddInt32(&sent, 1)
CheckErr(err)

}
}()
}
}()
}
Expand Down
17 changes: 9 additions & 8 deletions pkg/ha/ha_publisher.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,17 @@ func (p *ReliableProducer) handleNotifyClose(channelClose stream.ChannelClose) {
p.setStatus(StatusClosed)
}

select {
case p.reconnectionSignal <- struct{}{}:
case <-time.After(2 * time.Second):
}
p.reconnectionSignal.L.Lock()
p.reconnectionSignal.Broadcast()
p.reconnectionSignal.L.Unlock()
}
}()
}

// ReliableProducer is a producer that can reconnect in case of connection problems
// the function handlePublishConfirm is mandatory
// in case of problems the messages have the message.Confirmed == false
// The function `send` is blocked during the reconnection
// The functions `Send` and `SendBatch` are blocked during the reconnection
type ReliableProducer struct {
env *stream.Environment
producer *stream.Producer
Expand All @@ -64,7 +63,7 @@ type ReliableProducer struct {
mutex *sync.Mutex
mutexStatus *sync.Mutex
status int
reconnectionSignal chan struct{}
reconnectionSignal *sync.Cond
}

type ConfirmMessageHandler func(messageConfirm []*stream.ConfirmationStatus)
Expand All @@ -81,7 +80,7 @@ func NewReliableProducer(env *stream.Environment, streamName string,
mutex: &sync.Mutex{},
mutexStatus: &sync.Mutex{},
confirmMessageHandler: confirmMessageHandler,
reconnectionSignal: make(chan struct{}),
reconnectionSignal: sync.NewCond(&sync.Mutex{}),
}
if confirmMessageHandler == nil {
return nil, fmt.Errorf("the confirmation message handler is mandatory")
Expand Down Expand Up @@ -121,7 +120,9 @@ func (p *ReliableProducer) isReadyToSend() error {

if p.GetStatus() == StatusReconnecting {
logs.LogDebug("[Reliable] %s is reconnecting. The send is blocked", p.getInfo())
<-p.reconnectionSignal
p.reconnectionSignal.L.Lock()
p.reconnectionSignal.Wait()
p.reconnectionSignal.L.Unlock()
logs.LogDebug("[Reliable] %s reconnected. The send is unlocked", p.getInfo())
}

Expand Down
63 changes: 63 additions & 0 deletions pkg/ha/ha_publisher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,69 @@ var _ = Describe("Reliable Producer", func() {
Expect(producer.Close()).NotTo(HaveOccurred())
})

It("unblock all Reliable Producer sends while restarting with concurrent writes", func() {
const expectedMessages = 2
signal := make(chan struct{})
var confirmed int32
clientProvidedName := uuid.New().String()
producer, err := NewReliableProducer(envForRProducer,
streamForRProducer,
NewProducerOptions().
SetClientProvidedName(clientProvidedName),
func(messageConfirm []*ConfirmationStatus) {
for _, confirm := range messageConfirm {
Expect(confirm.IsConfirmed()).To(BeTrue())
}
if atomic.AddInt32(&confirmed, int32(len(messageConfirm))) == expectedMessages {
signal <- struct{}{}
}
})
Expect(err).NotTo(HaveOccurred())

time.Sleep(1 * time.Second)
connectionToDrop := ""
Eventually(func() bool {
connections, err := test_helper.Connections("15672")
if err != nil {
return false
}
for _, connection := range connections {
if connection.ClientProperties.Connection_name == clientProvidedName {
connectionToDrop = connection.Name
return true
}
}
return false
}, time.Second*5).
Should(BeTrue())

Expect(connectionToDrop).NotTo(BeEmpty())

// concurret writes while reconnecting
sendMsg := func() {
msg := amqp.NewMessage([]byte("ha"))
batch := []message.StreamMessage{msg}
err := producer.BatchSend(batch)
Expect(err).NotTo(HaveOccurred())
}

// kill the connection
errDrop := test_helper.DropConnection(connectionToDrop, "15672")
Expect(errDrop).NotTo(HaveOccurred())

// wait for the producer to be in reconnecting state
Eventually(func() bool {
return producer.GetStatus() == StatusReconnecting
}, time.Second*5, time.Millisecond).
Should(BeTrue())

go sendMsg()
go sendMsg()

<-signal
Expect(producer.Close()).NotTo(HaveOccurred())
})

It("Delete the stream should close the producer", func() {
producer, err := NewReliableProducer(envForRProducer,
streamForRProducer, NewProducerOptions(), func(messageConfirm []*ConfirmationStatus) {
Expand Down

0 comments on commit cb4e0d2

Please sign in to comment.