Skip to content

Commit

Permalink
run ruff formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
seapagan committed Mar 18, 2024
1 parent 381cecf commit 6084275
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 54 deletions.
39 changes: 31 additions & 8 deletions fastapi_redis_cache/cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""cache.py"""

import asyncio
from datetime import timedelta
from functools import partial, update_wrapper, wraps
Expand Down Expand Up @@ -40,14 +41,21 @@ async def inner_wrapper(*args, **kwargs):
if create_response_directly:
response = Response()
redis_cache = FastApiRedisCache()
if redis_cache.not_connected or redis_cache.request_is_not_cacheable(request):
if (
redis_cache.not_connected
or redis_cache.request_is_not_cacheable(request)
):
# if the redis client is not connected or request is not cacheable, no caching behavior is performed.
return await get_api_response_async(func, *args, **kwargs)
key = redis_cache.get_cache_key(func, *args, **kwargs)
ttl, in_cache = redis_cache.check_cache(key)
if in_cache:
redis_cache.set_response_headers(response, True, deserialize_json(in_cache), ttl)
if redis_cache.requested_resource_not_modified(request, in_cache):
redis_cache.set_response_headers(
response, True, deserialize_json(in_cache), ttl
)
if redis_cache.requested_resource_not_modified(
request, in_cache
):
response.status_code = int(HTTPStatus.NOT_MODIFIED)
return (
Response(
Expand All @@ -60,18 +68,29 @@ async def inner_wrapper(*args, **kwargs):
else response
)
return (
Response(content=in_cache, media_type="application/json", headers=response.headers)
Response(
content=in_cache,
media_type="application/json",
headers=response.headers,
)
if create_response_directly
else deserialize_json(in_cache)
)
response_data = await get_api_response_async(func, *args, **kwargs)
ttl = calculate_ttl(expire)
cached = redis_cache.add_to_cache(key, response_data, ttl)
if cached:
redis_cache.set_response_headers(response, cache_hit=False, response_data=response_data, ttl=ttl)
redis_cache.set_response_headers(
response,
cache_hit=False,
response_data=response_data,
ttl=ttl,
)
return (
Response(
content=serialize_json(response_data), media_type="application/json", headers=response.headers
content=serialize_json(response_data),
media_type="application/json",
headers=response.headers,
)
if create_response_directly
else response_data
Expand All @@ -85,11 +104,15 @@ async def inner_wrapper(*args, **kwargs):

async def get_api_response_async(func, *args, **kwargs):
"""Helper function that allows decorator to work with both async and non-async functions."""
return await func(*args, **kwargs) if asyncio.iscoroutinefunction(func) else func(*args, **kwargs)
return (
await func(*args, **kwargs)
if asyncio.iscoroutinefunction(func)
else func(*args, **kwargs)
)


def calculate_ttl(expire: Union[int, timedelta]) -> int:
""""Converts expire time to total seconds and ensures that ttl is capped at one year."""
""" "Converts expire time to total seconds and ensures that ttl is capped at one year."""
if isinstance(expire, timedelta):
expire = int(expire.total_seconds())
return min(expire, ONE_YEAR_IN_SECONDS)
Expand Down
53 changes: 43 additions & 10 deletions fastapi_redis_cache/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,40 @@ def init(
self._connect()

def _connect(self):
self.log(RedisEvent.CONNECT_BEGIN, msg="Attempting to connect to Redis server...")
self.log(
RedisEvent.CONNECT_BEGIN,
msg="Attempting to connect to Redis server...",
)
self.status, self.redis = redis_connect(self.host_url)
if self.status == RedisStatus.CONNECTED:
self.log(RedisEvent.CONNECT_SUCCESS, msg="Redis client is connected to server.")
self.log(
RedisEvent.CONNECT_SUCCESS,
msg="Redis client is connected to server.",
)
if self.status == RedisStatus.AUTH_ERROR: # pragma: no cover
self.log(RedisEvent.CONNECT_FAIL, msg="Unable to connect to redis server due to authentication error.")
self.log(
RedisEvent.CONNECT_FAIL,
msg="Unable to connect to redis server due to authentication error.",
)
if self.status == RedisStatus.CONN_ERROR: # pragma: no cover
self.log(RedisEvent.CONNECT_FAIL, msg="Redis server did not respond to PING message.")
self.log(
RedisEvent.CONNECT_FAIL,
msg="Redis server did not respond to PING message.",
)

def request_is_not_cacheable(self, request: Request) -> bool:
return request and (
request.method not in ALLOWED_HTTP_TYPES
or any(directive in request.headers.get("Cache-Control", "") for directive in ["no-store", "no-cache"])
or any(
directive in request.headers.get("Cache-Control", "")
for directive in ["no-store", "no-cache"]
)
)

def get_cache_key(self, func: Callable, *args: List, **kwargs: Dict) -> str:
return get_cache_key(self.prefix, self.ignore_arg_types, func, *args, **kwargs)
return get_cache_key(
self.prefix, self.ignore_arg_types, func, *args, **kwargs
)

def check_cache(self, key: str) -> Tuple[int, str]:
pipe = self.redis.pipeline()
Expand All @@ -101,10 +118,16 @@ def check_cache(self, key: str) -> Tuple[int, str]:
self.log(RedisEvent.KEY_FOUND_IN_CACHE, key=key)
return (ttl, in_cache)

def requested_resource_not_modified(self, request: Request, cached_data: str) -> bool:
def requested_resource_not_modified(
self, request: Request, cached_data: str
) -> bool:
if not request or "If-None-Match" not in request.headers:
return False
check_etags = [etag.strip() for etag in request.headers["If-None-Match"].split(",") if etag]
check_etags = [
etag.strip()
for etag in request.headers["If-None-Match"].split(",")
if etag
]
if len(check_etags) == 1 and check_etags[0] == "*":
return True
return self.get_etag(cached_data) in check_etags
Expand All @@ -125,7 +148,11 @@ def add_to_cache(self, key: str, value: Dict, expire: int) -> bool:
return cached

def set_response_headers(
self, response: Response, cache_hit: bool, response_data: Dict = None, ttl: int = None
self,
response: Response,
cache_hit: bool,
response_data: Dict = None,
ttl: int = None,
) -> None:
response.headers[self.response_header] = "Hit" if cache_hit else "Miss"
expires_at = datetime.utcnow() + timedelta(seconds=ttl)
Expand All @@ -135,7 +162,13 @@ def set_response_headers(
if "last_modified" in response_data: # pragma: no cover
response.headers["Last-Modified"] = response_data["last_modified"]

def log(self, event: RedisEvent, msg: Optional[str] = None, key: Optional[str] = None, value: Optional[str] = None):
def log(
self,
event: RedisEvent,
msg: Optional[str] = None,
key: Optional[str] = None,
value: Optional[str] = None,
):
"""Log `RedisEvent` using the configured `Logger` object"""
message = f" {self.get_log_time()} | {event.name}"
if msg:
Expand Down
23 changes: 19 additions & 4 deletions fastapi_redis_cache/key_gen.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""cache.py"""

from collections import OrderedDict
from inspect import signature, Signature
from typing import Any, Callable, Dict, List
Expand All @@ -10,7 +11,13 @@
ALWAYS_IGNORE_ARG_TYPES = [Response, Request]


def get_cache_key(prefix: str, ignore_arg_types: List[ArgType], func: Callable, *args: List, **kwargs: Dict) -> str:
def get_cache_key(
prefix: str,
ignore_arg_types: List[ArgType],
func: Callable,
*args: List,
**kwargs: Dict,
) -> str:
"""Ganerate a string that uniquely identifies the function and values of all arguments.
Args:
Expand Down Expand Up @@ -40,15 +47,23 @@ def get_cache_key(prefix: str, ignore_arg_types: List[ArgType], func: Callable,
return f"{prefix}{func.__module__}.{func.__name__}({args_str})"


def get_func_args(sig: Signature, *args: List, **kwargs: Dict) -> "OrderedDict[str, Any]":
def get_func_args(
sig: Signature, *args: List, **kwargs: Dict
) -> "OrderedDict[str, Any]":
"""Return a dict object containing the name and value of all function arguments."""
func_args = sig.bind(*args, **kwargs)
func_args.apply_defaults()
return func_args.arguments


def get_args_str(sig_params: SigParameters, func_args: "OrderedDict[str, Any]", ignore_arg_types: List[ArgType]) -> str:
def get_args_str(
sig_params: SigParameters,
func_args: "OrderedDict[str, Any]",
ignore_arg_types: List[ArgType],
) -> str:
"""Return a string with the name and value of all args whose type is not included in `ignore_arg_types`"""
return ",".join(
f"{arg}={val}" for arg, val in func_args.items() if sig_params[arg].annotation not in ignore_arg_types
f"{arg}={val}"
for arg, val in func_args.items()
if sig_params[arg].annotation not in ignore_arg_types
)
11 changes: 9 additions & 2 deletions fastapi_redis_cache/redis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""redis.py"""

import os
from typing import Tuple

Expand All @@ -9,10 +10,16 @@

def redis_connect(host_url: str) -> Tuple[RedisStatus, redis.client.Redis]:
"""Attempt to connect to `host_url` and return a Redis client instance if successful."""
return _connect(host_url) if os.environ.get("CACHE_ENV") != "TEST" else _connect_fake()
return (
_connect(host_url)
if os.environ.get("CACHE_ENV") != "TEST"
else _connect_fake()
)


def _connect(host_url: str) -> Tuple[RedisStatus, redis.client.Redis]: # pragma: no cover
def _connect(
host_url: str,
) -> Tuple[RedisStatus, redis.client.Redis]: # pragma: no cover
try:
redis_client = redis.from_url(host_url)
if redis_client.ping():
Expand Down
9 changes: 7 additions & 2 deletions fastapi_redis_cache/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
class BetterJsonEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, datetime):
return {"val": obj.strftime(DATETIME_AWARE), "_spec_type": str(datetime)}
return {
"val": obj.strftime(DATETIME_AWARE),
"_spec_type": str(datetime),
}
elif isinstance(obj, date):
return {"val": obj.strftime(DATE_ONLY), "_spec_type": str(date)}
elif isinstance(obj, Decimal):
Expand All @@ -37,7 +40,9 @@ def object_hook(obj):
return obj
_spec_type = obj["_spec_type"]
if _spec_type not in SERIALIZE_OBJ_MAP: # pragma: no cover
raise TypeError(f'"{obj["val"]}" (type: {_spec_type}) is not JSON serializable')
raise TypeError(
f'"{obj["val"]}" (type: {_spec_type}) is not JSON serializable'
)
return SERIALIZE_OBJ_MAP[_spec_type](obj["val"])


Expand Down
10 changes: 8 additions & 2 deletions tests/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ def cache_never_expire(request: Request, response: Response):
@app.get("/cache_expires")
@cache(expire=timedelta(seconds=5))
async def cache_expires():
return {"success": True, "message": "this data should be cached for five seconds"}
return {
"success": True,
"message": "this data should be cached for five seconds",
}


@app.get("/cache_json_encoder")
Expand All @@ -35,7 +38,10 @@ def cache_json_encoder():
@app.get("/cache_one_hour")
@cache_one_hour()
def partial_cache_one_hour(response: Response):
return {"success": True, "message": "this data should be cached for one hour"}
return {
"success": True,
"message": "this data should be cached for one hour",
}


@app.get("/cache_invalid_type")
Expand Down
Loading

0 comments on commit 6084275

Please sign in to comment.