From ed3b7494ccc4e4216cdce8d42b0233e0f678250c Mon Sep 17 00:00:00 2001 From: Krukov D Date: Sun, 5 May 2024 20:57:45 +0300 Subject: [PATCH] fix: use earcly ttl with early cache in etag middleware --- cashews/contrib/fastapi.py | 33 +++++++++++++++++-------- tests/test_intergations/test_fastapi.py | 21 ++++++++++++++++ 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/cashews/contrib/fastapi.py b/cashews/contrib/fastapi.py index 9ffed8a..b72672a 100644 --- a/cashews/contrib/fastapi.py +++ b/cashews/contrib/fastapi.py @@ -3,6 +3,7 @@ import contextlib from contextlib import nullcontext from contextvars import ContextVar +from datetime import datetime from hashlib import blake2s from typing import Any, ContextManager, Sequence @@ -130,7 +131,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): if etag and await self._cache.exists(etag): return Response(status_code=304) - set_key = None + set_key: None | str = None def set_callback(key: str, result: Any): nonlocal set_key @@ -140,21 +141,29 @@ def set_callback(key: str, result: Any): response = await call_next(request) calls = detector.calls_list if not calls: - if set_key: - etag = await self._get_etag(set_key) - if etag: - response.headers[_ETAG_HEADER] = etag + if set_key is not None: + _etag = await self._get_etag(set_key) + if _etag == etag: + return Response(status_code=304) + if _etag: + response.headers[_ETAG_HEADER] = _etag return response key, _ = calls[0] - etag = await self._get_etag(key) - if etag: - response.headers[_ETAG_HEADER] = etag + _etag = await self._get_etag(key) + if _etag == etag: + return Response(status_code=304) + if _etag: + response.headers[_ETAG_HEADER] = _etag return response async def _get_etag(self, key: str) -> str: - data = await self._cache.get_raw(key) - expire = await self._cache.get_expire(key) + data = await self._cache.get(key) + if _is_early_cache(data): + expire = (data[0] - datetime.utcnow()).total_seconds() # type: ignore[index] + data = data[1] # type: ignore[index] + else: + expire = await self._cache.get_expire(key) if not isinstance(data, bytes): data = data.body if isinstance(data, Response) else DEFAULT_PICKLER.dumps(data) etag = blake2s(data).hexdigest() @@ -162,6 +171,10 @@ async def _get_etag(self, key: str) -> str: return etag +def _is_early_cache(data: Any) -> bool: + return isinstance(data, list) and isinstance(data[0], datetime) + + class CacheDeleteMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): if request.headers.get(_CLEAR_CACHE_HEADER) == _CLEAR_CACHE_HEADER_VALUE: diff --git a/tests/test_intergations/test_fastapi.py b/tests/test_intergations/test_fastapi.py index 38cb31a..efe7afa 100644 --- a/tests/test_intergations/test_fastapi.py +++ b/tests/test_intergations/test_fastapi.py @@ -152,6 +152,27 @@ async def rand(): assert etag == response3.headers["ETag"] +def test_cache_etag_early(client_with_middleware, app, cache): + from cashews.contrib.fastapi import CacheEtagMiddleware + + @app.get("/to_cache") + @cache.early(ttl="10s", early_ttl="7s", key="to_cache") + async def rand(): + return str(random()).encode() + + with client_with_middleware(CacheEtagMiddleware, cache_instance=cache) as client: + response = client.get("/to_cache") + etag = response.headers["ETag"] + + response2 = client.get("/to_cache", headers={"If-None-Match": etag}) + assert response2.status_code == 304 + + response3 = client.get("/to_cache", headers={"If-None-Match": str(random())}) + assert response3.status_code == 200 + assert response.content == response3.content + assert etag == response3.headers["ETag"] + + @pytest.fixture(name="app_with_cache_control") def _app_with_cache_control(cache, app): from cashews.contrib.fastapi import CacheRequestControlMiddleware, cache_control_ttl