Skip to content

Commit

Permalink
Fix key_prefix issue with stats and cache_memlimit
Browse files Browse the repository at this point in the history
Add integration tests to reproduce the issue and add an argument to
_fetch_cmd to skip the key prefix logic as needed.

Closes #430
  • Loading branch information
jogo committed Oct 14, 2022
1 parent 805e813 commit 3dafd67
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 44 deletions.
36 changes: 22 additions & 14 deletions pymemcache/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,10 +368,10 @@ def __init__(
self.encoding = encoding
self.tls_context = tls_context

def check_key(self, key: Key) -> bytes:
def check_key(self, key: Key, key_prefix: bytes) -> bytes:
"""Checks key and add key_prefix."""
return check_key_helper(
key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=self.key_prefix
key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=key_prefix
)

def _connect(self) -> None:
Expand Down Expand Up @@ -684,7 +684,9 @@ def get(self, key: Key, default: Optional[Any] = None) -> Any:
Returns:
The value for the key, or default if the key wasn't found.
"""
return self._fetch_cmd(b"get", [key], False).get(key, default)
return self._fetch_cmd(b"get", [key], False, key_prefix=self.key_prefix).get(
key, default
)

def get_many(self, keys: Iterable[Key]) -> Dict[Key, Any]:
"""
Expand All @@ -701,7 +703,7 @@ def get_many(self, keys: Iterable[Key]) -> Dict[Key, Any]:
if not keys:
return {}

return self._fetch_cmd(b"get", keys, False)
return self._fetch_cmd(b"get", keys, False, key_prefix=self.key_prefix)

get_multi = get_many

Expand All @@ -721,7 +723,9 @@ def gets(
or (default, cas_defaults) if the key was not found.
"""
defaults = (default, cas_default)
return self._fetch_cmd(b"gets", [key], True).get(key, defaults)
return self._fetch_cmd(b"gets", [key], True, key_prefix=self.key_prefix).get(
key, defaults
)

def gets_many(self, keys: Iterable[Key]) -> Dict[Key, Tuple[Any, Any]]:
"""
Expand All @@ -738,7 +742,7 @@ def gets_many(self, keys: Iterable[Key]) -> Dict[Key, Tuple[Any, Any]]:
if not keys:
return {}

return self._fetch_cmd(b"gets", keys, True)
return self._fetch_cmd(b"gets", keys, True, key_prefix=self.key_prefix)

def delete(self, key: Key, noreply: Optional[bool] = None) -> bool:
"""
Expand All @@ -756,7 +760,7 @@ def delete(self, key: Key, noreply: Optional[bool] = None) -> bool:
"""
if noreply is None:
noreply = self.default_noreply
cmd = b"delete " + self.check_key(key)
cmd = b"delete " + self.check_key(key, self.key_prefix)
if noreply:
cmd += b" noreply"
cmd += b"\r\n"
Expand Down Expand Up @@ -790,7 +794,7 @@ def delete_many(self, keys: Iterable[Key], noreply: Optional[bool] = None) -> bo
for key in keys:
cmds.append(
b"delete "
+ self.check_key(key)
+ self.check_key(key, self.key_prefix)
+ (b" noreply" if noreply else b"")
+ b"\r\n"
)
Expand All @@ -814,7 +818,7 @@ def incr(
If noreply is True, always returns None. Otherwise returns the new
value of the key, or None if the key wasn't found.
"""
key = self.check_key(key)
key = self.check_key(key, self.key_prefix)
val = self._check_integer(value, "value")
cmd = b"incr " + key + b" " + val
if noreply:
Expand Down Expand Up @@ -842,7 +846,7 @@ def decr(
If noreply is True, always returns None. Otherwise returns the new
value of the key, or None if the key wasn't found.
"""
key = self.check_key(key)
key = self.check_key(key, self.key_prefix)
val = self._check_integer(value, "value")
cmd = b"decr " + key + b" " + val
if noreply:
Expand Down Expand Up @@ -872,7 +876,7 @@ def touch(self, key: Key, expire: int = 0, noreply: Optional[bool] = None) -> bo
"""
if noreply is None:
noreply = self.default_noreply
key = self.check_key(key)
key = self.check_key(key, self.key_prefix)
expire_bytes = self._check_integer(expire, "expire")
cmd = b"touch " + key + b" " + expire_bytes
if noreply:
Expand Down Expand Up @@ -1109,9 +1113,13 @@ def _extract_value(
return original_key, value, buf

def _fetch_cmd(
self, name: bytes, keys: Iterable[Key], expect_cas: bool
self,
name: bytes,
keys: Iterable[Key],
expect_cas: bool,
key_prefix: bytes = b"",
) -> Dict[Key, Any]:
prefixed_keys = [self.check_key(k) for k in keys]
prefixed_keys = [self.check_key(k, key_prefix=key_prefix) for k in keys]
remapped_keys = dict(zip(prefixed_keys, keys))

# It is important for all keys to be listed in their original order.
Expand Down Expand Up @@ -1184,7 +1192,7 @@ def _store_cmd(
# must be able to reliably map responses back to the original order
keys.append(key)

key = self.check_key(key)
key = self.check_key(key, self.key_prefix)
data, data_flags = self.serde.serialize(key, data)

# If 'flags' was explicitly provided, it overrides the value
Expand Down
10 changes: 8 additions & 2 deletions pymemcache/test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os.path
import pytest
import socket
import ssl

import pytest


def pytest_addoption(parser):
parser.addoption(
Expand Down Expand Up @@ -100,11 +101,16 @@ def pytest_generate_tests(metafunc):
metafunc.parametrize("socket_module", socket_modules)

if "client_class" in metafunc.fixturenames:
from pymemcache.client.base import PooledClient, Client
from pymemcache.client.base import Client, PooledClient
from pymemcache.client.hash import HashClient

class HashClientSingle(HashClient):
def __init__(self, server, *args, **kwargs):
super().__init__([server], *args, **kwargs)

metafunc.parametrize("client_class", [Client, PooledClient, HashClientSingle])

if "key_prefix" in metafunc.fixturenames:
mark = metafunc.definition.get_closest_marker("parametrize")
if not mark or "key_prefix" not in mark.args[0]:
metafunc.parametrize("key_prefix", [b"", b"prefix"])
78 changes: 50 additions & 28 deletions pymemcache/test/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
import json
import pytest
from collections import defaultdict

import pytest
from pymemcache.client.base import Client
from pymemcache.exceptions import (
MemcacheIllegalInputError,
MemcacheClientError,
MemcacheIllegalInputError,
MemcacheServerError,
)
from pymemcache.serde import (
compressed_serde,
PickleSerde,
pickle_serde,
)
from pymemcache.serde import PickleSerde, compressed_serde, pickle_serde


def get_set_helper(client, key, value, key2, value2):
Expand Down Expand Up @@ -56,8 +52,10 @@ def get_set_helper(client, key, value, key2, value2):
compressed_serde,
],
)
def test_get_set(client_class, host, port, serde, socket_module):
client = client_class((host, port), serde=serde, socket_module=socket_module)
def test_get_set(client_class, host, port, serde, socket_module, key_prefix):
client = client_class(
(host, port), serde=serde, socket_module=socket_module, key_prefix=key_prefix
)
client.flush_all()

key = b"key"
Expand All @@ -75,9 +73,15 @@ def test_get_set(client_class, host, port, serde, socket_module):
compressed_serde,
],
)
def test_get_set_unicode_key(client_class, host, port, serde, socket_module):
def test_get_set_unicode_key(
client_class, host, port, serde, socket_module, key_prefix
):
client = client_class(
(host, port), serde=serde, socket_module=socket_module, allow_unicode_keys=True
(host, port),
serde=serde,
socket_module=socket_module,
allow_unicode_keys=True,
key_prefix=key_prefix,
)
client.flush_all()

Expand All @@ -96,8 +100,10 @@ def test_get_set_unicode_key(client_class, host, port, serde, socket_module):
compressed_serde,
],
)
def test_add_replace(client_class, host, port, serde, socket_module):
client = client_class((host, port), serde=serde, socket_module=socket_module)
def test_add_replace(client_class, host, port, serde, socket_module, key_prefix):
client = client_class(
(host, port), serde=serde, socket_module=socket_module, key_prefix=key_prefix
)
client.flush_all()

result = client.add(b"key", b"value", noreply=False)
Expand All @@ -122,8 +128,10 @@ def test_add_replace(client_class, host, port, serde, socket_module):


@pytest.mark.integration()
def test_append_prepend(client_class, host, port, socket_module):
client = client_class((host, port), socket_module=socket_module)
def test_append_prepend(client_class, host, port, socket_module, key_prefix):
client = client_class(
(host, port), socket_module=socket_module, key_prefix=key_prefix
)
client.flush_all()

result = client.append(b"key", b"value", noreply=False)
Expand All @@ -150,8 +158,10 @@ def test_append_prepend(client_class, host, port, socket_module):


@pytest.mark.integration()
def test_cas(client_class, host, port, socket_module):
client = client_class((host, port), socket_module=socket_module)
def test_cas(client_class, host, port, socket_module, key_prefix):
client = client_class(
(host, port), socket_module=socket_module, key_prefix=key_prefix
)
client.flush_all()
result = client.cas(b"key", b"value", b"1", noreply=False)
assert result is None
Expand All @@ -178,8 +188,10 @@ def test_cas(client_class, host, port, socket_module):


@pytest.mark.integration()
def test_gets(client_class, host, port, socket_module):
client = client_class((host, port), socket_module=socket_module)
def test_gets(client_class, host, port, socket_module, key_prefix):
client = client_class(
(host, port), socket_module=socket_module, key_prefix=key_prefix
)
client.flush_all()

result = client.gets(b"key")
Expand All @@ -192,8 +204,10 @@ def test_gets(client_class, host, port, socket_module):


@pytest.mark.integration()
def test_delete(client_class, host, port, socket_module):
client = client_class((host, port), socket_module=socket_module)
def test_delete(client_class, host, port, socket_module, key_prefix):
client = client_class(
(host, port), socket_module=socket_module, key_prefix=key_prefix
)
client.flush_all()

result = client.delete(b"key", noreply=False)
Expand All @@ -210,8 +224,8 @@ def test_delete(client_class, host, port, socket_module):


@pytest.mark.integration()
def test_incr_decr(client_class, host, port, socket_module):
client = Client((host, port), socket_module=socket_module)
def test_incr_decr(client_class, host, port, socket_module, key_prefix):
client = Client((host, port), socket_module=socket_module, key_prefix=key_prefix)
client.flush_all()

result = client.incr(b"key", 1, noreply=False)
Expand All @@ -238,8 +252,10 @@ def _bad_int():


@pytest.mark.integration()
def test_touch(client_class, host, port, socket_module):
client = client_class((host, port), socket_module=socket_module)
def test_touch(client_class, host, port, socket_module, key_prefix):
client = client_class(
(host, port), socket_module=socket_module, key_prefix=key_prefix
)
client.flush_all()

result = client.touch(b"key", noreply=False)
Expand All @@ -256,10 +272,16 @@ def test_touch(client_class, host, port, socket_module):


@pytest.mark.integration()
def test_misc(client_class, host, port, socket_module):
client = Client((host, port), socket_module=socket_module)
def test_misc(client_class, host, port, socket_module, key_prefix):
client = Client((host, port), socket_module=socket_module, key_prefix=key_prefix)
client.flush_all()

# Ensure no exceptions are thrown
client.stats("cachedump", "1", "1")

success = client.cache_memlimit(50)
assert success


@pytest.mark.integration()
def test_serialization_deserialization(host, port, socket_module):
Expand Down

0 comments on commit 3dafd67

Please sign in to comment.