diff --git a/pkg/lib/collections/hashed_priority_queue.go b/pkg/lib/collections/hashed_priority_queue.go index 66f0b2d232..22ba6aeb86 100644 --- a/pkg/lib/collections/hashed_priority_queue.go +++ b/pkg/lib/collections/hashed_priority_queue.go @@ -1,14 +1,60 @@ package collections -import "sync" - +import ( + "sync" +) + +// HashedPriorityQueue is a priority queue that maintains only a single item per unique key. +// It combines the functionality of a hash map and a priority queue to provide efficient +// operations with the following key characteristics: +// +// 1. Single Item Per Key: The queue maintains only the latest version of an item for each +// unique key. When a new item with an existing key is enqueued, it replaces the old item +// instead of adding a duplicate. +// +// 2. Lazy Dequeuing: Outdated items (those that have been replaced by a newer version) are +// not immediately removed from the underlying queue. Instead, they are filtered out +// during dequeue operations. This approach improves enqueue performance +// at the cost of potentially slower dequeue operations. +// +// 3. Higher Enqueue Throughput: By avoiding immediate removal of outdated items, the +// HashedPriorityQueue achieves higher enqueue throughput. This makes it particularly +// suitable for scenarios with frequent updates to existing items. +// +// 4. Eventually Consistent: The queue becomes consistent over time as outdated items are +// lazily removed during dequeue operations. This means that the queue's length and the +// items it contains become accurate as items are dequeued. +// +// 5. Memory Consideration: Due to the lazy removal of outdated items, the underlying queue +// may temporarily hold more items than there are unique keys. This trade-off allows for +// better performance but may use more memory compared to a strictly consistent queue. +// +// Use HashedPriorityQueue when you need a priority queue that efficiently handles updates +// to existing items and can tolerate some latency in removing outdated entries in favor +// of higher enqueue performance. type HashedPriorityQueue[K comparable, T any] struct { - identifiers map[K]struct{} - queue *PriorityQueue[T] + identifiers map[K]int64 + queue *PriorityQueue[versionedItem[T]] mu sync.RWMutex indexer IndexerFunc[K, T] } +// versionedItem wraps the actual data item with a version number. +// This structure is used internally by HashedPriorityQueue to implement +// the versioning mechanism that allows for efficient updates and +// lazy removal of outdated items. The queue is only interested in +// the latest version of an item for each unique key: +// - data: The actual item of type T stored in the queue. +// - version: A monotonically increasing number representing the +// version of this item. When an item with the same key is enqueued, +// its version is incremented. This allows the queue to identify +// the most recent version during dequeue operations and discard +// any older versions of the same item. +type versionedItem[T any] struct { + data T + version int64 +} + // IndexerFunc is used to find the key (of type K) from the provided // item (T). This will be used for the item lookup in `Contains` type IndexerFunc[K comparable, T any] func(item T) K @@ -18,12 +64,24 @@ type IndexerFunc[K comparable, T any] func(item T) K // be used on Enqueue/Dequeue to keep the index up to date. func NewHashedPriorityQueue[K comparable, T any](indexer IndexerFunc[K, T]) *HashedPriorityQueue[K, T] { return &HashedPriorityQueue[K, T]{ - identifiers: make(map[K]struct{}), - queue: NewPriorityQueue[T](), + identifiers: make(map[K]int64), + queue: NewPriorityQueue[versionedItem[T]](), indexer: indexer, } } +// isLatestVersion checks if the given item is the latest version +func (q *HashedPriorityQueue[K, T]) isLatestVersion(item versionedItem[T]) bool { + k := q.indexer(item.data) + currentVersion := q.identifiers[k] + return item.version == currentVersion +} + +// unwrapQueueItem converts a versionedItem to a QueueItem +func (q *HashedPriorityQueue[K, T]) unwrapQueueItem(item *QueueItem[versionedItem[T]]) *QueueItem[T] { + return &QueueItem[T]{Value: item.Value.data, Priority: item.Priority} +} + // Contains will return true if the provided identifier (of type K) // will be found in this queue, false if it is not present. func (q *HashedPriorityQueue[K, T]) Contains(id K) bool { @@ -40,9 +98,9 @@ func (q *HashedPriorityQueue[K, T]) Enqueue(data T, priority int64) { defer q.mu.Unlock() k := q.indexer(data) - - q.identifiers[k] = struct{}{} - q.queue.Enqueue(data, priority) + version := q.identifiers[k] + 1 + q.identifiers[k] = version + q.queue.Enqueue(versionedItem[T]{data: data, version: version}, priority) } // Dequeue returns the next highest priority item, returning both @@ -53,16 +111,39 @@ func (q *HashedPriorityQueue[K, T]) Dequeue() *QueueItem[T] { q.mu.Lock() defer q.mu.Unlock() - item := q.queue.Dequeue() - if item == nil { - return nil + for { + item := q.queue.Dequeue() + if item == nil { + return nil + } + + if q.isLatestVersion(item.Value) { + k := q.indexer(item.Value.data) + delete(q.identifiers, k) + return q.unwrapQueueItem(item) + } } +} + +// Peek returns the next highest priority item without removing it from the queue. +// It returns nil if the queue is empty. +func (q *HashedPriorityQueue[K, T]) Peek() *QueueItem[T] { + q.mu.RLock() + defer q.mu.RUnlock() - // Find the key for the item and delete it from the presence map - k := q.indexer(item.Value) - delete(q.identifiers, k) + for { + item := q.queue.Peek() + if item == nil { + return nil + } - return item + if q.isLatestVersion(item.Value) { + return q.unwrapQueueItem(item) + } + + // If the peeked item is outdated, remove it and continue + q.queue.Dequeue() + } } // DequeueWhere allows the caller to iterate through the queue, in priority order, and @@ -74,26 +155,32 @@ func (q *HashedPriorityQueue[K, T]) DequeueWhere(matcher MatchingFunction[T]) *Q q.mu.Lock() defer q.mu.Unlock() - item := q.queue.DequeueWhere(matcher) - if item == nil { - return nil - } + for { + item := q.queue.DequeueWhere(func(vi versionedItem[T]) bool { + return matcher(vi.data) + }) - k := q.indexer(item.Value) - delete(q.identifiers, k) + if item == nil { + return nil + } - return item + if q.isLatestVersion(item.Value) { + k := q.indexer(item.Value.data) + delete(q.identifiers, k) + return q.unwrapQueueItem(item) + } + } } // Len returns the number of items currently in the queue func (q *HashedPriorityQueue[K, T]) Len() int { - return q.queue.Len() + return len(q.identifiers) } // IsEmpty returns a boolean denoting whether the queue is // currently empty or not. func (q *HashedPriorityQueue[K, T]) IsEmpty() bool { - return q.queue.Len() == 0 + return q.Len() == 0 } var _ PriorityQueueInterface[struct{}] = (*HashedPriorityQueue[string, struct{}])(nil) diff --git a/pkg/lib/collections/hashed_priority_queue_test.go b/pkg/lib/collections/hashed_priority_queue_test.go index 04340b939a..f5838fbe99 100644 --- a/pkg/lib/collections/hashed_priority_queue_test.go +++ b/pkg/lib/collections/hashed_priority_queue_test.go @@ -5,12 +5,21 @@ package collections_test import ( "testing" - "github.com/bacalhau-project/bacalhau/pkg/lib/collections" "github.com/stretchr/testify/suite" + + "github.com/bacalhau-project/bacalhau/pkg/lib/collections" ) type HashedPriorityQueueSuite struct { - suite.Suite + PriorityQueueTestSuite +} + +func (s *HashedPriorityQueueSuite) SetupTest() { + s.NewQueue = func() collections.PriorityQueueInterface[TestData] { + return collections.NewHashedPriorityQueue[string, TestData](func(t TestData) string { + return t.id + }) + } } func TestHashedPriorityQueueSuite(t *testing.T) { @@ -18,20 +27,120 @@ func TestHashedPriorityQueueSuite(t *testing.T) { } func (s *HashedPriorityQueueSuite) TestContains() { - type TestData struct { - id string - data int - } - - indexer := func(t TestData) string { - return t.id - } - - q := collections.NewHashedPriorityQueue[string, TestData](indexer) + q := s.NewQueue().(*collections.HashedPriorityQueue[string, TestData]) s.Require().False(q.Contains("A")) - q.Enqueue(TestData{id: "A", data: 0}, 1) + q.Enqueue(TestData{"A", 0}, 1) s.Require().True(q.Contains("A")) _ = q.Dequeue() s.Require().False(q.Contains("A")) } + +func (s *HashedPriorityQueueSuite) TestPeek() { + q := s.NewQueue().(*collections.HashedPriorityQueue[string, TestData]) + + q.Enqueue(TestData{"A", 1}, 1) + q.Enqueue(TestData{"B", 2}, 2) + + item := q.Peek() + s.Require().NotNil(item) + s.Require().Equal(TestData{"B", 2}, item.Value) + s.Require().True(q.Contains("A"), "Item A should still be in the queue after Peek") + s.Require().True(q.Contains("B"), "Item B should still be in the queue after Peek") + + _ = q.Dequeue() + s.Require().False(q.Contains("B"), "Item B should not be in the queue after Dequeue") + s.Require().True(q.Contains("A"), "Item A should still be in the queue after Dequeue") +} + +func (s *HashedPriorityQueueSuite) TestSingleItemPerKey() { + q := s.NewQueue().(*collections.HashedPriorityQueue[string, TestData]) + + q.Enqueue(TestData{"A", 1}, 1) + q.Enqueue(TestData{"A", 2}, 2) + q.Enqueue(TestData{"A", 3}, 3) + + s.Require().Equal(1, q.Len(), "Queue should only contain one item for key 'A'") + + item := q.Dequeue() + s.Require().NotNil(item) + s.Require().Equal(TestData{"A", 3}, item.Value, "Should return the latest version of item 'A'") + s.Require().Equal(int64(3), item.Priority, "Should have the priority of the latest enqueue") + + s.Require().Nil(q.Dequeue(), "Queue should be empty after dequeuing the single item") +} + +func (s *HashedPriorityQueueSuite) TestPeekReturnsLatestVersion() { + q := s.NewQueue().(*collections.HashedPriorityQueue[string, TestData]) + + q.Enqueue(TestData{"A", 1}, 1) + q.Enqueue(TestData{"B", 1}, 3) + q.Enqueue(TestData{"A", 2}, 2) + + item := q.Peek() + s.Require().NotNil(item) + s.Require().Equal(TestData{"B", 1}, item.Value, "Peek should return 'B' as it has the highest priority") + s.Require().Equal(int64(3), item.Priority) + + q.Enqueue(TestData{"B", 2}, 1) // Lower priority, but newer version + + item = q.Peek() + s.Require().NotNil(item) + s.Require().Equal(TestData{"A", 2}, item.Value, "Peek should now return 'A' as 'B' has lower priority") + s.Require().Equal(int64(2), item.Priority) +} + +func (s *HashedPriorityQueueSuite) TestDequeueWhereReturnsLatestVersion() { + q := s.NewQueue().(*collections.HashedPriorityQueue[string, TestData]) + + q.Enqueue(TestData{"A", 1}, 1) + q.Enqueue(TestData{"B", 1}, 2) + q.Enqueue(TestData{"A", 2}, 3) + + item := q.DequeueWhere(func(td TestData) bool { + return td.id == "A" + }) + + s.Require().NotNil(item) + s.Require().Equal(TestData{"A", 2}, item.Value, "DequeueWhere should return the latest version of 'A'") + s.Require().Equal(int64(3), item.Priority) + + s.Require().False(q.Contains("A"), "A should no longer be in the queue") + s.Require().True(q.Contains("B"), "B should still be in the queue") +} + +func (s *HashedPriorityQueueSuite) TestDuplicateKeys() { + inputs := []struct { + v TestData + p int64 + }{ + {TestData{"A", 1}, 3}, + {TestData{"B", 2}, 2}, + {TestData{"A", 3}, 1}, // Duplicate key with lower priority + {TestData{"C", 4}, 4}, + {TestData{"B", 5}, 5}, // Duplicate key with higher priority + } + + pq := s.NewQueue() + for _, tc := range inputs { + pq.Enqueue(tc.v, tc.p) + } + + expected := []struct { + v TestData + p int64 + }{ + {TestData{"B", 5}, 5}, + {TestData{"C", 4}, 4}, + {TestData{"A", 3}, 1}, + } + + for _, exp := range expected { + qitem := pq.Dequeue() + s.Require().NotNil(qitem) + s.Require().Equal(exp.v, qitem.Value) + s.Require().Equal(exp.p, qitem.Priority) + } + + s.Require().True(pq.IsEmpty()) +} diff --git a/pkg/lib/collections/priority_queue.go b/pkg/lib/collections/priority_queue.go index 72129a414e..8448111d3e 100644 --- a/pkg/lib/collections/priority_queue.go +++ b/pkg/lib/collections/priority_queue.go @@ -35,6 +35,10 @@ type PriorityQueueInterface[T any] interface { // extra PriorityQueue) for the dequeued items. DequeueWhere(matcher MatchingFunction[T]) *QueueItem[T] + // Peek returns the next highest priority item without removing it from the queue. + // It returns nil if the queue is empty. + Peek() *QueueItem[T] + // Len returns the number of items currently in the queue Len() int @@ -117,6 +121,22 @@ func (pq *PriorityQueue[T]) dequeue() *QueueItem[T] { return &QueueItem[T]{Value: item, Priority: heapItem.priority} } +// Peek returns the next highest priority item without removing it from the queue. +// It returns nil if the queue is empty. +func (pq *PriorityQueue[T]) Peek() *QueueItem[T] { + pq.mu.Lock() + defer pq.mu.Unlock() + + if pq.IsEmpty() { + return nil + } + + heapItem := pq.internalQueue[0] + item, _ := heapItem.value.(T) + + return &QueueItem[T]{Value: item, Priority: heapItem.priority} +} + // DequeueWhere allows the caller to iterate through the queue, in priority order, and // attempt to match an item using the provided `MatchingFunction`. This method has a high // time cost as dequeued but non-matching items must be held and requeued once the process diff --git a/pkg/lib/collections/priority_queue_base_test.go b/pkg/lib/collections/priority_queue_base_test.go new file mode 100644 index 0000000000..0a5db6f5cd --- /dev/null +++ b/pkg/lib/collections/priority_queue_base_test.go @@ -0,0 +1,153 @@ +//go:build unit || !integration + +package collections_test + +import ( + "github.com/stretchr/testify/suite" + + "github.com/bacalhau-project/bacalhau/pkg/lib/collections" +) + +type TestData struct { + id string + data int +} + +type PriorityQueueTestSuite struct { + suite.Suite + NewQueue func() collections.PriorityQueueInterface[TestData] +} + +func (s *PriorityQueueTestSuite) TestSimple() { + type testcase struct { + v TestData + p int64 + } + inputs := []testcase{ + {TestData{"B", 2}, 2}, {TestData{"A", 1}, 3}, {TestData{"C", 3}, 1}, + } + expected := []testcase{ + {TestData{"A", 1}, 3}, {TestData{"B", 2}, 2}, {TestData{"C", 3}, 1}, + } + + pq := s.NewQueue() + for _, tc := range inputs { + pq.Enqueue(tc.v, tc.p) + } + + for _, tc := range expected { + qitem := pq.Dequeue() + s.Require().NotNil(qitem) + s.Require().Equal(tc.v, qitem.Value) + s.Require().Equal(tc.p, qitem.Priority) + } + + s.Require().True(pq.IsEmpty()) +} + +func (s *PriorityQueueTestSuite) TestSimpleMin() { + type testcase struct { + v TestData + p int64 + } + inputs := []testcase{ + {TestData{"B", 2}, -2}, {TestData{"A", 1}, -3}, {TestData{"C", 3}, -1}, + } + expected := []testcase{ + {TestData{"C", 3}, -1}, {TestData{"B", 2}, -2}, {TestData{"A", 1}, -3}, + } + + pq := s.NewQueue() + for _, tc := range inputs { + pq.Enqueue(tc.v, tc.p) + } + + for _, tc := range expected { + qitem := pq.Dequeue() + s.Require().NotNil(qitem) + s.Require().Equal(tc.v, qitem.Value) + s.Require().Equal(tc.p, qitem.Priority) + } + + s.Require().True(pq.IsEmpty()) +} + +func (s *PriorityQueueTestSuite) TestEmpty() { + pq := s.NewQueue() + qitem := pq.Dequeue() + s.Require().Nil(qitem) + s.Require().True(pq.IsEmpty()) +} + +func (s *PriorityQueueTestSuite) TestDequeueWhere() { + pq := s.NewQueue() + pq.Enqueue(TestData{"A", 1}, 4) + pq.Enqueue(TestData{"D", 4}, 1) + pq.Enqueue(TestData{"D", 4}, 1) + pq.Enqueue(TestData{"D", 4}, 1) + pq.Enqueue(TestData{"D", 4}, 1) + pq.Enqueue(TestData{"B", 2}, 3) + pq.Enqueue(TestData{"C", 3}, 2) + + count := pq.Len() + + qitem := pq.DequeueWhere(func(possibleMatch TestData) bool { + return possibleMatch.id == "B" + }) + + s.Require().NotNil(qitem) + s.Require().Equal(TestData{"B", 2}, qitem.Value) + s.Require().Equal(int64(3), qitem.Priority) + s.Require().Equal(count-1, pq.Len()) +} + +func (s *PriorityQueueTestSuite) TestDequeueWhereFail() { + pq := s.NewQueue() + pq.Enqueue(TestData{"A", 1}, 4) + + qitem := pq.DequeueWhere(func(possibleMatch TestData) bool { + return possibleMatch.id == "Z" + }) + + s.Require().Nil(qitem) +} + +func (s *PriorityQueueTestSuite) TestPeek() { + pq := s.NewQueue() + + // Test 1: Peek on an empty queue + item := pq.Peek() + s.Require().Nil(item, "Peek on an empty queue should return nil") + + // Test 2: Peek after adding one item + pq.Enqueue(TestData{"A", 1}, 1) + item = pq.Peek() + s.Require().NotNil(item, "Peek should return an item") + s.Require().Equal(TestData{"A", 1}, item.Value, "Peek should return the correct value") + s.Require().Equal(int64(1), item.Priority, "Peek should return the correct priority") + s.Require().Equal(1, pq.Len(), "Peek should not remove the item from the queue") + + // Test 3: Peek with multiple items + pq.Enqueue(TestData{"B", 2}, 3) + pq.Enqueue(TestData{"C", 3}, 2) + item = pq.Peek() + s.Require().NotNil(item, "Peek should return an item") + s.Require().Equal(TestData{"B", 2}, item.Value, "Peek should return the highest priority item") + s.Require().Equal(int64(3), item.Priority, "Peek should return the correct priority") + s.Require().Equal(3, pq.Len(), "Peek should not remove any items from the queue") + + // Test 4: Peek after dequeue + dequeuedItem := pq.Dequeue() + s.Require().Equal(TestData{"B", 2}, dequeuedItem.Value, "Dequeue should return the highest priority item") + item = pq.Peek() + s.Require().NotNil(item, "Peek should return an item") + s.Require().Equal(TestData{"C", 3}, item.Value, "Peek should return the new highest priority item after dequeue") + s.Require().Equal(int64(2), item.Priority, "Peek should return the correct priority") + s.Require().Equal(2, pq.Len(), "Queue length should be reduced after dequeue") + + // Test 5: Multiple peeks should return the same item + item1 := pq.Peek() + item2 := pq.Peek() + s.Require().Equal(item1, item2, "Multiple peeks should return the same item") + s.Require().Equal(2, pq.Len(), "Multiple peeks should not change the queue length") +} diff --git a/pkg/lib/collections/priority_queue_test.go b/pkg/lib/collections/priority_queue_test.go index 9a2d115ab8..ec3adeb23d 100644 --- a/pkg/lib/collections/priority_queue_test.go +++ b/pkg/lib/collections/priority_queue_test.go @@ -5,109 +5,59 @@ package collections_test import ( "testing" - "github.com/bacalhau-project/bacalhau/pkg/lib/collections" "github.com/stretchr/testify/suite" + + "github.com/bacalhau-project/bacalhau/pkg/lib/collections" ) type PriorityQueueSuite struct { - suite.Suite + PriorityQueueTestSuite +} + +func (s *PriorityQueueSuite) SetupTest() { + s.NewQueue = func() collections.PriorityQueueInterface[TestData] { + return collections.NewPriorityQueue[TestData]() + } } func TestPriorityQueueSuite(t *testing.T) { suite.Run(t, new(PriorityQueueSuite)) } -func (s *PriorityQueueSuite) TestSimple() { - type testcase struct { - v string +func (s *PriorityQueueSuite) TestDuplicateKeys() { + inputs := []struct { + v TestData p int64 - } - inputs := []testcase{ - {"B", 2}, {"A", 3}, {"C", 1}, {"A", 3}, {"C", 1}, {"B", 2}, - } - expected := []testcase{ - {"A", 3}, {"A", 3}, {"B", 2}, {"B", 2}, {"C", 1}, {"C", 1}, + }{ + {TestData{"A", 1}, 3}, + {TestData{"B", 2}, 2}, + {TestData{"A", 3}, 1}, // Duplicate key with lower priority + {TestData{"C", 4}, 4}, + {TestData{"B", 5}, 5}, // Duplicate key with higher priority } - pq := collections.NewPriorityQueue[string]() + pq := s.NewQueue() for _, tc := range inputs { - pq.Enqueue(tc.v, int64(tc.p)) + pq.Enqueue(tc.v, tc.p) } - for _, tc := range expected { - qitem := pq.Dequeue() - s.Require().NotNil(qitem) - s.Require().Equal(tc.v, qitem.Value) - s.Require().Equal(tc.p, qitem.Priority) - } - - s.Require().True(pq.IsEmpty()) -} - -func (s *PriorityQueueSuite) TestSimpleMin() { - type testcase struct { - v string + expected := []struct { + v TestData p int64 - } - inputs := []testcase{ - {"B", -2}, {"A", -3}, {"C", -1}, {"A", -3}, {"C", -1}, {"B", -2}, - } - expected := []testcase{ - {"C", -1}, {"C", -1}, {"B", -2}, {"B", -2}, {"A", -3}, {"A", -3}, + }{ + {TestData{"B", 5}, 5}, + {TestData{"C", 4}, 4}, + {TestData{"A", 1}, 3}, + {TestData{"B", 2}, 2}, + {TestData{"A", 3}, 1}, } - pq := collections.NewPriorityQueue[string]() - for _, tc := range inputs { - pq.Enqueue(tc.v, int64(tc.p)) - } - - for _, tc := range expected { + for _, exp := range expected { qitem := pq.Dequeue() s.Require().NotNil(qitem) - s.Require().Equal(tc.v, qitem.Value) - s.Require().Equal(tc.p, qitem.Priority) + s.Require().Equal(exp.v, qitem.Value) + s.Require().Equal(exp.p, qitem.Priority) } s.Require().True(pq.IsEmpty()) } - -func (s *PriorityQueueSuite) TestEmpty() { - pq := collections.NewPriorityQueue[string]() - qitem := pq.Dequeue() - s.Require().Nil(qitem) - s.Require().True(pq.IsEmpty()) -} - -func (s *PriorityQueueSuite) TestDequeueWhere() { - pq := collections.NewPriorityQueue[string]() - pq.Enqueue("A", 4) - pq.Enqueue("D", 1) - pq.Enqueue("D", 1) - pq.Enqueue("D", 1) - pq.Enqueue("D", 1) - pq.Enqueue("B", 3) - pq.Enqueue("C", 2) - - count := pq.Len() - - qitem := pq.DequeueWhere(func(possibleMatch string) bool { - return possibleMatch == "B" - }) - - s.Require().NotNil(qitem) - s.Require().Equal("B", qitem.Value) - s.Require().Equal(int64(3), qitem.Priority) - s.Require().Equal(count-1, pq.Len()) - -} - -func (s *PriorityQueueSuite) TestDequeueWhereFail() { - pq := collections.NewPriorityQueue[string]() - pq.Enqueue("A", 4) - - qitem := pq.DequeueWhere(func(possibleMatch string) bool { - return possibleMatch == "Z" - }) - - s.Require().Nil(qitem) -} diff --git a/pkg/node/heartbeat/heartbeat_test.go b/pkg/node/heartbeat/heartbeat_test.go index e9b5743620..a1a814cf66 100644 --- a/pkg/node/heartbeat/heartbeat_test.go +++ b/pkg/node/heartbeat/heartbeat_test.go @@ -6,12 +6,14 @@ import ( "context" "fmt" "strconv" + "sync" "testing" "time" "github.com/benbjohnson/clock" "github.com/nats-io/nats-server/v2/server" "github.com/nats-io/nats.go" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" @@ -226,3 +228,122 @@ func (s *HeartbeatTestSuite) TestSendHeartbeatError() { err = client.SendHeartbeat(ctx, 1) s.Error(err) } + +func (s *HeartbeatTestSuite) TestConcurrentHeartbeats() { + ctx := context.Background() + numNodes := 10 + numHeartbeatsPerNode := 100 + + var wg sync.WaitGroup + wg.Add(numNodes) + + for i := 0; i < numNodes; i++ { + go func(nodeID string) { + defer wg.Done() + client, err := NewClient(s.natsConn, nodeID, s.publisher) + require.NoError(s.T(), err) + + for j := 0; j < numHeartbeatsPerNode; j++ { + s.Require().NoError(client.SendHeartbeat(ctx, uint64(j))) + time.Sleep(time.Millisecond) // Small delay to simulate real-world scenario + } + }(fmt.Sprintf("node-%d", i)) + } + + wg.Wait() + + // Allow time for all heartbeats to be processed + time.Sleep(100 * time.Millisecond) + + // Verify that all nodes are marked as HEALTHY + for i := 0; i < numNodes; i++ { + nodeID := fmt.Sprintf("node-%d", i) + nodeState := &models.NodeState{Info: models.NodeInfo{NodeID: nodeID}} + s.heartbeatServer.UpdateNodeInfo(nodeState) + s.Require().Equal(models.NodeStates.HEALTHY, nodeState.Connection) + } +} + +func (s *HeartbeatTestSuite) TestConcurrentHeartbeatsWithDisconnection() { + ctx := context.Background() + numNodes := 5 + numHeartbeatsPerNode := 50 + + var wg sync.WaitGroup + wg.Add(numNodes) + + for i := 0; i < numNodes; i++ { + go func(nodeID string) { + defer wg.Done() + client, err := NewClient(s.natsConn, nodeID, s.publisher) + require.NoError(s.T(), err) + + for j := 0; j < numHeartbeatsPerNode; j++ { + s.Require().NoError(client.SendHeartbeat(ctx, uint64(j))) + time.Sleep(time.Millisecond) + + if j == numHeartbeatsPerNode/2 { + // Simulate a disconnection by advancing the clock + s.clock.Add(10 * time.Second) + } + } + }(fmt.Sprintf("node-%d", i)) + } + + wg.Wait() + + // Allow time for all heartbeats to be processed + time.Sleep(100 * time.Millisecond) + + // Verify node states + for i := 0; i < numNodes; i++ { + nodeID := fmt.Sprintf("node-%d", i) + nodeState := &models.NodeState{Info: models.NodeInfo{NodeID: nodeID}} + s.heartbeatServer.UpdateNodeInfo(nodeState) + + // The exact state might vary depending on timing, but it should be either HEALTHY or DISCONNECTED + s.Require().Contains([]models.NodeConnectionState{models.NodeStates.HEALTHY, models.NodeStates.DISCONNECTED}, nodeState.Connection) + } +} + +func (s *HeartbeatTestSuite) TestConcurrentHeartbeatsAndChecks() { + ctx := context.Background() + numNodes := 5 + numHeartbeatsPerNode := 30 + checkInterval := 50 * time.Millisecond + + var wg sync.WaitGroup + wg.Add(numNodes + 1) // +1 for the checker goroutine + + // Start the checker goroutine + go func() { + defer wg.Done() + for i := 0; i < numHeartbeatsPerNode; i++ { + s.heartbeatServer.checkQueue(ctx) + time.Sleep(checkInterval) + } + }() + + for i := 0; i < numNodes; i++ { + go func(nodeID string) { + defer wg.Done() + client, err := NewClient(s.natsConn, nodeID, s.publisher) + require.NoError(s.T(), err) + + for j := 0; j < numHeartbeatsPerNode; j++ { + s.Require().NoError(client.SendHeartbeat(ctx, uint64(j))) + time.Sleep(checkInterval / 2) // Send heartbeats faster than checks + } + }(fmt.Sprintf("node-%d", i)) + } + + wg.Wait() + + // Verify final node states + for i := 0; i < numNodes; i++ { + nodeID := fmt.Sprintf("node-%d", i) + nodeState := &models.NodeState{Info: models.NodeInfo{NodeID: nodeID}} + s.heartbeatServer.UpdateNodeInfo(nodeState) + s.Require().Equal(models.NodeStates.HEALTHY, nodeState.Connection) + } +} diff --git a/pkg/node/heartbeat/server.go b/pkg/node/heartbeat/server.go index 5d77751ff3..ef3d901f69 100644 --- a/pkg/node/heartbeat/server.go +++ b/pkg/node/heartbeat/server.go @@ -113,7 +113,7 @@ func (h *HeartbeatServer) Start(ctx context.Context) error { case <-ctx.Done(): return case <-ticker.C: - h.CheckQueue(ctx) + h.checkQueue(ctx) } } }(ctx) @@ -125,34 +125,49 @@ func (h *HeartbeatServer) Start(ctx context.Context) error { return nil } -// CheckQueue will check the queue for old heartbeats that might make a node's +// checkQueue will check the queue for old heartbeats that might make a node's // liveness either unhealthy or unknown, and will update the node's status accordingly. -func (h *HeartbeatServer) CheckQueue(ctx context.Context) { - // These are the timestamps, below which we'll consider the item in one of those two - // states - nowStamp := h.clock.Now().UTC().Unix() - disconnectedUnder := nowStamp - int64(h.disconnectedAfter.Seconds()) +// This method is not thread-safe and should be called from a single goroutine. +func (h *HeartbeatServer) checkQueue(ctx context.Context) { + // Calculate the timestamp threshold for considering a node as disconnected + disconnectedUnder := h.clock.Now().Add(-h.disconnectedAfter).UTC().Unix() for { - // Dequeue anything older than the unknown timestamp - item := h.pqueue.DequeueWhere(func(item TimestampedHeartbeat) bool { - return item.Timestamp < disconnectedUnder - }) - - // We haven't found anything old enough yet. We can stop the loop and wait - // for the next cycle. - if item == nil { + // Peek at the next (oldest) item in the queue + peek := h.pqueue.Peek() + + // If the queue is empty, we're done + if peek == nil { + break + } + + // If the oldest item is recent enough, we're done + log.Ctx(ctx).Trace(). + Dur("LastHeartbeatAge", h.clock.Now().Sub(time.Unix(peek.Value.Timestamp, 0))). + Msgf("Peeked at %+v", peek) + if peek.Value.Timestamp >= disconnectedUnder { break } + // Dequeue the item and mark the node as disconnected + item := h.pqueue.Dequeue() + if item == nil || item.Value.Timestamp >= disconnectedUnder { + // This should never happen, but we'll check just in case + log.Warn().Msgf("Unexpected item dequeued: %+v didn't match previously peeked item: %+v", item, peek) + continue + } + if item.Value.NodeID == h.nodeID { // We don't want to mark ourselves as disconnected continue } - if item.Value.Timestamp < disconnectedUnder { - h.markNodeAs(item.Value.NodeID, models.NodeStates.DISCONNECTED) - } + log.Ctx(ctx).Debug(). + Str("NodeID", item.Value.NodeID). + Int64("LastHeartbeat", item.Value.Timestamp). + Dur("LastHeartbeatAge", h.clock.Now().Sub(time.Unix(item.Value.Timestamp, 0))). + Msg("Marking node as disconnected") + h.markNodeAs(item.Value.NodeID, models.NodeStates.DISCONNECTED) } } @@ -200,36 +215,14 @@ func (h *HeartbeatServer) ShouldProcess(ctx context.Context, message *ncl.Messag // Handle will handle a message received through the legacy heartbeat topic func (h *HeartbeatServer) Handle(ctx context.Context, heartbeat Heartbeat) error { - log.Ctx(ctx).Trace().Msgf("heartbeat received from %s", heartbeat.NodeID) - timestamp := h.clock.Now().UTC().Unix() + th := TimestampedHeartbeat{Heartbeat: heartbeat, Timestamp: timestamp} + log.Ctx(ctx).Trace().Msgf("Enqueueing heartbeat from %s with seq %d. %+v", th.NodeID, th.Sequence, th) - if h.pqueue.Contains(heartbeat.NodeID) { - // If we think we already have a heartbeat from this node, we'll update the - // timestamp of the entry so it is re-prioritized in the queue by dequeuing - // and re-enqueuing it (this will ensure it is heapified correctly). - result := h.pqueue.DequeueWhere(func(item TimestampedHeartbeat) bool { - return item.NodeID == heartbeat.NodeID - }) - - if result == nil { - log.Ctx(ctx).Warn().Msgf("consistency error in heartbeat heap, node %s not found", heartbeat.NodeID) - return nil - } - - log.Ctx(ctx).Trace().Msgf("Re-enqueueing heartbeat from %s", heartbeat.NodeID) - result.Value.Timestamp = timestamp - h.pqueue.Enqueue(result.Value, timestamp) - } else { - log.Ctx(ctx).Trace().Msgf("Enqueueing heartbeat from %s", heartbeat.NodeID) - - // We'll enqueue the heartbeat message with the current timestamp. The older - // the entry, the lower the timestamp (trending to 0) and the higher the priority. - h.pqueue.Enqueue(TimestampedHeartbeat{Heartbeat: heartbeat, Timestamp: timestamp}, timestamp) - } - + // We'll enqueue the heartbeat message with the current timestamp in reverse priority so that + // older heartbeats are dequeued first. + h.pqueue.Enqueue(th, -timestamp) h.markNodeAs(heartbeat.NodeID, models.NodeStates.HEALTHY) - return nil }