diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py index 8791d5b9..04ae0520 100644 --- a/pymemcache/client/base.py +++ b/pymemcache/client/base.py @@ -368,10 +368,10 @@ def __init__( self.encoding = encoding self.tls_context = tls_context - def check_key(self, key: Key) -> bytes: + def check_key(self, key: Key, key_prefix: bytes) -> bytes: """Checks key and add key_prefix.""" return check_key_helper( - key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=self.key_prefix + key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=key_prefix ) def _connect(self) -> None: @@ -684,7 +684,9 @@ def get(self, key: Key, default: Optional[Any] = None) -> Any: Returns: The value for the key, or default if the key wasn't found. """ - return self._fetch_cmd(b"get", [key], False).get(key, default) + return self._fetch_cmd(b"get", [key], False, key_prefix=self.key_prefix).get( + key, default + ) def get_many(self, keys: Iterable[Key]) -> Dict[Key, Any]: """ @@ -701,7 +703,7 @@ def get_many(self, keys: Iterable[Key]) -> Dict[Key, Any]: if not keys: return {} - return self._fetch_cmd(b"get", keys, False) + return self._fetch_cmd(b"get", keys, False, key_prefix=self.key_prefix) get_multi = get_many @@ -721,7 +723,9 @@ def gets( or (default, cas_defaults) if the key was not found. """ defaults = (default, cas_default) - return self._fetch_cmd(b"gets", [key], True).get(key, defaults) + return self._fetch_cmd(b"gets", [key], True, key_prefix=self.key_prefix).get( + key, defaults + ) def gets_many(self, keys: Iterable[Key]) -> Dict[Key, Tuple[Any, Any]]: """ @@ -738,7 +742,7 @@ def gets_many(self, keys: Iterable[Key]) -> Dict[Key, Tuple[Any, Any]]: if not keys: return {} - return self._fetch_cmd(b"gets", keys, True) + return self._fetch_cmd(b"gets", keys, True, key_prefix=self.key_prefix) def delete(self, key: Key, noreply: Optional[bool] = None) -> bool: """ @@ -756,7 +760,7 @@ def delete(self, key: Key, noreply: Optional[bool] = None) -> bool: """ if noreply is None: noreply = self.default_noreply - cmd = b"delete " + self.check_key(key) + cmd = b"delete " + self.check_key(key, self.key_prefix) if noreply: cmd += b" noreply" cmd += b"\r\n" @@ -790,7 +794,7 @@ def delete_many(self, keys: Iterable[Key], noreply: Optional[bool] = None) -> bo for key in keys: cmds.append( b"delete " - + self.check_key(key) + + self.check_key(key, self.key_prefix) + (b" noreply" if noreply else b"") + b"\r\n" ) @@ -814,7 +818,7 @@ def incr( If noreply is True, always returns None. Otherwise returns the new value of the key, or None if the key wasn't found. """ - key = self.check_key(key) + key = self.check_key(key, self.key_prefix) val = self._check_integer(value, "value") cmd = b"incr " + key + b" " + val if noreply: @@ -842,7 +846,7 @@ def decr( If noreply is True, always returns None. Otherwise returns the new value of the key, or None if the key wasn't found. """ - key = self.check_key(key) + key = self.check_key(key, self.key_prefix) val = self._check_integer(value, "value") cmd = b"decr " + key + b" " + val if noreply: @@ -872,7 +876,7 @@ def touch(self, key: Key, expire: int = 0, noreply: Optional[bool] = None) -> bo """ if noreply is None: noreply = self.default_noreply - key = self.check_key(key) + key = self.check_key(key, self.key_prefix) expire_bytes = self._check_integer(expire, "expire") cmd = b"touch " + key + b" " + expire_bytes if noreply: @@ -1109,9 +1113,13 @@ def _extract_value( return original_key, value, buf def _fetch_cmd( - self, name: bytes, keys: Iterable[Key], expect_cas: bool + self, + name: bytes, + keys: Iterable[Key], + expect_cas: bool, + key_prefix: bytes = b"", ) -> Dict[Key, Any]: - prefixed_keys = [self.check_key(k) for k in keys] + prefixed_keys = [self.check_key(k, key_prefix=key_prefix) for k in keys] remapped_keys = dict(zip(prefixed_keys, keys)) # It is important for all keys to be listed in their original order. @@ -1184,7 +1192,7 @@ def _store_cmd( # must be able to reliably map responses back to the original order keys.append(key) - key = self.check_key(key) + key = self.check_key(key, self.key_prefix) data, data_flags = self.serde.serialize(key, data) # If 'flags' was explicitly provided, it overrides the value diff --git a/pymemcache/test/conftest.py b/pymemcache/test/conftest.py index e8b8bdf4..ce532e96 100644 --- a/pymemcache/test/conftest.py +++ b/pymemcache/test/conftest.py @@ -1,8 +1,9 @@ import os.path -import pytest import socket import ssl +import pytest + def pytest_addoption(parser): parser.addoption( @@ -100,7 +101,7 @@ def pytest_generate_tests(metafunc): metafunc.parametrize("socket_module", socket_modules) if "client_class" in metafunc.fixturenames: - from pymemcache.client.base import PooledClient, Client + from pymemcache.client.base import Client, PooledClient from pymemcache.client.hash import HashClient class HashClientSingle(HashClient): @@ -108,3 +109,8 @@ def __init__(self, server, *args, **kwargs): super().__init__([server], *args, **kwargs) metafunc.parametrize("client_class", [Client, PooledClient, HashClientSingle]) + + if "key_prefix" in metafunc.fixturenames: + mark = metafunc.definition.get_closest_marker("parametrize") + if not mark or "key_prefix" not in mark.args[0]: + metafunc.parametrize("key_prefix", [b"", b"prefix"]) diff --git a/pymemcache/test/test_integration.py b/pymemcache/test/test_integration.py index 961beb67..19d04f20 100644 --- a/pymemcache/test/test_integration.py +++ b/pymemcache/test/test_integration.py @@ -12,21 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import defaultdict import json -import pytest +from collections import defaultdict +import pytest from pymemcache.client.base import Client from pymemcache.exceptions import ( - MemcacheIllegalInputError, MemcacheClientError, + MemcacheIllegalInputError, MemcacheServerError, ) -from pymemcache.serde import ( - compressed_serde, - PickleSerde, - pickle_serde, -) +from pymemcache.serde import PickleSerde, compressed_serde, pickle_serde def get_set_helper(client, key, value, key2, value2): @@ -56,8 +52,10 @@ def get_set_helper(client, key, value, key2, value2): compressed_serde, ], ) -def test_get_set(client_class, host, port, serde, socket_module): - client = client_class((host, port), serde=serde, socket_module=socket_module) +def test_get_set(client_class, host, port, serde, socket_module, key_prefix): + client = client_class( + (host, port), serde=serde, socket_module=socket_module, key_prefix=key_prefix + ) client.flush_all() key = b"key" @@ -75,9 +73,15 @@ def test_get_set(client_class, host, port, serde, socket_module): compressed_serde, ], ) -def test_get_set_unicode_key(client_class, host, port, serde, socket_module): +def test_get_set_unicode_key( + client_class, host, port, serde, socket_module, key_prefix +): client = client_class( - (host, port), serde=serde, socket_module=socket_module, allow_unicode_keys=True + (host, port), + serde=serde, + socket_module=socket_module, + allow_unicode_keys=True, + key_prefix=key_prefix, ) client.flush_all() @@ -96,8 +100,10 @@ def test_get_set_unicode_key(client_class, host, port, serde, socket_module): compressed_serde, ], ) -def test_add_replace(client_class, host, port, serde, socket_module): - client = client_class((host, port), serde=serde, socket_module=socket_module) +def test_add_replace(client_class, host, port, serde, socket_module, key_prefix): + client = client_class( + (host, port), serde=serde, socket_module=socket_module, key_prefix=key_prefix + ) client.flush_all() result = client.add(b"key", b"value", noreply=False) @@ -122,8 +128,10 @@ def test_add_replace(client_class, host, port, serde, socket_module): @pytest.mark.integration() -def test_append_prepend(client_class, host, port, socket_module): - client = client_class((host, port), socket_module=socket_module) +def test_append_prepend(client_class, host, port, socket_module, key_prefix): + client = client_class( + (host, port), socket_module=socket_module, key_prefix=key_prefix + ) client.flush_all() result = client.append(b"key", b"value", noreply=False) @@ -150,8 +158,10 @@ def test_append_prepend(client_class, host, port, socket_module): @pytest.mark.integration() -def test_cas(client_class, host, port, socket_module): - client = client_class((host, port), socket_module=socket_module) +def test_cas(client_class, host, port, socket_module, key_prefix): + client = client_class( + (host, port), socket_module=socket_module, key_prefix=key_prefix + ) client.flush_all() result = client.cas(b"key", b"value", b"1", noreply=False) assert result is None @@ -178,8 +188,10 @@ def test_cas(client_class, host, port, socket_module): @pytest.mark.integration() -def test_gets(client_class, host, port, socket_module): - client = client_class((host, port), socket_module=socket_module) +def test_gets(client_class, host, port, socket_module, key_prefix): + client = client_class( + (host, port), socket_module=socket_module, key_prefix=key_prefix + ) client.flush_all() result = client.gets(b"key") @@ -192,8 +204,10 @@ def test_gets(client_class, host, port, socket_module): @pytest.mark.integration() -def test_delete(client_class, host, port, socket_module): - client = client_class((host, port), socket_module=socket_module) +def test_delete(client_class, host, port, socket_module, key_prefix): + client = client_class( + (host, port), socket_module=socket_module, key_prefix=key_prefix + ) client.flush_all() result = client.delete(b"key", noreply=False) @@ -210,8 +224,8 @@ def test_delete(client_class, host, port, socket_module): @pytest.mark.integration() -def test_incr_decr(client_class, host, port, socket_module): - client = Client((host, port), socket_module=socket_module) +def test_incr_decr(client_class, host, port, socket_module, key_prefix): + client = Client((host, port), socket_module=socket_module, key_prefix=key_prefix) client.flush_all() result = client.incr(b"key", 1, noreply=False) @@ -238,8 +252,10 @@ def _bad_int(): @pytest.mark.integration() -def test_touch(client_class, host, port, socket_module): - client = client_class((host, port), socket_module=socket_module) +def test_touch(client_class, host, port, socket_module, key_prefix): + client = client_class( + (host, port), socket_module=socket_module, key_prefix=key_prefix + ) client.flush_all() result = client.touch(b"key", noreply=False) @@ -256,10 +272,16 @@ def test_touch(client_class, host, port, socket_module): @pytest.mark.integration() -def test_misc(client_class, host, port, socket_module): - client = Client((host, port), socket_module=socket_module) +def test_misc(client_class, host, port, socket_module, key_prefix): + client = Client((host, port), socket_module=socket_module, key_prefix=key_prefix) client.flush_all() + # Ensure no exceptions are thrown + client.stats("cachedump", "1", "1") + + success = client.cache_memlimit(50) + assert success + @pytest.mark.integration() def test_serialization_deserialization(host, port, socket_module):