diff --git a/internal/broker/conn.go b/internal/broker/conn.go index dc2d5ee5..324f2d76 100644 --- a/internal/broker/conn.go +++ b/internal/broker/conn.go @@ -313,7 +313,8 @@ func (c *Conn) sendResponse(topic string, resp response, requestID uint16) { func (c *Conn) CanSubscribe(ssid message.Ssid, channel []byte) bool { c.Lock() defer c.Unlock() - return c.subs.Increment(ssid, channel) + + return c.subs.IncrementOnce(ssid, channel) } // CanUnsubscribe decrements the internal counters and checks if the cluster diff --git a/internal/message/sub.go b/internal/message/sub.go index 9cbb7848..b7e50627 100644 --- a/internal/message/sub.go +++ b/internal/message/sub.go @@ -249,6 +249,19 @@ func (s *Counters) Increment(ssid Ssid, channel []byte) (first bool) { return m.Counter == 1 } +// IncrementOnce increments the subscription counter. +func (s *Counters) IncrementOnce(ssid Ssid, channel []byte) (first bool) { + s.Lock() + defer s.Unlock() + + m := s.getOrCreate(ssid, channel) + first = m.Counter == 0 + if first { + m.Counter++ + } + return first +} + // Decrement decrements a subscription counter. func (s *Counters) Decrement(ssid Ssid) (last bool) { s.Lock() diff --git a/internal/message/sub_test.go b/internal/message/sub_test.go index 4b3eb6b6..66142e86 100644 --- a/internal/message/sub_test.go +++ b/internal/message/sub_test.go @@ -181,6 +181,24 @@ func TestSub_Increment(t *testing.T) { assert.True(t, isDecremented) } +func TestSub_IncrementOnce(t *testing.T) { + // Preparation. + counters := NewCounters() + ssid1 := make([]uint32, 1) + key1 := (Ssid(ssid1)).GetHashCode() + + counters.getOrCreate(ssid1, []byte("test")) + + // Test previously created counter. + isFirst := counters.IncrementOnce(ssid1, []byte("test")) + assert.True(t, isFirst) + assert.Equal(t, 1, counters.m[key1].Counter) + + isFirst = counters.IncrementOnce(ssid1, []byte("test")) + assert.False(t, isFirst) + assert.Equal(t, 1, counters.m[key1].Counter) +} + func TestCollisions(t *testing.T) { subs := newSubscribers() count := 100000