diff --git a/fastapi_redis_cache/cache.py b/fastapi_redis_cache/cache.py index 5257a14..4466c9e 100644 --- a/fastapi_redis_cache/cache.py +++ b/fastapi_redis_cache/cache.py @@ -1,4 +1,5 @@ """cache.py""" + import asyncio from datetime import timedelta from functools import partial, update_wrapper, wraps @@ -40,14 +41,21 @@ async def inner_wrapper(*args, **kwargs): if create_response_directly: response = Response() redis_cache = FastApiRedisCache() - if redis_cache.not_connected or redis_cache.request_is_not_cacheable(request): + if ( + redis_cache.not_connected + or redis_cache.request_is_not_cacheable(request) + ): # if the redis client is not connected or request is not cacheable, no caching behavior is performed. return await get_api_response_async(func, *args, **kwargs) key = redis_cache.get_cache_key(func, *args, **kwargs) ttl, in_cache = redis_cache.check_cache(key) if in_cache: - redis_cache.set_response_headers(response, True, deserialize_json(in_cache), ttl) - if redis_cache.requested_resource_not_modified(request, in_cache): + redis_cache.set_response_headers( + response, True, deserialize_json(in_cache), ttl + ) + if redis_cache.requested_resource_not_modified( + request, in_cache + ): response.status_code = int(HTTPStatus.NOT_MODIFIED) return ( Response( @@ -60,7 +68,11 @@ async def inner_wrapper(*args, **kwargs): else response ) return ( - Response(content=in_cache, media_type="application/json", headers=response.headers) + Response( + content=in_cache, + media_type="application/json", + headers=response.headers, + ) if create_response_directly else deserialize_json(in_cache) ) @@ -68,10 +80,17 @@ async def inner_wrapper(*args, **kwargs): ttl = calculate_ttl(expire) cached = redis_cache.add_to_cache(key, response_data, ttl) if cached: - redis_cache.set_response_headers(response, cache_hit=False, response_data=response_data, ttl=ttl) + redis_cache.set_response_headers( + response, + cache_hit=False, + response_data=response_data, + ttl=ttl, + ) return ( Response( - content=serialize_json(response_data), media_type="application/json", headers=response.headers + content=serialize_json(response_data), + media_type="application/json", + headers=response.headers, ) if create_response_directly else response_data @@ -85,11 +104,15 @@ async def inner_wrapper(*args, **kwargs): async def get_api_response_async(func, *args, **kwargs): """Helper function that allows decorator to work with both async and non-async functions.""" - return await func(*args, **kwargs) if asyncio.iscoroutinefunction(func) else func(*args, **kwargs) + return ( + await func(*args, **kwargs) + if asyncio.iscoroutinefunction(func) + else func(*args, **kwargs) + ) def calculate_ttl(expire: Union[int, timedelta]) -> int: - """"Converts expire time to total seconds and ensures that ttl is capped at one year.""" + """ "Converts expire time to total seconds and ensures that ttl is capped at one year.""" if isinstance(expire, timedelta): expire = int(expire.total_seconds()) return min(expire, ONE_YEAR_IN_SECONDS) diff --git a/fastapi_redis_cache/client.py b/fastapi_redis_cache/client.py index d763488..dcae6ef 100644 --- a/fastapi_redis_cache/client.py +++ b/fastapi_redis_cache/client.py @@ -76,23 +76,40 @@ def init( self._connect() def _connect(self): - self.log(RedisEvent.CONNECT_BEGIN, msg="Attempting to connect to Redis server...") + self.log( + RedisEvent.CONNECT_BEGIN, + msg="Attempting to connect to Redis server...", + ) self.status, self.redis = redis_connect(self.host_url) if self.status == RedisStatus.CONNECTED: - self.log(RedisEvent.CONNECT_SUCCESS, msg="Redis client is connected to server.") + self.log( + RedisEvent.CONNECT_SUCCESS, + msg="Redis client is connected to server.", + ) if self.status == RedisStatus.AUTH_ERROR: # pragma: no cover - self.log(RedisEvent.CONNECT_FAIL, msg="Unable to connect to redis server due to authentication error.") + self.log( + RedisEvent.CONNECT_FAIL, + msg="Unable to connect to redis server due to authentication error.", + ) if self.status == RedisStatus.CONN_ERROR: # pragma: no cover - self.log(RedisEvent.CONNECT_FAIL, msg="Redis server did not respond to PING message.") + self.log( + RedisEvent.CONNECT_FAIL, + msg="Redis server did not respond to PING message.", + ) def request_is_not_cacheable(self, request: Request) -> bool: return request and ( request.method not in ALLOWED_HTTP_TYPES - or any(directive in request.headers.get("Cache-Control", "") for directive in ["no-store", "no-cache"]) + or any( + directive in request.headers.get("Cache-Control", "") + for directive in ["no-store", "no-cache"] + ) ) def get_cache_key(self, func: Callable, *args: List, **kwargs: Dict) -> str: - return get_cache_key(self.prefix, self.ignore_arg_types, func, *args, **kwargs) + return get_cache_key( + self.prefix, self.ignore_arg_types, func, *args, **kwargs + ) def check_cache(self, key: str) -> Tuple[int, str]: pipe = self.redis.pipeline() @@ -101,10 +118,16 @@ def check_cache(self, key: str) -> Tuple[int, str]: self.log(RedisEvent.KEY_FOUND_IN_CACHE, key=key) return (ttl, in_cache) - def requested_resource_not_modified(self, request: Request, cached_data: str) -> bool: + def requested_resource_not_modified( + self, request: Request, cached_data: str + ) -> bool: if not request or "If-None-Match" not in request.headers: return False - check_etags = [etag.strip() for etag in request.headers["If-None-Match"].split(",") if etag] + check_etags = [ + etag.strip() + for etag in request.headers["If-None-Match"].split(",") + if etag + ] if len(check_etags) == 1 and check_etags[0] == "*": return True return self.get_etag(cached_data) in check_etags @@ -125,7 +148,11 @@ def add_to_cache(self, key: str, value: Dict, expire: int) -> bool: return cached def set_response_headers( - self, response: Response, cache_hit: bool, response_data: Dict = None, ttl: int = None + self, + response: Response, + cache_hit: bool, + response_data: Dict = None, + ttl: int = None, ) -> None: response.headers[self.response_header] = "Hit" if cache_hit else "Miss" expires_at = datetime.utcnow() + timedelta(seconds=ttl) @@ -135,7 +162,13 @@ def set_response_headers( if "last_modified" in response_data: # pragma: no cover response.headers["Last-Modified"] = response_data["last_modified"] - def log(self, event: RedisEvent, msg: Optional[str] = None, key: Optional[str] = None, value: Optional[str] = None): + def log( + self, + event: RedisEvent, + msg: Optional[str] = None, + key: Optional[str] = None, + value: Optional[str] = None, + ): """Log `RedisEvent` using the configured `Logger` object""" message = f" {self.get_log_time()} | {event.name}" if msg: diff --git a/fastapi_redis_cache/key_gen.py b/fastapi_redis_cache/key_gen.py index 4e11e18..6077fae 100644 --- a/fastapi_redis_cache/key_gen.py +++ b/fastapi_redis_cache/key_gen.py @@ -1,4 +1,5 @@ """cache.py""" + from collections import OrderedDict from inspect import signature, Signature from typing import Any, Callable, Dict, List @@ -10,7 +11,13 @@ ALWAYS_IGNORE_ARG_TYPES = [Response, Request] -def get_cache_key(prefix: str, ignore_arg_types: List[ArgType], func: Callable, *args: List, **kwargs: Dict) -> str: +def get_cache_key( + prefix: str, + ignore_arg_types: List[ArgType], + func: Callable, + *args: List, + **kwargs: Dict, +) -> str: """Ganerate a string that uniquely identifies the function and values of all arguments. Args: @@ -40,15 +47,23 @@ def get_cache_key(prefix: str, ignore_arg_types: List[ArgType], func: Callable, return f"{prefix}{func.__module__}.{func.__name__}({args_str})" -def get_func_args(sig: Signature, *args: List, **kwargs: Dict) -> "OrderedDict[str, Any]": +def get_func_args( + sig: Signature, *args: List, **kwargs: Dict +) -> "OrderedDict[str, Any]": """Return a dict object containing the name and value of all function arguments.""" func_args = sig.bind(*args, **kwargs) func_args.apply_defaults() return func_args.arguments -def get_args_str(sig_params: SigParameters, func_args: "OrderedDict[str, Any]", ignore_arg_types: List[ArgType]) -> str: +def get_args_str( + sig_params: SigParameters, + func_args: "OrderedDict[str, Any]", + ignore_arg_types: List[ArgType], +) -> str: """Return a string with the name and value of all args whose type is not included in `ignore_arg_types`""" return ",".join( - f"{arg}={val}" for arg, val in func_args.items() if sig_params[arg].annotation not in ignore_arg_types + f"{arg}={val}" + for arg, val in func_args.items() + if sig_params[arg].annotation not in ignore_arg_types ) diff --git a/fastapi_redis_cache/redis.py b/fastapi_redis_cache/redis.py index 9d70bcb..181bc35 100644 --- a/fastapi_redis_cache/redis.py +++ b/fastapi_redis_cache/redis.py @@ -1,4 +1,5 @@ """redis.py""" + import os from typing import Tuple @@ -9,10 +10,16 @@ def redis_connect(host_url: str) -> Tuple[RedisStatus, redis.client.Redis]: """Attempt to connect to `host_url` and return a Redis client instance if successful.""" - return _connect(host_url) if os.environ.get("CACHE_ENV") != "TEST" else _connect_fake() + return ( + _connect(host_url) + if os.environ.get("CACHE_ENV") != "TEST" + else _connect_fake() + ) -def _connect(host_url: str) -> Tuple[RedisStatus, redis.client.Redis]: # pragma: no cover +def _connect( + host_url: str, +) -> Tuple[RedisStatus, redis.client.Redis]: # pragma: no cover try: redis_client = redis.from_url(host_url) if redis_client.ping(): diff --git a/fastapi_redis_cache/util.py b/fastapi_redis_cache/util.py index 663bf3c..2653398 100644 --- a/fastapi_redis_cache/util.py +++ b/fastapi_redis_cache/util.py @@ -23,7 +23,10 @@ class BetterJsonEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, datetime): - return {"val": obj.strftime(DATETIME_AWARE), "_spec_type": str(datetime)} + return { + "val": obj.strftime(DATETIME_AWARE), + "_spec_type": str(datetime), + } elif isinstance(obj, date): return {"val": obj.strftime(DATE_ONLY), "_spec_type": str(date)} elif isinstance(obj, Decimal): @@ -37,7 +40,9 @@ def object_hook(obj): return obj _spec_type = obj["_spec_type"] if _spec_type not in SERIALIZE_OBJ_MAP: # pragma: no cover - raise TypeError(f'"{obj["val"]}" (type: {_spec_type}) is not JSON serializable') + raise TypeError( + f'"{obj["val"]}" (type: {_spec_type}) is not JSON serializable' + ) return SERIALIZE_OBJ_MAP[_spec_type](obj["val"]) diff --git a/tests/main.py b/tests/main.py index be2b693..b95730b 100644 --- a/tests/main.py +++ b/tests/main.py @@ -18,7 +18,10 @@ def cache_never_expire(request: Request, response: Response): @app.get("/cache_expires") @cache(expire=timedelta(seconds=5)) async def cache_expires(): - return {"success": True, "message": "this data should be cached for five seconds"} + return { + "success": True, + "message": "this data should be cached for five seconds", + } @app.get("/cache_json_encoder") @@ -35,7 +38,10 @@ def cache_json_encoder(): @app.get("/cache_one_hour") @cache_one_hour() def partial_cache_one_hour(response: Response): - return {"success": True, "message": "this data should be cached for one hour"} + return { + "success": True, + "message": "this data should be cached for one hour", + } @app.get("/cache_invalid_type") diff --git a/tests/test_cache.py b/tests/test_cache.py index bdc0ccb..6034221 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -19,8 +19,14 @@ def test_cache_never_expire(): # Initial request, X-FastAPI-Cache header field should equal "Miss" response = client.get("/cache_never_expire") assert response.status_code == 200 - assert response.json() == {"success": True, "message": "this data can be cached indefinitely"} - assert "x-fastapi-cache" in response.headers and response.headers["x-fastapi-cache"] == "Miss" + assert response.json() == { + "success": True, + "message": "this data can be cached indefinitely", + } + assert ( + "x-fastapi-cache" in response.headers + and response.headers["x-fastapi-cache"] == "Miss" + ) assert "cache-control" in response.headers assert "expires" in response.headers assert "etag" in response.headers @@ -28,8 +34,14 @@ def test_cache_never_expire(): # Send request to same endpoint, X-FastAPI-Cache header field should now equal "Hit" response = client.get("/cache_never_expire") assert response.status_code == 200 - assert response.json() == {"success": True, "message": "this data can be cached indefinitely"} - assert "x-fastapi-cache" in response.headers and response.headers["x-fastapi-cache"] == "Hit" + assert response.json() == { + "success": True, + "message": "this data can be cached indefinitely", + } + assert ( + "x-fastapi-cache" in response.headers + and response.headers["x-fastapi-cache"] == "Hit" + ) assert "cache-control" in response.headers assert "expires" in response.headers assert "etag" in response.headers @@ -42,8 +54,14 @@ def test_cache_expires(): # Initial request, X-FastAPI-Cache header field should equal "Miss" response = client.get("/cache_expires") assert response.status_code == 200 - assert response.json() == {"success": True, "message": "this data should be cached for five seconds"} - assert "x-fastapi-cache" in response.headers and response.headers["x-fastapi-cache"] == "Miss" + assert response.json() == { + "success": True, + "message": "this data should be cached for five seconds", + } + assert ( + "x-fastapi-cache" in response.headers + and response.headers["x-fastapi-cache"] == "Miss" + ) assert "cache-control" in response.headers assert "expires" in response.headers assert "etag" in response.headers @@ -54,8 +72,14 @@ def test_cache_expires(): # Send request, X-FastAPI-Cache header field should now equal "Hit" response = client.get("/cache_expires") assert response.status_code == 200 - assert response.json() == {"success": True, "message": "this data should be cached for five seconds"} - assert "x-fastapi-cache" in response.headers and response.headers["x-fastapi-cache"] == "Hit" + assert response.json() == { + "success": True, + "message": "this data should be cached for five seconds", + } + assert ( + "x-fastapi-cache" in response.headers + and response.headers["x-fastapi-cache"] == "Hit" + ) # Verify eTag value matches the value stored from the initial response assert "etag" in response.headers @@ -89,8 +113,14 @@ def test_cache_expires(): # Send request, X-FastAPI-Cache header field should equal "Miss" since the cached value has been evicted response = client.get("/cache_expires") assert response.status_code == 200 - assert response.json() == {"success": True, "message": "this data should be cached for five seconds"} - assert "x-fastapi-cache" in response.headers and response.headers["x-fastapi-cache"] == "Miss" + assert response.json() == { + "success": True, + "message": "this data should be cached for five seconds", + } + assert ( + "x-fastapi-cache" in response.headers + and response.headers["x-fastapi-cache"] == "Miss" + ) assert "cache-control" in response.headers assert "expires" in response.headers assert "etag" in response.headers @@ -107,8 +137,14 @@ def test_cache_json_encoder(): response_json = response.json() assert response_json == { "success": True, - "start_time": {"_spec_type": "", "val": "04/20/2021 07:17:17 AM "}, - "finish_by": {"_spec_type": "", "val": "04/21/2021"}, + "start_time": { + "_spec_type": "", + "val": "04/20/2021 07:17:17 AM ", + }, + "finish_by": { + "_spec_type": "", + "val": "04/21/2021", + }, "final_calc": { "_spec_type": "", "val": "3.140000000000000124344978758017532527446746826171875", @@ -126,9 +162,14 @@ def test_cache_json_encoder(): def test_cache_control_no_cache(): # Simple test that verifies if a request is recieved with the cache-control header field containing "no-cache", # no caching behavior is performed - response = client.get("/cache_never_expire", headers={"cache-control": "no-cache"}) + response = client.get( + "/cache_never_expire", headers={"cache-control": "no-cache"} + ) assert response.status_code == 200 - assert response.json() == {"success": True, "message": "this data can be cached indefinitely"} + assert response.json() == { + "success": True, + "message": "this data can be cached indefinitely", + } assert "x-fastapi-cache" not in response.headers assert "cache-control" not in response.headers assert "expires" not in response.headers @@ -138,9 +179,14 @@ def test_cache_control_no_cache(): def test_cache_control_no_store(): # Simple test that verifies if a request is recieved with the cache-control header field containing "no-store", # no caching behavior is performed - response = client.get("/cache_never_expire", headers={"cache-control": "no-store"}) + response = client.get( + "/cache_never_expire", headers={"cache-control": "no-store"} + ) assert response.status_code == 200 - assert response.json() == {"success": True, "message": "this data can be cached indefinitely"} + assert response.json() == { + "success": True, + "message": "this data can be cached indefinitely", + } assert "x-fastapi-cache" not in response.headers assert "cache-control" not in response.headers assert "expires" not in response.headers @@ -151,8 +197,14 @@ def test_if_none_match(): # Initial request, response data is added to cache response = client.get("/cache_never_expire") assert response.status_code == 200 - assert response.json() == {"success": True, "message": "this data can be cached indefinitely"} - assert "x-fastapi-cache" in response.headers and response.headers["x-fastapi-cache"] == "Miss" + assert response.json() == { + "success": True, + "message": "this data can be cached indefinitely", + } + assert ( + "x-fastapi-cache" in response.headers + and response.headers["x-fastapi-cache"] == "Miss" + ) assert "cache-control" in response.headers assert "expires" in response.headers assert "etag" in response.headers @@ -163,10 +215,16 @@ def test_if_none_match(): invalid_etag = "W/-5480454928453453778" # Send request to same endpoint where If-None-Match header contains both valid and invalid eTag values - response = client.get("/cache_never_expire", headers={"if-none-match": f"{etag}, {invalid_etag}"}) + response = client.get( + "/cache_never_expire", + headers={"if-none-match": f"{etag}, {invalid_etag}"}, + ) assert response.status_code == 304 assert not response.content - assert "x-fastapi-cache" in response.headers and response.headers["x-fastapi-cache"] == "Hit" + assert ( + "x-fastapi-cache" in response.headers + and response.headers["x-fastapi-cache"] == "Hit" + ) assert "cache-control" in response.headers assert "expires" in response.headers assert "etag" in response.headers @@ -175,16 +233,27 @@ def test_if_none_match(): response = client.get("/cache_never_expire", headers={"if-none-match": "*"}) assert response.status_code == 304 assert not response.content - assert "x-fastapi-cache" in response.headers and response.headers["x-fastapi-cache"] == "Hit" + assert ( + "x-fastapi-cache" in response.headers + and response.headers["x-fastapi-cache"] == "Hit" + ) assert "cache-control" in response.headers assert "expires" in response.headers assert "etag" in response.headers # Send request to same endpoint where If-None-Match header contains only the invalid eTag value - response = client.get("/cache_never_expire", headers={"if-none-match": invalid_etag}) + response = client.get( + "/cache_never_expire", headers={"if-none-match": invalid_etag} + ) assert response.status_code == 200 - assert response.json() == {"success": True, "message": "this data can be cached indefinitely"} - assert "x-fastapi-cache" in response.headers and response.headers["x-fastapi-cache"] == "Hit" + assert response.json() == { + "success": True, + "message": "this data can be cached indefinitely", + } + assert ( + "x-fastapi-cache" in response.headers + and response.headers["x-fastapi-cache"] == "Hit" + ) assert "cache-control" in response.headers assert "expires" in response.headers assert "etag" in response.headers @@ -195,8 +264,14 @@ def test_partial_cache_one_hour(): # is working correctly. response = client.get("/cache_one_hour") assert response.status_code == 200 - assert response.json() == {"success": True, "message": "this data should be cached for one hour"} - assert "x-fastapi-cache" in response.headers and response.headers["x-fastapi-cache"] == "Miss" + assert response.json() == { + "success": True, + "message": "this data should be cached for one hour", + } + assert ( + "x-fastapi-cache" in response.headers + and response.headers["x-fastapi-cache"] == "Miss" + ) assert "cache-control" in response.headers match = MAX_AGE_REGEX.search(response.headers.get("cache-control")) assert match and int(match.groupdict()["ttl"]) == 3600