diff --git a/changelogs/current.yaml b/changelogs/current.yaml index d4568aa6bdf9..79947a08c0e2 100644 --- a/changelogs/current.yaml +++ b/changelogs/current.yaml @@ -232,6 +232,9 @@ removed_config_or_runtime: Removed runtime flag ``envoy.reloadable_features.exclude_host_in_eds_status_draining``. new_features: +- area: redis + change: | + Added support for keys and select. - area: wasm change: | Added the wasm vm reload support to reload wasm vm when the wasm vm is failed with runtime errors. See diff --git a/docs/root/intro/arch_overview/other_protocols/redis.rst b/docs/root/intro/arch_overview/other_protocols/redis.rst index 9c5fe12ea687..c3890510653b 100644 --- a/docs/root/intro/arch_overview/other_protocols/redis.rst +++ b/docs/root/intro/arch_overview/other_protocols/redis.rst @@ -157,11 +157,13 @@ For details on each command's usage see the official EXISTS, Generic EXPIRE, Generic EXPIREAT, Generic + KEYS, String PERSIST, Generic PEXPIRE, Generic PEXPIREAT, Generic PTTL, Generic RESTORE, Generic + SELECT, Generic TOUCH, Generic TTL, Generic TYPE, Generic @@ -300,7 +302,7 @@ Envoy can also generate its own errors in response to the client. the connection." invalid request, "Command was rejected by the first stage of the command splitter due to datatype or length." - unsupported command, "The command was not recognized by Envoy and therefore cannot be serviced + ERR unknown command, "The command was not recognized by Envoy and therefore cannot be serviced because it cannot be hashed to a backend server." finished with n errors, "Fragmented commands which sum the response (e.g. DEL) will return the total number of errors received if any were received." diff --git a/source/extensions/clusters/redis/redis_cluster_lb.cc b/source/extensions/clusters/redis/redis_cluster_lb.cc index 476fda7cd59c..934c4028d052 100644 --- a/source/extensions/clusters/redis/redis_cluster_lb.cc +++ b/source/extensions/clusters/redis/redis_cluster_lb.cc @@ -1,5 +1,7 @@ #include "source/extensions/clusters/redis/redis_cluster_lb.h" +#include + namespace Envoy { namespace Extensions { namespace Clusters { @@ -138,8 +140,17 @@ Upstream::HostConstSharedPtr RedisClusterLoadBalancerFactory::RedisClusterLoadBa return nullptr; } - auto shard = shard_vector_->at( - slot_array_->at(hash.value() % Envoy::Extensions::Clusters::Redis::MaxSlot)); + RedisShardSharedPtr shard; + if (dynamic_cast(context)) { + if (hash.value() < shard_vector_->size()) { + shard = shard_vector_->at(hash.value()); + } else { + return nullptr; + } + } else { + shard = shard_vector_->at( + slot_array_->at(hash.value() % Envoy::Extensions::Clusters::Redis::MaxSlot)); + } auto redis_context = dynamic_cast(context); if (redis_context && redis_context->isReadCommand()) { @@ -213,6 +224,12 @@ absl::string_view RedisLoadBalancerContextImpl::hashtag(absl::string_view v, boo return v.substr(start + 1, end - start - 1); } +RedisSpecifyShardContextImpl::RedisSpecifyShardContextImpl( + uint64_t shard_index, const NetworkFilters::Common::Redis::RespValue& request, + NetworkFilters::Common::Redis::Client::ReadPolicy read_policy) + : RedisLoadBalancerContextImpl(std::to_string(shard_index), true, true, request, read_policy), + shard_index_(shard_index) {} + } // namespace Redis } // namespace Clusters } // namespace Extensions diff --git a/source/extensions/clusters/redis/redis_cluster_lb.h b/source/extensions/clusters/redis/redis_cluster_lb.h index bd86f8e07398..496218e8d080 100644 --- a/source/extensions/clusters/redis/redis_cluster_lb.h +++ b/source/extensions/clusters/redis/redis_cluster_lb.h @@ -113,6 +113,26 @@ class RedisLoadBalancerContextImpl : public RedisLoadBalancerContext, const NetworkFilters::Common::Redis::Client::ReadPolicy read_policy_; }; +class RedisSpecifyShardContextImpl : public RedisLoadBalancerContextImpl { +public: + /** + * The redis specify Shard load balancer context for Redis requests. + * @param shard_index specify the shard index for the Redis request. + * @param request specify the Redis request. + * @param read_policy specify the read policy. + */ + RedisSpecifyShardContextImpl(uint64_t shard_index, + const NetworkFilters::Common::Redis::RespValue& request, + NetworkFilters::Common::Redis::Client::ReadPolicy read_policy = + NetworkFilters::Common::Redis::Client::ReadPolicy::Primary); + + // Upstream::LoadBalancerContextBase + absl::optional computeHashKey() override { return shard_index_; } + +private: + const absl::optional shard_index_; +}; + class ClusterSlotUpdateCallBack { public: virtual ~ClusterSlotUpdateCallBack() = default; diff --git a/source/extensions/filters/network/common/redis/supported_commands.h b/source/extensions/filters/network/common/redis/supported_commands.h index fc5d5b15dea1..16df8ab41602 100644 --- a/source/extensions/filters/network/common/redis/supported_commands.h +++ b/source/extensions/filters/network/common/redis/supported_commands.h @@ -79,6 +79,11 @@ struct SupportedCommands { */ static const std::string& mset() { CONSTRUCT_ON_FIRST_USE(std::string, "mset"); } + /** + * @return keys command + */ + static const std::string& keys() { CONSTRUCT_ON_FIRST_USE(std::string, "keys"); } + /** * @return ping command */ @@ -94,6 +99,11 @@ struct SupportedCommands { */ static const std::string& quit() { CONSTRUCT_ON_FIRST_USE(std::string, "quit"); } + /** + * @return select command + */ + static const std::string& select() { CONSTRUCT_ON_FIRST_USE(std::string, "select"); } + /** * @return commands which alters the state of redis */ @@ -112,6 +122,14 @@ struct SupportedCommands { static bool isReadCommand(const std::string& command) { return !writeCommands().contains(command); } + + static bool isSupportedCommand(const std::string& command) { + return (simpleCommands().contains(command) || evalCommands().contains(command) || + hashMultipleSumResultCommands().contains(command) || + transactionCommands().contains(command) || auth() == command || echo() == command || + mget() == command || mset() == command || keys() == command || ping() == command || + time() == command || quit() == command || select() == command); + } }; } // namespace Redis diff --git a/source/extensions/filters/network/redis_proxy/command_splitter_impl.cc b/source/extensions/filters/network/redis_proxy/command_splitter_impl.cc index 2bed56c82744..ff39761e523d 100644 --- a/source/extensions/filters/network/redis_proxy/command_splitter_impl.cc +++ b/source/extensions/filters/network/redis_proxy/command_splitter_impl.cc @@ -1,5 +1,7 @@ #include "source/extensions/filters/network/redis_proxy/command_splitter_impl.h" +#include + #include "source/common/common/logger.h" #include "source/extensions/filters/network/common/redis/supported_commands.h" @@ -75,6 +77,35 @@ makeFragmentedRequest(const RouteSharedPtr& route, const std::string& command, return handler; } +/** + * Make request and maybe mirror the request based on the mirror policies of the route. + * @param route supplies the route matched with the request. + * @param command supplies the command of the request. + * @param key supplies the key of the request. + * @param incoming_request supplies the request. + * @param callbacks supplies the request completion callbacks. + * @param transaction supplies the transaction info of the current connection. + * @return PoolRequest* a handle to the active request or nullptr if the request could not be made + * for some reason. + */ +Common::Redis::Client::PoolRequest* +makeFragmentedRequestToShard(const RouteSharedPtr& route, const std::string& command, + uint16_t shard_index, const Common::Redis::RespValue& incoming_request, + ConnPool::PoolCallbacks& callbacks, + Common::Redis::Client::Transaction& transaction) { + auto handler = route->upstream(command)->makeRequestToShard( + shard_index, ConnPool::RespVariant(incoming_request), callbacks, transaction); + if (handler) { + for (auto& mirror_policy : route->mirrorPolicies()) { + if (mirror_policy->shouldMirror(command)) { + mirror_policy->upstream()->makeRequestToShard( + shard_index, ConnPool::RespVariant(incoming_request), null_pool_callbacks, transaction); + } + } + } + return handler; +} + // Send a string response downstream. void localResponse(SplitCallbacks& callbacks, std::string response) { Common::Redis::RespValuePtr res(new Common::Redis::RespValue()); @@ -385,6 +416,80 @@ void MSETRequest::onChildResponse(Common::Redis::RespValuePtr&& value, uint32_t } } +SplitRequestPtr KeysRequest::create(Router& router, Common::Redis::RespValuePtr&& incoming_request, + SplitCallbacks& callbacks, CommandStats& command_stats, + TimeSource& time_source, bool delay_command_latency, + const StreamInfo::StreamInfo& stream_info) { + if (incoming_request->asArray().size() != 2) { + onWrongNumberOfArguments(callbacks, *incoming_request); + command_stats.error_.inc(); + return nullptr; + } + const auto route = router.upstreamPool(incoming_request->asArray()[1].asString(), stream_info); + uint32_t shard_size = + route ? route->upstream(incoming_request->asArray()[0].asString())->shardSize() : 0; + if (shard_size == 0) { + command_stats.error_.inc(); + callbacks.onResponse(Common::Redis::Utility::makeError(Response::get().NoUpstreamHost)); + return nullptr; + } + + std::unique_ptr request_ptr{ + new KeysRequest(callbacks, command_stats, time_source, delay_command_latency)}; + request_ptr->num_pending_responses_ = shard_size; + request_ptr->pending_requests_.reserve(request_ptr->num_pending_responses_); + + request_ptr->pending_response_ = std::make_unique(); + request_ptr->pending_response_->type(Common::Redis::RespType::Array); + + Common::Redis::RespValueSharedPtr base_request = std::move(incoming_request); + for (uint32_t shard_index = 0; shard_index < shard_size; shard_index++) { + request_ptr->pending_requests_.emplace_back(*request_ptr, shard_index); + PendingRequest& pending_request = request_ptr->pending_requests_.back(); + + ENVOY_LOG(debug, "keys request shard index {}: {}", shard_index, base_request->toString()); + pending_request.handle_ = + makeFragmentedRequestToShard(route, base_request->asArray()[0].asString(), shard_index, + *base_request, pending_request, callbacks.transaction()); + + if (!pending_request.handle_) { + pending_request.onResponse(Common::Redis::Utility::makeError(Response::get().NoUpstreamHost)); + } + } + + if (request_ptr->num_pending_responses_ > 0) { + return request_ptr; + } + + return nullptr; +} + +void KeysRequest::onChildResponse(Common::Redis::RespValuePtr&& value, uint32_t index) { + pending_requests_[index].handle_ = nullptr; + switch (value->type()) { + case Common::Redis::RespType::Array: { + pending_response_->asArray().insert(pending_response_->asArray().end(), + value->asArray().begin(), value->asArray().end()); + break; + } + default: { + error_count_++; + break; + } + } + + ASSERT(num_pending_responses_ > 0); + if (--num_pending_responses_ == 0) { + updateStats(error_count_ == 0); + if (error_count_ == 0) { + callbacks_.onResponse(std::move(pending_response_)); + } else { + callbacks_.onResponse(Common::Redis::Utility::makeError( + fmt::format("finished with {} error(s)", error_count_))); + } + } +} + SplitRequestPtr SplitKeysSumResultRequest::create(Router& router, Common::Redis::RespValuePtr&& incoming_request, SplitCallbacks& callbacks, CommandStats& command_stats, @@ -593,7 +698,7 @@ InstanceImpl::InstanceImpl(RouterPtr&& router, Stats::Scope& scope, const std::s Common::Redis::FaultManagerPtr&& fault_manager) : router_(std::move(router)), simple_command_handler_(*router_), eval_command_handler_(*router_), mget_handler_(*router_), mset_handler_(*router_), - split_keys_sum_result_handler_(*router_), + keys_handler_(*router_), split_keys_sum_result_handler_(*router_), transaction_handler_(*router_), stats_{ALL_COMMAND_SPLITTER_STATS( POOL_COUNTER_PREFIX(scope, stat_prefix + "splitter."))}, time_source_(time_source), fault_manager_(std::move(fault_manager)) { @@ -616,6 +721,9 @@ InstanceImpl::InstanceImpl(RouterPtr&& router, Stats::Scope& scope, const std::s addHandler(scope, stat_prefix, Common::Redis::SupportedCommands::mset(), latency_in_micros, mset_handler_); + addHandler(scope, stat_prefix, Common::Redis::SupportedCommands::keys(), latency_in_micros, + keys_handler_); + for (const std::string& command : Common::Redis::SupportedCommands::transactionCommands()) { addHandler(scope, stat_prefix, command, latency_in_micros, transaction_handler_); } @@ -637,6 +745,15 @@ SplitRequestPtr InstanceImpl::makeRequest(Common::Redis::RespValuePtr&& request, } std::string command_name = absl::AsciiStrToLower(request->asArray()[0].asString()); + // Compatible with redis behavior, if there is an unsupported command, return immediately, + // this action must be performed before verifying auth, some redis clients rely on this behavior. + if (!Common::Redis::SupportedCommands::isSupportedCommand(command_name)) { + stats_.unsupported_command_.inc(); + callbacks.onResponse(Common::Redis::Utility::makeError(fmt::format( + "ERR unknown command '{}', with args beginning with: {}", request->asArray()[0].asString(), + request->asArray().size() > 1 ? request->asArray()[1].asString() : ""))); + return nullptr; + } if (command_name == Common::Redis::SupportedCommands::auth()) { if (request->asArray().size() < 2) { @@ -704,6 +821,16 @@ SplitRequestPtr InstanceImpl::makeRequest(Common::Redis::RespValuePtr&& request, return nullptr; } + if (command_name == Common::Redis::SupportedCommands::select()) { + // Respond to OK locally. + if (request->asArray().size() != 2) { + onInvalidRequest(callbacks); + return nullptr; + } + localResponse(callbacks, "OK"); + return nullptr; + } + if (command_name == Common::Redis::SupportedCommands::quit()) { callbacks.onQuit(); return nullptr; @@ -718,12 +845,7 @@ SplitRequestPtr InstanceImpl::makeRequest(Common::Redis::RespValuePtr&& request, // Get the handler for the downstream request auto handler = handler_lookup_table_.find(command_name.c_str()); - if (handler == nullptr) { - stats_.unsupported_command_.inc(); - callbacks.onResponse(Common::Redis::Utility::makeError( - fmt::format("unsupported command '{}'", request->asArray()[0].asString()))); - return nullptr; - } + ASSERT(handler != nullptr); // If we are within a transaction, forward all requests to the transaction handler (i.e. handler // of "multi" command). diff --git a/source/extensions/filters/network/redis_proxy/command_splitter_impl.h b/source/extensions/filters/network/redis_proxy/command_splitter_impl.h index fdc94717e80c..b4a598c10786 100644 --- a/source/extensions/filters/network/redis_proxy/command_splitter_impl.h +++ b/source/extensions/filters/network/redis_proxy/command_splitter_impl.h @@ -284,6 +284,26 @@ class MGETRequest : public FragmentedRequest { void onChildResponse(Common::Redis::RespValuePtr&& value, uint32_t index) override; }; +/** + * KeysRequest sends the command to all Redis server. The response from each Redis (which + * must be an array) is merged and returned to the user. If there is any error or failure in + * processing the fragmented commands, an error will be returned. + */ +class KeysRequest : public FragmentedRequest { +public: + static SplitRequestPtr create(Router& router, Common::Redis::RespValuePtr&& incoming_request, + SplitCallbacks& callbacks, CommandStats& command_stats, + TimeSource& time_source, bool delay_command_latency, + const StreamInfo::StreamInfo& stream_info); + +private: + KeysRequest(SplitCallbacks& callbacks, CommandStats& command_stats, TimeSource& time_source, + bool delay_command_latency) + : FragmentedRequest(callbacks, command_stats, time_source, delay_command_latency) {} + // RedisProxy::CommandSplitter::FragmentedRequest + void onChildResponse(Common::Redis::RespValuePtr&& value, uint32_t index) override; +}; + /** * SplitKeysSumResultRequest takes each key from the command and sends the same incoming command * with each key to the appropriate Redis server. The response from each Redis (which must be an @@ -390,6 +410,7 @@ class InstanceImpl : public Instance, Logger::Loggable { CommandHandlerFactory eval_command_handler_; CommandHandlerFactory mget_handler_; CommandHandlerFactory mset_handler_; + CommandHandlerFactory keys_handler_; CommandHandlerFactory split_keys_sum_result_handler_; CommandHandlerFactory transaction_handler_; TrieLookupTable handler_lookup_table_; diff --git a/source/extensions/filters/network/redis_proxy/conn_pool.h b/source/extensions/filters/network/redis_proxy/conn_pool.h index 3beeebd0a2b8..d98e6a362d02 100644 --- a/source/extensions/filters/network/redis_proxy/conn_pool.h +++ b/source/extensions/filters/network/redis_proxy/conn_pool.h @@ -52,6 +52,7 @@ class Instance { public: virtual ~Instance() = default; + virtual uint16_t shardSize() PURE; /** * Makes a redis request. * @param hash_key supplies the key to use for consistent hashing. @@ -64,6 +65,18 @@ class Instance { virtual Common::Redis::Client::PoolRequest* makeRequest(const std::string& hash_key, RespVariant&& request, PoolCallbacks& callbacks, Common::Redis::Client::Transaction& transaction) PURE; + /** + * Makes a redis request. + * @param shard_index supplies the key to use for consistent hashing. + * @param request supplies the request to make. + * @param callbacks supplies the request completion callbacks. + * @param transaction supplies the transaction info of the current connection. + * @return PoolRequest* a handle to the active request or nullptr if the request could not be made + * for some reason. + */ + virtual Common::Redis::Client::PoolRequest* + makeRequestToShard(uint16_t shard_index, RespVariant&& request, PoolCallbacks& callbacks, + Common::Redis::Client::Transaction& transaction) PURE; }; using InstanceSharedPtr = std::shared_ptr; diff --git a/source/extensions/filters/network/redis_proxy/conn_pool_impl.cc b/source/extensions/filters/network/redis_proxy/conn_pool_impl.cc index f1e86fee798c..02e13794d385 100644 --- a/source/extensions/filters/network/redis_proxy/conn_pool_impl.cc +++ b/source/extensions/filters/network/redis_proxy/conn_pool_impl.cc @@ -69,6 +69,8 @@ void InstanceImpl::init() { }); } +uint16_t InstanceImpl::shardSize() { return tls_->getTyped().shardSize(); } + // This method is always called from a InstanceSharedPtr we don't have to worry about tls_->getTyped // failing due to InstanceImpl going away. Common::Redis::Client::PoolRequest* @@ -87,6 +89,16 @@ InstanceImpl::makeRequestToHost(const std::string& host_address, return tls_->getTyped().makeRequestToHost(host_address, request, callbacks); } +// This method is always called from a InstanceSharedPtr we don't have to worry about tls_->getTyped +// failing due to InstanceImpl going away. +Common::Redis::Client::PoolRequest* +InstanceImpl::makeRequestToShard(uint16_t shard_index, RespVariant&& request, + PoolCallbacks& callbacks, + Common::Redis::Client::Transaction& transaction) { + return tls_->getTyped().makeRequestToShard(shard_index, std::move(request), + callbacks, transaction); +} + InstanceImpl::ThreadLocalPool::ThreadLocalPool( std::shared_ptr parent, Event::Dispatcher& dispatcher, std::string cluster_name, const Extensions::Common::DynamicForwardProxy::DnsCacheSharedPtr& dns_cache) @@ -268,6 +280,25 @@ InstanceImpl::ThreadLocalPool::threadLocalActiveClient(Upstream::HostConstShared return client; } +uint16_t InstanceImpl::ThreadLocalPool::shardSize() { + if (cluster_ == nullptr) { + ASSERT(client_map_.empty()); + ASSERT(host_set_member_update_cb_handle_ == nullptr); + return 0; + } + + Common::Redis::RespValue request; + for (uint16_t size = 0;; size++) { + Clusters::Redis::RedisSpecifyShardContextImpl lb_context( + size, request, Common::Redis::Client::ReadPolicy::Primary); + Upstream::HostConstSharedPtr host = cluster_->loadBalancer().chooseHost(&lb_context); + if (!host) { + return size; + } + } + return 0; +} + Common::Redis::Client::PoolRequest* InstanceImpl::ThreadLocalPool::makeRequest(const std::string& key, RespVariant&& request, PoolCallbacks& callbacks, @@ -288,42 +319,29 @@ InstanceImpl::ThreadLocalPool::makeRequest(const std::string& key, RespVariant&& return nullptr; } - uint32_t client_idx = transaction.current_client_idx_; - // If there is an active transaction, establish a new connection if necessary. - if (transaction.active_ && !transaction.connection_established_) { - transaction.clients_[client_idx] = - client_factory_.create(host, dispatcher_, config_, redis_command_stats_, *(stats_scope_), - auth_username_, auth_password_, true); - if (transaction.connection_cb_) { - transaction.clients_[client_idx]->addConnectionCallbacks(*transaction.connection_cb_); - } - } - - pending_requests_.emplace_back(*this, std::move(request), callbacks, host); - PendingRequest& pending_request = pending_requests_.back(); + return makeRequestToHost(host, std::move(request), callbacks, transaction); +} - if (!transaction.active_) { - ThreadLocalActiveClientPtr& client = this->threadLocalActiveClient(host); - if (!client) { - ENVOY_LOG(debug, "redis connection is rate limited, erasing empty client"); - pending_request.request_handler_ = nullptr; - onRequestCompleted(); - client_map_.erase(host); - return nullptr; - } - pending_request.request_handler_ = client->redis_client_->makeRequest( - getRequest(pending_request.incoming_request_), pending_request); - } else { - pending_request.request_handler_ = transaction.clients_[client_idx]->makeRequest( - getRequest(pending_request.incoming_request_), pending_request); +Common::Redis::Client::PoolRequest* +InstanceImpl::ThreadLocalPool::makeRequestToShard(uint16_t shard_index, RespVariant&& request, + PoolCallbacks& callbacks, + Common::Redis::Client::Transaction& transaction) { + if (cluster_ == nullptr) { + ASSERT(client_map_.empty()); + ASSERT(host_set_member_update_cb_handle_ == nullptr); + return nullptr; } - if (pending_request.request_handler_) { - return &pending_request; - } else { - onRequestCompleted(); + Clusters::Redis::RedisSpecifyShardContextImpl lb_context( + shard_index, getRequest(request), + transaction.active_ ? Common::Redis::Client::ReadPolicy::Primary : config_->readPolicy()); + + Upstream::HostConstSharedPtr host = cluster_->loadBalancer().chooseHost(&lb_context); + if (!host) { + ENVOY_LOG(debug, "host not found: '{}'", shard_index); return nullptr; } + return makeRequestToHost(host, std::move(request), callbacks, transaction); } Common::Redis::Client::PoolRequest* InstanceImpl::ThreadLocalPool::makeRequestToHost( @@ -402,6 +420,48 @@ Common::Redis::Client::PoolRequest* InstanceImpl::ThreadLocalPool::makeRequestTo return client->redis_client_->makeRequest(request, callbacks); } +Common::Redis::Client::PoolRequest* +InstanceImpl::ThreadLocalPool::makeRequestToHost(Upstream::HostConstSharedPtr& host, + RespVariant&& request, PoolCallbacks& callbacks, + Common::Redis::Client::Transaction& transaction) { + uint32_t client_idx = transaction.current_client_idx_; + // If there is an active transaction, establish a new connection if necessary. + if (transaction.active_ && !transaction.connection_established_) { + transaction.clients_[client_idx] = + client_factory_.create(host, dispatcher_, config_, redis_command_stats_, *(stats_scope_), + auth_username_, auth_password_, true); + if (transaction.connection_cb_) { + transaction.clients_[client_idx]->addConnectionCallbacks(*transaction.connection_cb_); + } + } + + pending_requests_.emplace_back(*this, std::move(request), callbacks, host); + PendingRequest& pending_request = pending_requests_.back(); + + if (!transaction.active_) { + ThreadLocalActiveClientPtr& client = this->threadLocalActiveClient(host); + if (!client) { + ENVOY_LOG(debug, "redis connection is rate limited, erasing empty client"); + pending_request.request_handler_ = nullptr; + onRequestCompleted(); + client_map_.erase(host); + return nullptr; + } + pending_request.request_handler_ = client->redis_client_->makeRequest( + getRequest(pending_request.incoming_request_), pending_request); + } else { + pending_request.request_handler_ = transaction.clients_[client_idx]->makeRequest( + getRequest(pending_request.incoming_request_), pending_request); + } + + if (pending_request.request_handler_) { + return &pending_request; + } else { + onRequestCompleted(); + return nullptr; + } +} + void InstanceImpl::ThreadLocalPool::onRequestCompleted() { ASSERT(!pending_requests_.empty()); diff --git a/source/extensions/filters/network/redis_proxy/conn_pool_impl.h b/source/extensions/filters/network/redis_proxy/conn_pool_impl.h index 82c152793911..4d6d073c330b 100644 --- a/source/extensions/filters/network/redis_proxy/conn_pool_impl.h +++ b/source/extensions/filters/network/redis_proxy/conn_pool_impl.h @@ -66,10 +66,14 @@ class InstanceImpl : public Instance, public std::enable_shared_from_this& hosts_added); diff --git a/test/extensions/clusters/redis/redis_cluster_lb_test.cc b/test/extensions/clusters/redis/redis_cluster_lb_test.cc index f837f774f059..f3e4fe957d8d 100644 --- a/test/extensions/clusters/redis/redis_cluster_lb_test.cc +++ b/test/extensions/clusters/redis/redis_cluster_lb_test.cc @@ -152,6 +152,49 @@ TEST_F(RedisClusterLoadBalancerTest, Basic) { validateAssignment(hosts, expected_assignments); } +TEST_F(RedisClusterLoadBalancerTest, Shard) { + Upstream::HostVector hosts{Upstream::makeTestHost(info_, "tcp://127.0.0.1:90", simTime()), + Upstream::makeTestHost(info_, "tcp://127.0.0.1:91", simTime()), + Upstream::makeTestHost(info_, "tcp://127.0.0.1:92", simTime())}; + + ClusterSlotsPtr slots = std::make_unique>(std::vector{ + ClusterSlot(0, 1000, hosts[0]->address()), + ClusterSlot(1001, 2000, hosts[1]->address()), + ClusterSlot(2001, 16383, hosts[2]->address()), + }); + Upstream::HostMap all_hosts{ + {hosts[0]->address()->asString(), hosts[0]}, + {hosts[1]->address()->asString(), hosts[1]}, + {hosts[2]->address()->asString(), hosts[2]}, + }; + init(); + factory_->onClusterSlotUpdate(std::move(slots), all_hosts); + + // A list of (hash: host_index) pair + // Simple read command + std::vector get_foo(2); + get_foo[0].type(NetworkFilters::Common::Redis::RespType::BulkString); + get_foo[0].asString() = "get"; + get_foo[1].type(NetworkFilters::Common::Redis::RespType::BulkString); + get_foo[1].asString() = "foo"; + + NetworkFilters::Common::Redis::RespValue get_request; + get_request.type(NetworkFilters::Common::Redis::RespType::Array); + get_request.asArray().swap(get_foo); + + Upstream::LoadBalancerPtr lb = lb_->factory()->create(lb_params_); + for (uint16_t i = 0; i < 5; i++) { + RedisSpecifyShardContextImpl context(i, get_request); + auto host = lb->chooseHost(&context); + if (i < 3) { + EXPECT_FALSE(host == nullptr); + EXPECT_EQ(hosts[i]->address()->asString(), host->address()->asString()); + } else { + EXPECT_TRUE(host == nullptr); + } + } +} + TEST_F(RedisClusterLoadBalancerTest, ReadStrategiesHealthy) { Upstream::HostVector hosts{ Upstream::makeTestHost(info_, "tcp://127.0.0.1:90", simTime()), diff --git a/test/extensions/filters/network/redis_proxy/command_splitter_impl_test.cc b/test/extensions/filters/network/redis_proxy/command_splitter_impl_test.cc index 7a24ed7416ba..aec098bde159 100644 --- a/test/extensions/filters/network/redis_proxy/command_splitter_impl_test.cc +++ b/test/extensions/filters/network/redis_proxy/command_splitter_impl_test.cc @@ -1,6 +1,7 @@ #include #include #include +#include #include #include "source/common/common/fmt.h" @@ -91,6 +92,16 @@ TEST_F(RedisCommandSplitterImplTest, QuitSuccess) { EXPECT_EQ(0UL, store_.counter("redis.foo.splitter.invalid_request").value()); } +TEST_F(RedisCommandSplitterImplTest, AuthWithUser) { + EXPECT_CALL(callbacks_, onAuth("user", "password")); + Common::Redis::RespValuePtr request{new Common::Redis::RespValue()}; + makeBulkStringArray(*request, {"auth", "user", "password"}); + EXPECT_EQ(nullptr, + splitter_.makeRequest(std::move(request), callbacks_, dispatcher_, stream_info_)); + + EXPECT_EQ(0UL, store_.counter("redis.foo.splitter.invalid_request").value()); +} + TEST_F(RedisCommandSplitterImplTest, AuthWithNoPassword) { Common::Redis::RespValue response; response.type(Common::Redis::RespType::Error); @@ -172,8 +183,7 @@ TEST_F(RedisCommandSplitterImplTest, InvalidRequestArrayNotStrings) { TEST_F(RedisCommandSplitterImplTest, UnsupportedCommand) { Common::Redis::RespValue response; response.type(Common::Redis::RespType::Error); - response.asString() = "unsupported command 'newcommand'"; - EXPECT_CALL(callbacks_, connectionAllowed()).WillOnce(Return(true)); + response.asString() = "ERR unknown command 'newcommand', with args beginning with: hello"; EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&response))); Common::Redis::RespValuePtr request{new Common::Redis::RespValue()}; makeBulkStringArray(*request, {"newcommand", "hello"}); @@ -422,6 +432,22 @@ TEST_F(RedisSingleServerRequestTest, EchoSuccess) { EXPECT_EQ(nullptr, handle_); }; +TEST_F(RedisSingleServerRequestTest, EchoInvalid) { + InSequence s; + + Common::Redis::RespValuePtr request{new Common::Redis::RespValue()}; + makeBulkStringArray(*request, {"echo", "hello", "world"}); + + Common::Redis::RespValue response; + response.type(Common::Redis::RespType::Error); + response.asString() = RedisProxy::CommandSplitter::Response::get().InvalidRequest; + + EXPECT_CALL(callbacks_, connectionAllowed()).WillOnce(Return(true)); + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&response))); + handle_ = splitter_.makeRequest(std::move(request), callbacks_, dispatcher_, stream_info_); + EXPECT_EQ(nullptr, handle_); +}; + TEST_F(RedisSingleServerRequestTest, Time) { InSequence s; @@ -534,6 +560,53 @@ TEST_F(RedisSingleServerRequestTest, EvalNoUpstream) { EXPECT_EQ(1UL, store_.counter("redis.foo.command.eval.error").value()); }; +TEST_F(RedisSingleServerRequestTest, Select) { + InSequence s; + + Common::Redis::RespValuePtr request{new Common::Redis::RespValue()}; + makeBulkStringArray(*request, {"select", "1"}); + + Common::Redis::RespValue response; + response.type(Common::Redis::RespType::SimpleString); + response.asString() = Response::get().OK; + + EXPECT_CALL(callbacks_, connectionAllowed()).WillOnce(Return(true)); + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&response))); + handle_ = splitter_.makeRequest(std::move(request), callbacks_, dispatcher_, stream_info_); + EXPECT_EQ(nullptr, handle_); +}; + +TEST_F(RedisSingleServerRequestTest, SelectInvalid) { + InSequence s; + + Common::Redis::RespValuePtr request{new Common::Redis::RespValue()}; + makeBulkStringArray(*request, {"select", "1", "2"}); + + Common::Redis::RespValue response; + response.type(Common::Redis::RespType::Error); + response.asString() = RedisProxy::CommandSplitter::Response::get().InvalidRequest; + + EXPECT_CALL(callbacks_, connectionAllowed()).WillOnce(Return(true)); + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&response))); + handle_ = splitter_.makeRequest(std::move(request), callbacks_, dispatcher_, stream_info_); + EXPECT_EQ(nullptr, handle_); +}; + +TEST_F(RedisSingleServerRequestTest, Hello) { + InSequence s; + + Common::Redis::RespValuePtr request{new Common::Redis::RespValue()}; + makeBulkStringArray(*request, {"hello", "2", "auth", "mypass"}); + + Common::Redis::RespValue response; + response.type(Common::Redis::RespType::Error); + response.asString() = "ERR unknown command 'hello', with args beginning with: 2"; + + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&response))); + handle_ = splitter_.makeRequest(std::move(request), callbacks_, dispatcher_, stream_info_); + EXPECT_EQ(nullptr, handle_); +}; + MATCHER_P(CompositeArrayEq, rhs, "CompositeArray should be equal") { const ConnPool::RespVariant& obj = arg; const auto& lhs = absl::get(obj); @@ -591,6 +664,47 @@ class FragmentedRequestCommandHandlerTest : public RedisCommandSplitterImplTest handle_ = splitter_.makeRequest(std::move(request), callbacks_, dispatcher_, stream_info_); } + void makeRequestToShard(uint16_t shard_size, std::vector& request_strings, + const std::list& null_handle_indexes, bool mirrored) { + Common::Redis::RespValuePtr request{new Common::Redis::RespValue()}; + makeBulkStringArray(*request, request_strings); + + pool_callbacks_.resize(shard_size); + mirror_pool_callbacks_.resize(shard_size); + std::vector tmp_pool_requests(shard_size); + pool_requests_.swap(tmp_pool_requests); + std::vector tmp_mirrored_pool_requests(shard_size); + mirror_pool_requests_.swap(tmp_mirrored_pool_requests); + EXPECT_CALL(callbacks_, connectionAllowed()).WillOnce(Return(true)); + std::vector dummy_requests(shard_size); + + EXPECT_CALL(*conn_pool_, shardSize_()).WillRepeatedly(Return(shard_size)); + if (mirrored) { + EXPECT_CALL(*mirror_conn_pool_, shardSize_()).WillRepeatedly(Return(shard_size)); + } + ConnPool::RespVariant keys(*request); + for (uint32_t i = 0; i < shard_size; i++) { + Common::Redis::Client::PoolRequest* request_to_use = nullptr; + if (std::find(null_handle_indexes.begin(), null_handle_indexes.end(), i) == + null_handle_indexes.end()) { + request_to_use = &pool_requests_[i]; + } + Common::Redis::Client::PoolRequest* mirror_request_to_use = nullptr; + if (std::find(null_handle_indexes.begin(), null_handle_indexes.end(), i) == + null_handle_indexes.end()) { + mirror_request_to_use = &dummy_requests[i]; + } + EXPECT_CALL(*conn_pool_, makeRequestToShard_(i, keys, _)) + .WillOnce(DoAll(WithArg<2>(SaveArgAddress(&pool_callbacks_[i])), Return(request_to_use))); + if (mirrored) { + EXPECT_CALL(*mirror_conn_pool_, makeRequestToShard_(i, keys, _)) + .WillOnce(DoAll(WithArg<2>(SaveArgAddress(&mirror_pool_callbacks_[i])), + Return(mirror_request_to_use))); + } + } + handle_ = splitter_.makeRequest(std::move(request), callbacks_, dispatcher_, stream_info_); + } + std::vector> expected_requests_; std::vector pool_callbacks_; std::vector pool_requests_; @@ -932,6 +1046,145 @@ TEST_F(RedisMSETCommandHandlerTest, WrongNumberOfArgs) { EXPECT_EQ(1UL, store_.counter("redis.foo.command.mset.error").value()); }; +class KeysHandlerTest : public FragmentedRequestCommandHandlerTest, + public testing::WithParamInterface { +public: + void setup(uint16_t shard_size, const std::list& null_handle_indexes, + bool mirrored = false) { + std::vector request_strings = {"keys", "*"}; + makeRequestToShard(shard_size, request_strings, null_handle_indexes, mirrored); + } + + Common::Redis::RespValuePtr response() { + Common::Redis::RespValuePtr response = std::make_unique(); + response->type(Common::Redis::RespType::Array); + return response; + } +}; + +TEST_P(KeysHandlerTest, Normal) { + InSequence s; + + setup(2, {}); + EXPECT_NE(nullptr, handle_); + Common::Redis::RespValue expected_response; + expected_response.type(Common::Redis::RespType::Array); + pool_callbacks_[1]->onResponse(response()); + time_system_.setMonotonicTime(std::chrono::milliseconds(10)); + EXPECT_CALL( + store_, + deliverHistogramToSinks( + Property(&Stats::Metric::name, "redis.foo.command." + GetParam() + ".latency"), 10)); + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&expected_response))); + pool_callbacks_[0]->onResponse(response()); + EXPECT_EQ(1UL, store_.counter("redis.foo.command." + GetParam() + ".total").value()); + EXPECT_EQ(1UL, store_.counter("redis.foo.command." + GetParam() + ".success").value()); +}; + +TEST_P(KeysHandlerTest, Mirrored) { + InSequence s; + + setupMirrorPolicy(); + setup(2, {}, true); + EXPECT_NE(nullptr, handle_); + + Common::Redis::RespValue expected_response; + expected_response.type(Common::Redis::RespType::Array); + + pool_callbacks_[1]->onResponse(response()); + mirror_pool_callbacks_[1]->onResponse(response()); + + time_system_.setMonotonicTime(std::chrono::milliseconds(10)); + EXPECT_CALL( + store_, + deliverHistogramToSinks( + Property(&Stats::Metric::name, "redis.foo.command." + GetParam() + ".latency"), 10)); + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&expected_response))); + pool_callbacks_[0]->onResponse(response()); + mirror_pool_callbacks_[0]->onResponse(response()); + + EXPECT_EQ(1UL, store_.counter("redis.foo.command." + GetParam() + ".total").value()); + EXPECT_EQ(1UL, store_.counter("redis.foo.command." + GetParam() + ".success").value()); +}; + +TEST_F(KeysHandlerTest, Cancel) { + InSequence s; + + setup(2, {}); + EXPECT_NE(nullptr, handle_); + + EXPECT_CALL(pool_requests_[0], cancel()); + EXPECT_CALL(pool_requests_[1], cancel()); + handle_->cancel(); +}; + +TEST_P(KeysHandlerTest, NormalOneZero) { + InSequence s; + + setup(2, {}); + EXPECT_NE(nullptr, handle_); + + Common::Redis::RespValue expected_response; + expected_response.type(Common::Redis::RespType::Array); + + pool_callbacks_[1]->onResponse(response()); + + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&expected_response))); + pool_callbacks_[0]->onResponse(response()); + + EXPECT_EQ(1UL, store_.counter("redis.foo.command." + GetParam() + ".total").value()); + EXPECT_EQ(1UL, store_.counter("redis.foo.command." + GetParam() + ".success").value()); +}; + +TEST_P(KeysHandlerTest, UpstreamError) { + Common::Redis::RespValue expected_response; + expected_response.type(Common::Redis::RespType::Error); + expected_response.asString() = "finished with 2 error(s)"; + + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&expected_response))); + setup(2, {0, 1}); + EXPECT_EQ(nullptr, handle_); + EXPECT_EQ(1UL, store_.counter("redis.foo.command." + GetParam() + ".total").value()); + EXPECT_EQ(1UL, store_.counter("redis.foo.command." + GetParam() + ".error").value()); +}; + +TEST_P(KeysHandlerTest, NoUpstreamHostForAll) { + Common::Redis::RespValue expected_response; + expected_response.type(Common::Redis::RespType::Error); + expected_response.asString() = "no upstream host"; + + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&expected_response))); + setup(0, {}); + EXPECT_EQ(nullptr, handle_); + EXPECT_EQ(1UL, store_.counter("redis.foo.command." + GetParam() + ".total").value()); + EXPECT_EQ(1UL, store_.counter("redis.foo.command." + GetParam() + ".error").value()); +}; + +TEST_F(KeysHandlerTest, KeysWrongNumberOfArgs) { + InSequence s; + + Common::Redis::RespValuePtr request1{new Common::Redis::RespValue()}; + Common::Redis::RespValuePtr request2{new Common::Redis::RespValue()}; + Common::Redis::RespValue response; + response.type(Common::Redis::RespType::Error); + + response.asString() = "wrong number of arguments for 'keys' command"; + EXPECT_CALL(callbacks_, connectionAllowed()).WillOnce(Return(true)); + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&response))); + makeBulkStringArray(*request1, {"keys", "a*", "b*"}); + EXPECT_EQ(nullptr, + splitter_.makeRequest(std::move(request1), callbacks_, dispatcher_, stream_info_)); + + response.asString() = "invalid request"; + EXPECT_CALL(callbacks_, connectionAllowed()).WillOnce(Return(true)); + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&response))); + makeBulkStringArray(*request2, {"keys"}); + EXPECT_EQ(nullptr, + splitter_.makeRequest(std::move(request2), callbacks_, dispatcher_, stream_info_)); +}; + +INSTANTIATE_TEST_SUITE_P(KeysHandlerTest, KeysHandlerTest, testing::Values("keys")); + class RedisSplitKeysSumResultHandlerTest : public FragmentedRequestCommandHandlerTest, public testing::WithParamInterface { public: diff --git a/test/extensions/filters/network/redis_proxy/conn_pool_impl_test.cc b/test/extensions/filters/network/redis_proxy/conn_pool_impl_test.cc index b52f8eafd3ee..f22a8fc50401 100644 --- a/test/extensions/filters/network/redis_proxy/conn_pool_impl_test.cc +++ b/test/extensions/filters/network/redis_proxy/conn_pool_impl_test.cc @@ -365,6 +365,96 @@ TEST_F(RedisConnPoolImplTest, Basic) { tls_.shutdownThread(); }; +TEST_F(RedisConnPoolImplTest, ShardSize) { + InSequence s; + + setup(); + + Common::Redis::RespValueSharedPtr value = std::make_shared(); + MockPoolCallbacks callbacks; + Common::Redis::Client::MockClient* client = new NiceMock(); + + uint16_t shard_size = 3; + EXPECT_CALL(cm_.thread_local_cluster_.lb_, chooseHost(_)) + .WillRepeatedly( + Invoke([&](Upstream::LoadBalancerContext* context) -> Upstream::HostConstSharedPtr { + EXPECT_EQ(context->metadataMatchCriteria(), nullptr); + EXPECT_EQ(context->downstreamConnection(), nullptr); + std::cout << (context->computeHashKey().value()) << std::endl; + if (context->computeHashKey() < shard_size) { + return cm_.thread_local_cluster_.lb_.host_; + } + return nullptr; + })); + EXPECT_CALL(*this, create_(_)).WillRepeatedly(Return(client)); + EXPECT_CALL(*cm_.thread_local_cluster_.lb_.host_, address()) + .WillRepeatedly(Return(test_address_)); + EXPECT_EQ(conn_pool_->shardSize(), shard_size); + + for (uint16_t i = 0; i < 100; i++) { + shard_size = i; + EXPECT_EQ(conn_pool_->shardSize(), shard_size); + } + + delete client; + tls_.shutdownThread(); +}; + +TEST_F(RedisConnPoolImplTest, ShardHost) { + InSequence s; + + setup(); + + Common::Redis::RespValueSharedPtr value = std::make_shared(); + Common::Redis::Client::MockPoolRequest active_request; + MockPoolCallbacks callbacks; + Common::Redis::Client::MockClient* client = new NiceMock(); + + EXPECT_CALL(cm_.thread_local_cluster_.lb_, chooseHost(_)) + .WillOnce(Invoke([&](Upstream::LoadBalancerContext* context) -> Upstream::HostConstSharedPtr { + EXPECT_EQ(context->computeHashKey().value(), 0); + EXPECT_EQ(context->metadataMatchCriteria(), nullptr); + EXPECT_EQ(context->downstreamConnection(), nullptr); + return cm_.thread_local_cluster_.lb_.host_; + })); + EXPECT_CALL(*this, create_(_)).WillOnce(Return(client)); + EXPECT_CALL(*cm_.thread_local_cluster_.lb_.host_, address()) + .WillRepeatedly(Return(test_address_)); + EXPECT_CALL(*client, makeRequest_(Ref(*value), _)).WillOnce(Return(&active_request)); + Common::Redis::Client::PoolRequest* request = + conn_pool_->makeRequestToShard(0, value, callbacks, transaction_); + EXPECT_NE(nullptr, request); + + EXPECT_CALL(active_request, cancel()); + EXPECT_CALL(callbacks, onFailure_()); + EXPECT_CALL(*client, close()); + tls_.shutdownThread(); +}; + +TEST_F(RedisConnPoolImplTest, ShardNoHost) { + InSequence s; + + setup(); + + Common::Redis::RespValueSharedPtr value = std::make_shared(); + MockPoolCallbacks callbacks; + + EXPECT_CALL(cm_.thread_local_cluster_.lb_, chooseHost(_)) + .WillOnce(Invoke([&](Upstream::LoadBalancerContext* context) -> Upstream::HostConstSharedPtr { + EXPECT_EQ(context->computeHashKey().value(), 0); + EXPECT_EQ(context->metadataMatchCriteria(), nullptr); + EXPECT_EQ(context->downstreamConnection(), nullptr); + return nullptr; + })); + EXPECT_CALL(*cm_.thread_local_cluster_.lb_.host_, address()) + .WillRepeatedly(Return(test_address_)); + Common::Redis::Client::PoolRequest* request = + conn_pool_->makeRequestToShard(0, value, callbacks, transaction_); + EXPECT_EQ(nullptr, request); + + tls_.shutdownThread(); +}; + TEST_F(RedisConnPoolImplTest, BasicRespVariant) { InSequence s; @@ -396,6 +486,35 @@ TEST_F(RedisConnPoolImplTest, BasicRespVariant) { tls_.shutdownThread(); }; +TEST_F(RedisConnPoolImplTest, ShardRequestFailed) { + InSequence s; + + setup(); + + Common::Redis::RespValue value; + MockPoolCallbacks callbacks; + Common::Redis::Client::MockClient* client = new NiceMock(); + + EXPECT_CALL(cm_.thread_local_cluster_.lb_, chooseHost(_)) + .WillOnce(Invoke([&](Upstream::LoadBalancerContext* context) -> Upstream::HostConstSharedPtr { + EXPECT_EQ(context->computeHashKey().value(), 0); + EXPECT_EQ(context->metadataMatchCriteria(), nullptr); + EXPECT_EQ(context->downstreamConnection(), nullptr); + return cm_.thread_local_cluster_.lb_.host_; + })); + EXPECT_CALL(*this, create_(_)).WillOnce(Return(client)); + EXPECT_CALL(*cm_.thread_local_cluster_.lb_.host_, address()) + .WillRepeatedly(Return(test_address_)); + EXPECT_CALL(*client, makeRequest_(Eq(value), _)).WillOnce(Return(nullptr)); + Common::Redis::Client::PoolRequest* request = + conn_pool_->makeRequestToShard(0, ConnPool::RespVariant(value), callbacks, transaction_); + + // the request should be null and the callback is not called + EXPECT_EQ(nullptr, request); + EXPECT_CALL(*client, close()); + tls_.shutdownThread(); +}; + TEST_F(RedisConnPoolImplTest, ClientRequestFailed) { InSequence s; diff --git a/test/extensions/filters/network/redis_proxy/mocks.h b/test/extensions/filters/network/redis_proxy/mocks.h index f1e0c03b8e67..679587e224cf 100644 --- a/test/extensions/filters/network/redis_proxy/mocks.h +++ b/test/extensions/filters/network/redis_proxy/mocks.h @@ -86,14 +86,25 @@ class MockInstance : public Instance { MockInstance(); ~MockInstance() override; + uint16_t shardSize() override { return shardSize_(); } + Common::Redis::Client::PoolRequest* makeRequest(const std::string& hash_key, RespVariant&& request, PoolCallbacks& callbacks, Common::Redis::Client::Transaction&) override { return makeRequest_(hash_key, request, callbacks); } + Common::Redis::Client::PoolRequest* + makeRequestToShard(uint16_t shard_index, RespVariant&& request, PoolCallbacks& callbacks, + Common::Redis::Client::Transaction&) override { + return makeRequestToShard_(shard_index, request, callbacks); + } + + MOCK_METHOD(uint16_t, shardSize_, ()); MOCK_METHOD(Common::Redis::Client::PoolRequest*, makeRequest_, (const std::string& hash_key, RespVariant& request, PoolCallbacks& callbacks)); + MOCK_METHOD(Common::Redis::Client::PoolRequest*, makeRequestToShard_, + (uint16_t shard_index, RespVariant& request, PoolCallbacks& callbacks)); MOCK_METHOD(bool, onRedirection, ()); }; } // namespace ConnPool diff --git a/test/extensions/filters/network/redis_proxy/redis_proxy_integration_test.cc b/test/extensions/filters/network/redis_proxy/redis_proxy_integration_test.cc index ef188a699343..d5bb7ea230ec 100644 --- a/test/extensions/filters/network/redis_proxy/redis_proxy_integration_test.cc +++ b/test/extensions/filters/network/redis_proxy/redis_proxy_integration_test.cc @@ -878,6 +878,45 @@ TEST_P(RedisProxyIntegrationTest, QUITRequestAndResponse) { redis_client->close(); } +// This test sends an invalid Redis command from a fake +// downstream client to the envoy proxy. Envoy will respond +// with an ERR unknown command error. + +TEST_P(RedisProxyIntegrationTest, UnknownCommand) { + std::stringstream error_response; + error_response << "-" + << "ERR unknown command 'foo', with args beginning with: " + << "\r\n"; + initialize(); + simpleProxyResponse(makeBulkStringArray({"foo"}), error_response.str()); +} + +// This test sends an invalid Redis command from a fake +// downstream client to the envoy proxy. Envoy will respond +// with an ERR unknown command error. + +TEST_P(RedisProxyIntegrationTest, UnknownCommandWithArgs) { + std::stringstream error_response; + error_response << "-" + << "ERR unknown command 'hello', with args beginning with: world" + << "\r\n"; + initialize(); + simpleProxyResponse(makeBulkStringArray({"hello", "world"}), error_response.str()); +} + +// This test sends an invalid Redis command from a fake +// downstream client to the envoy proxy. Envoy will respond +// with an ERR unknown command error. + +TEST_P(RedisProxyIntegrationTest, HelloCommand) { + std::stringstream error_response; + error_response << "-" + << "ERR unknown command 'hello', with args beginning with: world" + << "\r\n"; + initialize(); + simpleProxyResponse(makeBulkStringArray({"hello", "world"}), error_response.str()); +} + // This test sends an invalid Redis command from a fake // downstream client to the envoy proxy. Envoy will respond // with an invalid request error. @@ -886,7 +925,20 @@ TEST_P(RedisProxyIntegrationTest, InvalidRequest) { std::stringstream error_response; error_response << "-" << RedisCmdSplitter::Response::get().InvalidRequest << "\r\n"; initialize(); - simpleProxyResponse(makeBulkStringArray({"foo"}), error_response.str()); + simpleProxyResponse(makeBulkStringArray({"keys"}), error_response.str()); +} + +// This test sends an invalid Redis command from a fake +// downstream client to the envoy proxy. Envoy will respond +// with an invalid request error. + +TEST_P(RedisProxyIntegrationTest, InvalidArgsRequest) { + std::stringstream error_response; + error_response << "-" + << "wrong number of arguments for 'keys' command" + << "\r\n"; + initialize(); + simpleProxyResponse(makeBulkStringArray({"keys", "a*", "b*"}), error_response.str()); } // This test sends a simple Redis command to a fake upstream