Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WatchSet for watching a dynamic set of channels #72

Merged
merged 1 commit into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions watchset.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package statedb

import (
"context"
"slices"
"sync"
)

const watchSetChunkSize = 16

// WatchSet is a set of watch channels that can be waited on.
type WatchSet struct {
mu sync.Mutex
chans []<-chan struct{}
}

func NewWatchSet() *WatchSet {
return &WatchSet{
chans: make([]<-chan struct{}, 0, watchSetChunkSize),
}
}

// Add a channel to the watch set.
func (ws *WatchSet) Add(chans ...<-chan struct{}) {
ws.mu.Lock()
for _, ch := range chans {
ws.chans = append(ws.chans, ch)
}
ws.mu.Unlock()
}

func (ws *WatchSet) Clear() {
ws.mu.Lock()
ws.chans = ws.chans[:0]
ws.mu.Unlock()
}

// Wait for any channel in the watch set to close. The
// watch set is cleared when this method returns.
func (ws *WatchSet) Wait(ctx context.Context) error {
ws.mu.Lock()
defer func() {
ws.chans = ws.chans[:0]
ws.mu.Unlock()
}()

// No channels to watch? Just watch the context.
if len(ws.chans) == 0 {
<-ctx.Done()
return ctx.Err()
}

// Collect the channels into a slice. The slice length is rounded to a full
// chunk size.
chunkSize := 16
roundedSize := len(ws.chans) + (chunkSize - len(ws.chans)%chunkSize)
ws.chans = slices.Grow(ws.chans, roundedSize)[:roundedSize]

if len(ws.chans) <= chunkSize {
watch16(ctx.Done(), ws.chans)
return ctx.Err()
}

// More than one chunk. Fork goroutines to watch each chunk. The first chunk
// that completes will cancel the context and stop the other goroutines.
innerCtx, cancel := context.WithCancel(ctx)
defer cancel()

var wg sync.WaitGroup
for chunk := range slices.Chunk(ws.chans, chunkSize) {
wg.Add(1)
go func() {
defer cancel()
defer wg.Done()
chunk = slices.Clone(chunk)
watch16(innerCtx.Done(), chunk)
}()
}
wg.Wait()
return ctx.Err()
}

func watch16(stop <-chan struct{}, chans []<-chan struct{}) {
select {
case <-stop:
case <-chans[0]:
case <-chans[1]:
case <-chans[2]:
case <-chans[3]:
case <-chans[4]:
case <-chans[5]:
case <-chans[6]:
case <-chans[7]:
case <-chans[8]:
case <-chans[9]:
case <-chans[10]:
case <-chans[11]:
case <-chans[12]:
case <-chans[13]:
case <-chans[14]:
case <-chans[15]:
}
}
99 changes: 99 additions & 0 deletions watchset_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package statedb

import (
"context"
"testing"
"time"

"github.com/cilium/statedb/part"
"github.com/stretchr/testify/require"
)

func TestWatchSet(t *testing.T) {
t.Parallel()
// NOTE: TestMain calls goleak.VerifyTestMain so we know this test doesn't leak goroutines.

ws := NewWatchSet()

// Empty watch set, cancelled context.
ctx, cancel := context.WithCancel(context.Background())
go cancel()
err := ws.Wait(ctx)
require.ErrorIs(t, err, context.Canceled)

// Few channels, cancelled context.
ch1 := make(chan struct{})
ch2 := make(chan struct{})
ch3 := make(chan struct{})
ws.Add(ch1, ch2, ch3)
ctx, cancel = context.WithCancel(context.Background())
go cancel()
err = ws.Wait(ctx)
require.ErrorIs(t, err, context.Canceled)

// Many channels
for _, numChans := range []int{0, 1, 8, 12, 16, 31, 32, 61, 64, 121} {
for i := range numChans {
var chans []chan struct{}
var rchans []<-chan struct{}
for range numChans {
ch := make(chan struct{})
chans = append(chans, ch)
rchans = append(rchans, ch)
}
ws.Add(rchans...)

close(chans[i])
ctx, cancel = context.WithCancel(context.Background())
err = ws.Wait(ctx)
require.NoError(t, err)
cancel()
}
}
}

func TestWatchSetInQueries(t *testing.T) {
t.Parallel()
db, table := newTestDBWithMetrics(t, &NopMetrics{}, tagsIndex)

ws := NewWatchSet()
txn := db.ReadTxn()
_, watchAll := table.AllWatch(txn)

// Should timeout as watches should not have closed yet.
ws.Add(watchAll)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
err := ws.Wait(ctx)
require.ErrorIs(t, err, context.DeadlineExceeded)
cancel()

// Insert some objects
wtxn := db.WriteTxn(table)
table.Insert(wtxn, testObject{ID: 1})
table.Insert(wtxn, testObject{ID: 2})
table.Insert(wtxn, testObject{ID: 3})
txn = wtxn.Commit()

// The 'watchAll' channel should now have closed and Wait() returns.
ws.Add(watchAll)
err = ws.Wait(context.Background())
require.NoError(t, err)

// Try watching specific objects for changes.
_, _, watch1, _ := table.GetWatch(txn, idIndex.Query(1))
_, _, watch2, _ := table.GetWatch(txn, idIndex.Query(2))
_, _, watch3, _ := table.GetWatch(txn, idIndex.Query(3))
ws.Add(watch3, watch2, watch1)
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Millisecond)
err = ws.Wait(ctx)
require.ErrorIs(t, err, context.DeadlineExceeded)
cancel()

wtxn = db.WriteTxn(table)
table.Insert(wtxn, testObject{ID: 1, Tags: part.NewSet("foo")})
wtxn.Commit()

ws.Add(watch3, watch2, watch1)
err = ws.Wait(context.Background())
require.NoError(t, err)
}
Loading