From decbb66be1e3d71fc673553775bb5c1c8f83b784 Mon Sep 17 00:00:00 2001 From: Raven Black Date: Fri, 20 Sep 2024 19:16:59 -0400 Subject: [PATCH] Refactor cache_filter to expect caches to post cb (#36184) Refactor cache_filter to expect caches to post cb Additional Description: This is a bit of an unintuitive change in that it moves some work from the common class to the plugin, meaning that work will be duplicated. However, there's a good reason for this - if the cache class needs to use a dispatcher to get back onto its own thread, having cache_filter also post means that actions end up being queued in the dispatcher twice. A different possible solution would be to have the cache_filter callbacks only post if the callback comes in on the wrong thread, but there's a wrinkle in that model too - if the callback executes immediately, on the same thread, as was the case with the simple_http_cache, it executes too soon, trying to resume a connection that hasn't yet stopped, which is an error. That, too, could be covered with *another* workaround, either intercepting when that happens and posting the resume, or intercepting when that happens and replacing the resume with returning `Continue` instead, but both of those options make the cache filter itself more complicated (and therefore error prone). Having just one consistent path, where the cache implementation always posts the callback (and never calls it if cancelled), and the cache always performs the callback outside of the initial call's context and on its own thread, is the least complexity, and avoids the performance impact of posting twice, at a cost of a bit more verbosity in the simple cache implementation. This PR also wraps the UpdateHeadersCallback into a declared type, and makes it an `AnyInvocable` instead of a `std::function`, which enforces that callbacks are called only once and that they're moved not copied, avoiding accidental performance drains. Risk Level: Low; WIP filter, existing tests still pass. Testing: Existing tests should be covering all cases. Added tests to enforce that all cache implementations' `LookupContext` correctly posts callback actions, and correctly cancels calling the callback if the context is deleted before the post resolves. Docs Changes: Code-comments only. Release Notes: Maybe? Platform Specific Features: n/a --------- Signed-off-by: Raven Black --- .../filters/http/cache/cache_filter.cc | 69 ++----- .../filters/http/cache/cache_filter.h | 2 + .../filters/http/cache/cache_insert_queue.cc | 114 +++++------ .../filters/http/cache/cache_insert_queue.h | 2 +- .../filters/http/cache/http_cache.h | 40 ++-- .../file_system_http_cache.cc | 17 +- .../file_system_http_cache.h | 3 +- .../file_system_http_cache/insert_context.cc | 11 +- .../file_system_http_cache/insert_context.h | 1 - .../file_system_http_cache/lookup_context.cc | 4 +- .../simple_http_cache/simple_http_cache.cc | 98 ++++++--- .../simple_http_cache/simple_http_cache.h | 3 +- .../filters/http/cache/cache_filter_test.cc | 189 +++++++++++++----- .../http_cache_implementation_test_common.cc | 56 ++++++ .../http_cache_implementation_test_common.h | 3 + test/extensions/filters/http/cache/mocks.h | 2 +- 16 files changed, 382 insertions(+), 232 deletions(-) diff --git a/source/extensions/filters/http/cache/cache_filter.cc b/source/extensions/filters/http/cache/cache_filter.cc index 2e9c073a6680..88aed69f9f08 100644 --- a/source/extensions/filters/http/cache/cache_filter.cc +++ b/source/extensions/filters/http/cache/cache_filter.cc @@ -303,38 +303,19 @@ CacheFilter::resolveLookupStatus(absl::optional cache_entry_st void CacheFilter::getHeaders(Http::RequestHeaderMap& request_headers) { ASSERT(lookup_, "CacheFilter is trying to call getHeaders with no LookupContext"); - - // If the cache posts a callback to the dispatcher then the CacheFilter is destroyed for any - // reason (e.g client disconnected and HTTP stream terminated), then there is no guarantee that - // the posted callback will run before the filter is deleted. Hence, a weak_ptr to the CacheFilter - // is captured and used to make sure the CacheFilter is still alive before accessing it in the - // posted callback. - // TODO(yosrym93): Look into other options for handling this (also in getBody and getTrailers) as - // they arise, e.g. cancellable posts, guaranteed ordering of posted callbacks and deletions, etc. - CacheFilterWeakPtr self = weak_from_this(); - - // The dispatcher needs to be captured because there's no guarantee that - // decoder_callbacks_->dispatcher() is thread-safe. - lookup_->getHeaders([self, &request_headers, &dispatcher = decoder_callbacks_->dispatcher()]( + callback_called_directly_ = true; + lookup_->getHeaders([this, &request_headers, &dispatcher = decoder_callbacks_->dispatcher()]( LookupResult&& result, bool end_stream) { - // The callback is posted to the dispatcher to make sure it is called on the worker thread. - dispatcher.post([self, &request_headers, result = std::move(result), end_stream]() mutable { - if (CacheFilterSharedPtr cache_filter = self.lock()) { - cache_filter->onHeaders(std::move(result), request_headers, end_stream); - } - }); + ASSERT(!callback_called_directly_ && dispatcher.isThreadSafe(), + "caches must post the callback to the filter's dispatcher"); + onHeaders(std::move(result), request_headers, end_stream); }); + callback_called_directly_ = false; } void CacheFilter::getBody() { ASSERT(lookup_, "CacheFilter is trying to call getBody with no LookupContext"); ASSERT(!remaining_ranges_.empty(), "No reason to call getBody when there's no body to get."); - // If the cache posts a callback to the dispatcher then the CacheFilter is destroyed for any - // reason (e.g client disconnected and HTTP stream terminated), then there is no guarantee that - // the posted callback will run before the filter is deleted. Hence, a weak_ptr to the CacheFilter - // is captured and used to make sure the CacheFilter is still alive before accessing it in the - // posted callback. - CacheFilterWeakPtr self = weak_from_this(); // We don't want to request more than a buffer-size at a time from the cache. uint64_t fetch_size_limit = encoder_callbacks_->encoderBufferLimit(); @@ -347,41 +328,27 @@ void CacheFilter::getBody() { ? (remaining_ranges_[0].begin() + fetch_size_limit) : remaining_ranges_[0].end()}; - // The dispatcher needs to be captured because there's no guarantee that - // decoder_callbacks_->dispatcher() is thread-safe. - lookup_->getBody(fetch_range, [self, &dispatcher = decoder_callbacks_->dispatcher()]( + callback_called_directly_ = true; + lookup_->getBody(fetch_range, [this, &dispatcher = decoder_callbacks_->dispatcher()]( Buffer::InstancePtr&& body, bool end_stream) { - // The callback is posted to the dispatcher to make sure it is called on the worker thread. - dispatcher.post([self, body = std::move(body), end_stream]() mutable { - if (CacheFilterSharedPtr cache_filter = self.lock()) { - cache_filter->onBody(std::move(body), end_stream); - } - }); + ASSERT(!callback_called_directly_ && dispatcher.isThreadSafe(), + "caches must post the callback to the filter's dispatcher"); + onBody(std::move(body), end_stream); }); + callback_called_directly_ = false; } void CacheFilter::getTrailers() { ASSERT(lookup_, "CacheFilter is trying to call getTrailers with no LookupContext"); - // If the cache posts a callback to the dispatcher then the CacheFilter is destroyed for any - // reason (e.g client disconnected and HTTP stream terminated), then there is no guarantee that - // the posted callback will run before the filter is deleted. Hence, a weak_ptr to the CacheFilter - // is captured and used to make sure the CacheFilter is still alive before accessing it in the - // posted callback. - CacheFilterWeakPtr self = weak_from_this(); - - // The dispatcher needs to be captured because there's no guarantee that - // decoder_callbacks_->dispatcher() is thread-safe. - lookup_->getTrailers([self, &dispatcher = decoder_callbacks_->dispatcher()]( + callback_called_directly_ = true; + lookup_->getTrailers([this, &dispatcher = decoder_callbacks_->dispatcher()]( Http::ResponseTrailerMapPtr&& trailers) { - // The callback is posted to the dispatcher to make sure it is called on the worker thread. - // The lambda must be mutable as it captures trailers as a unique_ptr. - dispatcher.post([self, trailers = std::move(trailers)]() mutable { - if (CacheFilterSharedPtr cache_filter = self.lock()) { - cache_filter->onTrailers(std::move(trailers)); - } - }); + ASSERT(!callback_called_directly_ && dispatcher.isThreadSafe(), + "caches must post the callback to the filter's dispatcher"); + onTrailers(std::move(trailers)); }); + callback_called_directly_ = false; } void CacheFilter::onHeaders(LookupResult&& result, Http::RequestHeaderMap& request_headers, diff --git a/source/extensions/filters/http/cache/cache_filter.h b/source/extensions/filters/http/cache/cache_filter.h index 3641797b5c7a..923a1cf1975e 100644 --- a/source/extensions/filters/http/cache/cache_filter.h +++ b/source/extensions/filters/http/cache/cache_filter.h @@ -164,6 +164,8 @@ class CacheFilter : public Http::PassThroughFilter, FilterState filter_state_ = FilterState::Initial; bool is_head_request_ = false; + // This toggle is used to detect callbacks being called directly and not posted. + bool callback_called_directly_ = false; // The status of the insert operation or header update, or decision not to insert or update. // If it's too early to determine the final status, this is empty. absl::optional insert_status_; diff --git a/source/extensions/filters/http/cache/cache_insert_queue.cc b/source/extensions/filters/http/cache/cache_insert_queue.cc index 66cb9d41ea12..80de4b61280c 100644 --- a/source/extensions/filters/http/cache/cache_insert_queue.cc +++ b/source/extensions/filters/http/cache/cache_insert_queue.cc @@ -14,7 +14,7 @@ class CacheInsertFragment { // on_complete is called when the cache completes the operation. virtual void send(InsertContext& context, - std::function on_complete) PURE; + absl::AnyInvocable on_complete) PURE; virtual ~CacheInsertFragment() = default; }; @@ -27,14 +27,14 @@ class CacheInsertFragmentBody : public CacheInsertFragment { CacheInsertFragmentBody(const Buffer::Instance& buffer, bool end_stream) : buffer_(buffer), end_stream_(end_stream) {} - void - send(InsertContext& context, - std::function on_complete) override { + void send(InsertContext& context, + absl::AnyInvocable on_complete) + override { size_t sz = buffer_.length(); context.insertBody( std::move(buffer_), - [on_complete, end_stream = end_stream_, sz](bool cache_success) { - on_complete(cache_success, end_stream, sz); + [cb = std::move(on_complete), end_stream = end_stream_, sz](bool cache_success) mutable { + std::move(cb)(cache_success, end_stream, sz); }, end_stream_); } @@ -52,14 +52,15 @@ class CacheInsertFragmentTrailers : public CacheInsertFragment { Http::ResponseTrailerMapImpl::copyFrom(*trailers_, trailers); } - void - send(InsertContext& context, - std::function on_complete) override { + void send(InsertContext& context, + absl::AnyInvocable on_complete) + override { // While zero isn't technically true for the size of trailers, it doesn't // matter at this point because watermarks after the stream is complete // aren't useful. - context.insertTrailers( - *trailers_, [on_complete](bool cache_success) { on_complete(cache_success, true, 0); }); + context.insertTrailers(*trailers_, [cb = std::move(on_complete)](bool cache_success) mutable { + std::move(cb)(cache_success, true, 0); + }); } private: @@ -72,7 +73,7 @@ CacheInsertQueue::CacheInsertQueue(std::shared_ptr cache, : dispatcher_(encoder_callbacks.dispatcher()), insert_context_(std::move(insert_context)), low_watermark_bytes_(encoder_callbacks.encoderBufferLimit() / 2), high_watermark_bytes_(encoder_callbacks.encoderBufferLimit()), - encoder_callbacks_(encoder_callbacks), abort_callback_(abort), cache_(cache) {} + encoder_callbacks_(encoder_callbacks), abort_callback_(std::move(abort)), cache_(cache) {} void CacheInsertQueue::insertHeaders(const Http::ResponseHeaderMap& response_headers, const ResponseMetadata& metadata, bool end_stream) { @@ -123,59 +124,54 @@ void CacheInsertQueue::insertTrailers(const Http::ResponseTrailerMap& trailers) } void CacheInsertQueue::onFragmentComplete(bool cache_success, bool end_stream, size_t sz) { - // If the cache implementation is asynchronous, this may be called from whatever - // thread that cache implementation runs on. Therefore, we post it to the - // dispatcher to be certain any callbacks and updates are called on the filter's - // thread (and therefore we don't have to mutex-guard anything). - dispatcher_.post([this, cache_success, end_stream, sz]() { - fragment_in_flight_ = false; - if (aborting_) { - // Parent filter was destroyed, so we can quit this operation. - fragments_.clear(); - self_ownership_.reset(); - return; + ASSERT(dispatcher_.isThreadSafe()); + fragment_in_flight_ = false; + if (aborting_) { + // Parent filter was destroyed, so we can quit this operation. + fragments_.clear(); + self_ownership_.reset(); + return; + } + ASSERT(queue_size_bytes_ >= sz, "queue can't be emptied by more than its size"); + queue_size_bytes_ -= sz; + if (watermarked_ && queue_size_bytes_ <= low_watermark_bytes_) { + if (encoder_callbacks_.has_value()) { + encoder_callbacks_.value().get().onEncoderFilterBelowWriteBufferLowWatermark(); } - ASSERT(queue_size_bytes_ >= sz, "queue can't be emptied by more than its size"); - queue_size_bytes_ -= sz; - if (watermarked_ && queue_size_bytes_ <= low_watermark_bytes_) { + watermarked_ = false; + } + if (!cache_success) { + // canceled by cache; unwatermark if necessary, inform the filter if + // it's still around, and delete the queue. + if (watermarked_) { if (encoder_callbacks_.has_value()) { encoder_callbacks_.value().get().onEncoderFilterBelowWriteBufferLowWatermark(); } watermarked_ = false; } - if (!cache_success) { - // canceled by cache; unwatermark if necessary, inform the filter if - // it's still around, and delete the queue. - if (watermarked_) { - if (encoder_callbacks_.has_value()) { - encoder_callbacks_.value().get().onEncoderFilterBelowWriteBufferLowWatermark(); - } - watermarked_ = false; - } - fragments_.clear(); - // Clearing self-ownership might provoke the destructor, so take a copy of the - // abort callback to avoid reading from 'this' after it may be deleted. - auto abort_callback = abort_callback_; - self_ownership_.reset(); - abort_callback(); - return; - } - if (end_stream) { - ASSERT(fragments_.empty(), "ending a stream with the queue not empty is a bug"); - ASSERT(!watermarked_, "being over the high watermark when the queue is empty makes no sense"); - self_ownership_.reset(); - return; - } - if (!fragments_.empty()) { - // If there's more in the queue, push the next fragment to the cache. - auto fragment = std::move(fragments_.front()); - fragments_.pop_front(); - fragment_in_flight_ = true; - fragment->send(*insert_context_, [this](bool cache_success, bool end_stream, size_t sz) { - onFragmentComplete(cache_success, end_stream, sz); - }); - } - }); + fragments_.clear(); + // Clearing self-ownership might provoke the destructor, so take a copy of the + // abort callback to avoid reading from 'this' after it may be deleted. + auto abort_callback = std::move(abort_callback_); + self_ownership_.reset(); + std::move(abort_callback)(); + return; + } + if (end_stream) { + ASSERT(fragments_.empty(), "ending a stream with the queue not empty is a bug"); + ASSERT(!watermarked_, "being over the high watermark when the queue is empty makes no sense"); + self_ownership_.reset(); + return; + } + if (!fragments_.empty()) { + // If there's more in the queue, push the next fragment to the cache. + auto fragment = std::move(fragments_.front()); + fragments_.pop_front(); + fragment_in_flight_ = true; + fragment->send(*insert_context_, [this](bool cache_success, bool end_stream, size_t sz) { + onFragmentComplete(cache_success, end_stream, sz); + }); + } } void CacheInsertQueue::setSelfOwned(std::unique_ptr self) { diff --git a/source/extensions/filters/http/cache/cache_insert_queue.h b/source/extensions/filters/http/cache/cache_insert_queue.h index feae50414a63..22297a70a528 100644 --- a/source/extensions/filters/http/cache/cache_insert_queue.h +++ b/source/extensions/filters/http/cache/cache_insert_queue.h @@ -12,7 +12,7 @@ namespace Cache { using OverHighWatermarkCallback = std::function; using UnderLowWatermarkCallback = std::function; -using AbortInsertCallback = std::function; +using AbortInsertCallback = absl::AnyInvocable; class CacheInsertFragment; // This queue acts as an intermediary between CacheFilter and the cache diff --git a/source/extensions/filters/http/cache/http_cache.h b/source/extensions/filters/http/cache/http_cache.h index 59face7598bb..e01fbf7d162a 100644 --- a/source/extensions/filters/http/cache/http_cache.h +++ b/source/extensions/filters/http/cache/http_cache.h @@ -122,19 +122,20 @@ struct CacheInfo { bool supports_range_requests_ = false; }; -using LookupBodyCallback = std::function; -using LookupHeadersCallback = std::function; -using LookupTrailersCallback = std::function; -using InsertCallback = std::function; +using LookupBodyCallback = absl::AnyInvocable; +using LookupHeadersCallback = absl::AnyInvocable; +using LookupTrailersCallback = absl::AnyInvocable; +using InsertCallback = absl::AnyInvocable; +using UpdateHeadersCallback = absl::AnyInvocable; // Manages the lifetime of an insertion. class InsertContext { public: // Accepts response_headers for caching. Only called once. // - // Implementations MUST call insert_complete(true) on success, or - // insert_complete(false) to attempt to abort the insertion. This - // call may be made asynchronously, but any async operation that can + // Implementations MUST post to the filter's dispatcher insert_complete(true) + // on success, or insert_complete(false) to attempt to abort the insertion. + // This call may be made asynchronously, but any async operation that can // potentially silently fail must include a timeout, to avoid memory leaks. virtual void insertHeaders(const Http::ResponseHeaderMap& response_headers, const ResponseMetadata& metadata, InsertCallback insert_complete, @@ -149,17 +150,17 @@ class InsertContext { // InsertContextPtr. A cache can abort the insertion by passing 'false' into // ready_for_next_fragment. // - // The cache implementation MUST call ready_for_next_fragment. This call may be - // made asynchronously, but any async operation that can potentially silently - // fail must include a timeout, to avoid memory leaks. + // The cache implementation MUST post ready_for_next_fragment to the filter's + // dispatcher. This post may be made asynchronously, but any async operation + // that can potentially silently fail must include a timeout, to avoid memory leaks. virtual void insertBody(const Buffer::Instance& fragment, InsertCallback ready_for_next_fragment, bool end_stream) PURE; // Inserts trailers into the cache. // - // The cache implementation MUST call insert_complete. This call may be - // made asynchronously, but any async operation that can potentially silently - // fail must include a timeout, to avoid memory leaks. + // The cache implementation MUST post insert_complete to the filter's dispatcher. + // This call may be made asynchronously, but any async operation that can + // potentially silently fail must include a timeout, to avoid memory leaks. virtual void insertTrailers(const Http::ResponseTrailerMap& trailers, InsertCallback insert_complete) PURE; @@ -199,6 +200,9 @@ class LookupContext { // implementation should wait until that is known before calling the callback, // and must pass a LookupResult with range_details_->satisfiable_ = false // if the request is invalid. + // + // A cache that posts the callback must wrap it such that if the LookupContext is + // destroyed before the callback is executed, the callback is not executed. virtual void getHeaders(LookupHeadersCallback&& cb) PURE; // Reads the next fragment from the cache, calling cb when the fragment is ready. @@ -228,11 +232,17 @@ class LookupContext { // getBody requests bytes 0-23 .......... callback with bytes 0-9 // getBody requests bytes 10-23 .......... callback with bytes 10-19 // getBody requests bytes 20-23 .......... callback with bytes 20-23 + // + // A cache that posts the callback must wrap it such that if the LookupContext is + // destroyed before the callback is executed, the callback is not executed. virtual void getBody(const AdjustedByteRange& range, LookupBodyCallback&& cb) PURE; // Get the trailers from the cache. Only called if the request reached the end of // the body and LookupBodyCallback did not pass true for end_stream. The // Http::ResponseTrailerMapPtr passed to cb must not be null. + // + // A cache that posts the callback must wrap it such that if the LookupContext is + // destroyed before the callback is executed, the callback is not executed. virtual void getTrailers(LookupTrailersCallback&& cb) PURE; // This routine is called prior to a LookupContext being destroyed. LookupContext is responsible @@ -248,7 +258,7 @@ class LookupContext { // 5. [Other thread] RPC completes and calls RPCLookupContext::onRPCDone. // --> RPCLookupContext's destructor and onRpcDone cause a data race in RPCLookupContext. // onDestroy() should cancel any outstanding async operations and, if necessary, - // it should block on that cancellation to avoid data races. InsertContext must not invoke any + // it should block on that cancellation to avoid data races. LookupContext must not invoke any // callbacks to the CacheFilter after having onDestroy() invoked. virtual void onDestroy() PURE; @@ -289,7 +299,7 @@ class HttpCache { virtual void updateHeaders(const LookupContext& lookup_context, const Http::ResponseHeaderMap& response_headers, const ResponseMetadata& metadata, - std::function on_complete) PURE; + UpdateHeadersCallback on_complete) PURE; // Returns statically known information about a cache. virtual CacheInfo cacheInfo() const PURE; diff --git a/source/extensions/http/cache/file_system_http_cache/file_system_http_cache.cc b/source/extensions/http/cache/file_system_http_cache/file_system_http_cache.cc index 1e5a49c6c6f6..5e220e2b4598 100644 --- a/source/extensions/http/cache/file_system_http_cache/file_system_http_cache.cc +++ b/source/extensions/http/cache/file_system_http_cache/file_system_http_cache.cc @@ -127,13 +127,13 @@ class HeaderUpdateContext : public Logger::Loggable { HeaderUpdateContext(Event::Dispatcher& dispatcher, const FileSystemHttpCache& cache, const Key& key, std::shared_ptr cleanup, const Http::ResponseHeaderMap& response_headers, - const ResponseMetadata& metadata, std::function on_complete) + const ResponseMetadata& metadata, UpdateHeadersCallback on_complete) : dispatcher_(dispatcher), filepath_(absl::StrCat(cache.cachePath(), cache.generateFilename(key))), cache_path_(cache.cachePath()), cleanup_(cleanup), async_file_manager_(cache.asyncFileManager()), response_headers_(Http::createHeaderMap(response_headers)), - response_metadata_(metadata), on_complete_(on_complete) {} + response_metadata_(metadata), on_complete_(std::move(on_complete)) {} void begin(std::shared_ptr ctx) { async_file_manager_->openExistingFile( @@ -278,14 +278,14 @@ class HeaderUpdateContext : public Logger::Loggable { fail("failed to link new cache file", link_result); return; } - on_complete_(true); + std::move(on_complete_)(true); }); ASSERT(queued.ok()); } void fail(absl::string_view msg, absl::Status status) { ENVOY_LOG(warn, "file_system_http_cache: {} for update cache file {}: {}", msg, filepath_, status); - on_complete_(false); + std::move(on_complete_)(false); } Event::Dispatcher* dispatcher() { return &dispatcher_; } Event::Dispatcher& dispatcher_; @@ -300,13 +300,13 @@ class HeaderUpdateContext : public Logger::Loggable { CacheFileHeader header_proto_; AsyncFileHandle read_handle_; AsyncFileHandle write_handle_; - std::function on_complete_; + UpdateHeadersCallback on_complete_; }; void FileSystemHttpCache::updateHeaders(const LookupContext& base_lookup_context, const Http::ResponseHeaderMap& response_headers, const ResponseMetadata& metadata, - std::function on_complete) { + UpdateHeadersCallback on_complete) { const FileLookupContext& lookup_context = dynamic_cast(base_lookup_context); const Key& key = lookup_context.key(); @@ -314,8 +314,9 @@ void FileSystemHttpCache::updateHeaders(const LookupContext& base_lookup_context if (!cleanup) { return; } - auto ctx = std::make_shared( - *lookup_context.dispatcher(), *this, key, cleanup, response_headers, metadata, on_complete); + auto ctx = + std::make_shared(*lookup_context.dispatcher(), *this, key, cleanup, + response_headers, metadata, std::move(on_complete)); ctx->begin(ctx); } diff --git a/source/extensions/http/cache/file_system_http_cache/file_system_http_cache.h b/source/extensions/http/cache/file_system_http_cache/file_system_http_cache.h index 9e81bf0410b0..be4c59402444 100644 --- a/source/extensions/http/cache/file_system_http_cache/file_system_http_cache.h +++ b/source/extensions/http/cache/file_system_http_cache/file_system_http_cache.h @@ -79,8 +79,7 @@ class FileSystemHttpCache : public HttpCache, */ void updateHeaders(const LookupContext& lookup_context, const Http::ResponseHeaderMap& response_headers, - const ResponseMetadata& metadata, - std::function on_complete) override; + const ResponseMetadata& metadata, UpdateHeadersCallback on_complete) override; /** * The config of this cache. Used by the factory to ensure there aren't incompatible diff --git a/source/extensions/http/cache/file_system_http_cache/insert_context.cc b/source/extensions/http/cache/file_system_http_cache/insert_context.cc index 98e05657734e..55503ef6dd25 100644 --- a/source/extensions/http/cache/file_system_http_cache/insert_context.cc +++ b/source/extensions/http/cache/file_system_http_cache/insert_context.cc @@ -32,7 +32,7 @@ void FileInsertContext::insertHeaders(const Http::ResponseHeaderMap& response_he const ResponseMetadata& metadata, InsertCallback insert_complete, bool end_stream) { ASSERT(dispatcher()->isThreadSafe()); - callback_in_flight_ = insert_complete; + callback_in_flight_ = std::move(insert_complete); const VaryAllowList& vary_allow_list = lookup_context_->lookup().varyAllowList(); const Http::RequestHeaderMap& request_headers = lookup_context_->lookup().requestHeaders(); if (VaryHeaderUtils::hasVary(response_headers)) { @@ -59,7 +59,6 @@ void FileInsertContext::insertHeaders(const Http::ResponseHeaderMap& response_he } cache_file_header_proto_ = makeCacheFileHeaderProto(key_, response_headers, metadata); end_stream_after_headers_ = end_stream; - on_insert_complete_ = std::move(insert_complete); createFile(); } @@ -140,10 +139,10 @@ void FileInsertContext::insertBody(const Buffer::Instance& fragment, ASSERT(!callback_in_flight_); if (!cleanup_) { // Already cancelled, do nothing, return failure. - ready_for_next_fragment(false); + std::move(ready_for_next_fragment)(false); return; } - callback_in_flight_ = ready_for_next_fragment; + callback_in_flight_ = std::move(ready_for_next_fragment); size_t sz = fragment.length(); Buffer::OwnedImpl consumable_fragment(fragment); auto queued = file_handle_->write( @@ -172,10 +171,10 @@ void FileInsertContext::insertTrailers(const Http::ResponseTrailerMap& trailers, ASSERT(!callback_in_flight_); if (!cleanup_) { // Already cancelled, do nothing, return failure. - insert_complete(false); + std::move(insert_complete)(false); return; } - callback_in_flight_ = insert_complete; + callback_in_flight_ = std::move(insert_complete); CacheFileTrailer file_trailer = makeCacheFileTrailerProto(trailers); Buffer::OwnedImpl consumable_buffer = bufferFromProto(file_trailer); size_t sz = consumable_buffer.length(); diff --git a/source/extensions/http/cache/file_system_http_cache/insert_context.h b/source/extensions/http/cache/file_system_http_cache/insert_context.h index 1f8e6665338f..a70fbbae1d06 100644 --- a/source/extensions/http/cache/file_system_http_cache/insert_context.h +++ b/source/extensions/http/cache/file_system_http_cache/insert_context.h @@ -88,7 +88,6 @@ class FileInsertContext : public InsertContext, public Logger::Loggable lookup_context_; Key key_; std::shared_ptr cache_; diff --git a/source/extensions/http/cache/file_system_http_cache/lookup_context.cc b/source/extensions/http/cache/file_system_http_cache/lookup_context.cc index a58b13a44119..9f3d7e028f5f 100644 --- a/source/extensions/http/cache/file_system_http_cache/lookup_context.cc +++ b/source/extensions/http/cache/file_system_http_cache/lookup_context.cc @@ -140,7 +140,7 @@ void FileLookupContext::getBody(const AdjustedByteRange& range, LookupBodyCallba ASSERT(file_handle_); auto queued = file_handle_->read( dispatcher(), header_block_.offsetToBody() + range.begin(), range.length(), - [this, cb = std::move(cb), range](absl::StatusOr read_result) { + [this, cb = std::move(cb), range](absl::StatusOr read_result) mutable { ASSERT(dispatcher()->isThreadSafe()); cancel_action_in_flight_ = nullptr; if (!read_result.ok() || read_result.value()->length() != range.length()) { @@ -164,7 +164,7 @@ void FileLookupContext::getTrailers(LookupTrailersCallback&& cb) { ASSERT(file_handle_); auto queued = file_handle_->read( dispatcher(), header_block_.offsetToTrailers(), header_block_.trailerSize(), - [this, cb = std::move(cb)](absl::StatusOr read_result) { + [this, cb = std::move(cb)](absl::StatusOr read_result) mutable { ASSERT(dispatcher()->isThreadSafe()); cancel_action_in_flight_ = nullptr; if (!read_result.ok() || read_result.value()->length() != header_block_.trailerSize()) { diff --git a/source/extensions/http/cache/simple_http_cache/simple_http_cache.cc b/source/extensions/http/cache/simple_http_cache/simple_http_cache.cc index a2be23281c86..04c70bba9d39 100644 --- a/source/extensions/http/cache/simple_http_cache/simple_http_cache.cc +++ b/source/extensions/http/cache/simple_http_cache/simple_http_cache.cc @@ -33,35 +33,57 @@ absl::optional variedRequestKey(const LookupRequest& request, class SimpleLookupContext : public LookupContext { public: - SimpleLookupContext(SimpleHttpCache& cache, LookupRequest&& request) - : cache_(cache), request_(std::move(request)) {} + SimpleLookupContext(Event::Dispatcher& dispatcher, SimpleHttpCache& cache, + LookupRequest&& request) + : dispatcher_(dispatcher), cache_(cache), request_(std::move(request)) {} void getHeaders(LookupHeadersCallback&& cb) override { auto entry = cache_.lookup(request_); body_ = std::move(entry.body_); trailers_ = std::move(entry.trailers_); - cb(entry.response_headers_ ? request_.makeLookupResult(std::move(entry.response_headers_), - std::move(entry.metadata_), body_.size()) - : LookupResult{}, - body_.empty() && trailers_ == nullptr); + LookupResult result = entry.response_headers_ + ? request_.makeLookupResult(std::move(entry.response_headers_), + std::move(entry.metadata_), body_.size()) + : LookupResult{}; + bool end_stream = body_.empty() && trailers_ == nullptr; + dispatcher_.post([result = std::move(result), cb = std::move(cb), end_stream, + cancelled = cancelled_]() mutable { + if (!*cancelled) { + std::move(cb)(std::move(result), end_stream); + } + }); } void getBody(const AdjustedByteRange& range, LookupBodyCallback&& cb) override { ASSERT(range.end() <= body_.length(), "Attempt to read past end of body."); - cb(std::make_unique(&body_[range.begin()], range.length()), - trailers_ == nullptr && range.end() == body_.length()); + auto result = std::make_unique(&body_[range.begin()], range.length()); + bool end_stream = trailers_ == nullptr && range.end() == body_.length(); + dispatcher_.post([result = std::move(result), cb = std::move(cb), end_stream, + cancelled = cancelled_]() mutable { + if (!*cancelled) { + std::move(cb)(std::move(result), end_stream); + } + }); } // The cache must call cb with the cached trailers. void getTrailers(LookupTrailersCallback&& cb) override { ASSERT(trailers_); - cb(std::move(trailers_)); + dispatcher_.post( + [cb = std::move(cb), trailers = std::move(trailers_), cancelled = cancelled_]() mutable { + if (!*cancelled) { + std::move(cb)(std::move(trailers)); + } + }); } const LookupRequest& request() const { return request_; } - void onDestroy() override {} + void onDestroy() override { *cancelled_ = true; } + Event::Dispatcher& dispatcher() const { return dispatcher_; } private: + Event::Dispatcher& dispatcher_; + std::shared_ptr cancelled_ = std::make_shared(false); SimpleHttpCache& cache_; const LookupRequest request_; std::string body_; @@ -70,13 +92,18 @@ class SimpleLookupContext : public LookupContext { class SimpleInsertContext : public InsertContext { public: - SimpleInsertContext(LookupContext& lookup_context, SimpleHttpCache& cache) - : key_(dynamic_cast(lookup_context).request().key()), - request_headers_( - dynamic_cast(lookup_context).request().requestHeaders()), - vary_allow_list_( - dynamic_cast(lookup_context).request().varyAllowList()), - cache_(cache) {} + SimpleInsertContext(SimpleLookupContext& lookup_context, SimpleHttpCache& cache) + : dispatcher_(lookup_context.dispatcher()), key_(lookup_context.request().key()), + request_headers_(lookup_context.request().requestHeaders()), + vary_allow_list_(lookup_context.request().varyAllowList()), cache_(cache) {} + + void post(InsertCallback cb, bool result) { + dispatcher_.post([cb = std::move(cb), result = result, cancelled = cancelled_]() mutable { + if (!*cancelled) { + std::move(cb)(result); + } + }); + } void insertHeaders(const Http::ResponseHeaderMap& response_headers, const ResponseMetadata& metadata, InsertCallback insert_success, @@ -85,9 +112,9 @@ class SimpleInsertContext : public InsertContext { response_headers_ = Http::createHeaderMap(response_headers); metadata_ = metadata; if (end_stream) { - insert_success(commit()); + post(std::move(insert_success), commit()); } else { - insert_success(true); + post(std::move(insert_success), true); } } @@ -98,9 +125,9 @@ class SimpleInsertContext : public InsertContext { body_.add(chunk); if (end_stream) { - ready_for_next_chunk(commit()); + post(std::move(ready_for_next_chunk), commit()); } else { - ready_for_next_chunk(true); + post(std::move(ready_for_next_chunk), true); } } @@ -108,10 +135,10 @@ class SimpleInsertContext : public InsertContext { InsertCallback insert_complete) override { ASSERT(!committed_); trailers_ = Http::createHeaderMap(trailers); - insert_complete(commit()); + post(std::move(insert_complete), commit()); } - void onDestroy() override {} + void onDestroy() override { *cancelled_ = true; } private: bool commit() { @@ -126,6 +153,8 @@ class SimpleInsertContext : public InsertContext { } } + Event::Dispatcher& dispatcher_; + std::shared_ptr cancelled_ = std::make_shared(false); Key key_; const Http::RequestHeaderMap& request_headers_; const VaryAllowList& vary_allow_list_; @@ -139,32 +168,38 @@ class SimpleInsertContext : public InsertContext { } // namespace LookupContextPtr SimpleHttpCache::makeLookupContext(LookupRequest&& request, - Http::StreamDecoderFilterCallbacks&) { - return std::make_unique(*this, std::move(request)); + Http::StreamDecoderFilterCallbacks& callbacks) { + return std::make_unique(callbacks.dispatcher(), *this, std::move(request)); } void SimpleHttpCache::updateHeaders(const LookupContext& lookup_context, const Http::ResponseHeaderMap& response_headers, const ResponseMetadata& metadata, - std::function on_complete) { + UpdateHeadersCallback on_complete) { const auto& simple_lookup_context = static_cast(lookup_context); const Key& key = simple_lookup_context.request().key(); absl::WriterMutexLock lock(&mutex_); auto iter = map_.find(key); + auto post_complete = [on_complete = std::move(on_complete), + &dispatcher = simple_lookup_context.dispatcher()](bool result) mutable { + dispatcher.post([on_complete = std::move(on_complete), result]() mutable { + std::move(on_complete)(result); + }); + }; if (iter == map_.end() || !iter->second.response_headers_) { - on_complete(false); + std::move(post_complete)(false); return; } if (VaryHeaderUtils::hasVary(*iter->second.response_headers_)) { absl::optional varied_key = variedRequestKey(simple_lookup_context.request(), *iter->second.response_headers_); if (!varied_key.has_value()) { - on_complete(false); + std::move(post_complete)(false); return; } iter = map_.find(varied_key.value()); if (iter == map_.end() || !iter->second.response_headers_) { - on_complete(false); + std::move(post_complete)(false); return; } } @@ -172,7 +207,7 @@ void SimpleHttpCache::updateHeaders(const LookupContext& lookup_context, applyHeaderUpdate(response_headers, *entry.response_headers_); entry.metadata_ = metadata; - on_complete(true); + std::move(post_complete)(true); } SimpleHttpCache::Entry SimpleHttpCache::lookup(const LookupRequest& request) { @@ -278,7 +313,8 @@ bool SimpleHttpCache::varyInsert(const Key& request_key, InsertContextPtr SimpleHttpCache::makeInsertContext(LookupContextPtr&& lookup_context, Http::StreamEncoderFilterCallbacks&) { ASSERT(lookup_context != nullptr); - return std::make_unique(*lookup_context, *this); + return std::make_unique(dynamic_cast(*lookup_context), + *this); } constexpr absl::string_view Name = "envoy.extensions.http.cache.simple"; diff --git a/source/extensions/http/cache/simple_http_cache/simple_http_cache.h b/source/extensions/http/cache/simple_http_cache/simple_http_cache.h index bfad38d9fdaf..515ccc5f0183 100644 --- a/source/extensions/http/cache/simple_http_cache/simple_http_cache.h +++ b/source/extensions/http/cache/simple_http_cache/simple_http_cache.h @@ -40,8 +40,7 @@ class SimpleHttpCache : public HttpCache, public Singleton::Instance { Http::StreamEncoderFilterCallbacks& callbacks) override; void updateHeaders(const LookupContext& lookup_context, const Http::ResponseHeaderMap& response_headers, - const ResponseMetadata& metadata, - std::function on_complete) override; + const ResponseMetadata& metadata, UpdateHeadersCallback on_complete) override; CacheInfo cacheInfo() const override; Entry lookup(const LookupRequest& request); diff --git a/test/extensions/filters/http/cache/cache_filter_test.cc b/test/extensions/filters/http/cache/cache_filter_test.cc index 27bd88716b80..14edc705c36c 100644 --- a/test/extensions/filters/http/cache/cache_filter_test.cc +++ b/test/extensions/filters/http/cache/cache_filter_test.cc @@ -291,6 +291,40 @@ TEST_F(CacheFilterTest, CacheMissWithTrailers) { dispatcher_->run(Event::Dispatcher::RunType::Block); } +TEST_F(CacheFilterTest, CacheMissWithTrailersWhenCacheRespondsQuickerThanUpstream) { + request_headers_.setHost("CacheMissWithTrailers"); + const std::string body = "abc"; + Buffer::OwnedImpl body_buffer(body); + Http::TestResponseTrailerMapImpl trailers; + + for (int request = 1; request <= 2; request++) { + // Each iteration a request is sent to a different host, therefore the second one is a miss + request_headers_.setHost("CacheMissWithTrailers" + std::to_string(request)); + + // Create filter for request 1 + CacheFilterSharedPtr filter = makeFilter(simple_cache_); + + testDecodeRequestMiss(filter); + + // Encode response header + EXPECT_EQ(filter->encodeHeaders(response_headers_, false), Http::FilterHeadersStatus::Continue); + // Resolve cache response + dispatcher_->run(Event::Dispatcher::RunType::Block); + EXPECT_EQ(filter->encodeData(body_buffer, false), Http::FilterDataStatus::Continue); + // Resolve cache response + dispatcher_->run(Event::Dispatcher::RunType::Block); + EXPECT_EQ(filter->encodeTrailers(trailers), Http::FilterTrailersStatus::Continue); + // Resolve cache response + dispatcher_->run(Event::Dispatcher::RunType::Block); + + filter->onStreamComplete(); + EXPECT_THAT(lookupStatus(), IsOkAndHolds(LookupStatus::CacheMiss)); + EXPECT_THAT(insertStatus(), IsOkAndHolds(InsertStatus::InsertSucceeded)); + } + // Clear events off the dispatcher. + dispatcher_->run(Event::Dispatcher::RunType::Block); +} + TEST_F(CacheFilterTest, CacheHitNoBody) { request_headers_.setHost("CacheHitNoBody"); @@ -372,22 +406,25 @@ TEST_F(CacheFilterTest, WatermarkEventsAreSentIfCacheBlocksStreamAndLimitExceede return std::move(mock_insert_context); }); EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { - cb(LookupResult{}, false); + dispatcher_->post([cb = std::move(cb)]() mutable { std::move(cb)(LookupResult{}, false); }); }); EXPECT_CALL(*mock_insert_context, insertHeaders(_, _, _, false)) .WillOnce([&](const Http::ResponseHeaderMap&, const ResponseMetadata&, - InsertCallback insert_complete, bool) { insert_complete(true); }); + InsertCallback insert_complete, bool) { + dispatcher_->post([cb = std::move(insert_complete)]() mutable { std::move(cb)(true); }); + }); InsertCallback captured_insert_body_callback; // The first time insertBody is called, block until the test is ready to call it. // For completion chunk, complete immediately. EXPECT_CALL(*mock_insert_context, insertBody(_, _, false)) .WillOnce([&](const Buffer::Instance&, InsertCallback ready_for_next_chunk, bool) { EXPECT_THAT(captured_insert_body_callback, IsNull()); - captured_insert_body_callback = ready_for_next_chunk; + captured_insert_body_callback = std::move(ready_for_next_chunk); }); EXPECT_CALL(*mock_insert_context, insertBody(_, _, true)) .WillOnce([&](const Buffer::Instance&, InsertCallback ready_for_next_chunk, bool) { - ready_for_next_chunk(true); + dispatcher_->post( + [cb = std::move(ready_for_next_chunk)]() mutable { std::move(cb)(true); }); }); EXPECT_CALL(*mock_insert_context, onDestroy()); EXPECT_CALL(*mock_lookup_context, onDestroy()); @@ -444,18 +481,20 @@ TEST_F(CacheFilterTest, FilterDestroyedWhileWatermarkedSendsLowWatermarkEvent) { return std::move(mock_insert_context); }); EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { - cb(LookupResult{}, false); + dispatcher_->post([cb = std::move(cb)]() mutable { std::move(cb)(LookupResult{}, false); }); }); EXPECT_CALL(*mock_insert_context, insertHeaders(_, _, _, false)) .WillOnce([&](const Http::ResponseHeaderMap&, const ResponseMetadata&, - InsertCallback insert_complete, bool) { insert_complete(true); }); + InsertCallback insert_complete, bool) { + dispatcher_->post([cb = std::move(insert_complete)]() mutable { std::move(cb)(true); }); + }); InsertCallback captured_insert_body_callback; // The first time insertBody is called, block until the test is ready to call it. // Cache aborts, so there is no second call. EXPECT_CALL(*mock_insert_context, insertBody(_, _, false)) .WillOnce([&](const Buffer::Instance&, InsertCallback ready_for_next_chunk, bool) { EXPECT_THAT(captured_insert_body_callback, IsNull()); - captured_insert_body_callback = ready_for_next_chunk; + captured_insert_body_callback = std::move(ready_for_next_chunk); }); EXPECT_CALL(*mock_insert_context, onDestroy()); EXPECT_CALL(*mock_lookup_context, onDestroy()); @@ -477,15 +516,15 @@ TEST_F(CacheFilterTest, FilterDestroyedWhileWatermarkedSendsLowWatermarkEvent) { Buffer::OwnedImpl body1buf(body1); Buffer::OwnedImpl body2buf(body2); EXPECT_EQ(filter->encodeData(body1buf, false), Http::FilterDataStatus::Continue); + dispatcher_->run(Event::Dispatcher::RunType::Block); EXPECT_EQ(filter->encodeData(body2buf, true), Http::FilterDataStatus::Continue); + dispatcher_->run(Event::Dispatcher::RunType::Block); ASSERT_THAT(captured_insert_body_callback, NotNull()); // When the filter is destroyed, a low watermark event should be sent. EXPECT_CALL(encoder_callbacks_, onEncoderFilterBelowWriteBufferLowWatermark()); filter->onDestroy(); filter.reset(); captured_insert_body_callback(false); - // The cache insertBody callback should be posted to the dispatcher. - // Run events on the dispatcher so that the callback is invoked. dispatcher_->run(Event::Dispatcher::RunType::Block); } } @@ -507,19 +546,28 @@ TEST_F(CacheFilterTest, CacheEntryStreamedWithTrailersAndNoContentLengthCanDeliv }); // response_headers_ intentionally has no content length, LookupResult also has no content length. EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { - cb(LookupResult{CacheEntryStatus::Ok, - std::make_unique(response_headers_), - absl::nullopt, absl::nullopt}, - /* end_stream = */ false); + dispatcher_->post([cb = std::move(cb), this]() mutable { + std::move(cb)( + LookupResult{CacheEntryStatus::Ok, + std::make_unique(response_headers_), + absl::nullopt, absl::nullopt}, + /* end_stream = */ false); + }); }); EXPECT_CALL(*mock_lookup_context, getBody(RangeMatcher(0, Gt(5)), _)) .WillOnce([&](AdjustedByteRange, LookupBodyCallback&& cb) { - cb(std::make_unique(body), false); + dispatcher_->post([cb = std::move(cb), &body]() mutable { + std::move(cb)(std::make_unique(body), false); + }); }); EXPECT_CALL(*mock_lookup_context, getBody(RangeMatcher(5, Gt(5)), _)) - .WillOnce([&](AdjustedByteRange, LookupBodyCallback&& cb) { cb(nullptr, false); }); + .WillOnce([&](AdjustedByteRange, LookupBodyCallback&& cb) { + dispatcher_->post([cb = std::move(cb)]() mutable { std::move(cb)(nullptr, false); }); + }); EXPECT_CALL(*mock_lookup_context, getTrailers(_)).WillOnce([&](LookupTrailersCallback&& cb) { - cb(std::make_unique()); + dispatcher_->post([cb = std::move(cb)]() mutable { + std::move(cb)(std::make_unique()); + }); }); EXPECT_CALL(*mock_lookup_context, onDestroy()); { @@ -558,7 +606,11 @@ TEST_F(CacheFilterTest, OnDestroyBeforeOnHeadersAbortsAction) { EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { std::unique_ptr response_headers = std::make_unique(response_headers_); - cb(LookupResult{CacheEntryStatus::Ok, std::move(response_headers), 8, absl::nullopt}, false); + dispatcher_->post([cb = std::move(cb), + response_headers = std::move(response_headers)]() mutable { + std::move(cb)( + LookupResult{CacheEntryStatus::Ok, std::move(response_headers), 8, absl::nullopt}, false); + }); }); auto filter = makeFilter(mock_http_cache, false); EXPECT_EQ(filter->decodeHeaders(request_headers_, true), @@ -581,19 +633,27 @@ TEST_F(CacheFilterTest, OnDestroyBeforeOnBodyAbortsAction) { EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { std::unique_ptr response_headers = std::make_unique(response_headers_); - cb(LookupResult{CacheEntryStatus::Ok, std::move(response_headers), 5, absl::nullopt}, false); + dispatcher_->post([cb = std::move(cb), + response_headers = std::move(response_headers)]() mutable { + std::move(cb)( + LookupResult{CacheEntryStatus::Ok, std::move(response_headers), 5, absl::nullopt}, false); + }); }); LookupBodyCallback body_callback; EXPECT_CALL(*mock_lookup_context, getBody(RangeMatcher(0, 5), _)) - .WillOnce([&](const AdjustedByteRange&, LookupBodyCallback&& cb) { body_callback = cb; }); + .WillOnce([&](const AdjustedByteRange&, LookupBodyCallback&& cb) { + body_callback = std::move(cb); + }); + EXPECT_CALL(*mock_lookup_context, onDestroy()); auto filter = makeFilter(mock_http_cache, false); EXPECT_EQ(filter->decodeHeaders(request_headers_, true), Http::FilterHeadersStatus::StopAllIterationAndWatermark); dispatcher_->run(Event::Dispatcher::RunType::NonBlock); filter->onDestroy(); - // onBody should do nothing because the filter was destroyed. - body_callback(std::make_unique("abcde"), true); - dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + ::testing::Mock::VerifyAndClearExpectations(mock_lookup_context.get()); + EXPECT_THAT(body_callback, NotNull()); + // body_callback should not be called because LookupContext::onDestroy, + // correctly implemented, should have aborted it. } TEST_F(CacheFilterTest, OnDestroyBeforeOnTrailersAbortsAction) { @@ -608,15 +668,21 @@ TEST_F(CacheFilterTest, OnDestroyBeforeOnTrailersAbortsAction) { EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { std::unique_ptr response_headers = std::make_unique(response_headers_); - cb(LookupResult{CacheEntryStatus::Ok, std::move(response_headers), 5, absl::nullopt}, false); + dispatcher_->post([cb = std::move(cb), + response_headers = std::move(response_headers)]() mutable { + std::move(cb)( + LookupResult{CacheEntryStatus::Ok, std::move(response_headers), 5, absl::nullopt}, false); + }); }); EXPECT_CALL(*mock_lookup_context, getBody(RangeMatcher(0, 5), _)) .WillOnce([&](const AdjustedByteRange&, LookupBodyCallback&& cb) { - cb(std::make_unique("abcde"), false); + dispatcher_->post([cb = std::move(cb)]() mutable { + std::move(cb)(std::make_unique("abcde"), false); + }); }); LookupTrailersCallback trailers_callback; EXPECT_CALL(*mock_lookup_context, getTrailers(_)).WillOnce([&](LookupTrailersCallback&& cb) { - trailers_callback = cb; + trailers_callback = std::move(cb); }); auto filter = makeFilter(mock_http_cache, false); EXPECT_EQ(filter->decodeHeaders(request_headers_, true), @@ -643,15 +709,23 @@ TEST_F(CacheFilterTest, BodyReadFromCacheLimitedToBufferSizeChunks) { EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { std::unique_ptr response_headers = std::make_unique(response_headers_); - cb(LookupResult{CacheEntryStatus::Ok, std::move(response_headers), 8, absl::nullopt}, false); + dispatcher_->post([cb = std::move(cb), + response_headers = std::move(response_headers)]() mutable { + std::move(cb)( + LookupResult{CacheEntryStatus::Ok, std::move(response_headers), 8, absl::nullopt}, false); + }); }); EXPECT_CALL(*mock_lookup_context, getBody(RangeMatcher(0, 5), _)) .WillOnce([&](const AdjustedByteRange&, LookupBodyCallback&& cb) { - cb(std::make_unique("abcde"), false); + dispatcher_->post([cb = std::move(cb)]() mutable { + std::move(cb)(std::make_unique("abcde"), false); + }); }); EXPECT_CALL(*mock_lookup_context, getBody(RangeMatcher(5, 8), _)) .WillOnce([&](const AdjustedByteRange&, LookupBodyCallback&& cb) { - cb(std::make_unique("fgh"), true); + dispatcher_->post([cb = std::move(cb)]() mutable { + std::move(cb)(std::make_unique("fgh"), true); + }); }); EXPECT_CALL(*mock_lookup_context, onDestroy()); @@ -703,14 +777,17 @@ TEST_F(CacheFilterTest, CacheInsertAbortedByCache) { return std::move(mock_insert_context); }); EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { - cb(LookupResult{}, false); + dispatcher_->post([cb = std::move(cb)]() mutable { std::move(cb)(LookupResult{}, false); }); }); EXPECT_CALL(*mock_insert_context, insertHeaders(_, _, _, false)) .WillOnce([&](const Http::ResponseHeaderMap&, const ResponseMetadata&, - InsertCallback insert_complete, bool) { insert_complete(true); }); + InsertCallback insert_complete, bool) { + dispatcher_->post([cb = std::move(insert_complete)]() mutable { std::move(cb)(true); }); + }); EXPECT_CALL(*mock_insert_context, insertBody(_, _, true)) .WillOnce([&](const Buffer::Instance&, InsertCallback ready_for_next_chunk, bool) { - ready_for_next_chunk(false); + dispatcher_->post( + [cb = std::move(ready_for_next_chunk)]() mutable { std::move(cb)(false); }); }); EXPECT_CALL(*mock_insert_context, onDestroy()); EXPECT_CALL(*mock_lookup_context, onDestroy()); @@ -753,13 +830,13 @@ TEST_F(CacheFilterTest, FilterDeletedWhileIncompleteCacheWriteInQueueShouldAband return std::move(mock_insert_context); }); EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { - cb(LookupResult{}, false); + dispatcher_->post([cb = std::move(cb)]() mutable { std::move(cb)(LookupResult{}, false); }); }); InsertCallback captured_insert_header_callback; EXPECT_CALL(*mock_insert_context, insertHeaders(_, _, _, false)) .WillOnce([&](const Http::ResponseHeaderMap&, const ResponseMetadata&, InsertCallback insert_complete, - bool) { captured_insert_header_callback = insert_complete; }); + bool) { captured_insert_header_callback = std::move(insert_complete); }); EXPECT_CALL(*mock_insert_context, onDestroy()); EXPECT_CALL(*mock_lookup_context, onDestroy()); { @@ -772,16 +849,14 @@ TEST_F(CacheFilterTest, FilterDeletedWhileIncompleteCacheWriteInQueueShouldAband // Encode header of response. response_headers_.setContentLength(body.size()); EXPECT_EQ(filter->encodeHeaders(response_headers_, false), Http::FilterHeadersStatus::Continue); - // Destroy the filter prematurely. + // Destroy the filter prematurely (it goes out of scope). } ASSERT_THAT(captured_insert_header_callback, NotNull()); EXPECT_THAT(weak_cache_pointer.lock(), NotNull()) << "cache instance was unexpectedly destroyed when filter was destroyed"; + // The callback should now do nothing visible, because the filter has been destroyed. + // Calling it allows the CacheInsertQueue to discard its self-ownership. captured_insert_header_callback(true); - // The callback should be posted to the dispatcher. - // Run events on the dispatcher so that the callback is invoked, - // where it should now do nothing due to the filter being destroyed. - dispatcher_->run(Event::Dispatcher::RunType::Block); } TEST_F(CacheFilterTest, FilterDeletedWhileCompleteCacheWriteInQueueShouldContinueWrite) { @@ -801,17 +876,17 @@ TEST_F(CacheFilterTest, FilterDeletedWhileCompleteCacheWriteInQueueShouldContinu return std::move(mock_insert_context); }); EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { - cb(LookupResult{}, false); + dispatcher_->post([cb = std::move(cb)]() mutable { std::move(cb)(LookupResult{}, false); }); }); InsertCallback captured_insert_header_callback; InsertCallback captured_insert_body_callback; EXPECT_CALL(*mock_insert_context, insertHeaders(_, _, _, false)) .WillOnce([&](const Http::ResponseHeaderMap&, const ResponseMetadata&, InsertCallback insert_complete, - bool) { captured_insert_header_callback = insert_complete; }); + bool) { captured_insert_header_callback = std::move(insert_complete); }); EXPECT_CALL(*mock_insert_context, insertBody(_, _, true)) .WillOnce([&](const Buffer::Instance&, InsertCallback ready_for_next_chunk, bool) { - captured_insert_body_callback = ready_for_next_chunk; + captured_insert_body_callback = std::move(ready_for_next_chunk); }); EXPECT_CALL(*mock_insert_context, onDestroy()); EXPECT_CALL(*mock_lookup_context, onDestroy()); @@ -1391,12 +1466,15 @@ TEST_F(CacheFilterDeathTest, BadRangeRequestLookup) { return std::move(mock_lookup_context); }); EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { - // LookupResult with unknown length and an unsatisfiable RangeDetails is invalid. - cb(LookupResult{CacheEntryStatus::Ok, - std::make_unique(response_headers_), - absl::nullopt, - RangeDetails{/*satisfiable_ = */ false, {AdjustedByteRange{0, 5}}}}, - false); + dispatcher_->post([cb = std::move(cb), this]() mutable { + // LookupResult with unknown length and an unsatisfiable RangeDetails is invalid. + std::move(cb)( + LookupResult{CacheEntryStatus::Ok, + std::make_unique(response_headers_), + absl::nullopt, + RangeDetails{/*satisfiable_ = */ false, {AdjustedByteRange{0, 5}}}}, + false); + }); }); EXPECT_CALL(*mock_lookup_context, onDestroy()); { @@ -1429,15 +1507,20 @@ TEST_F(CacheFilterTest, RangeRequestSatisfiedBeforeLengthKnown) { }); EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { // LookupResult with unknown length and an unsatisfiable RangeDetails is invalid. - cb(LookupResult{CacheEntryStatus::Ok, - std::make_unique(response_headers_), - absl::nullopt, - RangeDetails{/*satisfiable_ = */ true, {AdjustedByteRange{0, 5}}}}, - false); + dispatcher_->post([cb = std::move(cb), this]() mutable { + std::move(cb)( + LookupResult{CacheEntryStatus::Ok, + std::make_unique(response_headers_), + absl::nullopt, + RangeDetails{/*satisfiable_ = */ true, {AdjustedByteRange{0, 5}}}}, + false); + }); }); EXPECT_CALL(*mock_lookup_context, getBody(RangeMatcher(0, 5), _)) .WillOnce([&](AdjustedByteRange, LookupBodyCallback&& cb) { - cb(std::make_unique(body), false); + dispatcher_->post([cb = std::move(cb), &body]() mutable { + cb(std::make_unique(body), false); + }); }); EXPECT_CALL(*mock_lookup_context, onDestroy()); { diff --git a/test/extensions/filters/http/cache/http_cache_implementation_test_common.cc b/test/extensions/filters/http/cache/http_cache_implementation_test_common.cc index c5df0fa728f6..d4c6bb578b01 100644 --- a/test/extensions/filters/http/cache/http_cache_implementation_test_common.cc +++ b/test/extensions/filters/http/cache/http_cache_implementation_test_common.cc @@ -153,6 +153,20 @@ absl::Status HttpCacheImplementationTest::insert( return absl::OkStatus(); } +LookupContextPtr HttpCacheImplementationTest::lookupContextWithAllParts() { + absl::string_view path = "/common"; + Http::TestResponseHeaderMapImpl response_headers{ + {":status", "200"}, + {"date", formatter_.fromTime(time_system_.systemTime())}, + {"cache-control", "public,max-age=3600"}}; + Http::TestResponseTrailerMapImpl response_trailers{ + {"common-trailer", "irrelevant value"}, + }; + EXPECT_THAT(insert(lookup(path), response_headers, "commonbody", response_trailers), IsOk()); + LookupRequest request = makeLookupRequest(path); + return cache()->makeLookupContext(std::move(request), decoder_callbacks_); +} + absl::Status HttpCacheImplementationTest::insert(absl::string_view request_path, const Http::TestResponseHeaderMapImpl& headers, const absl::string_view body) { @@ -777,6 +791,48 @@ TEST_P(HttpCacheImplementationTest, EmptyTrailers) { EXPECT_TRUE(expectLookupSuccessWithBodyAndTrailers(name_lookup_context.get(), body1)); } +TEST_P(HttpCacheImplementationTest, DoesNotRunHeadersCallbackWhenCancelledAfterPosted) { + bool was_called = false; + { + LookupContextPtr context = lookupContextWithAllParts(); + context->getHeaders([&was_called](LookupResult&&, bool) { was_called = true; }); + pumpIntoDispatcher(); + context->onDestroy(); + } + pumpDispatcher(); + EXPECT_FALSE(was_called); +} + +TEST_P(HttpCacheImplementationTest, DoesNotRunBodyCallbackWhenCancelledAfterPosted) { + bool was_called = false; + { + LookupContextPtr context = lookupContextWithAllParts(); + context->getHeaders([](LookupResult&&, bool) {}); + pumpDispatcher(); + context->getBody({0, 10}, [&was_called](Buffer::InstancePtr&&, bool) { was_called = true; }); + pumpIntoDispatcher(); + context->onDestroy(); + } + pumpDispatcher(); + EXPECT_FALSE(was_called); +} + +TEST_P(HttpCacheImplementationTest, DoesNotRunTrailersCallbackWhenCancelledAfterPosted) { + bool was_called = false; + { + LookupContextPtr context = lookupContextWithAllParts(); + context->getHeaders([](LookupResult&&, bool) {}); + pumpDispatcher(); + context->getBody({0, 10}, [](Buffer::InstancePtr&&, bool) {}); + pumpDispatcher(); + context->getTrailers([&was_called](Http::ResponseTrailerMapPtr&&) { was_called = true; }); + pumpIntoDispatcher(); + context->onDestroy(); + } + pumpDispatcher(); + EXPECT_FALSE(was_called); +} + } // namespace Cache } // namespace HttpFilters } // namespace Extensions diff --git a/test/extensions/filters/http/cache/http_cache_implementation_test_common.h b/test/extensions/filters/http/cache/http_cache_implementation_test_common.h index 497bfcad81eb..26f20312fd70 100644 --- a/test/extensions/filters/http/cache/http_cache_implementation_test_common.h +++ b/test/extensions/filters/http/cache/http_cache_implementation_test_common.h @@ -63,6 +63,7 @@ class HttpCacheImplementationTest std::shared_ptr cache() const { return delegate_->cache(); } bool validationEnabled() const { return delegate_->validationEnabled(); } + void pumpIntoDispatcher() { delegate_->beforePumpingDispatcher(); } void pumpDispatcher() { delegate_->pumpDispatcher(); } LookupContextPtr lookup(absl::string_view request_path); @@ -91,6 +92,8 @@ class HttpCacheImplementationTest LookupRequest makeLookupRequest(absl::string_view request_path); + LookupContextPtr lookupContextWithAllParts(); + testing::AssertionResult expectLookupSuccessWithHeaders(LookupContext* lookup_context, const Http::TestResponseHeaderMapImpl& headers); diff --git a/test/extensions/filters/http/cache/mocks.h b/test/extensions/filters/http/cache/mocks.h index 0a243fab1910..d1739f958448 100644 --- a/test/extensions/filters/http/cache/mocks.h +++ b/test/extensions/filters/http/cache/mocks.h @@ -17,7 +17,7 @@ class MockHttpCache : public HttpCache { (LookupContextPtr && lookup_context, Http::StreamEncoderFilterCallbacks& callbacks)); MOCK_METHOD(void, updateHeaders, (const LookupContext& lookup_context, const Http::ResponseHeaderMap& response_headers, - const ResponseMetadata& metadata, std::function on_complete)); + const ResponseMetadata& metadata, absl::AnyInvocable on_complete)); MOCK_METHOD(CacheInfo, cacheInfo, (), (const)); };