From 14b4c0f3c2ef4315c5c24dd6d5b93bd66cd3b06f Mon Sep 17 00:00:00 2001 From: Jussi Maki Date: Tue, 28 Jan 2025 09:52:47 +0100 Subject: [PATCH] Add WatchSet for watching a dynamic set of channels When writing a controller that needs to react to specific queries changing it's useful to be able to watch a dynamic set of channels. Add the [WatchSet] utility for this. Example use: var names []string = ... ws := statedb.NewWatchSet() for { for _, name := range names { things, watch := myTable.ListWatch(txn, ThingsByName(name)) ws.Add(watch) processThings(things) } // Wait for things that have a name in 'names' to change. if err := ws.Wait(ctx); err != nil { // Context cancelled break } } Signed-off-by: Jussi Maki --- watchset.go | 103 +++++++++++++++++++++++++++++++++++++++++++++++ watchset_test.go | 99 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 202 insertions(+) create mode 100644 watchset.go create mode 100644 watchset_test.go diff --git a/watchset.go b/watchset.go new file mode 100644 index 0000000..41d6d2c --- /dev/null +++ b/watchset.go @@ -0,0 +1,103 @@ +package statedb + +import ( + "context" + "slices" + "sync" +) + +const watchSetChunkSize = 16 + +// WatchSet is a set of watch channels that can be waited on. +type WatchSet struct { + mu sync.Mutex + chans []<-chan struct{} +} + +func NewWatchSet() *WatchSet { + return &WatchSet{ + chans: make([]<-chan struct{}, 0, watchSetChunkSize), + } +} + +// Add a channel to the watch set. +func (ws *WatchSet) Add(chans ...<-chan struct{}) { + ws.mu.Lock() + for _, ch := range chans { + ws.chans = append(ws.chans, ch) + } + ws.mu.Unlock() +} + +func (ws *WatchSet) Clear() { + ws.mu.Lock() + ws.chans = ws.chans[:0] + ws.mu.Unlock() +} + +// Wait for any channel in the watch set to close. The +// watch set is cleared when this method returns. +func (ws *WatchSet) Wait(ctx context.Context) error { + ws.mu.Lock() + defer func() { + ws.chans = ws.chans[:0] + ws.mu.Unlock() + }() + + // No channels to watch? Just watch the context. + if len(ws.chans) == 0 { + <-ctx.Done() + return ctx.Err() + } + + // Collect the channels into a slice. The slice length is rounded to a full + // chunk size. + chunkSize := 16 + roundedSize := len(ws.chans) + (chunkSize - len(ws.chans)%chunkSize) + ws.chans = slices.Grow(ws.chans, roundedSize)[:roundedSize] + + if len(ws.chans) <= chunkSize { + watch16(ctx.Done(), ws.chans) + return ctx.Err() + } + + // More than one chunk. Fork goroutines to watch each chunk. The first chunk + // that completes will cancel the context and stop the other goroutines. + innerCtx, cancel := context.WithCancel(ctx) + defer cancel() + + var wg sync.WaitGroup + for chunk := range slices.Chunk(ws.chans, chunkSize) { + wg.Add(1) + go func() { + defer cancel() + defer wg.Done() + chunk = slices.Clone(chunk) + watch16(innerCtx.Done(), chunk) + }() + } + wg.Wait() + return ctx.Err() +} + +func watch16(stop <-chan struct{}, chans []<-chan struct{}) { + select { + case <-stop: + case <-chans[0]: + case <-chans[1]: + case <-chans[2]: + case <-chans[3]: + case <-chans[4]: + case <-chans[5]: + case <-chans[6]: + case <-chans[7]: + case <-chans[8]: + case <-chans[9]: + case <-chans[10]: + case <-chans[11]: + case <-chans[12]: + case <-chans[13]: + case <-chans[14]: + case <-chans[15]: + } +} diff --git a/watchset_test.go b/watchset_test.go new file mode 100644 index 0000000..2f57993 --- /dev/null +++ b/watchset_test.go @@ -0,0 +1,99 @@ +package statedb + +import ( + "context" + "testing" + "time" + + "github.com/cilium/statedb/part" + "github.com/stretchr/testify/require" +) + +func TestWatchSet(t *testing.T) { + t.Parallel() + // NOTE: TestMain calls goleak.VerifyTestMain so we know this test doesn't leak goroutines. + + ws := NewWatchSet() + + // Empty watch set, cancelled context. + ctx, cancel := context.WithCancel(context.Background()) + go cancel() + err := ws.Wait(ctx) + require.ErrorIs(t, err, context.Canceled) + + // Few channels, cancelled context. + ch1 := make(chan struct{}) + ch2 := make(chan struct{}) + ch3 := make(chan struct{}) + ws.Add(ch1, ch2, ch3) + ctx, cancel = context.WithCancel(context.Background()) + go cancel() + err = ws.Wait(ctx) + require.ErrorIs(t, err, context.Canceled) + + // Many channels + for _, numChans := range []int{0, 1, 8, 12, 16, 31, 32, 61, 64, 121} { + for i := range numChans { + var chans []chan struct{} + var rchans []<-chan struct{} + for range numChans { + ch := make(chan struct{}) + chans = append(chans, ch) + rchans = append(rchans, ch) + } + ws.Add(rchans...) + + close(chans[i]) + ctx, cancel = context.WithCancel(context.Background()) + err = ws.Wait(ctx) + require.NoError(t, err) + cancel() + } + } +} + +func TestWatchSetInQueries(t *testing.T) { + t.Parallel() + db, table := newTestDBWithMetrics(t, &NopMetrics{}, tagsIndex) + + ws := NewWatchSet() + txn := db.ReadTxn() + _, watchAll := table.AllWatch(txn) + + // Should timeout as watches should not have closed yet. + ws.Add(watchAll) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + err := ws.Wait(ctx) + require.ErrorIs(t, err, context.DeadlineExceeded) + cancel() + + // Insert some objects + wtxn := db.WriteTxn(table) + table.Insert(wtxn, testObject{ID: 1}) + table.Insert(wtxn, testObject{ID: 2}) + table.Insert(wtxn, testObject{ID: 3}) + txn = wtxn.Commit() + + // The 'watchAll' channel should now have closed and Wait() returns. + ws.Add(watchAll) + err = ws.Wait(context.Background()) + require.NoError(t, err) + + // Try watching specific objects for changes. + _, _, watch1, _ := table.GetWatch(txn, idIndex.Query(1)) + _, _, watch2, _ := table.GetWatch(txn, idIndex.Query(2)) + _, _, watch3, _ := table.GetWatch(txn, idIndex.Query(3)) + ws.Add(watch3, watch2, watch1) + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Millisecond) + err = ws.Wait(ctx) + require.ErrorIs(t, err, context.DeadlineExceeded) + cancel() + + wtxn = db.WriteTxn(table) + table.Insert(wtxn, testObject{ID: 1, Tags: part.NewSet("foo")}) + wtxn.Commit() + + ws.Add(watch3, watch2, watch1) + err = ws.Wait(context.Background()) + require.NoError(t, err) +}