diff --git a/pkg/cache/v3/simple.go b/pkg/cache/v3/simple.go index 6e5cebcfbd..c34eca3cff 100644 --- a/pkg/cache/v3/simple.go +++ b/pkg/cache/v3/simple.go @@ -233,54 +233,95 @@ func (cache *snapshotCache) SetSnapshot(ctx context.Context, node string, snapsh info.mu.Lock() defer info.mu.Unlock() - // responder callback for SOTW watches - respond := func(watch ResponseWatch, id int64) error { - version := snapshot.GetVersion(watch.Request.TypeUrl) - if version != watch.Request.VersionInfo { - cache.log.Debugf("respond open watch %d %s%v with new version %q", id, watch.Request.TypeUrl, watch.Request.ResourceNames, version) - resources := snapshot.GetResourcesAndTTL(watch.Request.TypeUrl) - err := cache.respond(ctx, watch.Request, watch.Response, resources, version, false) - if err != nil { - return err - } - // discard the watch - delete(info.watches, id) + // Respond to SOTW watches for the node. + if err := cache.respondSOTWWatches(ctx, info, snapshot); err != nil { + return err + } + + // Respond to delta watches for the node. + return cache.respondDeltaWatches(ctx, info, snapshot) + } + + return nil +} + +func (cache *snapshotCache) respondSOTWWatches(ctx context.Context, info *statusInfo, snapshot ResourceSnapshot) error { + // responder callback for SOTW watches + respond := func(watch ResponseWatch, id int64) error { + version := snapshot.GetVersion(watch.Request.TypeUrl) + if version != watch.Request.VersionInfo { + cache.log.Debugf("respond open watch %d %s%v with new version %q", id, watch.Request.TypeUrl, watch.Request.ResourceNames, version) + resources := snapshot.GetResourcesAndTTL(watch.Request.TypeUrl) + err := cache.respond(ctx, watch.Request, watch.Response, resources, version, false) + if err != nil { + return err } - return nil + // discard the watch + delete(info.watches, id) } + return nil + } - // If ADS is enabled we need to order response watches so we guarantee - // sending them in the correct order. Go's default implementation - // of maps are randomized order when ranged over. - if cache.ads { - info.orderResponseWatches() - for _, key := range info.orderedWatches { - err := respond(info.watches[key.ID], key.ID) - if err != nil { - return err - } + // If ADS is enabled we need to order response watches so we guarantee + // sending them in the correct order. Go's default implementation + // of maps are randomized order when ranged over. + if cache.ads { + info.orderResponseWatches() + for _, key := range info.orderedWatches { + err := respond(info.watches[key.ID], key.ID) + if err != nil { + return err } - } else { - for id, watch := range info.watches { - err := respond(watch, id) - if err != nil { - return err - } + } + } else { + for id, watch := range info.watches { + err := respond(watch, id) + if err != nil { + return err } } + } + + return nil +} + +func (cache *snapshotCache) respondDeltaWatches(ctx context.Context, info *statusInfo, snapshot ResourceSnapshot) error { + // We only calculate version hashes when using delta. We don't + // want to do this when using SOTW so we can avoid unnecessary + // computational cost if not using delta. + if len(info.deltaWatches) == 0 { + return nil + } + + err := snapshot.ConstructVersionMap() + if err != nil { + return err + } - // We only calculate version hashes when using delta. We don't - // want to do this when using SOTW so we can avoid unnecessary - // computational cost if not using delta. - if len(info.deltaWatches) > 0 { - err := snapshot.ConstructVersionMap() + // If ADS is enabled we need to order response delta watches so we guarantee + // sending them in the correct order. Go's default implementation + // of maps are randomized order when ranged over. + if cache.ads { + info.orderResponseDeltaWatches() + for _, key := range info.orderedDeltaWatches { + watch := info.deltaWatches[key.ID] + res, err := cache.respondDelta( + ctx, + snapshot, + watch.Request, + watch.Response, + watch.StreamState, + ) if err != nil { return err } + // If we detect a nil response here, that means there has been no state change + // so we don't want to respond or remove any existing resource watches + if res != nil { + delete(info.deltaWatches, key.ID) + } } - - // this won't run if there are no delta watches - // to process. + } else { for id, watch := range info.deltaWatches { res, err := cache.respondDelta( ctx, @@ -299,7 +340,6 @@ func (cache *snapshotCache) SetSnapshot(ctx context.Context, node string, snapsh } } } - return nil } diff --git a/pkg/cache/v3/status.go b/pkg/cache/v3/status.go index 1b3e8f490b..e50f85beff 100644 --- a/pkg/cache/v3/status.go +++ b/pkg/cache/v3/status.go @@ -70,7 +70,8 @@ type statusInfo struct { orderedWatches keys // deltaWatches are indexed channels for the delta response watches and the original requests - deltaWatches map[int64]DeltaResponseWatch + deltaWatches map[int64]DeltaResponseWatch + orderedDeltaWatches keys // the timestamp of the last watch request lastWatchRequestTime time.Time @@ -177,3 +178,22 @@ func (info *statusInfo) orderResponseWatches() { // This is only run when we enable ADS on the cache. sort.Sort(info.orderedWatches) } + +// orderResponseDeltaWatches will track a list of delta watch keys and order them if +// true is passed. +func (info *statusInfo) orderResponseDeltaWatches() { + info.orderedDeltaWatches = make(keys, len(info.deltaWatches)) + + var index int + for id, deltaWatch := range info.deltaWatches { + info.orderedDeltaWatches[index] = key{ + ID: id, + TypeURL: deltaWatch.Request.TypeUrl, + } + index++ + } + + // Sort our list which we can use in the SetSnapshot functions. + // This is only run when we enable ADS on the cache. + sort.Sort(info.orderedDeltaWatches) +} diff --git a/pkg/server/delta/v3/server.go b/pkg/server/delta/v3/server.go index 85935db1ca..74f13e3505 100644 --- a/pkg/server/delta/v3/server.go +++ b/pkg/server/delta/v3/server.go @@ -83,7 +83,7 @@ func (s *server) processDelta(str stream.DeltaStream, reqCh <-chan *discovery.De } }() - // Sends a response, returns the new stream nonce + // sends a response, returns the new stream nonce send := func(resp cache.DeltaResponse) (string, error) { if resp == nil { return "", errors.New("missing response") @@ -103,6 +103,44 @@ func (s *server) processDelta(str stream.DeltaStream, reqCh <-chan *discovery.De return response.Nonce, str.Send(response) } + // process a single delta response + process := func(resp cache.DeltaResponse) error { + typ := resp.GetDeltaRequest().GetTypeUrl() + if resp == deltaErrorResponse { + return status.Errorf(codes.Unavailable, typ+" watch failed") + } + + nonce, err := send(resp) + if err != nil { + return err + } + + watch := watches.deltaWatches[typ] + watch.nonce = nonce + + watch.state.SetResourceVersions(resp.GetNextVersionMap()) + watches.deltaWatches[typ] = watch + return nil + } + + // processAll purges the deltaMuxedResponses channel + processAll := func() error { + for { + select { + // We watch the multiplexed channel for incoming responses. + case resp, more := <-watches.deltaMuxedResponses: + if !more { + break + } + if err := process(resp); err != nil { + return err + } + default: + return nil + } + } + } + if s.callbacks != nil { if err := s.callbacks.OnDeltaStreamOpen(str.Context(), streamID, defaultTypeURL); err != nil { return err @@ -113,35 +151,31 @@ func (s *server) processDelta(str stream.DeltaStream, reqCh <-chan *discovery.De select { case <-s.ctx.Done(): return nil + // We watch the multiplexed channel for incoming responses. case resp, more := <-watches.deltaMuxedResponses: + // input stream ended or errored out if !more { break } - typ := resp.GetDeltaRequest().GetTypeUrl() - if resp == deltaErrorResponse { - return status.Errorf(codes.Unavailable, typ+" watch failed") - } - - nonce, err := send(resp) - if err != nil { + if err := process(resp); err != nil { return err } - - watch := watches.deltaWatches[typ] - watch.nonce = nonce - - watch.state.SetResourceVersions(resp.GetNextVersionMap()) - watches.deltaWatches[typ] = watch case req, more := <-reqCh: // input stream ended or errored out if !more { return nil } + if req == nil { return status.Errorf(codes.Unavailable, "empty request") } + // make sure all existing responses are processed prior to new requests to avoid deadlock + if err := processAll(); err != nil { + return err + } + if s.callbacks != nil { if err := s.callbacks.OnStreamDeltaRequest(streamID, req); err != nil { return err @@ -184,16 +218,8 @@ func (s *server) processDelta(str stream.DeltaStream, reqCh <-chan *discovery.De s.subscribe(req.GetResourceNamesSubscribe(), &watch.state) s.unsubscribe(req.GetResourceNamesUnsubscribe(), &watch.state) - watch.responses = make(chan cache.DeltaResponse, 1) - watch.cancel = s.cache.CreateDeltaWatch(req, watch.state, watch.responses) + watch.cancel = s.cache.CreateDeltaWatch(req, watch.state, watches.deltaMuxedResponses) watches.deltaWatches[typeURL] = watch - - go func() { - resp, more := <-watch.responses - if more { - watches.deltaMuxedResponses <- resp - } - }() } } } diff --git a/pkg/server/delta/v3/watches.go b/pkg/server/delta/v3/watches.go index c88548388a..63c4c2d38d 100644 --- a/pkg/server/delta/v3/watches.go +++ b/pkg/server/delta/v3/watches.go @@ -17,9 +17,13 @@ type watches struct { // newWatches creates and initializes watches. func newWatches() watches { // deltaMuxedResponses needs a buffer to release go-routines populating it + // + // because deltaMuxedResponses can be populated by an update from the cache + // and a request from the client, we need to create the channel with a buffer + // size of 2x the number of types to avoid deadlocks. return watches{ deltaWatches: make(map[string]watch, int(types.UnknownType)), - deltaMuxedResponses: make(chan cache.DeltaResponse, int(types.UnknownType)), + deltaMuxedResponses: make(chan cache.DeltaResponse, int(types.UnknownType)*2), } } @@ -28,13 +32,14 @@ func (w *watches) Cancel() { for _, watch := range w.deltaWatches { watch.Cancel() } + + close(w.deltaMuxedResponses) } // watch contains the necessary modifiables for receiving resource responses type watch struct { - responses chan cache.DeltaResponse - cancel func() - nonce string + cancel func() + nonce string state stream.StreamState } @@ -44,9 +49,4 @@ func (w *watch) Cancel() { if w.cancel != nil { w.cancel() } - if w.responses != nil { - // w.responses should never be used by a producer once cancel() has been closed, so we can safely close it here - // This is needed to release resources taken by goroutines watching this channel - close(w.responses) - } } diff --git a/pkg/server/delta/v3/watches_test.go b/pkg/server/delta/v3/watches_test.go index 104b979be3..cee0985ebd 100644 --- a/pkg/server/delta/v3/watches_test.go +++ b/pkg/server/delta/v3/watches_test.go @@ -5,8 +5,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - - "github.com/envoyproxy/go-control-plane/pkg/cache/v3" ) func TestDeltaWatches(t *testing.T) { @@ -14,14 +12,11 @@ func TestDeltaWatches(t *testing.T) { watches := newWatches() cancelCount := 0 - var channels []chan cache.DeltaResponse // create a few watches, and ensure that the cancel function are called and the channels are closed for i := 0; i < 5; i++ { newWatch := watch{} if i%2 == 0 { newWatch.cancel = func() { cancelCount++ } - newWatch.responses = make(chan cache.DeltaResponse) - channels = append(channels, newWatch.responses) } watches.deltaWatches[strconv.Itoa(i)] = newWatch @@ -30,13 +25,5 @@ func TestDeltaWatches(t *testing.T) { watches.Cancel() assert.Equal(t, 3, cancelCount) - for _, channel := range channels { - select { - case _, ok := <-channel: - assert.False(t, ok, "a channel was not closed") - default: - assert.Fail(t, "a channel was not closed") - } - } }) } diff --git a/pkg/server/v3/delta_test.go b/pkg/server/v3/delta_test.go index f8429f2997..870b0a85fd 100644 --- a/pkg/server/v3/delta_test.go +++ b/pkg/server/v3/delta_test.go @@ -345,6 +345,8 @@ func TestDeltaAggregatedHandlers(t *testing.T) { resp.recv <- r } + // We create the server with the optional ordered ADS flag so we guarantee resource + // ordering over the stream. s := server.NewServer(context.Background(), config, server.CallbackFuncs{}) go func() { err := s.DeltaAggregatedResources(resp)